1
//! Implement the default transport, which opens TCP connections using a
2
//! happy-eyeballs style parallel algorithm.
3

            
4
use std::{net::SocketAddr, time::Duration};
5

            
6
use async_trait::async_trait;
7
use futures::{FutureExt, StreamExt, TryFutureExt, stream::FuturesUnordered};
8
use safelog::sensitive as sv;
9
use tor_error::bad_api_usage;
10
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11
use tor_proto::peer::PeerAddr;
12
use tor_rtcompat::{NetStreamProvider, Runtime};
13
use tracing::{instrument, trace};
14

            
15
use crate::{Error, err::ConnectError};
16

            
17
/// A default transport object that opens TCP connections for a
18
/// `ChannelMethod::Direct`.
19
///
20
/// It opens almost-simultaneous parallel TCP connections to each address, and
21
/// chooses the first one to succeed.
22
#[derive(Clone, Debug)]
23
pub(crate) struct DefaultTransport<R: Runtime> {
24
    /// The runtime that we use for connecting.
25
    runtime: R,
26
    /// The outbound proxy to use, if any
27
    outbound_proxy: Option<crate::config::ProxyProtocol>,
28
}
29

            
30
impl<R: Runtime> DefaultTransport<R> {
31
    /// Construct a new DefaultTransport
32
38
    pub(crate) fn new(runtime: R, outbound_proxy: Option<crate::config::ProxyProtocol>) -> Self {
33
38
        Self {
34
38
            runtime,
35
38
            outbound_proxy,
36
38
        }
37
38
    }
38
}
39

            
40
#[async_trait]
41
impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
42
    type Stream = <R as NetStreamProvider>::Stream;
43

            
44
    /// Implements the transport: makes a TCP connection (possibly
45
    /// tunneled over whatever protocol) if possible.
46
    #[instrument(skip_all, level = "trace")]
47
    async fn connect(&self, target: &OwnedChanTarget) -> crate::Result<(PeerAddr, Self::Stream)> {
48
        let direct_addrs: Vec<_> = match target.chan_method() {
49
            ChannelMethod::Direct(addrs) => addrs,
50
            #[allow(unreachable_patterns)]
51
            _ => {
52
                return Err(Error::UnusableTarget(bad_api_usage!(
53
                    "Used default transport implementation for an unsupported transport."
54
                )));
55
            }
56
        };
57

            
58
        trace!("Launching direct connection for {}", target);
59

            
60
        let (stream, addr) =
61
            connect_to_one(&self.runtime, &direct_addrs, &self.outbound_proxy).await?;
62
        Ok((addr.into(), stream))
63
    }
64
}
65

            
66
/// Time to wait between starting parallel connections to the same relay.
67
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
68

            
69
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
70
///
71
/// This implements a basic version of RFC 8305 "happy eyeballs".
72
#[instrument(skip_all, level = "trace")]
73
30
async fn connect_to_one<R: Runtime>(
74
30
    rt: &R,
75
30
    addrs: &[SocketAddr],
76
30
    outbound_proxy: &Option<crate::config::ProxyProtocol>,
77
30
) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
78
    // We need *some* addresses to connect to.
79
    if addrs.is_empty() {
80
        return Err(Error::UnusableTarget(bad_api_usage!(
81
            "No addresses for chosen relay"
82
        )));
83
    }
84

            
85
    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
86
    // attempts to connect to the address using the runtime (where i is the
87
    // array index). Shove all of these into a `FuturesUnordered`, polling them
88
    // simultaneously and returning the results in completion order.
89
    //
90
    // This is basically the concurrent-connection stuff from RFC 8305, ish.
91
    // TODO(eta): sort the addresses first?
92
    let mut connections = addrs
93
        .iter()
94
        .enumerate()
95
52
        .map(|(i, a)| {
96
52
            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
97
52
            let proxy = outbound_proxy.clone();
98
52
            delay.then(move |_| {
99
40
                tracing::debug!("Connecting to {}", a);
100
40
                let a = *a;
101
40
                async move {
102
40
                    let stream = if let Some(ref protocol) = proxy {
103
                        // Use proxy - extract address and protocol details
104
                        let target = tor_linkspec::PtTargetAddr::IpPort(a);
105
                        match protocol {
106
                            crate::config::ProxyProtocol::Socks {
107
                                version,
108
                                auth,
109
                                addr,
110
                            } => {
111
                                let proto = super::proxied::Protocol::Socks(*version, auth.clone());
112
                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
113
                            }
114
                            crate::config::ProxyProtocol::HttpConnect { addr, credentials } => {
115
                                // Wrap credentials in Sensitive to avoid accidental logging.
116
                                let auth = credentials.as_ref().map(|cred| {
117
                                    (
118
                                        safelog::Sensitive::new(cred.username.clone()),
119
                                        safelog::Sensitive::new(
120
                                            cred.password.clone().unwrap_or_default(),
121
                                        ),
122
                                    )
123
                                });
124
                                let proto = super::proxied::Protocol::HttpConnect { auth };
125
                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
126
                            }
127
                        }
128
                    } else {
129
                        // Direct connection
130
                        // We don't (yet) use any custom options on the socket.
131
40
                        let connect_options = Default::default();
132
40
                        rt.connect(&a, &connect_options).await?
133
                    };
134
20
                    Ok((stream, a))
135
30
                }
136
40
                .map_err(move |e: ConnectError| (e, a))
137
40
            })
138
52
        })
139
        .collect::<FuturesUnordered<_>>();
140

            
141
    let mut ret = None;
142
    let mut errors: Vec<(ConnectError, SocketAddr)> = vec![];
143

            
144
    while let Some(result) = connections.next().await {
145
        match result {
146
            Ok(s) => {
147
                // We got a stream (and address).
148
                ret = Some(s);
149
                break;
150
            }
151
            Err((e, a)) => {
152
                // We got a failure on one of the streams. Store the error.
153
                // TODO(eta): ideally we'd start the next connection attempt immediately.
154
                errors.push((e, a));
155
            }
156
        }
157
    }
158

            
159
    // Ensure we don't continue trying to make connections.
160
    drop(connections);
161

            
162
    ret.ok_or_else(|| Error::Connect {
163
2
        addresses: errors
164
2
            .into_iter()
165
2
            .map(|(e, a)| (sv(a.to_string()), e))
166
2
            .collect(),
167
2
    })
168
24
}
169

            
170
#[cfg(test)]
171
mod test {
172
    // @@ begin test lint list maintained by maint/add_warning @@
173
    #![allow(clippy::bool_assert_comparison)]
174
    #![allow(clippy::clone_on_copy)]
175
    #![allow(clippy::dbg_macro)]
176
    #![allow(clippy::mixed_attributes_style)]
177
    #![allow(clippy::print_stderr)]
178
    #![allow(clippy::print_stdout)]
179
    #![allow(clippy::single_char_pattern)]
180
    #![allow(clippy::unwrap_used)]
181
    #![allow(clippy::unchecked_time_subtraction)]
182
    #![allow(clippy::useless_vec)]
183
    #![allow(clippy::needless_pass_by_value)]
184
    #![allow(clippy::string_slice)] // See arti#2571
185
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
186

            
187
    use std::str::FromStr;
188

            
189
    use tor_rtcompat::{SleepProviderExt, test_with_one_runtime};
190
    use tor_rtmock::net::MockNetwork;
191

            
192
    use super::*;
193

            
194
    #[test]
195
    fn test_connect_one() {
196
        let client_addr = "192.0.1.16".parse().unwrap();
197
        // We'll put a "relay" at this address
198
        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
199
        // We'll put nothing at this address, to generate errors.
200
        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
201
        // Well put a black hole at this address, to generate timeouts.
202
        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
203
        // We'll put a "relay" at this address too
204
        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
205

            
206
        test_with_one_runtime!(|rt| async move {
207
            // Stub out the internet so that this connection can work.
208
            let network = MockNetwork::new();
209

            
210
            // Set up a client and server runtime with a given IP
211
            let client_rt = network
212
                .builder()
213
                .add_address(client_addr)
214
                .runtime(rt.clone());
215
            let server_rt = network
216
                .builder()
217
                .add_address(addr1.ip())
218
                .add_address(addr4.ip())
219
                .runtime(rt.clone());
220

            
221
            let listen_options = Default::default();
222

            
223
            let _listener = server_rt
224
                .mock_net()
225
                .listen(&addr1, &listen_options)
226
                .await
227
                .unwrap();
228
            let _listener2 = server_rt
229
                .mock_net()
230
                .listen(&addr4, &listen_options)
231
                .await
232
                .unwrap();
233

            
234
            // TODO: Because this test doesn't mock time, there will actually be
235
            // delays as we wait for connections to this address to time out. It
236
            // would be good to use MockSleepProvider instead, once we figure
237
            // out how to make it both reliable and convenient.
238
            network.add_blackhole(addr3).unwrap();
239

            
240
            // No addresses? Can't succeed.
241
            let failure = connect_to_one(&client_rt, &[], &None).await;
242
            assert!(failure.is_err());
243

            
244
            // Connect to a set of addresses including addr1? That's a success.
245
            for addresses in [
246
                &[addr1][..],
247
                &[addr1, addr2][..],
248
                &[addr2, addr1][..],
249
                &[addr1, addr3][..],
250
                &[addr3, addr1][..],
251
                &[addr1, addr2, addr3][..],
252
                &[addr3, addr2, addr1][..],
253
            ] {
254
                let (_conn, addr) = connect_to_one(&client_rt, addresses, &None).await.unwrap();
255
                assert_eq!(addr, addr1);
256
            }
257

            
258
            // Connect to a set of addresses including addr2 but not addr1?
259
            // That's an error of one kind or another.
260
            for addresses in [
261
                &[addr2][..],
262
                &[addr2, addr3][..],
263
                &[addr3, addr2][..],
264
                &[addr3][..],
265
            ] {
266
                let expect_timeout = addresses.contains(&addr3);
267
                let failure = rt
268
                    .timeout(
269
                        Duration::from_millis(300),
270
                        connect_to_one(&client_rt, addresses, &None),
271
                    )
272
                    .await;
273
                if expect_timeout {
274
                    assert!(failure.is_err());
275
                } else {
276
                    assert!(failure.unwrap().is_err());
277
                }
278
            }
279

            
280
            // Connect to addr1 and addr4?  The first one should win.
281
            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4], &None)
282
                .await
283
                .unwrap();
284
            assert_eq!(addr, addr1);
285
            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1], &None)
286
                .await
287
                .unwrap();
288
            assert_eq!(addr, addr4);
289
        });
290
    }
291
}