1
//! Types and code to map circuit IDs to circuits.
2

            
3
// NOTE: This is a work in progress and I bet I'll refactor it a lot;
4
// it needs to stay opaque!
5

            
6
use crate::circuit::CircuitRxSender;
7
use crate::client::circuit::padding::{PaddingController, QueuedCellPaddingInfo};
8
use crate::{Error, Result};
9
use tor_basic_utils::RngExt;
10
use tor_cell::chancell::CircId;
11

            
12
use crate::circuit::celltypes::CreateResponse;
13
use crate::client::circuit::halfcirc::HalfCirc;
14

            
15
use oneshot_fused_workaround as oneshot;
16

            
17
use rand::Rng;
18
use rand::distr::Distribution;
19
use std::collections::{HashMap, hash_map::Entry};
20
use std::ops::{Deref, DerefMut};
21

            
22
/// Which group of circuit IDs are we allowed to allocate in this map?
23
///
24
/// If we initiated the channel, we use High circuit ids.  If we're the
25
/// responder, we use low circuit ids.
26
#[derive(Copy, Clone)]
27
pub(super) enum CircIdRange {
28
    /// Only use circuit IDs with the MSB cleared.
29
    #[allow(dead_code)] // Relays will need this.
30
    Low,
31
    /// Only use circuit IDs with the MSB set.
32
    High,
33
    // Historical note: There used to be an "All" range of circuit IDs
34
    // available to clients only.  We stopped using "All" when we moved to link
35
    // protocol version 4.
36
}
37

            
38
impl rand::distr::Distribution<CircId> for CircIdRange {
39
    /// Return a random circuit ID in the appropriate range.
40
524
    fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
41
524
        let midpoint = 0x8000_0000_u32;
42
524
        let v = match self {
43
            // 0 is an invalid value
44
256
            CircIdRange::Low => rng.gen_range_checked(1..midpoint),
45
268
            CircIdRange::High => rng.gen_range_checked(midpoint..=u32::MAX),
46
        };
47
524
        let v = v.expect("Unexpected empty range passed to gen_range_checked");
48
524
        CircId::new(v).expect("Unexpected zero value")
49
524
    }
50
}
51

            
52
/// An entry in the circuit map.  Right now, we only have "here's the
53
/// way to send cells to a given circuit", but that's likely to
54
/// change.
55
#[derive(Debug)]
56
pub(super) enum CircEnt {
57
    /// A circuit that has not yet received a CREATED cell.
58
    ///
59
    /// For this circuit, the CREATED* cell or DESTROY cell gets sent
60
    /// to the oneshot sender to tell the corresponding
61
    /// PendingClientCirc that the handshake is done.
62
    ///
63
    /// Once that's done, the `CircuitRxSender` mpsc sender will be used to send subsequent
64
    /// cells to the circuit.
65
    Opening {
66
        /// The oneshot sender on which to report a create response
67
        create_response_sender: oneshot::Sender<CreateResponse>,
68
        /// A sink which should receive all the relay cells for this circuit
69
        /// from this channel
70
        cell_sender: CircuitRxSender,
71
        /// A padding controller we should use when reporting flushed cells.
72
        padding_ctrl: PaddingController,
73
    },
74

            
75
    /// A circuit that is open and can be given relay cells.
76
    Open {
77
        /// A sink which should receive all the relay cells for this circuit
78
        /// from this channel
79
        cell_sender: CircuitRxSender,
80
        /// A padding controller we should use when reporting flushed cells.
81
        padding_ctrl: PaddingController,
82
    },
83

            
84
    /// A circuit where we have sent a DESTROY, but the other end might
85
    /// not have gotten a DESTROY yet.
86
    DestroySent(HalfCirc),
87
}
88

            
89
/// An "smart pointer" that wraps an exclusive reference
90
/// of a `CircEnt`.
91
///
92
/// When being dropped, this object updates the open or opening entries
93
/// counter of the `CircMap`.
94
pub(super) struct MutCircEnt<'a> {
95
    /// An exclusive reference to the `CircEnt`.
96
    value: &'a mut CircEnt,
97
    /// An exclusive reference to the open or opening
98
    ///  entries counter.
99
    open_count: &'a mut usize,
100
    /// True if the entry was open or opening when borrowed.
101
    was_open: bool,
102
}
103

            
104
impl<'a> Drop for MutCircEnt<'a> {
105
624
    fn drop(&mut self) {
106
624
        let is_open = !matches!(self.value, CircEnt::DestroySent(_));
107
624
        match (self.was_open, is_open) {
108
            (false, true) => *self.open_count = self.open_count.saturating_add(1),
109
            (true, false) => *self.open_count = self.open_count.saturating_sub(1),
110
624
            (_, _) => (),
111
        };
112
624
    }
113
}
114

            
115
impl<'a> Deref for MutCircEnt<'a> {
116
    type Target = CircEnt;
117
284
    fn deref(&self) -> &Self::Target {
118
284
        self.value
119
284
    }
120
}
121

            
122
impl<'a> DerefMut for MutCircEnt<'a> {
123
336
    fn deref_mut(&mut self) -> &mut Self::Target {
124
336
        self.value
125
336
    }
126
}
127

            
128
/// A map from circuit IDs to circuit entries. Each channel has one.
129
pub(super) struct CircMap {
130
    /// Map from circuit IDs to entries
131
    m: HashMap<CircId, CircEnt>,
132
    /// Rule for allocating new circuit IDs.
133
    range: CircIdRange,
134
    /// Number of open or opening entry in this map.
135
    open_count: usize,
136
}
137

            
138
impl CircMap {
139
    /// Make a new empty CircMap
140
526
    pub(super) fn new(idrange: CircIdRange) -> Self {
141
526
        CircMap {
142
526
            m: HashMap::new(),
143
526
            range: idrange,
144
526
            open_count: 0,
145
526
        }
146
526
    }
147

            
148
    /// Add a new set of elements (corresponding to a PendingClientCirc)
149
    /// to this map.
150
    ///
151
    /// On success return the allocated circuit ID.
152
524
    pub(super) fn add_ent<R: Rng>(
153
524
        &mut self,
154
524
        rng: &mut R,
155
524
        createdsink: oneshot::Sender<CreateResponse>,
156
524
        sink: CircuitRxSender,
157
524
        padding_ctrl: PaddingController,
158
524
    ) -> Result<CircId> {
159
        /// How many times do we probe for a random circuit ID before
160
        /// we assume that the range is fully populated?
161
        ///
162
        /// TODO: C tor does 64, but that is probably overkill with 4-byte circuit IDs.
163
        const N_ATTEMPTS: usize = 16;
164
524
        let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
165
524
        let circ_ent = CircEnt::Opening {
166
524
            create_response_sender: createdsink,
167
524
            cell_sender: sink,
168
524
            padding_ctrl,
169
524
        };
170
524
        for id in iter {
171
524
            let ent = self.m.entry(id);
172
524
            if let Entry::Vacant(_) = &ent {
173
524
                ent.or_insert(circ_ent);
174
524
                self.open_count += 1;
175
524
                return Ok(id);
176
            }
177
        }
178
        Err(Error::IdRangeFull)
179
524
    }
180

            
181
    /// Testing only: install an entry in this circuit map without regard
182
    /// for consistency.
183
    #[cfg(test)]
184
72
    pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
185
72
        self.m.insert(id, ent);
186
72
    }
187

            
188
    /// Return the entry for `id` in this map, if any.
189
652
    pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
190
652
        let open_count = &mut self.open_count;
191
652
        self.m.get_mut(&id).map(move |ent| MutCircEnt {
192
624
            open_count,
193
624
            was_open: !matches!(ent, CircEnt::DestroySent(_)),
194
624
            value: ent,
195
624
        })
196
652
    }
197

            
198
    /// Inform the relevant circuit's padding subsystem that a given cell has been flushed.
199
4334
    pub(super) fn note_cell_flushed(&mut self, id: CircId, info: QueuedCellPaddingInfo) {
200
4334
        let padding_ctrl = match self.m.get(&id) {
201
            Some(CircEnt::Opening { padding_ctrl, .. }) => padding_ctrl,
202
            Some(CircEnt::Open { padding_ctrl, .. }) => padding_ctrl,
203
4334
            Some(CircEnt::DestroySent(..)) | None => return,
204
        };
205
        padding_ctrl.flushed_relay_cell(info);
206
4334
    }
207

            
208
    /// See whether 'id' is an opening circuit.  If so, mark it "open" and
209
    /// return a oneshot::Sender that is waiting for its create cell.
210
30
    pub(super) fn advance_from_opening(
211
30
        &mut self,
212
30
        id: CircId,
213
30
    ) -> Result<oneshot::Sender<CreateResponse>> {
214
        // TODO: there should be a better way to do
215
        // this. hash_map::Entry seems like it could be better, but
216
        // there seems to be no way to replace the object in-place as
217
        // a consuming function of itself.
218
30
        let ok = matches!(self.m.get(&id), Some(CircEnt::Opening { .. }));
219
30
        if ok {
220
            if let Some(CircEnt::Opening {
221
2
                create_response_sender: oneshot,
222
2
                cell_sender: sink,
223
2
                padding_ctrl,
224
2
            }) = self.m.remove(&id)
225
            {
226
2
                self.m.insert(
227
2
                    id,
228
2
                    CircEnt::Open {
229
2
                        cell_sender: sink,
230
2
                        padding_ctrl,
231
2
                    },
232
                );
233
2
                Ok(oneshot)
234
            } else {
235
                panic!("internal error: inconsistent circuit state");
236
            }
237
        } else {
238
28
            Err(Error::ChanProto(
239
28
                "Unexpected CREATED* cell not on opening circuit".into(),
240
28
            ))
241
        }
242
30
    }
243

            
244
    /// Called when we have sent a DESTROY on a circuit.  Configures
245
    /// a "HalfCirc" object to track how many cells we get on this
246
    /// circuit, and to prevent us from reusing it immediately.
247
100
    pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
248
100
        if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
249
12
            if !matches!(replaced, CircEnt::DestroySent(_)) {
250
12
                // replaced an Open/Opening entry with DestroySent
251
12
                self.open_count = self.open_count.saturating_sub(1);
252
12
            }
253
88
        }
254
100
    }
255

            
256
    /// Extract the value from this map with 'id' if any
257
50
    pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
258
69
        self.m.remove(&id).map(|removed| {
259
38
            if !matches!(removed, CircEnt::DestroySent(_)) {
260
26
                self.open_count = self.open_count.saturating_sub(1);
261
26
            }
262
38
            removed
263
38
        })
264
50
    }
265

            
266
    /// Return the total number of open and opening entries in the map
267
166
    pub(super) fn open_ent_count(&self) -> usize {
268
166
        self.open_count
269
166
    }
270

            
271
    // TODO: Eventually if we want relay support, we'll need to support
272
    // circuit IDs chosen by somebody else. But for now, we don't need those.
273
}
274

            
275
#[cfg(test)]
276
mod test {
277
    // @@ begin test lint list maintained by maint/add_warning @@
278
    #![allow(clippy::bool_assert_comparison)]
279
    #![allow(clippy::clone_on_copy)]
280
    #![allow(clippy::dbg_macro)]
281
    #![allow(clippy::mixed_attributes_style)]
282
    #![allow(clippy::print_stderr)]
283
    #![allow(clippy::print_stdout)]
284
    #![allow(clippy::single_char_pattern)]
285
    #![allow(clippy::unwrap_used)]
286
    #![allow(clippy::unchecked_time_subtraction)]
287
    #![allow(clippy::useless_vec)]
288
    #![allow(clippy::needless_pass_by_value)]
289
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
290
    use super::*;
291
    use crate::{client::circuit::padding::new_padding, fake_mpsc};
292
    use tor_basic_utils::test_rng::testing_rng;
293
    use tor_rtcompat::DynTimeProvider;
294

            
295
    #[test]
296
    fn circmap_basics() {
297
        let mut map_low = CircMap::new(CircIdRange::Low);
298
        let mut map_high = CircMap::new(CircIdRange::High);
299
        let mut ids_low: Vec<CircId> = Vec::new();
300
        let mut ids_high: Vec<CircId> = Vec::new();
301
        let mut rng = testing_rng();
302
        tor_rtcompat::test_with_one_runtime!(|runtime| async {
303
            let (padding_ctrl, _padding_stream) = new_padding(DynTimeProvider::new(runtime));
304

            
305
            assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
306

            
307
            for _ in 0..128 {
308
                let (csnd, _) = oneshot::channel();
309
                let (snd, _) = fake_mpsc(8);
310
                let id_low = map_low
311
                    .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
312
                    .unwrap();
313
                assert!(u32::from(id_low) > 0);
314
                assert!(u32::from(id_low) < 0x80000000);
315
                assert!(!ids_low.contains(&id_low));
316
                ids_low.push(id_low);
317

            
318
                assert!(matches!(
319
                    *map_low.get_mut(id_low).unwrap(),
320
                    CircEnt::Opening { .. }
321
                ));
322

            
323
                let (csnd, _) = oneshot::channel();
324
                let (snd, _) = fake_mpsc(8);
325
                let id_high = map_high
326
                    .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
327
                    .unwrap();
328
                assert!(u32::from(id_high) >= 0x80000000);
329
                assert!(!ids_high.contains(&id_high));
330
                ids_high.push(id_high);
331
            }
332

            
333
            // Test open / opening entry counting
334
            assert_eq!(128, map_low.open_ent_count());
335
            assert_eq!(128, map_high.open_ent_count());
336

            
337
            // Test remove
338
            assert!(map_low.get_mut(ids_low[0]).is_some());
339
            map_low.remove(ids_low[0]);
340
            assert!(map_low.get_mut(ids_low[0]).is_none());
341
            assert_eq!(127, map_low.open_ent_count());
342

            
343
            // Test DestroySent doesn't count
344
            map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
345
            assert_eq!(127, map_low.open_ent_count());
346

            
347
            // Test advance_from_opening.
348

            
349
            // Good case.
350
            assert!(map_high.get_mut(ids_high[0]).is_some());
351
            assert!(matches!(
352
                *map_high.get_mut(ids_high[0]).unwrap(),
353
                CircEnt::Opening { .. }
354
            ));
355
            let adv = map_high.advance_from_opening(ids_high[0]);
356
            assert!(adv.is_ok());
357
            assert!(matches!(
358
                *map_high.get_mut(ids_high[0]).unwrap(),
359
                CircEnt::Open { .. }
360
            ));
361

            
362
            // Can't double-advance.
363
            let adv = map_high.advance_from_opening(ids_high[0]);
364
            assert!(adv.is_err());
365

            
366
            // Can't advance an entry that is not there.  We know "77"
367
            // can't be in map_high, since we only added high circids to
368
            // it.
369
            let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
370
            assert!(adv.is_err());
371
        });
372
    }
373
}