1
//! `HasMemoryCost`-related traits, and typed memory cost tracking
2

            
3
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
4

            
5
use crate::internal_prelude::*;
6

            
7
pub use tor_memquota_cost::memory_cost::HasMemoryCost;
8

            
9
/// A [`Participation`] for use only for tracking the memory use of objects of type `T`
10
///
11
/// Wrapping a `Participation` in a `TypedParticipation`
12
/// helps prevent accidentally passing wrongly calculated costs
13
/// to `claim` and `release`.
14
#[derive(Deref, Educe)]
15
#[educe(Clone)]
16
#[educe(Debug(named_field = false))]
17
pub struct TypedParticipation<T> {
18
    /// The actual participation
19
    #[deref]
20
    raw: Participation,
21
    /// Marker
22
    #[educe(Debug(ignore))]
23
    marker: PhantomData<fn(T)>,
24
}
25

            
26
/// Memory cost obtained from a `T`
27
#[derive(Educe, derive_more::Display)]
28
#[educe(Copy, Clone)]
29
#[educe(Debug(named_field = false))]
30
#[display("{raw}")]
31
pub struct TypedMemoryCost<T> {
32
    /// The actual cost in bytes
33
    raw: usize,
34
    /// Marker
35
    #[educe(Debug(ignore))]
36
    marker: PhantomData<fn(T)>,
37
}
38

            
39
/// Types that can return a memory cost known to be the cost of some value of type `T`
40
///
41
/// [`TypedParticipation::claim`] and
42
/// [`release`](TypedParticipation::release)
43
/// take arguments implementing this trait.
44
///
45
/// Implemented by:
46
///
47
///   * `T: HasMemoryCost` (the usual case)
48
///   * `HasTypedMemoryCost<T>` (memory cost, calculated earlier, from a `T`)
49
///
50
/// ### Guarantees
51
///
52
/// This trait has the same guarantees as `HasMemoryCost`.
53
/// Normally, it will not be necessary to add an implementation.
54
// We could seal this trait, but we would need to use a special variant of Sealed,
55
// since we wouldn't want to `impl<T: HasMemoryCost> Sealed for T`
56
// for a normal Sealed trait also used elsewhere.
57
// The bug of implementing this trait for other types seems unlikely,
58
// and we don't think there's a significant API stability hazard.
59
pub trait HasTypedMemoryCost<T>: Sized {
60
    /// The cost, as a `TypedMemoryCost<T>` rather than a raw `usize`
61
    fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T>;
62
}
63

            
64
impl<T: HasMemoryCost> HasTypedMemoryCost<T> for T {
65
19388
    fn typed_memory_cost(&self, enabled: EnabledToken) -> TypedMemoryCost<T> {
66
19388
        TypedMemoryCost::from_raw(self.memory_cost(enabled))
67
19388
    }
68
}
69
impl<T> HasTypedMemoryCost<T> for TypedMemoryCost<T> {
70
19396
    fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T> {
71
19396
        *self
72
19396
    }
73
}
74

            
75
impl<T> TypedParticipation<T> {
76
    /// Wrap a [`Participation`], ensuring that future calls claim and release only `T`
77
6864
    pub fn new(raw: Participation) -> Self {
78
6864
        TypedParticipation {
79
6864
            raw,
80
6864
            marker: PhantomData,
81
6864
        }
82
6864
    }
83

            
84
    /// Record increase in memory use, of a `T: HasMemoryCost` or a `TypedMemoryCost<T>`
85
9950
    pub fn claim(&mut self, t: &impl HasTypedMemoryCost<T>) -> Result<(), Error> {
86
9950
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
87
            return Ok(());
88
        };
89
9950
        self.raw.claim(t.typed_memory_cost(enabled).raw)
90
9950
    }
91
    /// Record decrease in memory use, of a `T: HasMemoryCost` or a `TypedMemoryCost<T>`
92
9458
    pub fn release(&mut self, t: &impl HasTypedMemoryCost<T>) {
93
9458
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
94
            return;
95
        };
96
9458
        self.raw.release(t.typed_memory_cost(enabled).raw);
97
9458
    }
98

            
99
    /// Claiming wrapper for a closure
100
    ///
101
    /// Claims the memory, iff `call` succeeds.
102
    ///
103
    /// Specifically:
104
    /// Claims memory for `item`.   If that fails, returns the error.
105
    /// If the claim succeeded, calls `call`.
106
    /// If it fails or panics, the memory is released, undoing the claim,
107
    /// and the error is returned (or the panic propagated).
108
    ///
109
    /// In these error cases, `item` will typically be dropped by `call`,
110
    /// it is not convenient for `call` to do otherwise.
111
    /// If that's wanted, use [`try_claim_or_return`](TypedParticipation::try_claim_or_return).
112
9686
    pub fn try_claim<C, F, E, R>(&mut self, item: C, call: F) -> Result<Result<R, E>, Error>
113
9686
    where
114
9686
        C: HasTypedMemoryCost<T>,
115
9686
        F: FnOnce(C) -> Result<R, E>,
116
    {
117
9686
        self.try_claim_or_return(item, call).map_err(|(e, _item)| e)
118
9686
    }
119

            
120
    /// Claiming wrapper for a closure
121
    ///
122
    /// Claims the memory, iff `call` succeeds.
123
    ///
124
    /// Like [`try_claim`](TypedParticipation::try_claim),
125
    /// but returns the item if memory claim fails.
126
    /// Typically, a failing `call` will need to return the item in `E`.
127
9942
    pub fn try_claim_or_return<C, F, E, R>(
128
9942
        &mut self,
129
9942
        item: C,
130
9942
        call: F,
131
9942
    ) -> Result<Result<R, E>, (Error, C)>
132
9942
    where
133
9942
        C: HasTypedMemoryCost<T>,
134
9942
        F: FnOnce(C) -> Result<R, E>,
135
    {
136
9942
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
137
            return Ok(call(item));
138
        };
139

            
140
9942
        let cost = item.typed_memory_cost(enabled);
141
9942
        match self.claim(&cost) {
142
9934
            Ok(()) => {}
143
8
            Err(e) => return Err((e, item)),
144
        }
145
        // Unwind safety:
146
        //  - "`F` may not be safely transferred across an unwind boundary"
147
        //    but we don't; it is moved into the closure and
148
        //   it can't obwerve its own panic
149
        //  - "`C` may not be safely transferred across an unwind boundary"
150
        //   Once again, item is moved into call, and never seen again.
151
9934
        match catch_unwind(AssertUnwindSafe(move || call(item))) {
152
4
            Err(panic_payload) => {
153
4
                self.release(&cost);
154
4
                std::panic::resume_unwind(panic_payload)
155
            }
156
12
            Ok(Err(caller_error)) => {
157
12
                self.release(&cost);
158
12
                Ok(Err(caller_error))
159
            }
160
9918
            Ok(Ok(y)) => Ok(Ok(y)),
161
        }
162
9938
    }
163

            
164
    /// Mutably access the inner `Participation`
165
    ///
166
    /// This bypasses the type check.
167
    /// It is up to you to make sure that the `claim` and `release` calls
168
    /// are only made with properly calculated costs.
169
    pub fn as_raw(&mut self) -> &mut Participation {
170
        &mut self.raw
171
    }
172

            
173
    /// Unwrap, and obtain the inner `Participation`
174
3374
    pub fn into_raw(self) -> Participation {
175
3374
        self.raw
176
3374
    }
177
}
178

            
179
impl<T> From<Participation> for TypedParticipation<T> {
180
3378
    fn from(untyped: Participation) -> TypedParticipation<T> {
181
3378
        TypedParticipation::new(untyped)
182
3378
    }
183
}
184

            
185
impl<T> TypedMemoryCost<T> {
186
    /// Convert a raw number of bytes into a type-tagged memory cost
187
19388
    pub fn from_raw(raw: usize) -> Self {
188
19388
        TypedMemoryCost {
189
19388
            raw,
190
19388
            marker: PhantomData,
191
19388
        }
192
19388
    }
193

            
194
    /// Convert a type-tagged memory cost into a raw number of bytes
195
    pub fn into_raw(self) -> usize {
196
        self.raw
197
    }
198
}
199

            
200
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
201
mod test {
202
    // @@ begin test lint list maintained by maint/add_warning @@
203
    #![allow(clippy::bool_assert_comparison)]
204
    #![allow(clippy::clone_on_copy)]
205
    #![allow(clippy::dbg_macro)]
206
    #![allow(clippy::mixed_attributes_style)]
207
    #![allow(clippy::print_stderr)]
208
    #![allow(clippy::print_stdout)]
209
    #![allow(clippy::single_char_pattern)]
210
    #![allow(clippy::unwrap_used)]
211
    #![allow(clippy::unchecked_time_subtraction)]
212
    #![allow(clippy::useless_vec)]
213
    #![allow(clippy::needless_pass_by_value)]
214
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
215
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
216

            
217
    use super::*;
218
    use crate::mtracker::test::*;
219
    use crate::mtracker::*;
220
    use tor_rtmock::MockRuntime;
221

            
222
    // We don't really need to test the correctness, since this is just type wrappers.
223
    // But we should at least demonstrate that the API is usable.
224

            
225
    #[derive(Debug)]
226
    struct DummyParticipant;
227
    impl IsParticipant for DummyParticipant {
228
        fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
229
            None
230
        }
231
        fn reclaim(self: Arc<Self>, _: EnabledToken) -> ReclaimFuture {
232
            panic!()
233
        }
234
    }
235

            
236
    struct Costed;
237
    impl HasMemoryCost for Costed {
238
        fn memory_cost(&self, _: EnabledToken) -> usize {
239
            // We nearly exceed the limit with one allocation.
240
            //
241
            // This proves that claim does claim, or we'd underflow on release,
242
            // and that release does release, not claim, or we'd reclaim and crash.
243
            TEST_DEFAULT_LIMIT - mbytes(1)
244
        }
245
    }
246

            
247
    #[test]
248
    fn api() {
249
        MockRuntime::test_with_various(|rt| async move {
250
            let trk = mk_tracker(&rt);
251
            let acct = trk.new_account(None).unwrap();
252
            let particip = Arc::new(DummyParticipant);
253
            let partn = acct
254
                .register_participant(Arc::downgrade(&particip) as _)
255
                .unwrap();
256
            let mut partn: TypedParticipation<Costed> = partn.into();
257

            
258
            partn.claim(&Costed).unwrap();
259
            partn.release(&Costed);
260

            
261
            let cost = Costed.typed_memory_cost(EnabledToken::new());
262
            partn.claim(&cost).unwrap();
263
            partn.release(&cost);
264

            
265
            // claim, then release due to error
266
            partn
267
                .try_claim(Costed, |_: Costed| Err::<Void, _>(()))
268
                .unwrap()
269
                .unwrap_err();
270

            
271
            // claim, then release due to panic
272
            catch_unwind(AssertUnwindSafe(|| {
273
                let didnt_panic =
274
                    partn.try_claim(Costed, |_: Costed| -> Result<Void, Void> { panic!() });
275
                panic!("{:?}", didnt_panic);
276
            }))
277
            .unwrap_err();
278

            
279
            // claim OK, then explicitly release later
280
            let did_claim = partn
281
                .try_claim(Costed, |c: Costed| Ok::<Costed, Void>(c))
282
                .unwrap()
283
                .void_unwrap();
284
            // Check that we did claim at least something!
285
            assert!(trk.used_current_approx().unwrap() > 0);
286

            
287
            partn.release(&did_claim);
288

            
289
            drop(acct);
290
            drop(particip);
291
            drop(trk);
292
            partn
293
                .try_claim(Costed, |_| -> Result<Void, Void> { panic!() })
294
                .unwrap_err();
295

            
296
            rt.advance_until_stalled().await;
297
        });
298
    }
299
}