1
//! An object mapper for looking up `rpc::Object`s by ID.
2
//!
3
//! This mapper stores strong or weak references, and uses a generational index
4
//! to keep track of names for them.
5
//!
6
//! TODO RPC: Add an object diagram here once the implementation settles down.
7

            
8
use std::sync::{Arc, Weak};
9

            
10
use rand::RngExt;
11

            
12
use slotmap_careful::{Key as _, KeyData, SlotMap};
13
use tor_rpcbase as rpc;
14

            
15
pub(crate) mod methods;
16

            
17
slotmap_careful::new_key_type! {
18
    pub(crate) struct GenIdx;
19

            
20
}
21

            
22
/// A weak or a strong reference to an RPC object.
23
//
24
// Note: This type does not pack very efficiently, due to Rust's current lack
25
// of alignment-based niche optimization.
26
// If this ever matters, we can either use two slotmaps, or we can implement
27
// some kind of kludgey hack on our own.
28
#[derive(Clone, derive_more::From)]
29
enum ObjectRef {
30
    /// A strong reference.
31
    Strong(Arc<dyn rpc::Object>),
32
    /// A weak reference reference.
33
    Weak(Weak<dyn rpc::Object>),
34
}
35

            
36
impl ObjectRef {
37
    /// Return this reference as an Arc, if it is present.
38
16
    fn get(&self) -> Option<Arc<dyn rpc::Object>> {
39
16
        match self {
40
10
            ObjectRef::Strong(s) => Some(Arc::clone(s)),
41
6
            ObjectRef::Weak(w) => w.upgrade(),
42
        }
43
16
    }
44
}
45

            
46
/// A mechanism to look up RPC `Objects` by their `ObjectId`.
47
#[derive(Default)]
48
pub(crate) struct ObjMap {
49
    /// Generationally indexed arena of strong object references.
50
    arena: SlotMap<GenIdx, ObjectRef>,
51
}
52

            
53
/// Encoding functions for GenIdx.
54
///
55
/// The encoding is deliberately nondeterministic: we want to avoid situations
56
/// where applications depend on the details of our ObjectIds, or hardcode the
57
/// ObjectIds they expect, or rely on the same  weak generational index getting
58
/// encoded the same way every time they see it.
59
///
60
/// The encoding is deliberately non-cryptographic: we do not want to imply
61
/// that this gives any security. It is just a mild deterrent to misuse.
62
///
63
/// If you find yourself wanting to reverse-engineer this code so that you can
64
/// analyze these object IDs, please contact the Arti developers instead and let
65
/// us give you a better way to do whatever you want.
66
impl GenIdx {
67
    /// The length of a byte-encoded (but not base-64 encoded) GenIdx.
68
    pub(crate) const BYTE_LEN: usize = 16;
69

            
70
    /// Encode `self` into an rpc::ObjectId that we can give to a client.
71
    pub(crate) fn encode(self) -> rpc::ObjectId {
72
        self.encode_with_rng(&mut rand::rng())
73
    }
74

            
75
    /// As `encode`, but take a Rng as an argument. For testing.
76
1040
    fn encode_with_rng<R: rand::Rng>(self, rng: &mut R) -> rpc::ObjectId {
77
        use base64ct::Encoding;
78
1040
        let bytes = self.to_bytes(rng);
79
1040
        rpc::ObjectId::from(base64ct::Base64UrlUnpadded::encode_string(&bytes[..]))
80
1040
    }
81

            
82
    /// As `encode_with_rng`, but return an array of bytes.
83
1046
    pub(crate) fn to_bytes<R: rand::Rng>(self, rng: &mut R) -> [u8; Self::BYTE_LEN] {
84
        use tor_bytes::Writer;
85
1046
        let ffi_idx = self.data().as_ffi();
86
1046
        let x = rng.random::<u64>();
87
1046
        let mut bytes = Vec::with_capacity(Self::BYTE_LEN);
88
1046
        bytes.write_u64(x);
89
1046
        bytes.write_u64(ffi_idx.wrapping_add(x));
90

            
91
1046
        bytes.try_into().expect("Length was wrong!")
92
1046
    }
93

            
94
    /// Attempt to decode `id` into a `GenIdx` than an ObjMap can use.
95
1040
    pub(crate) fn try_decode(id: &rpc::ObjectId) -> Result<Self, rpc::LookupError> {
96
        use base64ct::Encoding;
97

            
98
1040
        let bytes = base64ct::Base64UrlUnpadded::decode_vec(id.as_ref())
99
1040
            .map_err(|_| rpc::LookupError::NoObject(id.clone()))?;
100
1040
        Self::from_bytes(&bytes).ok_or_else(|| rpc::LookupError::NoObject(id.clone()))
101
1040
    }
102

            
103
    /// As `try_decode`, but take a slice of bytes.
104
1042
    pub(crate) fn from_bytes(bytes: &[u8]) -> Option<Self> {
105
        use tor_bytes::Reader;
106
1042
        let mut r = Reader::from_slice(bytes);
107
1042
        let x = r.take_u64().ok()?;
108
1042
        let ffi_idx = r.take_u64().ok()?;
109
1042
        r.should_be_exhausted().ok()?;
110

            
111
1042
        let ffi_idx = ffi_idx.wrapping_sub(x);
112
1042
        Some(GenIdx::from(KeyData::from_ffi(ffi_idx)))
113
1042
    }
114
}
115

            
116
impl ObjMap {
117
    /// Create a new empty ObjMap.
118
8
    pub(crate) fn new() -> Self {
119
8
        Self::default()
120
8
    }
121

            
122
    /// Unconditionally insert a strong entry for `value` in self, and return its index.
123
14
    pub(crate) fn insert_strong(&mut self, value: Arc<dyn rpc::Object>) -> GenIdx {
124
14
        self.arena.insert(ObjectRef::Strong(value))
125
14
    }
126

            
127
    /// Unconditionally insert a weak entry for `value` in self, and return its index.
128
10
    pub(crate) fn insert_weak(&mut self, value: &Arc<dyn rpc::Object>) -> GenIdx {
129
10
        self.arena.insert(ObjectRef::Weak(Arc::downgrade(value)))
130
10
    }
131

            
132
    /// Return the entry from this ObjMap for `idx`.
133
24
    pub(crate) fn lookup(&self, idx: GenIdx) -> Result<Arc<dyn rpc::Object>, LookupError> {
134
24
        self.arena
135
24
            .get(idx)
136
24
            .ok_or(LookupError::NoObject)?
137
16
            .get()
138
16
            .ok_or(LookupError::Expired)
139
24
    }
140

            
141
    /// Remove the entry at `idx`.
142
    ///
143
    /// Return true if anything was removed.
144
6
    pub(crate) fn remove(&mut self, idx: GenIdx) -> bool {
145
6
        self.arena.remove(idx).is_some()
146
6
    }
147

            
148
    /// Testing only: Assert that every invariant for this structure is met.
149
    #[cfg(test)]
150
16
    fn assert_okay(&self) {}
151
}
152

            
153
/// A failure from ObjMap::lookup.
154
///
155
/// (This type is immediately returned into rpc::LookupError before we return it.)
156
#[derive(Clone, Debug, thiserror::Error)]
157
pub(crate) enum LookupError {
158
    /// There was no object with the given ID.
159
    #[error("Object not found")]
160
    NoObject,
161

            
162
    /// The object was present, but it was a weak reference that expired.
163
    #[error("Object expired")]
164
    Expired,
165
}
166

            
167
impl LookupError {
168
    /// Convert this `LookupError` into an [`rpc::LookupError`]
169
    pub(crate) fn to_rpc_lookup_error(&self, id: rpc::ObjectId) -> rpc::LookupError {
170
        match self {
171
            LookupError::NoObject => rpc::LookupError::NoObject(id),
172
            LookupError::Expired => rpc::LookupError::Expired(id),
173
        }
174
    }
175
}
176

            
177
#[cfg(test)]
178
mod test {
179
    // @@ begin test lint list maintained by maint/add_warning @@
180
    #![allow(clippy::bool_assert_comparison)]
181
    #![allow(clippy::clone_on_copy)]
182
    #![allow(clippy::dbg_macro)]
183
    #![allow(clippy::mixed_attributes_style)]
184
    #![allow(clippy::print_stderr)]
185
    #![allow(clippy::print_stdout)]
186
    #![allow(clippy::single_char_pattern)]
187
    #![allow(clippy::unwrap_used)]
188
    #![allow(clippy::unchecked_time_subtraction)]
189
    #![allow(clippy::useless_vec)]
190
    #![allow(clippy::needless_pass_by_value)]
191
    #![allow(clippy::string_slice)] // See arti#2571
192
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
193

            
194
    use super::*;
195
    use derive_deftly::Deftly;
196
    use tor_rpcbase::templates::*;
197

            
198
    #[derive(Clone, Debug, Deftly)]
199
    #[derive_deftly(Object)]
200
    struct ExampleObject(#[allow(unused)] String);
201

            
202
    #[test]
203
    fn map_basics() {
204
        // Insert an object, make sure it gets inserted twice, and look it up.
205
        let obj1 = Arc::new(ExampleObject("abcdef".to_string()));
206
        let mut map = ObjMap::new();
207
        map.assert_okay();
208
        let id1 = map.insert_strong(obj1.clone());
209
        let id2 = map.insert_strong(obj1.clone());
210
        assert_ne!(id1, id2);
211
        let obj1: Arc<dyn rpc::Object> = obj1;
212
        let obj_out1 = map.lookup(id1).unwrap();
213
        let obj_out2 = map.lookup(id2).unwrap();
214
        assert!(Arc::ptr_eq(&obj1, &obj_out1));
215
        assert!(Arc::ptr_eq(&obj1, &obj_out2));
216
        map.assert_okay();
217

            
218
        map.remove(id1);
219
        assert!(map.lookup(id1).is_err());
220
        let obj_out2b = map.lookup(id2).unwrap();
221
        assert!(Arc::ptr_eq(&obj_out2, &obj_out2b));
222

            
223
        map.assert_okay();
224
    }
225

            
226
    #[test]
227
    fn strong_and_weak() {
228
        // Make sure that a strong object behaves like one, and so does a weak
229
        // object.
230
        let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
231
        let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
232
        let mut map = ObjMap::new();
233
        let id1 = map.insert_strong(obj1.clone());
234
        let id2 = map.insert_weak(&obj2);
235

            
236
        {
237
            let out1 = map.lookup(id1).unwrap();
238
            let out2 = map.lookup(id2).unwrap();
239
            assert!(Arc::ptr_eq(&obj1, &out1));
240
            assert!(Arc::ptr_eq(&obj2, &out2));
241
        }
242
        map.assert_okay();
243

            
244
        // Now drop every object we've got, and see what we can still find.
245
        drop(obj1);
246
        drop(obj2);
247
        {
248
            let out1 = map.lookup(id1);
249
            let out2 = map.lookup(id2);
250

            
251
            // This one was strong, so it is still there.
252
            assert!(out1.is_ok());
253

            
254
            // This one is weak so it went away.
255
            assert!(out2.is_err());
256
        }
257
        map.assert_okay();
258
    }
259

            
260
    #[test]
261
    fn remove() {
262
        // Make sure that removing an object makes it go away.
263
        let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
264
        let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
265
        let mut map = ObjMap::new();
266
        let id1 = map.insert_strong(obj1.clone());
267
        let id2 = map.insert_weak(&obj2);
268
        map.assert_okay();
269

            
270
        map.remove(id1);
271
        map.assert_okay();
272
        assert!(map.lookup(id1).is_err());
273
        assert!(map.lookup(id2).is_ok());
274

            
275
        map.remove(id2);
276
        map.assert_okay();
277
        assert!(map.lookup(id1).is_err());
278
        assert!(map.lookup(id2).is_err());
279
    }
280

            
281
    #[test]
282
    fn duplicates() {
283
        let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
284
        let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
285
        let mut map = ObjMap::new();
286
        let id1 = map.insert_strong(obj1.clone());
287
        let id2 = map.insert_weak(&obj2);
288

            
289
        {
290
            assert_ne!(id2, map.insert_weak(&obj1));
291
            assert_ne!(id2, map.insert_weak(&obj2));
292
        }
293

            
294
        {
295
            assert_ne!(id1, map.insert_strong(obj1.clone()));
296
            assert_ne!(id2, map.insert_strong(obj2.clone()));
297
        }
298
    }
299

            
300
    #[test]
301
    fn objid_encoding() {
302
        fn test_roundtrip(a: u32, b: u32, rng: &mut tor_basic_utils::test_rng::TestingRng) {
303
            let a: u64 = a.into();
304
            let b: u64 = b.into();
305
            let data = KeyData::from_ffi((a << 33) | (1_u64 << 32) | b);
306
            let idx = GenIdx::from(data);
307
            let s1 = idx.encode_with_rng(rng);
308
            let s2 = idx.encode_with_rng(rng);
309
            assert_ne!(s1, s2);
310
            assert_eq!(idx, GenIdx::try_decode(&s1).unwrap());
311
            assert_eq!(idx, GenIdx::try_decode(&s2).unwrap());
312
        }
313
        let mut rng = tor_basic_utils::test_rng::testing_rng();
314

            
315
        test_roundtrip(0, 1, &mut rng);
316
        test_roundtrip(0, 2, &mut rng);
317
        test_roundtrip(1, 1, &mut rng);
318
        test_roundtrip(0xffffffff, 0xffffffff, &mut rng);
319

            
320
        for _ in 0..256 {
321
            test_roundtrip(rng.random(), rng.random(), &mut rng);
322
        }
323
    }
324
}