1
//! Exit policies: match patterns of addresses and/or ports.
2
//!
3
//! Every Tor relays has a set of address:port combinations that it
4
//! actually allows connections to.  The set, abstractly, is the
5
//! relay's "exit policy".
6
//!
7
//! Address policies can be transmitted in two forms.  One is a "full
8
//! policy", that includes a list of rules that are applied in order
9
//! to represent addresses and ports.  We represent this with the
10
//! AddrPolicy type.
11
//!
12
//! In microdescriptors, and for IPv6 policies, policies are just
13
//! given a list of ports for which _most_ addresses are permitted.
14
//! We represent this kind of policy with the PortPolicy type.
15
//!
16
//! TODO: This module probably belongs in a crate of its own, with
17
//! possibly only the parsing code in this crate.
18

            
19
mod addrpolicy;
20
mod portpolicy;
21

            
22
use std::str::FromStr;
23
use std::{collections::BTreeSet, fmt::Display};
24
use thiserror::Error;
25

            
26
pub use addrpolicy::{AddrPolicy, AddrPortPattern};
27
pub use portpolicy::PortPolicy;
28

            
29
use crate::NormalItemArgument;
30
#[cfg(feature = "parse2")]
31
use crate::parse2::{ArgumentError, ArgumentStream, ItemArgumentParseable};
32

            
33
/// Error from an unparsable or invalid policy.
34
#[derive(Debug, Error, Clone, PartialEq, Eq)]
35
#[non_exhaustive]
36
pub enum PolicyError {
37
    /// A port was not a number in the range 1..65535
38
    #[error("Invalid port")]
39
    InvalidPort,
40
    /// A port range had its starting-point higher than its ending point.
41
    #[error("Invalid port range")]
42
    InvalidRange,
43
    /// An address could not be interpreted.
44
    #[error("Invalid address")]
45
    InvalidAddress,
46
    /// Tried to use a bitmask with the address "*".
47
    #[error("mask with star")]
48
    MaskWithStar,
49
    /// A bit mask was out of range.
50
    #[error("invalid mask")]
51
    InvalidMask,
52
    /// A policy could not be parsed for some other reason.
53
    #[error("Invalid policy")]
54
    InvalidPolicy,
55
}
56

            
57
/// A PortRange is a set of consecutively numbered TCP or UDP ports.
58
///
59
/// # Example
60
/// ```
61
/// use tor_netdoc::types::policy::PortRange;
62
///
63
/// let r: PortRange = "22-8000".parse().unwrap();
64
/// assert!(r.contains(128));
65
/// assert!(r.contains(22));
66
/// assert!(r.contains(8000));
67
///
68
/// assert!(! r.contains(21));
69
/// assert!(! r.contains(8001));
70
/// ```
71
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72
#[allow(clippy::exhaustive_structs)]
73
pub struct PortRange {
74
    /// The first port in this range.
75
    lo: u16,
76
    /// The last port in this range.
77
    hi: u16,
78
}
79

            
80
impl PortRange {
81
    /// Create a new port range spanning from lo to hi, asserting that
82
    /// the correct invariants hold.
83
73346
    fn new_unchecked(lo: u16, hi: u16) -> Self {
84
73346
        assert!(lo != 0);
85
73346
        assert!(lo <= hi);
86
73346
        PortRange { lo, hi }
87
73346
    }
88
    /// Create a port range containing all ports.
89
25570
    pub fn new_all() -> Self {
90
25570
        PortRange::new_unchecked(1, 65535)
91
25570
    }
92
    /// Create a new PortRange.
93
    ///
94
    /// The Portrange contains all ports between `lo` and `hi` inclusive.
95
    ///
96
    /// Returns None if lo is greater than hi, or if either is zero.
97
617101
    pub fn new(lo: u16, hi: u16) -> Option<Self> {
98
617101
        if lo != 0 && lo <= hi {
99
617095
            Some(PortRange { lo, hi })
100
        } else {
101
6
            None
102
        }
103
617101
    }
104
    /// Return true if a port is in this range.
105
5464
    pub fn contains(&self, port: u16) -> bool {
106
5464
        self.lo <= port && port <= self.hi
107
5464
    }
108
    /// Return true if this range contains all ports.
109
24
    pub fn is_all(&self) -> bool {
110
24
        self.lo == 1 && self.hi == 65535
111
24
    }
112

            
113
    /// Helper for binary search: compare this range to a port.
114
    ///
115
    /// This range is "equal" to all ports that it contains.  It is
116
    /// "greater" than all ports that precede its starting point, and
117
    /// "less" than all ports that follow its ending point.
118
29568384
    fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
119
        use std::cmp::Ordering::*;
120
29568384
        if port < self.lo {
121
3320808
            Greater
122
26247576
        } else if port <= self.hi {
123
19880565
            Equal
124
        } else {
125
6367011
            Less
126
        }
127
29568384
    }
128
}
129

            
130
/// A PortRange is displayed as a number if it contains a single port,
131
/// and as a start point and end point separated by a dash if it contains
132
/// more than one port.
133
impl Display for PortRange {
134
44
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
44
        if self.lo == self.hi {
136
18
            write!(f, "{}", self.lo)
137
        } else {
138
26
            write!(f, "{}-{}", self.lo, self.hi)
139
        }
140
44
    }
141
}
142

            
143
impl FromStr for PortRange {
144
    type Err = PolicyError;
145
617071
    fn from_str(s: &str) -> Result<Self, PolicyError> {
146
617071
        let idx = s.find('-');
147
        // Find "lo" and "hi".
148
617071
        let (lo, hi) = if let Some(pos) = idx {
149
            // This is a range; parse each part.
150
            (
151
379390
                s[..pos]
152
379390
                    .parse::<u16>()
153
379390
                    .map_err(|_| PolicyError::InvalidPort)?,
154
379382
                s[pos + 1..]
155
379382
                    .parse::<u16>()
156
379382
                    .map_err(|_| PolicyError::InvalidPort)?,
157
            )
158
        } else {
159
            // There was no hyphen, so try to parse this range as a singleton.
160
237681
            let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
161
237663
            (v, v)
162
        };
163
617041
        PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
164
617071
    }
165
}
166

            
167
impl NormalItemArgument for PortRange {}
168

            
169
/// A collection of port ranges in a sorted order.
170
///
171
/// Please use this when storing multiple port ranges because it optimizies
172
/// them storage wise.
173
// TODO: We should rewrite most of this, the implementation has lots of
174
// potential for off-by-one errors and such.
175
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
176
// Invariant:
177
//
178
// The `PortRange`s are valid, nonoverlapping, non-abutting, and sorted.
179
struct PortRanges(Vec<PortRange>);
180

            
181
impl PortRanges {
182
    /// Creates a new [`PortRanges`] collection with no elements in it.
183
518442
    fn new() -> Self {
184
518442
        Self(Vec::new())
185
518442
    }
186

            
187
    /// Checks whether there are no ranges in this instance.
188
19501280
    fn is_empty(&self) -> bool {
189
19501280
        self.0.is_empty()
190
19501280
    }
191

            
192
    /// Adds a new range into this [`PortRanges`].
193
    ///
194
    /// The ranges must be valid, nonoverlapping, and pushed in a monotonically increasing order,
195
    /// meaning that inserting `400-500,450-600` or `400-500,500-600` are
196
    /// invalid, whereas `400-500,501-600` and `400-500,501-600` are.
197
664330
    fn push_ordered(&mut self, item: PortRange) -> Result<(), PolicyError> {
198
664330
        if let Some(prev) = self.0.last() {
199
            // TODO SPEC: We don't enforce this in Tor, but we probably
200
            // should.  See torspec#60.
201
150708
            if prev.hi >= item.lo {
202
16
                return Err(PolicyError::InvalidPolicy);
203
150692
            } else if prev.hi == item.lo - 1 {
204
                // We compress a-b,(b+1)-c into a-c.
205
20
                let r = PortRange::new_unchecked(prev.lo, item.hi);
206
20
                self.0.pop();
207
20
                self.0.push(r);
208
20
                return Ok(());
209
150672
            }
210
513622
        }
211

            
212
664294
        self.0.push(item);
213
664294
        Ok(())
214
664330
    }
215

            
216
    /// Checks whether `port` is contained in a range.
217
    ///
218
    /// Whether this means if `port` is allowed or rejected depends on the
219
    /// surroundings (such as which field this `PortRage` is in,
220
    /// or an associated [`RuleKind`]).
221
32963388
    fn contains(&self, port: u16) -> bool {
222
33147512
        debug_assert!(self.0.is_sorted_by(|a, b| a.lo < b.lo));
223
32963388
        self.0
224
33521514
            .binary_search_by(|range| range.compare_to_port(port))
225
32963388
            .is_ok()
226
32963388
    }
227

            
228
    /// Inverts a [`PortRanges`].
229
    ///
230
    /// For example, a [`PortRanges`] of `80-443` would become `1-79,444-65535`.
231
261954
    fn invert(&mut self) {
232
261954
        let mut prev_hi = 0;
233
261954
        let mut new_allowed = Vec::new();
234
261966
        for entry in &self.0 {
235
            // ports prev_hi+1 through entry.lo-1 were rejected.  We should
236
            // make them allowed.
237
261966
            if entry.lo > prev_hi + 1 {
238
22
                new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
239
261948
            }
240
261966
            prev_hi = entry.hi;
241
        }
242
261954
        if prev_hi < 65535 {
243
14
            new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
244
261940
        }
245
261954
        self.0 = new_allowed;
246
261954
    }
247

            
248
    /// Returns an iterator for [`PortRanges`].
249
10
    fn iter(&self) -> impl Iterator<Item = &PortRange> {
250
10
        self.0.iter()
251
10
    }
252
}
253

            
254
impl FromIterator<u16> for PortRanges {
255
19684
    fn from_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
256
        // Collect all ports into a BTreeSet to have them sorted and deduped.
257
19684
        let ports = iter.into_iter().collect::<BTreeSet<_>>();
258
19684
        let mut ports = ports.into_iter().peekable();
259

            
260
19684
        let mut out = Self::new();
261
19684
        let mut current_min = None;
262
97788
        while let Some(port) = ports.next() {
263
78104
            if current_min.is_none() {
264
47720
                current_min = Some(port);
265
47720
            }
266
78104
            if let Some(next_port) = ports.peek().copied() {
267
                // We do not have to worry about port == 65535, because then
268
                // ports.peek() will be None, as each item in the BTreeSet is
269
                // ordered and unique, implying that there won't be a successor
270
                // to a port == 65535.
271
63192
                if next_port != port + 1 {
272
32808
                    let _ = out.push_ordered(PortRange::new_unchecked(
273
32808
                        current_min.expect("Don't have min port number"),
274
32808
                        port,
275
32808
                    ));
276
32808
                    current_min = None;
277
32928
                }
278
14912
            } else {
279
14912
                let _ = out.push_ordered(PortRange::new_unchecked(
280
14912
                    current_min.expect("Don't have min port number"),
281
14912
                    port,
282
14912
                ));
283
14912
            }
284
        }
285

            
286
19684
        out
287
19684
    }
288
}
289

            
290
// There is deliberately no Display implementation for PortRanges because this
291
// highly depends on the semantic wrapper around it.  For example, an empty
292
// PortRanges may either be represented as `reject 1-65535` or `accept 1-65535`
293
// depending on the context.
294

            
295
impl FromStr for PortRanges {
296
    type Err = PolicyError;
297

            
298
498720
    fn from_str(s: &str) -> Result<Self, Self::Err> {
299
        // Pitfall: Do not use a clever iterator here because we need the result
300
        // of .push() in order to avoid things such as `30-19`.
301
498720
        let mut ranges = Self::new();
302
616624
        for range in s.split(',') {
303
616624
            ranges.push_ordered(range.parse()?)?;
304
        }
305
498690
        Ok(ranges)
306
498720
    }
307
}
308

            
309
#[cfg(feature = "parse2")]
310
impl ItemArgumentParseable for PortRanges {
311
    /// [`PortRanges`] argument parser which is odd because port ranges are
312
    /// syntactically a single argument although semantically multiple ones.
313
38
    fn from_args<'s>(args: &mut ArgumentStream<'s>) -> Result<Self, ArgumentError> {
314
38
        args.next()
315
38
            .map(Self::from_str)
316
38
            .unwrap_or(Ok(Self::new()))
317
38
            .map_err(|_| ArgumentError::Invalid)
318
38
    }
319
}
320

            
321
/// A kind of policy rule: either accepts or rejects addresses
322
/// matching a pattern.
323
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, derive_more::FromStr)]
324
#[display(rename_all = "lowercase")]
325
#[from_str(rename_all = "lowercase")]
326
#[allow(clippy::exhaustive_enums)]
327
pub enum RuleKind {
328
    /// A rule that accepts matching address:port combinations.
329
    Accept,
330
    /// A rule that rejects matching address:port combinations.
331
    Reject,
332
}
333

            
334
impl NormalItemArgument for RuleKind {}
335

            
336
#[cfg(test)]
337
mod test {
338
    // @@ begin test lint list maintained by maint/add_warning @@
339
    #![allow(clippy::bool_assert_comparison)]
340
    #![allow(clippy::clone_on_copy)]
341
    #![allow(clippy::dbg_macro)]
342
    #![allow(clippy::mixed_attributes_style)]
343
    #![allow(clippy::print_stderr)]
344
    #![allow(clippy::print_stdout)]
345
    #![allow(clippy::single_char_pattern)]
346
    #![allow(clippy::unwrap_used)]
347
    #![allow(clippy::unchecked_time_subtraction)]
348
    #![allow(clippy::useless_vec)]
349
    #![allow(clippy::needless_pass_by_value)]
350
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
351
    use super::*;
352
    use crate::Result;
353
    #[test]
354
    fn parse_portrange() -> Result<()> {
355
        assert_eq!(
356
            "1-100".parse::<PortRange>()?,
357
            PortRange::new(1, 100).unwrap()
358
        );
359
        assert_eq!(
360
            "01-100".parse::<PortRange>()?,
361
            PortRange::new(1, 100).unwrap()
362
        );
363
        assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
364
        assert_eq!(
365
            "10-30".parse::<PortRange>()?,
366
            PortRange::new(10, 30).unwrap()
367
        );
368
        assert_eq!(
369
            "9001".parse::<PortRange>()?,
370
            PortRange::new(9001, 9001).unwrap()
371
        );
372
        assert_eq!(
373
            "9001-9001".parse::<PortRange>()?,
374
            PortRange::new(9001, 9001).unwrap()
375
        );
376

            
377
        assert!("hello".parse::<PortRange>().is_err());
378
        assert!("0".parse::<PortRange>().is_err());
379
        assert!("65536".parse::<PortRange>().is_err());
380
        assert!("65537".parse::<PortRange>().is_err());
381
        assert!("1-2-3".parse::<PortRange>().is_err());
382
        assert!("10-5".parse::<PortRange>().is_err());
383
        assert!("1-".parse::<PortRange>().is_err());
384
        assert!("-2".parse::<PortRange>().is_err());
385
        assert!("-".parse::<PortRange>().is_err());
386
        assert!("*".parse::<PortRange>().is_err());
387
        Ok(())
388
    }
389

            
390
    #[test]
391
    fn pr_manip() {
392
        assert!(PortRange::new_all().is_all());
393
        assert!(!PortRange::new(2, 65535).unwrap().is_all());
394

            
395
        assert!(PortRange::new_all().contains(1));
396
        assert!(PortRange::new_all().contains(65535));
397
        assert!(PortRange::new_all().contains(7777));
398

            
399
        assert!(PortRange::new(20, 30).unwrap().contains(20));
400
        assert!(PortRange::new(20, 30).unwrap().contains(25));
401
        assert!(PortRange::new(20, 30).unwrap().contains(30));
402
        assert!(!PortRange::new(20, 30).unwrap().contains(19));
403
        assert!(!PortRange::new(20, 30).unwrap().contains(31));
404

            
405
        use std::cmp::Ordering::*;
406
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
407
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
408
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
409
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
410
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
411
    }
412

            
413
    #[test]
414
    fn pr_fmt() {
415
        fn chk(a: u16, b: u16, s: &str) {
416
            let pr = PortRange::new(a, b).unwrap();
417
            assert_eq!(format!("{}", pr), s);
418
        }
419

            
420
        chk(1, 65535, "1-65535");
421
        chk(10, 20, "10-20");
422
        chk(20, 20, "20");
423
    }
424

            
425
    #[test]
426
    fn port_ranges() {
427
        const INPUT: &str = "22,80,443,8000-9000,9002";
428
        let ranges = PortRanges::from_str(INPUT).unwrap();
429
        assert_eq!(
430
            ranges.0,
431
            [
432
                PortRange::new(22, 22).unwrap(),
433
                PortRange::new(80, 80).unwrap(),
434
                PortRange::new(443, 443).unwrap(),
435
                PortRange::new(8000, 9000).unwrap(),
436
                PortRange::new(9002, 9002).unwrap(),
437
            ]
438
        );
439
        assert!(ranges.contains(22));
440
        assert!(ranges.contains(80));
441
        assert!(ranges.contains(443));
442
        assert!(ranges.contains(8000));
443
        assert!(ranges.contains(8500));
444
        assert!(ranges.contains(9000));
445
        assert!(!ranges.contains(9001));
446
        assert!(ranges.contains(9002));
447

            
448
        let mut ranges_inverse = ranges.clone();
449
        ranges_inverse.invert();
450
        assert_eq!(
451
            ranges_inverse.0,
452
            [
453
                PortRange::new(1, 21).unwrap(),
454
                PortRange::new(23, 79).unwrap(),
455
                PortRange::new(81, 442).unwrap(),
456
                PortRange::new(444, 7999).unwrap(),
457
                PortRange::new(9001, 9001).unwrap(),
458
                PortRange::new(9003, 65535).unwrap(),
459
            ]
460
        );
461

            
462
        #[cfg(feature = "parse2")]
463
        {
464
            use crate::parse2::{self, ParseInput};
465

            
466
            #[derive(derive_deftly::Deftly)]
467
            #[derive_deftly(NetdocParseable)]
468
            struct Dummy {
469
                #[deftly(netdoc(single_arg))]
470
                dummy: PortRanges,
471
            }
472
            let ranges2 =
473
                parse2::parse_netdoc::<Dummy>(&ParseInput::new(&format!("dummy {INPUT}\n"), ""))
474
                    .unwrap();
475
            assert_eq!(ranges, ranges2.dummy);
476
        }
477
    }
478
}