1
//! Abstract implementation of a channel manager
2

            
3
use crate::factory::BootstrapReporter;
4
use crate::mgr::state::{ChannelForTarget, PendingChannelHandle};
5
use crate::util::defer::Defer;
6
use crate::{ChanProvenance, ChannelConfig, ChannelUsage, Dormancy, Error, Result};
7

            
8
use async_trait::async_trait;
9
use futures::future::Shared;
10
use oneshot_fused_workaround as oneshot;
11
use std::result::Result as StdResult;
12
use std::sync::Arc;
13
use std::time::Duration;
14
use tor_error::{error_report, internal};
15
use tor_linkspec::{HasChanMethod, HasRelayIds};
16
use tor_netdir::params::NetParameters;
17
use tor_proto::channel::kist::KistParams;
18
use tor_proto::channel::params::ChannelPaddingInstructionsUpdates;
19
use tor_proto::memquota::{ChannelAccount, SpecificAccount as _, ToplevelAccount};
20
use tracing::{instrument, trace};
21

            
22
#[cfg(feature = "relay")]
23
use {safelog::Sensitive, std::net::IpAddr, tor_proto::RelayIdentities};
24

            
25
mod select;
26
mod state;
27

            
28
/// Trait to describe as much of a
29
/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
30
/// needs to use.
31
pub(crate) trait AbstractChannel: HasRelayIds {
32
    /// Return true iff this channel is considered canonical by us.
33
    fn is_canonical(&self) -> bool;
34
    /// Return true if we think the peer considers this channel as canonical.
35
    fn is_canonical_to_peer(&self) -> bool;
36
    /// Return true if this channel is usable.
37
    ///
38
    /// A channel might be unusable because it is closed, because it has
39
    /// hit a bug, or for some other reason.  We don't return unusable
40
    /// channels back to the user.
41
    fn is_usable(&self) -> bool;
42
    /// Return the amount of time a channel has not been in use.
43
    /// Return None if the channel is currently in use.
44
    fn duration_unused(&self) -> Option<Duration>;
45

            
46
    /// Reparameterize this channel according to the provided `ChannelPaddingInstructionsUpdates`
47
    ///
48
    /// The changed parameters may not be implemented "immediately",
49
    /// but this will be done "reasonably soon".
50
    fn reparameterize(
51
        &self,
52
        updates: Arc<ChannelPaddingInstructionsUpdates>,
53
    ) -> tor_proto::Result<()>;
54

            
55
    /// Update the KIST parameters.
56
    ///
57
    /// The changed parameters may not be implemented "immediately",
58
    /// but this will be done "reasonably soon".
59
    fn reparameterize_kist(&self, kist_params: KistParams) -> tor_proto::Result<()>;
60

            
61
    /// Specify that this channel should do activities related to channel padding
62
    ///
63
    /// See [`Channel::engage_padding_activities`]
64
    ///
65
    /// [`Channel::engage_padding_activities`]: tor_proto::channel::Channel::engage_padding_activities
66
    fn engage_padding_activities(&self);
67
}
68

            
69
/// Trait to describe how channels-like objects are created.
70
///
71
/// This differs from [`ChannelFactory`](crate::factory::ChannelFactory) in that
72
/// it's a purely crate-internal type that we use to decouple the
73
/// AbstractChanMgr code from actual "what is a channel" concerns.
74
#[async_trait]
75
pub(crate) trait AbstractChannelFactory {
76
    /// The type of channel that this factory can build.
77
    type Channel: AbstractChannel;
78
    /// Type that explains how to build an outgoing channel.
79
    type BuildSpec: HasRelayIds + HasChanMethod;
80
    /// The type of byte stream that's required to build channels for incoming connections.
81
    type Stream;
82

            
83
    /// Construct a new channel to the destination described at `target`.
84
    ///
85
    /// This function must take care of all timeouts, error detection,
86
    /// and so on.
87
    ///
88
    /// It should not retry; that is handled at a higher level.
89
    async fn build_channel(
90
        &self,
91
        target: &Self::BuildSpec,
92
        reporter: BootstrapReporter,
93
        memquota: ChannelAccount,
94
    ) -> Result<Arc<Self::Channel>>;
95

            
96
    /// Construct a new channel for an incoming connection.
97
    #[cfg(feature = "relay")]
98
    async fn build_channel_using_incoming(
99
        &self,
100
        peer: Sensitive<std::net::SocketAddr>,
101
        stream: Self::Stream,
102
        memquota: ChannelAccount,
103
    ) -> Result<Arc<Self::Channel>>;
104
}
105

            
106
/// This is the configuration for a [`ChanMgr`](crate::ChanMgr) given to the constructor.
107
#[derive(Default)]
108
pub struct ChanMgrConfig {
109
    /// Channel configuration which usually comes from a configuration file.
110
    pub(crate) cfg: ChannelConfig,
111
    /// Relay identities needed for relay channels.
112
    #[cfg(feature = "relay")]
113
    pub(crate) identities: Option<Arc<RelayIdentities>>,
114
    /// Our address(es). When building outgoing channel, we need our addresses in order to send
115
    /// them in the NETINFO cell.
116
    #[cfg(feature = "relay")]
117
    pub(crate) my_addrs: Vec<IpAddr>,
118
    // TODO: Would be good to add more things such as NetParameters and Dormancy maybe?
119
}
120

            
121
impl ChanMgrConfig {
122
    /// Constructor.
123
396
    pub fn new(cfg: ChannelConfig) -> Self {
124
396
        Self {
125
396
            cfg,
126
396
            #[cfg(feature = "relay")]
127
396
            identities: None,
128
396
            #[cfg(feature = "relay")]
129
396
            my_addrs: Vec::new(),
130
396
        }
131
396
    }
132

            
133
    /// Set the relay identities and return itself.
134
    #[cfg(feature = "relay")]
135
    pub fn with_identities(mut self, ids: Arc<RelayIdentities>) -> Self {
136
        self.identities = Some(ids);
137
        self
138
    }
139

            
140
    /// Set our addresses that we advertise to the world.
141
    #[cfg(feature = "relay")]
142
    pub fn with_my_addrs(mut self, my_addrs: Vec<IpAddr>) -> Self {
143
        self.my_addrs = my_addrs;
144
        self
145
    }
146
}
147

            
148
/// A type- and network-agnostic implementation for [`ChanMgr`](crate::ChanMgr).
149
///
150
/// This type does the work of keeping track of open channels and pending
151
/// channel requests, launching requests as needed, waiting for pending
152
/// requests, and so forth.
153
///
154
/// The actual job of launching connections is deferred to an
155
/// `AbstractChannelFactory` type.
156
pub(crate) struct AbstractChanMgr<CF: AbstractChannelFactory> {
157
    /// All internal state held by this channel manager.
158
    ///
159
    /// The most important part is the map from relay identity to channel, or
160
    /// to pending channel status.
161
    pub(crate) channels: state::MgrState<CF>,
162

            
163
    /// A bootstrap reporter to give out when building channels.
164
    pub(crate) reporter: BootstrapReporter,
165

            
166
    /// The memory quota account that every channel will be a child of
167
    pub(crate) memquota: ToplevelAccount,
168
}
169

            
170
/// Type alias for a future that we wait on to see when a pending
171
/// channel is done or failed.
172
type Pending = Shared<oneshot::Receiver<Result<()>>>;
173

            
174
/// Type alias for the sender we notify when we complete a channel (or fail to
175
/// complete it).
176
type Sending = oneshot::Sender<Result<()>>;
177

            
178
impl<CF: AbstractChannelFactory + Clone> AbstractChanMgr<CF> {
179
    /// Make a new empty channel manager.
180
116
    pub(crate) fn new(
181
116
        connector: CF,
182
116
        config: ChannelConfig,
183
116
        dormancy: Dormancy,
184
116
        netparams: &NetParameters,
185
116
        reporter: BootstrapReporter,
186
116
        memquota: ToplevelAccount,
187
116
    ) -> Self {
188
116
        AbstractChanMgr {
189
116
            channels: state::MgrState::new(connector, config, dormancy, netparams),
190
116
            reporter,
191
116
            memquota,
192
116
        }
193
116
    }
194

            
195
    /// Run a function to modify the channel builder in this object.
196
    #[allow(unused)]
197
22
    pub(crate) fn with_mut_builder<F>(&self, func: F)
198
22
    where
199
22
        F: FnOnce(&mut CF),
200
    {
201
22
        self.channels.with_mut_builder(func);
202
22
    }
203

            
204
    /// Remove every unusable entry from this channel manager.
205
    #[cfg(test)]
206
2
    pub(crate) fn remove_unusable_entries(&self) -> Result<()> {
207
2
        self.channels.remove_unusable()
208
2
    }
209

            
210
    /// Build a channel for an incoming stream. See
211
    /// [`ChanMgr::handle_incoming`](crate::ChanMgr::handle_incoming).
212
    #[cfg(feature = "relay")]
213
    pub(crate) async fn handle_incoming(
214
        &self,
215
        src: Sensitive<std::net::SocketAddr>,
216
        stream: CF::Stream,
217
    ) -> Result<Arc<CF::Channel>> {
218
        let chan_builder = self.channels.builder();
219
        let memquota = ChannelAccount::new(&self.memquota)?;
220
        let channel = chan_builder
221
            .build_channel_using_incoming(src, stream, memquota)
222
            .await?;
223
        // Add it to our list.
224
        self.channels.add_open(channel.clone())?;
225
        Ok(channel)
226
    }
227

            
228
    /// Get a channel corresponding to the identities of `target`.
229
    ///
230
    /// If a usable channel exists with that identity, return it.
231
    ///
232
    /// If no such channel exists already, and none is in progress,
233
    /// launch a new request using `target`.
234
    ///
235
    /// If no such channel exists already, but we have one that's in
236
    /// progress, wait for it to succeed or fail.
237
    #[instrument(skip_all, level = "trace")]
238
82
    pub(crate) async fn get_or_launch(
239
82
        &self,
240
82
        target: CF::BuildSpec,
241
82
        usage: ChannelUsage,
242
82
    ) -> Result<(Arc<CF::Channel>, ChanProvenance)> {
243
        use ChannelUsage as CU;
244

            
245
        let chan = self.get_or_launch_internal(target).await?;
246

            
247
82
        match usage {
248
            CU::Dir | CU::UselessCircuit => {}
249
            CU::UserTraffic => chan.0.engage_padding_activities(),
250
        }
251

            
252
        Ok(chan)
253
82
    }
254

            
255
    /// Get a channel whose identity is `ident` - internal implementation
256
    #[allow(clippy::cognitive_complexity)]
257
    #[instrument(skip_all, level = "trace")]
258
82
    async fn get_or_launch_internal(
259
82
        &self,
260
82
        target: CF::BuildSpec,
261
82
    ) -> Result<(Arc<CF::Channel>, ChanProvenance)> {
262
        /// How many times do we try?
263
        const N_ATTEMPTS: usize = 2;
264
        let mut attempts_so_far = 0;
265
        let mut final_attempt = false;
266
        let mut provenance = ChanProvenance::Preexisting;
267

            
268
        // TODO(nickm): It would be neat to use tor_retry instead.
269
        let mut last_err = None;
270

            
271
        while attempts_so_far < N_ATTEMPTS || final_attempt {
272
            attempts_so_far += 1;
273

            
274
            // For each attempt, we _first_ look at the state of the channel map
275
            // to decide on an `Action`, and _then_ we execute that action.
276

            
277
            // First, see what state we're in, and what we should do about it.
278
            let action = self.choose_action(&target, final_attempt)?;
279

            
280
            // We are done deciding on our Action! It's time act based on the
281
            // Action that we chose.
282
            match action {
283
                // If this happens, we were trying to make one final check of our state, but
284
                // we would have had to make additional attempts.
285
                None => {
286
                    if !final_attempt {
287
                        return Err(Error::Internal(internal!(
288
                            "No action returned while not on final attempt"
289
                        )));
290
                    }
291
                    break;
292
                }
293
                // Easy case: we have an error or a channel to return.
294
                Some(Action::Return(v)) => {
295
                    trace!("Returning existing channel");
296
10
                    return v.map(|chan| (chan, provenance));
297
                }
298
                // There's an in-progress channel.  Wait for it.
299
                Some(Action::Wait(pend)) => {
300
                    trace!("Waiting for in-progress channel");
301
                    match pend.await {
302
                        Ok(Ok(())) => {
303
                            // We were waiting for a channel, and it succeeded, or it
304
                            // got cancelled.  But it might have gotten more
305
                            // identities while negotiating than it had when it was
306
                            // launched, or it might have failed to get all the
307
                            // identities we want. Check for this.
308
                            final_attempt = true;
309
                            provenance = ChanProvenance::NewlyCreated;
310
                            last_err.get_or_insert(Error::RequestCancelled);
311
                        }
312
                        Ok(Err(e)) => {
313
                            last_err = Some(e);
314
                        }
315
                        Err(_) => {
316
                            last_err =
317
                                Some(Error::Internal(internal!("channel build task disappeared")));
318
                        }
319
                    }
320
                }
321
                // We need to launch a channel.
322
                Some(Action::Launch((handle, send))) => {
323
                    trace!("Launching channel");
324
                    // If the remainder of this code returns early or is cancelled, we still want to
325
                    // clean up our pending entry in the channel map. The following closure will be
326
                    // run when dropped to ensure that it's cleaned up properly.
327
                    //
328
                    // The `remove_pending_channel` will acquire the lock within `MgrState`, but
329
                    // this won't lead to deadlocks since the lock is only ever acquired within
330
                    // methods of `MgrState`. When this `Defer` is being dropped, no other
331
                    // `MgrState` methods will be running on this thread, so the lock will not have
332
                    // already been acquired.
333
8
                    let defer_remove_pending = Defer::new(handle, |handle| {
334
8
                        if let Err(e) = self.channels.remove_pending_channel(handle) {
335
                            // Just log an error if we're unable to remove it, since there's
336
                            // nothing else we can do here, and returning the error would
337
                            // hide the actual error that we care about (the channel build
338
                            // failure).
339
                            #[allow(clippy::missing_docs_in_private_items)]
340
                            const MSG: &str = "Unable to remove the pending channel";
341
                            error_report!(internal!("{e}"), "{}", MSG);
342
8
                        }
343
8
                    });
344

            
345
                    let connector = self.channels.builder();
346
                    let memquota = ChannelAccount::new(&self.memquota)?;
347

            
348
                    let outcome = connector
349
                        .build_channel(&target, self.reporter.clone(), memquota)
350
                        .await;
351

            
352
                    match outcome {
353
                        Ok(ref chan) => {
354
                            // Replace the pending channel with the newly built channel.
355
                            let handle = defer_remove_pending.cancel();
356
                            self.channels
357
                                .upgrade_pending_channel_to_open(handle, Arc::clone(chan))?;
358
                        }
359
                        Err(_) => {
360
                            // Remove the pending channel.
361
                            drop(defer_remove_pending);
362
                        }
363
                    }
364

            
365
                    // It's okay if all the receivers went away:
366
                    // that means that nobody was waiting for this channel.
367
                    let _ignore_err = send.send(outcome.clone().map(|_| ()));
368

            
369
                    match outcome {
370
                        Ok(chan) => {
371
                            return Ok((chan, ChanProvenance::NewlyCreated));
372
                        }
373
                        Err(e) => last_err = Some(e),
374
                    }
375
                }
376
            }
377

            
378
            // End of this attempt. We will try again...
379
        }
380

            
381
        Err(last_err.unwrap_or_else(|| Error::Internal(internal!("no error was set!?"))))
382
82
    }
383

            
384
    /// Helper: based on our internal state, decide which action to take when
385
    /// asked for a channel, and update our internal state accordingly.
386
    ///
387
    /// If `final_attempt` is true, then we will not pick any action that does
388
    /// not result in an immediate result. If we would pick such an action, we
389
    /// instead return `Ok(None)`.  (We could instead have the caller detect
390
    /// such actions, but it's less efficient to construct them, insert them,
391
    /// and immediately revert them.)
392
    #[instrument(skip_all, level = "trace")]
393
94
    fn choose_action(
394
94
        &self,
395
94
        target: &CF::BuildSpec,
396
94
        final_attempt: bool,
397
94
    ) -> Result<Option<Action<CF::Channel>>> {
398
        // don't create new channels on the final attempt
399
94
        let response = self.channels.request_channel(
400
94
            target,
401
94
            /* add_new_entry_if_not_found= */ !final_attempt,
402
        );
403

            
404
94
        match response {
405
10
            Ok(Some(ChannelForTarget::Open(channel))) => Ok(Some(Action::Return(Ok(channel)))),
406
10
            Ok(Some(ChannelForTarget::Pending(pending))) => {
407
10
                if !final_attempt {
408
10
                    Ok(Some(Action::Wait(pending)))
409
                } else {
410
                    // don't return a pending channel on the final attempt
411
                    Ok(None)
412
                }
413
            }
414
74
            Ok(Some(ChannelForTarget::NewEntry((handle, send)))) => {
415
                // do not drop the handle if refactoring; see `PendingChannelHandle` for details
416
74
                Ok(Some(Action::Launch((handle, send))))
417
            }
418
            Ok(None) => Ok(None),
419
            Err(e @ Error::IdentityConflict) => Ok(Some(Action::Return(Err(e)))),
420
            Err(e) => Err(e),
421
        }
422
94
    }
423

            
424
    /// Update the netdir
425
24
    pub(crate) fn update_netparams(
426
24
        &self,
427
24
        netparams: Arc<dyn AsRef<NetParameters>>,
428
24
    ) -> StdResult<(), tor_error::Bug> {
429
24
        self.channels.reconfigure_general(None, None, netparams)
430
24
    }
431

            
432
    /// Notifies the chanmgr to be dormant like dormancy
433
42
    pub(crate) fn set_dormancy(
434
42
        &self,
435
42
        dormancy: Dormancy,
436
42
        netparams: Arc<dyn AsRef<NetParameters>>,
437
42
    ) -> StdResult<(), tor_error::Bug> {
438
42
        self.channels
439
42
            .reconfigure_general(None, Some(dormancy), netparams)
440
42
    }
441

            
442
    /// Reconfigure all channels
443
28
    pub(crate) fn reconfigure(
444
28
        &self,
445
28
        config: &ChannelConfig,
446
28
        netparams: Arc<dyn AsRef<NetParameters>>,
447
28
    ) -> StdResult<(), tor_error::Bug> {
448
28
        self.channels
449
28
            .reconfigure_general(Some(config), None, netparams)
450
28
    }
451

            
452
    /// Expire any channels that have been unused longer than
453
    /// their maximum unused duration assigned during creation.
454
    ///
455
    /// Return a duration from now until next channel expires.
456
    ///
457
    /// If all channels are in use or there are no open channels,
458
    /// return 180 seconds which is the minimum value of
459
    /// max_unused_duration.
460
36
    pub(crate) fn expire_channels(&self) -> Duration {
461
36
        self.channels.expire_channels()
462
36
    }
463

            
464
    /// Test only: return the open usable channels with a given `ident`.
465
    #[cfg(test)]
466
12
    pub(crate) fn get_nowait<'a, T>(&self, ident: T) -> Vec<Arc<CF::Channel>>
467
12
    where
468
12
        T: Into<tor_linkspec::RelayIdRef<'a>>,
469
    {
470
        use state::ChannelState::*;
471
12
        self.channels
472
12
            .with_channels(|channel_map| {
473
12
                channel_map
474
12
                    .by_id(ident)
475
12
                    .filter_map(|entry| match entry {
476
8
                        Open(ent) if ent.channel.is_usable() => Some(Arc::clone(&ent.channel)),
477
                        _ => None,
478
8
                    })
479
12
                    .collect()
480
12
            })
481
12
            .expect("Poisoned lock")
482
12
    }
483
}
484

            
485
/// Possible actions that we'll decide to take when asked for a channel.
486
#[allow(clippy::large_enum_variant)]
487
enum Action<C: AbstractChannel> {
488
    /// We found no channel.  We're going to launch a new one,
489
    /// then tell everybody about it.
490
    Launch((PendingChannelHandle, Sending)),
491
    /// We found an in-progress attempt at making a channel.
492
    /// We're going to wait for it to finish.
493
    Wait(Pending),
494
    /// We found a usable channel.  We're going to return it.
495
    Return(Result<Arc<C>>),
496
}
497

            
498
#[cfg(test)]
499
mod test {
500
    // @@ begin test lint list maintained by maint/add_warning @@
501
    #![allow(clippy::bool_assert_comparison)]
502
    #![allow(clippy::clone_on_copy)]
503
    #![allow(clippy::dbg_macro)]
504
    #![allow(clippy::mixed_attributes_style)]
505
    #![allow(clippy::print_stderr)]
506
    #![allow(clippy::print_stdout)]
507
    #![allow(clippy::single_char_pattern)]
508
    #![allow(clippy::unwrap_used)]
509
    #![allow(clippy::unchecked_time_subtraction)]
510
    #![allow(clippy::useless_vec)]
511
    #![allow(clippy::needless_pass_by_value)]
512
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
513
    use super::*;
514
    use crate::Error;
515

            
516
    use futures::join;
517
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
518
    use std::sync::Arc;
519
    use std::sync::atomic::{AtomicBool, Ordering};
520
    use std::time::Duration;
521
    use tor_error::bad_api_usage;
522
    use tor_linkspec::ChannelMethod;
523
    use tor_llcrypto::pk::ed25519::Ed25519Identity;
524
    use tor_memquota::ArcMemoryQuotaTrackerExt as _;
525

            
526
    use crate::ChannelUsage as CU;
527
    use tor_rtcompat::{Runtime, task::yield_now, test_with_one_runtime};
528

            
529
    // Two distinct addresses we can use in tests.
530
    const ADDR_A: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 443));
531
    const ADDR_B: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443));
532

            
533
    #[derive(Clone)]
534
    struct FakeChannelFactory<RT> {
535
        runtime: RT,
536
    }
537

            
538
    #[derive(Clone, Debug)]
539
    struct FakeChannel {
540
        ed_ident: Ed25519Identity,
541
        mood: char,
542
        closing: Arc<AtomicBool>,
543
        detect_reuse: Arc<char>,
544
        // last_params: Option<ChannelPaddingInstructionsUpdates>,
545
    }
546

            
547
    impl PartialEq for FakeChannel {
548
        fn eq(&self, other: &Self) -> bool {
549
            Arc::ptr_eq(&self.detect_reuse, &other.detect_reuse)
550
        }
551
    }
552

            
553
    impl AbstractChannel for FakeChannel {
554
        fn is_canonical(&self) -> bool {
555
            unimplemented!()
556
        }
557
        fn is_canonical_to_peer(&self) -> bool {
558
            unimplemented!()
559
        }
560
        fn is_usable(&self) -> bool {
561
            !self.closing.load(Ordering::SeqCst)
562
        }
563
        fn duration_unused(&self) -> Option<Duration> {
564
            None
565
        }
566
        fn reparameterize(
567
            &self,
568
            _updates: Arc<ChannelPaddingInstructionsUpdates>,
569
        ) -> tor_proto::Result<()> {
570
            // *self.last_params.lock().unwrap() = Some((*updates).clone());
571
            Ok(())
572
        }
573
        fn reparameterize_kist(&self, _kist_params: KistParams) -> tor_proto::Result<()> {
574
            Ok(())
575
        }
576
        fn engage_padding_activities(&self) {}
577
    }
578

            
579
    impl HasRelayIds for FakeChannel {
580
        fn identity(
581
            &self,
582
            key_type: tor_linkspec::RelayIdType,
583
        ) -> Option<tor_linkspec::RelayIdRef<'_>> {
584
            match key_type {
585
                tor_linkspec::RelayIdType::Ed25519 => Some((&self.ed_ident).into()),
586
                _ => None,
587
            }
588
        }
589
    }
590

            
591
    impl FakeChannel {
592
        fn start_closing(&self) {
593
            self.closing.store(true, Ordering::SeqCst);
594
        }
595
    }
596

            
597
    impl<RT: Runtime> FakeChannelFactory<RT> {
598
        fn new(runtime: RT) -> Self {
599
            FakeChannelFactory { runtime }
600
        }
601
    }
602

            
603
    fn new_test_abstract_chanmgr<R: Runtime>(runtime: R) -> AbstractChanMgr<FakeChannelFactory<R>> {
604
        let cf = FakeChannelFactory::new(runtime);
605
        AbstractChanMgr::new(
606
            cf,
607
            Default::default(),
608
            Default::default(),
609
            &Default::default(),
610
            BootstrapReporter::fake(),
611
            ToplevelAccount::new_noop(),
612
        )
613
    }
614

            
615
    #[derive(Clone, Debug)]
616
    struct FakeBuildSpec(u32, char, Ed25519Identity, SocketAddr);
617

            
618
    impl HasRelayIds for FakeBuildSpec {
619
        fn identity(
620
            &self,
621
            key_type: tor_linkspec::RelayIdType,
622
        ) -> Option<tor_linkspec::RelayIdRef<'_>> {
623
            match key_type {
624
                tor_linkspec::RelayIdType::Ed25519 => Some((&self.2).into()),
625
                _ => None,
626
            }
627
        }
628
    }
629

            
630
    impl HasChanMethod for FakeBuildSpec {
631
        fn chan_method(&self) -> ChannelMethod {
632
            ChannelMethod::Direct(vec![self.3.clone()])
633
        }
634
    }
635

            
636
    /// Helper to make a fake Ed identity from a u32.
637
    fn u32_to_ed(n: u32) -> Ed25519Identity {
638
        let mut bytes = [0; 32];
639
        bytes[0..4].copy_from_slice(&n.to_be_bytes());
640
        bytes.into()
641
    }
642

            
643
    #[async_trait]
644
    impl<RT: Runtime> AbstractChannelFactory for FakeChannelFactory<RT> {
645
        type Channel = FakeChannel;
646
        type BuildSpec = FakeBuildSpec;
647
        type Stream = ();
648

            
649
        async fn build_channel(
650
            &self,
651
            target: &Self::BuildSpec,
652
            _reporter: BootstrapReporter,
653
            _memquota: ChannelAccount,
654
        ) -> Result<Arc<FakeChannel>> {
655
            yield_now().await;
656
            let FakeBuildSpec(ident, mood, id, _addr) = *target;
657
            let ed_ident = u32_to_ed(ident);
658
            assert_eq!(ed_ident, id);
659
            match mood {
660
                // "X" means never connect.
661
                '❌' | '🔥' => return Err(Error::UnusableTarget(bad_api_usage!("emoji"))),
662
                // "zzz" means wait for 15 seconds then succeed.
663
                '💤' => {
664
                    self.runtime.sleep(Duration::new(15, 0)).await;
665
                }
666
                _ => {}
667
            }
668
            Ok(Arc::new(FakeChannel {
669
                ed_ident,
670
                mood,
671
                closing: Arc::new(AtomicBool::new(false)),
672
                detect_reuse: Default::default(),
673
                // last_params: None,
674
            }))
675
        }
676

            
677
        #[cfg(feature = "relay")]
678
        async fn build_channel_using_incoming(
679
            &self,
680
            _peer: Sensitive<std::net::SocketAddr>,
681
            _stream: Self::Stream,
682
            _memquota: ChannelAccount,
683
        ) -> Result<Arc<Self::Channel>> {
684
            unimplemented!()
685
        }
686
    }
687

            
688
    #[test]
689
    fn connect_one_ok() {
690
        test_with_one_runtime!(|runtime| async {
691
            let mgr = new_test_abstract_chanmgr(runtime);
692
            let target = FakeBuildSpec(413, '!', u32_to_ed(413), ADDR_A);
693
            let chan1 = mgr
694
                .get_or_launch(target.clone(), CU::UserTraffic)
695
                .await
696
                .unwrap()
697
                .0;
698
            let chan2 = mgr.get_or_launch(target, CU::UserTraffic).await.unwrap().0;
699

            
700
            assert_eq!(chan1, chan2);
701
            assert_eq!(mgr.get_nowait(&u32_to_ed(413)), vec![chan1]);
702
        });
703
    }
704

            
705
    #[test]
706
    fn connect_one_fail() {
707
        test_with_one_runtime!(|runtime| async {
708
            let mgr = new_test_abstract_chanmgr(runtime);
709

            
710
            // This is set up to always fail.
711
            let target = FakeBuildSpec(999, '❌', u32_to_ed(999), ADDR_A);
712
            let res1 = mgr.get_or_launch(target, CU::UserTraffic).await;
713
            assert!(matches!(res1, Err(Error::UnusableTarget(_))));
714

            
715
            assert!(mgr.get_nowait(&u32_to_ed(999)).is_empty());
716
        });
717
    }
718

            
719
    #[test]
720
    fn connect_different_address() {
721
        test_with_one_runtime!(|runtime| async {
722
            let mgr = new_test_abstract_chanmgr(runtime);
723

            
724
            // Two targets that have different addresses.
725
            let target1 = FakeBuildSpec(413, '!', u32_to_ed(413), ADDR_A);
726
            let mut target2 = target1.clone();
727
            target2.3 = ADDR_B;
728

            
729
            let chan1 = mgr.get_or_launch(target1, CU::UserTraffic).await.unwrap().0;
730
            let chan2 = mgr.get_or_launch(target2, CU::UserTraffic).await.unwrap().0;
731

            
732
            // Even with different addresses, the original channel is returned.
733
            assert_eq!(chan1, chan2);
734
            assert_eq!(mgr.get_nowait(&u32_to_ed(413)), vec![chan1]);
735
        });
736
    }
737

            
738
    #[test]
739
    fn test_concurrent() {
740
        test_with_one_runtime!(|runtime| async {
741
            let mgr = new_test_abstract_chanmgr(runtime);
742

            
743
            let usage = CU::UserTraffic;
744

            
745
            // TODO(nickm): figure out how to make these actually run
746
            // concurrently. Right now it seems that they don't actually
747
            // interact.
748
            let (ch3a, ch3b, ch44a, ch44b, ch50a, ch50b, ch86a, ch86b) = join!(
749
                mgr.get_or_launch(FakeBuildSpec(3, 'a', u32_to_ed(3), ADDR_A), usage),
750
                mgr.get_or_launch(FakeBuildSpec(3, 'b', u32_to_ed(3), ADDR_A), usage),
751
                mgr.get_or_launch(FakeBuildSpec(44, 'a', u32_to_ed(44), ADDR_A), usage),
752
                mgr.get_or_launch(FakeBuildSpec(44, 'b', u32_to_ed(44), ADDR_A), usage),
753
                mgr.get_or_launch(FakeBuildSpec(50, 'a', u32_to_ed(50), ADDR_A), usage),
754
                mgr.get_or_launch(FakeBuildSpec(50, 'b', u32_to_ed(50), ADDR_B), usage),
755
                mgr.get_or_launch(FakeBuildSpec(86, '❌', u32_to_ed(86), ADDR_A), usage),
756
                mgr.get_or_launch(FakeBuildSpec(86, '🔥', u32_to_ed(86), ADDR_A), usage),
757
            );
758
            let ch3a = ch3a.unwrap();
759
            let ch3b = ch3b.unwrap();
760
            let ch44a = ch44a.unwrap();
761
            let ch44b = ch44b.unwrap();
762
            let ch50a = ch50a.unwrap();
763
            let ch50b = ch50b.unwrap();
764
            let err_a = ch86a.unwrap_err();
765
            let err_b = ch86b.unwrap_err();
766

            
767
            assert_eq!(ch3a, ch3b);
768
            assert_eq!(ch44a, ch44b);
769
            assert_eq!(ch50a, ch50b);
770
            assert_ne!(ch44a, ch3a);
771

            
772
            assert!(matches!(err_a, Error::UnusableTarget(_)));
773
            assert!(matches!(err_b, Error::UnusableTarget(_)));
774
        });
775
    }
776

            
777
    #[test]
778
    fn unusable_entries() {
779
        test_with_one_runtime!(|runtime| async {
780
            let mgr = new_test_abstract_chanmgr(runtime);
781

            
782
            let (ch3, ch4, ch5) = join!(
783
                mgr.get_or_launch(FakeBuildSpec(3, 'a', u32_to_ed(3), ADDR_A), CU::UserTraffic),
784
                mgr.get_or_launch(FakeBuildSpec(4, 'a', u32_to_ed(4), ADDR_A), CU::UserTraffic),
785
                mgr.get_or_launch(FakeBuildSpec(5, 'a', u32_to_ed(5), ADDR_A), CU::UserTraffic),
786
            );
787

            
788
            let ch3 = ch3.unwrap().0;
789
            let _ch4 = ch4.unwrap();
790
            let ch5 = ch5.unwrap().0;
791

            
792
            ch3.start_closing();
793
            ch5.start_closing();
794

            
795
            let ch3_new = mgr
796
                .get_or_launch(FakeBuildSpec(3, 'b', u32_to_ed(3), ADDR_A), CU::UserTraffic)
797
                .await
798
                .unwrap()
799
                .0;
800
            assert_ne!(ch3, ch3_new);
801
            assert_eq!(ch3_new.mood, 'b');
802

            
803
            mgr.remove_unusable_entries().unwrap();
804

            
805
            assert!(!mgr.get_nowait(&u32_to_ed(3)).is_empty());
806
            assert!(!mgr.get_nowait(&u32_to_ed(4)).is_empty());
807
            assert!(mgr.get_nowait(&u32_to_ed(5)).is_empty());
808
        });
809
    }
810
}