1
//! Middle-level API for RPC connections
2
//!
3
//! This module focuses around the `RpcConn` type, which supports sending RPC requests
4
//! and matching them with their responses.
5

            
6
use std::{
7
    io::{self},
8
    sync::{Arc, Mutex},
9
};
10

            
11
use crate::msgs::{
12
    AnyRequestId, ObjectId,
13
    request::InvalidRequestError,
14
    response::{ResponseKind, RpcError, ValidatedResponse},
15
};
16

            
17
mod auth;
18
mod builder;
19
mod connimpl;
20
mod stream;
21

            
22
use crate::util::Utf8CString;
23
pub use builder::{BuilderError, ConnPtDescription, RpcConnBuilder};
24
pub use connimpl::{RpcConn, RpcPoll, WouldBlock};
25
use serde::{Deserialize, de::DeserializeOwned};
26
pub use stream::StreamError;
27
use tor_rpc_connect::{HasClientErrorAction, auth::cookie::CookieAccessError};
28

            
29
/// A user-provided tag used to identify requests provided to
30
/// [`RpcConn::submit`].
31
///
32
/// Most users will want to crate tags that are unique
33
/// for the lifetime of their associated requests.
34
/// This is not enforced: the only drawback of duplicating tags
35
/// is that you will not be able to use them to distinguish
36
/// which reply is which.
37
///
38
/// This is distinct from the request ID type (represented by [`AnyRequestId`])
39
/// that is sent to the RPC server with each request
40
/// and returned along with each corresponding response.
41
/// By contrast, a `UserTag` is never sent to the RPC server,
42
/// and therefore is safe to use with information
43
/// (like callback and data pointers)
44
/// which it would not be safe to take from an untrusted source.
45
//
46
// Note: The tag is chosen to be two pointers in size,
47
// to accommodate C implementations that want to
48
// stuff a `void fn(void*), void*` inside of one of these.
49
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
50
#[allow(clippy::exhaustive_structs)]
51
pub struct UserTag(pub usize, pub usize);
52

            
53
/// A handle to an open request.
54
///
55
/// These handles are created with [`RpcConn::execute_with_handle`].
56
///
57
/// Note that dropping a RequestHandle does not cancel the associated request:
58
/// it will continue running, but you won't have a way to receive updates from it.
59
/// To cancel a request, use [`RpcConn::cancel`].
60
#[derive(educe::Educe)]
61
#[educe(Debug)]
62
pub struct RequestHandle {
63
    /// The underlying `Receiver` that we'll use to get updates for this request
64
    ///
65
    /// It's wrapped in a `Mutex` to prevent concurrent calls to `Receiver::wait_on_message_for`.
66
    //
67
    // NOTE: As an alternative to using a Mutex here, we _could_ remove
68
    // the restriction from `wait_on_message_for` that says that only one thread
69
    // may be waiting on a given request ID at once.  But that would introduce
70
    // complexity to the implementation,
71
    // and it's not clear that the benefit would be worth it.
72
    #[educe(Debug(ignore))]
73
    conn: Mutex<Arc<connimpl::Receiver>>,
74
    /// The ID of this request.
75
    id: AnyRequestId,
76
}
77

            
78
// TODO RPC: Possibly abolish these types.
79
//
80
// I am keeping this for now because it makes it more clear that we can never reinterpret
81
// a success as an update or similar.
82
//
83
// I am not at all pleased with these types; we should revise them.
84
//
85
// TODO RPC: Possibly, all of these should be reconstructed
86
// from their serde_json::Values rather than forwarded verbatim.
87
// (But why would we our json to be more canonical than arti's? See #1491.)
88
//
89
// DODGY TYPES BEGIN: TODO RPC
90

            
91
/// A Success Response from Arti, indicating that a request was successful.
92
///
93
/// This is the complete message, including `id` and `result` fields.
94
//
95
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
96
// TODO RPC: check that the newline invariant is enforced in constructors.
97
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
98
#[as_ref(forward)]
99
pub struct SuccessResponse(Utf8CString);
100

            
101
impl SuccessResponse {
102
    /// Helper: Decode the `result` field of this response as an instance of D.
103
2018
    fn decode<D: DeserializeOwned>(&self) -> Result<D, serde_json::Error> {
104
        /// Helper object for decoding the "result" field.
105
        #[derive(Deserialize)]
106
        struct Response<R> {
107
            /// The decoded value.
108
            result: R,
109
        }
110
2018
        let response: Response<D> = serde_json::from_str(self.as_ref())?;
111
2018
        Ok(response.result)
112
2018
    }
113
}
114

            
115
/// An Update Response from Arti, with information about the progress of a request.
116
///
117
/// This is the complete message, including `id` and `update` fields.
118
//
119
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
120
// TODO RPC: check that the newline invariant is enforced in constructors.
121
// TODO RPC consider changing this to CString.
122
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
123
#[as_ref(forward)]
124
pub struct UpdateResponse(Utf8CString);
125

            
126
/// A Error Response from Arti, indicating that an error occurred.
127
///
128
/// (This is the complete message, including the `error` field.
129
/// It also an `id` if it
130
/// is in response to a request; but not if it is a fatal protocol error.)
131
//
132
// Invariant: Does not contain a NUL. (Safe to convert to CString.)
133
//
134
// Invariant: This field MUST encode a response whose body is an RPC error.
135
//
136
// Otherwise the `decode` method may panic.
137
//
138
// TODO RPC: check that the newline invariant is enforced in constructors.
139
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
140
#[as_ref(forward)]
141
// TODO: If we keep this, it should implement Error.
142
pub struct ErrorResponse(Utf8CString);
143
impl ErrorResponse {
144
    /// Construct an ErrorResponse from the Error reply.
145
    ///
146
    /// This not a From impl because we want it to be crate-internal.
147
2084
    pub(crate) fn from_validated_string(s: Utf8CString) -> Self {
148
2084
        ErrorResponse(s)
149
2084
    }
150

            
151
    /// Convert this response into an internal error in response to `cmd`.
152
    ///
153
    /// This is only appropriate when the error cannot be caused because of user behavior.
154
    pub(crate) fn internal_error(&self, cmd: &str) -> ProtoError {
155
        ProtoError::InternalRequestFailed(UnexpectedReply {
156
            request: cmd.to_string(),
157
            reply: self.to_string(),
158
            problem: UnexpectedReplyProblem::ErrorNotExpected,
159
        })
160
    }
161

            
162
    /// Try to interpret this response as an [`RpcError`].
163
2080
    pub fn decode(&self) -> RpcError {
164
2080
        crate::msgs::response::try_decode_response_as_err(self.0.as_ref())
165
2080
            .expect("Could not decode response that was already decoded as an error?")
166
2080
            .expect("Could not extract error from response that was already decoded as an error?")
167
2080
    }
168
}
169

            
170
impl std::fmt::Display for ErrorResponse {
171
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172
        let e = self.decode();
173
        write!(f, "Peer said {:?}", e.message())
174
    }
175
}
176

            
177
/// A final response -- that is, the last one that we expect to receive for a request.
178
///
179
type FinalResponse = Result<SuccessResponse, ErrorResponse>;
180

            
181
/// Any of the three types of Arti responses.
182
#[derive(Clone, Debug)]
183
#[allow(clippy::exhaustive_structs)]
184
pub enum AnyResponse {
185
    /// The request has succeeded; no more response will be given.
186
    Success(SuccessResponse),
187
    /// The request has failed; no more response will be given.
188
    Error(ErrorResponse),
189
    /// An incremental update; more messages may arrive.
190
    Update(UpdateResponse),
191
}
192
// TODO RPC: DODGY TYPES END.
193

            
194
impl AnyResponse {
195
    /// Convert `v` into `AnyResponse`.
196
8082
    fn from_validated(v: ValidatedResponse) -> Self {
197
        // TODO RPC, Perhaps unify AnyResponse with ValidatedResponse, once we are sure what
198
        // AnyResponse should look like.
199
8082
        match v.meta.kind {
200
2080
            ResponseKind::Error => AnyResponse::Error(ErrorResponse::from_validated_string(v.msg)),
201
2018
            ResponseKind::Success => AnyResponse::Success(SuccessResponse(v.msg)),
202
3984
            ResponseKind::Update => AnyResponse::Update(UpdateResponse(v.msg)),
203
        }
204
8082
    }
205

            
206
    /// Consume this `AnyResponse`, and return its internal string.
207
    #[cfg(feature = "ffi")]
208
    pub(crate) fn into_string(self) -> Utf8CString {
209
        match self {
210
            AnyResponse::Success(m) => m.into(),
211
            AnyResponse::Error(m) => m.into(),
212
            AnyResponse::Update(m) => m.into(),
213
        }
214
    }
215
}
216

            
217
impl RpcConn {
218
    /// Return the ObjectId for the negotiated Session.
219
    ///
220
    /// Nearly all RPC methods require a Session, or some other object
221
    /// accessed via the session.
222
    ///
223
    /// (This function will only return None if no authentication has been performed.
224
    /// TODO RPC: It is not currently possible to make an unauthenticated connection.)
225
    pub fn session(&self) -> Option<&ObjectId> {
226
        self.session.as_ref()
227
    }
228

            
229
    /// Run a command, and wait for success or failure.
230
    ///
231
    /// Note that this function will return `Err(.)` only if sending the command or getting a
232
    /// response failed.
233
    /// If the command was sent successfully, and Arti reported an error in response,
234
    /// this function returns `Ok(Err(.))`.
235
    ///
236
    /// Note that the command does not need to include an `id` field.  If you omit it,
237
    /// one will be generated.
238
98
    pub fn execute(&self, cmd: &str) -> Result<FinalResponse, ProtoError> {
239
98
        let hnd = self.execute_with_handle(cmd)?;
240
90
        hnd.wait()
241
98
    }
242

            
243
    /// Helper for executing internally-generated requests and decoding their results.
244
    ///
245
    /// Behaves like `execute`, except on success, where it tries to decode the `result` field
246
    /// of the response as a `T`.
247
    ///
248
    /// Use this method in cases where it's reasonable for Arti to sometimes return an RPC error:
249
    /// in other words, where it's not necessarily a programming error or version mismatch.
250
    ///
251
    /// Don't use this for user-generated requests: it will misreport unexpected replies
252
    /// as internal errors.
253
2
    pub(crate) fn execute_internal<T: DeserializeOwned>(
254
2
        &self,
255
2
        cmd: &str,
256
2
    ) -> Result<Result<T, ErrorResponse>, ProtoError> {
257
2
        match self.execute(cmd)? {
258
2
            Ok(success) => match success.decode::<T>() {
259
2
                Ok(result) => Ok(Ok(result)),
260
                Err(json_error) => Err(ProtoError::InternalRequestFailed(UnexpectedReply {
261
                    request: cmd.to_string(),
262
                    reply: Utf8CString::from(success).to_string(),
263
                    problem: UnexpectedReplyProblem::CannotDecode(Arc::new(json_error)),
264
                })),
265
            },
266
            Err(error) => Ok(Err(error)),
267
        }
268
2
    }
269

            
270
    /// Helper for executing internally-generated requests and decoding their results.
271
    ///
272
    /// Behaves like `execute_internal`, except that it treats any RPC error reply
273
    /// as an internal error or version mismatch.
274
    ///
275
    /// Don't use this for user-generated requests, or for requests that can fail because of
276
    /// incorrect user inputs: it will misreport failures in those requests as internal errors.
277
2
    pub(crate) fn execute_internal_ok<T: DeserializeOwned>(
278
2
        &self,
279
2
        cmd: &str,
280
2
    ) -> Result<T, ProtoError> {
281
2
        match self.execute_internal(cmd)? {
282
2
            Ok(v) => Ok(v),
283
            Err(err_response) => Err(err_response.internal_error(cmd)),
284
        }
285
2
    }
286

            
287
    /// Cancel a request by ID.
288
    pub fn cancel(&self, request_id: &AnyRequestId) -> Result<(), ProtoError> {
289
        /// Arguments to an `rpc::cancel` request.
290
        #[derive(serde::Serialize, Debug)]
291
        struct CancelParams<'a> {
292
            /// The request to cancel.
293
            request_id: &'a AnyRequestId,
294
        }
295

            
296
        let request = crate::msgs::request::Request::new(
297
            ObjectId::connection_id(),
298
            "rpc:cancel",
299
            CancelParams { request_id },
300
        );
301
        match self.execute_internal::<EmptyReply>(&request.encode()?)? {
302
            Ok(EmptyReply {}) => Ok(()),
303
            Err(_) => Err(ProtoError::RequestCompleted),
304
        }
305
    }
306

            
307
    /// Like `execute`, but don't wait.  This lets the caller see the
308
    /// request ID and  maybe cancel it.
309
4194
    pub fn execute_with_handle(&self, cmd: &str) -> Result<RequestHandle, ProtoError> {
310
4194
        self.send_waitable_request(cmd)
311
4194
    }
312
    /// As execute(), but run update_cb for every update we receive.
313
4096
    pub fn execute_with_updates<F>(
314
4096
        &self,
315
4096
        cmd: &str,
316
4096
        mut update_cb: F,
317
4096
    ) -> Result<FinalResponse, ProtoError>
318
4096
    where
319
4096
        F: FnMut(UpdateResponse) + Send + Sync,
320
    {
321
4096
        let hnd = self.execute_with_handle(cmd)?;
322
        loop {
323
8080
            match hnd.wait_with_updates()? {
324
2016
                AnyResponse::Success(s) => return Ok(Ok(s)),
325
2080
                AnyResponse::Error(e) => return Ok(Err(e)),
326
3984
                AnyResponse::Update(u) => update_cb(u),
327
            }
328
        }
329
4096
    }
330

            
331
    /// As execute(), but do not wait for a response.
332
    ///
333
    /// Instead, the caller must provide a [`UserTag`] to identify a particular request,
334
    /// and must make sure that responses are being processed via [`wait()`](Self::wait).
335
    ///
336
    /// (If nobody is running `wait()`, then responses will never be handled,
337
    /// and can potentially fill up memory.)
338
    pub fn submit(&self, tag: UserTag, cmd: &str) -> Result<(), ProtoError> {
339
        self.send_pollable_request(tag, cmd)
340
    }
341

            
342
    /// Helper: Tell Arti to release `obj`.
343
    ///
344
    /// Do not use this method for a user-provided object ID:
345
    /// It gives an internal error if the object does not exist.
346
    pub(crate) fn release_obj(&self, obj: ObjectId) -> Result<(), ProtoError> {
347
        let release_request = crate::msgs::request::Request::new(obj, "rpc:release", NoParams {});
348
        let _empty_response: EmptyReply = self.execute_internal_ok(&release_request.encode()?)?;
349
        Ok(())
350
    }
351

            
352
    /// Wait for a response to arrive for a request that was sent via [`submit()`](Self::submit).
353
    ///
354
    /// Return that response,
355
    /// along with the [`UserTag`] that was associated with its request.
356
    ///
357
    /// This method will never return responses
358
    /// to any requests made with one of the `execute` methods;
359
    /// only to requests submitted with `submit()`.
360
    ///
361
    /// It is safe, but generally pointless, to call this method from multiple threads.
362
    pub fn wait(&self) -> Result<(UserTag, AnyResponse), ProtoError> {
363
        let (tag, r) = self.receiver.wait_on_pollable_response()?;
364
        Ok((tag, AnyResponse::from_validated(r)))
365
    }
366

            
367
    // TODO RPC: shutdown() on the socket on Drop.
368
}
369

            
370
impl RequestHandle {
371
    /// Return the ID of this request, to help cancelling it.
372
    pub fn id(&self) -> &AnyRequestId {
373
        &self.id
374
    }
375
    /// Wait for success or failure, and return what happened.
376
    ///
377
    /// (Ignores any update messages that are received.)
378
    ///
379
    /// Note that this function will return `Err(.)` only if sending the command or getting a
380
    /// response failed.
381
    /// If the command was sent successfully, and Arti reported an error in response,
382
    /// this function returns `Ok(Err(.))`.
383
90
    pub fn wait(self) -> Result<FinalResponse, ProtoError> {
384
        loop {
385
90
            match self.wait_with_updates()? {
386
2
                AnyResponse::Success(s) => return Ok(Ok(s)),
387
                AnyResponse::Error(e) => return Ok(Err(e)),
388
                AnyResponse::Update(_) => {}
389
            }
390
        }
391
90
    }
392
    /// Wait for the next success, failure, or update from this handle.
393
    ///
394
    /// Note that this function will return `Err(.)` only if sending the command or getting a
395
    /// response failed.
396
    /// If the command was sent successfully, and Arti reported an error in response,
397
    /// this function returns `Ok(AnyResponse::Error(.))`.
398
    ///
399
    /// You may call this method on the same `RequestHandle` from multiple threads.
400
    /// If you do so, those calls will receive responses (or errors) in an unspecified order.
401
    ///
402
    /// If this function returns Success or Error, then you shouldn't call it again.
403
    /// All future calls to this function will fail with `CmdError::RequestCancelled`.
404
    /// (TODO RPC: Maybe rename that error.)
405
8170
    pub fn wait_with_updates(&self) -> Result<AnyResponse, ProtoError> {
406
8170
        let conn = self.conn.lock().expect("Poisoned lock");
407
8170
        let validated = conn.wait_on_message_for(&self.id)?;
408
8082
        Ok(AnyResponse::from_validated(validated))
409
8170
    }
410

            
411
    // TODO RPC: Sketch out how we would want to do this in an async world,
412
    // or with poll
413
}
414

            
415
/// An error (or other condition) that has caused an RPC connection to shut down.
416
#[derive(Clone, Debug, thiserror::Error)]
417
#[non_exhaustive]
418
pub enum ShutdownError {
419
    // TODO nb: Read/Write are no longer well separated in the API.
420
    //
421
    /// Io error occurred while reading.
422
    #[error("Unable to read response")]
423
    Read(#[source] Arc<io::Error>),
424
    /// Io error occurred while writing.
425
    #[error("Unable to write request")]
426
    Write(#[source] Arc<io::Error>),
427
    /// Something was wrong with Arti's responses; this is a protocol violation.
428
    #[error("Arti sent a message that didn't conform to the RPC protocol: {0:?}")]
429
    ProtocolViolated(String),
430
    /// Arti has told us that we violated the protocol somehow.
431
    #[error("Arti reported a fatal error: {0:?}")]
432
    ProtocolViolationReport(ErrorResponse),
433
    /// The underlying connection closed.
434
    ///
435
    /// This probably means that Arti has shut down.
436
    #[error("Connection closed")]
437
    ConnectionClosed,
438
}
439

            
440
impl From<crate::msgs::response::DecodeResponseError> for ShutdownError {
441
4
    fn from(value: crate::msgs::response::DecodeResponseError) -> Self {
442
        use crate::msgs::response::DecodeResponseError::*;
443
        use ShutdownError as E;
444
4
        match value {
445
2
            JsonProtocolViolation(e) => E::ProtocolViolated(e.to_string()),
446
            ProtocolViolation(s) => E::ProtocolViolated(s.to_string()),
447
2
            Fatal(rpc_err) => E::ProtocolViolationReport(rpc_err),
448
        }
449
4
    }
450
}
451

            
452
/// An error that has occurred while launching an RPC command.
453
#[derive(Clone, Debug, thiserror::Error)]
454
#[non_exhaustive]
455
pub enum ProtoError {
456
    /// The RPC connection failed, or was closed by the other side.
457
    #[error("RPC connection is shut down")]
458
    Shutdown(#[from] ShutdownError),
459

            
460
    /// There was a problem in the request we tried to send.
461
    #[error("Invalid request")]
462
    InvalidRequest(#[from] InvalidRequestError),
463

            
464
    /// We tried to send a request with an ID that was already pending.
465
    #[error("Request ID already in use.")]
466
    RequestIdInUse,
467

            
468
    /// We tried to wait for or inspect a request that had already succeeded or failed.
469
    #[error("Request has already completed (or failed)")]
470
    RequestCompleted,
471

            
472
    /// We tried to wait for the same request more than once.
473
    ///
474
    /// (This should be impossible.)
475
    #[error("Internal error: waiting on the same request more than once at a time.")]
476
    DuplicateWait,
477

            
478
    /// We got an internal error while trying to encode an RPC request.
479
    ///
480
    /// (This should be impossible.)
481
    #[error("Internal error while encoding request")]
482
    CouldNotEncode(#[source] Arc<serde_json::Error>),
483

            
484
    /// We tried to wait on a request that was not created with a queue.
485
    ///
486
    /// (This should be impossible).
487
    #[error("Internal error: waiting on a request created for polling.")]
488
    RequestNotWaitable,
489

            
490
    /// We got a response to some internally generated request that wasn't what we expected.
491
    #[error("{0}")]
492
    InternalRequestFailed(#[source] UnexpectedReply),
493
}
494

            
495
/// A set of errors encountered while trying to connect to the Arti process
496
#[derive(Clone, Debug, thiserror::Error)]
497
pub struct ConnectFailure {
498
    /// A list of all the declined connect points we encountered, and how they failed.
499
    declined: Vec<(builder::ConnPtDescription, ConnectError)>,
500
    /// A description of where we found the final error (if it's an abort.)
501
    final_desc: Option<builder::ConnPtDescription>,
502
    /// The final error explaining why we couldn't connect.
503
    ///
504
    /// This is either an abort, an AllAttemptsDeclined, or an error that prevented the
505
    /// search process from even beginning.
506
    #[source]
507
    pub(crate) final_error: ConnectError,
508
}
509

            
510
impl std::fmt::Display for ConnectFailure {
511
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512
        write!(f, "Unable to connect")?;
513
        if !self.declined.is_empty() {
514
            write!(
515
                f,
516
                " ({} attempts failed{})",
517
                self.declined.len(),
518
                if matches!(self.final_error, ConnectError::AllAttemptsDeclined) {
519
                    ""
520
                } else {
521
                    " before fatal error"
522
                }
523
            )?;
524
        }
525
        Ok(())
526
    }
527
}
528

            
529
impl ConnectFailure {
530
    /// If this attempt failed because of a fatal error that made a connect point attempt abort,
531
    /// return a description of the origin of that connect point.
532
    pub fn fatal_error_origin(&self) -> Option<&builder::ConnPtDescription> {
533
        self.final_desc.as_ref()
534
    }
535

            
536
    /// For each connect attempt that failed nonfatally, return a description of the
537
    /// origin of that connect point, and the error that caused it to fail.
538
    pub fn declined_attempt_outcomes(
539
        &self,
540
    ) -> impl Iterator<Item = (&builder::ConnPtDescription, &ConnectError)> {
541
        // Note: this map looks like a no-op, but isn't.
542
        self.declined.iter().map(|(a, b)| (a, b))
543
    }
544

            
545
    /// Return a helper type to format this error, and all of its internal errors recursively.
546
    ///
547
    /// Unlike [`tor_error::Report`], this method includes not only fatal errors, but also
548
    /// information about connect attempts that failed nonfatally.
549
    pub fn display_verbose(&self) -> ConnectFailureVerboseFmt<'_> {
550
        ConnectFailureVerboseFmt(self)
551
    }
552
}
553

            
554
/// Helper type to format a ConnectFailure along with all of its internal errors,
555
/// including non-fatal errors.
556
#[derive(Debug, Clone)]
557
pub struct ConnectFailureVerboseFmt<'a>(&'a ConnectFailure);
558

            
559
impl<'a> std::fmt::Display for ConnectFailureVerboseFmt<'a> {
560
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
561
        use tor_error::ErrorReport as _;
562
        writeln!(f, "{}:", self.0)?;
563
        for (idx, (origin, error)) in self.0.declined_attempt_outcomes().enumerate() {
564
            writeln!(f, "  {}. {}: {}", idx + 1, origin, error.report())?;
565
        }
566
        if let Some(origin) = self.0.fatal_error_origin() {
567
            writeln!(
568
                f,
569
                "  {}. [FATAL] {}: {}",
570
                self.0.declined.len() + 1,
571
                origin,
572
                self.0.final_error.report()
573
            )?;
574
        } else {
575
            writeln!(f, "  - {}", self.0.final_error.report())?;
576
        }
577
        Ok(())
578
    }
579
}
580

            
581
/// An error while trying to connect to the Arti process.
582
#[derive(Clone, Debug, thiserror::Error)]
583
#[non_exhaustive]
584
pub enum ConnectError {
585
    /// Unable to parse connect points from an environment variable.
586
    #[error("Cannot parse connect points from environment variable")]
587
    BadEnvironment,
588
    /// We were unable to load and/or parse a given connect point.
589
    #[error("Unable to load and parse connect point")]
590
    CannotParse(#[from] tor_rpc_connect::load::LoadError),
591
    /// The path used to specify a connect file couldn't be resolved.
592
    #[error("Unable to resolve connect point path")]
593
    CannotResolvePath(#[source] tor_config_path::CfgPathError),
594
    /// A parsed connect point couldn't be resolved.
595
    #[error("Unable to resolve connect point")]
596
    CannotResolveConnectPoint(#[from] tor_rpc_connect::ResolveError),
597
    /// IO error while connecting to Arti.
598
    #[error("Unable to make a connection")]
599
    CannotConnect(#[from] tor_rpc_connect::ConnectError),
600
    /// The connect point told us to connect via a type of stream we don't know how to support.
601
    #[error("Connect point stream type was unsupported")]
602
    StreamTypeUnsupported,
603
    /// Opened a connection, but didn't get a banner message.
604
    ///
605
    /// (This isn't a `BadMessage`, since it is likelier to represent something that isn't
606
    /// pretending to be Arti at all than it is to be a malfunctioning Arti.)
607
    #[error("Did not receive expected banner message upon connecting")]
608
    InvalidBanner,
609
    /// All attempted connect points were declined, and none were aborted.
610
    #[error("All connect points were declined (or there were none)")]
611
    AllAttemptsDeclined,
612
    /// A connect file or directory was given as a relative path.
613
    /// (Only absolute paths are supported).
614
    #[error("Connect file was given as a relative path.")]
615
    RelativeConnectFile,
616
    /// One of our authentication messages received an error.
617
    #[error("Received an error while trying to authenticate: {0}")]
618
    AuthenticationFailed(ErrorResponse),
619
    /// The connect point uses an RPC authentication type we don't support.
620
    #[error("Authentication type is not supported")]
621
    AuthenticationNotSupported,
622
    /// We couldn't decode one of the responses we got.
623
    #[error("Message not in expected format")]
624
    BadMessage(#[source] Arc<serde_json::Error>),
625
    /// A protocol error occurred during negotiations.
626
    #[error("Error while negotiating with Arti")]
627
    ProtoError(#[from] ProtoError),
628
    /// The server thinks it is listening on an address where we don't expect to find it.
629
    /// This can be misconfiguration or an attempted MITM attack.
630
    #[error("We connected to the server at {ours}, but it thinks it's listening at {theirs}")]
631
    ServerAddressMismatch {
632
        /// The address we think the server has
633
        ours: String,
634
        /// The address that the server says it has.
635
        theirs: String,
636
    },
637
    /// The server tried to prove knowledge of a cookie file, but its proof was incorrect.
638
    #[error("Server's cookie MAC was not as expected.")]
639
    CookieMismatch,
640
    /// We were unable to access the configured cookie file.
641
    #[error("Unable to load secret cookie value")]
642
    LoadCookie(#[from] CookieAccessError),
643
    /// We want superuser permission, and this connect point does not grant it.
644
    #[error("Connect point does not provide superuser permission.")]
645
    NoSuperuserPermission,
646
}
647

            
648
impl HasClientErrorAction for ConnectError {
649
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
650
        use ConnectError as E;
651
        use tor_rpc_connect::ClientErrorAction as A;
652
        match self {
653
            E::BadEnvironment => A::Abort,
654
            E::CannotParse(e) => e.client_action(),
655
            E::CannotResolvePath(_) => A::Abort,
656
            E::CannotResolveConnectPoint(e) => e.client_action(),
657
            E::CannotConnect(e) => e.client_action(),
658
            E::StreamTypeUnsupported => A::Decline,
659
            E::InvalidBanner => A::Decline,
660
            E::RelativeConnectFile => A::Abort,
661
            E::AuthenticationFailed(_) => A::Decline,
662
            // TODO RPC: Is this correct?  This error can also occur when
663
            // we are talking to something other than an RPC server.
664
            E::BadMessage(_) => A::Abort,
665
            E::ProtoError(e) => e.client_action(),
666
            E::AllAttemptsDeclined => A::Abort,
667
            E::AuthenticationNotSupported => A::Decline,
668
            E::ServerAddressMismatch { .. } => A::Abort,
669
            E::CookieMismatch => A::Abort,
670
            E::LoadCookie(e) => e.client_action(),
671
            E::NoSuperuserPermission => A::Decline,
672
        }
673
    }
674
}
675

            
676
impl HasClientErrorAction for ProtoError {
677
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
678
        use ProtoError as E;
679
        use tor_rpc_connect::ClientErrorAction as A;
680
        match self {
681
            E::Shutdown(_) => A::Decline,
682
            E::InternalRequestFailed(_) => A::Decline,
683
            // These are always internal errors if they occur
684
            // while negotiating a connection to RPC,
685
            // which is the context we care about for `HasClientErrorAction`.
686
            E::InvalidRequest(_)
687
            | E::RequestIdInUse
688
            | E::RequestCompleted
689
            | E::DuplicateWait
690
            | E::RequestNotWaitable
691
            | E::CouldNotEncode(_) => A::Abort,
692
        }
693
    }
694
}
695

            
696
/// In response to a request that we generated internally,
697
/// Arti gave a reply that we did not understand.
698
///
699
/// This could be due to a bug in this library, a bug in Arti,
700
/// or a compatibility issue between the two.
701
#[derive(Clone, Debug, thiserror::Error)]
702
#[error("In response to our request {request:?}, Arti gave the unexpected reply {reply:?}")]
703
pub struct UnexpectedReply {
704
    /// The request we sent.
705
    request: String,
706
    /// The response we got.
707
    reply: String,
708
    /// What was wrong with the response.
709
    #[source]
710
    problem: UnexpectedReplyProblem,
711
}
712

            
713
/// Underlying reason for an UnexpectedReply
714
#[derive(Clone, Debug, thiserror::Error)]
715
enum UnexpectedReplyProblem {
716
    /// There was a json failure while trying to decode the response:
717
    /// the result type was not what we expected.
718
    #[error("Cannot decode as correct JSON type")]
719
    CannotDecode(Arc<serde_json::Error>),
720
    /// Arti replied with an RPC error in a context no error should have been possible.
721
    #[error("Unexpected error")]
722
    ErrorNotExpected,
723
}
724

            
725
/// Arguments to a request that takes no parameters.
726
#[derive(serde::Serialize, Debug)]
727
struct NoParams {}
728

            
729
/// A reply with no data.
730
#[derive(serde::Deserialize, Debug)]
731
struct EmptyReply {}
732

            
733
#[cfg(test)]
734
mod test {
735
    // @@ begin test lint list maintained by maint/add_warning @@
736
    #![allow(clippy::bool_assert_comparison)]
737
    #![allow(clippy::clone_on_copy)]
738
    #![allow(clippy::dbg_macro)]
739
    #![allow(clippy::mixed_attributes_style)]
740
    #![allow(clippy::print_stderr)]
741
    #![allow(clippy::print_stdout)]
742
    #![allow(clippy::single_char_pattern)]
743
    #![allow(clippy::unwrap_used)]
744
    #![allow(clippy::unchecked_time_subtraction)]
745
    #![allow(clippy::useless_vec)]
746
    #![allow(clippy::needless_pass_by_value)]
747
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
748

            
749
    use std::{sync::atomic::AtomicUsize, thread, time::Duration};
750

            
751
    use io::{BufRead as _, BufReader, Write as _};
752
    use rand::{Rng as _, SeedableRng as _, seq::SliceRandom as _};
753
    use tor_basic_utils::{RngExt as _, test_rng::testing_rng};
754

            
755
    use crate::{
756
        ll_conn::BlockingConnection,
757
        msgs::request::{JsonMap, Request, ValidatedRequest},
758
    };
759

            
760
    use super::*;
761

            
762
    /// helper: Return a dummy RpcConn, along with a socketpair for it to talk to.
763
    fn dummy_connected() -> (RpcConn, crate::testing::SocketpairStream) {
764
        let (s1, s2) = crate::testing::construct_socketpair().unwrap();
765
        let conn = RpcConn::new(BlockingConnection::new(s1).unwrap());
766

            
767
        (conn, s2)
768
    }
769

            
770
    fn write_val(w: &mut impl io::Write, v: &serde_json::Value) {
771
        let mut enc = serde_json::to_string(v).unwrap();
772
        enc.push('\n');
773
        w.write_all(enc.as_bytes()).unwrap();
774
    }
775

            
776
    #[test]
777
    fn simple() {
778
        let (conn, sock) = dummy_connected();
779

            
780
        let user_thread = thread::spawn(move || {
781
            let response1 = conn
782
                .execute_internal_ok::<JsonMap>(
783
                    r#"{"obj":"fred","method":"arti:x-frob","params":{}}"#,
784
                )
785
                .unwrap();
786
            (response1, conn)
787
        });
788

            
789
        let fake_arti_thread = thread::spawn(move || {
790
            let mut sock = BufReader::new(sock);
791
            let mut s = String::new();
792
            let _len = sock.read_line(&mut s).unwrap();
793
            let request = ValidatedRequest::from_string_strict(s.as_ref()).unwrap();
794
            let response = serde_json::json!({
795
                "id": request.id().clone(),
796
                "result": { "xyz" : 3 }
797
            });
798
            write_val(sock.get_mut(), &response);
799
            sock // prevent close
800
        });
801

            
802
        let _sock = fake_arti_thread.join().unwrap();
803
        let (map, _conn) = user_thread.join().unwrap();
804
        assert_eq!(map.get("xyz"), Some(&serde_json::Value::Number(3.into())));
805
    }
806

            
807
    #[test]
808
    fn complex() {
809
        use std::sync::atomic::Ordering::SeqCst;
810
        let n_threads = 16;
811
        let n_commands_per_thread = 128;
812
        let n_commands_total = n_threads * n_commands_per_thread;
813
        let n_completed = Arc::new(AtomicUsize::new(0));
814

            
815
        let (conn, sock) = dummy_connected();
816
        let conn = Arc::new(conn);
817
        let mut user_threads = Vec::new();
818
        let mut rng = testing_rng();
819

            
820
        // -------
821
        // User threads: Make a bunch of requests.
822
        for th_idx in 0..n_threads {
823
            let conn = Arc::clone(&conn);
824
            let n_completed = Arc::clone(&n_completed);
825
            let mut rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
826
            let th = thread::spawn(move || {
827
                for cmd_idx in 0..n_commands_per_thread {
828
                    // We are spawning a bunch of worker threads,
829
                    // each of which will run a number of
830
                    // commands in sequence.  Each command will be a request that gets optional
831
                    // updates, and an error or a success.
832
                    // We will double-check that each request gets the response it asked for.
833
                    let s = format!("{}:{}", th_idx, cmd_idx);
834
                    let want_updates: bool = rng.random();
835
                    let want_failure: bool = rng.random();
836
                    let req = serde_json::json!({
837
                        "obj":"fred",
838
                        "method":"arti:x-echo",
839
                        "meta": {
840
                            "updates": want_updates,
841
                        },
842
                        "params": {
843
                            "val": &s,
844
                            "fail": want_failure,
845
                        },
846
                    });
847
                    let req = serde_json::to_string(&req).unwrap();
848

            
849
                    // Wait for a final response, processing updates if we asked for them.
850
                    let mut n_updates = 0;
851
                    let outcome = conn
852
                        .execute_with_updates(&req, |_update| {
853
                            n_updates += 1;
854
                        })
855
                        .unwrap();
856
                    assert_eq!(n_updates > 0, want_updates);
857

            
858
                    // See if we liked the final response.
859
                    if want_failure {
860
                        let e = outcome.unwrap_err().decode();
861
                        assert_eq!(e.message(), "You asked me to fail");
862
                        assert_eq!(i32::from(e.code()), 33);
863
                        assert_eq!(
864
                            e.kinds_iter().collect::<Vec<_>>(),
865
                            vec!["Example".to_string()]
866
                        );
867
                    } else {
868
                        let success = outcome.unwrap();
869
                        let map = success.decode::<JsonMap>().unwrap();
870
                        assert_eq!(map.get("echo"), Some(&serde_json::Value::String(s)));
871
                    }
872
                    n_completed.fetch_add(1, SeqCst);
873
                    if rng.random::<f32>() < 0.02 {
874
                        thread::sleep(Duration::from_millis(3));
875
                    }
876
                }
877
            });
878
            user_threads.push(th);
879
        }
880

            
881
        #[derive(serde::Deserialize, Debug)]
882
        struct Echo {
883
            val: String,
884
            fail: bool,
885
        }
886

            
887
        // -----
888
        // Worker thread: handles user requests.
889
        let worker_rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
890
        let worker_thread = thread::spawn(move || {
891
            let mut rng = worker_rng;
892
            let mut sock = BufReader::new(sock);
893
            let mut pending: Vec<Request<Echo>> = Vec::new();
894
            let mut n_received = 0;
895

            
896
            // How many requests do we buffer before we shuffle them and answer them out-of-order?
897
            let scramble_factor = 7;
898
            // After receiving how many requests do we stop shuffling requests?
899
            //
900
            // (Our shuffling algorithm can deadlock us otherwise.)
901
            let scramble_threshold =
902
                n_commands_total - (n_commands_per_thread + 1) * scramble_factor;
903

            
904
            'outer: loop {
905
                let flush_pending_at = if n_received >= scramble_threshold {
906
                    1
907
                } else {
908
                    scramble_factor
909
                };
910

            
911
                // Queue a handful of requests in "pending"
912
                while pending.len() < flush_pending_at {
913
                    let mut buf = String::new();
914
                    if sock.read_line(&mut buf).unwrap() == 0 {
915
                        break 'outer;
916
                    }
917
                    n_received += 1;
918
                    let req: Request<Echo> = serde_json::from_str(&buf).unwrap();
919
                    pending.push(req);
920
                }
921

            
922
                // Handle the requests in "pending" in random order.
923
                let mut handling = std::mem::take(&mut pending);
924
                handling.shuffle(&mut rng);
925

            
926
                for req in handling {
927
                    if req.meta.unwrap_or_default().updates {
928
                        let n_updates = rng.gen_range_checked(1..4).unwrap();
929
                        for _ in 0..n_updates {
930
                            let up = serde_json::json!({
931
                                "id": req.id.clone(),
932
                                "update": {
933
                                    "hello": req.params.val.clone(),
934
                                }
935
                            });
936
                            write_val(sock.get_mut(), &up);
937
                        }
938
                    }
939

            
940
                    let response = if req.params.fail {
941
                        serde_json::json!({
942
                            "id": req.id.clone(),
943
                            "error": {
944
                                "message": "You asked me to fail",
945
                                "code": 33,
946
                                "kinds": ["Example"],
947
                                "data": req.params.val,
948
                            },
949
                        })
950
                    } else {
951
                        serde_json::json!({
952
                            "id": req.id.clone(),
953
                            "result": {
954
                                "echo": req.params.val
955
                            }
956
                        })
957
                    };
958
                    write_val(sock.get_mut(), &response);
959
                }
960
            }
961
        });
962
        drop(conn);
963
        for t in user_threads {
964
            t.join().unwrap();
965
        }
966

            
967
        worker_thread.join().unwrap();
968

            
969
        assert_eq!(n_completed.load(SeqCst), n_commands_total);
970
    }
971

            
972
    #[test]
973
    fn arti_socket_closed() {
974
        // Here we send a bunch of requests and then close the socket without answering them.
975
        //
976
        // Every request should get a ProtoError::Shutdown.
977
        let n_threads = 16;
978

            
979
        let (conn, sock) = dummy_connected();
980
        let conn = Arc::new(conn);
981
        let mut user_threads = Vec::new();
982
        for _ in 0..n_threads {
983
            let conn = Arc::clone(&conn);
984
            let th = thread::spawn(move || {
985
                // We are spawning a bunch of worker threads, each of which will run a number of
986
                // We will double-check that each request gets the response it asked for.
987
                let req = serde_json::json!({
988
                    "obj":"fred",
989
                    "method":"arti:x-echo",
990
                    "params":{}
991
                });
992
                let req = serde_json::to_string(&req).unwrap();
993
                let outcome = conn.execute(&req);
994
                if !matches!(
995
                    &outcome,
996
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
997
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_))),
998
                ) {
999
                    dbg!(&outcome);
                }
                assert!(matches!(
                    outcome,
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_)))
                        | Err(ProtoError::Shutdown(ShutdownError::ConnectionClosed))
                ));
            });
            user_threads.push(th);
        }
        drop(sock);
        for t in user_threads {
            t.join().unwrap();
        }
    }
    /// Send a bunch of requests and then send back a single reply.
    ///
    /// That reply should cause every request to get closed.
    fn proto_err_with_msg<F>(msg: &str, outcome_ok: F)
    where
        F: Fn(ProtoError) -> bool,
    {
        let n_threads = 16;
        let (conn, mut sock) = dummy_connected();
        let conn = Arc::new(conn);
        let mut user_threads = Vec::new();
        for _ in 0..n_threads {
            let conn = Arc::clone(&conn);
            let th = thread::spawn(move || {
                // We are spawning a bunch of worker threads, each of which will run a number of
                // We will double-check that each request gets the response it asked for.
                let req = serde_json::json!({
                    "obj":"fred",
                    "method":"arti:x-echo",
                    "params":{}
                });
                let req = serde_json::to_string(&req).unwrap();
                conn.execute(&req)
            });
            user_threads.push(th);
        }
        sock.write_all(msg.as_bytes()).unwrap();
        for t in user_threads {
            let outcome = t.join().unwrap();
            assert!(outcome_ok(outcome.unwrap_err()));
        }
    }
    #[test]
    fn syntax_error() {
        proto_err_with_msg("this is not json\n", |outcome| {
            matches!(
                outcome,
                ProtoError::Shutdown(ShutdownError::ProtocolViolated(_))
            )
        });
    }
    #[test]
    fn fatal_error() {
        let j = serde_json::json!({
            "error": {
                "message":
                "This test is doomed",
                "code": 413,
                "kinds": ["Example"],
                "data": {},
            },
        });
        let mut s = serde_json::to_string(&j).unwrap();
        s.push('\n');
        proto_err_with_msg(&s, |outcome| {
            matches!(
                outcome,
                ProtoError::Shutdown(ShutdownError::ProtocolViolationReport(_))
            )
        });
    }
}