1
//! Implementation logic for RpcConn.
2
//!
3
//! Except for [`RpcConn`] itself, nothing in this module is a public API.
4
//! This module exists so that we can more easily audit the code that
5
//! touches the members of `RpcConn`.
6
//!
7
//! NOTE that many of the types and fields here have documented invariants.
8
//! Except if noted otherwise, these invariants only hold when nobody
9
//! is holding the lock on [`RequestState`].
10
//!
11
//! # Overview
12
//!
13
//! Each connection supports both:
14
//!  - requests that the caller will block on (a Waitable request)
15
//!  - requests that the caller will poll for (a Pollable request).
16
//!
17
//! ## Identifying requests
18
//!
19
//! Each request has a corresponding value of a type that implements QueueId
20
//! to identify which queue responses for the request should go into.
21
//!
22
//! - Waitable requests have [`AnyRequestId`]. which implements [`QueueId`]
23
//! - Pollable requests have [`PolledRequests`], a ZST that implements `QueueId``
24
//!
25
//! (Requests themselves all have an [`AnyRequestId`] --
26
//! the actual ID that we send out in the request,
27
//! which the RPC server sends back in all responses.
28
//! Additionally, Pollable requests are created with a client-defined [`UserTag`],
29
//! which the client can use to identify their particular requests.
30
//! `UserTag is a separate type to help FFI-style programs
31
//! that want to put things like pointers in it.)
32
//!
33
//! # Data structure
34
//!
35
//! The connection has
36
//!   - an outbound queue for outbound messages, in its [`BlockingConnection`].
37
//!   - [`RequestMap`], a data structure containing outstanding requests,
38
//!     which is used for knowing what to do with inbound messages
39
//!
40
//! If the request is Waitable,
41
//! its `RequestMap.map` entry is [`RequestState::Waiting`],
42
//! and contains its own [`ResponseQueue`].
43
//!
44
//! If the request is Pollable,
45
//! its `RequestMap` entry is [`RequestState::Pollable`],
46
//! and contains the Tag that the application will use
47
//! to distinguish responses ot that request.
48
//! All responses to _all_ Pollable requests
49
//! are queued within `RequestMap::polled_response_queue`.
50
//!
51
//! # Operation
52
//!
53
//! When we make a request, we add an entry to the `RequestMap::map`.
54
//! The entry stays there until we receive a final response to the request.
55
//!
56
//! At any given time,
57
//! multiple threads can be waiting for responses on the same RpcConn object.
58
//! If [`RpcPoll`] is not in use,
59
//! exactly of these threads will actually be holding the [`BlockingConnection`]
60
//! and trying to read from the network.
61
//! (Otherwise (if [`RpcPoll`] is in use), the `RpcPoll` object will be holding the
62
//! [`NonblockingConnection`] and will be responsible for all reads and writes.)
63
//!
64
//! If it finds a response for itself, it returns that response.
65
//! Otherwise, it puts the response in the appropriate queue,
66
//! and signal's the condvar associated with that queue.
67
//!
68
//! There are two kinds of queue:
69
//! A per-request queue used by Waitable requests,
70
//! and a single queue shared by all Polled requests.
71
//! Every queue has its  own associated condvar.
72
//!
73
//! The two kinds of queue are slightly different.
74
//! (We represent their differences with the QueueId trait):
75
//!     - Pollable responses need to carry a `UserTag`;
76
//!       Waitable responses don't. This is [`QueueId::UserTag`].
77
//!     - We need to treat final responses a bit differently
78
//!       in terms of how we find what to remove.
79
//!       This is [`QueueId::remove_entry`].
80
//!     - If we're holding the connection and waiting for responses on a given queue,
81
//!       we need to answer the "is this for us?" question a little differently.
82
//!       This is [`QueueId::response_disposition`]`.
83

            
84
use std::{
85
    collections::{HashMap, VecDeque},
86
    sync::{Arc, Condvar, Mutex, MutexGuard},
87
};
88

            
89
use crate::{
90
    UserTag,
91
    conn::AnyResponse,
92
    ll_conn::{BlockingConnection, NonblockingConnection},
93
    msgs::{
94
        AnyRequestId, ObjectId,
95
        request::{IdGenerator, ValidatedRequest},
96
        response::ValidatedResponse,
97
    },
98
};
99

            
100
use super::{ProtoError, ShutdownError};
101

            
102
/// An identifier for a [`ResponseQueue`] within a [`RequestMap`].
103
trait QueueId {
104
    /// A tag type associated with responses in the identified queue.
105
    ///
106
    /// ("Polling" requests use [`UserTag`]s to tell the user
107
    /// which response goes with which request.)
108
    type UserTag: Sized;
109

            
110
    /// Find the queue identified by this `QueueId` within `map`,
111
    /// in order to wait for messages on it.
112
    fn get_queue_mut<'a>(
113
        &self,
114
        map: &'a mut RequestMap,
115
    ) -> Result<&'a mut ResponseQueue<Self>, ProtoError>;
116

            
117
    /// Given that we are polling on the queue identified by `self`,
118
    /// determine what we should do with `msg`.
119
    ///
120
    /// (Should we return it, drop it, or forward it to somebody else?)
121
    ///
122
    /// This is used by the core waiting code in `read_until_message_for`,
123
    /// which needs to be able to handle any incoming response,
124
    /// even one which is for a different context / different caller,
125
    /// and reroute the message to the appropriate place.
126
    fn response_disposition<'a>(
127
        &self,
128
        map: &'a mut RequestMap,
129
        msg: &ValidatedResponse,
130
    ) -> ResponseDisposition<'a, Self>;
131

            
132
    /// Remove any state from `map` associated with `msg_id`.
133
    ///
134
    /// (If `msg_id` is absent, an error occurred that was not associated with any message ID.)
135
    fn remove_entry<'a>(&self, map: &'a mut RequestMap, msg_id: Option<&AnyRequestId>);
136

            
137
    /// Create and return a new RequestState to track a request associated with this kind of ID.
138
    fn new_entry(tag: Self::UserTag) -> RequestState;
139
}
140

            
141
impl QueueId for AnyRequestId {
142
    type UserTag = ();
143

            
144
16184
    fn get_queue_mut<'a>(
145
16184
        &self,
146
16184
        map: &'a mut RequestMap,
147
16184
    ) -> Result<&'a mut ResponseQueue<Self>, ProtoError> {
148
16184
        match map.map.get_mut(self) {
149
16184
            Some(RequestState::Waiting(s)) => Ok(s),
150
            Some(RequestState::Pollable(_)) => Err(ProtoError::RequestNotWaitable),
151
            None => Err(ProtoError::RequestCompleted),
152
        }
153
16184
    }
154

            
155
8082
    fn response_disposition<'a>(
156
8082
        &self,
157
8082
        map: &'a mut RequestMap,
158
8082
        msg: &ValidatedResponse,
159
8082
    ) -> ResponseDisposition<'a, Self> {
160
8082
        if self == msg.id() {
161
            // This message is for us; no reason to look anything up.
162
880
            return ResponseDisposition::Return(());
163
7202
        }
164

            
165
7202
        match map.map.get_mut(msg.id()) {
166
7202
            Some(RequestState::Waiting(q)) => ResponseDisposition::ForwardWaiting(q),
167
            Some(RequestState::Pollable(tag)) => {
168
                ResponseDisposition::ForwardPollable(*tag, &mut map.polled_response_queue)
169
            }
170
            None => ResponseDisposition::Ignore,
171
        }
172
8082
    }
173

            
174
4186
    fn remove_entry<'a>(&self, map: &'a mut RequestMap, _: Option<&AnyRequestId>) {
175
4186
        map.map.remove(self);
176
4186
    }
177

            
178
    /// Create and return a new RequestState to track a request associated with this kind of ID.
179
4186
    fn new_entry(_: Self::UserTag) -> RequestState {
180
4186
        RequestState::Waiting(ResponseQueue::default())
181
4186
    }
182
}
183

            
184
/// Identifier for the set of "Pollable" requests.
185
///
186
/// As distinct from "Waitable" requests, which are created with "execute*" methods and
187
/// whose APIs expect the user to block while waiting for responses,
188
/// polled requests are created with "submit*" methods,
189
/// and their replies are returned, along with [`UserTag`] instances,
190
/// from the RpcConn directly.
191
struct PolledRequests;
192

            
193
impl QueueId for PolledRequests {
194
    type UserTag = UserTag;
195

            
196
    fn get_queue_mut<'a>(
197
        &self,
198
        map: &'a mut RequestMap,
199
    ) -> Result<&'a mut ResponseQueue<Self>, ProtoError> {
200
        Ok(&mut map.polled_response_queue)
201
    }
202

            
203
    fn response_disposition<'a>(
204
        &self,
205
        map: &'a mut RequestMap,
206
        msg: &ValidatedResponse,
207
    ) -> ResponseDisposition<'a, Self> {
208
        match map.map.get_mut(msg.id()) {
209
            Some(RequestState::Waiting(s)) => ResponseDisposition::ForwardWaiting(s),
210
            Some(RequestState::Pollable(tag)) => ResponseDisposition::Return(*tag),
211
            None => ResponseDisposition::Ignore,
212
        }
213
    }
214

            
215
    fn remove_entry<'a>(&self, map: &'a mut RequestMap, msg_id: Option<&AnyRequestId>) {
216
        let Some(msg_id) = msg_id else {
217
            // This can only happen when we have an error that wasn't associated with a message ID.
218
            // We can't actually remove the appropriate thing.
219
            return;
220
        };
221

            
222
        map.map.remove(msg_id);
223
    }
224

            
225
    fn new_entry(tag: Self::UserTag) -> RequestState {
226
        RequestState::Pollable(tag)
227
    }
228
}
229

            
230
/// A queue of responses used to alert a polling function about replies to
231
/// one or more requests.
232
#[derive(educe::Educe)]
233
#[educe(Default)]
234
struct ResponseQueue<Q: QueueId + ?Sized> {
235
    /// A queue of replies received with this request's identity.
236
    queue: VecDeque<(Q::UserTag, ValidatedResponse)>,
237
    /// A condition variable used to wake a thread waiting for this request
238
    /// to have messages.
239
    ///
240
    /// We `notify` this condvar thread under one of three circumstances:
241
    ///
242
    /// * When we queue a response for this request.
243
    /// * When we store a fatal error affecting all requests in the RpcConn.
244
    /// * When the thread currently interacting with he [`BlockingConnection`] for this
245
    ///   RpcConn stops doing so, and the request waiting
246
    ///   on this thread has been chosen to take responsibility for interacting.
247
    ///
248
    /// Invariants:
249
    /// * The condvar is Some if (and only if) some thread is waiting
250
    ///   on it.
251
    waiter: Option<Arc<Condvar>>,
252
}
253

            
254
/// State held by the [`RpcConn`] for a single request ID.
255
enum RequestState {
256
    /// A request submitted by one of the `execute_*` functions:
257
    /// The user must call a "wait" function for this request specifically in order to get
258
    /// responses. This request has its own queue.
259
    Waiting(ResponseQueue<AnyRequestId>),
260

            
261
    /// A request submitted by one of the `submit_*` functions:
262
    /// the user must provide an associated [`UserTag`],
263
    /// and call [`RpcConn::wait`] to find responses.
264
    Pollable(UserTag),
265
}
266

            
267
impl<Q: QueueId + ?Sized> ResponseQueue<Q> {
268
    /// Helper: Pop and return the next message for this request.
269
    ///
270
    /// If there are no queued messages, but a fatal error has occurred on the connection,
271
    /// return that.
272
    ///
273
    /// If there are no queued messages and no fatal error, return None.
274
16184
    fn pop_next_msg(
275
16184
        &mut self,
276
16184
        fatal: &Option<ShutdownError>,
277
16184
    ) -> Option<Result<(Q::UserTag, ValidatedResponse), ShutdownError>> {
278
16184
        if let Some(m) = self.queue.pop_front() {
279
7202
            Some(Ok(m))
280
        } else {
281
8982
            fatal.as_ref().map(|f| Err(f.clone()))
282
        }
283
16184
    }
284

            
285
    /// Queue `response` for this request, and alert the condvar (if any).
286
7202
    fn push_back_and_alert(&mut self, tag: Q::UserTag, response: ValidatedResponse) {
287
7202
        self.queue.push_back((tag, response));
288

            
289
7202
        if let Some(cv) = &self.waiter {
290
7126
            cv.notify_one();
291
7126
        }
292
7202
    }
293
}
294

            
295
/// A map from a [`QueueId`] to a request state.
296
#[derive(Default)]
297
struct RequestMap {
298
    /// A map from request ID to the state for that request ID.
299
    ///
300
    /// Entries are added to this map when a request is sent,
301
    /// and removed when the request encounters
302
    /// an error or a final response.
303
    map: HashMap<AnyRequestId, RequestState>,
304

            
305
    /// A response queue to hold the responses for pollable requests.
306
    polled_response_queue: ResponseQueue<PolledRequests>,
307
}
308

            
309
/// An action to take with a given message.
310
///
311
/// Returned by [`QueueId::response_disposition`]
312
enum ResponseDisposition<'a, Q: QueueId + ?Sized> {
313
    /// This message is for the queue that we are waiting for;
314
    /// we should return it to the caller.
315
    Return(Q::UserTag),
316

            
317
    /// This message if for a dead request that was probably cancelled;
318
    /// we should drop it.
319
    Ignore,
320

            
321
    /// This message is for some other request;
322
    /// we should instead forward it to that request's queue.
323
    ForwardWaiting(&'a mut ResponseQueue<AnyRequestId>),
324

            
325
    /// This message is for some other request;
326
    ///  we should instead forward it to the polled request queue.
327
    ForwardPollable(UserTag, &'a mut ResponseQueue<PolledRequests>),
328
}
329

            
330
/// Mutable state to implement receiving replies on an RpcConn.
331
struct ReceiverState {
332
    /// Helper to assign connection- unique IDs to any requests without them.
333
    id_gen: IdGenerator,
334
    /// A fatal error, if any has occurred.
335
    fatal: Option<ShutdownError>,
336
    /// A map from request ID to the corresponding state.
337
    ///
338
    /// There is an entry in this map for every request that we have sent,
339
    /// unless we have received a final response for that request,
340
    /// or we have cancelled that request.
341
    ///
342
    /// (TODO: We might handle cancelling differently.)
343
    pending: RequestMap,
344
    /// A steam that we use to send requests and receive replies from Arti.
345
    ///
346
    /// Invariants:
347
    ///
348
    /// * If this is None, a thread is polling and will take responsibility
349
    ///   for liveness.
350
    /// * If this is Some, no-one is polling and anyone who cares about liveness
351
    ///   must take on the interactor role.
352
    ///
353
    /// (Therefore, when it becomes Some, we must signal a cv, if any is set.)
354
    conn: Option<BlockingConnection>,
355
}
356

            
357
impl RequestMap {
358
    /// Notify an arbitrarily chosen request's condvar.
359
962
    fn alert_anybody(&self) {
360
        // TODO: This is O(n) in the worst case.
361
        //
362
        // But with luck, nobody will make a million requests and
363
        // then wait on them one at a time?
364
1248
        for ent in self.map.values() {
365
            if let RequestState::Waiting(ResponseQueue {
366
948
                waiter: Some(cv), ..
367
1248
            }) = ent
368
            {
369
948
                cv.notify_one();
370
948
                return;
371
300
            }
372
        }
373
962
    }
374

            
375
    /// Notify the condvar for every request.
376
6
    fn alert_everybody(&self) {
377
82
        for ent in self.map.values() {
378
            if let RequestState::Waiting(ResponseQueue {
379
82
                waiter: Some(cv), ..
380
82
            }) = ent
381
82
            {
382
82
                // By our rules, each condvar is waited on by precisely one thread.
383
82
                // So we call `notify_one` even though we are trying to wake up everyone.
384
82
                cv.notify_one();
385
82
            }
386
        }
387
6
    }
388
}
389

            
390
/// Object to receive messages on an RpcConn.
391
///
392
/// This is a crate-internal abstraction.
393
/// It's separate from RpcConn for a few reasons:
394
///
395
/// - So we can keep polling the channel while the RpcConn has
396
///   been dropped.
397
/// - So we can hold the lock on this part without being blocked on threads writing.
398
/// - Because this is the only part that for which
399
///   `RequestHandle` needs to keep a reference.
400
pub(super) struct Receiver {
401
    /// Mutable state.
402
    ///
403
    /// This lock should only be held briefly, and never while interacting with the
404
    /// `BlockingConnection`.
405
    state: Mutex<ReceiverState>,
406
}
407

            
408
/// An open RPC connection to Arti.
409
#[derive(educe::Educe)]
410
#[educe(Debug)]
411
pub struct RpcConn {
412
    /// The receiver object for this conn.
413
    ///
414
    /// It's in an `Arc<>` so that we can share it with the RequestHandles.
415
    #[educe(Debug(ignore))]
416
    pub(super) receiver: Arc<Receiver>,
417

            
418
    /// A writer that we use to queue requests to be sent back to Arti.
419
    writer: crate::ll_conn::WriteHandle,
420

            
421
    /// If set, we are authenticated and we have negotiated a session that has
422
    /// this ObjectID.
423
    pub(super) session: Option<ObjectId>,
424
}
425

            
426
/// A handle used to poll for RPC responses within an [event-driven IO] loop.
427
///
428
/// Only one handle of this type can exist per [`RpcConn`].
429
///
430
/// This type is _not_ intended to be used by multiple threads at once: Only one thread at a time
431
/// should ever invoke its [`poll`](RpcPoll::poll) method.
432
/// (In Rust, this is enforced by having RpcPoll::poll take `&mut self`.)
433
///
434
/// [event-driven IO]: https://man7.org/linux/man-pages/man2/select.2.html
435
pub struct RpcPoll {
436
    /// The message-receiver that we're using to track request state and report responses.
437
    receiver: Arc<Receiver>,
438

            
439
    /// The underling nonblocking connection that we're polling for readiness,
440
    /// and using to send and receive messages.
441
    nbconn: NonblockingConnection,
442
}
443

            
444
/// Instruction to alert some additional condvar(s) before releasing our lock and returning
445
///
446
/// Any code which receives one of these must pass the instruction on to someone else,
447
/// until, eventually, the instruction is acted on in [`Receiver::wait_on_message_for`].
448
#[must_use]
449
#[derive(Debug)]
450
enum AlertWhom {
451
    /// We don't need to alert anybody;
452
    /// we have not taken the connection, or registered our own condvar:
453
    /// therefore nobody expects us to take the connection.
454
    Nobody,
455
    /// We have taken the connection or been alerted via our condvar:
456
    /// therefore, we are responsible for making sure
457
    /// that _somebody_ takes the connection.
458
    ///
459
    /// We should therefore alert somebody if nobody currently has the connection.
460
    Anybody,
461
    /// We have been the first to encounter a fatal error.
462
    /// Therefore, we should inform _everybody_.
463
    Everybody,
464
}
465

            
466
impl RpcConn {
467
    /// Construct a new RpcConn with a given BlockingConnection.
468
10
    pub(super) fn new(conn: BlockingConnection) -> Self {
469
10
        let writer = conn.writer();
470
10
        Self {
471
10
            receiver: Arc::new(Receiver {
472
10
                state: Mutex::new(ReceiverState {
473
10
                    id_gen: IdGenerator::default(),
474
10
                    fatal: None,
475
10
                    pending: RequestMap::default(),
476
10
                    conn: Some(conn),
477
10
                }),
478
10
            }),
479
10
            writer,
480
10
            session: None,
481
10
        }
482
10
    }
483

            
484
    /// Return a new [`RpcPoll`] to use for managing an RpcConn using event-driven IO.
485
    ///
486
    /// Removes the `BlockingConnection` from this `RpcConn`
487
    /// and drops any mio resources associated with it.
488
    /// After this method is called is called, only `RpcPoll::poll()` can interact with it.
489
    ///
490
    /// See caveats on [`RpcConnBuilder::connect_polling`](crate::RpcConnBuilder::connect_polling).
491
    pub(crate) fn construct_rpc_poll(
492
        &mut self,
493
        event_loop: Box<dyn crate::ll_conn::EventLoop>,
494
    ) -> Option<RpcPoll> {
495
        let mut state = self.receiver.state.lock().expect("Lock poisoned");
496
        // TODO nb: enforce that nobody else is holding the state?  Return an error?
497
        let mut nbconn = state.conn.take()?.into_nonblocking();
498
        nbconn.replace_event_loop_handle(event_loop);
499
        Some(RpcPoll {
500
            receiver: Arc::clone(&self.receiver),
501
            nbconn,
502
        })
503
    }
504

            
505
    /// Send the request in `msg` on this connection, and return a RequestHandle
506
    /// to wait for a reply.
507
    ///
508
    /// We validate `msg` before sending it out, and reject it if it doesn't
509
    /// make sense. If `msg` has no `id` field, we allocate a new one
510
    /// according to the rules in [`IdGenerator`].
511
    ///
512
    /// Limitation: We don't preserved unrecognized fields in the framing and meta
513
    /// parts of `msg`.  See notes in `request.rs`.
514
4194
    pub(super) fn send_waitable_request(
515
4194
        &self,
516
4194
        msg: &str,
517
4194
    ) -> Result<super::RequestHandle, ProtoError> {
518
4194
        let id = self.send_request_impl::<AnyRequestId>(msg, ())?;
519
4186
        Ok(super::RequestHandle {
520
4186
            conn: Mutex::new(Arc::clone(&self.receiver)),
521
4186
            id,
522
4186
        })
523
4194
    }
524

            
525
    /// As a`send_waitable_request`, but send a Polled request -- one without a RequestHandle,
526
    /// where responses are returned via [`RpcConn::wait()`].
527
    pub(super) fn send_pollable_request(&self, tag: UserTag, msg: &str) -> Result<(), ProtoError> {
528
        let _id = self.send_request_impl::<PolledRequests>(msg, tag)?;
529
        Ok(())
530
    }
531

            
532
    /// Helper for send_request.
533
    ///
534
    /// We use the [`QueueId`] parameter to determine what kind of queue will
535
4194
    fn send_request_impl<Q: QueueId>(
536
4194
        &self,
537
4194
        msg: &str,
538
4194
        tag: Q::UserTag,
539
4194
    ) -> Result<AnyRequestId, ProtoError> {
540
        use std::collections::hash_map::Entry::*;
541

            
542
4194
        let mut state = self.receiver.state.lock().expect("poisoned");
543
4194
        if let Some(f) = &state.fatal {
544
            // If there's been a fatal error we don't even try to send the request.
545
8
            return Err(f.clone().into());
546
4186
        }
547

            
548
        // Convert this request into validated form (with an ID) and re-encode it.
549
4186
        let valid: ValidatedRequest =
550
4186
            ValidatedRequest::from_string_loose(msg, || state.id_gen.next_id())?;
551

            
552
        // Do the necessary housekeeping before we send the request, so that
553
        // we'll be able to understand the replies.
554
4186
        let id = valid.id().clone();
555
4186
        match state.pending.map.entry(id.clone()) {
556
            Occupied(_) => return Err(ProtoError::RequestIdInUse),
557
4186
            Vacant(v) => {
558
4186
                v.insert(Q::new_entry(tag));
559
4186
            }
560
        }
561
        // Release the lock on the ReceiverState here; the two locks must not overlap.
562
4186
        drop(state);
563

            
564
        // NOTE: This is the only block of code that holds the writer lock!
565
4186
        let write_outcome = self.writer.send_valid(&valid);
566

            
567
4186
        match write_outcome {
568
            Err(e) => {
569
                // A failed write is a fatal error for everybody.
570
                let e = ShutdownError::Write(Arc::new(e));
571
                let mut state = self.receiver.state.lock().expect("poisoned");
572
                if state.fatal.is_none() {
573
                    state.fatal = Some(e.clone());
574
                    state.pending.alert_everybody();
575
                }
576
                Err(e.into())
577
            }
578

            
579
4186
            Ok(()) => Ok(id),
580
        }
581
4194
    }
582
}
583

            
584
impl Receiver {
585
    /// Wait until there is either a fatal error on this connection,
586
    /// _or_ there is a new message for the queue with the provided waiting request `id`.
587
    /// Return that message, or a copy of the fatal error.
588
8170
    pub(super) fn wait_on_message_for(
589
8170
        &self,
590
8170
        id: &AnyRequestId,
591
8170
    ) -> Result<ValidatedResponse, ProtoError> {
592
8170
        let ((), response) = self.wait_on_message_for_queue(id)?;
593
8082
        Ok(response)
594
8170
    }
595

            
596
    /// Wait until there is aeither a fatal error on this connection,
597
    /// _or_ there is a new message for some pollable request.
598
    pub(super) fn wait_on_pollable_response(
599
        &self,
600
    ) -> Result<(UserTag, ValidatedResponse), ProtoError> {
601
        self.wait_on_message_for_queue(&PolledRequests)
602
    }
603

            
604
    /// Wait until there is either a fatal error on this connection,
605
    /// _or_ there is a new message for the queue with the provided `queue_id`.
606
    /// Return that message, or a copy of the fatal error.
607
8170
    fn wait_on_message_for_queue<Q: QueueId>(
608
8170
        &self,
609
8170
        queue_id: &Q,
610
8170
    ) -> Result<(Q::UserTag, ValidatedResponse), ProtoError> {
611
        // Here in wait_on_message_for_impl, we do the actual work
612
        // of waiting for the message.
613
8170
        let state = self.state.lock().expect("poisoned");
614
8170
        let (result, mut state, should_alert) = self.wait_on_message_for_impl(state, queue_id);
615

            
616
        // Great; we have a message or a fatal error.  All we need to do now
617
        // is to restore our invariants before we drop state_lock.
618
        //
619
        // (It would be a bug to return early without restoring the invariants,
620
        // so we'll use an IEFE pattern to prevent "?" and "return Err".)
621
        #[allow(clippy::redundant_closure_call)]
622
8170
        (|| {
623
            // "final" in this case means that we are not expecting any more
624
            // replies for this request.
625
8170
            let (msg_id, is_final) = match &result {
626
88
                Err(_) => (None, true),
627
8082
                Ok(r) => (Some(r.1.id()), r.1.is_final()),
628
            };
629

            
630
8170
            if is_final {
631
4186
                // Note 1: It might be cleaner to use Entry::remove(), but Entry is not
632
4186
                // exactly the right shape for us; see note in
633
4186
                // wait_on_message_for_impl.
634
4186

            
635
4186
                // Note 2: This remove isn't necessary if `result` is
636
4186
                // RequestCancelled, but it won't hurt.
637
4186

            
638
4186
                // Note 3: On DuplicateWait, it is not totally clear whether we should
639
4186
                // remove or not.  But that's an internal error that should never occur,
640
4186
                // so it is probably okay if we let the _other_ waiter keep on trying.
641
4186
                queue_id.remove_entry(&mut state.pending, msg_id);
642
4186
            }
643

            
644
8026
            match should_alert {
645
138
                AlertWhom::Nobody => {}
646
8026
                AlertWhom::Anybody if state.conn.is_none() => {}
647
962
                AlertWhom::Anybody => state.pending.alert_anybody(),
648
6
                AlertWhom::Everybody => state.pending.alert_everybody(),
649
            }
650
        })();
651

            
652
8170
        result
653
8170
    }
654

            
655
    /// Helper to implement [`wait_on_message_for`](Self::wait_on_message_for).
656
    ///
657
    /// Takes a `MutexGuard` as one of its arguments, and returns an equivalent
658
    /// `MutexGuard` on completion.
659
    ///
660
    /// The caller is responsible for:
661
    ///
662
    /// - Removing the appropriate entry from `pending`, if the result
663
    ///   indicates that no more messages will be received for this request.
664
    /// - Possibly, notifying one or more condvars,
665
    ///   depending on the resulting `AlertWhom`.
666
    ///
667
    /// The caller must not drop the `MutexGuard` until it has done the above.
668
    #[allow(clippy::type_complexity)]
669
8170
    fn wait_on_message_for_impl<'a, Q: QueueId>(
670
8170
        &'a self,
671
8170
        mut state_lock: MutexGuard<'a, ReceiverState>,
672
8170
        queue_id: &Q,
673
8170
    ) -> (
674
8170
        Result<(Q::UserTag, ValidatedResponse), ProtoError>,
675
8170
        MutexGuard<'a, ReceiverState>,
676
8170
        AlertWhom,
677
8170
    ) {
678
        // At this point, we have not registered on a condvar, and we have not
679
        // taken the BlockingConnection.
680
        // Therefore, we do not yet need to ensure that anybody else takes the BlockingConnection.
681
        //
682
        // TODO: It is possibly too easy to forget to set this,
683
        // or to set it to a less "alerty" value.  Refactoring might help;
684
        // see discussion at
685
        // https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2258#note_3047267
686
8170
        let mut should_alert = AlertWhom::Nobody;
687

            
688
8170
        let mut state: &mut ReceiverState = &mut state_lock;
689

            
690
        // Initialize `this_ent` to our own entry in the pending table.
691
8170
        let mut this_ent = match queue_id.get_queue_mut(&mut state.pending) {
692
8170
            Ok(ent) => ent,
693
            Err(err) => return (Err(err), state_lock, should_alert),
694
        };
695

            
696
886
        let mut conn = loop {
697
            // Note: It might be nice to use a hash_map::Entry here, but it
698
            // doesn't really work the way we want.  The `entry()` API is always
699
            // ready to insert, and requires that we clone `id`.  But what we
700
            // want in this case is something that would give us a .remove()able
701
            // Entry only if one is present.
702
16184
            if this_ent.waiter.is_some() {
703
                // This is an internal error; nobody should be able to cause this.
704
                return (Err(ProtoError::DuplicateWait), state_lock, should_alert);
705
16184
            }
706

            
707
16184
            if let Some(ready) = this_ent.pop_next_msg(&state.fatal) {
708
                // There is a reply for us, or a fatal error.
709
7284
                return (ready.map_err(ProtoError::from), state_lock, should_alert);
710
8900
            }
711

            
712
            // If we reach this point, we are about to either take the connection or
713
            // register a cv.  This means that when we return, we need to make
714
            // sure that at least one other cv gets notified.
715
8900
            should_alert = AlertWhom::Anybody;
716

            
717
8900
            if let Some(r) = state.conn.take() {
718
                // Nobody else is polling; we have to do it.
719
886
                break r;
720
8014
            }
721

            
722
            // Somebody else is polling; register a condvar.
723
8014
            let cv = Arc::new(Condvar::new());
724
8014
            this_ent.waiter = Some(Arc::clone(&cv));
725

            
726
8014
            state_lock = cv.wait(state_lock).expect("poisoned lock");
727
8014
            state = &mut state_lock;
728
            // Restore `this_ent`...
729
8014
            let e = match queue_id.get_queue_mut(&mut state.pending) {
730
8014
                Ok(ent) => ent,
731
                Err(err) => return (Err(err), state_lock, should_alert),
732
            };
733
8014
            this_ent = e;
734
            // ... And un-register our condvar.
735
8014
            this_ent.waiter = None;
736

            
737
            // We have been notified: either there is a reply or us,
738
            // or we are supposed to take the connection.  We'll find out on our
739
            // next time through the loop.
740
        };
741

            
742
886
        let (result, mut state_lock, should_alert) =
743
886
            self.read_until_message_for(state_lock, &mut conn, queue_id);
744
        // Put the connection back.
745
886
        state_lock.conn = Some(conn);
746

            
747
886
        (result.map_err(ProtoError::from), state_lock, should_alert)
748
8170
    }
749

            
750
    /// Interact with `conn`, writing any queued messages,
751
    /// reading messages, and
752
    /// delivering them as appropriate, until we find one for the queue `queue_id`
753
    /// or a fatal error occurs.
754
    ///
755
    /// Return that message or error, along with a `MutexGuard`.
756
    ///
757
    /// The caller is responsible for restoring the following state before
758
    /// dropping the `MutexGuard`:
759
    ///
760
    /// - Putting `conn` back into the `conn` field.
761
    /// - Other invariants as discussed in wait_on_message_for_impl.
762
    #[allow(clippy::type_complexity)]
763
886
    fn read_until_message_for<'a, Q: QueueId>(
764
886
        &'a self,
765
886
        mut state_lock: MutexGuard<'a, ReceiverState>,
766
886
        conn: &mut BlockingConnection,
767
886
        queue_id: &Q,
768
886
    ) -> (
769
886
        Result<(Q::UserTag, ValidatedResponse), ShutdownError>,
770
886
        MutexGuard<'a, ReceiverState>,
771
886
        AlertWhom,
772
886
    ) {
773
        loop {
774
            // Importantly, we drop the state lock while we are polling.
775
            // This is okay, since all our invariants should hold at this point.
776
8088
            drop(state_lock);
777

            
778
8088
            let result = match conn.interact() {
779
2
                Err(e) => Err(ShutdownError::Read(Arc::new(e))),
780
                Ok(None) => Err(ShutdownError::ConnectionClosed),
781
8086
                Ok(Some(m)) => m.try_validate().map_err(ShutdownError::from),
782
            };
783

            
784
8088
            state_lock = self.state.lock().expect("poisoned lock");
785
8088
            let state = &mut state_lock;
786

            
787
8088
            let response = match result {
788
8082
                Ok(m) => m,
789
6
                Err(e) => {
790
                    // This is a fatal error on the whole connection.
791
                    //
792
                    // If it's the first one encountered, queue the error.
793
                    // In any case, return it.
794
6
                    if state.fatal.is_none() {
795
6
                        state.fatal = Some(e.clone());
796
6
                    }
797
6
                    return (Err(e), state_lock, AlertWhom::Everybody);
798
                }
799
            };
800

            
801
8082
            match queue_id.response_disposition(&mut state.pending, &response) {
802
880
                ResponseDisposition::Return(tag) => {
803
                    // This only is for us, so there's no need to alert anybody specific
804
                    // or queue it.
805
880
                    return (Ok((tag, response)), state_lock, AlertWhom::Anybody);
806
                }
807
7202
                ResponseDisposition::ForwardWaiting(queue) => {
808
7202
                    queue.push_back_and_alert((), response);
809
7202
                }
810
                ResponseDisposition::ForwardPollable(tag, queue) => {
811
                    queue.push_back_and_alert(tag, response);
812
                }
813
                ResponseDisposition::Ignore => {
814
                    // Nothing wanted this response any longer.
815
                    // _Probably_ this means that we decided to cancel the
816
                    // request but Arti sent this response before it handled
817
                    // our cancellation.
818
                }
819
            }
820
        }
821
886
    }
822
}
823

            
824
/// Type returned by [`RpcPoll::poll`] when no progress can be made until the underlying
825
/// connection has more data to read or write.
826
#[derive(Copy, Clone, Debug, Default)]
827
#[non_exhaustive]
828
pub struct WouldBlock;
829

            
830
impl RpcPoll {
831
    #[cfg(unix)]
832
    /// If possible, return a fd to use with an underlying event-driven IO code.
833
    ///
834
    /// This implementation fails if the underlying connection to the Arti RPC server
835
    /// is _not_ implemented via an fd.
836
    /// This is not possible in the current implementation,
837
    /// but may become possible in the future.
838
    /// Applications should consider this a fatal error.
839
    pub fn try_as_fd(&self) -> std::io::Result<std::os::fd::BorrowedFd<'_>> {
840
        self.nbconn.try_as_handle()
841
    }
842

            
843
    #[cfg(windows)]
844
    /// If possible, return a SOCKET to use with an underlying event-driven IO code.
845
    ///
846
    /// This implementation fails if the underlying connection to the Arti RPC server
847
    /// is _not_ implemented via a SOCKET.
848
    /// This is not possible in the current implementation,
849
    /// but may become possible in the future.
850
    /// Applications should consider this a fatal error.
851
    pub fn try_as_socket(&self) -> std::io::Result<std::os::windows::io::BorrowedSocket<'_>> {
852
        self.nbconn.try_as_handle()
853
    }
854

            
855
    /// Return true iff this [`RpcPoll`] currently wants to write
856
    ///
857
    /// If this returns true, the RPC library user should invoke [`RpcPoll::poll`]
858
    /// when the underlying connection is ready to write.
859
    ///
860
    /// See [`Eventloop`] for full usage information.
861
    ///
862
    /// Changes to the return value of this function correspond to calls
863
    /// to the methods on [`EventLoop`].
864
    ///
865
    /// A returned `false` value can be invalidated by calls to [`RpcConn::submit`].
866
    ///
867
    /// A returned `true` value can be invalidated by calls to [`RpcPoll::poll`].
868
    ///
869
    /// [`EventLoop`]: crate::EventLoop
870
    pub fn wants_to_write(&self) -> bool {
871
        self.nbconn.wants_to_write()
872
    }
873

            
874
    /// Handle IO for the associated RPC connection, without blocking.
875
    ///
876
    /// This method reads and writes data from the RPC server,
877
    /// until either:
878
    ///
879
    ///   * A response is available to a request created with [`RpcConn::submit`];
880
    ///     in which case, `RpcPoll::poll` returns that response.
881
    ///
882
    ///   * No further progress can be made without blocking;
883
    ///     in which case `RpcPoll::poll` returns [`WouldBlock`].
884
    ///
885
    /// This is used in conjunction with `EventLoop` and/or `wants_to_write`;
886
    /// see [the `EventLoop` documentation] for details.
887
    pub fn poll(&mut self) -> Result<Result<(UserTag, AnyResponse), WouldBlock>, ProtoError> {
888
        use crate::ll_conn::PollStatus;
889
        // We try reading _and_ writing regardless; it won't hurt anything.
890
        loop {
891
            let r = self.nbconn.interact_once();
892
            let response = match r {
893
                Ok(PollStatus::Msg(m)) => m.try_validate().map_err(ShutdownError::from),
894
                Ok(PollStatus::Closed) => return Err(ShutdownError::ConnectionClosed.into()),
895
                Ok(PollStatus::WouldBlock) => return Ok(Err(WouldBlock)),
896
                Err(io_error) => return Err(ShutdownError::Read(Arc::new(io_error)).into()),
897
            };
898

            
899
            let mut state = self.receiver.state.lock().expect("Poisoned lock");
900

            
901
            let response = match response {
902
                Ok(m) => m,
903
                Err(e) => {
904
                    if state.fatal.is_none() {
905
                        state.fatal = Some(e.clone());
906
                        state.pending.alert_everybody();
907
                    }
908
                    return Err(e.into());
909
                }
910
            };
911

            
912
            match PolledRequests.response_disposition(&mut state.pending, &response) {
913
                ResponseDisposition::Return(tag) => {
914
                    return Ok(Ok((tag, AnyResponse::from_validated(response))));
915
                }
916
                ResponseDisposition::Ignore => {}
917
                ResponseDisposition::ForwardWaiting(response_queue) => {
918
                    response_queue.push_back_and_alert((), response);
919
                }
920
                ResponseDisposition::ForwardPollable(_, _) => panic!("This should be unreachable"),
921
            };
922
            drop(state);
923
        }
924
    }
925
}