1
//! Exit policies: match patterns of addresses and/or ports.
2
//!
3
//! Every Tor relays has a set of address:port combinations that it
4
//! actually allows connections to.  The set, abstractly, is the
5
//! relay's "exit policy".
6
//!
7
//! Address policies can be transmitted in two forms.  One is a "full
8
//! policy", that includes a list of rules that are applied in order
9
//! to represent addresses and ports.  We represent this with the
10
//! AddrPolicy type.
11
//!
12
//! In microdescriptors, and for IPv6 policies, policies are just
13
//! given a list of ports for which _most_ addresses are permitted.
14
//! We represent this kind of policy with the PortPolicy type.
15
//!
16
//! TODO: This module probably belongs in a crate of its own, with
17
//! possibly only the parsing code in this crate.
18

            
19
mod addrpolicy;
20
mod portpolicy;
21

            
22
use std::str::FromStr;
23
use std::{collections::BTreeSet, fmt::Display};
24
use thiserror::Error;
25

            
26
pub use addrpolicy::{AddrPolicy, AddrPortPattern};
27
pub use portpolicy::PortPolicy;
28

            
29
use crate::NormalItemArgument;
30
use crate::parse2::{ArgumentError, ArgumentStream, ItemArgumentParseable};
31

            
32
/// Error from an unparsable or invalid policy.
33
#[derive(Debug, Error, Clone, PartialEq, Eq)]
34
#[non_exhaustive]
35
pub enum PolicyError {
36
    /// A port was not a number in the range 1..65535
37
    #[error("Invalid port")]
38
    InvalidPort,
39
    /// A port range had its starting-point higher than its ending point.
40
    #[error("Invalid port range")]
41
    InvalidRange,
42
    /// An address could not be interpreted.
43
    #[error("Invalid address")]
44
    InvalidAddress,
45
    /// Tried to use a bitmask with the address "*".
46
    #[error("mask with star")]
47
    MaskWithStar,
48
    /// A bit mask was out of range.
49
    #[error("invalid mask")]
50
    InvalidMask,
51
    /// A policy could not be parsed for some other reason.
52
    #[error("Invalid policy")]
53
    InvalidPolicy,
54
}
55

            
56
/// A PortRange is a set of consecutively numbered TCP or UDP ports.
57
///
58
/// # Example
59
/// ```
60
/// use tor_netdoc::types::policy::PortRange;
61
///
62
/// let r: PortRange = "22-8000".parse().unwrap();
63
/// assert!(r.contains(128));
64
/// assert!(r.contains(22));
65
/// assert!(r.contains(8000));
66
///
67
/// assert!(! r.contains(21));
68
/// assert!(! r.contains(8001));
69
/// ```
70
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71
#[allow(clippy::exhaustive_structs)]
72
pub struct PortRange {
73
    /// The first port in this range.
74
    lo: u16,
75
    /// The last port in this range.
76
    hi: u16,
77
}
78

            
79
impl PortRange {
80
    /// Create a new port range spanning from lo to hi, asserting that
81
    /// the correct invariants hold.
82
73972
    fn new_unchecked(lo: u16, hi: u16) -> Self {
83
73972
        assert!(lo != 0);
84
73972
        assert!(lo <= hi);
85
73972
        PortRange { lo, hi }
86
73972
    }
87
    /// Create a port range containing all ports.
88
26058
    pub fn new_all() -> Self {
89
26058
        PortRange::new_unchecked(1, 65535)
90
26058
    }
91
    /// Create a new PortRange.
92
    ///
93
    /// The Portrange contains all ports between `lo` and `hi` inclusive.
94
    ///
95
    /// Returns None if lo is greater than hi, or if either is zero.
96
629196
    pub fn new(lo: u16, hi: u16) -> Option<Self> {
97
629196
        if lo != 0 && lo <= hi {
98
629190
            Some(PortRange { lo, hi })
99
        } else {
100
6
            None
101
        }
102
629196
    }
103
    /// Return true if a port is in this range.
104
5924
    pub fn contains(&self, port: u16) -> bool {
105
5924
        self.lo <= port && port <= self.hi
106
5924
    }
107
    /// Return true if this range contains all ports.
108
24
    pub fn is_all(&self) -> bool {
109
24
        self.lo == 1 && self.hi == 65535
110
24
    }
111

            
112
    /// Helper for binary search: compare this range to a port.
113
    ///
114
    /// This range is "equal" to all ports that it contains.  It is
115
    /// "greater" than all ports that precede its starting point, and
116
    /// "less" than all ports that follow its ending point.
117
30126272
    fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
118
        use std::cmp::Ordering::*;
119
30126272
        if port < self.lo {
120
3383462
            Greater
121
26742810
        } else if port <= self.hi {
122
20255668
            Equal
123
        } else {
124
6487142
            Less
125
        }
126
30126272
    }
127
}
128

            
129
/// A PortRange is displayed as a number if it contains a single port,
130
/// and as a start point and end point separated by a dash if it contains
131
/// more than one port.
132
impl Display for PortRange {
133
44
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134
44
        if self.lo == self.hi {
135
18
            write!(f, "{}", self.lo)
136
        } else {
137
26
            write!(f, "{}-{}", self.lo, self.hi)
138
        }
139
44
    }
140
}
141

            
142
impl FromStr for PortRange {
143
    type Err = PolicyError;
144
629166
    fn from_str(s: &str) -> Result<Self, PolicyError> {
145
629166
        let (lo, hi) = match s.split_once('-') {
146
387006
            Some((lo, hi)) => (
147
387006
                lo.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?,
148
386998
                hi.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?,
149
            ),
150
            None => {
151
                // There was no hyphen, so try to parse this range as a singleton.
152
242160
                let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
153
242142
                (v, v)
154
            }
155
        };
156
629136
        PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
157
629166
    }
158
}
159

            
160
impl NormalItemArgument for PortRange {}
161

            
162
/// A collection of port ranges in a sorted order.
163
///
164
/// Please use this when storing multiple port ranges because it optimizies
165
/// them storage wise.
166
// TODO: We should rewrite most of this, the implementation has lots of
167
// potential for off-by-one errors and such.
168
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
169
// Invariant:
170
//
171
// The `PortRange`s are valid, nonoverlapping, non-abutting, and sorted.
172
struct PortRanges(Vec<PortRange>);
173

            
174
impl PortRanges {
175
    /// Creates a new [`PortRanges`] collection with no elements in it.
176
528950
    fn new() -> Self {
177
528950
        Self(Vec::new())
178
528950
    }
179

            
180
    /// Checks whether there are no ranges in this instance.
181
19869688
    fn is_empty(&self) -> bool {
182
19869688
        self.0.is_empty()
183
19869688
    }
184

            
185
    /// Adds a new range into this [`PortRanges`].
186
    ///
187
    /// The ranges must be valid, nonoverlapping, and pushed in a monotonically increasing order,
188
    /// meaning that inserting `400-500,450-600` or `400-500,500-600` are
189
    /// invalid, whereas `400-500,501-600` and `400-500,501-600` are.
190
676548
    fn push_ordered(&mut self, item: PortRange) -> Result<(), PolicyError> {
191
676548
        if let Some(prev) = self.0.last() {
192
            // TODO SPEC: We don't enforce this in Tor, but we probably
193
            // should.  See torspec#60.
194
152968
            if prev.hi >= item.lo {
195
16
                return Err(PolicyError::InvalidPolicy);
196
152952
            } else if prev.hi == item.lo - 1 {
197
                // We compress a-b,(b+1)-c into a-c.
198
20
                let r = PortRange::new_unchecked(prev.lo, item.hi);
199
20
                self.0.pop();
200
20
                self.0.push(r);
201
20
                return Ok(());
202
152932
            }
203
523580
        }
204

            
205
676512
        self.0.push(item);
206
676512
        Ok(())
207
676548
    }
208

            
209
    /// Checks whether `port` is contained in a range.
210
    ///
211
    /// Whether this means if `port` is allowed or rejected depends on the
212
    /// surroundings (such as which field this `PortRage` is in,
213
    /// or an associated [`RuleKind`]).
214
33585336
    fn contains(&self, port: u16) -> bool {
215
33769460
        debug_assert!(self.0.is_sorted_by(|a, b| a.lo < b.lo));
216
33585336
        self.0
217
34143462
            .binary_search_by(|range| range.compare_to_port(port))
218
33585336
            .is_ok()
219
33585336
    }
220

            
221
    /// Inverts a [`PortRanges`].
222
    ///
223
    /// For example, a [`PortRanges`] of `80-443` would become `1-79,444-65535`.
224
267348
    fn invert(&mut self) {
225
267348
        let mut prev_hi = 0;
226
267348
        let mut new_allowed = Vec::new();
227
267360
        for entry in &self.0 {
228
            // ports prev_hi+1 through entry.lo-1 were rejected.  We should
229
            // make them allowed.
230
267360
            if entry.lo > prev_hi + 1 {
231
22
                new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
232
267338
            }
233
267360
            prev_hi = entry.hi;
234
        }
235
267348
        if prev_hi < 65535 {
236
14
            new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
237
267334
        }
238
267348
        self.0 = new_allowed;
239
267348
    }
240

            
241
    /// Returns an iterator for [`PortRanges`].
242
10
    fn iter(&self) -> impl Iterator<Item = &PortRange> {
243
10
        self.0.iter()
244
10
    }
245
}
246

            
247
impl FromIterator<u16> for PortRanges {
248
19864
    fn from_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
249
        // Collect all ports into a BTreeSet to have them sorted and deduped.
250
19864
        let ports = iter.into_iter().collect::<BTreeSet<_>>();
251
19864
        let mut ports = ports.into_iter().peekable();
252

            
253
19864
        let mut out = Self::new();
254
19864
        let mut current_min = None;
255
98106
        while let Some(port) = ports.next() {
256
78242
            if current_min.is_none() {
257
47858
                current_min = Some(port);
258
47858
            }
259
78242
            if let Some(next_port) = ports.peek().copied() {
260
                // We do not have to worry about port == 65535, because then
261
                // ports.peek() will be None, as each item in the BTreeSet is
262
                // ordered and unique, implying that there won't be a successor
263
                // to a port == 65535.
264
63240
                if next_port != port + 1 {
265
32856
                    let _ = out.push_ordered(PortRange::new_unchecked(
266
32856
                        current_min.expect("Don't have min port number"),
267
32856
                        port,
268
32856
                    ));
269
32856
                    current_min = None;
270
32976
                }
271
15002
            } else {
272
15002
                let _ = out.push_ordered(PortRange::new_unchecked(
273
15002
                    current_min.expect("Don't have min port number"),
274
15002
                    port,
275
15002
                ));
276
15002
            }
277
        }
278

            
279
19864
        out
280
19864
    }
281
}
282

            
283
// There is deliberately no Display implementation for PortRanges because this
284
// highly depends on the semantic wrapper around it.  For example, an empty
285
// PortRanges may either be represented as `reject 1-65535` or `accept 1-65535`
286
// depending on the context.
287

            
288
impl FromStr for PortRanges {
289
    type Err = PolicyError;
290

            
291
508588
    fn from_str(s: &str) -> Result<Self, Self::Err> {
292
        // Pitfall: Do not use a clever iterator here because we need the result
293
        // of .push() in order to avoid things such as `30-19`.
294
508588
        let mut ranges = Self::new();
295
628704
        for range in s.split(',') {
296
628704
            ranges.push_ordered(range.parse()?)?;
297
        }
298
508558
        Ok(ranges)
299
508588
    }
300
}
301

            
302
impl ItemArgumentParseable for PortRanges {
303
    /// [`PortRanges`] argument parser which is odd because port ranges are
304
    /// syntactically a single argument although semantically multiple ones.
305
498
    fn from_args<'s>(args: &mut ArgumentStream<'s>) -> Result<Self, ArgumentError> {
306
498
        args.next()
307
498
            .map(Self::from_str)
308
498
            .unwrap_or(Ok(Self::new()))
309
498
            .map_err(|_| ArgumentError::Invalid)
310
498
    }
311
}
312

            
313
/// A kind of policy rule: either accepts or rejects addresses
314
/// matching a pattern.
315
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, derive_more::FromStr)]
316
#[display(rename_all = "lowercase")]
317
#[from_str(rename_all = "lowercase")]
318
#[allow(clippy::exhaustive_enums)]
319
pub enum RuleKind {
320
    /// A rule that accepts matching address:port combinations.
321
    Accept,
322
    /// A rule that rejects matching address:port combinations.
323
    Reject,
324
}
325

            
326
impl NormalItemArgument for RuleKind {}
327

            
328
#[cfg(test)]
329
mod test {
330
    // @@ begin test lint list maintained by maint/add_warning @@
331
    #![allow(clippy::bool_assert_comparison)]
332
    #![allow(clippy::clone_on_copy)]
333
    #![allow(clippy::dbg_macro)]
334
    #![allow(clippy::mixed_attributes_style)]
335
    #![allow(clippy::print_stderr)]
336
    #![allow(clippy::print_stdout)]
337
    #![allow(clippy::single_char_pattern)]
338
    #![allow(clippy::unwrap_used)]
339
    #![allow(clippy::unchecked_time_subtraction)]
340
    #![allow(clippy::useless_vec)]
341
    #![allow(clippy::needless_pass_by_value)]
342
    #![allow(clippy::string_slice)] // See arti#2571
343
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
344
    use super::*;
345
    use crate::Result;
346
    use crate::parse2::{self, ParseInput};
347

            
348
    #[test]
349
    fn parse_portrange() -> Result<()> {
350
        assert_eq!(
351
            "1-100".parse::<PortRange>()?,
352
            PortRange::new(1, 100).unwrap()
353
        );
354
        assert_eq!(
355
            "01-100".parse::<PortRange>()?,
356
            PortRange::new(1, 100).unwrap()
357
        );
358
        assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
359
        assert_eq!(
360
            "10-30".parse::<PortRange>()?,
361
            PortRange::new(10, 30).unwrap()
362
        );
363
        assert_eq!(
364
            "9001".parse::<PortRange>()?,
365
            PortRange::new(9001, 9001).unwrap()
366
        );
367
        assert_eq!(
368
            "9001-9001".parse::<PortRange>()?,
369
            PortRange::new(9001, 9001).unwrap()
370
        );
371

            
372
        assert!("hello".parse::<PortRange>().is_err());
373
        assert!("0".parse::<PortRange>().is_err());
374
        assert!("65536".parse::<PortRange>().is_err());
375
        assert!("65537".parse::<PortRange>().is_err());
376
        assert!("1-2-3".parse::<PortRange>().is_err());
377
        assert!("10-5".parse::<PortRange>().is_err());
378
        assert!("1-".parse::<PortRange>().is_err());
379
        assert!("-2".parse::<PortRange>().is_err());
380
        assert!("-".parse::<PortRange>().is_err());
381
        assert!("*".parse::<PortRange>().is_err());
382
        Ok(())
383
    }
384

            
385
    #[test]
386
    fn pr_manip() {
387
        assert!(PortRange::new_all().is_all());
388
        assert!(!PortRange::new(2, 65535).unwrap().is_all());
389

            
390
        assert!(PortRange::new_all().contains(1));
391
        assert!(PortRange::new_all().contains(65535));
392
        assert!(PortRange::new_all().contains(7777));
393

            
394
        assert!(PortRange::new(20, 30).unwrap().contains(20));
395
        assert!(PortRange::new(20, 30).unwrap().contains(25));
396
        assert!(PortRange::new(20, 30).unwrap().contains(30));
397
        assert!(!PortRange::new(20, 30).unwrap().contains(19));
398
        assert!(!PortRange::new(20, 30).unwrap().contains(31));
399

            
400
        use std::cmp::Ordering::*;
401
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
402
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
403
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
404
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
405
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
406
    }
407

            
408
    #[test]
409
    fn pr_fmt() {
410
        fn chk(a: u16, b: u16, s: &str) {
411
            let pr = PortRange::new(a, b).unwrap();
412
            assert_eq!(format!("{}", pr), s);
413
        }
414

            
415
        chk(1, 65535, "1-65535");
416
        chk(10, 20, "10-20");
417
        chk(20, 20, "20");
418
    }
419

            
420
    #[test]
421
    fn port_ranges() {
422
        const INPUT: &str = "22,80,443,8000-9000,9002";
423
        let ranges = PortRanges::from_str(INPUT).unwrap();
424
        assert_eq!(
425
            ranges.0,
426
            [
427
                PortRange::new(22, 22).unwrap(),
428
                PortRange::new(80, 80).unwrap(),
429
                PortRange::new(443, 443).unwrap(),
430
                PortRange::new(8000, 9000).unwrap(),
431
                PortRange::new(9002, 9002).unwrap(),
432
            ]
433
        );
434
        assert!(ranges.contains(22));
435
        assert!(ranges.contains(80));
436
        assert!(ranges.contains(443));
437
        assert!(ranges.contains(8000));
438
        assert!(ranges.contains(8500));
439
        assert!(ranges.contains(9000));
440
        assert!(!ranges.contains(9001));
441
        assert!(ranges.contains(9002));
442

            
443
        let mut ranges_inverse = ranges.clone();
444
        ranges_inverse.invert();
445
        assert_eq!(
446
            ranges_inverse.0,
447
            [
448
                PortRange::new(1, 21).unwrap(),
449
                PortRange::new(23, 79).unwrap(),
450
                PortRange::new(81, 442).unwrap(),
451
                PortRange::new(444, 7999).unwrap(),
452
                PortRange::new(9001, 9001).unwrap(),
453
                PortRange::new(9003, 65535).unwrap(),
454
            ]
455
        );
456

            
457
        #[derive(derive_deftly::Deftly)]
458
        #[derive_deftly(NetdocParseable)]
459
        struct Dummy {
460
            #[deftly(netdoc(single_arg))]
461
            dummy: PortRanges,
462
        }
463
        let ranges2 =
464
            parse2::parse_netdoc::<Dummy>(&ParseInput::new(&format!("dummy {INPUT}\n"), ""))
465
                .unwrap();
466
        assert_eq!(ranges, ranges2.dummy);
467
    }
468
}