1
//! Functions for applying the correct weights to relays when choosing
2
//! a relay at random.
3
//!
4
//! The weight to use when picking a relay depends on several factors:
5
//!
6
//! - The relay's *apparent bandwidth*.  (This is ideally measured by a set of
7
//!   bandwidth authorities, but if no bandwidth authorities are running (as on
8
//!   a test network), we might fall back either to relays' self-declared
9
//!   values, or we might treat all relays as having equal bandwidth.)
10
//! - The role that we're selecting a relay to play.  (See [`WeightRole`]).
11
//! - The flags that a relay has in the consensus, and their scarcity.  If a
12
//!   relay provides particularly scarce functionality, we might choose not to
13
//!   use it for other roles, or to use it less commonly for them.
14

            
15
use crate::ConsensusRelays;
16
use crate::params::NetParameters;
17
use bitflags::bitflags;
18
use tor_netdoc::doc::netstatus::{self, MdConsensus, MdRouterStatus, NetParams};
19

            
20
/// Helper: Calculate the function we should use to find initial relay
21
/// bandwidths.
22
10242
fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23
10242
where
24
10242
    I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25
{
26
10789
    let has_measured = weights.clone().any(|w| w.is_measured());
27
10489
    let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28
10791
    let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29

            
30
10242
    if !has_nonzero {
31
        // If every value is zero, we should just pretend everything has
32
        // bandwidth == 1.
33
53
        BandwidthFn::Uniform
34
10189
    } else if !has_measured {
35
        // If there are no measured values, then we can look at unmeasured
36
        // weights.
37
100
        BandwidthFn::IncludeUnmeasured
38
10089
    } else if has_nonzero_measured {
39
        // Otherwise, there are measured values; we should look at those only, if
40
        // any of them is nonzero.
41
10087
        BandwidthFn::MeasuredOnly
42
    } else {
43
        // This is a bit of an ugly case: We have measured values, but they're
44
        // all zero.  If this happens, the bandwidth authorities exist but they
45
        // very confused: we should fall back to uniform weighting.
46
2
        BandwidthFn::Uniform
47
    }
48
10242
}
49

            
50
/// Internal: how should we find the base bandwidth of each relay?  This
51
/// value is global over a whole directory, and depends on the bandwidth
52
/// weights in the consensus.
53
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54
enum BandwidthFn {
55
    /// There are no weights at all in the consensus: weight every
56
    /// relay as 1.
57
    Uniform,
58
    /// There are no measured weights in the consensus: count
59
    /// unmeasured weights as the weights for relays.
60
    IncludeUnmeasured,
61
    /// There are measured relays in the consensus; only use those.
62
    MeasuredOnly,
63
}
64

            
65
impl BandwidthFn {
66
    /// Apply this function to the measured or unmeasured bandwidth
67
    /// of a single relay.
68
25017478
    fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69
        use BandwidthFn::*;
70
        use netstatus::RelayWeight::*;
71
25017478
        match (self, w) {
72
2650
            (Uniform, _) => 1,
73
4510
            (IncludeUnmeasured, Unmeasured(u)) => *u,
74
2
            (IncludeUnmeasured, Measured(m)) => *m,
75
124
            (MeasuredOnly, Unmeasured(_)) => 0,
76
25010192
            (MeasuredOnly, Measured(m)) => *m,
77
            (_, _) => 0,
78
        }
79
25017478
    }
80
}
81

            
82
/// Possible ways to weight relays when selecting them a random.
83
///
84
/// Relays are weighted by a function of their bandwidth that
85
/// depends on how scarce that "kind" of bandwidth is.  For
86
/// example, if Exit bandwidth is rare, then Exits should be
87
/// less likely to get chosen for the middle hop of a path.
88
#[derive(Clone, Debug, Copy)]
89
#[non_exhaustive]
90
pub enum WeightRole {
91
    /// Selecting a relay to use as a guard
92
    Guard,
93
    /// Selecting a relay to use as a middle relay in a circuit.
94
    Middle,
95
    /// Selecting a relay to use to deliver traffic to the internet.
96
    Exit,
97
    /// Selecting a relay for a one-hop BEGIN_DIR directory request.
98
    BeginDir,
99
    /// Selecting a relay with no additional weight beyond its bandwidth.
100
    Unweighted,
101
    /// Selecting a relay for use as a hidden service introduction point
102
    HsIntro,
103
    /// Selecting a relay for use as a hidden service rendezvous point
104
    HsRend,
105
}
106

            
107
/// Description for how to weight a single kind of relay for each WeightRole.
108
#[derive(Clone, Debug, Copy)]
109
struct RelayWeight {
110
    /// How to weight this kind of relay when picking a guard relay.
111
    as_guard: u32,
112
    /// How to weight this kind of relay when picking a middle relay.
113
    as_middle: u32,
114
    /// How to weight this kind of relay when picking a exit relay.
115
    as_exit: u32,
116
    /// How to weight this kind of relay when picking a one-hop BEGIN_DIR.
117
    as_dir: u32,
118
}
119

            
120
impl std::ops::Mul<u32> for RelayWeight {
121
    type Output = Self;
122
40936
    fn mul(self, rhs: u32) -> Self {
123
40936
        RelayWeight {
124
40936
            as_guard: self.as_guard * rhs,
125
40936
            as_middle: self.as_middle * rhs,
126
40936
            as_exit: self.as_exit * rhs,
127
40936
            as_dir: self.as_dir * rhs,
128
40936
        }
129
40936
    }
130
}
131
impl std::ops::Div<u32> for RelayWeight {
132
    type Output = Self;
133
40936
    fn div(self, rhs: u32) -> Self {
134
40936
        RelayWeight {
135
40936
            as_guard: self.as_guard / rhs,
136
40936
            as_middle: self.as_middle / rhs,
137
40936
            as_exit: self.as_exit / rhs,
138
40936
            as_dir: self.as_dir / rhs,
139
40936
        }
140
40936
    }
141
}
142

            
143
impl RelayWeight {
144
    /// Return the largest weight that we give for this kind of relay.
145
    // The unwrap() is safe because array is nonempty.
146
    #[allow(clippy::unwrap_used)]
147
81872
    fn max_weight(&self) -> u32 {
148
81872
        [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
149
81872
            .iter()
150
81872
            .max()
151
81872
            .copied()
152
81872
            .unwrap()
153
81872
    }
154
    /// Return the weight we should give this kind of relay's
155
    /// bandwidth for a given role.
156
24630481
    fn for_role(&self, role: WeightRole) -> u32 {
157
24630481
        match role {
158
1301046
            WeightRole::Guard => self.as_guard,
159
14942030
            WeightRole::Middle => self.as_middle,
160
7587065
            WeightRole::Exit => self.as_exit,
161
386989
            WeightRole::BeginDir => self.as_dir,
162
26362
            WeightRole::HsIntro => self.as_middle, // TODO SPEC is this right?
163
            WeightRole::HsRend => self.as_middle,  // TODO SPEC is this right?
164
386989
            WeightRole::Unweighted => 1,
165
        }
166
24630481
    }
167
}
168

            
169
bitflags! {
170
    /// A kind of relay, for the purposes of selecting a relay by weight.
171
    ///
172
    /// Relays can have or lack the Guard flag, the Exit flag, and the
173
    /// V2Dir flag. All together, this makes 8 kinds of relays.
174
    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
175
    struct WeightKind: u8 {
176
        /// Flag in weightkind for Guard relays.
177
        const GUARD = 1 << 0;
178
        /// Flag in weightkind for Exit relays.
179
        const EXIT = 1 << 1;
180
        /// Flag in weightkind for V2Dir relays.
181
        const DIR = 1 << 2;
182
    }
183
}
184

            
185
impl WeightKind {
186
    /// Return the appropriate WeightKind for a relay.
187
24630479
    fn for_rs(rs: &MdRouterStatus) -> Self {
188
24630479
        let mut r = WeightKind::empty();
189
24630479
        if rs.is_flagged_guard() {
190
12653792
            r |= WeightKind::GUARD;
191
12653792
        }
192
24630479
        if rs.is_flagged_exit() {
193
15327305
            r |= WeightKind::EXIT;
194
15327305
        }
195
24630479
        if rs.is_flagged_v2dir() {
196
24630073
            r |= WeightKind::DIR;
197
24630073
        }
198
24630479
        r
199
24630479
    }
200
    /// Return the index to use for this kind of a relay within a WeightSet.
201
24630481
    fn idx(self) -> usize {
202
24630481
        self.bits() as usize
203
24630481
    }
204
}
205

            
206
/// Information derived from a consensus to use when picking relays by
207
/// weighted bandwidth.
208
#[derive(Debug, Clone)]
209
pub(crate) struct WeightSet {
210
    /// How to find the bandwidth to use when picking a relay by weighted
211
    /// bandwidth.
212
    ///
213
    /// (This tells us us whether to count unmeasured relays, whether
214
    /// to look at bandwidths at all, etc.)
215
    bandwidth_fn: BandwidthFn,
216
    /// Number of bits that we need to right-shift our weighted products
217
    /// so that their sum won't overflow u64::MAX.
218
    //
219
    // TODO: Perhaps we should use f64 to hold our weights instead,
220
    // so we don't need to keep this ad-hoc fixed-point implementation?
221
    // If we did so, we won't have to worry about overflows.
222
    // (When we call choose_multiple_weighted, it already converts into
223
    // f64 internally.  (Though choose_weighted doesn't.))
224
    // Before making this change, however,
225
    // we should think a little about performance and precision.
226
    shift: u8,
227
    /// A set of RelayWeight values, indexed by [`WeightKind::idx`], used
228
    /// to weight different kinds of relays.
229
    w: [RelayWeight; 8],
230
}
231

            
232
impl WeightSet {
233
    /// Find the actual 64-bit weight to use for a given routerstatus when
234
    /// considering it for a given role.
235
    ///
236
    /// NOTE: This function _does not_ consider whether the relay in question
237
    /// actually matches the given role.  For example, if `role` is Guard
238
    /// we don't check whether or not `rs` actually has the Guard flag.
239
24630469
    pub(crate) fn weight_rs_for_role(&self, rs: &MdRouterStatus, role: WeightRole) -> u64 {
240
24630469
        self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
241
24630469
    }
242

            
243
    /// Find the 64-bit weight to report for a relay of `kind` whose weight in
244
    /// the consensus is `relay_weight` when using it for `role`.
245
24630481
    fn weight_bw_for_role(
246
24630481
        &self,
247
24630481
        kind: WeightKind,
248
24630481
        relay_weight: &netstatus::RelayWeight,
249
24630481
        role: WeightRole,
250
24630481
    ) -> u64 {
251
24630481
        let ws = &self.w[kind.idx()];
252

            
253
24630481
        let router_bw = self.bandwidth_fn.apply(relay_weight);
254
        // Note a subtlety here: we multiply the two values _before_
255
        // we shift, to improve accuracy.  We know that this will be
256
        // safe, since the inputs are both u32, and so cannot overflow
257
        // a u64.
258
24630481
        let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
259
24630481
        router_weight >> self.shift
260
24630481
    }
261

            
262
    /// Compute the correct WeightSet for a provided MdConsensus.
263
10232
    pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
264
32251
        let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
265
10232
        let weight_scale = params.bw_weight_scale.into();
266

            
267
10232
        let total_bw = consensus
268
10232
            .c_relays()
269
10232
            .iter()
270
387213
            .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
271
10232
            .sum();
272
10232
        let p = consensus.bandwidth_weights();
273

            
274
10232
        Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
275
10232
    }
276

            
277
    /// Compute the correct WeightSet given a bandwidth function, a
278
    /// weight-scaling parameter, a total amount of bandwidth for all
279
    /// relays in the consensus, and a set of bandwidth parameters.
280
10234
    fn from_parts(
281
10234
        bandwidth_fn: BandwidthFn,
282
10234
        total_bw: u64,
283
10234
        weight_scale: u32,
284
10234
        p: &NetParams<i32>,
285
10234
    ) -> Self {
286
        /// Find a single RelayWeight, given the names that its bandwidth
287
        /// parameters have. The `g` parameter is the weight as a guard, the
288
        /// `m` parameter is the weight as a middle relay, the `e` parameter is
289
        /// the weight as an exit, and the `d` parameter is the weight as a
290
        /// directory.
291
        #[allow(clippy::many_single_char_names)]
292
40936
        fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
293
40936
            RelayWeight {
294
40936
                as_guard: w_param(p, g),
295
40936
                as_middle: w_param(p, m),
296
40936
                as_exit: w_param(p, e),
297
40936
                as_dir: w_param(p, d),
298
40936
            }
299
40936
        }
300

            
301
        // Prevent division by zero in case we're called with a bogus
302
        // input.  (That shouldn't be possible.)
303
10234
        let weight_scale = weight_scale.max(1);
304

            
305
        // For non-V2Dir relays, we have names for most of their weights.
306
        //
307
        // (There is no Wge, since we only use Guard relays as guards.  By the
308
        // same logic, Wme has no reason to exist, but according to the spec it
309
        // does.)
310
10234
        let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
311
10234
        let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
312
10234
        let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
313
10234
        let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
314

            
315
        // Note that the positions of the elements in this array need to
316
        // match the values returned by WeightKind.as_idx().
317
10234
        let w = [
318
10234
            w_none,
319
10234
            w_guard,
320
10234
            w_exit,
321
10234
            w_both,
322
10234
            // The V2Dir values are the same as the non-V2Dir values, except
323
10234
            // each is multiplied by an additional factor.
324
10234
            //
325
10234
            // (We don't need to check for overflow here, since the
326
10234
            // authorities make sure that the inputs don't get too big.)
327
10234
            (w_none * w_param(p, "Wmb")) / weight_scale,
328
10234
            (w_guard * w_param(p, "Wgb")) / weight_scale,
329
10234
            (w_exit * w_param(p, "Web")) / weight_scale,
330
10234
            (w_both * w_param(p, "Wdb")) / weight_scale,
331
10234
        ];
332

            
333
        // This is the largest weight value.
334
        // The unwrap() is safe because `w` is nonempty.
335
        #[allow(clippy::unwrap_used)]
336
10234
        let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
337

            
338
        // We want "shift" such that (total * w_max) >> shift <= u64::max
339
10234
        let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
340

            
341
10234
        WeightSet {
342
10234
            bandwidth_fn,
343
10234
            shift,
344
10234
            w,
345
10234
        }
346
10234
    }
347

            
348
    /// Assert that we have correctly computed our shift values so that
349
    /// our total weighted bws do not exceed u64::MAX.
350
10232
    fn validate(self, consensus: &MdConsensus) -> Self {
351
        use WeightRole::*;
352
51160
        for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
353
51160
            let _: u64 = consensus
354
51160
                .c_relays()
355
51160
                .iter()
356
1936065
                .map(|rs| self.weight_rs_for_role(rs, role))
357
1936065
                .fold(0_u64, |a, b| {
358
1934925
                    a.checked_add(b)
359
1934925
                        .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
360
1934925
                });
361
        }
362
10232
        self
363
10232
    }
364
}
365

            
366
/// The value to return if a weight parameter is absent.
367
///
368
/// (If there are no weights at all, then it's correct to set them all to 1,
369
/// and just use the bandwidths.  If _some_ are present and some are absent,
370
/// then the spec doesn't say what to do, but this behavior appears
371
/// reasonable.)
372
const DFLT_WEIGHT: i32 = 1;
373

            
374
/// Return the weight param named 'kwd' in p.
375
///
376
/// Returns DFLT_WEIGHT if there is no such parameter, and 0
377
/// if `kwd` is "---".
378
204680
fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
379
204680
    if kwd == "---" {
380
10234
        0
381
    } else {
382
194446
        clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
383
    }
384
204680
}
385

            
386
/// If `inp` is less than 0, return 0.  Otherwise return `inp` as a u32.
387
194456
fn clamp_to_pos(inp: i32) -> u32 {
388
    // (The spec says that we might encounter negative values here, though
389
    // we never actually generate them, and don't plan to generate them.)
390
194456
    if inp < 0 { 0 } else { inp as u32 }
391
194456
}
392

            
393
/// Compute a 'shift' value such that `(a * b) >> shift` will be contained
394
/// inside 64 bits.
395
10242
fn calculate_shift(a: u64, b: u64) -> u32 {
396
10242
    let bits_for_product = log2_upper(a) + log2_upper(b);
397
10242
    bits_for_product.saturating_sub(64)
398
10242
}
399

            
400
/// Return an upper bound for the log2 of n.
401
///
402
/// This function overestimates whenever n is a power of two, but that doesn't
403
/// much matter for the uses we're giving it here.
404
20494
fn log2_upper(n: u64) -> u32 {
405
20494
    64 - n.leading_zeros()
406
20494
}
407

            
408
#[cfg(test)]
409
mod test {
410
    // @@ begin test lint list maintained by maint/add_warning @@
411
    #![allow(clippy::bool_assert_comparison)]
412
    #![allow(clippy::clone_on_copy)]
413
    #![allow(clippy::dbg_macro)]
414
    #![allow(clippy::mixed_attributes_style)]
415
    #![allow(clippy::print_stderr)]
416
    #![allow(clippy::print_stdout)]
417
    #![allow(clippy::single_char_pattern)]
418
    #![allow(clippy::unwrap_used)]
419
    #![allow(clippy::unchecked_time_subtraction)]
420
    #![allow(clippy::useless_vec)]
421
    #![allow(clippy::needless_pass_by_value)]
422
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
423
    use super::*;
424
    use netstatus::RelayWeight as RW;
425
    use std::net::SocketAddr;
426
    use std::time::{Duration, SystemTime};
427
    use tor_basic_utils::test_rng::testing_rng;
428
    use tor_netdoc::doc::netstatus::{Lifetime, MdRouterStatusBuilder};
429
    use tor_netdoc::types::relay_flags::{RelayFlag, RelayFlags};
430
    use web_time_compat::SystemTimeExt;
431

            
432
    #[test]
433
    fn t_clamp() {
434
        assert_eq!(clamp_to_pos(32), 32);
435
        assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
436
        assert_eq!(clamp_to_pos(0), 0);
437
        assert_eq!(clamp_to_pos(-1), 0);
438
        assert_eq!(clamp_to_pos(i32::MIN), 0);
439
    }
440

            
441
    #[test]
442
    fn t_log2() {
443
        assert_eq!(log2_upper(u64::MAX), 64);
444
        assert_eq!(log2_upper(0), 0);
445
        assert_eq!(log2_upper(1), 1);
446
        assert_eq!(log2_upper(63), 6);
447
        assert_eq!(log2_upper(64), 7); // a little buggy but harmless.
448
    }
449

            
450
    #[test]
451
    fn t_calc_shift() {
452
        assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
453
        assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
454
        assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
455
        assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
456
        assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
457
        assert!(
458
            ((432_u64 << 40) >> 38)
459
                .checked_mul(7777_u64 << 40)
460
                .is_some()
461
        );
462
    }
463

            
464
    #[test]
465
    fn t_pick_bwfunc() {
466
        let empty = [];
467
        assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
468

            
469
        let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
470
        assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
471

            
472
        let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
473
        assert_eq!(
474
            pick_bandwidth_fn(all_unmeasured.iter()),
475
            BandwidthFn::IncludeUnmeasured
476
        );
477

            
478
        let some_measured = [
479
            RW::Unmeasured(10),
480
            RW::Measured(7),
481
            RW::Measured(4),
482
            RW::Unmeasured(0),
483
        ];
484
        assert_eq!(
485
            pick_bandwidth_fn(some_measured.iter()),
486
            BandwidthFn::MeasuredOnly
487
        );
488

            
489
        // This corresponds to an open question in
490
        // `pick_bandwidth_fn`, about what to do when the only nonzero
491
        // weights are unmeasured.
492
        let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
493
        assert_eq!(
494
            pick_bandwidth_fn(measured_all_zero.iter()),
495
            BandwidthFn::Uniform
496
        );
497
    }
498

            
499
    #[test]
500
    fn t_apply_bwfn() {
501
        use BandwidthFn::*;
502
        use netstatus::RelayWeight::*;
503

            
504
        assert_eq!(Uniform.apply(&Measured(7)), 1);
505
        assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
506

            
507
        assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
508
        assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
509

            
510
        assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
511
        assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
512
    }
513

            
514
    // From a fairly recent Tor consensus.
515
    const TESTVEC_PARAMS: &str = "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
516

            
517
    #[test]
518
    fn t_weightset_basic() {
519
        let total_bandwidth = 1_000_000_000;
520
        let params = TESTVEC_PARAMS.parse().unwrap();
521
        let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, &params);
522

            
523
        assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
524
        assert_eq!(ws.shift, 0);
525

            
526
        assert_eq!(ws.w[0].as_guard, 5904);
527
        assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
528
        assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
529
        assert_eq!(
530
            ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
531
            0
532
        );
533
        assert_eq!(
534
            ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
535
            4096
536
        );
537
        assert_eq!(
538
            ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
539
            4096
540
        );
541

            
542
        assert_eq!(
543
            ws.weight_bw_for_role(
544
                WeightKind::GUARD | WeightKind::DIR,
545
                &RW::Unmeasured(7777),
546
                WeightRole::Guard
547
            ),
548
            0
549
        );
550

            
551
        assert_eq!(
552
            ws.weight_bw_for_role(
553
                WeightKind::GUARD | WeightKind::DIR,
554
                &RW::Measured(7777),
555
                WeightRole::Guard
556
            ),
557
            7777 * 5904
558
        );
559

            
560
        assert_eq!(
561
            ws.weight_bw_for_role(
562
                WeightKind::GUARD | WeightKind::DIR,
563
                &RW::Measured(7777),
564
                WeightRole::Middle
565
            ),
566
            7777 * 4096
567
        );
568

            
569
        assert_eq!(
570
            ws.weight_bw_for_role(
571
                WeightKind::GUARD | WeightKind::DIR,
572
                &RW::Measured(7777),
573
                WeightRole::Exit
574
            ),
575
            7777 * 10000
576
        );
577

            
578
        assert_eq!(
579
            ws.weight_bw_for_role(
580
                WeightKind::GUARD | WeightKind::DIR,
581
                &RW::Measured(7777),
582
                WeightRole::BeginDir
583
            ),
584
            7777 * 4096
585
        );
586

            
587
        assert_eq!(
588
            ws.weight_bw_for_role(
589
                WeightKind::GUARD | WeightKind::DIR,
590
                &RW::Measured(7777),
591
                WeightRole::Unweighted
592
            ),
593
            7777
594
        );
595

            
596
        // Now try those last few with routerstatuses.
597
        let rs = rs_builder()
598
            .set_flags(RelayFlag::Guard | RelayFlag::V2Dir)
599
            .weight(RW::Measured(7777))
600
            .build()
601
            .unwrap();
602
        assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
603
        assert_eq!(
604
            ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
605
            7777 * 4096
606
        );
607
        assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
608
    }
609

            
610
    /// Return a routerstatus builder set up to deliver a routerstatus
611
    /// with most features disabled.
612
    fn rs_builder() -> MdRouterStatusBuilder {
613
        MdConsensus::builder()
614
            .rs()
615
            .identity([9; 20].into())
616
            .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
617
            .doc_digest([9; 32])
618
            .protos("".parse().unwrap())
619
            .clone()
620
    }
621

            
622
    #[test]
623
    fn weight_flags() {
624
        let rs1 = rs_builder().set_flags(RelayFlag::Exit).build().unwrap();
625
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
626

            
627
        let rs1 = rs_builder().set_flags(RelayFlag::Guard).build().unwrap();
628
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
629

            
630
        let rs1 = rs_builder().set_flags(RelayFlag::V2Dir).build().unwrap();
631
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
632

            
633
        let rs1 = rs_builder().build().unwrap();
634
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
635

            
636
        let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
637
        assert_eq!(
638
            WeightKind::for_rs(&rs1),
639
            WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
640
        );
641
    }
642

            
643
    #[test]
644
    fn weightset_from_consensus() {
645
        use rand::Rng;
646
        let now = SystemTime::get();
647
        let one_hour = Duration::new(3600, 0);
648
        let mut rng = testing_rng();
649
        let mut bld = MdConsensus::builder();
650
        bld.consensus_method(34)
651
            .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
652
            .weights(TESTVEC_PARAMS.parse().unwrap());
653

            
654
        // We're going to add a huge amount of unmeasured bandwidth,
655
        // and a reasonable amount of  measured bandwidth.
656
        for _ in 0..10 {
657
            rs_builder()
658
                .identity(rng.random::<[u8; 20]>().into()) // random id
659
                .weight(RW::Unmeasured(1_000_000))
660
                .set_flags(RelayFlag::Guard | RelayFlag::Exit)
661
                .build_into(&mut bld)
662
                .unwrap();
663
        }
664
        for n in 0..30 {
665
            rs_builder()
666
                .identity(rng.random::<[u8; 20]>().into()) // random id
667
                .weight(RW::Measured(1_000 * n))
668
                .set_flags(RelayFlag::Guard | RelayFlag::Exit)
669
                .build_into(&mut bld)
670
                .unwrap();
671
        }
672

            
673
        let consensus = bld.testing_consensus().unwrap();
674
        let params = NetParameters::default();
675
        let ws = WeightSet::from_consensus(&consensus, &params);
676

            
677
        assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
678
        assert_eq!(ws.shift, 0);
679
        assert_eq!(ws.w[0].as_guard, 5904);
680
        assert_eq!(ws.w[5].as_guard, 5904);
681
        assert_eq!(ws.w[5].as_middle, 4096);
682
    }
683
}