1
//! Re-exports of the smol runtime for use with arti.
2
//! This crate defines a slim API around our async runtime so that we
3
//! can swap it out easily.
4

            
5
/// Types used for networking (smol implementation).
6
pub(crate) mod net {
7
    use super::SmolRuntime;
8
    use crate::network::{TcpConnectOptions, TcpListenOptions};
9
    #[cfg(unix)]
10
    use crate::network::{UnixConnectOptions, UnixListenOptions};
11
    use crate::{impls, traits};
12
    use async_trait::async_trait;
13
    use futures::stream::{self, Stream};
14
    use paste::paste;
15
    use smol::Async;
16
    #[cfg(unix)]
17
    use smol::net::unix::{UnixListener, UnixStream};
18
    use smol::net::{TcpListener, TcpStream, UdpSocket as SmolUdpSocket};
19
    use std::io::Result as IoResult;
20
    use std::net::SocketAddr;
21
    use std::pin::Pin;
22
    use std::task::{Context, Poll};
23
    use tor_general_addr::unix;
24
    use tracing::instrument;
25

            
26
    /// Provide wrapper for different stream types
27
    /// (e.g async_net::TcpStream and async_net::unix::UnixStream).
28
    macro_rules! impl_stream {
29
        { $kind:ident, $addr:ty } => { paste! {
30

            
31
            /// A `Stream` of incoming streams.
32
            pub struct [<Incoming $kind Streams>] {
33
                /// Underlying stream of incoming connections.
34
                inner: Pin<Box<dyn Stream<Item = IoResult<([<$kind Stream>], $addr)>> + Send + Sync>>,
35
            }
36

            
37
            impl [<Incoming $kind Streams>] {
38
                /// Create a new `Incoming*Streams` from a listener.
39
6
                pub fn from_listener(lis: [<$kind Listener>]) -> Self {
40
19
                    let stream = stream::unfold(lis, |lis| async move {
41
16
                        let result = lis.accept().await;
42
16
                        Some((result, lis))
43
32
                    });
44
6
                    Self {
45
6
                        inner: Box::pin(stream),
46
6
                    }
47
6
                }
48
            }
49

            
50
            impl Stream for [<Incoming $kind Streams>] {
51
                type Item = IoResult<([<$kind Stream>], $addr)>;
52

            
53
22
                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
54
22
                    self.inner.as_mut().poll_next(cx)
55
22
                }
56
            }
57

            
58
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
59
                type Stream = [<$kind Stream>];
60
                type Incoming = [<Incoming $kind Streams>];
61

            
62
6
                fn incoming(self) -> Self::Incoming {
63
6
                    [<Incoming $kind Streams>]::from_listener(self)
64
6
                }
65

            
66
6
                fn local_addr(&self) -> IoResult<$addr> {
67
6
                    [<$kind Listener>]::local_addr(self)
68
6
                }
69
            }
70
        }}
71
    }
72

            
73
    impl_stream! { Tcp, SocketAddr }
74
    #[cfg(unix)]
75
    impl_stream! { Unix, unix::SocketAddr }
76

            
77
    #[async_trait]
78
    impl traits::NetStreamProvider<SocketAddr> for SmolRuntime {
79
        type Stream = TcpStream;
80
        type Listener = TcpListener;
81
        type ConnectOptions = TcpConnectOptions;
82
        type ListenOptions = TcpListenOptions;
83

            
84
        #[instrument(skip_all, level = "trace")]
85
        async fn connect(
86
            &self,
87
            addr: &SocketAddr,
88
            options: &Self::ConnectOptions,
89
        ) -> IoResult<Self::Stream> {
90
            // The smol runtime uses async-io internally.
91
            let stream = impls::tcp_async_io_connect(addr, options).await?;
92

            
93
            // The socket is already non-blocking,
94
            // so `Async` doesn't need to set as non-blocking again.
95
            Ok(Async::new_nonblocking(stream)?.into())
96
        }
97

            
98
        async fn listen(
99
            &self,
100
            addr: &SocketAddr,
101
            options: &Self::ListenOptions,
102
6
        ) -> IoResult<Self::Listener> {
103
            // Use an implementation that's the same across all runtimes.
104
            // The socket is already non-blocking, so `Async` doesn't need to set as non-blocking
105
            // again. If it *were* to be blocking, then I/O operations would block in async
106
            // contexts, which would lead to deadlocks.
107
            Ok(Async::new_nonblocking(impls::tcp_listen(addr, options)?)?.into())
108
6
        }
109
    }
110

            
111
    #[cfg(unix)]
112
    #[async_trait]
113
    impl traits::NetStreamProvider<unix::SocketAddr> for SmolRuntime {
114
        type Stream = UnixStream;
115
        type Listener = UnixListener;
116
        type ConnectOptions = UnixConnectOptions;
117
        type ListenOptions = UnixListenOptions;
118

            
119
        #[instrument(skip_all, level = "trace")]
120
        async fn connect(
121
            &self,
122
            addr: &unix::SocketAddr,
123
            options: &Self::ConnectOptions,
124
        ) -> IoResult<Self::Stream> {
125
            // Will fail to compile if we add options without handling them here.
126
            let UnixConnectOptions {} = options;
127

            
128
            let path = addr
129
                .as_pathname()
130
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
131
            UnixStream::connect(path).await
132
        }
133

            
134
        async fn listen(
135
            &self,
136
            addr: &unix::SocketAddr,
137
            options: &Self::ListenOptions,
138
        ) -> IoResult<Self::Listener> {
139
            // Will fail to compile if we add options without handling them here.
140
            let UnixListenOptions {} = options;
141

            
142
            let path = addr
143
                .as_pathname()
144
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
145
            UnixListener::bind(path)
146
        }
147
    }
148

            
149
    #[cfg(not(unix))]
150
    crate::impls::impl_unix_non_provider! { SmolRuntime }
151

            
152
    #[async_trait]
153
    impl traits::UdpProvider for SmolRuntime {
154
        type UdpSocket = UdpSocket;
155

            
156
4
        async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
157
            SmolUdpSocket::bind(addr)
158
                .await
159
4
                .map(|socket| UdpSocket { socket })
160
4
        }
161
    }
162

            
163
    /// Wrapper for `SmolUdpSocket`.
164
    // Required to implement `traits::UdpSocket`.
165
    pub struct UdpSocket {
166
        /// The underlying socket.
167
        socket: SmolUdpSocket,
168
    }
169

            
170
    #[async_trait]
171
    impl traits::UdpSocket for UdpSocket {
172
2
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
173
            self.socket.recv_from(buf).await
174
2
        }
175

            
176
2
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
177
            self.socket.send_to(buf, target).await
178
2
        }
179

            
180
4
        fn local_addr(&self) -> IoResult<SocketAddr> {
181
4
            self.socket.local_addr()
182
4
        }
183
    }
184

            
185
    impl traits::StreamOps for TcpStream {
186
        fn set_tcp_notsent_lowat(&self, lowat: u32) -> IoResult<()> {
187
            impls::streamops::set_tcp_notsent_lowat(self, lowat)
188
        }
189

            
190
        #[cfg(target_os = "linux")]
191
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
192
            Box::new(impls::streamops::TcpSockFd::from_fd(self))
193
        }
194
    }
195

            
196
    #[cfg(unix)]
197
    impl traits::StreamOps for UnixStream {
198
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
199
            Err(traits::UnsupportedStreamOp::new(
200
                "set_tcp_notsent_lowat",
201
                "unsupported on Unix streams",
202
            )
203
            .into())
204
        }
205
    }
206
}
207

            
208
// ==============================
209

            
210
use crate::traits::*;
211
use futures::task::{FutureObj, Spawn, SpawnError};
212
use futures::{Future, FutureExt};
213
use std::pin::Pin;
214
use std::time::Duration;
215

            
216
/// Type to wrap `smol::Executor`.
217
#[derive(Clone)]
218
pub struct SmolRuntime {
219
    /// Instance of the smol executor we own.
220
    executor: std::sync::Arc<smol::Executor<'static>>,
221
}
222

            
223
/// Construct new instance of the smol runtime.
224
//
225
// TODO: Make SmolRuntime multi-threaded.
226
1046
pub fn create_runtime() -> SmolRuntime {
227
1046
    SmolRuntime {
228
1046
        executor: std::sync::Arc::new(smol::Executor::new()),
229
1046
    }
230
1046
}
231

            
232
impl SleepProvider for SmolRuntime {
233
    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
234
908
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
235
908
        Box::pin(async_io::Timer::after(duration).map(|_| ()))
236
908
    }
237
}
238

            
239
impl ToplevelBlockOn for SmolRuntime {
240
316
    fn block_on<F: Future>(&self, f: F) -> F::Output {
241
316
        smol::block_on(self.executor.run(f))
242
316
    }
243
}
244

            
245
impl Blocking for SmolRuntime {
246
    type ThreadHandle<T: Send + 'static> = blocking::Task<T>;
247

            
248
    fn spawn_blocking<F, T>(&self, f: F) -> blocking::Task<T>
249
    where
250
        F: FnOnce() -> T + Send + 'static,
251
        T: Send + 'static,
252
    {
253
        smol::unblock(f)
254
    }
255

            
256
    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
257
        smol::block_on(self.executor.run(f))
258
    }
259
}
260

            
261
impl Spawn for SmolRuntime {
262
1252
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
263
1252
        self.executor.spawn(future).detach();
264
1252
        Ok(())
265
1252
    }
266
}