1
//! Re-exports of the async_std runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5
//!
6
//! We'll probably want to support tokio as well in the future.
7

            
8
/// Types used for networking (async_std implementation)
9
mod net {
10
    use crate::{impls, traits};
11

            
12
    use async_std_crate::net::{TcpListener, TcpStream, UdpSocket as StdUdpSocket};
13
    #[cfg(unix)]
14
    use async_std_crate::os::unix::net::{UnixListener, UnixStream};
15
    use async_trait::async_trait;
16
    use futures::stream::{self, Stream};
17
    use paste::paste;
18
    use std::io::Result as IoResult;
19
    use std::net::SocketAddr;
20
    use std::pin::Pin;
21
    use std::task::{Context, Poll};
22
    #[cfg(unix)]
23
    use tor_general_addr::unix;
24
    use tracing::instrument;
25

            
26
    /// Implement NetStreamProvider-related functionality for a single address type.
27
    macro_rules! impl_stream {
28
        { $kind:ident, $addr:ty } => {paste!{
29
            /// A `Stream` of incoming streams.
30
            ///
31
            /// Differs from the output of `*Listener::incoming` in that this
32
            /// struct is a real type, and that it returns a stream and an address
33
            /// for each input.
34
            pub struct [<Incoming $kind Streams>] {
35
                /// Underlying stream of incoming connections.
36
                inner: Pin<Box<dyn Stream<Item = IoResult<([<$kind Stream>], $addr)>> + Send + Sync>>,
37
            }
38
            impl [<Incoming $kind Streams>] {
39
                /// Create a new IncomingStreams from a Listener.
40
6
                pub fn from_listener(lis: [<$kind Listener>]) -> [<Incoming $kind Streams>] {
41
19
                    let stream = stream::unfold(lis, |lis| async move {
42
16
                        let result = lis.accept().await;
43
16
                        Some((result, lis))
44
32
                    });
45
6
                    Self {
46
6
                        inner: Box::pin(stream),
47
6
                    }
48
6
                }
49
            }
50
            impl Stream for [< Incoming $kind Streams >] {
51
                type Item = IoResult<([<$kind Stream>], $addr)>;
52

            
53
22
                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
54
22
                    self.inner.as_mut().poll_next(cx)
55
22
                }
56
            }
57
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
58
                type Stream = [<$kind Stream>];
59
                type Incoming = [<Incoming $kind Streams>];
60
6
                fn incoming(self) -> [<Incoming $kind Streams>] {
61
6
                    [<Incoming $kind Streams>]::from_listener(self)
62
6
                }
63
6
                fn local_addr(&self) -> IoResult<$addr> {
64
6
                    [<$kind Listener>]::local_addr(self)
65
6
                }
66
            }
67
        }}
68
    }
69

            
70
    impl_stream! { Tcp, std::net::SocketAddr }
71
    #[cfg(unix)]
72
    impl_stream! { Unix, unix::SocketAddr}
73

            
74
    #[async_trait]
75
    impl traits::NetStreamProvider<std::net::SocketAddr> for async_executors::AsyncStd {
76
        type Stream = TcpStream;
77
        type Listener = TcpListener;
78
        #[instrument(skip_all, level = "trace")]
79
        async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
80
            TcpStream::connect(addr).await
81
        }
82
6
        async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
83
            // Use an implementation that's the same across all runtimes.
84
            Ok(impls::tcp_listen(addr)?.into())
85
6
        }
86
    }
87

            
88
    #[cfg(unix)]
89
    #[async_trait]
90
    impl traits::NetStreamProvider<unix::SocketAddr> for async_executors::AsyncStd {
91
        type Stream = UnixStream;
92
        type Listener = UnixListener;
93
        #[instrument(skip_all, level = "trace")]
94
        async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
95
            let path = addr
96
                .as_pathname()
97
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
98
            UnixStream::connect(path).await
99
        }
100
        async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
101
            let path = addr
102
                .as_pathname()
103
                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
104
            UnixListener::bind(path).await
105
        }
106
    }
107

            
108
    #[cfg(not(unix))]
109
    crate::impls::impl_unix_non_provider! { async_executors::AsyncStd }
110

            
111
    #[async_trait]
112
    impl traits::UdpProvider for async_executors::AsyncStd {
113
        type UdpSocket = UdpSocket;
114

            
115
4
        async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
116
            StdUdpSocket::bind(*addr)
117
                .await
118
4
                .map(|socket| UdpSocket { socket })
119
4
        }
120
    }
121

            
122
    /// Wrap a AsyncStd UdpSocket
123
    pub struct UdpSocket {
124
        /// The underlying UdpSocket
125
        socket: StdUdpSocket,
126
    }
127

            
128
    #[async_trait]
129
    impl traits::UdpSocket for UdpSocket {
130
2
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
131
            self.socket.recv_from(buf).await
132
2
        }
133

            
134
2
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
135
            self.socket.send_to(buf, target).await
136
2
        }
137

            
138
4
        fn local_addr(&self) -> IoResult<SocketAddr> {
139
4
            self.socket.local_addr()
140
4
        }
141
    }
142

            
143
    impl traits::StreamOps for TcpStream {
144
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
145
            impls::streamops::set_tcp_notsent_lowat(self, notsent_lowat)
146
        }
147

            
148
        #[cfg(target_os = "linux")]
149
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
150
            Box::new(impls::streamops::TcpSockFd::from_fd(self))
151
        }
152
    }
153

            
154
    #[cfg(unix)]
155
    impl traits::StreamOps for UnixStream {
156
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
157
            Err(traits::UnsupportedStreamOp::new(
158
                "set_tcp_notsent_lowat",
159
                "unsupported on Unix streams",
160
            )
161
            .into())
162
        }
163
    }
164
}
165

            
166
// ==============================
167

            
168
use futures::{Future, FutureExt};
169
use std::pin::Pin;
170
use std::time::Duration;
171

            
172
use crate::traits::*;
173

            
174
/// Create and return a new `async_std` runtime.
175
1052
pub fn create_runtime() -> async_executors::AsyncStd {
176
1052
    async_executors::AsyncStd::new()
177
1052
}
178

            
179
impl SleepProvider for async_executors::AsyncStd {
180
    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
181
738
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
182
738
        Box::pin(async_io::Timer::after(duration).map(|_| ()))
183
738
    }
184
}
185

            
186
impl ToplevelBlockOn for async_executors::AsyncStd {
187
318
    fn block_on<F: Future>(&self, f: F) -> F::Output {
188
318
        async_executors::AsyncStd::block_on(f)
189
318
    }
190
}
191

            
192
impl Blocking for async_executors::AsyncStd {
193
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
194

            
195
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
196
    where
197
        F: FnOnce() -> T + Send + 'static,
198
        T: Send + 'static,
199
    {
200
        async_executors::SpawnBlocking::spawn_blocking(&self, f)
201
    }
202

            
203
    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
204
        async_executors::AsyncStd::block_on(f)
205
    }
206
}