1
//! Exit policies: match patterns of addresses and/or ports.
2
//!
3
//! Every Tor relays has a set of address:port combinations that it
4
//! actually allows connections to.  The set, abstractly, is the
5
//! relay's "exit policy".
6
//!
7
//! Address policies can be transmitted in two forms.  One is a "full
8
//! policy", that includes a list of rules that are applied in order
9
//! to represent addresses and ports.  We represent this with the
10
//! AddrPolicy type.
11
//!
12
//! In microdescriptors, and for IPv6 policies, policies are just
13
//! given a list of ports for which _most_ addresses are permitted.
14
//! We represent this kind of policy with the PortPolicy type.
15
//!
16
//! TODO: This module probably belongs in a crate of its own, with
17
//! possibly only the parsing code in this crate.
18

            
19
mod addrpolicy;
20
mod portpolicy;
21

            
22
use std::fmt;
23
use std::str::FromStr;
24
use std::{collections::BTreeSet, fmt::Display};
25
use thiserror::Error;
26
use tor_basic_utils::iter_join;
27

            
28
pub use addrpolicy::{AddrPolicy, AddrPortPattern};
29
pub use portpolicy::PortPolicy;
30

            
31
use crate::NormalItemArgument;
32
use crate::parse2::{ArgumentError, ArgumentStream, ItemArgumentParseable};
33

            
34
/// Error from an unparsable or invalid policy.
35
#[derive(Debug, Error, Clone, PartialEq, Eq)]
36
#[non_exhaustive]
37
pub enum PolicyError {
38
    /// A port was not a number in the range 1..65535
39
    #[error("Invalid port")]
40
    InvalidPort,
41
    /// A port range had its starting-point higher than its ending point.
42
    #[error("Invalid port range")]
43
    InvalidRange,
44
    /// An address could not be interpreted.
45
    #[error("Invalid address")]
46
    InvalidAddress,
47
    /// Tried to use a bitmask or prefix len with the address "*".
48
    // TODO maybe rename this, we never use masks, only prefix lengths
49
    #[error("mask or prefix length with star")]
50
    MaskWithStar,
51
    /// A bit mask was out of range.
52
    // TODO maybe rename this, we never use masks, only prefix lengths
53
    #[error("invalid prefix length or mask")]
54
    InvalidMask,
55
    /// A policy could not be parsed for some other reason.
56
    #[error("Invalid policy")]
57
    InvalidPolicy,
58
}
59

            
60
/// A PortRange is a set of consecutively numbered TCP or UDP ports.
61
///
62
/// # Example
63
/// ```
64
/// use tor_netdoc::types::policy::PortRange;
65
///
66
/// let r: PortRange = "22-8000".parse().unwrap();
67
/// assert!(r.contains(128));
68
/// assert!(r.contains(22));
69
/// assert!(r.contains(8000));
70
///
71
/// assert!(! r.contains(21));
72
/// assert!(! r.contains(8001));
73
/// ```
74
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75
#[allow(clippy::exhaustive_structs)]
76
pub struct PortRange {
77
    /// The first port in this range.
78
    lo: u16,
79
    /// The last port in this range.
80
    hi: u16,
81
}
82

            
83
impl PortRange {
84
    /// Create a new port range spanning from lo to hi, asserting that
85
    /// the correct invariants hold.
86
72282
    const fn new_unchecked(lo: u16, hi: u16) -> Self {
87
72282
        assert!(lo != 0);
88
72282
        assert!(lo <= hi);
89
72282
        PortRange { lo, hi }
90
72282
    }
91
    /// Create a port range containing all ports.
92
24618
    pub const fn new_all() -> Self {
93
24618
        PortRange::new_unchecked(1, 65535)
94
24618
    }
95
    /// Create a new PortRange.
96
    ///
97
    /// The Portrange contains all ports between `lo` and `hi` inclusive.
98
    ///
99
    /// Returns None if lo is greater than hi, or if either is zero.
100
605964
    pub fn new(lo: u16, hi: u16) -> Option<Self> {
101
605964
        if lo != 0 && lo <= hi {
102
605958
            Some(PortRange { lo, hi })
103
        } else {
104
6
            None
105
        }
106
605964
    }
107
    /// Return true if a port is in this range.
108
5396
    pub fn contains(&self, port: u16) -> bool {
109
5396
        self.lo <= port && port <= self.hi
110
5396
    }
111
    /// Return true if this range contains all ports.
112
50
    pub fn is_all(&self) -> bool {
113
50
        self.lo == 1 && self.hi == 65535
114
50
    }
115

            
116
    /// Helper for binary search: compare this range to a port.
117
    ///
118
    /// This range is "equal" to all ports that it contains.  It is
119
    /// "greater" than all ports that precede its starting point, and
120
    /// "less" than all ports that follow its ending point.
121
29010496
    fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
122
        use std::cmp::Ordering::*;
123
29010496
        if port < self.lo {
124
3258154
            Greater
125
25752342
        } else if port <= self.hi {
126
19505462
            Equal
127
        } else {
128
6246880
            Less
129
        }
130
29010496
    }
131
}
132

            
133
/// A PortRange is displayed as a number if it contains a single port,
134
/// and as a start point and end point separated by a dash if it contains
135
/// more than one port.
136
impl Display for PortRange {
137
108
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138
108
        if self.lo == self.hi {
139
32
            write!(f, "{}", self.lo)
140
        } else {
141
76
            write!(f, "{}-{}", self.lo, self.hi)
142
        }
143
108
    }
144
}
145

            
146
impl FromStr for PortRange {
147
    type Err = PolicyError;
148
605934
    fn from_str(s: &str) -> Result<Self, PolicyError> {
149
605934
        let (lo, hi) = match s.split_once('-') {
150
372714
            Some((lo, hi)) => (
151
372714
                lo.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?,
152
372706
                hi.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?,
153
            ),
154
            None => {
155
                // There was no hyphen, so try to parse this range as a singleton.
156
233220
                let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
157
233202
                (v, v)
158
            }
159
        };
160
605904
        PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
161
605934
    }
162
}
163

            
164
impl NormalItemArgument for PortRange {}
165

            
166
/// A collection of port ranges in a sorted order.
167
///
168
/// Please use this when storing multiple port ranges because it optimizies
169
/// them storage wise.
170
// TODO: We should rewrite most of this, the implementation has lots of
171
// potential for off-by-one errors and such.
172
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
173
// Invariant:
174
//
175
// The `PortRange`s are valid, nonoverlapping, non-abutting, and sorted.
176
struct PortRanges(Vec<PortRange>);
177

            
178
impl PortRanges {
179
    /// Creates a new [`PortRanges`] collection with no elements in it.
180
509798
    fn new() -> Self {
181
509798
        Self(Vec::new())
182
509798
    }
183

            
184
    /// Checks whether there are no ranges in this instance.
185
19133880
    fn is_empty(&self) -> bool {
186
19133880
        self.0.is_empty()
187
19133880
    }
188

            
189
    /// Adds a new range into this [`PortRanges`].
190
    ///
191
    /// The ranges must be valid, nonoverlapping, and pushed in a monotonically increasing order,
192
    /// meaning that inserting `400-500,450-600` or `400-500,500-600` are
193
    /// invalid, whereas `400-500,501-600` and `400-500,501-600` are.
194
653048
    fn push_ordered(&mut self, item: PortRange) -> Result<(), PolicyError> {
195
653048
        if let Some(prev) = self.0.last() {
196
            // TODO SPEC: We don't enforce this in Tor, but we probably
197
            // should.  See torspec#60.
198
148452
            if prev.hi >= item.lo {
199
16
                return Err(PolicyError::InvalidPolicy);
200
148436
            } else if prev.hi == item.lo - 1 {
201
                // We compress a-b,(b+1)-c into a-c.
202
20
                let r = PortRange::new_unchecked(prev.lo, item.hi);
203
20
                self.0.pop();
204
20
                self.0.push(r);
205
20
                return Ok(());
206
148416
            }
207
504596
        }
208

            
209
653012
        self.0.push(item);
210
653012
        Ok(())
211
653048
    }
212

            
213
    /// Checks whether `port` is contained in a range.
214
    ///
215
    /// Whether this means if `port` is allowed or rejected depends on the
216
    /// surroundings (such as which field this `PortRage` is in,
217
    /// or an associated [`RuleKind`]).
218
32341440
    fn contains(&self, port: u16) -> bool {
219
32525564
        debug_assert!(self.0.is_sorted_by(|a, b| a.lo < b.lo));
220
32341440
        self.0
221
32899566
            .binary_search_by(|range| range.compare_to_port(port))
222
32341440
            .is_ok()
223
32341440
    }
224

            
225
    /// Returns an inverted [`PortRanges`].
226
    ///
227
    /// For example, a [`PortRanges`] of `80-443` would become `1-79,444-65535`.
228
257504
    fn inverted(&self) -> PortRanges {
229
257504
        let mut prev_hi = 0;
230
257504
        let mut new_allowed = Vec::new();
231
257504
        for entry in &self.0 {
232
            // ports prev_hi+1 through entry.lo-1 were rejected.  We should
233
            // make them allowed.
234
257504
            if entry.lo > prev_hi + 1 {
235
30
                new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
236
257474
            }
237
257504
            prev_hi = entry.hi;
238
        }
239
257504
        if prev_hi < 65535 {
240
32
            new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
241
257472
        }
242
257504
        PortRanges(new_allowed)
243
257504
    }
244

            
245
    /// Inverts a [`PortRanges`] in place
246
    ///
247
    /// For example, a [`PortRanges`] of `80-443` would become `1-79,444-65535`.
248
257460
    fn invert(&mut self) {
249
257460
        *self = self.inverted();
250
257460
    }
251

            
252
    /// Returns an iterator for [`PortRanges`].
253
54
    fn iter(&self) -> impl Iterator<Item = &PortRange> + Clone {
254
54
        self.0.iter()
255
54
    }
256

            
257
    /// If set of ranges is non-empty, returns a string representation
258
    ///
259
    /// We don't provide a normal `Display` impl, because it would have to
260
    /// emit the empty string for an empty range, which would be quite odd.
261
    ///
262
    /// When displaying accept/reject ranges, the caller needs to
263
    /// choose between prepending `accept` and prepending `reject`.
264
88
    fn display(&self) -> Option<impl Display + '_> {
265
        struct DisplayWrapper<'r>(&'r PortRanges);
266

            
267
        impl Display for DisplayWrapper<'_> {
268
54
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
269
54
                write!(f, "{}", iter_join(",", self.0.iter()))
270
54
            }
271
        }
272

            
273
88
        (!self.is_empty()).then_some(DisplayWrapper(self))
274
88
    }
275
}
276

            
277
impl FromIterator<u16> for PortRanges {
278
19504
    fn from_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
279
        // Collect all ports into a BTreeSet to have them sorted and deduped.
280
19504
        let ports = iter.into_iter().collect::<BTreeSet<_>>();
281
19504
        let mut ports = ports.into_iter().peekable();
282

            
283
19504
        let mut out = Self::new();
284
19504
        let mut current_min = None;
285
97470
        while let Some(port) = ports.next() {
286
77966
            if current_min.is_none() {
287
47582
                current_min = Some(port);
288
47582
            }
289
77966
            if let Some(next_port) = ports.peek().copied() {
290
                // We do not have to worry about port == 65535, because then
291
                // ports.peek() will be None, as each item in the BTreeSet is
292
                // ordered and unique, implying that there won't be a successor
293
                // to a port == 65535.
294
63144
                if next_port != port + 1 {
295
32760
                    let _ = out.push_ordered(PortRange::new_unchecked(
296
32760
                        current_min.expect("Don't have min port number"),
297
32760
                        port,
298
32760
                    ));
299
32760
                    current_min = None;
300
32880
                }
301
14822
            } else {
302
14822
                let _ = out.push_ordered(PortRange::new_unchecked(
303
14822
                    current_min.expect("Don't have min port number"),
304
14822
                    port,
305
14822
                ));
306
14822
            }
307
        }
308

            
309
19504
        out
310
19504
    }
311
}
312

            
313
impl FromStr for PortRanges {
314
    type Err = PolicyError;
315

            
316
489784
    fn from_str(s: &str) -> Result<Self, Self::Err> {
317
        // Pitfall: Do not use a clever iterator here because we need the result
318
        // of .push() in order to avoid things such as `30-19`.
319
489784
        let mut ranges = Self::new();
320
605480
        for range in s.split(',') {
321
605480
            ranges.push_ordered(range.parse()?)?;
322
        }
323
489754
        Ok(ranges)
324
489784
    }
325
}
326

            
327
impl ItemArgumentParseable for PortRanges {
328
    /// [`PortRanges`] argument parser which is odd because port ranges are
329
    /// syntactically a single argument although semantically multiple ones.
330
510
    fn from_args<'s>(args: &mut ArgumentStream<'s>) -> Result<Self, ArgumentError> {
331
510
        args.next()
332
510
            .map(Self::from_str)
333
510
            .unwrap_or(Ok(Self::new()))
334
510
            .map_err(|_| ArgumentError::Invalid)
335
510
    }
336
}
337

            
338
/// A kind of policy rule: either accepts or rejects addresses
339
/// matching a pattern.
340
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, derive_more::FromStr)]
341
#[display(rename_all = "lowercase")]
342
#[from_str(rename_all = "lowercase")]
343
#[allow(clippy::exhaustive_enums)]
344
pub enum RuleKind {
345
    /// A rule that accepts matching address:port combinations.
346
    Accept,
347
    /// A rule that rejects matching address:port combinations.
348
    Reject,
349
}
350

            
351
impl NormalItemArgument for RuleKind {}
352

            
353
#[cfg(test)]
354
mod test {
355
    // @@ begin test lint list maintained by maint/add_warning @@
356
    #![allow(clippy::bool_assert_comparison)]
357
    #![allow(clippy::clone_on_copy)]
358
    #![allow(clippy::dbg_macro)]
359
    #![allow(clippy::mixed_attributes_style)]
360
    #![allow(clippy::print_stderr)]
361
    #![allow(clippy::print_stdout)]
362
    #![allow(clippy::single_char_pattern)]
363
    #![allow(clippy::unwrap_used)]
364
    #![allow(clippy::unchecked_time_subtraction)]
365
    #![allow(clippy::useless_vec)]
366
    #![allow(clippy::needless_pass_by_value)]
367
    #![allow(clippy::string_slice)] // See arti#2571
368
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
369
    use super::*;
370
    use crate::Result;
371
    use crate::parse2::{self, ParseInput};
372

            
373
    #[test]
374
    fn parse_portrange() -> Result<()> {
375
        assert_eq!(
376
            "1-100".parse::<PortRange>()?,
377
            PortRange::new(1, 100).unwrap()
378
        );
379
        assert_eq!(
380
            "01-100".parse::<PortRange>()?,
381
            PortRange::new(1, 100).unwrap()
382
        );
383
        assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
384
        assert_eq!(
385
            "10-30".parse::<PortRange>()?,
386
            PortRange::new(10, 30).unwrap()
387
        );
388
        assert_eq!(
389
            "9001".parse::<PortRange>()?,
390
            PortRange::new(9001, 9001).unwrap()
391
        );
392
        assert_eq!(
393
            "9001-9001".parse::<PortRange>()?,
394
            PortRange::new(9001, 9001).unwrap()
395
        );
396

            
397
        assert!("hello".parse::<PortRange>().is_err());
398
        assert!("0".parse::<PortRange>().is_err());
399
        assert!("65536".parse::<PortRange>().is_err());
400
        assert!("65537".parse::<PortRange>().is_err());
401
        assert!("1-2-3".parse::<PortRange>().is_err());
402
        assert!("10-5".parse::<PortRange>().is_err());
403
        assert!("1-".parse::<PortRange>().is_err());
404
        assert!("-2".parse::<PortRange>().is_err());
405
        assert!("-".parse::<PortRange>().is_err());
406
        assert!("*".parse::<PortRange>().is_err());
407
        Ok(())
408
    }
409

            
410
    #[test]
411
    fn pr_manip() {
412
        assert!(PortRange::new_all().is_all());
413
        assert!(!PortRange::new(2, 65535).unwrap().is_all());
414

            
415
        assert!(PortRange::new_all().contains(1));
416
        assert!(PortRange::new_all().contains(65535));
417
        assert!(PortRange::new_all().contains(7777));
418

            
419
        assert!(PortRange::new(20, 30).unwrap().contains(20));
420
        assert!(PortRange::new(20, 30).unwrap().contains(25));
421
        assert!(PortRange::new(20, 30).unwrap().contains(30));
422
        assert!(!PortRange::new(20, 30).unwrap().contains(19));
423
        assert!(!PortRange::new(20, 30).unwrap().contains(31));
424

            
425
        use std::cmp::Ordering::*;
426
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
427
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
428
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
429
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
430
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
431
    }
432

            
433
    #[test]
434
    fn pr_fmt() {
435
        fn chk(a: u16, b: u16, s: &str) {
436
            let pr = PortRange::new(a, b).unwrap();
437
            assert_eq!(format!("{}", pr), s);
438
        }
439

            
440
        chk(1, 65535, "1-65535");
441
        chk(10, 20, "10-20");
442
        chk(20, 20, "20");
443
    }
444

            
445
    #[test]
446
    fn port_ranges() {
447
        const INPUT: &str = "22,80,443,8000-9000,9002";
448
        let ranges = PortRanges::from_str(INPUT).unwrap();
449
        assert_eq!(
450
            ranges.0,
451
            [
452
                PortRange::new(22, 22).unwrap(),
453
                PortRange::new(80, 80).unwrap(),
454
                PortRange::new(443, 443).unwrap(),
455
                PortRange::new(8000, 9000).unwrap(),
456
                PortRange::new(9002, 9002).unwrap(),
457
            ]
458
        );
459
        assert!(ranges.contains(22));
460
        assert!(ranges.contains(80));
461
        assert!(ranges.contains(443));
462
        assert!(ranges.contains(8000));
463
        assert!(ranges.contains(8500));
464
        assert!(ranges.contains(9000));
465
        assert!(!ranges.contains(9001));
466
        assert!(ranges.contains(9002));
467

            
468
        let mut ranges_inverse = ranges.clone();
469
        ranges_inverse.invert();
470
        assert_eq!(
471
            ranges_inverse.0,
472
            [
473
                PortRange::new(1, 21).unwrap(),
474
                PortRange::new(23, 79).unwrap(),
475
                PortRange::new(81, 442).unwrap(),
476
                PortRange::new(444, 7999).unwrap(),
477
                PortRange::new(9001, 9001).unwrap(),
478
                PortRange::new(9003, 65535).unwrap(),
479
            ]
480
        );
481

            
482
        #[derive(derive_deftly::Deftly)]
483
        #[derive_deftly(NetdocParseable)]
484
        struct Dummy {
485
            #[deftly(netdoc(single_arg))]
486
            dummy: PortRanges,
487
        }
488
        let ranges2 =
489
            parse2::parse_netdoc::<Dummy>(&ParseInput::new(&format!("dummy {INPUT}\n"), ""))
490
                .unwrap();
491
        assert_eq!(ranges, ranges2.dummy);
492
    }
493
}