1
//! Support for encoding and decoding RPC Requests.
2
//!
3
//! There are several types in this module:
4
//!
5
//! - [`Request`] is for requests that are generated from within this crate,
6
//!   to implement authentication, negotiation, and other functionality.
7
//! - `ParsedRequestFields` (internal) is for a request we've completely validated,
8
//!   with all of its fields present.
9
//! - [`ValidatedRequest`] is for a string that we have validated as a request.
10

            
11
use std::sync::Arc;
12

            
13
use serde::{Deserialize, Serialize};
14

            
15
/// Alias for a Map as used by the serde_json.
16
pub(crate) type JsonMap = serde_json::Map<String, serde_json::Value>;
17

            
18
use crate::conn::ProtoError;
19

            
20
use super::{AnyRequestId, JsonAnyObj, ObjectId};
21

            
22
/// An outbound request that we have generated from within this crate.
23
///
24
/// It lacks a required `id` field (since we will generate one when sending it),
25
/// and it allows any Serialize for its `params`.
26
#[derive(Serialize, Debug)]
27
// Testing only. Don't implement Deserialize here; this is not the type you should parse into!
28
#[cfg_attr(test, derive(Eq, PartialEq, Deserialize))]
29
#[allow(clippy::missing_docs_in_private_items)] // Fields are as for ParsedRequest.
30
pub(crate) struct Request<T> {
31
    #[serde(skip_serializing_if = "Option::is_none")]
32
    pub(crate) id: Option<AnyRequestId>,
33
    pub(crate) obj: ObjectId,
34
    #[serde(skip_serializing_if = "Option::is_none")]
35
    pub(crate) meta: Option<RequestMeta>,
36
    pub(crate) method: String,
37
    pub(crate) params: T,
38
}
39

            
40
/// An error that has prevented us from validating an request.
41
#[derive(Clone, Debug, thiserror::Error)]
42
#[non_exhaustive]
43
pub enum InvalidRequestError {
44
    /// We failed to turn the request into any kind of json.
45
    #[error("Request was not valid Json")]
46
    InvalidJson(#[source] Arc<serde_json::Error>),
47
    /// We got the request into json, but we couldn't find the fields we wanted.
48
    #[error("Request's fields were invalid or missing")]
49
    InvalidFormat(#[source] Arc<serde_json::Error>),
50
    /// We validated the request, but couldn't re-encode it.
51
    #[error("Unable to re-encode or format request")]
52
    ReencodeFailed(#[source] Arc<serde_json::Error>),
53
}
54

            
55
impl<T: Serialize> Request<T> {
56
    /// Construct a new outbound Request.
57
4
    pub(crate) fn new(obj: ObjectId, method: impl Into<String>, params: T) -> Self {
58
4
        Self {
59
4
            id: None,
60
4
            obj,
61
4
            meta: Default::default(),
62
4
            method: method.into(),
63
4
            params,
64
4
        }
65
4
    }
66
    /// Try to encode this request as a String.
67
    ///
68
    /// The string may not yet be a valid request; it might need to get an ID assigned.
69
4
    pub(crate) fn encode(&self) -> Result<String, ProtoError> {
70
4
        serde_json::to_string(self).map_err(|e| ProtoError::CouldNotEncode(Arc::new(e)))
71
4
    }
72
}
73

            
74
/// A request in its decoded (or unencoded) format.
75
///
76
/// We use this type to validate outbound requests from the application.
77
#[derive(Deserialize, Debug)]
78
// Don't implement Serialize here; this is not for generating requests!
79
#[allow(dead_code)] // The fields here are only used for validating serde objects.
80
struct ParsedRequestFields {
81
    /// The identifier for this request.
82
    ///
83
    /// Used to match a request with its responses.
84
    id: AnyRequestId,
85
    /// The ID for the object to which this request is addressed.
86
    ///
87
    /// (Every request goes to a single object.)
88
    obj: ObjectId,
89
    /// Additional information for Arti about how to handle the request.
90
    #[serde(skip_serializing_if = "Option::is_none")]
91
    meta: Option<RequestMeta>,
92
    /// The name of the method to invoke.
93
    method: String,
94
    /// Parameters to pass to the method.
95
    params: JsonAnyObj,
96
}
97

            
98
/// A known-valid request, encoded as a string (in a single line, with a terminating newline).
99
#[derive(derive_more::AsRef, Debug, Clone)]
100
pub(crate) struct ValidatedRequest {
101
    /// The message itself, as encoded.
102
    #[as_ref]
103
    msg: String,
104
    /// The ID for this request.
105
    id: AnyRequestId,
106
}
107

            
108
impl ValidatedRequest {
109
    /// Return the Id associated with this request.
110
4186
    pub(crate) fn id(&self) -> &AnyRequestId {
111
4186
        &self.id
112
4186
    }
113

            
114
    /// Try to construct a validated request from a `serde_json::Value`.
115
4204
    fn from_json_value(val: serde_json::Value) -> Result<Self, InvalidRequestError> {
116
4204
        let mut msg = serde_json::to_string(&val)
117
4204
            .map_err(|e| InvalidRequestError::ReencodeFailed(Arc::new(e)))?;
118
4204
        debug_assert!(!msg.contains('\n'));
119
4204
        msg.push('\n');
120

            
121
4204
        let req: ParsedRequestFields = serde_json::from_value(val)
122
4204
            .map_err(|e| InvalidRequestError::InvalidFormat(Arc::new(e)))?;
123
4204
        let id = req.id;
124

            
125
4204
        Ok(ValidatedRequest { id, msg })
126
4204
    }
127

            
128
    /// Try to construct a validated request using `s`.
129
    // TODO nb: Expose or remove.
130
    #[allow(dead_code)]
131
10
    pub(crate) fn from_string_strict(s: &str) -> Result<Self, InvalidRequestError> {
132
10
        let value: serde_json::Value =
133
10
            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
134
10
        Self::from_json_value(value)
135
10
    }
136

            
137
    /// Try to construct a ValidatedRequest from the string in `s`.
138
    ///
139
    /// If it has no `id`, add one using `id_generator`.
140
4194
    pub(crate) fn from_string_loose<F>(
141
4194
        s: &str,
142
4194
        id_generator: F,
143
4194
    ) -> Result<Self, InvalidRequestError>
144
4194
    where
145
4194
        F: FnOnce() -> AnyRequestId,
146
    {
147
4194
        let mut value: serde_json::Value =
148
4194
            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
149

            
150
4194
        if let Some(obj) = value.as_object_mut() {
151
4194
            obj.entry("id")
152
4194
                .or_insert_with(|| id_generator().into_json_value());
153
        }
154

            
155
4194
        Self::from_json_value(value)
156
4194
    }
157
}
158

            
159
/// Crate-internal: The "meta" field in a request.
160
#[derive(Deserialize, Serialize, Debug, Default)]
161
#[cfg_attr(test, derive(Eq, PartialEq))]
162
pub(crate) struct RequestMeta {
163
    /// If true, the application wants to receive incremental updates
164
    /// about the request that it sent.
165
    ///
166
    /// (Default: false)
167
    #[serde(default)]
168
    pub(crate) updates: bool,
169
    /// Any unrecognized fields that we received from the user.
170
    /// (We re-encode these in case the user knows about fields that we don't.)
171
    #[serde(flatten)]
172
    pub(crate) unrecognized_fields: JsonMap,
173
}
174

            
175
/// A helper to return unique Request identifiers.
176
///
177
/// All identifiers are prefixed with `"!aut o!--"`:
178
/// if you don't use that string in your own IDs,
179
/// you won't have any collisions.
180
#[derive(Debug, Default)]
181
pub(crate) struct IdGenerator {
182
    /// The number
183
    next_id: u64,
184
}
185

            
186
impl IdGenerator {
187
    /// Return a previously unyielded identifier.
188
4184
    pub(crate) fn next_id(&mut self) -> AnyRequestId {
189
4184
        let id = self.next_id;
190
4184
        self.next_id += 1;
191
4184
        format!("!auto!--{id}").into()
192
4184
    }
193
}
194

            
195
#[cfg(test)]
196
mod test {
197
    // @@ begin test lint list maintained by maint/add_warning @@
198
    #![allow(clippy::bool_assert_comparison)]
199
    #![allow(clippy::clone_on_copy)]
200
    #![allow(clippy::dbg_macro)]
201
    #![allow(clippy::mixed_attributes_style)]
202
    #![allow(clippy::print_stderr)]
203
    #![allow(clippy::print_stdout)]
204
    #![allow(clippy::single_char_pattern)]
205
    #![allow(clippy::unwrap_used)]
206
    #![allow(clippy::unchecked_time_subtraction)]
207
    #![allow(clippy::useless_vec)]
208
    #![allow(clippy::needless_pass_by_value)]
209
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
210

            
211
    impl ParsedRequestFields {
212
        /// Return true if this request is asking for updates.
213
        fn updates_requested(&self) -> bool {
214
            self.meta.as_ref().map(|m| m.updates).unwrap_or(false)
215
        }
216
    }
217

            
218
    use crate::util::assert_same_json;
219

            
220
    use super::*;
221
    const REQ1: &str = r#"{"id":7, "obj": "hi", "meta": {"updates": true}, "method":"twiddle", "params":{"stuff": "nonsense"} }"#;
222
    const REQ2: &str = r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{} }"#;
223
    const REQ3: &str =
224
        r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{},"unrecognized":"waffles"}"#;
225

            
226
    #[test]
227
    fn parse_requests() {
228
        let req1: ParsedRequestFields = serde_json::from_str(REQ1).unwrap();
229
        assert_eq!(req1.id, 7.into());
230
        assert_eq!(req1.obj.as_ref(), "hi");
231
        assert_eq!(req1.updates_requested(), true);
232
        assert_eq!(req1.method, "twiddle");
233

            
234
        let req2: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
235
        assert_eq!(req2.id, "fred".to_string().into());
236
        assert_eq!(req2.obj.as_ref(), "hi");
237
        assert_eq!(req2.updates_requested(), false);
238
        assert_eq!(req2.method, "twiddle");
239

            
240
        let _req3: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
241
    }
242

            
243
    #[test]
244
    fn reencode_requests() {
245
        for r in [REQ1, REQ2, REQ3] {
246
            let val1 = ValidatedRequest::from_string_strict(r).unwrap();
247
            let val2 = ValidatedRequest::from_string_loose(r, || panic!()).unwrap();
248

            
249
            assert_same_json!(val1.as_ref(), val2.as_ref());
250
            assert_same_json!(val1.as_ref(), r);
251
        }
252
    }
253

            
254
    #[test]
255
    fn bad_requests() {
256
        for text in [
257
            // not an object.
258
            "123",
259
            // missing most parts.
260
            r#"{"id":12,}"#,
261
            // no id.
262
            r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
263
            // no params
264
            r#"{"obj":"hi", "id": 7, "method":"twiddle"}"#,
265
            // bad params type
266
            r#"{"obj":"hi", "id": 7, "method":"twiddle", "params": []}"#,
267
            // weird obj.
268
            r#"{"obj":7, "id": 7, "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
269
            // weird id.
270
            r#"{"obj":"hi", "id": [], "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
271
            // weird method
272
            r#"{"obj":"hi", "id": 7, "method":6", "params":{"stuff":"nonsense"}}"#,
273
        ] {
274
            let r: Result<ParsedRequestFields, _> = serde_json::from_str(dbg!(text));
275
            assert!(r.is_err());
276
        }
277
    }
278

            
279
    #[test]
280
    fn fix_requests() {
281
        let no_id = r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
282
        let validated = ValidatedRequest::from_string_loose(no_id, || 7.into()).unwrap();
283
        let expected_with_id =
284
            r#"{"id": 7, "obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
285
        assert_same_json!(validated.as_ref(), expected_with_id);
286
    }
287

            
288
    #[test]
289
    fn preserve_fields() {
290
        let orig = r#"
291
            {"obj":"hi",
292
             "meta": { "updates": true, "waffles": "yesplz" },
293
             "method":"twiddle",
294
             "params":{"stuff":"nonsense"},
295
             "explosions": -70
296
            }"#;
297
        let validated = ValidatedRequest::from_string_loose(orig, || 77.into()).unwrap();
298
        let expected_with_id = r#"
299
            {"id":77,
300
            "obj":"hi",
301
            "meta": { "updates": true, "waffles": "yesplz" },
302
            "method":"twiddle",
303
            "params":{"stuff":"nonsense"},
304
            "explosions": -70
305
            }"#;
306
        assert_same_json!(validated.as_ref(), expected_with_id);
307
    }
308

            
309
    #[test]
310
    fn ok_request_encode() {
311
        let expected_encoded_request =
312
            r#"{"obj":"connection","method":"arti:get_rpc_proxy_info","params":"123"}"#;
313
        let obj_id = ObjectId::connection_id();
314
        let encoded_request = Request::new(obj_id, "arti:get_rpc_proxy_info", "123")
315
            .encode()
316
            .unwrap();
317
        assert_eq!(expected_encoded_request, encoded_request);
318
    }
319

            
320
    // This should not be possible
321
    #[test]
322
    fn err_request_encode() {
323
        struct FailingSerialization;
324

            
325
        impl serde::Serialize for FailingSerialization {
326
            fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
327
            where
328
                S: serde::Serializer,
329
            {
330
                Err(serde::ser::Error::custom(
331
                    "Intentional serialization failure",
332
                ))
333
            }
334
        }
335

            
336
        let obj_id = ObjectId::connection_id();
337
        let failing_request = Request::new(obj_id, "arti:get_rpc_proxy_info", FailingSerialization);
338

            
339
        let err = failing_request.encode().unwrap_err();
340
        assert!(matches!(err, ProtoError::CouldNotEncode(_)));
341
    }
342
}