1
//! Implementation logic for RpcConn.
2
//!
3
//! Except for [`RpcConn`] itself, nothing in this module is a public API.
4
//! This module exists so that we can more easily audit the code that
5
//! touches the members of `RpcConn`.
6
//!
7
//! NOTE that many of the types and fields here have documented invariants.
8
//! Except if noted otherwise, these invariants only hold when nobody
9
//! is holding the lock on [`RequestState`].
10
use std::{
11
    collections::{HashMap, VecDeque},
12
    sync::{Arc, Condvar, Mutex, MutexGuard},
13
};
14

            
15
use crate::{
16
    msgs::{
17
        AnyRequestId, ObjectId,
18
        request::{IdGenerator, ValidatedRequest},
19
        response::ValidatedResponse,
20
    },
21
    nb_stream::PollingStream,
22
};
23

            
24
use super::{ProtoError, ShutdownError};
25

            
26
/// State held by the [`RpcConn`] for a single request ID.
27
#[derive(Default)]
28
struct RequestState {
29
    /// A queue of replies received with this request's identity.
30
    queue: VecDeque<ValidatedResponse>,
31
    /// A condition variable used to wake a thread waiting for this request
32
    /// to have messages.
33
    ///
34
    /// We `notify` this condvar thread under one of three circumstances:
35
    ///
36
    /// * When we queue a response for this request.
37
    /// * When we store a fatal error affecting all requests in the RpcConn.
38
    /// * When the thread currently interacting with he [`PollingStream`] for this
39
    ///   RpcConn stops doing so, and the request waiting
40
    ///   on this thread has been chosen to take responsibility for interacting.
41
    ///
42
    /// Invariants:
43
    /// * The condvar is Some if (and only if) some thread is waiting
44
    ///   on it.
45
    waiter: Option<Arc<Condvar>>,
46
}
47

            
48
impl RequestState {
49
    /// Helper: Pop and return the next message for this request.
50
    ///
51
    /// If there are no queued messages, but a fatal error has occurred on the connection,
52
    /// return that.
53
    ///
54
    /// If there are no queued messages and no fatal error, return None.
55
16050
    fn pop_next_msg(
56
16050
        &mut self,
57
16050
        fatal: &Option<ShutdownError>,
58
16050
    ) -> Option<Result<ValidatedResponse, ShutdownError>> {
59
16050
        if let Some(m) = self.queue.pop_front() {
60
7158
            Some(Ok(m))
61
        } else {
62
8932
            fatal.as_ref().map(|f| Err(f.clone()))
63
        }
64
16050
    }
65
}
66

            
67
/// Mutable state to implement receiving replies on an RpcConn.
68
struct ReceiverState {
69
    /// Helper to assign connection- unique IDs to any requests without them.
70
    id_gen: IdGenerator,
71
    /// A fatal error, if any has occurred.
72
    fatal: Option<ShutdownError>,
73
    /// A map from request ID to the corresponding state.
74
    ///
75
    /// There is an entry in this map for every request that we have sent,
76
    /// unless we have received a final response for that request,
77
    /// or we have cancelled that request.
78
    ///
79
    /// (TODO: We might handle cancelling differently.)
80
    pending: HashMap<AnyRequestId, RequestState>,
81
    /// A steam that we use to send requests and receive replies from Arti.
82
    ///
83
    /// Invariants:
84
    ///
85
    /// * If this is None, a thread is polling and will take responsibility
86
    ///   for liveness.
87
    /// * If this is Some, no-one is polling and anyone who cares about liveness
88
    ///   must take on the interactor role.
89
    ///
90
    /// (Therefore, when it becomes Some, we must signal a cv, if any is set.)
91
    stream: Option<PollingStream>,
92
}
93

            
94
impl ReceiverState {
95
    /// Notify an arbitrarily chosen request's condvar.
96
938
    fn alert_anybody(&self) {
97
        // TODO: This is O(n) in the worst case.
98
        //
99
        // But with luck, nobody will make a million requests and
100
        // then wait on them one at a time?
101
1180
        for ent in self.pending.values() {
102
1180
            if let Some(cv) = &ent.waiter {
103
904
                cv.notify_one();
104
904
                return;
105
276
            }
106
        }
107
938
    }
108

            
109
    /// Notify the condvar for every request.
110
6
    fn alert_everybody(&self) {
111
80
        for ent in self.pending.values() {
112
80
            if let Some(cv) = &ent.waiter {
113
80
                // By our rules, each condvar is waited on by precisely one thread.
114
80
                // So we call `notify_one` even though we are trying to wake up everyone.
115
80
                cv.notify_one();
116
80
            }
117
        }
118
6
    }
119
}
120

            
121
/// Object to receive messages on an RpcConn.
122
///
123
/// This is a crate-internal abstraction.
124
/// It's separate from RpcConn for a few reasons:
125
///
126
/// - So we can keep polling the channel while the RpcConn has
127
///   been dropped.
128
/// - So we can hold the lock on this part without being blocked on threads writing.
129
/// - Because this is the only part that for which
130
///   `RequestHandle` needs to keep a reference.
131
pub(super) struct Receiver {
132
    /// Mutable state.
133
    ///
134
    /// This lock should only be held briefly, and never while interacting with the
135
    /// `PollingStream`.
136
    state: Mutex<ReceiverState>,
137
}
138

            
139
/// An open RPC connection to Arti.
140
#[derive(educe::Educe)]
141
#[educe(Debug)]
142
pub struct RpcConn {
143
    /// The receiver object for this conn.
144
    ///
145
    /// It's in an `Arc<>` so that we can share it with the RequestHandles.
146
    #[educe(Debug(ignore))]
147
    receiver: Arc<Receiver>,
148

            
149
    /// A writer that we use to queue requests to be sent back to Arti.
150
    writer: crate::nb_stream::WriteHandle,
151

            
152
    /// If set, we are authenticated and we have negotiated a session that has
153
    /// this ObjectID.
154
    pub(super) session: Option<ObjectId>,
155
}
156

            
157
/// Instruction to alert some additional condvar(s) before releasing our lock and returning
158
///
159
/// Any code which receives one of these must pass the instruction on to someone else,
160
/// until, eventually, the instruction is acted on in [`Receiver::wait_on_message_for`].
161
#[must_use]
162
#[derive(Debug)]
163
enum AlertWhom {
164
    /// We don't need to alert anybody;
165
    /// we have not taken the stream, or registered our own condvar:
166
    /// therefore nobody expects us to take the stream.
167
    Nobody,
168
    /// We have taken the stream or been alerted via our condvar:
169
    /// therefore, we are responsible for making sure
170
    /// that _somebody_ takes the stream.
171
    ///
172
    /// We should therefore alert somebody if nobody currently has the stream.
173
    Anybody,
174
    /// We have been the first to encounter a fatal error.
175
    /// Therefore, we should inform _everybody_.
176
    Everybody,
177
}
178

            
179
impl RpcConn {
180
    /// Construct a new RpcConn with a given PollingStream.
181
10
    pub(super) fn new(stream: PollingStream) -> Self {
182
10
        let writer = stream.writer();
183
10
        Self {
184
10
            receiver: Arc::new(Receiver {
185
10
                state: Mutex::new(ReceiverState {
186
10
                    id_gen: IdGenerator::default(),
187
10
                    fatal: None,
188
10
                    pending: HashMap::new(),
189
10
                    stream: Some(stream),
190
10
                }),
191
10
            }),
192
10
            writer,
193
10
            session: None,
194
10
        }
195
10
    }
196

            
197
    /// Send the request in `msg` on this connection, and return a RequestHandle
198
    /// to wait for a reply.
199
    ///
200
    /// We validate `msg` before sending it out, and reject it if it doesn't
201
    /// make sense. If `msg` has no `id` field, we allocate a new one
202
    /// according to the rules in [`IdGenerator`].
203
    ///
204
    /// Limitation: We don't preserved unrecognized fields in the framing and meta
205
    /// parts of `msg`.  See notes in `request.rs`.
206
4194
    pub(super) fn send_request(&self, msg: &str) -> Result<super::RequestHandle, ProtoError> {
207
        use std::collections::hash_map::Entry::*;
208

            
209
4194
        let mut state = self.receiver.state.lock().expect("poisoned");
210
4194
        if let Some(f) = &state.fatal {
211
            // If there's been a fatal error we don't even try to send the request.
212
10
            return Err(f.clone().into());
213
4184
        }
214

            
215
        // Convert this request into validated form (with an ID) and re-encode it.
216
4184
        let valid: ValidatedRequest =
217
6276
            ValidatedRequest::from_string_loose(msg, || state.id_gen.next_id())?;
218

            
219
        // Do the necessary housekeeping before we send the request, so that
220
        // we'll be able to understand the replies.
221
4184
        let id = valid.id().clone();
222
4184
        match state.pending.entry(id.clone()) {
223
            Occupied(_) => return Err(ProtoError::RequestIdInUse),
224
4184
            Vacant(v) => {
225
4184
                v.insert(RequestState::default());
226
4184
            }
227
        }
228
        // Release the lock on the ReceiverState here; the two locks must not overlap.
229
4184
        drop(state);
230

            
231
        // NOTE: This is the only block of code that holds the writer lock!
232
4184
        let write_outcome = self.writer.send_valid(&valid);
233

            
234
4184
        match write_outcome {
235
            Err(e) => {
236
                // A failed write is a fatal error for everybody.
237
                let e = ShutdownError::Write(Arc::new(e));
238
                let mut state = self.receiver.state.lock().expect("poisoned");
239
                if state.fatal.is_none() {
240
                    state.fatal = Some(e.clone());
241
                    state.alert_everybody();
242
                }
243
                Err(e.into())
244
            }
245

            
246
4184
            Ok(()) => Ok(super::RequestHandle {
247
4184
                id,
248
4184
                conn: Mutex::new(Arc::clone(&self.receiver)),
249
4184
            }),
250
        }
251
4194
    }
252
}
253

            
254
impl Receiver {
255
    /// Wait until there is either a fatal error on this connection,
256
    /// _or_ there is a new message for the request with the provided `id`.
257
    /// Return that message, or a copy of the fatal error.
258
8102
    pub(super) fn wait_on_message_for(
259
8102
        &self,
260
8102
        id: &AnyRequestId,
261
8102
    ) -> Result<ValidatedResponse, ProtoError> {
262
        // Here in wait_on_message_for_impl, we do the the actual work
263
        // of waiting for the message.
264
8102
        let state = self.state.lock().expect("poisoned");
265
8102
        let (result, mut state, should_alert) = self.wait_on_message_for_impl(state, id);
266

            
267
        // Great; we have a message or a fatal error.  All we need to do now
268
        // is to restore our invariants before we drop state_lock.
269
        //
270
        // (It would be a bug to return early without restoring the invariants,
271
        // so we'll use an IEFE pattern to prevent "?" and "return Err".)
272
        #[allow(clippy::redundant_closure_call)]
273
8102
        (|| {
274
            // "final" in this case means that we are not expecting any more
275
            // replies for this request.
276
8102
            let is_final = match &result {
277
86
                Err(_) => true,
278
8016
                Ok(r) => r.is_final(),
279
            };
280

            
281
8102
            if is_final {
282
4184
                // Note 1: It might be cleaner to use Entry::remove(), but Entry is not
283
4184
                // exactly the right shape for us; see note in
284
4184
                // wait_on_message_for_impl.
285
4184

            
286
4184
                // Note 2: This remove isn't necessary if `result` is
287
4184
                // RequestCancelled, but it won't hurt.
288
4184

            
289
4184
                // Note 3: On DuplicateWait, it is not totally clear whether we should
290
4184
                // remove or not.  But that's an internal error that should never occur,
291
4184
                // so it is probably okay if we let the _other_ waiter keep on trying.
292
4184
                state.pending.remove(id);
293
4184
            }
294

            
295
7982
            match should_alert {
296
114
                AlertWhom::Nobody => {}
297
7982
                AlertWhom::Anybody if state.stream.is_none() => {}
298
938
                AlertWhom::Anybody => state.alert_anybody(),
299
6
                AlertWhom::Everybody => state.alert_everybody(),
300
            }
301
        })();
302

            
303
8102
        result
304
8102
    }
305

            
306
    /// Helper to implement [`wait_on_message_for`](Self::wait_on_message_for).
307
    ///
308
    /// Takes a `MutexGuard` as one of its arguments, and returns an equivalent
309
    /// `MutexGuard` on completion.
310
    ///
311
    /// The caller is responsible for:
312
    ///
313
    /// - Removing the appropriate entry from `pending`, if the result
314
    ///   indicates that no more messages will be received for this request.
315
    /// - Possibly, notifying one or more condvars,
316
    ///   depending on the resulting `AlertWhom`.
317
    ///
318
    /// The caller must not drop the `MutexGuard` until it has done the above.
319
8102
    fn wait_on_message_for_impl<'a>(
320
8102
        &'a self,
321
8102
        mut state_lock: MutexGuard<'a, ReceiverState>,
322
8102
        id: &AnyRequestId,
323
8102
    ) -> (
324
8102
        Result<ValidatedResponse, ProtoError>,
325
8102
        MutexGuard<'a, ReceiverState>,
326
8102
        AlertWhom,
327
8102
    ) {
328
        // At this point, we have not registered on a condvar, and we have not
329
        // taken the PollingStream.
330
        // Therefore, we do not yet need to ensure that anybody else takes the PollingStream.
331
        //
332
        // TODO: It is possibly too easy to forget to set this,
333
        // or to set it to a less "alerty" value.  Refactoring might help;
334
        // see discussion at
335
        // https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2258#note_3047267
336
8102
        let mut should_alert = AlertWhom::Nobody;
337

            
338
8102
        let mut state: &mut ReceiverState = &mut state_lock;
339

            
340
        // Initialize `this_ent` to our own entry in the pending table.
341
8102
        let Some(mut this_ent) = state.pending.get_mut(id) else {
342
            return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
343
        };
344

            
345
864
        let mut stream = loop {
346
            // Note: It might be nice to use a hash_map::Entry here, but it
347
            // doesn't really work the way we want.  The `entry()` API is always
348
            // ready to insert, and requires that we clone `id`.  But what we
349
            // want in this case is something that would give us a .remove()able
350
            // Entry only if one is present.
351
16050
            if this_ent.waiter.is_some() {
352
                // This is an internal error; nobody should be able to cause this.
353
                return (Err(ProtoError::DuplicateWait), state_lock, should_alert);
354
16050
            }
355

            
356
16050
            if let Some(ready) = this_ent.pop_next_msg(&state.fatal) {
357
                // There is a reply for us, or a fatal error.
358
7238
                return (ready.map_err(ProtoError::from), state_lock, should_alert);
359
8812
            }
360

            
361
            // If we reach this point, we are about to either take the stream or
362
            // register a cv.  This means that when we return, we need to make
363
            // sure that at least one other cv gets notified.
364
8812
            should_alert = AlertWhom::Anybody;
365

            
366
8812
            if let Some(r) = state.stream.take() {
367
                // Nobody else is polling; we have to do it.
368
864
                break r;
369
7948
            }
370

            
371
            // Somebody else is polling; register a condvar.
372
7948
            let cv = Arc::new(Condvar::new());
373
7948
            this_ent.waiter = Some(Arc::clone(&cv));
374

            
375
7948
            state_lock = cv.wait(state_lock).expect("poisoned lock");
376
7948
            state = &mut state_lock;
377
            // Restore `this_ent`...
378
7948
            let Some(e) = state.pending.get_mut(id) else {
379
                return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
380
            };
381
7948
            this_ent = e;
382
            // ... And un-register our condvar.
383
7948
            this_ent.waiter = None;
384

            
385
            // We have been notified: either there is a reply or us,
386
            // or we are supposed to take the stream.  We'll find out on our
387
            // next time through the loop.
388
        };
389

            
390
864
        let (result, mut state_lock, should_alert) =
391
864
            self.read_until_message_for(state_lock, &mut stream, id);
392
        // Put the stream back.
393
864
        state_lock.stream = Some(stream);
394

            
395
864
        (result.map_err(ProtoError::from), state_lock, should_alert)
396
8102
    }
397

            
398
    /// Interact with `stream`, writing any queued messages,
399
    /// reading messages, and
400
    /// delivering them as appropriate, until we find one for `id`,
401
    /// or a fatal error occurs.
402
    ///
403
    /// Return that message or error, along with a `MutexGuard`.
404
    ///
405
    /// The caller is responsible for restoring the following state before
406
    /// dropping the `MutexGuard`:
407
    ///
408
    /// - Putting `stream` back into the `stream` field.
409
    /// - Other invariants as discussed in wait_on_message_for_impl.
410
864
    fn read_until_message_for<'a>(
411
864
        &'a self,
412
864
        mut state_lock: MutexGuard<'a, ReceiverState>,
413
864
        stream: &mut PollingStream,
414
864
        id: &AnyRequestId,
415
864
    ) -> (
416
864
        Result<ValidatedResponse, ShutdownError>,
417
864
        MutexGuard<'a, ReceiverState>,
418
864
        AlertWhom,
419
864
    ) {
420
        loop {
421
            // Importantly, we drop the state lock while we are polling.
422
            // This is okay, since all our invariants should hold at this point.
423
8022
            drop(state_lock);
424

            
425
8022
            let result = match stream.interact() {
426
2
                Err(e) => Err(ShutdownError::Read(Arc::new(e))),
427
                Ok(None) => Err(ShutdownError::ConnectionClosed),
428
8020
                Ok(Some(m)) => m.try_validate().map_err(ShutdownError::from),
429
            };
430

            
431
8022
            state_lock = self.state.lock().expect("poisoned lock");
432
8022
            let state = &mut state_lock;
433

            
434
8016
            match result {
435
8016
                Ok(m) if m.id() == id => {
436
                    // This only is for us, so there's no need to alert anybody
437
                    // or queue it.
438
858
                    return (Ok(m), state_lock, AlertWhom::Anybody);
439
                }
440
6
                Err(e) => {
441
                    // This is a fatal error on the whole connection.
442
                    //
443
                    // If it's the first one encountered, queue the error, and
444
                    // return it.
445
6
                    if state.fatal.is_none() {
446
6
                        state.fatal = Some(e.clone());
447
6
                    }
448
6
                    return (Err(e), state_lock, AlertWhom::Everybody);
449
                }
450
7158
                Ok(m) => {
451
                    // This is a message for exactly one ID, that isn't us.
452
                    // Queue it and notify them.
453
7158
                    if let Some(ent) = state.pending.get_mut(m.id()) {
454
7158
                        ent.queue.push_back(m);
455
7158
                        if let Some(cv) = &ent.waiter {
456
7106
                            cv.notify_one();
457
7106
                        }
458
                    } else {
459
                        // Nothing wanted this response any longer.
460
                        // _Probably_ this means that we decided to cancel the
461
                        // request but Arti sent this response before it handled
462
                        // our cancellation.
463
                    }
464
                }
465
            };
466
        }
467
864
    }
468
}