1
//! A crate for performing GeoIP lookups using the Tor GeoIP database.
2

            
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48

            
49
// TODO #1645 (either remove this, or decide to have it everywhere)
50
#![cfg_attr(not(all(feature = "full")), allow(unused))]
51

            
52
pub use crate::err::Error;
53
use rangemap::RangeInclusiveMap;
54
use std::fmt::{Debug, Display, Formatter};
55
use std::net::{IpAddr, Ipv6Addr};
56
use std::num::{NonZeroU8, NonZeroU32, TryFromIntError};
57
use std::str::FromStr;
58
use std::sync::{Arc, OnceLock};
59

            
60
mod err;
61

            
62
/// An embedded copy of the latest geoip v4 database at the time of compilation.
63
///
64
/// FIXME(eta): This does use a few megabytes of binary size, which is less than ideal.
65
///             It would be better to parse it at compile time or something.
66
#[cfg(feature = "embedded-db")]
67
static EMBEDDED_DB_V4: &str = include_str!("../data/geoip");
68

            
69
/// An embedded copy of the latest geoip v6 database at the time of compilation.
70
#[cfg(feature = "embedded-db")]
71
static EMBEDDED_DB_V6: &str = include_str!("../data/geoip6");
72

            
73
/// A parsed copy of the embedded database.
74
#[cfg(feature = "embedded-db")]
75
static EMBEDDED_DB_PARSED: OnceLock<Arc<GeoipDb>> = OnceLock::new();
76

            
77
/// A two-letter country code.
78
///
79
/// Specifically, this type represents a purported "ISO 3166-1 alpha-2" country
80
/// code, such as "IT" for Italy or "UY" for Uruguay.
81
///
82
/// It does not include the sentinel value `??` that we use to represent
83
/// "country unknown"; if you need that, use [`OptionCc`]. Other than that, we
84
/// do not check whether the country code represents a real country: we only
85
/// ensure that it is a pair of printing ASCII characters.
86
///
87
/// Note that the geoip databases included with Arti will only include real
88
/// countries; we do not include the pseudo-countries `A1` through `An` for
89
/// "anonymous proxies", since doing so would mean putting nearly all Tor relays
90
/// into one of those countries.
91
#[derive(Copy, Clone, Eq, PartialEq)]
92
pub struct CountryCode {
93
    /// The underlying value (two printable ASCII characters, stored uppercase).
94
    ///
95
    /// The special value `??` is excluded, since it is not a country; use
96
    /// `OptionCc` instead if you need to represent that.
97
    ///
98
    /// We store these as `NonZeroU8` so that an `Option<CountryCode>` only has to
99
    /// take 2 bytes. This helps with alignment and storage.
100
    inner: [NonZeroU8; 2],
101
}
102

            
103
impl CountryCode {
104
    /// Make a new `CountryCode`.
105
14656486
    fn new(cc_orig: &str) -> Result<Self, Error> {
106
        /// Try to convert an array of 2 bytes into an array of 2 nonzero bytes.
107
        #[inline]
108
14626318
        fn try_cvt_to_nz(inp: [u8; 2]) -> Result<[NonZeroU8; 2], TryFromIntError> {
109
            // I have confirmed that the asm here is reasonably efficient.
110
14626318
            Ok([inp[0].try_into()?, inp[1].try_into()?])
111
14626318
        }
112

            
113
14656486
        let cc = cc_orig.to_ascii_uppercase();
114

            
115
14656486
        let cc: [u8; 2] = cc
116
14656486
            .as_bytes()
117
14656486
            .try_into()
118
14656489
            .map_err(|_| Error::BadCountryCode(cc))?;
119

            
120
29899227
        if !cc.iter().all(|b| b.is_ascii() && !b.is_ascii_control()) {
121
6
            return Err(Error::BadCountryCode(cc_orig.to_owned()));
122
14656474
        }
123

            
124
14656474
        if &cc == b"??" {
125
30156
            return Err(Error::NowhereNotSupported);
126
14626318
        }
127

            
128
        Ok(Self {
129
14626318
            inner: try_cvt_to_nz(cc).map_err(|_| Error::BadCountryCode(cc_orig.to_owned()))?,
130
        })
131
14656486
    }
132

            
133
    /// Get the actual country code.
134
    ///
135
    /// This just calls `.as_ref()`.
136
    pub fn get(&self) -> &str {
137
        self.as_ref()
138
    }
139
}
140

            
141
impl Display for CountryCode {
142
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143
        write!(f, "{}", self.as_ref())
144
    }
145
}
146

            
147
impl Debug for CountryCode {
148
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149
        write!(f, "CountryCode(\"{}\")", self.as_ref())
150
    }
151
}
152

            
153
impl AsRef<str> for CountryCode {
154
104
    fn as_ref(&self) -> &str {
155
        /// Convert a reference to an array of 2 nonzero bytes to a reference to
156
        /// an array of 2 bytes.
157
        #[inline]
158
104
        fn cvt_ref(inp: &[NonZeroU8; 2]) -> &[u8; 2] {
159
            // SAFETY: Every NonZeroU8 has a layout and bit validity that is
160
            // also a valid u8.  The layout of arrays is also guaranteed.
161
            //
162
            // (We don't use try_into here because we need to return a str that
163
            // points to a reference to self.)
164
104
            let ptr = inp.as_ptr() as *const u8;
165
104
            let slice = unsafe { std::slice::from_raw_parts(ptr, inp.len()) };
166
104
            slice
167
104
                .try_into()
168
104
                .expect("the resulting slice should have the correct length!")
169
104
        }
170

            
171
        // This shouldn't ever panic, since we shouldn't feed non-utf8 country
172
        // codes in.
173
        //
174
        // In theory we could use from_utf8_unchecked, but that's probably not
175
        // needed.
176
104
        std::str::from_utf8(cvt_ref(&self.inner)).expect("invalid country code in CountryCode")
177
104
    }
178
}
179

            
180
impl FromStr for CountryCode {
181
    type Err = Error;
182

            
183
32
    fn from_str(s: &str) -> Result<Self, Self::Err> {
184
32
        CountryCode::new(s)
185
32
    }
186
}
187

            
188
/// Wrapper for an `Option<`[`CountryCode`]`>` that encodes `None` as `??`.
189
///
190
/// Used so that we can implement foreign traits.
191
#[derive(
192
    Copy, Clone, Debug, Eq, PartialEq, derive_more::Into, derive_more::From, derive_more::AsRef,
193
)]
194
#[allow(clippy::exhaustive_structs)]
195
pub struct OptionCc(pub Option<CountryCode>);
196

            
197
impl FromStr for OptionCc {
198
    type Err = Error;
199

            
200
14656454
    fn from_str(s: &str) -> Result<Self, Self::Err> {
201
14656454
        match CountryCode::new(s) {
202
30154
            Err(Error::NowhereNotSupported) => Ok(None.into()),
203
            Err(e) => Err(e),
204
14626300
            Ok(cc) => Ok(Some(cc).into()),
205
        }
206
14656454
    }
207
}
208

            
209
impl Display for OptionCc {
210
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
211
        match self.0 {
212
            Some(cc) => write!(f, "{}", cc),
213
            None => write!(f, "??"),
214
        }
215
    }
216
}
217

            
218
/// A country code / ASN definition.
219
///
220
/// Type lifted from `geoip-db-tool` in the C-tor source.
221
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
222
struct NetDefn {
223
    /// The country code.
224
    ///
225
    /// We translate the value "??" into None.
226
    cc: Option<CountryCode>,
227
    /// The ASN, if we have one. We translate the value "0" into None.
228
    asn: Option<NonZeroU32>,
229
}
230

            
231
impl NetDefn {
232
    /// Make a new `NetDefn`.
233
14656450
    fn new(cc: &str, asn: Option<u32>) -> Result<Self, Error> {
234
14656450
        let asn = NonZeroU32::new(asn.unwrap_or(0));
235
14656450
        let cc = cc.parse::<OptionCc>()?.into();
236

            
237
14656450
        Ok(Self { cc, asn })
238
14656450
    }
239

            
240
    /// Return the country code.
241
250
    fn country_code(&self) -> Option<&CountryCode> {
242
250
        self.cc.as_ref()
243
250
    }
244

            
245
    /// Return the ASN, if there is one.
246
    fn asn(&self) -> Option<u32> {
247
        self.asn.as_ref().map(|x| x.get())
248
    }
249
}
250

            
251
/// A database of IP addresses to country codes.
252
#[derive(Clone, Eq, PartialEq, Debug)]
253
pub struct GeoipDb {
254
    /// The IPv4 subset of the database, with v4 addresses stored as 32-bit integers.
255
    map_v4: RangeInclusiveMap<u32, NetDefn>,
256
    /// The IPv6 subset of the database, with v6 addresses stored as 128-bit integers.
257
    map_v6: RangeInclusiveMap<u128, NetDefn>,
258
}
259

            
260
impl GeoipDb {
261
    /// Make a new `GeoipDb` using a compiled-in copy of the GeoIP database.
262
    ///
263
    /// The returned instance of the database is shared with `Arc` across all invocations of this
264
    /// function in the same program.
265
    #[cfg(feature = "embedded-db")]
266
146
    pub fn new_embedded() -> Arc<Self> {
267
148
        Arc::clone(EMBEDDED_DB_PARSED.get_or_init(|| {
268
50
            Arc::new(
269
                // It's reasonable to assume the one we embedded is fine -- we'll test it in CI, etc.
270
50
                Self::new_from_legacy_format(EMBEDDED_DB_V4, EMBEDDED_DB_V6)
271
50
                    .expect("failed to parse embedded geoip database"),
272
            )
273
50
        }))
274
146
    }
275

            
276
    /// Make a new `GeoipDb` using provided copies of the v4 and v6 database, in Tor legacy format.
277
100
    pub fn new_from_legacy_format(db_v4: &str, db_v6: &str) -> Result<Self, Error> {
278
100
        let mut ret = GeoipDb {
279
100
            map_v4: Default::default(),
280
100
            map_v6: Default::default(),
281
100
        };
282

            
283
8416306
        for line in db_v4.lines() {
284
8416306
            if line.starts_with('#') {
285
850
                continue;
286
8415456
            }
287
8415456
            let line = line.trim();
288
8415456
            if line.is_empty() {
289
4
                continue;
290
8415452
            }
291
8415452
            let mut split = line.split(',');
292
8415452
            let from = split
293
8415452
                .next()
294
8415452
                .ok_or(Error::BadFormat("empty line somehow?"))?
295
8415452
                .parse::<u32>()?;
296
8415452
            let to = split
297
8415452
                .next()
298
8415452
                .ok_or(Error::BadFormat("line with insufficient commas"))?
299
8415452
                .parse::<u32>()?;
300
8415452
            let cc = split
301
8415452
                .next()
302
8415452
                .ok_or(Error::BadFormat("line with insufficient commas"))?;
303
8415452
            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
304

            
305
8415452
            let defn = NetDefn::new(cc, asn)?;
306

            
307
8415452
            ret.map_v4.insert(from..=to, defn);
308
        }
309

            
310
        // This is slightly copypasta, but probably less readable to merge into one thing.
311
6241948
        for line in db_v6.lines() {
312
6241948
            if line.starts_with('#') {
313
850
                continue;
314
6241098
            }
315
6241098
            let line = line.trim();
316
6241098
            if line.is_empty() {
317
100
                continue;
318
6240998
            }
319
6240998
            let mut split = line.split(',');
320
6240998
            let from = split
321
6240998
                .next()
322
6240998
                .ok_or(Error::BadFormat("empty line somehow?"))?
323
6240998
                .parse::<Ipv6Addr>()?;
324
6240998
            let to = split
325
6240998
                .next()
326
6240998
                .ok_or(Error::BadFormat("line with insufficient commas"))?
327
6240998
                .parse::<Ipv6Addr>()?;
328
6240998
            let cc = split
329
6240998
                .next()
330
6240998
                .ok_or(Error::BadFormat("line with insufficient commas"))?;
331
6240998
            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
332

            
333
6240998
            let defn = NetDefn::new(cc, asn)?;
334

            
335
6240998
            ret.map_v6.insert(from.into()..=to.into(), defn);
336
        }
337

            
338
100
        Ok(ret)
339
100
    }
340

            
341
    /// Get the `NetDefn` for an IP address.
342
3134
    fn lookup_defn(&self, ip: IpAddr) -> Option<&NetDefn> {
343
3134
        match ip {
344
2598
            IpAddr::V4(v4) => self.map_v4.get(&v4.into()),
345
536
            IpAddr::V6(v6) => self.map_v6.get(&v6.into()),
346
        }
347
3134
    }
348

            
349
    /// Get a 2-letter country code for the given IP address, if this data is available.
350
3134
    pub fn lookup_country_code(&self, ip: IpAddr) -> Option<&CountryCode> {
351
3144
        self.lookup_defn(ip).and_then(|x| x.country_code())
352
3134
    }
353

            
354
    /// Determine a 2-letter country code for a host with multiple IP addresses.
355
    ///
356
    /// This looks up all of the IP addresses with `lookup_country_code`. If the lookups
357
    /// return different countries, `None` is returned. IP addresses that fail to resolve
358
    /// into a country are ignored if some of the other addresses do resolve successfully.
359
738
    pub fn lookup_country_code_multi<I>(&self, ips: I) -> Option<&CountryCode>
360
738
    where
361
738
        I: IntoIterator<Item = IpAddr>,
362
    {
363
738
        let mut ret = None;
364

            
365
1766
        for ip in ips {
366
1030
            if let Some(cc) = self.lookup_country_code(ip) {
367
                // If we already have a return value and it's different, then return None;
368
                // a server can't be in two different countries.
369
10
                if ret.is_some() && ret != Some(cc) {
370
2
                    return None;
371
8
                }
372

            
373
8
                ret = Some(cc);
374
1020
            }
375
        }
376

            
377
736
        ret
378
738
    }
379

            
380
    /// Return the ASN the IP address is in, if this data is available.
381
    pub fn lookup_asn(&self, ip: IpAddr) -> Option<u32> {
382
        self.lookup_defn(ip)?.asn()
383
    }
384
}
385

            
386
/// A (representation of a) host on the network which may have a known country code.
387
pub trait HasCountryCode {
388
    /// Return the country code in which this server is most likely located.
389
    ///
390
    /// This is usually implemented by simple GeoIP lookup on the addresses provided by `HasAddrs`.
391
    /// It follows that the server might not actually be in the returned country, but this is a
392
    /// halfway decent estimate for what other servers might guess the server's location to be
393
    /// (and thus useful for e.g. getting around simple geo-blocks, or having webpages return
394
    /// the correct localised versions).
395
    ///
396
    /// Returning `None` signifies that no country code information is available. (Conflicting
397
    /// GeoIP lookup results might also cause `None` to be returned.)
398
    fn country_code(&self) -> Option<CountryCode>;
399
}
400

            
401
#[cfg(test)]
402
mod test {
403
    // @@ begin test lint list maintained by maint/add_warning @@
404
    #![allow(clippy::bool_assert_comparison)]
405
    #![allow(clippy::clone_on_copy)]
406
    #![allow(clippy::dbg_macro)]
407
    #![allow(clippy::mixed_attributes_style)]
408
    #![allow(clippy::print_stderr)]
409
    #![allow(clippy::print_stdout)]
410
    #![allow(clippy::single_char_pattern)]
411
    #![allow(clippy::unwrap_used)]
412
    #![allow(clippy::unchecked_time_subtraction)]
413
    #![allow(clippy::useless_vec)]
414
    #![allow(clippy::needless_pass_by_value)]
415
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
416

            
417
    use super::*;
418
    use std::net::Ipv4Addr;
419

            
420
    // NOTE(eta): this test takes a whole 1.6 seconds in *non-release* mode
421
    #[test]
422
    #[cfg(feature = "embedded-db")]
423
    fn embedded_db() {
424
        let db = GeoipDb::new_embedded();
425

            
426
        assert_eq!(
427
            db.lookup_country_code(Ipv4Addr::new(8, 8, 8, 8).into())
428
                .map(|x| x.as_ref()),
429
            Some("US")
430
        );
431

            
432
        assert_eq!(
433
            db.lookup_country_code("2001:4860:4860::8888".parse().unwrap())
434
                .map(|x| x.as_ref()),
435
            Some("US")
436
        );
437
    }
438

            
439
    #[test]
440
    fn basic_lookups() {
441
        let src_v4 = r#"
442
        16909056,16909311,GB
443
        "#;
444
        let src_v6 = r#"
445
        fe80::,fe81::,US
446
        dead:beef::,dead:ffff::,??
447
        "#;
448
        let db = GeoipDb::new_from_legacy_format(src_v4, src_v6).unwrap();
449

            
450
        assert_eq!(
451
            db.lookup_country_code(Ipv4Addr::new(1, 2, 3, 4).into())
452
                .map(|x| x.as_ref()),
453
            Some("GB")
454
        );
455

            
456
        assert_eq!(
457
            db.lookup_country_code(Ipv4Addr::new(1, 1, 1, 1).into()),
458
            None
459
        );
460

            
461
        assert_eq!(
462
            db.lookup_country_code("fe80::dead:beef".parse().unwrap())
463
                .map(|x| x.as_ref()),
464
            Some("US")
465
        );
466

            
467
        assert_eq!(
468
            db.lookup_country_code("fe81::dead:beef".parse().unwrap()),
469
            None
470
        );
471
        assert_eq!(
472
            db.lookup_country_code("dead:beef::1".parse().unwrap()),
473
            None
474
        );
475
    }
476

            
477
    #[test]
478
    fn cc_parse() -> Result<(), Error> {
479
        // real countries.
480
        assert_eq!(CountryCode::from_str("us")?, CountryCode::from_str("US")?);
481
        assert_eq!(CountryCode::from_str("UY")?, CountryCode::from_str("UY")?);
482

            
483
        // not real as of this writing, but still representable.
484
        assert_eq!(CountryCode::from_str("A7")?, CountryCode::from_str("a7")?);
485
        assert_eq!(CountryCode::from_str("xz")?, CountryCode::from_str("xz")?);
486

            
487
        // Can't convert to two bytes.
488
        assert!(matches!(
489
            CountryCode::from_str("z"),
490
            Err(Error::BadCountryCode(_))
491
        ));
492
        assert!(matches!(
493
            CountryCode::from_str("🐻‍❄️"),
494
            Err(Error::BadCountryCode(_))
495
        ));
496
        assert!(matches!(
497
            CountryCode::from_str("Sheboygan"),
498
            Err(Error::BadCountryCode(_))
499
        ));
500

            
501
        // Can convert to two bytes, but still not printable ascii
502
        assert!(matches!(
503
            CountryCode::from_str("\r\n"),
504
            Err(Error::BadCountryCode(_))
505
        ));
506
        assert!(matches!(
507
            CountryCode::from_str("\0\0"),
508
            Err(Error::BadCountryCode(_))
509
        ));
510
        assert!(matches!(
511
            CountryCode::from_str("¡"),
512
            Err(Error::BadCountryCode(_))
513
        ));
514

            
515
        // Not a country.
516
        assert!(matches!(
517
            CountryCode::from_str("??"),
518
            Err(Error::NowhereNotSupported)
519
        ));
520

            
521
        Ok(())
522
    }
523

            
524
    #[test]
525
    fn opt_cc_parse() -> Result<(), Error> {
526
        assert_eq!(
527
            CountryCode::from_str("br")?,
528
            OptionCc::from_str("BR")?.0.unwrap()
529
        );
530
        assert!(OptionCc::from_str("??")?.0.is_none());
531

            
532
        Ok(())
533
    }
534
}