1
//! Implement a set of RelayId.
2

            
3
use std::collections::HashSet;
4

            
5
use serde::de::Visitor;
6
use tor_llcrypto::pk::{ed25519::Ed25519Identity, rsa::RsaIdentity};
7

            
8
use crate::{RelayId, RelayIdRef};
9

            
10
/// A set of relay identities, backed by `HashSet`.
11
///
12
/// # Note
13
///
14
/// I'd rather use `HashSet` entirely, but that doesn't let us index by
15
/// RelayIdRef.
16
#[derive(Clone, Debug, Default, Eq, PartialEq)]
17
pub struct RelayIdSet {
18
    /// The Ed25519 members of this set.
19
    ed25519: HashSet<Ed25519Identity>,
20
    /// The RSA members of this set.
21
    rsa: HashSet<RsaIdentity>,
22
}
23

            
24
impl RelayIdSet {
25
    /// Construct a new empty RelayIdSet.
26
3534283
    pub fn new() -> Self {
27
3534283
        Self::default()
28
3534283
    }
29

            
30
    /// Insert `key` into this set.  
31
    ///
32
    /// Return true if it was not already there.
33
286126
    pub fn insert<T: Into<RelayId>>(&mut self, key: T) -> bool {
34
286126
        let key: RelayId = key.into();
35
286126
        match key {
36
142940
            RelayId::Ed25519(key) => self.ed25519.insert(key),
37
143186
            RelayId::Rsa(key) => self.rsa.insert(key),
38
        }
39
286126
    }
40

            
41
    /// Remove `key` from the set.
42
    ///
43
    /// Return true if `key` was present.
44
8
    pub fn remove<'a, T: Into<RelayIdRef<'a>>>(&mut self, key: T) -> bool {
45
8
        let key: RelayIdRef<'a> = key.into();
46
8
        match key {
47
4
            RelayIdRef::Ed25519(key) => self.ed25519.remove(key),
48
4
            RelayIdRef::Rsa(key) => self.rsa.remove(key),
49
        }
50
8
    }
51

            
52
    /// Return true if `key` is a member of this set.
53
66338872
    pub fn contains<'a, T: Into<RelayIdRef<'a>>>(&self, key: T) -> bool {
54
66338872
        let key: RelayIdRef<'a> = key.into();
55
66338872
        match key {
56
33230942
            RelayIdRef::Ed25519(key) => self.ed25519.contains(key),
57
33107930
            RelayIdRef::Rsa(key) => self.rsa.contains(key),
58
        }
59
66338872
    }
60

            
61
    /// Return an iterator over the members of this set.
62
    ///
63
    /// The ordering of the iterator is undefined; do not rely on it.
64
1741270
    pub fn iter(&self) -> impl Iterator<Item = RelayIdRef<'_>> {
65
1741270
        self.ed25519
66
1741270
            .iter()
67
1741474
            .map(|id| id.into())
68
1741478
            .chain(self.rsa.iter().map(|id| id.into()))
69
1741270
    }
70

            
71
    /// Return the number of keys in this set.
72
102
    pub fn len(&self) -> usize {
73
102
        self.ed25519.len() + self.rsa.len()
74
102
    }
75

            
76
    /// Return true if there are not keys in this set.
77
241
    pub fn is_empty(&self) -> bool {
78
241
        self.ed25519.is_empty() && self.rsa.is_empty()
79
241
    }
80
}
81

            
82
impl<ID: Into<RelayId>> Extend<ID> for RelayIdSet {
83
1731907
    fn extend<T: IntoIterator<Item = ID>>(&mut self, iter: T) {
84
1736818
        for item in iter {
85
247528
            self.insert(item);
86
247528
        }
87
1731907
    }
88
}
89

            
90
impl FromIterator<RelayId> for RelayIdSet {
91
27495
    fn from_iter<T: IntoIterator<Item = RelayId>>(iter: T) -> Self {
92
27495
        let mut set = RelayIdSet::new();
93
27495
        set.extend(iter);
94
27495
        set
95
27495
    }
96
}
97

            
98
impl serde::Serialize for RelayIdSet {
99
6
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100
6
    where
101
6
        S: serde::Serializer,
102
    {
103
6
        serializer.collect_seq(self.iter())
104
6
    }
105
}
106

            
107
impl<'de> serde::Deserialize<'de> for RelayIdSet {
108
12
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109
12
    where
110
12
        D: serde::Deserializer<'de>,
111
    {
112
        /// A serde visitor to deserialize a sequence of RelayIds into a
113
        /// RelayIdSet.
114
        struct IdSetVisitor;
115
        impl<'de> Visitor<'de> for IdSetVisitor {
116
            type Value = RelayIdSet;
117

            
118
            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
119
                write!(f, "a list of relay identities")
120
            }
121

            
122
12
            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
123
12
            where
124
12
                A: serde::de::SeqAccess<'de>,
125
            {
126
12
                let mut set = RelayIdSet::new();
127
20
                while let Some(key) = seq.next_element::<RelayId>()? {
128
8
                    set.insert(key);
129
8
                }
130
12
                Ok(set)
131
12
            }
132
        }
133
12
        deserializer.deserialize_seq(IdSetVisitor)
134
12
    }
135
}
136

            
137
#[cfg(test)]
138
mod test {
139
    // @@ begin test lint list maintained by maint/add_warning @@
140
    #![allow(clippy::bool_assert_comparison)]
141
    #![allow(clippy::clone_on_copy)]
142
    #![allow(clippy::dbg_macro)]
143
    #![allow(clippy::mixed_attributes_style)]
144
    #![allow(clippy::print_stderr)]
145
    #![allow(clippy::print_stdout)]
146
    #![allow(clippy::single_char_pattern)]
147
    #![allow(clippy::unwrap_used)]
148
    #![allow(clippy::unchecked_time_subtraction)]
149
    #![allow(clippy::useless_vec)]
150
    #![allow(clippy::needless_pass_by_value)]
151
    #![allow(clippy::string_slice)] // See arti#2571
152
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
153

            
154
    use super::*;
155
    use hex_literal::hex;
156
    use serde_test::{Token, assert_tokens};
157

            
158
    #[test]
159
    fn basic_usage() {
160
        #![allow(clippy::cognitive_complexity)]
161
        let rsa1 = RsaIdentity::from(hex!("42656c6f7665642c207768617420617265206e61"));
162
        let rsa2 = RsaIdentity::from(hex!("6d657320627574206169723f43686f6f73652074"));
163
        let rsa3 = RsaIdentity::from(hex!("686f752077686174657665722073756974732074"));
164

            
165
        let ed1 = Ed25519Identity::from(hex!(
166
            "6865206c696e653a43616c6c206d652053617070686f2c2063616c6c206d6520"
167
        ));
168
        let ed2 = Ed25519Identity::from(hex!(
169
            "43686c6f7269732c2043616c6c206d65204c616c6167652c206f7220446f7269"
170
        ));
171
        let ed3 = Ed25519Identity::from(hex!(
172
            "732c204f6e6c792c206f6e6c792c2063616c6c206d65207468696e652e000000"
173
        ));
174

            
175
        let mut set = RelayIdSet::new();
176
        assert_eq!(set.is_empty(), true);
177
        assert_eq!(set.len(), 0);
178

            
179
        set.insert(rsa1);
180
        set.insert(rsa2);
181
        set.insert(ed1);
182

            
183
        assert_eq!(set.is_empty(), false);
184
        assert_eq!(set.len(), 3);
185
        assert_eq!(set.contains(&rsa1), true);
186
        assert_eq!(set.contains(&rsa2), true);
187
        assert_eq!(set.contains(&rsa3), false);
188
        assert_eq!(set.contains(&ed1), true);
189
        assert_eq!(set.contains(&ed2), false);
190
        assert_eq!(set.contains(&ed3), false);
191

            
192
        let contents: HashSet<_> = set.iter().collect();
193
        assert_eq!(contents.len(), set.len());
194
        assert!(contents.contains(&RelayIdRef::from(&rsa1)));
195
        assert!(contents.contains(&RelayIdRef::from(&rsa2)));
196
        assert!(contents.contains(&RelayIdRef::from(&ed1)));
197

            
198
        assert_eq!(set.remove(&ed2), false);
199
        assert_eq!(set.remove(&ed1), true);
200
        assert_eq!(set.remove(&rsa3), false);
201
        assert_eq!(set.remove(&rsa1), true);
202
        assert_eq!(set.is_empty(), false);
203
        assert_eq!(set.len(), 1);
204
        assert_eq!(set.contains(&ed1), false);
205
        assert_eq!(set.contains(&rsa1), false);
206
        assert_eq!(set.contains(&rsa2), true);
207

            
208
        let contents2: Vec<_> = set.iter().collect();
209
        assert_eq!(contents2, vec![RelayIdRef::from(&rsa2)]);
210

            
211
        let set2: RelayIdSet = set.iter().map(|id| id.to_owned()).collect();
212
        assert_eq!(set, set2);
213

            
214
        let mut set3 = RelayIdSet::new();
215
        set3.extend(set.iter().map(|id| id.to_owned()));
216
        assert_eq!(set2, set3);
217
    }
218

            
219
    #[test]
220
    fn serde_empty() {
221
        let set = RelayIdSet::new();
222

            
223
        assert_tokens(&set, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
224
    }
225

            
226
    #[test]
227
    fn serde_singleton_rsa() {
228
        let mut set = RelayIdSet::new();
229
        set.insert(RsaIdentity::from(hex!(
230
            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
231
        )));
232

            
233
        assert_tokens(
234
            &set,
235
            &[
236
                Token::Seq { len: Some(1) },
237
                Token::Str("$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
238
                Token::SeqEnd,
239
            ],
240
        );
241
    }
242

            
243
    #[test]
244
    fn serde_singleton_ed25519() {
245
        let mut set = RelayIdSet::new();
246
        set.insert(Ed25519Identity::from(hex!(
247
            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
248
        )));
249

            
250
        assert_tokens(
251
            &set,
252
            &[
253
                Token::Seq { len: Some(1) },
254
                Token::String("ed25519:u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7s"),
255
                Token::SeqEnd,
256
            ],
257
        );
258
    }
259
}