1
//! Middle-level API for RPC connections
2
//!
3
//! This module focuses around the `RpcConn` type, which supports sending RPC requests
4
//! and matching them with their responses.
5

            
6
use std::{
7
    io::{self},
8
    sync::{Arc, Mutex},
9
};
10

            
11
use crate::msgs::{
12
    AnyRequestId, ObjectId,
13
    request::InvalidRequestError,
14
    response::{ResponseKind, RpcError, ValidatedResponse},
15
};
16

            
17
mod auth;
18
mod builder;
19
mod connimpl;
20
mod stream;
21

            
22
use crate::util::Utf8CString;
23
pub use builder::{BuilderError, ConnPtDescription, RpcConnBuilder};
24
pub use connimpl::RpcConn;
25
use serde::{Deserialize, de::DeserializeOwned};
26
pub use stream::StreamError;
27
use tor_rpc_connect::{HasClientErrorAction, auth::cookie::CookieAccessError};
28

            
29
/// A handle to an open request.
30
///
31
/// These handles are created with [`RpcConn::execute_with_handle`].
32
///
33
/// Note that dropping a RequestHandle does not cancel the associated request:
34
/// it will continue running, but you won't have a way to receive updates from it.
35
/// To cancel a request, use [`RpcConn::cancel`].
36
#[derive(educe::Educe)]
37
#[educe(Debug)]
38
pub struct RequestHandle {
39
    /// The underlying `Receiver` that we'll use to get updates for this request
40
    ///
41
    /// It's wrapped in a `Mutex` to prevent concurrent calls to `Receiver::wait_on_message_for`.
42
    //
43
    // NOTE: As an alternative to using a Mutex here, we _could_ remove
44
    // the restriction from `wait_on_message_for` that says that only one thread
45
    // may be waiting on a given request ID at once.  But that would introduce
46
    // complexity to the implementation,
47
    // and it's not clear that the benefit would be worth it.
48
    #[educe(Debug(ignore))]
49
    conn: Mutex<Arc<connimpl::Receiver>>,
50
    /// The ID of this request.
51
    id: AnyRequestId,
52
}
53

            
54
// TODO RPC: Possibly abolish these types.
55
//
56
// I am keeping this for now because it makes it more clear that we can never reinterpret
57
// a success as an update or similar.
58
//
59
// I am not at all pleased with these types; we should revise them.
60
//
61
// TODO RPC: Possibly, all of these should be reconstructed
62
// from their serde_json::Values rather than forwarded verbatim.
63
// (But why would we our json to be more canonical than arti's? See #1491.)
64
//
65
// DODGY TYPES BEGIN: TODO RPC
66

            
67
/// A Success Response from Arti, indicating that a request was successful.
68
///
69
/// This is the complete message, including `id` and `result` fields.
70
//
71
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
72
// TODO RPC: check that the newline invariant is enforced in constructors.
73
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
74
#[as_ref(forward)]
75
pub struct SuccessResponse(Utf8CString);
76

            
77
impl SuccessResponse {
78
    /// Helper: Decode the `result` field of this response as an instance of D.
79
2040
    fn decode<D: DeserializeOwned>(&self) -> Result<D, serde_json::Error> {
80
        /// Helper object for decoding the "result" field.
81
        #[derive(Deserialize)]
82
        struct Response<R> {
83
            /// The decoded value.
84
            result: R,
85
        }
86
2040
        let response: Response<D> = serde_json::from_str(self.as_ref())?;
87
2040
        Ok(response.result)
88
2040
    }
89
}
90

            
91
/// An Update Response from Arti, with information about the progress of a request.
92
///
93
/// This is the complete message, including `id` and `update` fields.
94
//
95
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
96
// TODO RPC: check that the newline invariant is enforced in constructors.
97
// TODO RPC consider changing this to CString.
98
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
99
#[as_ref(forward)]
100
pub struct UpdateResponse(Utf8CString);
101

            
102
/// A Error Response from Arti, indicating that an error occurred.
103
///
104
/// (This is the complete message, including the `error` field.
105
/// It also an `id` if it
106
/// is in response to a request; but not if it is a fatal protocol error.)
107
//
108
// Invariant: Does not contain a NUL. (Safe to convert to CString.)
109
//
110
// Invariant: This field MUST encode a response whose body is an RPC error.
111
//
112
// Otherwise the `decode` method may panic.
113
//
114
// TODO RPC: check that the newline invariant is enforced in constructors.
115
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
116
#[as_ref(forward)]
117
// TODO: If we keep this, it should implement Error.
118
pub struct ErrorResponse(Utf8CString);
119
impl ErrorResponse {
120
    /// Construct an ErrorResponse from the Error reply.
121
    ///
122
    /// This not a From impl because we want it to be crate-internal.
123
2062
    pub(crate) fn from_validated_string(s: Utf8CString) -> Self {
124
2062
        ErrorResponse(s)
125
2062
    }
126

            
127
    /// Convert this response into an internal error in response to `cmd`.
128
    ///
129
    /// This is only appropriate when the error cannot be caused because of user behavior.
130
    pub(crate) fn internal_error(&self, cmd: &str) -> ProtoError {
131
        ProtoError::InternalRequestFailed(UnexpectedReply {
132
            request: cmd.to_string(),
133
            reply: self.to_string(),
134
            problem: UnexpectedReplyProblem::ErrorNotExpected,
135
        })
136
    }
137

            
138
    /// Try to interpret this response as an [`RpcError`].
139
2058
    pub fn decode(&self) -> RpcError {
140
2058
        crate::msgs::response::try_decode_response_as_err(self.0.as_ref())
141
2058
            .expect("Could not decode response that was already decoded as an error?")
142
2058
            .expect("Could not extract error from response that was already decoded as an error?")
143
2058
    }
144
}
145

            
146
impl std::fmt::Display for ErrorResponse {
147
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148
        let e = self.decode();
149
        write!(f, "Peer said {:?}", e.message())
150
    }
151
}
152

            
153
/// A final response -- that is, the last one that we expect to receive for a request.
154
///
155
type FinalResponse = Result<SuccessResponse, ErrorResponse>;
156

            
157
/// Any of the three types of Arti responses.
158
#[derive(Clone, Debug)]
159
#[allow(clippy::exhaustive_structs)]
160
pub enum AnyResponse {
161
    /// The request has succeeded; no more response will be given.
162
    Success(SuccessResponse),
163
    /// The request has failed; no more response will be given.
164
    Error(ErrorResponse),
165
    /// An incremental update; more messages may arrive.
166
    Update(UpdateResponse),
167
}
168
// TODO RPC: DODGY TYPES END.
169

            
170
impl AnyResponse {
171
    /// Convert `v` into `AnyResponse`.
172
8056
    fn from_validated(v: ValidatedResponse) -> Self {
173
        // TODO RPC, Perhaps unify AnyResponse with ValidatedResponse, once we are sure what
174
        // AnyResponse should look like.
175
8056
        match v.meta.kind {
176
2058
            ResponseKind::Error => AnyResponse::Error(ErrorResponse::from_validated_string(v.msg)),
177
2040
            ResponseKind::Success => AnyResponse::Success(SuccessResponse(v.msg)),
178
3958
            ResponseKind::Update => AnyResponse::Update(UpdateResponse(v.msg)),
179
        }
180
8056
    }
181

            
182
    /// Consume this `AnyResponse`, and return its internal string.
183
    #[cfg(feature = "ffi")]
184
    pub(crate) fn into_string(self) -> Utf8CString {
185
        match self {
186
            AnyResponse::Success(m) => m.into(),
187
            AnyResponse::Error(m) => m.into(),
188
            AnyResponse::Update(m) => m.into(),
189
        }
190
    }
191
}
192

            
193
impl RpcConn {
194
    /// Return the ObjectId for the negotiated Session.
195
    ///
196
    /// Nearly all RPC methods require a Session, or some other object
197
    /// accessed via the session.
198
    ///
199
    /// (This function will only return None if no authentication has been performed.
200
    /// TODO RPC: It is not currently possible to make an unauthenticated connection.)
201
    pub fn session(&self) -> Option<&ObjectId> {
202
        self.session.as_ref()
203
    }
204

            
205
    /// Run a command, and wait for success or failure.
206
    ///
207
    /// Note that this function will return `Err(.)` only if sending the command or getting a
208
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
209
    /// this function returns `Ok(Err(.))`.
210
    ///
211
    /// Note that the command does not need to include an `id` field.  If you omit it,
212
    /// one will be generated.
213
98
    pub fn execute(&self, cmd: &str) -> Result<FinalResponse, ProtoError> {
214
98
        let hnd = self.execute_with_handle(cmd)?;
215
68
        hnd.wait()
216
98
    }
217

            
218
    /// Helper for executing internally-generated requests and decoding their results.
219
    ///
220
    /// Behaves like `execute`, except on success, where it tries to decode the `result` field
221
    /// of the response as a `T`.
222
    ///
223
    /// Use this method in cases where it's reasonable for Arti to sometimes return an RPC error:
224
    /// in other words, where it's not necessarily a programming error or version mismatch.
225
    ///
226
    /// Don't use this for user-generated requests: it will misreport unexpected replies
227
    /// as internal errors.
228
2
    pub(crate) fn execute_internal<T: DeserializeOwned>(
229
2
        &self,
230
2
        cmd: &str,
231
2
    ) -> Result<Result<T, ErrorResponse>, ProtoError> {
232
2
        match self.execute(cmd)? {
233
2
            Ok(success) => match success.decode::<T>() {
234
2
                Ok(result) => Ok(Ok(result)),
235
                Err(json_error) => Err(ProtoError::InternalRequestFailed(UnexpectedReply {
236
                    request: cmd.to_string(),
237
                    reply: Utf8CString::from(success).to_string(),
238
                    problem: UnexpectedReplyProblem::CannotDecode(Arc::new(json_error)),
239
                })),
240
            },
241
            Err(error) => Ok(Err(error)),
242
        }
243
2
    }
244

            
245
    /// Helper for executing internally-generated requests and decoding their results.
246
    ///
247
    /// Behaves like `execute_internal`, except that it treats any RPC error reply
248
    /// as an internal error or version mismatch.
249
    ///
250
    /// Don't use this for user-generated requests, or for requests that can fail because of
251
    /// incorrect user inputs: it will misreport failures in those requests as internal errors.
252
2
    pub(crate) fn execute_internal_ok<T: DeserializeOwned>(
253
2
        &self,
254
2
        cmd: &str,
255
2
    ) -> Result<T, ProtoError> {
256
2
        match self.execute_internal(cmd)? {
257
2
            Ok(v) => Ok(v),
258
            Err(err_response) => Err(err_response.internal_error(cmd)),
259
        }
260
2
    }
261

            
262
    /// Cancel a request by ID.
263
    pub fn cancel(&self, request_id: &AnyRequestId) -> Result<(), ProtoError> {
264
        /// Arguments to an `rpc::cancel` request.
265
        #[derive(serde::Serialize, Debug)]
266
        struct CancelParams<'a> {
267
            /// The request to cancel.
268
            request_id: &'a AnyRequestId,
269
        }
270

            
271
        let request = crate::msgs::request::Request::new(
272
            ObjectId::connection_id(),
273
            "rpc:cancel",
274
            CancelParams { request_id },
275
        );
276
        match self.execute_internal::<EmptyReply>(&request.encode()?)? {
277
            Ok(EmptyReply {}) => Ok(()),
278
            Err(_) => Err(ProtoError::RequestCompleted),
279
        }
280
    }
281

            
282
    /// Like `execute`, but don't wait.  This lets the caller see the
283
    /// request ID and  maybe cancel it.
284
4194
    pub fn execute_with_handle(&self, cmd: &str) -> Result<RequestHandle, ProtoError> {
285
4194
        self.send_request(cmd)
286
4194
    }
287
    /// As execute(), but run update_cb for every update we receive.
288
4096
    pub fn execute_with_updates<F>(
289
4096
        &self,
290
4096
        cmd: &str,
291
4096
        mut update_cb: F,
292
4096
    ) -> Result<FinalResponse, ProtoError>
293
4096
    where
294
4096
        F: FnMut(UpdateResponse) + Send + Sync,
295
    {
296
4096
        let hnd = self.execute_with_handle(cmd)?;
297
        loop {
298
8054
            match hnd.wait_with_updates()? {
299
2038
                AnyResponse::Success(s) => return Ok(Ok(s)),
300
2058
                AnyResponse::Error(e) => return Ok(Err(e)),
301
3958
                AnyResponse::Update(u) => update_cb(u),
302
            }
303
        }
304
4096
    }
305

            
306
    /// Helper: Tell Arti to release `obj`.
307
    ///
308
    /// Do not use this method for a user-provided object ID:
309
    /// It gives an internal error if the object does not exist.
310
    pub(crate) fn release_obj(&self, obj: ObjectId) -> Result<(), ProtoError> {
311
        let release_request = crate::msgs::request::Request::new(obj, "rpc:release", NoParams {});
312
        let _empty_response: EmptyReply = self.execute_internal_ok(&release_request.encode()?)?;
313
        Ok(())
314
    }
315

            
316
    // TODO RPC: shutdown() on the socket on Drop.
317
}
318

            
319
impl RequestHandle {
320
    /// Return the ID of this request, to help cancelling it.
321
    pub fn id(&self) -> &AnyRequestId {
322
        &self.id
323
    }
324
    /// Wait for success or failure, and return what happened.
325
    ///
326
    /// (Ignores any update messages that are received.)
327
    ///
328
    /// Note that this function will return `Err(.)` only if sending the command or getting a
329
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
330
    /// this function returns `Ok(Err(.))`.
331
68
    pub fn wait(self) -> Result<FinalResponse, ProtoError> {
332
        loop {
333
68
            match self.wait_with_updates()? {
334
2
                AnyResponse::Success(s) => return Ok(Ok(s)),
335
                AnyResponse::Error(e) => return Ok(Err(e)),
336
                AnyResponse::Update(_) => {}
337
            }
338
        }
339
68
    }
340
    /// Wait for the next success, failure, or update from this handle.
341
    ///
342
    /// Note that this function will return `Err(.)` only if sending the command or getting a
343
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
344
    /// this function returns `Ok(AnyResponse::Error(.))`.
345
    ///
346
    /// You may call this method on the same `RequestHandle` from multiple threads.
347
    /// If you do so, those calls will receive responses (or errors) in an unspecified order.
348
    ///
349
    /// If this function returns Success or Error, then you shouldn't call it again.
350
    /// All future calls to this function will fail with `CmdError::RequestCancelled`.
351
    /// (TODO RPC: Maybe rename that error.)
352
8122
    pub fn wait_with_updates(&self) -> Result<AnyResponse, ProtoError> {
353
8122
        let conn = self.conn.lock().expect("Poisoned lock");
354
8122
        let validated = conn.wait_on_message_for(&self.id)?;
355

            
356
8056
        Ok(AnyResponse::from_validated(validated))
357
8122
    }
358

            
359
    // TODO RPC: Sketch out how we would want to do this in an async world,
360
    // or with poll
361
}
362

            
363
/// An error (or other condition) that has caused an RPC connection to shut down.
364
#[derive(Clone, Debug, thiserror::Error)]
365
#[non_exhaustive]
366
pub enum ShutdownError {
367
    // TODO nb: Read/Write are no longer well separated in the API.
368
    //
369
    /// Io error occurred while reading.
370
    #[error("Unable to read response")]
371
    Read(#[source] Arc<io::Error>),
372
    /// Io error occurred while writing.
373
    #[error("Unable to write request")]
374
    Write(#[source] Arc<io::Error>),
375
    /// Something was wrong with Arti's responses; this is a protocol violation.
376
    #[error("Arti sent a message that didn't conform to the RPC protocol: {0:?}")]
377
    ProtocolViolated(String),
378
    /// Arti has told us that we violated the protocol somehow.
379
    #[error("Arti reported a fatal error: {0:?}")]
380
    ProtocolViolationReport(ErrorResponse),
381
    /// The underlying connection closed.
382
    ///
383
    /// This probably means that Arti has shut down.
384
    #[error("Connection closed")]
385
    ConnectionClosed,
386
}
387

            
388
impl From<crate::msgs::response::DecodeResponseError> for ShutdownError {
389
4
    fn from(value: crate::msgs::response::DecodeResponseError) -> Self {
390
        use crate::msgs::response::DecodeResponseError::*;
391
        use ShutdownError as E;
392
4
        match value {
393
2
            JsonProtocolViolation(e) => E::ProtocolViolated(e.to_string()),
394
            ProtocolViolation(s) => E::ProtocolViolated(s.to_string()),
395
2
            Fatal(rpc_err) => E::ProtocolViolationReport(rpc_err),
396
        }
397
4
    }
398
}
399

            
400
/// An error that has occurred while launching an RPC command.
401
#[derive(Clone, Debug, thiserror::Error)]
402
#[non_exhaustive]
403
pub enum ProtoError {
404
    /// The RPC connection failed, or was closed by the other side.
405
    #[error("RPC connection is shut down")]
406
    Shutdown(#[from] ShutdownError),
407

            
408
    /// There was a problem in the request we tried to send.
409
    #[error("Invalid request")]
410
    InvalidRequest(#[from] InvalidRequestError),
411

            
412
    /// We tried to send a request with an ID that was already pending.
413
    #[error("Request ID already in use.")]
414
    RequestIdInUse,
415

            
416
    /// We tried to wait for or inspect a request that had already succeeded or failed.
417
    #[error("Request has already completed (or failed)")]
418
    RequestCompleted,
419

            
420
    /// We tried to wait for the same request more than once.
421
    ///
422
    /// (This should be impossible.)
423
    #[error("Internal error: waiting on the same request more than once at a time.")]
424
    DuplicateWait,
425

            
426
    /// We got an internal error while trying to encode an RPC request.
427
    ///
428
    /// (This should be impossible.)
429
    #[error("Internal error while encoding request")]
430
    CouldNotEncode(#[source] Arc<serde_json::Error>),
431

            
432
    /// We got a response to some internally generated request that wasn't what we expected.
433
    #[error("{0}")]
434
    InternalRequestFailed(#[source] UnexpectedReply),
435
}
436

            
437
/// A set of errors encountered while trying to connect to the Arti process
438
#[derive(Clone, Debug, thiserror::Error)]
439
pub struct ConnectFailure {
440
    /// A list of all the declined connect points we encountered, and how they failed.
441
    declined: Vec<(builder::ConnPtDescription, ConnectError)>,
442
    /// A description of where we found the final error (if it's an abort.)
443
    final_desc: Option<builder::ConnPtDescription>,
444
    /// The final error explaining why we couldn't connect.
445
    ///
446
    /// This is either an abort, an AllAttemptsDeclined, or an error that prevented the
447
    /// search process from even beginning.
448
    #[source]
449
    pub(crate) final_error: ConnectError,
450
}
451

            
452
impl std::fmt::Display for ConnectFailure {
453
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454
        write!(f, "Unable to connect")?;
455
        if !self.declined.is_empty() {
456
            write!(
457
                f,
458
                " ({} attempts failed{})",
459
                self.declined.len(),
460
                if matches!(self.final_error, ConnectError::AllAttemptsDeclined) {
461
                    ""
462
                } else {
463
                    " before fatal error"
464
                }
465
            )?;
466
        }
467
        Ok(())
468
    }
469
}
470

            
471
impl ConnectFailure {
472
    /// If this attempt failed because of a fatal error that made a connect point attempt abort,
473
    /// return a description of the origin of that connect point.
474
    pub fn fatal_error_origin(&self) -> Option<&builder::ConnPtDescription> {
475
        self.final_desc.as_ref()
476
    }
477

            
478
    /// For each connect attempt that failed nonfatally, return a description of the
479
    /// origin of that connect point, and the error that caused it to fail.
480
    pub fn declined_attempt_outcomes(
481
        &self,
482
    ) -> impl Iterator<Item = (&builder::ConnPtDescription, &ConnectError)> {
483
        // Note: this map looks like a no-op, but isn't.
484
        self.declined.iter().map(|(a, b)| (a, b))
485
    }
486

            
487
    /// Return a helper type to format this error, and all of its internal errors recursively.
488
    ///
489
    /// Unlike [`tor_error::Report`], this method includes not only fatal errors, but also
490
    /// information about connect attempts that failed nonfatally.
491
    pub fn display_verbose(&self) -> ConnectFailureVerboseFmt<'_> {
492
        ConnectFailureVerboseFmt(self)
493
    }
494
}
495

            
496
/// Helper type to format a ConnectFailure along with all of its internal errors,
497
/// including non-fatal errors.
498
#[derive(Debug, Clone)]
499
pub struct ConnectFailureVerboseFmt<'a>(&'a ConnectFailure);
500

            
501
impl<'a> std::fmt::Display for ConnectFailureVerboseFmt<'a> {
502
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503
        use tor_error::ErrorReport as _;
504
        writeln!(f, "{}:", self.0)?;
505
        for (idx, (origin, error)) in self.0.declined_attempt_outcomes().enumerate() {
506
            writeln!(f, "  {}. {}: {}", idx + 1, origin, error.report())?;
507
        }
508
        if let Some(origin) = self.0.fatal_error_origin() {
509
            writeln!(
510
                f,
511
                "  {}. [FATAL] {}: {}",
512
                self.0.declined.len() + 1,
513
                origin,
514
                self.0.final_error.report()
515
            )?;
516
        } else {
517
            writeln!(f, "  - {}", self.0.final_error.report())?;
518
        }
519
        Ok(())
520
    }
521
}
522

            
523
/// An error while trying to connect to the Arti process.
524
#[derive(Clone, Debug, thiserror::Error)]
525
#[non_exhaustive]
526
pub enum ConnectError {
527
    /// Unable to parse connect points from an environment variable.
528
    #[error("Cannot parse connect points from environment variable")]
529
    BadEnvironment,
530
    /// We were unable to load and/or parse a given connect point.
531
    #[error("Unable to load and parse connect point")]
532
    CannotParse(#[from] tor_rpc_connect::load::LoadError),
533
    /// The path used to specify a connect file couldn't be resolved.
534
    #[error("Unable to resolve connect point path")]
535
    CannotResolvePath(#[source] tor_config_path::CfgPathError),
536
    /// A parsed connect point couldn't be resolved.
537
    #[error("Unable to resolve connect point")]
538
    CannotResolveConnectPoint(#[from] tor_rpc_connect::ResolveError),
539
    /// IO error while connecting to Arti.
540
    #[error("Unable to make a connection")]
541
    CannotConnect(#[from] tor_rpc_connect::ConnectError),
542
    /// The connect point told us to connect via a type of stream we don't know how to support.
543
    #[error("Connect point stream type was unsupported")]
544
    StreamTypeUnsupported,
545
    /// Opened a connection, but didn't get a banner message.
546
    ///
547
    /// (This isn't a `BadMessage`, since it is likelier to represent something that isn't
548
    /// pretending to be Arti at all than it is to be a malfunctioning Arti.)
549
    #[error("Did not receive expected banner message upon connecting")]
550
    InvalidBanner,
551
    /// All attempted connect points were declined, and none were aborted.
552
    #[error("All connect points were declined (or there were none)")]
553
    AllAttemptsDeclined,
554
    /// A connect file or directory was given as a relative path.
555
    /// (Only absolute paths are supported).
556
    #[error("Connect file was given as a relative path.")]
557
    RelativeConnectFile,
558
    /// One of our authentication messages received an error.
559
    #[error("Received an error while trying to authenticate: {0}")]
560
    AuthenticationFailed(ErrorResponse),
561
    /// The connect point uses an RPC authentication type we don't support.
562
    #[error("Authentication type is not supported")]
563
    AuthenticationNotSupported,
564
    /// We couldn't decode one of the responses we got.
565
    #[error("Message not in expected format")]
566
    BadMessage(#[source] Arc<serde_json::Error>),
567
    /// A protocol error occurred during negotiations.
568
    #[error("Error while negotiating with Arti")]
569
    ProtoError(#[from] ProtoError),
570
    /// The server thinks it is listening on an address where we don't expect to find it.
571
    /// This can be misconfiguration or an attempted MITM attack.
572
    #[error("We connected to the server at {ours}, but it thinks it's listening at {theirs}")]
573
    ServerAddressMismatch {
574
        /// The address we think the server has
575
        ours: String,
576
        /// The address that the server says it has.
577
        theirs: String,
578
    },
579
    /// The server tried to prove knowledge of a cookie file, but its proof was incorrect.
580
    #[error("Server's cookie MAC was not as expected.")]
581
    CookieMismatch,
582
    /// We were unable to access the configured cookie file.
583
    #[error("Unable to load secret cookie value")]
584
    LoadCookie(#[from] CookieAccessError),
585
}
586

            
587
impl HasClientErrorAction for ConnectError {
588
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
589
        use ConnectError as E;
590
        use tor_rpc_connect::ClientErrorAction as A;
591
        match self {
592
            E::BadEnvironment => A::Abort,
593
            E::CannotParse(e) => e.client_action(),
594
            E::CannotResolvePath(_) => A::Abort,
595
            E::CannotResolveConnectPoint(e) => e.client_action(),
596
            E::CannotConnect(e) => e.client_action(),
597
            E::StreamTypeUnsupported => A::Decline,
598
            E::InvalidBanner => A::Decline,
599
            E::RelativeConnectFile => A::Abort,
600
            E::AuthenticationFailed(_) => A::Decline,
601
            // TODO RPC: Is this correct?  This error can also occur when
602
            // we are talking to something other than an RPC server.
603
            E::BadMessage(_) => A::Abort,
604
            E::ProtoError(e) => e.client_action(),
605
            E::AllAttemptsDeclined => A::Abort,
606
            E::AuthenticationNotSupported => A::Decline,
607
            E::ServerAddressMismatch { .. } => A::Abort,
608
            E::CookieMismatch => A::Abort,
609
            E::LoadCookie(e) => e.client_action(),
610
        }
611
    }
612
}
613

            
614
impl HasClientErrorAction for ProtoError {
615
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
616
        use ProtoError as E;
617
        use tor_rpc_connect::ClientErrorAction as A;
618
        match self {
619
            E::Shutdown(_) => A::Decline,
620
            E::InternalRequestFailed(_) => A::Decline,
621
            // These are always internal errors if they occur while negotiating a connection to RPC,
622
            // which is the context we care about for `HasClientErrorAction`.
623
            E::InvalidRequest(_)
624
            | E::RequestIdInUse
625
            | E::RequestCompleted
626
            | E::DuplicateWait
627
            | E::CouldNotEncode(_) => A::Abort,
628
        }
629
    }
630
}
631

            
632
/// In response to a request that we generated internally,
633
/// Arti gave a reply that we did not understand.
634
///
635
/// This could be due to a bug in this library, a bug in Arti,
636
/// or a compatibility issue between the two.
637
#[derive(Clone, Debug, thiserror::Error)]
638
#[error("In response to our request {request:?}, Arti gave the unexpected reply {reply:?}")]
639
pub struct UnexpectedReply {
640
    /// The request we sent.
641
    request: String,
642
    /// The response we got.
643
    reply: String,
644
    /// What was wrong with the response.
645
    #[source]
646
    problem: UnexpectedReplyProblem,
647
}
648

            
649
/// Underlying reason for an UnexpectedReply
650
#[derive(Clone, Debug, thiserror::Error)]
651
enum UnexpectedReplyProblem {
652
    /// There was a json failure while trying to decode the response:
653
    /// the result type was not what we expected.
654
    #[error("Cannot decode as correct JSON type")]
655
    CannotDecode(Arc<serde_json::Error>),
656
    /// Arti replied with an RPC error in a context no error should have been possible.
657
    #[error("Unexpected error")]
658
    ErrorNotExpected,
659
}
660

            
661
/// Arguments to a request that takes no parameters.
662
#[derive(serde::Serialize, Debug)]
663
struct NoParams {}
664

            
665
/// A reply with no data.
666
#[derive(serde::Deserialize, Debug)]
667
struct EmptyReply {}
668

            
669
#[cfg(test)]
670
mod test {
671
    // @@ begin test lint list maintained by maint/add_warning @@
672
    #![allow(clippy::bool_assert_comparison)]
673
    #![allow(clippy::clone_on_copy)]
674
    #![allow(clippy::dbg_macro)]
675
    #![allow(clippy::mixed_attributes_style)]
676
    #![allow(clippy::print_stderr)]
677
    #![allow(clippy::print_stdout)]
678
    #![allow(clippy::single_char_pattern)]
679
    #![allow(clippy::unwrap_used)]
680
    #![allow(clippy::unchecked_time_subtraction)]
681
    #![allow(clippy::useless_vec)]
682
    #![allow(clippy::needless_pass_by_value)]
683
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
684

            
685
    use std::{sync::atomic::AtomicUsize, thread, time::Duration};
686

            
687
    use io::{BufRead as _, BufReader, Write as _};
688
    use rand::{Rng as _, SeedableRng as _, seq::SliceRandom as _};
689
    use tor_basic_utils::{RngExt as _, test_rng::testing_rng};
690

            
691
    use crate::{
692
        msgs::request::{JsonMap, Request, ValidatedRequest},
693
        nb_stream::PollingStream,
694
    };
695

            
696
    use super::*;
697

            
698
    /// helper: Return a dummy RpcConn, along with a socketpair for it to talk to.
699
    fn dummy_connected() -> (RpcConn, crate::testing::SocketpairStream) {
700
        let (s1, s2) = crate::testing::construct_socketpair().unwrap();
701
        let conn = RpcConn::new(PollingStream::new(s1).unwrap());
702

            
703
        (conn, s2)
704
    }
705

            
706
    fn write_val(w: &mut impl io::Write, v: &serde_json::Value) {
707
        let mut enc = serde_json::to_string(v).unwrap();
708
        enc.push('\n');
709
        w.write_all(enc.as_bytes()).unwrap();
710
    }
711

            
712
    #[test]
713
    fn simple() {
714
        let (conn, sock) = dummy_connected();
715

            
716
        let user_thread = thread::spawn(move || {
717
            let response1 = conn
718
                .execute_internal_ok::<JsonMap>(
719
                    r#"{"obj":"fred","method":"arti:x-frob","params":{}}"#,
720
                )
721
                .unwrap();
722
            (response1, conn)
723
        });
724

            
725
        let fake_arti_thread = thread::spawn(move || {
726
            let mut sock = BufReader::new(sock);
727
            let mut s = String::new();
728
            let _len = sock.read_line(&mut s).unwrap();
729
            let request = ValidatedRequest::from_string_strict(s.as_ref()).unwrap();
730
            let response = serde_json::json!({
731
                "id": request.id().clone(),
732
                "result": { "xyz" : 3 }
733
            });
734
            write_val(sock.get_mut(), &response);
735
            sock // prevent close
736
        });
737

            
738
        let _sock = fake_arti_thread.join().unwrap();
739
        let (map, _conn) = user_thread.join().unwrap();
740
        assert_eq!(map.get("xyz"), Some(&serde_json::Value::Number(3.into())));
741
    }
742

            
743
    #[test]
744
    fn complex() {
745
        use std::sync::atomic::Ordering::SeqCst;
746
        let n_threads = 16;
747
        let n_commands_per_thread = 128;
748
        let n_commands_total = n_threads * n_commands_per_thread;
749
        let n_completed = Arc::new(AtomicUsize::new(0));
750

            
751
        let (conn, sock) = dummy_connected();
752
        let conn = Arc::new(conn);
753
        let mut user_threads = Vec::new();
754
        let mut rng = testing_rng();
755

            
756
        // -------
757
        // User threads: Make a bunch of requests.
758
        for th_idx in 0..n_threads {
759
            let conn = Arc::clone(&conn);
760
            let n_completed = Arc::clone(&n_completed);
761
            let mut rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
762
            let th = thread::spawn(move || {
763
                for cmd_idx in 0..n_commands_per_thread {
764
                    // We are spawning a bunch of worker threads, each of which will run a number of
765
                    // commands in sequence.  Each command will be a request that gets optional
766
                    // updates, and an error or a success.
767
                    // We will double-check that each request gets the response it asked for.
768
                    let s = format!("{}:{}", th_idx, cmd_idx);
769
                    let want_updates: bool = rng.random();
770
                    let want_failure: bool = rng.random();
771
                    let req = serde_json::json!({
772
                        "obj":"fred",
773
                        "method":"arti:x-echo",
774
                        "meta": {
775
                            "updates": want_updates,
776
                        },
777
                        "params": {
778
                            "val": &s,
779
                            "fail": want_failure,
780
                        },
781
                    });
782
                    let req = serde_json::to_string(&req).unwrap();
783

            
784
                    // Wait for a final response, processing updates if we asked for them.
785
                    let mut n_updates = 0;
786
                    let outcome = conn
787
                        .execute_with_updates(&req, |_update| {
788
                            n_updates += 1;
789
                        })
790
                        .unwrap();
791
                    assert_eq!(n_updates > 0, want_updates);
792

            
793
                    // See if we liked the final response.
794
                    if want_failure {
795
                        let e = outcome.unwrap_err().decode();
796
                        assert_eq!(e.message(), "You asked me to fail");
797
                        assert_eq!(i32::from(e.code()), 33);
798
                        assert_eq!(
799
                            e.kinds_iter().collect::<Vec<_>>(),
800
                            vec!["Example".to_string()]
801
                        );
802
                    } else {
803
                        let success = outcome.unwrap();
804
                        let map = success.decode::<JsonMap>().unwrap();
805
                        assert_eq!(map.get("echo"), Some(&serde_json::Value::String(s)));
806
                    }
807
                    n_completed.fetch_add(1, SeqCst);
808
                    if rng.random::<f32>() < 0.02 {
809
                        thread::sleep(Duration::from_millis(3));
810
                    }
811
                }
812
            });
813
            user_threads.push(th);
814
        }
815

            
816
        #[derive(serde::Deserialize, Debug)]
817
        struct Echo {
818
            val: String,
819
            fail: bool,
820
        }
821

            
822
        // -----
823
        // Worker thread: handles user requests.
824
        let worker_rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
825
        let worker_thread = thread::spawn(move || {
826
            let mut rng = worker_rng;
827
            let mut sock = BufReader::new(sock);
828
            let mut pending: Vec<Request<Echo>> = Vec::new();
829
            let mut n_received = 0;
830

            
831
            // How many requests do we buffer before we shuffle them and answer them out-of-order?
832
            let scramble_factor = 7;
833
            // After receiving how many requests do we stop shuffling requests?
834
            //
835
            // (Our shuffling algorithm can deadlock us otherwise.)
836
            let scramble_threshold =
837
                n_commands_total - (n_commands_per_thread + 1) * scramble_factor;
838

            
839
            'outer: loop {
840
                let flush_pending_at = if n_received >= scramble_threshold {
841
                    1
842
                } else {
843
                    scramble_factor
844
                };
845

            
846
                // Queue a handful of requests in "pending"
847
                while pending.len() < flush_pending_at {
848
                    let mut buf = String::new();
849
                    if sock.read_line(&mut buf).unwrap() == 0 {
850
                        break 'outer;
851
                    }
852
                    n_received += 1;
853
                    let req: Request<Echo> = serde_json::from_str(&buf).unwrap();
854
                    pending.push(req);
855
                }
856

            
857
                // Handle the requests in "pending" in random order.
858
                let mut handling = std::mem::take(&mut pending);
859
                handling.shuffle(&mut rng);
860

            
861
                for req in handling {
862
                    if req.meta.unwrap_or_default().updates {
863
                        let n_updates = rng.gen_range_checked(1..4).unwrap();
864
                        for _ in 0..n_updates {
865
                            let up = serde_json::json!({
866
                                "id": req.id.clone(),
867
                                "update": {
868
                                    "hello": req.params.val.clone(),
869
                                }
870
                            });
871
                            write_val(sock.get_mut(), &up);
872
                        }
873
                    }
874

            
875
                    let response = if req.params.fail {
876
                        serde_json::json!({
877
                            "id": req.id.clone(),
878
                            "error": { "message": "You asked me to fail", "code": 33, "kinds": ["Example"], "data": req.params.val },
879
                        })
880
                    } else {
881
                        serde_json::json!({
882
                            "id": req.id.clone(),
883
                            "result": {
884
                                "echo": req.params.val
885
                            }
886
                        })
887
                    };
888
                    write_val(sock.get_mut(), &response);
889
                }
890
            }
891
        });
892
        drop(conn);
893
        for t in user_threads {
894
            t.join().unwrap();
895
        }
896

            
897
        worker_thread.join().unwrap();
898

            
899
        assert_eq!(n_completed.load(SeqCst), n_commands_total);
900
    }
901

            
902
    #[test]
903
    fn arti_socket_closed() {
904
        // Here we send a bunch of requests and then close the socket without answering them.
905
        //
906
        // Every request should get a ProtoError::Shutdown.
907
        let n_threads = 16;
908

            
909
        let (conn, sock) = dummy_connected();
910
        let conn = Arc::new(conn);
911
        let mut user_threads = Vec::new();
912
        for _ in 0..n_threads {
913
            let conn = Arc::clone(&conn);
914
            let th = thread::spawn(move || {
915
                // We are spawning a bunch of worker threads, each of which will run a number of
916
                // We will double-check that each request gets the response it asked for.
917
                let req = serde_json::json!({
918
                    "obj":"fred",
919
                    "method":"arti:x-echo",
920
                    "params":{}
921
                });
922
                let req = serde_json::to_string(&req).unwrap();
923
                let outcome = conn.execute(&req);
924
                if !matches!(
925
                    &outcome,
926
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
927
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_))),
928
                ) {
929
                    dbg!(&outcome);
930
                }
931

            
932
                assert!(matches!(
933
                    outcome,
934
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
935
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_)))
936
                        | Err(ProtoError::Shutdown(ShutdownError::ConnectionClosed))
937
                ));
938
            });
939
            user_threads.push(th);
940
        }
941

            
942
        drop(sock);
943

            
944
        for t in user_threads {
945
            t.join().unwrap();
946
        }
947
    }
948

            
949
    /// Send a bunch of requests and then send back a single reply.
950
    ///
951
    /// That reply should cause every request to get closed.
952
    fn proto_err_with_msg<F>(msg: &str, outcome_ok: F)
953
    where
954
        F: Fn(ProtoError) -> bool,
955
    {
956
        let n_threads = 16;
957

            
958
        let (conn, mut sock) = dummy_connected();
959
        let conn = Arc::new(conn);
960
        let mut user_threads = Vec::new();
961
        for _ in 0..n_threads {
962
            let conn = Arc::clone(&conn);
963
            let th = thread::spawn(move || {
964
                // We are spawning a bunch of worker threads, each of which will run a number of
965
                // We will double-check that each request gets the response it asked for.
966
                let req = serde_json::json!({
967
                    "obj":"fred",
968
                    "method":"arti:x-echo",
969
                    "params":{}
970
                });
971
                let req = serde_json::to_string(&req).unwrap();
972
                conn.execute(&req)
973
            });
974
            user_threads.push(th);
975
        }
976

            
977
        sock.write_all(msg.as_bytes()).unwrap();
978

            
979
        for t in user_threads {
980
            let outcome = t.join().unwrap();
981
            assert!(outcome_ok(outcome.unwrap_err()));
982
        }
983
    }
984

            
985
    #[test]
986
    fn syntax_error() {
987
        proto_err_with_msg("this is not json\n", |outcome| {
988
            matches!(
989
                outcome,
990
                ProtoError::Shutdown(ShutdownError::ProtocolViolated(_))
991
            )
992
        });
993
    }
994

            
995
    #[test]
996
    fn fatal_error() {
997
        let j = serde_json::json!({
998
            "error":{ "message": "This test is doomed", "code": 413, "kinds": ["Example"], "data": {} },
999
        });
        let mut s = serde_json::to_string(&j).unwrap();
        s.push('\n');
        proto_err_with_msg(&s, |outcome| {
            matches!(
                outcome,
                ProtoError::Shutdown(ShutdownError::ProtocolViolationReport(_))
            )
        });
    }
}