1
//! Re-exports of the tokio runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5

            
6
/// Types used for networking (tokio implementation)
7
pub(crate) mod net {
8
    use crate::{impls, traits};
9
    use async_trait::async_trait;
10
    #[cfg(unix)]
11
    use tor_general_addr::unix;
12

            
13
    pub(crate) use tokio_crate::net::{
14
        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
15
    };
16
    #[cfg(unix)]
17
    pub(crate) use tokio_crate::net::{
18
        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
19
    };
20

            
21
    use futures::io::{AsyncRead, AsyncWrite};
22
    use paste::paste;
23
    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
24

            
25
    use std::io::Result as IoResult;
26
    use std::net::SocketAddr;
27
    use std::pin::Pin;
28
    use std::task::{Context, Poll};
29

            
30
    /// Provide a set of network stream wrappers for a single stream type.
31
    macro_rules! stream_impl {
32
        {
33
            $kind:ident,
34
            $addr:ty,
35
            $cvt_addr:ident
36
        } => {paste!{
37
            /// Wrapper for Tokio's
38
            #[doc = stringify!($kind)]
39
            /// streams,
40
            /// that implements the standard
41
            /// AsyncRead and AsyncWrite.
42
            pub struct [<$kind Stream>] {
43
                /// Underlying tokio_util::compat::Compat wrapper.
44
                s: Compat<[<Tokio $kind Stream>]>,
45
            }
46
            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
47
66
                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
48
66
                    let s = s.compat();
49
66
                    [<$kind Stream>] { s }
50
66
                }
51
            }
52
            impl AsyncRead for  [<$kind Stream>] {
53
210
                fn poll_read(
54
210
                    mut self: Pin<&mut Self>,
55
210
                    cx: &mut Context<'_>,
56
210
                    buf: &mut [u8],
57
210
                ) -> Poll<IoResult<usize>> {
58
210
                    Pin::new(&mut self.s).poll_read(cx, buf)
59
210
                }
60
            }
61
            impl AsyncWrite for  [<$kind Stream>] {
62
72
                fn poll_write(
63
72
                    mut self: Pin<&mut Self>,
64
72
                    cx: &mut Context<'_>,
65
72
                    buf: &[u8],
66
72
                ) -> Poll<IoResult<usize>> {
67
72
                    Pin::new(&mut self.s).poll_write(cx, buf)
68
72
                }
69
60
                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
70
60
                    Pin::new(&mut self.s).poll_flush(cx)
71
60
                }
72
8
                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
73
8
                    Pin::new(&mut self.s).poll_close(cx)
74
8
                }
75
            }
76

            
77
            /// Wrap a Tokio
78
            #[doc = stringify!($kind)]
79
            /// Listener to behave as a futures::io::TcpListener.
80
            pub struct [<$kind Listener>] {
81
                /// The underlying listener.
82
                pub(super) lis: [<Tokio $kind Listener>],
83
            }
84

            
85
            /// Asynchronous stream that yields incoming connections from a
86
            #[doc = stringify!($kind)]
87
            /// Listener.
88
            ///
89
            /// This is analogous to async_std::net::Incoming.
90
            pub struct [<Incoming $kind Streams>] {
91
                /// Reference to the underlying listener.
92
                pub(super) lis: [<Tokio $kind Listener>],
93
            }
94

            
95
            impl futures::stream::Stream for [<Incoming $kind Streams>] {
96
                type Item = IoResult<([<$kind Stream>], $addr)>;
97

            
98
40
                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99
40
                    match self.lis.poll_accept(cx) {
100
30
                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
101
                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
102
10
                        Poll::Pending => Poll::Pending,
103
                    }
104
40
                }
105
            }
106
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
107
                type Stream = [<$kind Stream>];
108
                type Incoming = [<Incoming $kind Streams>];
109
10
                fn incoming(self) -> Self::Incoming {
110
10
                    [<Incoming $kind Streams>] { lis: self.lis }
111
10
                }
112
10
                fn local_addr(&self) -> IoResult<$addr> {
113
10
                    $cvt_addr(self.lis.local_addr()?)
114
10
                }
115
            }
116
        }}
117
    }
118

            
119
    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
120
    ///
121
    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
122
    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
123
    #[cfg(unix)]
124
    #[allow(clippy::needless_pass_by_value)]
125
    fn try_cvt_tokio_unix_addr(
126
        addr: tokio_crate::net::unix::SocketAddr,
127
    ) -> IoResult<unix::SocketAddr> {
128
        if addr.is_unnamed() {
129
            crate::unix::new_unnamed_socketaddr()
130
        } else if let Some(p) = addr.as_pathname() {
131
            unix::SocketAddr::from_pathname(p)
132
        } else {
133
            Err(crate::unix::UnsupportedAfUnixAddressType.into())
134
        }
135
    }
136

            
137
    /// Wrapper for (not) converting std::net::SocketAddr to itself.
138
    #[allow(clippy::unnecessary_wraps)]
139
40
    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
140
40
        Ok(addr)
141
40
    }
142

            
143
    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
144
    #[cfg(unix)]
145
    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
146

            
147
    /// Wrap a Tokio UdpSocket
148
    pub struct UdpSocket {
149
        /// The underelying UdpSocket
150
        socket: TokioUdpSocket,
151
    }
152

            
153
    impl UdpSocket {
154
        /// Bind a UdpSocket
155
12
        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
156
8
            TokioUdpSocket::bind(addr)
157
8
                .await
158
8
                .map(|socket| UdpSocket { socket })
159
8
        }
160
    }
161

            
162
    #[async_trait]
163
    impl traits::UdpSocket for UdpSocket {
164
4
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
165
            self.socket.recv_from(buf).await
166
4
        }
167

            
168
4
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
169
            self.socket.send_to(buf, target).await
170
4
        }
171

            
172
8
        fn local_addr(&self) -> IoResult<SocketAddr> {
173
8
            self.socket.local_addr()
174
8
        }
175
    }
176

            
177
    impl traits::StreamOps for TcpStream {
178
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
179
            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
180
        }
181

            
182
        #[cfg(target_os = "linux")]
183
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
184
            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
185
        }
186
    }
187

            
188
    #[cfg(unix)]
189
    impl traits::StreamOps for UnixStream {
190
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
191
            Err(traits::UnsupportedStreamOp::new(
192
                "set_tcp_notsent_lowat",
193
                "unsupported on Unix streams",
194
            )
195
            .into())
196
        }
197
    }
198
}
199

            
200
// ==============================
201

            
202
use crate::network::{TcpConnectOptions, TcpListenOptions};
203
#[cfg(unix)]
204
use crate::network::{UnixConnectOptions, UnixListenOptions};
205
use crate::traits::*;
206
use async_trait::async_trait;
207
use futures::Future;
208
use std::io::Result as IoResult;
209
use std::time::Duration;
210
#[cfg(unix)]
211
use tor_general_addr::unix;
212
use tracing::instrument;
213

            
214
impl SleepProvider for TokioRuntimeHandle {
215
    type SleepFuture = tokio_crate::time::Sleep;
216
13296
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
217
13296
        tokio_crate::time::sleep(duration)
218
13296
    }
219
}
220

            
221
#[async_trait]
222
impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
223
    type Stream = net::TcpStream;
224
    type Listener = net::TcpListener;
225
    type ConnectOptions = TcpConnectOptions;
226
    type ListenOptions = TcpListenOptions;
227

            
228
    #[instrument(skip_all, level = "trace")]
229
    async fn connect(
230
        &self,
231
        addr: &std::net::SocketAddr,
232
        options: &Self::ConnectOptions,
233
    ) -> IoResult<Self::Stream> {
234
        // The socket before connect() has been called.
235
        let socket = super::tcp_pre_connect(addr, options)?;
236

            
237
        // It might seem a little weird to convert the `socket2::Socket` to a std `TcpStream` before
238
        // it's connected, but this is the approach recommended by tokio.
239
        //
240
        // https://docs.rs/tokio/latest/tokio/net/struct.TcpSocket.html#method.from_std_stream
241
        //
242
        // > Converts a `std::net::TcpStream` into a `TcpSocket`. The provided socket must not have
243
        // > been connected prior to calling this function. This function is typically used together
244
        // > with crates such as socket2 to configure socket options that are not available on
245
        // > `TcpSocket`.
246
        //
247
        // The socket will already be non-blocking.
248
        let socket = std::net::TcpStream::from(socket);
249
        let socket = tokio_crate::net::TcpSocket::from_std_stream(socket);
250

            
251
        // Let tokio handle the connection.
252
        let socket = socket.connect(*addr).await?;
253

            
254
        Ok(socket.into())
255
    }
256
    async fn listen(
257
        &self,
258
        addr: &std::net::SocketAddr,
259
        options: &Self::ListenOptions,
260
10
    ) -> IoResult<Self::Listener> {
261
        // Use an implementation that's the same across all runtimes.
262
        let lis = net::TokioTcpListener::from_std(super::tcp_listen(addr, options)?)?;
263

            
264
        Ok(net::TcpListener { lis })
265
10
    }
266
}
267

            
268
#[cfg(unix)]
269
#[async_trait]
270
impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
271
    type Stream = net::UnixStream;
272
    type Listener = net::UnixListener;
273
    type ConnectOptions = UnixConnectOptions;
274
    type ListenOptions = UnixListenOptions;
275

            
276
    #[instrument(skip_all, level = "trace")]
277
    async fn connect(
278
        &self,
279
        addr: &unix::SocketAddr,
280
        options: &Self::ConnectOptions,
281
    ) -> IoResult<Self::Stream> {
282
        // Will fail to compile if we add options without handling them here.
283
        let UnixConnectOptions {} = options;
284

            
285
        let path = addr
286
            .as_pathname()
287
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
288
        let s = net::TokioUnixStream::connect(path).await?;
289
        Ok(s.into())
290
    }
291
    async fn listen(
292
        &self,
293
        addr: &unix::SocketAddr,
294
        options: &Self::ListenOptions,
295
    ) -> IoResult<Self::Listener> {
296
        // Will fail to compile if we add options without handling them here.
297
        let UnixListenOptions {} = options;
298

            
299
        let path = addr
300
            .as_pathname()
301
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
302
        let lis = net::TokioUnixListener::bind(path)?;
303
        Ok(net::UnixListener { lis })
304
    }
305
}
306

            
307
#[cfg(not(unix))]
308
crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
309

            
310
#[async_trait]
311
impl crate::traits::UdpProvider for TokioRuntimeHandle {
312
    type UdpSocket = net::UdpSocket;
313

            
314
8
    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
315
        net::UdpSocket::bind(*addr).await
316
8
    }
317
}
318

            
319
/// Create and return a new Tokio multithreaded runtime.
320
15954
pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
321
15954
    let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
322
15954
    Ok(runtime.into())
323
15954
}
324

            
325
/// Wrapper around a Handle to a tokio runtime.
326
///
327
/// Ideally, this type would go away, and we would just use
328
/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
329
/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
330
/// to define a new type here.
331
///
332
/// # Limitations
333
///
334
/// Note that Arti requires that the runtime should have working implementations
335
/// for Tokio's time, net, and io facilities, but we have no good way to check
336
/// that when creating this object.
337
#[derive(Clone, Debug)]
338
pub struct TokioRuntimeHandle {
339
    /// If present, the tokio executor that we've created (and which we own).
340
    ///
341
    /// We never access this directly; only through `handle`.  We keep it here
342
    /// so that our Runtime types can be agnostic about whether they own the
343
    /// executor.
344
    owned: Option<async_executors::TokioTp>,
345
    /// The underlying Handle.
346
    handle: tokio_crate::runtime::Handle,
347
}
348

            
349
impl TokioRuntimeHandle {
350
    /// Wrap a tokio runtime handle into a format that Arti can use.
351
    ///
352
    /// # Limitations
353
    ///
354
    /// Note that Arti requires that the runtime should have working
355
    /// implementations for Tokio's time, net, and io facilities, but we have
356
    /// no good way to check that when creating this object.
357
264
    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
358
264
        handle.into()
359
264
    }
360

            
361
    /// Return true if this handle owns the executor that it points to.
362
    pub fn is_owned(&self) -> bool {
363
        self.owned.is_some()
364
    }
365
}
366

            
367
impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
368
264
    fn from(handle: tokio_crate::runtime::Handle) -> Self {
369
264
        Self {
370
264
            owned: None,
371
264
            handle,
372
264
        }
373
264
    }
374
}
375

            
376
impl From<async_executors::TokioTp> for TokioRuntimeHandle {
377
15954
    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
378
16245
        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
379
15954
        Self {
380
15954
            owned: Some(owner),
381
15954
            handle,
382
15954
        }
383
15954
    }
384
}
385

            
386
impl ToplevelBlockOn for TokioRuntimeHandle {
387
    #[track_caller]
388
450
    fn block_on<F: Future>(&self, f: F) -> F::Output {
389
450
        self.handle.block_on(f)
390
450
    }
391
}
392

            
393
impl Blocking for TokioRuntimeHandle {
394
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
395

            
396
    #[track_caller]
397
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
398
    where
399
        F: FnOnce() -> T + Send + 'static,
400
        T: Send + 'static,
401
    {
402
        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
403
    }
404

            
405
    #[track_caller]
406
    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
407
        self.handle.block_on(future)
408
    }
409

            
410
    #[track_caller]
411
    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
412
    where
413
        F: FnOnce() -> T + Send + 'static,
414
        T: Send + 'static,
415
    {
416
        let r = tokio_crate::task::block_in_place(f);
417
        std::future::ready(r)
418
    }
419
}
420

            
421
impl futures::task::Spawn for TokioRuntimeHandle {
422
    #[track_caller]
423
11639
    fn spawn_obj(
424
11639
        &self,
425
11639
        future: futures::task::FutureObj<'static, ()>,
426
11639
    ) -> Result<(), futures::task::SpawnError> {
427
11639
        let join_handle = self.handle.spawn(future);
428
11639
        drop(join_handle); // this makes the task detached.
429
11639
        Ok(())
430
11639
    }
431
}