1
//! Abstract implementation of a channel manager
2

            
3
use crate::factory::BootstrapReporter;
4
use crate::mgr::state::{ChannelForTarget, PendingChannelHandle};
5
use crate::{ChanProvenance, ChannelConfig, ChannelUsage, Dormancy, Error, Result};
6

            
7
use async_trait::async_trait;
8
use futures::future::Shared;
9
use oneshot_fused_workaround as oneshot;
10
use std::result::Result as StdResult;
11
use std::sync::Arc;
12
use std::time::Duration;
13
use tor_error::{error_report, internal};
14
use tor_linkspec::{HasChanMethod, HasRelayIds};
15
use tor_netdir::params::NetParameters;
16
use tor_proto::channel::kist::KistParams;
17
use tor_proto::channel::params::ChannelPaddingInstructionsUpdates;
18
use tor_proto::memquota::{ChannelAccount, SpecificAccount as _, ToplevelAccount};
19
use tracing::{instrument, trace};
20

            
21
#[cfg(feature = "relay")]
22
use {safelog::Sensitive, std::net::SocketAddr, tor_proto::RelayChannelAuthMaterial};
23

            
24
mod select;
25
mod state;
26

            
27
/// Trait to describe as much of a
28
/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
29
/// needs to use.
30
pub(crate) trait AbstractChannel: HasRelayIds {
31
    /// Return true iff this channel is considered canonical by us.
32
    fn is_canonical(&self) -> bool;
33
    /// Return true if we think the peer considers this channel as canonical.
34
    fn is_canonical_to_peer(&self) -> bool;
35
    /// Return true if this channel is usable.
36
    ///
37
    /// A channel might be unusable because it is closed, because it has
38
    /// hit a bug, or for some other reason.  We don't return unusable
39
    /// channels back to the user.
40
    fn is_usable(&self) -> bool;
41
    /// Return the amount of time a channel has not been in use.
42
    /// Return None if the channel is currently in use.
43
    fn duration_unused(&self) -> Option<Duration>;
44

            
45
    /// Reparameterize this channel according to the provided `ChannelPaddingInstructionsUpdates`
46
    ///
47
    /// The changed parameters may not be implemented "immediately",
48
    /// but this will be done "reasonably soon".
49
    fn reparameterize(
50
        &self,
51
        updates: Arc<ChannelPaddingInstructionsUpdates>,
52
    ) -> tor_proto::Result<()>;
53

            
54
    /// Update the KIST parameters.
55
    ///
56
    /// The changed parameters may not be implemented "immediately",
57
    /// but this will be done "reasonably soon".
58
    fn reparameterize_kist(&self, kist_params: KistParams) -> tor_proto::Result<()>;
59

            
60
    /// Specify that this channel should do activities related to channel padding
61
    ///
62
    /// See [`Channel::engage_padding_activities`]
63
    ///
64
    /// [`Channel::engage_padding_activities`]: tor_proto::channel::Channel::engage_padding_activities
65
    fn engage_padding_activities(&self);
66
}
67

            
68
/// Trait to describe how channels-like objects are created.
69
///
70
/// This differs from [`ChannelFactory`](crate::factory::ChannelFactory) in that
71
/// it's a purely crate-internal type that we use to decouple the
72
/// AbstractChanMgr code from actual "what is a channel" concerns.
73
#[async_trait]
74
pub(crate) trait AbstractChannelFactory {
75
    /// The type of channel that this factory can build.
76
    type Channel: AbstractChannel;
77
    /// Type that explains how to build an outgoing channel.
78
    type BuildSpec: HasRelayIds + HasChanMethod;
79
    /// The type of byte stream that's required to build channels for incoming connections.
80
    type Stream;
81

            
82
    /// Construct a new channel to the destination described at `target`.
83
    ///
84
    /// This function must take care of all timeouts, error detection,
85
    /// and so on.
86
    ///
87
    /// It should not retry; that is handled at a higher level.
88
    async fn build_channel(
89
        &self,
90
        target: &Self::BuildSpec,
91
        reporter: BootstrapReporter,
92
        memquota: ChannelAccount,
93
    ) -> Result<Arc<Self::Channel>>;
94

            
95
    /// Construct a new channel for an incoming connection.
96
    #[cfg(feature = "relay")]
97
    async fn build_channel_using_incoming(
98
        &self,
99
        peer: Sensitive<std::net::SocketAddr>,
100
        stream: Self::Stream,
101
        memquota: ChannelAccount,
102
    ) -> Result<Arc<Self::Channel>>;
103
}
104

            
105
/// This is the configuration for a [`ChanMgr`](crate::ChanMgr) given to the constructor.
106
#[derive(Default)]
107
pub struct ChanMgrConfig {
108
    /// Channel configuration which usually comes from a configuration file.
109
    pub(crate) cfg: ChannelConfig,
110
    /// Relay authentication key material for relay channels.
111
    #[cfg(feature = "relay")]
112
    pub(crate) auth_material: Option<Arc<RelayChannelAuthMaterial>>,
113
    /// Our address(es). When building outgoing channel, we need our addresses in order to send
114
    /// them in the NETINFO cell. It will also be used to validate initiator channel target.
115
    #[cfg(feature = "relay")]
116
    pub(crate) my_addrs: Vec<SocketAddr>,
117
    // TODO: Would be good to add more things such as NetParameters and Dormancy maybe?
118
}
119

            
120
impl ChanMgrConfig {
121
    /// Constructor.
122
    pub fn new(cfg: ChannelConfig) -> Self {
123
        Self {
124
            cfg,
125
            #[cfg(feature = "relay")]
126
            auth_material: None,
127
            #[cfg(feature = "relay")]
128
            my_addrs: Vec::new(),
129
        }
130
    }
131

            
132
    /// Set the relay channel authentication key material and return itself.
133
    #[cfg(feature = "relay")]
134
    pub fn with_auth_material(mut self, auth_material: Arc<RelayChannelAuthMaterial>) -> Self {
135
        self.auth_material = Some(auth_material);
136
        self
137
    }
138

            
139
    /// Set our addresses that we advertise to the world.
140
    #[cfg(feature = "relay")]
141
    pub fn with_my_addrs(mut self, my_addrs: Vec<SocketAddr>) -> Self {
142
        self.my_addrs = my_addrs;
143
        self
144
    }
145
}
146

            
147
/// A type- and network-agnostic implementation for [`ChanMgr`](crate::ChanMgr).
148
///
149
/// This type does the work of keeping track of open channels and pending
150
/// channel requests, launching requests as needed, waiting for pending
151
/// requests, and so forth.
152
///
153
/// The actual job of launching connections is deferred to an
154
/// `AbstractChannelFactory` type.
155
pub(crate) struct AbstractChanMgr<CF: AbstractChannelFactory> {
156
    /// All internal state held by this channel manager.
157
    ///
158
    /// The most important part is the map from relay identity to channel, or
159
    /// to pending channel status.
160
    pub(crate) channels: state::MgrState<CF>,
161

            
162
    /// A bootstrap reporter to give out when building channels.
163
    pub(crate) reporter: BootstrapReporter,
164

            
165
    /// The memory quota account that every channel will be a child of
166
    pub(crate) memquota: ToplevelAccount,
167

            
168
    /// Metrics counters / gauges / histograms.
169
    #[cfg(feature = "metrics")]
170
    pub(crate) metrics: ChanMgrMetrics,
171
}
172

            
173
/// Struct to hold all the metrics counters / gauges / histograms we use.
174
///
175
/// We create these and store them in the [`AbstractChanMgr`] in order to avoid
176
/// the performance hit associated with re-registering counters.
177
#[cfg(feature = "metrics")]
178
pub(crate) struct ChanMgrMetrics {
179
    /// Number of inbound channels successfully built.
180
    pub(crate) inbound_channels_built_success: metrics::Counter,
181
    /// Number of inbound channels that we tried to build but had an [`Error::UnusableTarget`] error.
182
    pub(crate) inbound_channels_built_failure_unusable_target: metrics::Counter,
183
    /// Number of inbound channels that we tried to build but had an [`Error::PendingFailed`] error.
184
    pub(crate) inbound_channels_built_failure_pending_failed: metrics::Counter,
185
    /// Number of inbound channels that we tried to build but had an [`Error::ChanTimeout`] error.
186
    pub(crate) inbound_channels_built_failure_chan_timeout: metrics::Counter,
187
    /// Number of inbound channels that we tried to build but had an [`Error::Proto`] error.
188
    pub(crate) inbound_channels_built_failure_proto: metrics::Counter,
189
    /// Number of inbound channels that we tried to build but had an [`Error::Io`] error.
190
    pub(crate) inbound_channels_built_failure_io: metrics::Counter,
191
    /// Number of inbound channels that we tried to build but had an [`Error::Connect`] error.
192
    pub(crate) inbound_channels_built_failure_connect: metrics::Counter,
193
    /// Number of inbound channels that we tried to build but had an [`Error::Spawn`] error.
194
    pub(crate) inbound_channels_built_failure_spawn: metrics::Counter,
195
    /// Number of inbound channels that we tried to build but had an [`Error::MissingId`] error.
196
    pub(crate) inbound_channels_built_failure_missing_id: metrics::Counter,
197
    /// Number of inbound channels that we tried to build but had an [`Error::IdentityConflict`] error.
198
    pub(crate) inbound_channels_built_failure_identity_conflict: metrics::Counter,
199
    /// Number of inbound channels that we tried to build but had an [`Error::NoSuchTransport`] error.
200
    pub(crate) inbound_channels_built_failure_no_such_transport: metrics::Counter,
201
    /// Number of inbound channels that we tried to build but had an [`Error::RequestCancelled`] error.
202
    pub(crate) inbound_channels_built_failure_request_cancelled: metrics::Counter,
203
    /// Number of inbound channels that we tried to build but had an [`Error::Pt`] error.
204
    pub(crate) inbound_channels_built_failure_pt: metrics::Counter,
205
    /// Number of inbound channels that we tried to build but had an [`Error::Memquota`] error.
206
    pub(crate) inbound_channels_built_failure_memquota: metrics::Counter,
207
    /// Number of inbound channels that we tried to build but had an [`Error::Internal`] error.
208
    pub(crate) inbound_channels_built_failure_internal: metrics::Counter,
209
}
210

            
211
#[cfg(feature = "metrics")]
212
impl ChanMgrMetrics {
213
    /// Create a new instance of [`ChanMgrMetrics`].
214
746
    pub(crate) fn new() -> Self {
215
746
        ChanMgrMetrics {
216
746
            inbound_channels_built_success: metrics::counter!(
217
746
                description: "Total number of channels built",
218
746
                unit: metrics::Unit::Count,
219
746
                "arti_chanmgr_channels_built",
220
746
                "result" => "success",
221
746
                "direction" => "inbound",
222
746
            ),
223
746
            inbound_channels_built_failure_unusable_target: metrics::counter!(
224
746
                description: "Total number of channels built",
225
746
                unit: metrics::Unit::Count,
226
746
                "arti_chanmgr_channels_built",
227
746
                "result" => "failure",
228
746
                "direction" => "inbound",
229
746
                "error" => "unusable_target",
230
746
            ),
231
746
            inbound_channels_built_failure_pending_failed: metrics::counter!(
232
746
                description: "Total number of channels built",
233
746
                unit: metrics::Unit::Count,
234
746
                "arti_chanmgr_channels_built",
235
746
                "result" => "failure",
236
746
                "direction" => "inbound",
237
746
                "error" => "pending_failed",
238
746
            ),
239
746
            inbound_channels_built_failure_chan_timeout: metrics::counter!(
240
746
                description: "Total number of channels built",
241
746
                unit: metrics::Unit::Count,
242
746
                "arti_chanmgr_channels_built",
243
746
                "result" => "failure",
244
746
                "direction" => "inbound",
245
746
                "error" => "chan_timeout",
246
746
            ),
247
746
            inbound_channels_built_failure_proto: metrics::counter!(
248
746
                description: "Total number of channels built",
249
746
                unit: metrics::Unit::Count,
250
746
                "arti_chanmgr_channels_built",
251
746
                "result" => "failure",
252
746
                "direction" => "inbound",
253
746
                "error" => "proto",
254
746
            ),
255
746
            inbound_channels_built_failure_io: metrics::counter!(
256
746
                description: "Total number of channels built",
257
746
                unit: metrics::Unit::Count,
258
746
                "arti_chanmgr_channels_built",
259
746
                "result" => "failure",
260
746
                "direction" => "inbound",
261
746
                "error" => "io",
262
746
            ),
263
746
            inbound_channels_built_failure_connect: metrics::counter!(
264
746
                description: "Total number of channels built",
265
746
                unit: metrics::Unit::Count,
266
746
                "arti_chanmgr_channels_built",
267
746
                "result" => "failure",
268
746
                "direction" => "inbound",
269
746
                "error" => "connect",
270
746
            ),
271
746
            inbound_channels_built_failure_spawn: metrics::counter!(
272
746
                description: "Total number of channels built",
273
746
                unit: metrics::Unit::Count,
274
746
                "arti_chanmgr_channels_built",
275
746
                "result" => "failure",
276
746
                "direction" => "inbound",
277
746
                "error" => "spawn",
278
746
            ),
279
746
            inbound_channels_built_failure_missing_id: metrics::counter!(
280
746
                description: "Total number of channels built",
281
746
                unit: metrics::Unit::Count,
282
746
                "arti_chanmgr_channels_built",
283
746
                "result" => "failure",
284
746
                "direction" => "inbound",
285
746
                "error" => "missing_id",
286
746
            ),
287
746
            inbound_channels_built_failure_identity_conflict: metrics::counter!(
288
746
                description: "Total number of channels built",
289
746
                unit: metrics::Unit::Count,
290
746
                "arti_chanmgr_channels_built",
291
746
                "result" => "failure",
292
746
                "direction" => "inbound",
293
746
                "error" => "identity_conflict",
294
746
            ),
295
746
            inbound_channels_built_failure_no_such_transport: metrics::counter!(
296
746
                description: "Total number of channels built",
297
746
                unit: metrics::Unit::Count,
298
746
                "arti_chanmgr_channels_built",
299
746
                "result" => "failure",
300
746
                "direction" => "inbound",
301
746
                "error" => "no_such_transport",
302
746
            ),
303
746
            inbound_channels_built_failure_request_cancelled: metrics::counter!(
304
746
                description: "Total number of channels built",
305
746
                unit: metrics::Unit::Count,
306
746
                "arti_chanmgr_channels_built",
307
746
                "result" => "failure",
308
746
                "direction" => "inbound",
309
746
                "error" => "request_cancelled",
310
746
            ),
311
746
            inbound_channels_built_failure_pt: metrics::counter!(
312
746
                description: "Total number of channels built",
313
746
                unit: metrics::Unit::Count,
314
746
                "arti_chanmgr_channels_built",
315
746
                "result" => "failure",
316
746
                "direction" => "inbound",
317
746
                "error" => "pt",
318
746
            ),
319
746
            inbound_channels_built_failure_memquota: metrics::counter!(
320
746
                description: "Total number of channels built",
321
746
                unit: metrics::Unit::Count,
322
746
                "arti_chanmgr_channels_built",
323
746
                "result" => "failure",
324
746
                "direction" => "inbound",
325
746
                "error" => "memquota",
326
746
            ),
327
746
            inbound_channels_built_failure_internal: metrics::counter!(
328
746
                description: "Total number of channels built",
329
746
                unit: metrics::Unit::Count,
330
746
                "arti_chanmgr_channels_built",
331
746
                "result" => "failure",
332
746
                "direction" => "inbound",
333
746
                "error" => "internal",
334
746
            ),
335
746
        }
336
746
    }
337

            
338
    /// Increment the correct inbound_channels_built counter for the given result.
339
    pub(crate) fn increment_inbound_channels_built<R>(&self, result: &Result<R>) {
340
        match result {
341
            Ok(_) => self.inbound_channels_built_success.increment(1),
342
            Err(Error::UnusableTarget(_)) => self
343
                .inbound_channels_built_failure_unusable_target
344
                .increment(1),
345
            Err(Error::PendingFailed { .. }) => self
346
                .inbound_channels_built_failure_pending_failed
347
                .increment(1),
348
            Err(Error::ChanTimeout { .. }) => self
349
                .inbound_channels_built_failure_chan_timeout
350
                .increment(1),
351
            Err(Error::Proto { .. }) => self.inbound_channels_built_failure_proto.increment(1),
352
            Err(Error::Io { .. }) => self.inbound_channels_built_failure_io.increment(1),
353
            Err(Error::Connect { .. }) => self.inbound_channels_built_failure_connect.increment(1),
354
            Err(Error::Spawn { .. }) => self.inbound_channels_built_failure_spawn.increment(1),
355
            Err(Error::MissingId) => self.inbound_channels_built_failure_missing_id.increment(1),
356
            Err(Error::IdentityConflict) => self
357
                .inbound_channels_built_failure_identity_conflict
358
                .increment(1),
359
            Err(Error::NoSuchTransport(_)) => self
360
                .inbound_channels_built_failure_no_such_transport
361
                .increment(1),
362
            Err(Error::RequestCancelled) => self
363
                .inbound_channels_built_failure_request_cancelled
364
                .increment(1),
365
            Err(Error::Pt(_)) => self.inbound_channels_built_failure_pt.increment(1),
366
            Err(Error::Memquota(_)) => self.inbound_channels_built_failure_memquota.increment(1),
367
            Err(Error::Internal(_)) => self.inbound_channels_built_failure_internal.increment(1),
368
        }
369
    }
370
}
371

            
372
/// Type alias for a future that we wait on to see when a pending
373
/// channel is done or failed.
374
type Pending = Shared<oneshot::Receiver<Result<()>>>;
375

            
376
/// Type alias for the sender we notify when we complete a channel (or fail to
377
/// complete it).
378
type Sending = oneshot::Sender<Result<()>>;
379

            
380
/// Keeps a pending launch entry and its waiters in sync.
381
///
382
/// Every exit path from a launch attempt must either remove the pending entry
383
/// or upgrade it to an open channel, and must notify all waiters with the
384
/// outcome. This guard makes cancellation and early returns follow the same
385
/// cleanup path as ordinary failures.
386
struct PendingLaunchGuard<'a, CF: AbstractChannelFactory> {
387
    /// Channel state used to remove or upgrade the pending entry.
388
    channels: &'a state::MgrState<CF>,
389
    /// Handle to the pending entry, if it has not yet been removed.
390
    handle: Option<PendingChannelHandle>,
391
    /// Sender used to notify tasks waiting on this launch.
392
    send: Option<Sending>,
393
    /// Result to report to the waiters if the launch ends here.
394
    result: Result<()>,
395
}
396

            
397
impl<'a, CF: AbstractChannelFactory> PendingLaunchGuard<'a, CF> {
398
    /// Create a new guard for a pending launch.
399
82
    fn new(channels: &'a state::MgrState<CF>, handle: PendingChannelHandle, send: Sending) -> Self {
400
82
        Self {
401
82
            channels,
402
82
            handle: Some(handle),
403
82
            send: Some(send),
404
82
            result: Err(Error::RequestCancelled),
405
82
        }
406
82
    }
407

            
408
    /// Record the result that should be reported to any waiters.
409
78
    fn note_result(&mut self, result: Result<()>) {
410
78
        self.result = result;
411
78
    }
412

            
413
    /// Replace the pending channel with an open one.
414
70
    fn upgrade_pending_channel_to_open(&mut self, channel: Arc<CF::Channel>) -> Result<()> {
415
70
        let handle = self
416
70
            .handle
417
70
            .take()
418
70
            .expect("pending launch guard lost its handle before upgrade");
419
70
        self.channels
420
70
            .upgrade_pending_channel_to_open(handle, channel)
421
70
    }
422
}
423

            
424
impl<'a, CF: AbstractChannelFactory> Drop for PendingLaunchGuard<'a, CF> {
425
82
    fn drop(&mut self) {
426
82
        if let Some(handle) = self.handle.take() {
427
12
            if let Err(e) = self.channels.remove_pending_channel(handle) {
428
                // Just log an error if we're unable to remove it, since there's
429
                // nothing else we can do here, and returning the error would
430
                // hide the actual error that we care about (the channel build
431
                // failure).
432
                #[allow(clippy::missing_docs_in_private_items)]
433
                const MSG: &str = "Unable to remove the pending channel";
434
                error_report!(internal!("{e}"), "{}", MSG);
435
12
            }
436
70
        }
437

            
438
82
        if let Some(send) = self.send.take() {
439
82
            // It's okay if all the receivers went away:
440
82
            // that means that nobody was waiting for this channel.
441
82
            let _ignore_err = send.send(self.result.clone());
442
82
        }
443
82
    }
444
}
445

            
446
impl<CF: AbstractChannelFactory + Clone> AbstractChanMgr<CF> {
447
    /// Make a new empty channel manager.
448
98
    pub(crate) fn new(
449
98
        connector: CF,
450
98
        config: ChannelConfig,
451
98
        dormancy: Dormancy,
452
98
        netparams: &NetParameters,
453
98
        reporter: BootstrapReporter,
454
98
        memquota: ToplevelAccount,
455
98
    ) -> Self {
456
98
        AbstractChanMgr {
457
98
            channels: state::MgrState::new(connector, config, dormancy, netparams),
458
98
            reporter,
459
98
            memquota,
460
98
            #[cfg(feature = "metrics")]
461
98
            metrics: ChanMgrMetrics::new(),
462
98
        }
463
98
    }
464

            
465
    /// Run a function to modify the channel builder in this object.
466
    #[allow(unused)]
467
    pub(crate) fn with_mut_builder<F>(&self, func: F)
468
    where
469
        F: FnOnce(&mut CF),
470
    {
471
        self.channels.with_mut_builder(func);
472
    }
473

            
474
    /// Remove every unusable entry from this channel manager.
475
    #[cfg(test)]
476
2
    pub(crate) fn remove_unusable_entries(&self) -> Result<()> {
477
2
        self.channels.remove_unusable()
478
2
    }
479

            
480
    /// Build a channel for an incoming stream. See
481
    /// [`ChanMgr::handle_incoming`](crate::ChanMgr::handle_incoming).
482
    #[cfg(feature = "relay")]
483
    pub(crate) async fn handle_incoming(
484
        &self,
485
        src: Sensitive<std::net::SocketAddr>,
486
        stream: CF::Stream,
487
    ) -> Result<Arc<CF::Channel>> {
488
        let chan_builder = self.channels.builder();
489
        let memquota = ChannelAccount::new(&self.memquota)?;
490
        let channel = chan_builder
491
            .build_channel_using_incoming(src, stream, memquota)
492
            .await?;
493
        // Add it to our list.
494
        self.channels.add_open(channel.clone())?;
495
        Ok(channel)
496
    }
497

            
498
    /// Get a channel corresponding to the identities of `target`.
499
    ///
500
    /// If a usable channel exists with that identity, return it.
501
    ///
502
    /// If no such channel exists already, and none is in progress,
503
    /// launch a new request using `target`.
504
    ///
505
    /// If no such channel exists already, but we have one that's in
506
    /// progress, wait for it to succeed or fail.
507
    #[instrument(skip_all, level = "trace")]
508
92
    pub(crate) async fn get_or_launch(
509
92
        &self,
510
92
        target: CF::BuildSpec,
511
92
        usage: ChannelUsage,
512
92
    ) -> Result<(Arc<CF::Channel>, ChanProvenance)> {
513
        use ChannelUsage as CU;
514

            
515
        let chan = self.get_or_launch_internal(target).await?;
516

            
517
92
        match usage {
518
            CU::Dir | CU::UselessCircuit => {}
519
            CU::UserTraffic => chan.0.engage_padding_activities(),
520
        }
521

            
522
        Ok(chan)
523
88
    }
524

            
525
    /// Get a channel whose identity is `ident` - internal implementation
526
    #[allow(clippy::cognitive_complexity)]
527
    #[instrument(skip_all, level = "trace")]
528
92
    async fn get_or_launch_internal(
529
92
        &self,
530
92
        target: CF::BuildSpec,
531
92
    ) -> Result<(Arc<CF::Channel>, ChanProvenance)> {
532
        /// How many times do we try?
533
        const N_ATTEMPTS: usize = 2;
534
        let mut attempts_so_far = 0;
535
        let mut final_attempt = false;
536
        let mut provenance = ChanProvenance::Preexisting;
537

            
538
        // TODO(nickm): It would be neat to use tor_retry instead.
539
        let mut last_err = None;
540

            
541
        while attempts_so_far < N_ATTEMPTS || final_attempt {
542
            attempts_so_far += 1;
543

            
544
            // For each attempt, we _first_ look at the state of the channel map
545
            // to decide on an `Action`, and _then_ we execute that action.
546

            
547
            // First, see what state we're in, and what we should do about it.
548
            let action = self.choose_action(&target, final_attempt)?;
549

            
550
            // We are done deciding on our Action! It's time act based on the
551
            // Action that we chose.
552
            match action {
553
                // If this happens, we were trying to make one final check of our state, but
554
                // we would have had to make additional attempts.
555
                None => {
556
                    if !final_attempt {
557
                        return Err(Error::Internal(internal!(
558
                            "No action returned while not on final attempt"
559
                        )));
560
                    }
561
                    break;
562
                }
563
                // Easy case: we have an error or a channel to return.
564
                Some(Action::Return(v)) => {
565
                    trace!("Returning existing channel");
566
10
                    return v.map(|chan| (chan, provenance));
567
                }
568
                // There's an in-progress channel.  Wait for it.
569
                Some(Action::Wait(pend)) => {
570
                    trace!("Waiting for in-progress channel");
571
                    match pend.await {
572
                        Ok(Ok(())) => {
573
                            // We were waiting for a channel, and it succeeded, or it
574
                            // got cancelled.  But it might have gotten more
575
                            // identities while negotiating than it had when it was
576
                            // launched, or it might have failed to get all the
577
                            // identities we want. Check for this.
578
                            final_attempt = true;
579
                            provenance = ChanProvenance::NewlyCreated;
580
                            last_err.get_or_insert(Error::RequestCancelled);
581
                        }
582
                        Ok(Err(e)) => {
583
                            last_err = Some(e);
584
                        }
585
                        Err(_) => {
586
                            last_err =
587
                                Some(Error::Internal(internal!("channel build task disappeared")));
588
                        }
589
                    }
590
                }
591
                // We need to launch a channel.
592
                Some(Action::Launch((handle, send))) => {
593
                    trace!("Launching channel");
594
                    let connector = self.channels.builder();
595
                    let mut launch = PendingLaunchGuard::new(&self.channels, handle, send);
596
                    let memquota = match ChannelAccount::new(&self.memquota) {
597
                        Ok(memquota) => memquota,
598
                        Err(e) => {
599
                            let e: Error = e.into();
600
                            launch.note_result(Err(e.clone()));
601
                            return Err(e);
602
                        }
603
                    };
604

            
605
                    let outcome = connector
606
                        .build_channel(&target, self.reporter.clone(), memquota)
607
                        .await;
608

            
609
                    match outcome {
610
                        Ok(ref chan) => {
611
                            // Replace the pending channel with the newly built channel.
612
                            match launch.upgrade_pending_channel_to_open(Arc::clone(chan)) {
613
                                Ok(()) => launch.note_result(Ok(())),
614
                                Err(e) => {
615
                                    launch.note_result(Err(e.clone()));
616
                                    return Err(e);
617
                                }
618
                            }
619
                        }
620
                        Err(_) => {
621
                            launch.note_result(outcome.clone().map(|_| ()));
622
                        }
623
                    }
624

            
625
                    match outcome {
626
                        Ok(chan) => {
627
                            return Ok((chan, ChanProvenance::NewlyCreated));
628
                        }
629
                        Err(e) => last_err = Some(e),
630
                    }
631
                }
632
            }
633

            
634
            // End of this attempt. We will try again...
635
        }
636

            
637
        Err(last_err.unwrap_or_else(|| Error::Internal(internal!("no error was set!?"))))
638
88
    }
639

            
640
    /// Helper: based on our internal state, decide which action to take when
641
    /// asked for a channel, and update our internal state accordingly.
642
    ///
643
    /// If `final_attempt` is true, then we will not pick any action that does
644
    /// not result in an immediate result. If we would pick such an action, we
645
    /// instead return `Ok(None)`.  (We could instead have the caller detect
646
    /// such actions, but it's less efficient to construct them, insert them,
647
    /// and immediately revert them.)
648
    #[instrument(skip_all, level = "trace")]
649
108
    fn choose_action(
650
108
        &self,
651
108
        target: &CF::BuildSpec,
652
108
        final_attempt: bool,
653
108
    ) -> Result<Option<Action<CF::Channel>>> {
654
        // don't create new channels on the final attempt
655
108
        let response = self.channels.request_channel(
656
108
            target,
657
108
            /* add_new_entry_if_not_found= */ !final_attempt,
658
        );
659

            
660
108
        match response {
661
10
            Ok(Some(ChannelForTarget::Open(channel))) => Ok(Some(Action::Return(Ok(channel)))),
662
16
            Ok(Some(ChannelForTarget::Pending(pending))) => {
663
16
                if !final_attempt {
664
16
                    Ok(Some(Action::Wait(pending)))
665
                } else {
666
                    // don't return a pending channel on the final attempt
667
                    Ok(None)
668
                }
669
            }
670
82
            Ok(Some(ChannelForTarget::NewEntry((handle, send)))) => {
671
                // do not drop the handle if refactoring; see `PendingChannelHandle` for details
672
82
                Ok(Some(Action::Launch((handle, send))))
673
            }
674
            Ok(None) => Ok(None),
675
            Err(e @ Error::IdentityConflict) => Ok(Some(Action::Return(Err(e)))),
676
            Err(e) => Err(e),
677
        }
678
108
    }
679

            
680
    /// Update the netdir
681
24
    pub(crate) fn update_netparams(
682
24
        &self,
683
24
        netparams: Arc<dyn AsRef<NetParameters>>,
684
24
    ) -> StdResult<(), tor_error::Bug> {
685
24
        self.channels.reconfigure_general(None, None, netparams)
686
24
    }
687

            
688
    /// Notifies the chanmgr to be dormant like dormancy
689
24
    pub(crate) fn set_dormancy(
690
24
        &self,
691
24
        dormancy: Dormancy,
692
24
        netparams: Arc<dyn AsRef<NetParameters>>,
693
24
    ) -> StdResult<(), tor_error::Bug> {
694
24
        self.channels
695
24
            .reconfigure_general(None, Some(dormancy), netparams)
696
24
    }
697

            
698
    /// Reconfigure all channels
699
24
    pub(crate) fn reconfigure(
700
24
        &self,
701
24
        config: &ChannelConfig,
702
24
        netparams: Arc<dyn AsRef<NetParameters>>,
703
24
    ) -> StdResult<(), tor_error::Bug> {
704
24
        self.channels
705
24
            .reconfigure_general(Some(config), None, netparams)
706
24
    }
707

            
708
    /// Expire any channels that have been unused longer than
709
    /// their maximum unused duration assigned during creation.
710
    ///
711
    /// Return a duration from now until next channel expires.
712
    ///
713
    /// If all channels are in use or there are no open channels,
714
    /// return 180 seconds which is the minimum value of
715
    /// max_unused_duration.
716
    pub(crate) fn expire_channels(&self) -> Duration {
717
        self.channels.expire_channels()
718
    }
719

            
720
    /// Test only: return the open usable channels with a given `ident`.
721
    #[cfg(test)]
722
14
    pub(crate) fn get_nowait<'a, T>(&self, ident: T) -> Vec<Arc<CF::Channel>>
723
14
    where
724
14
        T: Into<tor_linkspec::RelayIdRef<'a>>,
725
    {
726
        use state::ChannelState::*;
727
14
        self.channels
728
14
            .with_channels(|channel_map| {
729
14
                channel_map
730
14
                    .by_id(ident)
731
14
                    .filter_map(|entry| match entry {
732
8
                        Open(ent) if ent.channel.is_usable() => Some(Arc::clone(&ent.channel)),
733
                        _ => None,
734
8
                    })
735
14
                    .collect()
736
14
            })
737
14
            .expect("Poisoned lock")
738
14
    }
739
}
740

            
741
/// Possible actions that we'll decide to take when asked for a channel.
742
#[allow(clippy::large_enum_variant)]
743
enum Action<C: AbstractChannel> {
744
    /// We found no channel.  We're going to launch a new one,
745
    /// then tell everybody about it.
746
    Launch((PendingChannelHandle, Sending)),
747
    /// We found an in-progress attempt at making a channel.
748
    /// We're going to wait for it to finish.
749
    Wait(Pending),
750
    /// We found a usable channel.  We're going to return it.
751
    Return(Result<Arc<C>>),
752
}
753

            
754
#[cfg(test)]
755
mod test {
756
    // @@ begin test lint list maintained by maint/add_warning @@
757
    #![allow(clippy::bool_assert_comparison)]
758
    #![allow(clippy::clone_on_copy)]
759
    #![allow(clippy::dbg_macro)]
760
    #![allow(clippy::mixed_attributes_style)]
761
    #![allow(clippy::print_stderr)]
762
    #![allow(clippy::print_stdout)]
763
    #![allow(clippy::single_char_pattern)]
764
    #![allow(clippy::unwrap_used)]
765
    #![allow(clippy::unchecked_time_subtraction)]
766
    #![allow(clippy::useless_vec)]
767
    #![allow(clippy::needless_pass_by_value)]
768
    #![allow(clippy::string_slice)] // See arti#2571
769
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
770
    use super::*;
771
    use crate::Error;
772

            
773
    use futures::{join, poll};
774
    use std::error::Error as StdError;
775
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
776
    use std::sync::Arc;
777
    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
778
    use std::time::Duration;
779
    use tor_error::bad_api_usage;
780
    use tor_linkspec::ChannelMethod;
781
    use tor_llcrypto::pk::ed25519::Ed25519Identity;
782
    use tor_memquota::ArcMemoryQuotaTrackerExt as _;
783

            
784
    use crate::ChannelUsage as CU;
785
    use tor_rtcompat::{Runtime, task::yield_now, test_with_one_runtime};
786

            
787
    // Two distinct addresses we can use in tests.
788
    const ADDR_A: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 443));
789
    const ADDR_B: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443));
790

            
791
    #[derive(Clone)]
792
    struct FakeChannelFactory<RT> {
793
        runtime: RT,
794
        build_attempts: Arc<AtomicUsize>,
795
    }
796

            
797
    #[derive(Clone, Debug)]
798
    struct FakeChannel {
799
        ed_ident: Ed25519Identity,
800
        mood: char,
801
        closing: Arc<AtomicBool>,
802
        detect_reuse: Arc<char>,
803
        // last_params: Option<ChannelPaddingInstructionsUpdates>,
804
    }
805

            
806
    impl PartialEq for FakeChannel {
807
        fn eq(&self, other: &Self) -> bool {
808
            Arc::ptr_eq(&self.detect_reuse, &other.detect_reuse)
809
        }
810
    }
811

            
812
    impl AbstractChannel for FakeChannel {
813
        fn is_canonical(&self) -> bool {
814
            unimplemented!()
815
        }
816
        fn is_canonical_to_peer(&self) -> bool {
817
            unimplemented!()
818
        }
819
        fn is_usable(&self) -> bool {
820
            !self.closing.load(Ordering::SeqCst)
821
        }
822
        fn duration_unused(&self) -> Option<Duration> {
823
            None
824
        }
825
        fn reparameterize(
826
            &self,
827
            _updates: Arc<ChannelPaddingInstructionsUpdates>,
828
        ) -> tor_proto::Result<()> {
829
            // *self.last_params.lock().unwrap() = Some((*updates).clone());
830
            match self.mood {
831
                // Build succeeds, but installing the channel into the manager fails.
832
                'r' => Err(tor_proto::Error::ChanProto(
833
                    "synthetic reparameterize failure".into(),
834
                )),
835
                _ => Ok(()),
836
            }
837
        }
838
        fn reparameterize_kist(&self, _kist_params: KistParams) -> tor_proto::Result<()> {
839
            Ok(())
840
        }
841
        fn engage_padding_activities(&self) {}
842
    }
843

            
844
    impl HasRelayIds for FakeChannel {
845
        fn identity(
846
            &self,
847
            key_type: tor_linkspec::RelayIdType,
848
        ) -> Option<tor_linkspec::RelayIdRef<'_>> {
849
            match key_type {
850
                tor_linkspec::RelayIdType::Ed25519 => Some((&self.ed_ident).into()),
851
                _ => None,
852
            }
853
        }
854
    }
855

            
856
    impl FakeChannel {
857
        fn start_closing(&self) {
858
            self.closing.store(true, Ordering::SeqCst);
859
        }
860
    }
861

            
862
    impl<RT: Runtime> FakeChannelFactory<RT> {
863
        fn new(runtime: RT, build_attempts: Arc<AtomicUsize>) -> Self {
864
            FakeChannelFactory {
865
                runtime,
866
                build_attempts,
867
            }
868
        }
869
    }
870

            
871
    fn new_test_abstract_chanmgr<R: Runtime>(runtime: R) -> AbstractChanMgr<FakeChannelFactory<R>> {
872
        new_test_abstract_chanmgr_and_build_attempts(runtime).0
873
    }
874

            
875
    fn new_test_abstract_chanmgr_and_build_attempts<R: Runtime>(
876
        runtime: R,
877
    ) -> (AbstractChanMgr<FakeChannelFactory<R>>, Arc<AtomicUsize>) {
878
        let build_attempts = Arc::new(AtomicUsize::new(0));
879
        let cf = FakeChannelFactory::new(runtime, Arc::clone(&build_attempts));
880
        let mgr = AbstractChanMgr::new(
881
            cf,
882
            Default::default(),
883
            Default::default(),
884
            &Default::default(),
885
            BootstrapReporter::fake(),
886
            ToplevelAccount::new_noop(),
887
        );
888
        (mgr, build_attempts)
889
    }
890

            
891
    #[derive(Clone, Debug)]
892
    struct FakeBuildSpec(u32, char, Ed25519Identity, SocketAddr);
893

            
894
    impl HasRelayIds for FakeBuildSpec {
895
        fn identity(
896
            &self,
897
            key_type: tor_linkspec::RelayIdType,
898
        ) -> Option<tor_linkspec::RelayIdRef<'_>> {
899
            match key_type {
900
                tor_linkspec::RelayIdType::Ed25519 => Some((&self.2).into()),
901
                _ => None,
902
            }
903
        }
904
    }
905

            
906
    impl HasChanMethod for FakeBuildSpec {
907
        fn chan_method(&self) -> ChannelMethod {
908
            ChannelMethod::Direct(vec![self.3.clone()])
909
        }
910
    }
911

            
912
    /// Helper to make a fake Ed identity from a u32.
913
    fn u32_to_ed(n: u32) -> Ed25519Identity {
914
        let mut bytes = [0; 32];
915
        bytes[0..4].copy_from_slice(&n.to_be_bytes());
916
        bytes.into()
917
    }
918

            
919
    /// Return true if `needle` appears anywhere in `err`'s error chain.
920
    fn error_contains(err: &Error, needle: &str) -> bool {
921
        let mut source: Option<&(dyn StdError + 'static)> = Some(err);
922
        while let Some(err) = source {
923
            if err.to_string().contains(needle) || format!("{err:?}").contains(needle) {
924
                return true;
925
            }
926
            source = err.source();
927
        }
928
        false
929
    }
930

            
931
    #[async_trait]
932
    impl<RT: Runtime> AbstractChannelFactory for FakeChannelFactory<RT> {
933
        type Channel = FakeChannel;
934
        type BuildSpec = FakeBuildSpec;
935
        type Stream = ();
936

            
937
        async fn build_channel(
938
            &self,
939
            target: &Self::BuildSpec,
940
            _reporter: BootstrapReporter,
941
            _memquota: ChannelAccount,
942
        ) -> Result<Arc<FakeChannel>> {
943
            self.build_attempts.fetch_add(1, Ordering::SeqCst);
944
            yield_now().await;
945
            let FakeBuildSpec(ident, mood, id, _addr) = *target;
946
            let ed_ident = u32_to_ed(ident);
947
            assert_eq!(ed_ident, id);
948
            match mood {
949
                // "X" means never connect.
950
                '❌' | '🔥' => return Err(Error::UnusableTarget(bad_api_usage!("emoji"))),
951
                // "zzz" means wait for 15 seconds then succeed.
952
                '💤' => {
953
                    self.runtime.sleep(Duration::new(15, 0)).await;
954
                }
955
                _ => {}
956
            }
957
            Ok(Arc::new(FakeChannel {
958
                ed_ident,
959
                mood,
960
                closing: Arc::new(AtomicBool::new(false)),
961
                detect_reuse: Default::default(),
962
                // last_params: None,
963
            }))
964
        }
965

            
966
        #[cfg(feature = "relay")]
967
        async fn build_channel_using_incoming(
968
            &self,
969
            _peer: Sensitive<std::net::SocketAddr>,
970
            _stream: Self::Stream,
971
            _memquota: ChannelAccount,
972
        ) -> Result<Arc<Self::Channel>> {
973
            unimplemented!()
974
        }
975
    }
976

            
977
    #[test]
978
    fn connect_one_ok() {
979
        test_with_one_runtime!(|runtime| async {
980
            let mgr = new_test_abstract_chanmgr(runtime);
981
            let target = FakeBuildSpec(413, '!', u32_to_ed(413), ADDR_A);
982
            let chan1 = mgr
983
                .get_or_launch(target.clone(), CU::UserTraffic)
984
                .await
985
                .unwrap()
986
                .0;
987
            let chan2 = mgr.get_or_launch(target, CU::UserTraffic).await.unwrap().0;
988

            
989
            assert_eq!(chan1, chan2);
990
            assert_eq!(mgr.get_nowait(&u32_to_ed(413)), vec![chan1]);
991
        });
992
    }
993

            
994
    #[test]
995
    fn connect_one_fail() {
996
        test_with_one_runtime!(|runtime| async {
997
            let mgr = new_test_abstract_chanmgr(runtime);
998

            
999
            // This is set up to always fail.
            let target = FakeBuildSpec(999, '❌', u32_to_ed(999), ADDR_A);
            let res1 = mgr.get_or_launch(target, CU::UserTraffic).await;
            assert!(matches!(res1, Err(Error::UnusableTarget(_))));
            assert!(mgr.get_nowait(&u32_to_ed(999)).is_empty());
        });
    }
    #[test]
    fn connect_different_address() {
        test_with_one_runtime!(|runtime| async {
            let mgr = new_test_abstract_chanmgr(runtime);
            // Two targets that have different addresses.
            let target1 = FakeBuildSpec(413, '!', u32_to_ed(413), ADDR_A);
            let mut target2 = target1.clone();
            target2.3 = ADDR_B;
            let chan1 = mgr.get_or_launch(target1, CU::UserTraffic).await.unwrap().0;
            let chan2 = mgr.get_or_launch(target2, CU::UserTraffic).await.unwrap().0;
            // Even with different addresses, the original channel is returned.
            assert_eq!(chan1, chan2);
            assert_eq!(mgr.get_nowait(&u32_to_ed(413)), vec![chan1]);
        });
    }
    #[test]
    fn test_concurrent() {
        test_with_one_runtime!(|runtime| async {
            let mgr = new_test_abstract_chanmgr(runtime);
            let usage = CU::UserTraffic;
            // TODO(nickm): figure out how to make these actually run
            // concurrently. Right now it seems that they don't actually
            // interact.
            let (ch3a, ch3b, ch44a, ch44b, ch50a, ch50b, ch86a, ch86b) = join!(
                mgr.get_or_launch(FakeBuildSpec(3, 'a', u32_to_ed(3), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(3, 'b', u32_to_ed(3), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(44, 'a', u32_to_ed(44), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(44, 'b', u32_to_ed(44), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(50, 'a', u32_to_ed(50), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(50, 'b', u32_to_ed(50), ADDR_B), usage),
                mgr.get_or_launch(FakeBuildSpec(86, '❌', u32_to_ed(86), ADDR_A), usage),
                mgr.get_or_launch(FakeBuildSpec(86, '🔥', u32_to_ed(86), ADDR_A), usage),
            );
            let ch3a = ch3a.unwrap();
            let ch3b = ch3b.unwrap();
            let ch44a = ch44a.unwrap();
            let ch44b = ch44b.unwrap();
            let ch50a = ch50a.unwrap();
            let ch50b = ch50b.unwrap();
            let err_a = ch86a.unwrap_err();
            let err_b = ch86b.unwrap_err();
            assert_eq!(ch3a, ch3b);
            assert_eq!(ch44a, ch44b);
            assert_eq!(ch50a, ch50b);
            assert_ne!(ch44a, ch3a);
            assert!(matches!(err_a, Error::UnusableTarget(_)));
            assert!(matches!(err_b, Error::UnusableTarget(_)));
        });
    }
    #[test]
    fn dropped_launch_reports_request_cancelled_to_waiters() {
        test_with_one_runtime!(|runtime| async {
            let mgr = new_test_abstract_chanmgr(runtime);
            let target = FakeBuildSpec(777, '💤', u32_to_ed(777), ADDR_A);
            let usage = CU::UserTraffic;
            let mut owner1 = Box::pin(mgr.get_or_launch(target.clone(), usage));
            assert!(poll!(&mut owner1).is_pending());
            let mut waiter = Box::pin(mgr.get_or_launch(target.clone(), usage));
            assert!(poll!(&mut waiter).is_pending());
            drop(owner1);
            let mut owner2 = Box::pin(mgr.get_or_launch(target, usage));
            assert!(poll!(&mut owner2).is_pending());
            assert!(poll!(&mut waiter).is_pending());
            drop(owner2);
            let waiter = waiter.await;
            assert!(
                matches!(&waiter, Err(Error::RequestCancelled)),
                "{waiter:?}"
            );
            if let Err(ref err) = waiter {
                assert!(!error_contains(err, "channel build task disappeared"));
            }
        });
    }
    #[test]
    fn failed_upgrade_reports_original_error_without_owner_retry() {
        test_with_one_runtime!(|runtime| async {
            let (mgr, build_attempts) = new_test_abstract_chanmgr_and_build_attempts(runtime);
            let target = FakeBuildSpec(778, 'r', u32_to_ed(778), ADDR_A);
            let usage = CU::UserTraffic;
            let mut owner = Box::pin(mgr.get_or_launch(target.clone(), usage));
            assert!(poll!(&mut owner).is_pending());
            let mut waiter = Box::pin(mgr.get_or_launch(target.clone(), usage));
            assert!(poll!(&mut waiter).is_pending());
            let owner = owner.await;
            assert!(matches!(&owner, Err(Error::Internal(_))), "{owner:?}");
            if let Err(ref err) = owner {
                assert!(error_contains(err, "failure on new channel"));
                assert!(!error_contains(err, "channel build task disappeared"));
            }
            assert_eq!(build_attempts.load(Ordering::SeqCst), 1);
            assert!(mgr.get_nowait(&u32_to_ed(778)).is_empty());
            let waiter = waiter.await;
            assert!(matches!(&waiter, Err(Error::Internal(_))), "{waiter:?}");
            if let Err(ref err) = waiter {
                assert!(error_contains(err, "failure on new channel"));
                assert!(!error_contains(err, "channel build task disappeared"));
            }
        });
    }
    #[test]
    fn unusable_entries() {
        test_with_one_runtime!(|runtime| async {
            let mgr = new_test_abstract_chanmgr(runtime);
            let (ch3, ch4, ch5) = join!(
                mgr.get_or_launch(FakeBuildSpec(3, 'a', u32_to_ed(3), ADDR_A), CU::UserTraffic),
                mgr.get_or_launch(FakeBuildSpec(4, 'a', u32_to_ed(4), ADDR_A), CU::UserTraffic),
                mgr.get_or_launch(FakeBuildSpec(5, 'a', u32_to_ed(5), ADDR_A), CU::UserTraffic),
            );
            let ch3 = ch3.unwrap().0;
            let _ch4 = ch4.unwrap();
            let ch5 = ch5.unwrap().0;
            ch3.start_closing();
            ch5.start_closing();
            let ch3_new = mgr
                .get_or_launch(FakeBuildSpec(3, 'b', u32_to_ed(3), ADDR_A), CU::UserTraffic)
                .await
                .unwrap()
                .0;
            assert_ne!(ch3, ch3_new);
            assert_eq!(ch3_new.mood, 'b');
            mgr.remove_unusable_entries().unwrap();
            assert!(!mgr.get_nowait(&u32_to_ed(3)).is_empty());
            assert!(!mgr.get_nowait(&u32_to_ed(4)).is_empty());
            assert!(mgr.get_nowait(&u32_to_ed(5)).is_empty());
        });
    }
}