1
//! Implement the default transport, which opens TCP connections using a
2
//! happy-eyeballs style parallel algorithm.
3

            
4
use std::{net::SocketAddr, time::Duration};
5

            
6
use async_trait::async_trait;
7
use futures::{FutureExt, StreamExt, TryFutureExt, stream::FuturesUnordered};
8
use safelog::sensitive as sv;
9
use tor_error::bad_api_usage;
10
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11
use tor_proto::peer::PeerAddr;
12
use tor_rtcompat::{NetStreamProvider, Runtime};
13
use tracing::{instrument, trace};
14

            
15
use crate::{Error, err::ConnectError};
16

            
17
/// A default transport object that opens TCP connections for a
18
/// `ChannelMethod::Direct`.
19
///
20
/// It opens almost-simultaneous parallel TCP connections to each address, and
21
/// chooses the first one to succeed.
22
#[derive(Clone, Debug)]
23
pub(crate) struct DefaultTransport<R: Runtime> {
24
    /// The runtime that we use for connecting.
25
    runtime: R,
26
    /// The outbound proxy to use, if any
27
    outbound_proxy: Option<crate::config::ProxyProtocol>,
28
}
29

            
30
impl<R: Runtime> DefaultTransport<R> {
31
    /// Construct a new DefaultTransport
32
60
    pub(crate) fn new(runtime: R, outbound_proxy: Option<crate::config::ProxyProtocol>) -> Self {
33
60
        Self {
34
60
            runtime,
35
60
            outbound_proxy,
36
60
        }
37
60
    }
38
}
39

            
40
#[async_trait]
41
impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
42
    type Stream = <R as NetStreamProvider>::Stream;
43

            
44
    /// Implements the transport: makes a TCP connection (possibly
45
    /// tunneled over whatever protocol) if possible.
46
    #[instrument(skip_all, level = "trace")]
47
    async fn connect(&self, target: &OwnedChanTarget) -> crate::Result<(PeerAddr, Self::Stream)> {
48
        let direct_addrs: Vec<_> = match target.chan_method() {
49
            ChannelMethod::Direct(addrs) => addrs,
50
            #[allow(unreachable_patterns)]
51
            _ => {
52
                return Err(Error::UnusableTarget(bad_api_usage!(
53
                    "Used default transport implementation for an unsupported transport."
54
                )));
55
            }
56
        };
57

            
58
        trace!("Launching direct connection for {}", target);
59

            
60
        let (stream, addr) =
61
            connect_to_one(&self.runtime, &direct_addrs, &self.outbound_proxy).await?;
62
        Ok((addr.into(), stream))
63
    }
64
}
65

            
66
/// Time to wait between starting parallel connections to the same relay.
67
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
68

            
69
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
70
///
71
/// This implements a basic version of RFC 8305 "happy eyeballs".
72
#[instrument(skip_all, level = "trace")]
73
30
async fn connect_to_one<R: Runtime>(
74
30
    rt: &R,
75
30
    addrs: &[SocketAddr],
76
30
    outbound_proxy: &Option<crate::config::ProxyProtocol>,
77
30
) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
78
    // We need *some* addresses to connect to.
79
    if addrs.is_empty() {
80
        return Err(Error::UnusableTarget(bad_api_usage!(
81
            "No addresses for chosen relay"
82
        )));
83
    }
84

            
85
    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
86
    // attempts to connect to the address using the runtime (where i is the
87
    // array index). Shove all of these into a `FuturesUnordered`, polling them
88
    // simultaneously and returning the results in completion order.
89
    //
90
    // This is basically the concurrent-connection stuff from RFC 8305, ish.
91
    // TODO(eta): sort the addresses first?
92
    let mut connections = addrs
93
        .iter()
94
        .enumerate()
95
52
        .map(|(i, a)| {
96
52
            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
97
52
            let proxy = outbound_proxy.clone();
98
52
            delay.then(move |_| {
99
40
                tracing::debug!("Connecting to {}", a);
100
40
                let a = *a;
101
40
                async move {
102
40
                    let stream = if let Some(ref protocol) = proxy {
103
                        // Use proxy - extract address and protocol details
104
                        let target = tor_linkspec::PtTargetAddr::IpPort(a);
105
                        match protocol {
106
                            crate::config::ProxyProtocol::Socks {
107
                                version,
108
                                auth,
109
                                addr,
110
                            } => {
111
                                let proto = super::proxied::Protocol::Socks(*version, auth.clone());
112
                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
113
                            }
114
                            crate::config::ProxyProtocol::HttpConnect { addr, credentials } => {
115
                                // Wrap credentials in Sensitive to avoid accidental logging.
116
                                let auth = credentials.as_ref().map(|cred| {
117
                                    (
118
                                        safelog::Sensitive::new(cred.username.clone()),
119
                                        safelog::Sensitive::new(
120
                                            cred.password.clone().unwrap_or_default(),
121
                                        ),
122
                                    )
123
                                });
124
                                let proto = super::proxied::Protocol::HttpConnect { auth };
125
                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
126
                            }
127
                        }
128
                    } else {
129
                        // Direct connection
130
40
                        rt.connect(&a).await?
131
                    };
132
20
                    Ok((stream, a))
133
30
                }
134
40
                .map_err(move |e: ConnectError| (e, a))
135
40
            })
136
52
        })
137
        .collect::<FuturesUnordered<_>>();
138

            
139
    let mut ret = None;
140
    let mut errors: Vec<(ConnectError, SocketAddr)> = vec![];
141

            
142
    while let Some(result) = connections.next().await {
143
        match result {
144
            Ok(s) => {
145
                // We got a stream (and address).
146
                ret = Some(s);
147
                break;
148
            }
149
            Err((e, a)) => {
150
                // We got a failure on one of the streams. Store the error.
151
                // TODO(eta): ideally we'd start the next connection attempt immediately.
152
                errors.push((e, a));
153
            }
154
        }
155
    }
156

            
157
    // Ensure we don't continue trying to make connections.
158
    drop(connections);
159

            
160
    ret.ok_or_else(|| Error::Connect {
161
2
        addresses: errors
162
2
            .into_iter()
163
2
            .map(|(e, a)| (sv(a.to_string()), e))
164
2
            .collect(),
165
2
    })
166
24
}
167

            
168
#[cfg(test)]
169
mod test {
170
    // @@ begin test lint list maintained by maint/add_warning @@
171
    #![allow(clippy::bool_assert_comparison)]
172
    #![allow(clippy::clone_on_copy)]
173
    #![allow(clippy::dbg_macro)]
174
    #![allow(clippy::mixed_attributes_style)]
175
    #![allow(clippy::print_stderr)]
176
    #![allow(clippy::print_stdout)]
177
    #![allow(clippy::single_char_pattern)]
178
    #![allow(clippy::unwrap_used)]
179
    #![allow(clippy::unchecked_time_subtraction)]
180
    #![allow(clippy::useless_vec)]
181
    #![allow(clippy::needless_pass_by_value)]
182
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
183

            
184
    use std::str::FromStr;
185

            
186
    use tor_rtcompat::{SleepProviderExt, test_with_one_runtime};
187
    use tor_rtmock::net::MockNetwork;
188

            
189
    use super::*;
190

            
191
    #[test]
192
    fn test_connect_one() {
193
        let client_addr = "192.0.1.16".parse().unwrap();
194
        // We'll put a "relay" at this address
195
        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
196
        // We'll put nothing at this address, to generate errors.
197
        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
198
        // Well put a black hole at this address, to generate timeouts.
199
        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
200
        // We'll put a "relay" at this address too
201
        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
202

            
203
        test_with_one_runtime!(|rt| async move {
204
            // Stub out the internet so that this connection can work.
205
            let network = MockNetwork::new();
206

            
207
            // Set up a client and server runtime with a given IP
208
            let client_rt = network
209
                .builder()
210
                .add_address(client_addr)
211
                .runtime(rt.clone());
212
            let server_rt = network
213
                .builder()
214
                .add_address(addr1.ip())
215
                .add_address(addr4.ip())
216
                .runtime(rt.clone());
217
            let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
218
            let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
219
            // TODO: Because this test doesn't mock time, there will actually be
220
            // delays as we wait for connections to this address to time out. It
221
            // would be good to use MockSleepProvider instead, once we figure
222
            // out how to make it both reliable and convenient.
223
            network.add_blackhole(addr3).unwrap();
224

            
225
            // No addresses? Can't succeed.
226
            let failure = connect_to_one(&client_rt, &[], &None).await;
227
            assert!(failure.is_err());
228

            
229
            // Connect to a set of addresses including addr1? That's a success.
230
            for addresses in [
231
                &[addr1][..],
232
                &[addr1, addr2][..],
233
                &[addr2, addr1][..],
234
                &[addr1, addr3][..],
235
                &[addr3, addr1][..],
236
                &[addr1, addr2, addr3][..],
237
                &[addr3, addr2, addr1][..],
238
            ] {
239
                let (_conn, addr) = connect_to_one(&client_rt, addresses, &None).await.unwrap();
240
                assert_eq!(addr, addr1);
241
            }
242

            
243
            // Connect to a set of addresses including addr2 but not addr1?
244
            // That's an error of one kind or another.
245
            for addresses in [
246
                &[addr2][..],
247
                &[addr2, addr3][..],
248
                &[addr3, addr2][..],
249
                &[addr3][..],
250
            ] {
251
                let expect_timeout = addresses.contains(&addr3);
252
                let failure = rt
253
                    .timeout(
254
                        Duration::from_millis(300),
255
                        connect_to_one(&client_rt, addresses, &None),
256
                    )
257
                    .await;
258
                if expect_timeout {
259
                    assert!(failure.is_err());
260
                } else {
261
                    assert!(failure.unwrap().is_err());
262
                }
263
            }
264

            
265
            // Connect to addr1 and addr4?  The first one should win.
266
            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4], &None)
267
                .await
268
                .unwrap();
269
            assert_eq!(addr, addr1);
270
            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1], &None)
271
                .await
272
                .unwrap();
273
            assert_eq!(addr, addr4);
274
        });
275
    }
276
}