1
//! Re-exports of the tokio runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5

            
6
/// Types used for networking (tokio implementation)
7
pub(crate) mod net {
8
    use crate::{impls, traits};
9
    use async_trait::async_trait;
10
    #[cfg(unix)]
11
    use tor_general_addr::unix;
12

            
13
    pub(crate) use tokio_crate::net::{
14
        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
15
    };
16
    #[cfg(unix)]
17
    pub(crate) use tokio_crate::net::{
18
        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
19
    };
20

            
21
    use futures::io::{AsyncRead, AsyncWrite};
22
    use paste::paste;
23
    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
24

            
25
    use std::io::Result as IoResult;
26
    use std::net::SocketAddr;
27
    use std::pin::Pin;
28
    use std::task::{Context, Poll};
29

            
30
    /// Provide a set of network stream wrappers for a single stream type.
31
    macro_rules! stream_impl {
32
        {
33
            $kind:ident,
34
            $addr:ty,
35
            $cvt_addr:ident
36
        } => {paste!{
37
            /// Wrapper for Tokio's
38
            #[doc = stringify!($kind)]
39
            /// streams,
40
            /// that implements the standard
41
            /// AsyncRead and AsyncWrite.
42
            pub struct [<$kind Stream>] {
43
                /// Underlying tokio_util::compat::Compat wrapper.
44
                s: Compat<[<Tokio $kind Stream>]>,
45
            }
46
            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
47
66
                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
48
66
                    let s = s.compat();
49
66
                    [<$kind Stream>] { s }
50
66
                }
51
            }
52
            impl AsyncRead for  [<$kind Stream>] {
53
200
                fn poll_read(
54
200
                    mut self: Pin<&mut Self>,
55
200
                    cx: &mut Context<'_>,
56
200
                    buf: &mut [u8],
57
200
                ) -> Poll<IoResult<usize>> {
58
200
                    Pin::new(&mut self.s).poll_read(cx, buf)
59
200
                }
60
            }
61
            impl AsyncWrite for  [<$kind Stream>] {
62
72
                fn poll_write(
63
72
                    mut self: Pin<&mut Self>,
64
72
                    cx: &mut Context<'_>,
65
72
                    buf: &[u8],
66
72
                ) -> Poll<IoResult<usize>> {
67
72
                    Pin::new(&mut self.s).poll_write(cx, buf)
68
72
                }
69
60
                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
70
60
                    Pin::new(&mut self.s).poll_flush(cx)
71
60
                }
72
8
                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
73
8
                    Pin::new(&mut self.s).poll_close(cx)
74
8
                }
75
            }
76

            
77
            /// Wrap a Tokio
78
            #[doc = stringify!($kind)]
79
            /// Listener to behave as a futures::io::TcpListener.
80
            pub struct [<$kind Listener>] {
81
                /// The underlying listener.
82
                pub(super) lis: [<Tokio $kind Listener>],
83
            }
84

            
85
            /// Asynchronous stream that yields incoming connections from a
86
            #[doc = stringify!($kind)]
87
            /// Listener.
88
            ///
89
            /// This is analogous to async_std::net::Incoming.
90
            pub struct [<Incoming $kind Streams>] {
91
                /// Reference to the underlying listener.
92
                pub(super) lis: [<Tokio $kind Listener>],
93
            }
94

            
95
            impl futures::stream::Stream for [<Incoming $kind Streams>] {
96
                type Item = IoResult<([<$kind Stream>], $addr)>;
97

            
98
40
                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99
40
                    match self.lis.poll_accept(cx) {
100
30
                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
101
                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
102
10
                        Poll::Pending => Poll::Pending,
103
                    }
104
40
                }
105
            }
106
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
107
                type Stream = [<$kind Stream>];
108
                type Incoming = [<Incoming $kind Streams>];
109
10
                fn incoming(self) -> Self::Incoming {
110
10
                    [<Incoming $kind Streams>] { lis: self.lis }
111
10
                }
112
10
                fn local_addr(&self) -> IoResult<$addr> {
113
10
                    $cvt_addr(self.lis.local_addr()?)
114
10
                }
115
            }
116
        }}
117
    }
118

            
119
    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
120
    ///
121
    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
122
    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
123
    #[cfg(unix)]
124
    #[allow(clippy::needless_pass_by_value)]
125
    fn try_cvt_tokio_unix_addr(
126
        addr: tokio_crate::net::unix::SocketAddr,
127
    ) -> IoResult<unix::SocketAddr> {
128
        if addr.is_unnamed() {
129
            crate::unix::new_unnamed_socketaddr()
130
        } else if let Some(p) = addr.as_pathname() {
131
            unix::SocketAddr::from_pathname(p)
132
        } else {
133
            Err(crate::unix::UnsupportedAfUnixAddressType.into())
134
        }
135
    }
136

            
137
    /// Wrapper for (not) converting std::net::SocketAddr to itself.
138
    #[allow(clippy::unnecessary_wraps)]
139
40
    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
140
40
        Ok(addr)
141
40
    }
142

            
143
    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
144
    #[cfg(unix)]
145
    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
146

            
147
    /// Wrap a Tokio UdpSocket
148
    pub struct UdpSocket {
149
        /// The underelying UdpSocket
150
        socket: TokioUdpSocket,
151
    }
152

            
153
    impl UdpSocket {
154
        /// Bind a UdpSocket
155
12
        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
156
8
            TokioUdpSocket::bind(addr)
157
8
                .await
158
8
                .map(|socket| UdpSocket { socket })
159
8
        }
160
    }
161

            
162
    #[async_trait]
163
    impl traits::UdpSocket for UdpSocket {
164
4
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
165
            self.socket.recv_from(buf).await
166
4
        }
167

            
168
4
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
169
            self.socket.send_to(buf, target).await
170
4
        }
171

            
172
8
        fn local_addr(&self) -> IoResult<SocketAddr> {
173
8
            self.socket.local_addr()
174
8
        }
175
    }
176

            
177
    impl traits::StreamOps for TcpStream {
178
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
179
            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
180
        }
181

            
182
        #[cfg(target_os = "linux")]
183
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
184
            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
185
        }
186
    }
187

            
188
    #[cfg(unix)]
189
    impl traits::StreamOps for UnixStream {
190
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
191
            Err(traits::UnsupportedStreamOp::new(
192
                "set_tcp_notsent_lowat",
193
                "unsupported on Unix streams",
194
            )
195
            .into())
196
        }
197
    }
198
}
199

            
200
// ==============================
201

            
202
use crate::traits::*;
203
use async_trait::async_trait;
204
use futures::Future;
205
use std::io::Result as IoResult;
206
use std::time::Duration;
207
#[cfg(unix)]
208
use tor_general_addr::unix;
209
use tracing::instrument;
210

            
211
impl SleepProvider for TokioRuntimeHandle {
212
    type SleepFuture = tokio_crate::time::Sleep;
213
15156
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
214
15156
        tokio_crate::time::sleep(duration)
215
15156
    }
216
}
217

            
218
#[async_trait]
219
impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
220
    type Stream = net::TcpStream;
221
    type Listener = net::TcpListener;
222

            
223
    #[instrument(skip_all, level = "trace")]
224
    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
225
        let s = net::TokioTcpStream::connect(addr).await?;
226
        Ok(s.into())
227
    }
228
10
    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
229
        // Use an implementation that's the same across all runtimes.
230
        let lis = net::TokioTcpListener::from_std(super::tcp_listen(addr)?)?;
231

            
232
        Ok(net::TcpListener { lis })
233
10
    }
234
}
235

            
236
#[cfg(unix)]
237
#[async_trait]
238
impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
239
    type Stream = net::UnixStream;
240
    type Listener = net::UnixListener;
241

            
242
    #[instrument(skip_all, level = "trace")]
243
    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
244
        let path = addr
245
            .as_pathname()
246
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247
        let s = net::TokioUnixStream::connect(path).await?;
248
        Ok(s.into())
249
    }
250
    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
251
        let path = addr
252
            .as_pathname()
253
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
254
        let lis = net::TokioUnixListener::bind(path)?;
255
        Ok(net::UnixListener { lis })
256
    }
257
}
258

            
259
#[cfg(not(unix))]
260
crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
261

            
262
#[async_trait]
263
impl crate::traits::UdpProvider for TokioRuntimeHandle {
264
    type UdpSocket = net::UdpSocket;
265

            
266
8
    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
267
        net::UdpSocket::bind(*addr).await
268
8
    }
269
}
270

            
271
/// Create and return a new Tokio multithreaded runtime.
272
15025
pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
273
15025
    let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
274
15025
    Ok(runtime.into())
275
15025
}
276

            
277
/// Wrapper around a Handle to a tokio runtime.
278
///
279
/// Ideally, this type would go away, and we would just use
280
/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
281
/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
282
/// to define a new type here.
283
///
284
/// # Limitations
285
///
286
/// Note that Arti requires that the runtime should have working implementations
287
/// for Tokio's time, net, and io facilities, but we have no good way to check
288
/// that when creating this object.
289
#[derive(Clone, Debug)]
290
pub struct TokioRuntimeHandle {
291
    /// If present, the tokio executor that we've created (and which we own).
292
    ///
293
    /// We never access this directly; only through `handle`.  We keep it here
294
    /// so that our Runtime types can be agnostic about whether they own the
295
    /// executor.
296
    owned: Option<async_executors::TokioTp>,
297
    /// The underlying Handle.
298
    handle: tokio_crate::runtime::Handle,
299
}
300

            
301
impl TokioRuntimeHandle {
302
    /// Wrap a tokio runtime handle into a format that Arti can use.
303
    ///
304
    /// # Limitations
305
    ///
306
    /// Note that Arti requires that the runtime should have working
307
    /// implementations for Tokio's time, net, and io facilities, but we have
308
    /// no good way to check that when creating this object.
309
319
    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
310
319
        handle.into()
311
319
    }
312

            
313
    /// Return true if this handle owns the executor that it points to.
314
    pub fn is_owned(&self) -> bool {
315
        self.owned.is_some()
316
    }
317
}
318

            
319
impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
320
319
    fn from(handle: tokio_crate::runtime::Handle) -> Self {
321
319
        Self {
322
319
            owned: None,
323
319
            handle,
324
319
        }
325
319
    }
326
}
327

            
328
impl From<async_executors::TokioTp> for TokioRuntimeHandle {
329
15025
    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
330
15309
        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
331
15025
        Self {
332
15025
            owned: Some(owner),
333
15025
            handle,
334
15025
        }
335
15025
    }
336
}
337

            
338
impl ToplevelBlockOn for TokioRuntimeHandle {
339
    #[track_caller]
340
446
    fn block_on<F: Future>(&self, f: F) -> F::Output {
341
446
        self.handle.block_on(f)
342
446
    }
343
}
344

            
345
impl Blocking for TokioRuntimeHandle {
346
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
347

            
348
    #[track_caller]
349
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
350
    where
351
        F: FnOnce() -> T + Send + 'static,
352
        T: Send + 'static,
353
    {
354
        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
355
    }
356

            
357
    #[track_caller]
358
    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
359
        self.handle.block_on(future)
360
    }
361

            
362
    #[track_caller]
363
    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
364
    where
365
        F: FnOnce() -> T + Send + 'static,
366
        T: Send + 'static,
367
    {
368
        let r = tokio_crate::task::block_in_place(f);
369
        std::future::ready(r)
370
    }
371
}
372

            
373
impl futures::task::Spawn for TokioRuntimeHandle {
374
    #[track_caller]
375
18778
    fn spawn_obj(
376
18778
        &self,
377
18778
        future: futures::task::FutureObj<'static, ()>,
378
18778
    ) -> Result<(), futures::task::SpawnError> {
379
18778
        let join_handle = self.handle.spawn(future);
380
18778
        drop(join_handle); // this makes the task detached.
381
18778
        Ok(())
382
18778
    }
383
}