1
//! Exit policies: match patterns of addresses and/or ports.
2
//!
3
//! Every Tor relays has a set of address:port combinations that it
4
//! actually allows connections to.  The set, abstractly, is the
5
//! relay's "exit policy".
6
//!
7
//! Address policies can be transmitted in two forms.  One is a "full
8
//! policy", that includes a list of rules that are applied in order
9
//! to represent addresses and ports.  We represent this with the
10
//! AddrPolicy type.
11
//!
12
//! In microdescriptors, and for IPv6 policies, policies are just
13
//! given a list of ports for which _most_ addresses are permitted.
14
//! We represent this kind of policy with the PortPolicy type.
15
//!
16
//! TODO: This module probably belongs in a crate of its own, with
17
//! possibly only the parsing code in this crate.
18

            
19
mod addrpolicy;
20
mod portpolicy;
21

            
22
use std::fmt::Display;
23
use std::str::FromStr;
24
use thiserror::Error;
25

            
26
pub use addrpolicy::{AddrPolicy, AddrPortPattern, RuleKind};
27
pub use portpolicy::PortPolicy;
28

            
29
/// Error from an unparsable or invalid policy.
30
#[derive(Debug, Error, Clone, PartialEq, Eq)]
31
#[non_exhaustive]
32
pub enum PolicyError {
33
    /// A port was not a number in the range 1..65535
34
    #[error("Invalid port")]
35
    InvalidPort,
36
    /// A port range had its starting-point higher than its ending point.
37
    #[error("Invalid port range")]
38
    InvalidRange,
39
    /// An address could not be interpreted.
40
    #[error("Invalid address")]
41
    InvalidAddress,
42
    /// Tried to use a bitmask with the address "*".
43
    #[error("mask with star")]
44
    MaskWithStar,
45
    /// A bit mask was out of range.
46
    #[error("invalid mask")]
47
    InvalidMask,
48
    /// A policy could not be parsed for some other reason.
49
    #[error("Invalid policy")]
50
    InvalidPolicy,
51
}
52

            
53
/// A PortRange is a set of consecutively numbered TCP or UDP ports.
54
///
55
/// # Example
56
/// ```
57
/// use tor_netdoc::types::policy::PortRange;
58
///
59
/// let r: PortRange = "22-8000".parse().unwrap();
60
/// assert!(r.contains(128));
61
/// assert!(r.contains(22));
62
/// assert!(r.contains(8000));
63
///
64
/// assert!(! r.contains(21));
65
/// assert!(! r.contains(8001));
66
/// ```
67
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68
#[allow(clippy::exhaustive_structs)]
69
pub struct PortRange {
70
    /// The first port in this range.
71
    pub lo: u16,
72
    /// The last port in this range.
73
    pub hi: u16,
74
}
75

            
76
impl PortRange {
77
    /// Create a new port range spanning from lo to hi, asserting that
78
    /// the correct invariants hold.
79
70832
    fn new_unchecked(lo: u16, hi: u16) -> Self {
80
70832
        assert!(lo != 0);
81
70832
        assert!(lo <= hi);
82
70832
        PortRange { lo, hi }
83
70832
    }
84
    /// Create a port range containing all ports.
85
23642
    pub fn new_all() -> Self {
86
23642
        PortRange::new_unchecked(1, 65535)
87
23642
    }
88
    /// Create a new PortRange.
89
    ///
90
    /// The Portrange contains all ports between `lo` and `hi` inclusive.
91
    ///
92
    /// Returns None if lo is greater than hi, or if either is zero.
93
570501
    pub fn new(lo: u16, hi: u16) -> Option<Self> {
94
570501
        if lo != 0 && lo <= hi {
95
570495
            Some(PortRange { lo, hi })
96
        } else {
97
6
            None
98
        }
99
570501
    }
100
    /// Return true if a port is in this range.
101
5154
    pub fn contains(&self, port: u16) -> bool {
102
5154
        self.lo <= port && port <= self.hi
103
5154
    }
104
    /// Return true if this range contains all ports.
105
24
    pub fn is_all(&self) -> bool {
106
24
        self.lo == 1 && self.hi == 65535
107
24
    }
108

            
109
    /// Helper for binary search: compare this range to a port.
110
    ///
111
    /// This range is "equal" to all ports that it contains.  It is
112
    /// "greater" than all ports that precede its starting point, and
113
    /// "less" than all ports that follow its ending point.
114
27336678
    fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
115
        use std::cmp::Ordering::*;
116
27336678
        if port < self.lo {
117
3070132
            Greater
118
24266546
        } else if port <= self.hi {
119
18380099
            Equal
120
        } else {
121
5886447
            Less
122
        }
123
27336678
    }
124
}
125

            
126
/// A PortRange is displayed as a number if it contains a single port,
127
/// and as a start point and end point separated by a dash if it contains
128
/// more than one port.
129
impl Display for PortRange {
130
44
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131
44
        if self.lo == self.hi {
132
18
            write!(f, "{}", self.lo)
133
        } else {
134
26
            write!(f, "{}-{}", self.lo, self.hi)
135
        }
136
44
    }
137
}
138

            
139
impl FromStr for PortRange {
140
    type Err = PolicyError;
141
570485
    fn from_str(s: &str) -> Result<Self, PolicyError> {
142
570485
        let idx = s.find('-');
143
        // Find "lo" and "hi".
144
570485
        let (lo, hi) = if let Some(pos) = idx {
145
            // This is a range; parse each part.
146
            (
147
350742
                s[..pos]
148
350742
                    .parse::<u16>()
149
350742
                    .map_err(|_| PolicyError::InvalidPort)?,
150
350736
                s[pos + 1..]
151
350736
                    .parse::<u16>()
152
350736
                    .map_err(|_| PolicyError::InvalidPort)?,
153
            )
154
        } else {
155
            // There was no hyphen, so try to parse this range as a singleton.
156
219743
            let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
157
219731
            (v, v)
158
        };
159
570463
        PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
160
570485
    }
161
}
162

            
163
#[cfg(test)]
164
mod test {
165
    // @@ begin test lint list maintained by maint/add_warning @@
166
    #![allow(clippy::bool_assert_comparison)]
167
    #![allow(clippy::clone_on_copy)]
168
    #![allow(clippy::dbg_macro)]
169
    #![allow(clippy::mixed_attributes_style)]
170
    #![allow(clippy::print_stderr)]
171
    #![allow(clippy::print_stdout)]
172
    #![allow(clippy::single_char_pattern)]
173
    #![allow(clippy::unwrap_used)]
174
    #![allow(clippy::unchecked_time_subtraction)]
175
    #![allow(clippy::useless_vec)]
176
    #![allow(clippy::needless_pass_by_value)]
177
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
178
    use super::*;
179
    use crate::Result;
180
    #[test]
181
    fn parse_portrange() -> Result<()> {
182
        assert_eq!(
183
            "1-100".parse::<PortRange>()?,
184
            PortRange::new(1, 100).unwrap()
185
        );
186
        assert_eq!(
187
            "01-100".parse::<PortRange>()?,
188
            PortRange::new(1, 100).unwrap()
189
        );
190
        assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
191
        assert_eq!(
192
            "10-30".parse::<PortRange>()?,
193
            PortRange::new(10, 30).unwrap()
194
        );
195
        assert_eq!(
196
            "9001".parse::<PortRange>()?,
197
            PortRange::new(9001, 9001).unwrap()
198
        );
199
        assert_eq!(
200
            "9001-9001".parse::<PortRange>()?,
201
            PortRange::new(9001, 9001).unwrap()
202
        );
203

            
204
        assert!("hello".parse::<PortRange>().is_err());
205
        assert!("0".parse::<PortRange>().is_err());
206
        assert!("65536".parse::<PortRange>().is_err());
207
        assert!("65537".parse::<PortRange>().is_err());
208
        assert!("1-2-3".parse::<PortRange>().is_err());
209
        assert!("10-5".parse::<PortRange>().is_err());
210
        assert!("1-".parse::<PortRange>().is_err());
211
        assert!("-2".parse::<PortRange>().is_err());
212
        assert!("-".parse::<PortRange>().is_err());
213
        assert!("*".parse::<PortRange>().is_err());
214
        Ok(())
215
    }
216

            
217
    #[test]
218
    fn pr_manip() {
219
        assert!(PortRange::new_all().is_all());
220
        assert!(!PortRange::new(2, 65535).unwrap().is_all());
221

            
222
        assert!(PortRange::new_all().contains(1));
223
        assert!(PortRange::new_all().contains(65535));
224
        assert!(PortRange::new_all().contains(7777));
225

            
226
        assert!(PortRange::new(20, 30).unwrap().contains(20));
227
        assert!(PortRange::new(20, 30).unwrap().contains(25));
228
        assert!(PortRange::new(20, 30).unwrap().contains(30));
229
        assert!(!PortRange::new(20, 30).unwrap().contains(19));
230
        assert!(!PortRange::new(20, 30).unwrap().contains(31));
231

            
232
        use std::cmp::Ordering::*;
233
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
234
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
235
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
236
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
237
        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
238
    }
239

            
240
    #[test]
241
    fn pr_fmt() {
242
        fn chk(a: u16, b: u16, s: &str) {
243
            let pr = PortRange::new(a, b).unwrap();
244
            assert_eq!(format!("{}", pr), s);
245
        }
246

            
247
        chk(1, 65535, "1-65535");
248
        chk(10, 20, "10-20");
249
        chk(20, 20, "20");
250
    }
251
}