1
//! Re-exports of the async_std runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5
//!
6
//! We'll probably want to support tokio as well in the future.
7

            
8
/// Types used for networking (async_std implementation)
9
mod net {
10
    use crate::network::{TcpConnectOptions, TcpListenOptions};
11
    #[cfg(unix)]
12
    use crate::network::{UnixConnectOptions, UnixListenOptions};
13
    use crate::{impls, traits};
14

            
15
    use async_std_crate::net::{TcpListener, TcpStream, UdpSocket as StdUdpSocket};
16
    #[cfg(unix)]
17
    use async_std_crate::os::unix::net::{UnixListener, UnixStream};
18
    use async_trait::async_trait;
19
    use futures::stream::{self, Stream};
20
    use paste::paste;
21
    use std::io::Result as IoResult;
22
    use std::net::SocketAddr;
23
    use std::pin::Pin;
24
    use std::task::{Context, Poll};
25
    #[cfg(unix)]
26
    use tor_general_addr::unix;
27
    use tracing::instrument;
28

            
29
    /// Implement NetStreamProvider-related functionality for a single address type.
30
    macro_rules! impl_stream {
31
        { $kind:ident, $addr:ty } => {paste!{
32
            /// A `Stream` of incoming streams.
33
            ///
34
            /// Differs from the output of `*Listener::incoming` in that this
35
            /// struct is a real type, and that it returns a stream and an address
36
            /// for each input.
37
            pub struct [<Incoming $kind Streams>] {
38
                /// Underlying stream of incoming connections.
39
                inner: Pin<Box<dyn Stream<Item = IoResult<([<$kind Stream>], $addr)>> + Send + Sync>>,
40
            }
41
            impl [<Incoming $kind Streams>] {
42
                /// Create a new IncomingStreams from a Listener.
43
6
                pub fn from_listener(lis: [<$kind Listener>]) -> [<Incoming $kind Streams>] {
44
19
                    let stream = stream::unfold(lis, |lis| async move {
45
16
                        let result = lis.accept().await;
46
16
                        Some((result, lis))
47
32
                    });
48
6
                    Self {
49
6
                        inner: Box::pin(stream),
50
6
                    }
51
6
                }
52
            }
53
            impl Stream for [< Incoming $kind Streams >] {
54
                type Item = IoResult<([<$kind Stream>], $addr)>;
55

            
56
22
                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
57
22
                    self.inner.as_mut().poll_next(cx)
58
22
                }
59
            }
60
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
61
                type Stream = [<$kind Stream>];
62
                type Incoming = [<Incoming $kind Streams>];
63
6
                fn incoming(self) -> [<Incoming $kind Streams>] {
64
6
                    [<Incoming $kind Streams>]::from_listener(self)
65
6
                }
66
6
                fn local_addr(&self) -> IoResult<$addr> {
67
6
                    [<$kind Listener>]::local_addr(self)
68
6
                }
69
            }
70
        }}
71
    }
72

            
73
    impl_stream! { Tcp, std::net::SocketAddr }
74
    #[cfg(unix)]
75
    impl_stream! { Unix, unix::SocketAddr}
76

            
77
    #[async_trait]
78
    impl traits::NetStreamProvider<std::net::SocketAddr> for async_executors::AsyncStd {
79
        type Stream = TcpStream;
80
        type Listener = TcpListener;
81
        type ConnectOptions = TcpConnectOptions;
82
        type ListenOptions = TcpListenOptions;
83
        #[instrument(skip_all, level = "trace")]
84
        async fn connect(
85
            &self,
86
            addr: &SocketAddr,
87
            options: &Self::ConnectOptions,
88
        ) -> IoResult<Self::Stream> {
89
            // The async-std runtime uses async-io internally.
90
            let stream = impls::tcp_async_io_connect(addr, options).await?;
91
            Ok(stream.into())
92
        }
93
        async fn listen(
94
            &self,
95
            addr: &SocketAddr,
96
            options: &Self::ListenOptions,
97
6
        ) -> IoResult<Self::Listener> {
98
            // Use an implementation that's the same across all runtimes.
99
            Ok(impls::tcp_listen(addr, options)?.into())
100
6
        }
101
    }
102

            
103
    #[cfg(unix)]
104
    #[async_trait]
105
    impl traits::NetStreamProvider<unix::SocketAddr> for async_executors::AsyncStd {
106
        type Stream = UnixStream;
107
        type Listener = UnixListener;
108
        type ConnectOptions = UnixConnectOptions;
109
        type ListenOptions = UnixListenOptions;
110
        #[instrument(skip_all, level = "trace")]
111
        async fn connect(
112
            &self,
113
            addr: &unix::SocketAddr,
114
            options: &Self::ConnectOptions,
115
        ) -> IoResult<Self::Stream> {
116
            // Will fail to compile if we add options without handling them here.
117
            let UnixConnectOptions {} = options;
118

            
119
            let path = addr
120
                .as_pathname()
121
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
122
            UnixStream::connect(path).await
123
        }
124
        async fn listen(
125
            &self,
126
            addr: &unix::SocketAddr,
127
            options: &Self::ListenOptions,
128
        ) -> IoResult<Self::Listener> {
129
            // Will fail to compile if we add options without handling them here.
130
            let UnixListenOptions {} = options;
131

            
132
            let path = addr
133
                .as_pathname()
134
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
135
            UnixListener::bind(path).await
136
        }
137
    }
138

            
139
    #[cfg(not(unix))]
140
    crate::impls::impl_unix_non_provider! { async_executors::AsyncStd }
141

            
142
    #[async_trait]
143
    impl traits::UdpProvider for async_executors::AsyncStd {
144
        type UdpSocket = UdpSocket;
145

            
146
4
        async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
147
            StdUdpSocket::bind(*addr)
148
                .await
149
4
                .map(|socket| UdpSocket { socket })
150
4
        }
151
    }
152

            
153
    /// Wrap a AsyncStd UdpSocket
154
    pub struct UdpSocket {
155
        /// The underlying UdpSocket
156
        socket: StdUdpSocket,
157
    }
158

            
159
    #[async_trait]
160
    impl traits::UdpSocket for UdpSocket {
161
2
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
162
            self.socket.recv_from(buf).await
163
2
        }
164

            
165
2
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
166
            self.socket.send_to(buf, target).await
167
2
        }
168

            
169
4
        fn local_addr(&self) -> IoResult<SocketAddr> {
170
4
            self.socket.local_addr()
171
4
        }
172
    }
173

            
174
    impl traits::StreamOps for TcpStream {
175
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
176
            impls::streamops::set_tcp_notsent_lowat(self, notsent_lowat)
177
        }
178

            
179
        #[cfg(target_os = "linux")]
180
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
181
            Box::new(impls::streamops::TcpSockFd::from_fd(self))
182
        }
183
    }
184

            
185
    #[cfg(unix)]
186
    impl traits::StreamOps for UnixStream {
187
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
188
            Err(traits::UnsupportedStreamOp::new(
189
                "set_tcp_notsent_lowat",
190
                "unsupported on Unix streams",
191
            )
192
            .into())
193
        }
194
    }
195
}
196

            
197
// ==============================
198

            
199
use futures::{Future, FutureExt};
200
use std::pin::Pin;
201
use std::time::Duration;
202

            
203
use crate::traits::*;
204

            
205
/// Create and return a new `async_std` runtime.
206
1052
pub fn create_runtime() -> async_executors::AsyncStd {
207
1052
    async_executors::AsyncStd::new()
208
1052
}
209

            
210
impl SleepProvider for async_executors::AsyncStd {
211
    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
212
748
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
213
748
        Box::pin(async_io::Timer::after(duration).map(|_| ()))
214
748
    }
215
}
216

            
217
impl ToplevelBlockOn for async_executors::AsyncStd {
218
318
    fn block_on<F: Future>(&self, f: F) -> F::Output {
219
318
        async_executors::AsyncStd::block_on(f)
220
318
    }
221
}
222

            
223
impl Blocking for async_executors::AsyncStd {
224
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
225

            
226
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
227
    where
228
        F: FnOnce() -> T + Send + 'static,
229
        T: Send + 'static,
230
    {
231
        async_executors::SpawnBlocking::spawn_blocking(&self, f)
232
    }
233

            
234
    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
235
        async_executors::AsyncStd::block_on(f)
236
    }
237
}