1
//! Re-exports of the smol runtime for use with arti.
2
//! This crate defines a slim API around our async runtime so that we
3
//! can swap it out easily.
4

            
5
/// Types used for networking (smol implementation).
6
pub(crate) mod net {
7
    use super::SmolRuntime;
8
    use crate::{impls, traits};
9
    use async_trait::async_trait;
10
    use futures::stream::{self, Stream};
11
    use paste::paste;
12
    use smol::Async;
13
    #[cfg(unix)]
14
    use smol::net::unix::{UnixListener, UnixStream};
15
    use smol::net::{TcpListener, TcpStream, UdpSocket as SmolUdpSocket};
16
    use std::io::Result as IoResult;
17
    use std::net::SocketAddr;
18
    use std::pin::Pin;
19
    use std::task::{Context, Poll};
20
    use tor_general_addr::unix;
21
    use tracing::instrument;
22

            
23
    /// Provide wrapper for different stream types
24
    /// (e.g async_net::TcpStream and async_net::unix::UnixStream).
25
    macro_rules! impl_stream {
26
        { $kind:ident, $addr:ty } => { paste! {
27

            
28
            /// A `Stream` of incoming streams.
29
            pub struct [<Incoming $kind Streams>] {
30
                /// Underlying stream of incoming connections.
31
                inner: Pin<Box<dyn Stream<Item = IoResult<([<$kind Stream>], $addr)>> + Send + Sync>>,
32
            }
33

            
34
            impl [<Incoming $kind Streams>] {
35
                /// Create a new `Incoming*Streams` from a listener.
36
6
                pub fn from_listener(lis: [<$kind Listener>]) -> Self {
37
19
                    let stream = stream::unfold(lis, |lis| async move {
38
16
                        let result = lis.accept().await;
39
16
                        Some((result, lis))
40
32
                    });
41
6
                    Self {
42
6
                        inner: Box::pin(stream),
43
6
                    }
44
6
                }
45
            }
46

            
47
            impl Stream for [<Incoming $kind Streams>] {
48
                type Item = IoResult<([<$kind Stream>], $addr)>;
49

            
50
22
                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
51
22
                    self.inner.as_mut().poll_next(cx)
52
22
                }
53
            }
54

            
55
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
56
                type Stream = [<$kind Stream>];
57
                type Incoming = [<Incoming $kind Streams>];
58

            
59
6
                fn incoming(self) -> Self::Incoming {
60
6
                    [<Incoming $kind Streams>]::from_listener(self)
61
6
                }
62

            
63
6
                fn local_addr(&self) -> IoResult<$addr> {
64
6
                    [<$kind Listener>]::local_addr(self)
65
6
                }
66
            }
67
        }}
68
    }
69

            
70
    impl_stream! { Tcp, SocketAddr }
71
    #[cfg(unix)]
72
    impl_stream! { Unix, unix::SocketAddr }
73

            
74
    #[async_trait]
75
    impl traits::NetStreamProvider<SocketAddr> for SmolRuntime {
76
        type Stream = TcpStream;
77
        type Listener = TcpListener;
78

            
79
        #[instrument(skip_all, level = "trace")]
80
        async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
81
            TcpStream::connect(addr).await
82
        }
83

            
84
6
        async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
85
            // Use an implementation that's the same across all runtimes.
86
            // The socket is already non-blocking, so `Async` doesn't need to set as non-blocking
87
            // again. If it *were* to be blocking, then I/O operations would block in async
88
            // contexts, which would lead to deadlocks.
89
            Ok(Async::new_nonblocking(impls::tcp_listen(addr)?)?.into())
90
6
        }
91
    }
92

            
93
    #[cfg(unix)]
94
    #[async_trait]
95
    impl traits::NetStreamProvider<unix::SocketAddr> for SmolRuntime {
96
        type Stream = UnixStream;
97
        type Listener = UnixListener;
98

            
99
        #[instrument(skip_all, level = "trace")]
100
        async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
101
            let path = addr
102
                .as_pathname()
103
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
104
            UnixStream::connect(path).await
105
        }
106

            
107
        async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
108
            let path = addr
109
                .as_pathname()
110
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
111
            UnixListener::bind(path)
112
        }
113
    }
114

            
115
    #[cfg(not(unix))]
116
    crate::impls::impl_unix_non_provider! { SmolRuntime }
117

            
118
    #[async_trait]
119
    impl traits::UdpProvider for SmolRuntime {
120
        type UdpSocket = UdpSocket;
121

            
122
4
        async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
123
            SmolUdpSocket::bind(addr)
124
                .await
125
4
                .map(|socket| UdpSocket { socket })
126
4
        }
127
    }
128

            
129
    /// Wrapper for `SmolUdpSocket`.
130
    // Required to implement `traits::UdpSocket`.
131
    pub struct UdpSocket {
132
        /// The underlying socket.
133
        socket: SmolUdpSocket,
134
    }
135

            
136
    #[async_trait]
137
    impl traits::UdpSocket for UdpSocket {
138
2
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
139
            self.socket.recv_from(buf).await
140
2
        }
141

            
142
2
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
143
            self.socket.send_to(buf, target).await
144
2
        }
145

            
146
4
        fn local_addr(&self) -> IoResult<SocketAddr> {
147
4
            self.socket.local_addr()
148
4
        }
149
    }
150

            
151
    impl traits::StreamOps for TcpStream {
152
        fn set_tcp_notsent_lowat(&self, lowat: u32) -> IoResult<()> {
153
            impls::streamops::set_tcp_notsent_lowat(self, lowat)
154
        }
155

            
156
        #[cfg(target_os = "linux")]
157
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
158
            Box::new(impls::streamops::TcpSockFd::from_fd(self))
159
        }
160
    }
161

            
162
    #[cfg(unix)]
163
    impl traits::StreamOps for UnixStream {
164
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
165
            Err(traits::UnsupportedStreamOp::new(
166
                "set_tcp_notsent_lowat",
167
                "unsupported on Unix streams",
168
            )
169
            .into())
170
        }
171
    }
172
}
173

            
174
// ==============================
175

            
176
use crate::traits::*;
177
use futures::task::{FutureObj, Spawn, SpawnError};
178
use futures::{Future, FutureExt};
179
use std::pin::Pin;
180
use std::time::Duration;
181

            
182
/// Type to wrap `smol::Executor`.
183
#[derive(Clone)]
184
pub struct SmolRuntime {
185
    /// Instance of the smol executor we own.
186
    executor: std::sync::Arc<smol::Executor<'static>>,
187
}
188

            
189
/// Construct new instance of the smol runtime.
190
//
191
// TODO: Make SmolRuntime multi-threaded.
192
1046
pub fn create_runtime() -> SmolRuntime {
193
1046
    SmolRuntime {
194
1046
        executor: std::sync::Arc::new(smol::Executor::new()),
195
1046
    }
196
1046
}
197

            
198
impl SleepProvider for SmolRuntime {
199
    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
200
908
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
201
908
        Box::pin(async_io::Timer::after(duration).map(|_| ()))
202
908
    }
203
}
204

            
205
impl ToplevelBlockOn for SmolRuntime {
206
316
    fn block_on<F: Future>(&self, f: F) -> F::Output {
207
316
        smol::block_on(self.executor.run(f))
208
316
    }
209
}
210

            
211
impl Blocking for SmolRuntime {
212
    type ThreadHandle<T: Send + 'static> = blocking::Task<T>;
213

            
214
    fn spawn_blocking<F, T>(&self, f: F) -> blocking::Task<T>
215
    where
216
        F: FnOnce() -> T + Send + 'static,
217
        T: Send + 'static,
218
    {
219
        smol::unblock(f)
220
    }
221

            
222
    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
223
        smol::block_on(self.executor.run(f))
224
    }
225
}
226

            
227
impl Spawn for SmolRuntime {
228
1252
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
229
1252
        self.executor.spawn(future).detach();
230
1252
        Ok(())
231
1252
    }
232
}