1
//! Implement the default transport, which opens TCP connections using a
2
//! happy-eyeballs style parallel algorithm.
3

            
4
use std::{net::SocketAddr, sync::Arc, time::Duration};
5

            
6
use async_trait::async_trait;
7
use futures::{FutureExt, StreamExt, TryFutureExt, stream::FuturesUnordered};
8
use safelog::sensitive as sv;
9
use tor_error::bad_api_usage;
10
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11
use tor_proto::peer::PeerAddr;
12
use tor_rtcompat::{NetStreamProvider, Runtime};
13
use tracing::{instrument, trace};
14

            
15
use crate::Error;
16

            
17
/// A default transport object that opens TCP connections for a
18
/// `ChannelMethod::Direct`.
19
///
20
/// It opens almost-simultaneous parallel TCP connections to each address, and
21
/// chooses the first one to succeed.
22
#[derive(Clone, Debug)]
23
pub(crate) struct DefaultTransport<R: Runtime> {
24
    /// The runtime that we use for connecting.
25
    runtime: R,
26
    /// The outbound proxy to use, if any
27
    outbound_proxy: Option<crate::config::ProxyProtocol>,
28
}
29

            
30
impl<R: Runtime> DefaultTransport<R> {
31
    /// Construct a new DefaultTransport
32
60
    pub(crate) fn new(runtime: R, outbound_proxy: Option<crate::config::ProxyProtocol>) -> Self {
33
60
        Self {
34
60
            runtime,
35
60
            outbound_proxy,
36
60
        }
37
60
    }
38
}
39

            
40
#[async_trait]
41
impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
42
    type Stream = <R as NetStreamProvider>::Stream;
43

            
44
    /// Implements the transport: makes a TCP connection (possibly
45
    /// tunneled over whatever protocol) if possible.
46
    #[instrument(skip_all, level = "trace")]
47
    async fn connect(&self, target: &OwnedChanTarget) -> crate::Result<(PeerAddr, Self::Stream)> {
48
        let direct_addrs: Vec<_> = match target.chan_method() {
49
            ChannelMethod::Direct(addrs) => addrs,
50
            #[allow(unreachable_patterns)]
51
            _ => {
52
                return Err(Error::UnusableTarget(bad_api_usage!(
53
                    "Used default transport implementation for an unsupported transport."
54
                )));
55
            }
56
        };
57

            
58
        trace!("Launching direct connection for {}", target);
59

            
60
        let (stream, addr) =
61
            connect_to_one(&self.runtime, &direct_addrs, &self.outbound_proxy).await?;
62
        Ok((addr.into(), stream))
63
    }
64
}
65

            
66
/// Time to wait between starting parallel connections to the same relay.
67
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
68

            
69
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
70
///
71
/// This implements a basic version of RFC 8305 "happy eyeballs".
72
#[instrument(skip_all, level = "trace")]
73
30
async fn connect_to_one<R: Runtime>(
74
30
    rt: &R,
75
30
    addrs: &[SocketAddr],
76
30
    outbound_proxy: &Option<crate::config::ProxyProtocol>,
77
30
) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
78
    // We need *some* addresses to connect to.
79
    if addrs.is_empty() {
80
        return Err(Error::UnusableTarget(bad_api_usage!(
81
            "No addresses for chosen relay"
82
        )));
83
    }
84

            
85
    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
86
    // attempts to connect to the address using the runtime (where i is the
87
    // array index). Shove all of these into a `FuturesUnordered`, polling them
88
    // simultaneously and returning the results in completion order.
89
    //
90
    // This is basically the concurrent-connection stuff from RFC 8305, ish.
91
    // TODO(eta): sort the addresses first?
92
    let mut connections = addrs
93
        .iter()
94
        .enumerate()
95
52
        .map(|(i, a)| {
96
52
            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
97
52
            let proxy = outbound_proxy.clone();
98
52
            delay.then(move |_| {
99
40
                tracing::debug!("Connecting to {}", a);
100
40
                let a = *a;
101
40
                async move {
102
40
                    let stream = if let Some(ref protocol) = proxy {
103
                        // Use proxy - extract address and protocol details
104
                        let (proxy_addr, version, auth) = match protocol {
105
                            crate::config::ProxyProtocol::Socks {
106
                                version,
107
                                auth,
108
                                addr,
109
                            } => (*addr, *version, auth.clone()),
110
                        };
111
                        let target = tor_linkspec::PtTargetAddr::IpPort(a);
112
                        let proto = super::proxied::Protocol::Socks(version, auth);
113
                        super::proxied::connect_via_proxy(rt, &proxy_addr, &proto, &target).await
114
                    } else {
115
                        // Direct connection
116
40
                        rt.connect(&a)
117
40
                            .await
118
30
                            .map_err(super::proxied::ProxyError::from)
119
10
                    }?;
120
20
                    Ok((stream, a))
121
30
                }
122
40
                .map_err(move |e: super::proxied::ProxyError| (e, a))
123
40
            })
124
52
        })
125
        .collect::<FuturesUnordered<_>>();
126

            
127
    let mut ret = None;
128
    let mut errors = vec![];
129

            
130
    while let Some(result) = connections.next().await {
131
        match result {
132
            Ok(s) => {
133
                // We got a stream (and address).
134
                ret = Some(s);
135
                break;
136
            }
137
            Err((e, a)) => {
138
                // We got a failure on one of the streams. Store the error.
139
                // TODO(eta): ideally we'd start the next connection attempt immediately.
140
                tor_error::warn_report!(&e, "Connection to {} failed", sv(a));
141
                errors.push((e, a));
142
            }
143
        }
144
    }
145

            
146
    // Ensure we don't continue trying to make connections.
147
    drop(connections);
148

            
149
    ret.ok_or_else(|| Error::ChannelBuild {
150
2
        addresses: errors
151
2
            .into_iter()
152
2
            .map(|(e, a)| (sv(a), Arc::new(std::io::Error::from(e))))
153
2
            .collect(),
154
2
    })
155
24
}
156

            
157
#[cfg(test)]
158
mod test {
159
    // @@ begin test lint list maintained by maint/add_warning @@
160
    #![allow(clippy::bool_assert_comparison)]
161
    #![allow(clippy::clone_on_copy)]
162
    #![allow(clippy::dbg_macro)]
163
    #![allow(clippy::mixed_attributes_style)]
164
    #![allow(clippy::print_stderr)]
165
    #![allow(clippy::print_stdout)]
166
    #![allow(clippy::single_char_pattern)]
167
    #![allow(clippy::unwrap_used)]
168
    #![allow(clippy::unchecked_time_subtraction)]
169
    #![allow(clippy::useless_vec)]
170
    #![allow(clippy::needless_pass_by_value)]
171
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
172

            
173
    use std::str::FromStr;
174

            
175
    use tor_rtcompat::{SleepProviderExt, test_with_one_runtime};
176
    use tor_rtmock::net::MockNetwork;
177

            
178
    use super::*;
179

            
180
    #[test]
181
    fn test_connect_one() {
182
        let client_addr = "192.0.1.16".parse().unwrap();
183
        // We'll put a "relay" at this address
184
        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
185
        // We'll put nothing at this address, to generate errors.
186
        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
187
        // Well put a black hole at this address, to generate timeouts.
188
        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
189
        // We'll put a "relay" at this address too
190
        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
191

            
192
        test_with_one_runtime!(|rt| async move {
193
            // Stub out the internet so that this connection can work.
194
            let network = MockNetwork::new();
195

            
196
            // Set up a client and server runtime with a given IP
197
            let client_rt = network
198
                .builder()
199
                .add_address(client_addr)
200
                .runtime(rt.clone());
201
            let server_rt = network
202
                .builder()
203
                .add_address(addr1.ip())
204
                .add_address(addr4.ip())
205
                .runtime(rt.clone());
206
            let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
207
            let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
208
            // TODO: Because this test doesn't mock time, there will actually be
209
            // delays as we wait for connections to this address to time out. It
210
            // would be good to use MockSleepProvider instead, once we figure
211
            // out how to make it both reliable and convenient.
212
            network.add_blackhole(addr3).unwrap();
213

            
214
            // No addresses? Can't succeed.
215
            let failure = connect_to_one(&client_rt, &[], &None).await;
216
            assert!(failure.is_err());
217

            
218
            // Connect to a set of addresses including addr1? That's a success.
219
            for addresses in [
220
                &[addr1][..],
221
                &[addr1, addr2][..],
222
                &[addr2, addr1][..],
223
                &[addr1, addr3][..],
224
                &[addr3, addr1][..],
225
                &[addr1, addr2, addr3][..],
226
                &[addr3, addr2, addr1][..],
227
            ] {
228
                let (_conn, addr) = connect_to_one(&client_rt, addresses, &None).await.unwrap();
229
                assert_eq!(addr, addr1);
230
            }
231

            
232
            // Connect to a set of addresses including addr2 but not addr1?
233
            // That's an error of one kind or another.
234
            for addresses in [
235
                &[addr2][..],
236
                &[addr2, addr3][..],
237
                &[addr3, addr2][..],
238
                &[addr3][..],
239
            ] {
240
                let expect_timeout = addresses.contains(&addr3);
241
                let failure = rt
242
                    .timeout(
243
                        Duration::from_millis(300),
244
                        connect_to_one(&client_rt, addresses, &None),
245
                    )
246
                    .await;
247
                if expect_timeout {
248
                    assert!(failure.is_err());
249
                } else {
250
                    assert!(failure.unwrap().is_err());
251
                }
252
            }
253

            
254
            // Connect to addr1 and addr4?  The first one should win.
255
            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4], &None)
256
                .await
257
                .unwrap();
258
            assert_eq!(addr, addr1);
259
            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1], &None)
260
                .await
261
                .unwrap();
262
            assert_eq!(addr, addr4);
263
        });
264
    }
265
}