1
//! Types and code to map circuit IDs to circuits.
2

            
3
// NOTE: This is a work in progress and I bet I'll refactor it a lot;
4
// it needs to stay opaque!
5

            
6
use crate::circuit::CircuitRxSender;
7
use crate::client::circuit::padding::{PaddingController, QueuedCellPaddingInfo};
8
use crate::{Error, Result};
9
use tor_basic_utils::RngExt;
10
use tor_cell::chancell::CircId;
11
use tor_cell::chancell::msg::DestroyReason;
12

            
13
use crate::circuit::celltypes::CreateResponse;
14
use crate::client::circuit::halfcirc::HalfCirc;
15

            
16
use oneshot_fused_workaround as oneshot;
17

            
18
use rand::Rng;
19
use rand::distr::Distribution;
20
use std::collections::{HashMap, hash_map::Entry};
21
use std::ops::{Deref, DerefMut};
22
use std::result::Result as StdResult;
23
use std::sync::Arc;
24

            
25
#[cfg(feature = "relay")]
26
use crate::relay::RelayCirc;
27

            
28
/// Which group of circuit IDs are we allowed to allocate in this map?
29
///
30
/// If we initiated the channel, we use High circuit ids.  If we're the
31
/// responder, we use low circuit ids.
32
#[derive(Copy, Clone)]
33
pub(crate) enum CircIdRange {
34
    /// Only use circuit IDs with the MSB cleared.
35
    #[allow(dead_code)] // Relays will need this.
36
    Low,
37
    /// Only use circuit IDs with the MSB set.
38
    High,
39
    // Historical note: There used to be an "All" range of circuit IDs
40
    // available to clients only.  We stopped using "All" when we moved to link
41
    // protocol version 4.
42
}
43

            
44
impl CircIdRange {
45
    /// The range of integer circuit IDs that we are allowed to allocate.
46
    /// Prefer using other more specific methods over this one.
47
528
    const fn integer_range(&self) -> std::ops::RangeInclusive<u32> {
48
        const MIDPOINT: u32 = 0x8000_0000;
49

            
50
528
        match self {
51
            // 0 is an invalid value
52
256
            Self::Low => 1..=(MIDPOINT - 1),
53
272
            Self::High => MIDPOINT..=u32::MAX,
54
        }
55
528
    }
56

            
57
    /// Is this circuit ID allowed to be allocated by the channel's peer?
58
    pub(crate) fn is_allowed_for_peer(&self, id: CircId) -> bool {
59
        // If our range does not contain it, then it is allowed.
60
        // Note that a `CircId` never contains a value of zero,
61
        // so no need to consider it here.
62
        !self.integer_range().contains(&id.into())
63
    }
64
}
65

            
66
impl rand::distr::Distribution<CircId> for CircIdRange {
67
    /// Return a random circuit ID in the appropriate range.
68
528
    fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
69
528
        let v = rng.gen_range_checked(self.integer_range());
70
528
        let v = v.expect("Unexpected empty range passed to gen_range_checked");
71
528
        CircId::new(v).expect("Unexpected zero value")
72
528
    }
73
}
74

            
75
/// An entry in the circuit map.  Right now, we only have "here's the
76
/// way to send cells to a given circuit", but that's likely to
77
/// change.
78
#[derive(Debug)]
79
pub(super) enum CircEnt {
80
    /// An origin circuit that has not yet received a CREATED cell.
81
    ///
82
    /// For this circuit, the CREATED* cell or DESTROY cell gets sent
83
    /// to the oneshot sender to tell the corresponding
84
    /// PendingClientCirc that the handshake is done.
85
    ///
86
    /// Once that's done, the `CircuitRxSender` mpsc sender will be used to send subsequent
87
    /// cells to the circuit.
88
    Opening {
89
        /// The oneshot sender on which to report a create response
90
        create_response_sender: oneshot::Sender<CreateResponse>,
91
        /// A sink which should receive all the relay cells for this circuit
92
        /// from this channel
93
        cell_sender: CircuitRxSender,
94
        /// A padding controller we should use when reporting flushed cells.
95
        padding_ctrl: PaddingController,
96
    },
97

            
98
    /// An origin circuit (a circuit which originated here)
99
    /// that is open and can be given relay cells.
100
    OpenOrigin {
101
        /// A sink which should receive all the relay cells for this circuit
102
        /// from this channel
103
        cell_sender: CircuitRxSender,
104
        /// A padding controller we should use when reporting flushed cells.
105
        padding_ctrl: PaddingController,
106
    },
107

            
108
    /// A relay circuit (a circuit in which we are a hop on the path)
109
    /// that is open and can be given relay cells.
110
    #[cfg(feature = "relay")]
111
    OpenRelay {
112
        /// A handle to the circuit.
113
        /// TODO(relay): We need to store the `Arc<RelayCirc>` somewhere
114
        /// and currently this seems like the best place to store it.
115
        /// As we implement more functionality maybe we'll find a better place to store it,
116
        /// in which case we should consider combining the `OpenOrigin` and `OpenRelay` variants.
117
        _circ: Arc<RelayCirc>,
118
        /// A sink which should receive all the relay cells for this circuit
119
        /// from this channel
120
        cell_sender: CircuitRxSender,
121
        /// A padding controller we should use when reporting flushed cells.
122
        padding_ctrl: PaddingController,
123
    },
124

            
125
    /// A circuit where we have sent a DESTROY, but the other end might
126
    /// not have gotten a DESTROY yet.
127
    DestroySent(HalfCirc),
128
}
129

            
130
/// An "smart pointer" that wraps an exclusive reference
131
/// of a `CircEnt`.
132
///
133
/// When being dropped, this object updates the open or opening entries
134
/// counter of the `CircMap`.
135
pub(super) struct MutCircEnt<'a> {
136
    /// An exclusive reference to the `CircEnt`.
137
    value: &'a mut CircEnt,
138
    /// An exclusive reference to the open or opening
139
    ///  entries counter.
140
    open_count: &'a mut usize,
141
    /// True if the entry was open or opening when borrowed.
142
    was_open: bool,
143
}
144

            
145
impl<'a> Drop for MutCircEnt<'a> {
146
624
    fn drop(&mut self) {
147
624
        let is_open = !matches!(self.value, CircEnt::DestroySent(_));
148
624
        match (self.was_open, is_open) {
149
            (false, true) => *self.open_count = self.open_count.saturating_add(1),
150
            (true, false) => *self.open_count = self.open_count.saturating_sub(1),
151
624
            (_, _) => (),
152
        };
153
624
    }
154
}
155

            
156
impl<'a> Deref for MutCircEnt<'a> {
157
    type Target = CircEnt;
158
284
    fn deref(&self) -> &Self::Target {
159
284
        self.value
160
284
    }
161
}
162

            
163
impl<'a> DerefMut for MutCircEnt<'a> {
164
336
    fn deref_mut(&mut self) -> &mut Self::Target {
165
336
        self.value
166
336
    }
167
}
168

            
169
/// A map from circuit IDs to circuit entries. Each channel has one.
170
pub(super) struct CircMap {
171
    /// Map from circuit IDs to entries
172
    m: HashMap<CircId, CircEnt>,
173
    /// Rule for allocating new circuit IDs.
174
    range: CircIdRange,
175
    /// Number of open or opening entry in this map.
176
    open_count: usize,
177
}
178

            
179
impl CircMap {
180
    /// Make a new empty CircMap
181
535
    pub(super) fn new(idrange: CircIdRange) -> Self {
182
535
        CircMap {
183
535
            m: HashMap::new(),
184
535
            range: idrange,
185
535
            open_count: 0,
186
535
        }
187
535
    }
188

            
189
    /// Add a new set of elements (corresponding to a
190
    /// [`PendingClientTunnel`](crate::client::circuit::PendingClientTunnel))
191
    /// as an entry to this map.
192
    ///
193
    /// On success return the allocated circuit ID.
194
528
    pub(super) fn add_origin_ent<R: Rng>(
195
528
        &mut self,
196
528
        rng: &mut R,
197
528
        createdsink: oneshot::Sender<CreateResponse>,
198
528
        sink: CircuitRxSender,
199
528
        padding_ctrl: PaddingController,
200
528
    ) -> Result<CircId> {
201
        /// How many times do we probe for a random circuit ID before
202
        /// we assume that the range is fully populated?
203
        ///
204
        /// TODO: C tor does 64, but that is probably overkill with 4-byte circuit IDs.
205
        const N_ATTEMPTS: usize = 16;
206
528
        let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
207
528
        let circ_ent = CircEnt::Opening {
208
528
            create_response_sender: createdsink,
209
528
            cell_sender: sink,
210
528
            padding_ctrl,
211
528
        };
212
528
        for id in iter {
213
528
            let ent = self.m.entry(id);
214
528
            if let Entry::Vacant(_) = &ent {
215
528
                ent.or_insert(circ_ent);
216
528
                self.open_count += 1;
217
528
                return Ok(id);
218
            }
219
        }
220
        Err(Error::IdRangeFull)
221
528
    }
222

            
223
    /// Add a new set of elements (corresponding to a [`RelayCirc`]) as an entry to this map.
224
    ///
225
    /// We use [`DestroyReason`] as the return type since we very likely want to destroy the circuit
226
    /// if this fails, and not return an error and destroy the entire channel.
227
    #[cfg(feature = "relay")]
228
    pub(super) fn add_relay_ent(
229
        &mut self,
230
        circ_id: CircId,
231
        circ: Arc<RelayCirc>,
232
        sink: CircuitRxSender,
233
        padding_ctrl: PaddingController,
234
    ) -> StdResult<(), DestroyReason> {
235
        // The peer is only allowed to use a subset of the ID range.
236
        if !self.range.is_allowed_for_peer(circ_id) {
237
            return Err(DestroyReason::PROTOCOL);
238
        }
239

            
240
        let circ_ent = CircEnt::OpenRelay {
241
            _circ: circ,
242
            cell_sender: sink,
243
            padding_ctrl,
244
        };
245

            
246
        if let Entry::Vacant(ent) = self.m.entry(circ_id) {
247
            ent.insert(circ_ent);
248
            self.open_count += 1;
249
            Ok(())
250
        } else {
251
            Err(DestroyReason::PROTOCOL)
252
        }
253
    }
254

            
255
    /// Testing only: install an entry in this circuit map without regard
256
    /// for consistency.
257
    #[cfg(test)]
258
72
    pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
259
72
        self.m.insert(id, ent);
260
72
    }
261

            
262
    /// Return the entry for `id` in this map, if any.
263
652
    pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
264
652
        let open_count = &mut self.open_count;
265
652
        self.m.get_mut(&id).map(move |ent| MutCircEnt {
266
624
            open_count,
267
624
            was_open: !matches!(ent, CircEnt::DestroySent(_)),
268
624
            value: ent,
269
624
        })
270
652
    }
271

            
272
    /// Inform the relevant circuit's padding subsystem that a given cell has been flushed.
273
4302
    pub(super) fn note_cell_flushed(&mut self, id: CircId, info: QueuedCellPaddingInfo) {
274
4302
        let padding_ctrl = match self.m.get(&id) {
275
            Some(CircEnt::Opening { padding_ctrl, .. }) => padding_ctrl,
276
            Some(CircEnt::OpenOrigin { padding_ctrl, .. }) => padding_ctrl,
277
            #[cfg(feature = "relay")]
278
            Some(CircEnt::OpenRelay { padding_ctrl, .. }) => padding_ctrl,
279
4302
            Some(CircEnt::DestroySent(..)) | None => return,
280
        };
281
        padding_ctrl.flushed_relay_cell(info);
282
4302
    }
283

            
284
    /// See whether 'id' is an opening circuit.  If so, mark it "open" and
285
    /// return a oneshot::Sender that is waiting for its create cell.
286
34
    pub(super) fn advance_from_opening(
287
34
        &mut self,
288
34
        id: CircId,
289
34
    ) -> Result<oneshot::Sender<CreateResponse>> {
290
        // TODO: there should be a better way to do
291
        // this. hash_map::Entry seems like it could be better, but
292
        // there seems to be no way to replace the object in-place as
293
        // a consuming function of itself.
294
34
        let ok = matches!(self.m.get(&id), Some(CircEnt::Opening { .. }));
295
34
        if ok {
296
            if let Some(CircEnt::Opening {
297
6
                create_response_sender: oneshot,
298
6
                cell_sender: sink,
299
6
                padding_ctrl,
300
6
            }) = self.m.remove(&id)
301
            {
302
6
                self.m.insert(
303
6
                    id,
304
6
                    CircEnt::OpenOrigin {
305
6
                        cell_sender: sink,
306
6
                        padding_ctrl,
307
6
                    },
308
                );
309
6
                Ok(oneshot)
310
            } else {
311
                panic!("internal error: inconsistent circuit state");
312
            }
313
        } else {
314
28
            Err(Error::ChanProto(
315
28
                "Unexpected CREATED* cell not on opening circuit".into(),
316
28
            ))
317
        }
318
34
    }
319

            
320
    /// Called when we have sent a DESTROY on a circuit.  Configures
321
    /// a "HalfCirc" object to track how many cells we get on this
322
    /// circuit, and to prevent us from reusing it immediately.
323
114
    pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
324
114
        if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
325
16
            if !matches!(replaced, CircEnt::DestroySent(_)) {
326
16
                // replaced an Open/Opening entry with DestroySent
327
16
                self.open_count = self.open_count.saturating_sub(1);
328
16
            }
329
98
        }
330
114
    }
331

            
332
    /// Extract the value from this map with 'id' if any
333
50
    pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
334
50
        self.m.remove(&id).map(|removed| {
335
38
            if !matches!(removed, CircEnt::DestroySent(_)) {
336
26
                self.open_count = self.open_count.saturating_sub(1);
337
26
            }
338
38
            removed
339
38
        })
340
50
    }
341

            
342
    /// Return the total number of open and opening entries in the map
343
184
    pub(super) fn open_ent_count(&self) -> usize {
344
184
        self.open_count
345
184
    }
346

            
347
    // TODO: Eventually if we want relay support, we'll need to support
348
    // circuit IDs chosen by somebody else. But for now, we don't need those.
349
}
350

            
351
#[cfg(test)]
352
mod test {
353
    // @@ begin test lint list maintained by maint/add_warning @@
354
    #![allow(clippy::bool_assert_comparison)]
355
    #![allow(clippy::clone_on_copy)]
356
    #![allow(clippy::dbg_macro)]
357
    #![allow(clippy::mixed_attributes_style)]
358
    #![allow(clippy::print_stderr)]
359
    #![allow(clippy::print_stdout)]
360
    #![allow(clippy::single_char_pattern)]
361
    #![allow(clippy::unwrap_used)]
362
    #![allow(clippy::unchecked_time_subtraction)]
363
    #![allow(clippy::useless_vec)]
364
    #![allow(clippy::needless_pass_by_value)]
365
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
366
    use super::*;
367
    use crate::{client::circuit::padding::new_padding, fake_mpsc};
368
    use tor_basic_utils::test_rng::testing_rng;
369
    use tor_rtcompat::DynTimeProvider;
370

            
371
    #[test]
372
    fn circmap_basics() {
373
        let mut map_low = CircMap::new(CircIdRange::Low);
374
        let mut map_high = CircMap::new(CircIdRange::High);
375
        let mut ids_low: Vec<CircId> = Vec::new();
376
        let mut ids_high: Vec<CircId> = Vec::new();
377
        let mut rng = testing_rng();
378
        tor_rtcompat::test_with_one_runtime!(|runtime| async {
379
            let (padding_ctrl, _padding_stream) = new_padding(DynTimeProvider::new(runtime));
380

            
381
            assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
382

            
383
            for _ in 0..128 {
384
                let (csnd, _) = oneshot::channel();
385
                let (snd, _) = fake_mpsc(8);
386
                let id_low = map_low
387
                    .add_origin_ent(&mut rng, csnd, snd, padding_ctrl.clone())
388
                    .unwrap();
389
                assert!(u32::from(id_low) > 0);
390
                assert!(u32::from(id_low) < 0x80000000);
391
                assert!(!ids_low.contains(&id_low));
392
                ids_low.push(id_low);
393

            
394
                assert!(matches!(
395
                    *map_low.get_mut(id_low).unwrap(),
396
                    CircEnt::Opening { .. }
397
                ));
398

            
399
                let (csnd, _) = oneshot::channel();
400
                let (snd, _) = fake_mpsc(8);
401
                let id_high = map_high
402
                    .add_origin_ent(&mut rng, csnd, snd, padding_ctrl.clone())
403
                    .unwrap();
404
                assert!(u32::from(id_high) >= 0x80000000);
405
                assert!(!ids_high.contains(&id_high));
406
                ids_high.push(id_high);
407
            }
408

            
409
            // Test open / opening entry counting
410
            assert_eq!(128, map_low.open_ent_count());
411
            assert_eq!(128, map_high.open_ent_count());
412

            
413
            // Test remove
414
            assert!(map_low.get_mut(ids_low[0]).is_some());
415
            map_low.remove(ids_low[0]);
416
            assert!(map_low.get_mut(ids_low[0]).is_none());
417
            assert_eq!(127, map_low.open_ent_count());
418

            
419
            // Test DestroySent doesn't count
420
            map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
421
            assert_eq!(127, map_low.open_ent_count());
422

            
423
            // Test advance_from_opening.
424

            
425
            // Good case.
426
            assert!(map_high.get_mut(ids_high[0]).is_some());
427
            assert!(matches!(
428
                *map_high.get_mut(ids_high[0]).unwrap(),
429
                CircEnt::Opening { .. }
430
            ));
431
            let adv = map_high.advance_from_opening(ids_high[0]);
432
            assert!(adv.is_ok());
433
            assert!(matches!(
434
                *map_high.get_mut(ids_high[0]).unwrap(),
435
                CircEnt::OpenOrigin { .. }
436
            ));
437

            
438
            // Can't double-advance.
439
            let adv = map_high.advance_from_opening(ids_high[0]);
440
            assert!(adv.is_err());
441

            
442
            // Can't advance an entry that is not there.  We know "77"
443
            // can't be in map_high, since we only added high circids to
444
            // it.
445
            let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
446
            assert!(adv.is_err());
447
        });
448
    }
449
}