1
//! Support for streams and listeners on `general::SocketAddr`.
2

            
3
use async_trait::async_trait;
4
use futures::{AsyncRead, AsyncWrite, StreamExt as _, stream};
5
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
6
use std::net;
7
use std::task::Poll;
8
use std::{pin::Pin, task::Context};
9
use tor_general_addr::unix;
10
use tracing::instrument;
11

            
12
use crate::{NetStreamListener, NetStreamProvider, StreamOps};
13
use tor_general_addr::general;
14

            
15
pub use general::{AddrParseError, SocketAddr};
16

            
17
/// Helper trait to allow us to create a type-erased stream.
18
///
19
/// (Rust doesn't allow "dyn AsyncRead + AsyncWrite")
20
trait ReadAndWrite: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
21
impl<T> ReadAndWrite for T where T: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
22

            
23
/// A stream returned by a `NetStreamProvider<GeneralizedAddr>`
24
pub struct Stream(Pin<Box<dyn ReadAndWrite>>);
25
impl AsyncRead for Stream {
26
    fn poll_read(
27
        mut self: Pin<&mut Self>,
28
        cx: &mut Context<'_>,
29
        buf: &mut [u8],
30
    ) -> Poll<IoResult<usize>> {
31
        self.0.as_mut().poll_read(cx, buf)
32
    }
33
}
34
impl AsyncWrite for Stream {
35
    fn poll_write(
36
        mut self: Pin<&mut Self>,
37
        cx: &mut Context<'_>,
38
        buf: &[u8],
39
    ) -> Poll<IoResult<usize>> {
40
        self.0.as_mut().poll_write(cx, buf)
41
    }
42

            
43
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
44
        self.0.as_mut().poll_flush(cx)
45
    }
46

            
47
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
48
        self.0.as_mut().poll_close(cx)
49
    }
50
}
51

            
52
impl StreamOps for Stream {
53
    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
54
        self.0.set_tcp_notsent_lowat(notsent_lowat)
55
    }
56

            
57
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
58
        self.0.new_handle()
59
    }
60
}
61

            
62
/// The type of the result from an [`IncomingStreams`].
63
type StreamItem = IoResult<(Stream, general::SocketAddr)>;
64

            
65
/// A stream of incoming connections on a [`general::Listener`](Listener).
66
pub struct IncomingStreams(Pin<Box<dyn stream::Stream<Item = StreamItem> + Send + Sync>>);
67

            
68
impl stream::Stream for IncomingStreams {
69
    type Item = IoResult<(Stream, general::SocketAddr)>;
70

            
71
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72
        self.0.as_mut().poll_next(cx)
73
    }
74
}
75

            
76
/// A listener returned by a `NetStreamProvider<general::SocketAddr>`.
77
pub struct Listener {
78
    /// The `futures::Stream` of incoming network streams.
79
    streams: IncomingStreams,
80
    /// The local address on which we're listening.
81
    local_addr: general::SocketAddr,
82
}
83

            
84
impl NetStreamListener<general::SocketAddr> for Listener {
85
    type Stream = Stream;
86
    type Incoming = IncomingStreams;
87

            
88
    fn incoming(self) -> IncomingStreams {
89
        self.streams
90
    }
91

            
92
    fn local_addr(&self) -> IoResult<general::SocketAddr> {
93
        Ok(self.local_addr.clone())
94
    }
95
}
96

            
97
/// Use `provider` to launch a `NetStreamListener` at `address`, and wrap that listener
98
/// as a `Listener`.
99
async fn abstract_listener_on<ADDR, P>(
100
    provider: &P,
101
    address: &ADDR,
102
    options: &P::ListenOptions,
103
) -> IoResult<Listener>
104
where
105
    P: NetStreamProvider<ADDR>,
106
    general::SocketAddr: From<ADDR>,
107
{
108
    let lis = provider.listen(address, options).await?;
109
    let local_addr = general::SocketAddr::from(lis.local_addr()?);
110
    let streams = lis.incoming().map(|result| {
111
        result.map(|(socket, addr)| (Stream(Box::pin(socket)), general::SocketAddr::from(addr)))
112
    });
113
    let streams = IncomingStreams(Box::pin(streams));
114
    Ok(Listener {
115
        streams,
116
        local_addr,
117
    })
118
}
119

            
120
#[async_trait]
121
impl<T> NetStreamProvider<general::SocketAddr> for T
122
where
123
    T: NetStreamProvider<net::SocketAddr> + NetStreamProvider<unix::SocketAddr>,
124
{
125
    type Stream = Stream;
126
    type Listener = Listener;
127
    // TODO: If unix sockets ever support `CommonConnectOptions`,
128
    // we could accept these common options and convert to the appropriate type.
129
    type ConnectOptions = ();
130
    // TODO: If unix sockets ever support `CommonListenOptions`,
131
    // we could accept these common options and convert to the appropriate type.
132
    type ListenOptions = ();
133

            
134
    #[instrument(skip_all, level = "trace")]
135
    async fn connect(
136
        &self,
137
        addr: &general::SocketAddr,
138
        (): &Self::ConnectOptions,
139
    ) -> IoResult<Stream> {
140
        use general::SocketAddr as G;
141
        match addr {
142
            G::Inet(a) => {
143
                let options = Default::default();
144
                Ok(Stream(Box::pin(self.connect(a, &options).await?)))
145
            }
146
            G::Unix(a) => {
147
                let options = Default::default();
148
                Ok(Stream(Box::pin(self.connect(a, &options).await?)))
149
            }
150
            other => Err(IoError::new(
151
                IoErrorKind::InvalidInput,
152
                UnsupportedAddress(other.clone()),
153
            )),
154
        }
155
    }
156
    async fn listen(
157
        &self,
158
        addr: &general::SocketAddr,
159
        (): &Self::ListenOptions,
160
    ) -> IoResult<Listener> {
161
        use general::SocketAddr as G;
162
        match addr {
163
            G::Inet(a) => abstract_listener_on(self, a, &Default::default()).await,
164
            G::Unix(a) => abstract_listener_on(self, a, &Default::default()).await,
165
            other => Err(IoError::new(
166
                IoErrorKind::InvalidInput,
167
                UnsupportedAddress(other.clone()),
168
            )),
169
        }
170
    }
171
}
172

            
173
/// Tried to use a [`general::SocketAddr`] that `tor-rtcompat` didn't understand.
174
#[derive(Clone, Debug, thiserror::Error)]
175
#[error("Socket address {0:?} is not supported by tor-rtcompat")]
176
pub struct UnsupportedAddress(general::SocketAddr);