1
#![cfg_attr(docsrs, feature(doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48

            
49
// TODO probably remove this at some point - see tpo/core/arti#1060
50
#![cfg_attr(
51
    not(all(feature = "full", feature = "experimental")),
52
    allow(unused_imports)
53
)]
54

            
55
mod err;
56
pub mod request;
57
mod response;
58
mod util;
59

            
60
use tor_circmgr::{CircMgr, DirInfo};
61
use tor_error::bad_api_usage;
62
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
63

            
64
// Zlib is required; the others are optional.
65
#[cfg(feature = "xz")]
66
use async_compression::futures::bufread::XzDecoder;
67
use async_compression::futures::bufread::ZlibDecoder;
68
#[cfg(feature = "zstd")]
69
use async_compression::futures::bufread::ZstdDecoder;
70

            
71
use futures::FutureExt;
72
use futures::io::{
73
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
74
};
75
use memchr::memchr;
76
use std::sync::Arc;
77
use std::time::Duration;
78
use tracing::{info, instrument};
79

            
80
pub use err::{Error, RequestError, RequestFailedError};
81
pub use response::{DirResponse, SourceInfo};
82

            
83
/// Type for results returned in this crate.
84
pub type Result<T> = std::result::Result<T, Error>;
85

            
86
/// Type for internal results  containing a RequestError.
87
pub type RequestResult<T> = std::result::Result<T, RequestError>;
88

            
89
/// Flag to declare whether a request is always anonymized or not.
90
///
91
/// This is used by tor-dirclient to control whether *other* deanonymizing metadata
92
/// might be added to the request (eg in request headers):
93
/// Some requests (like those to download onion service descriptors) are always
94
/// anonymized, and should never be sent in a way that leaks information about
95
/// our settings or configuration.
96
///
97
/// It is up to the *caller* of `tor-dirclient` to ensure that
98
///
99
///   - every request whose anonymization status is `AnonymizedRequest::Direct`
100
///     is sent only over non-anonymous connections.
101
///
102
///     (Sending an `AnonymizedRequest::Direct` request over an anonymized connection
103
///     would weaken the connection's anonymity, and can therefore weaken the anonymity
104
///     of user traffic sharing the same circuit.)
105
///
106
///   - every request whose anonymization status is `AnonymizedRequest::Anonymized`
107
///     is sent over only anonymous connections (ie, multi-hop circuits).
108
///
109
///     (Sending an `AnonymizedRequest::Anonymized` request over a direct connection
110
///     would directly reveal user behaviour data to the directory server.)
111
///
112
/// TODO the calling code cannot easily be sure to get this right this because
113
/// the anonymization status is a run-time property and the choice of connection kind
114
/// is statically defined in the calling code.  (Perhaps this could be checked in tests?)
115
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
116
#[non_exhaustive]
117
pub enum AnonymizedRequest {
118
    /// This request's content or semantics reveals or is correlated with sensitive information.
119
    ///
120
    /// For example, requests for hidden service descriptors reveal which hidden services
121
    /// the client is connecting to.
122
    ///
123
    /// The request must be sent over an anonymous circuit by the caller
124
    /// and no additional deanonymizing information should be added to it by `tor-dirclient`.
125
    /// (For example, no client-version-specific information should be
126
    /// sent in HTTP headers when the request is made.)
127
    Anonymized,
128

            
129
    /// Making this request does not reveal anything sensitive, nor any user behaviour.
130
    ///
131
    /// The request body is uncorrelated with such things as the websites the user might visit,
132
    /// the onion services the user is visiting or running, etc.
133
    ///
134
    /// For example, requests for all router microdescriptors are made by all clients,
135
    /// so which microdescriptor(s) are requested reveals nothing to any attacker.
136
    ///
137
    /// tor-dirclient is allowed to add include information about our capabilities
138
    /// when sending this request.
139
    /// The request must *not* be sent over an anonymous circuit by the caller
140
    /// (at least, not one used for anything else).
141
    Direct,
142
}
143

            
144
/// Fetch the resource described by `req` over the Tor network.
145
///
146
/// Circuits are built or found using `circ_mgr`, using paths
147
/// constructed using `dirinfo`.
148
///
149
/// For more fine-grained control over the circuit and stream used,
150
/// construct them yourself, and then call [`send_request`] instead.
151
///
152
/// # TODO
153
///
154
/// This is the only function in this crate that knows about CircMgr and
155
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
156
#[instrument(level = "trace", skip_all)]
157
pub async fn get_resource<CR, R, SP>(
158
    req: &CR,
159
    dirinfo: DirInfo<'_>,
160
    runtime: &SP,
161
    circ_mgr: Arc<CircMgr<R>>,
162
) -> Result<DirResponse>
163
where
164
    CR: request::Requestable + ?Sized,
165
    R: Runtime,
166
    SP: SleepProvider,
167
{
168
    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
169

            
170
    if req.anonymized() == AnonymizedRequest::Anonymized {
171
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
172
    }
173

            
174
    // TODO(nickm) This should be an option, and is too long.
175
    let begin_timeout = Duration::from_secs(5);
176
    let source = match SourceInfo::from_tunnel(&tunnel) {
177
        Ok(source) => source,
178
        Err(e) => {
179
            return Err(Error::RequestFailed(RequestFailedError {
180
                source: None,
181
                error: e.into(),
182
            }));
183
        }
184
    };
185

            
186
    let wrap_err = |error| {
187
        Error::RequestFailed(RequestFailedError {
188
            source: source.clone(),
189
            error,
190
        })
191
    };
192

            
193
    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
194

            
195
    // Launch the stream.
196
    let mut stream = runtime
197
        .timeout(begin_timeout, tunnel.begin_dir_stream())
198
        .await
199
        .map_err(RequestError::from)
200
        .map_err(wrap_err)?
201
        .map_err(RequestError::from)
202
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
203

            
204
    // TODO: Perhaps we want separate timeouts for each phase of this.
205
    // For now, we just use higher-level timeouts in `dirmgr`.
206
    let r = send_request(runtime, req, &mut stream, source.clone()).await;
207

            
208
    if should_retire_circ(&r) {
209
        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
210
    }
211

            
212
    r
213
}
214

            
215
/// Return true if `result` holds an error indicating that we should retire the
216
/// circuit used for the corresponding request.
217
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
218
    match result {
219
        Err(e) => e.should_retire_circ(),
220
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
221
    }
222
}
223

            
224
/// Fetch a Tor directory object from a provided stream.
225
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
226
pub async fn download<R, S, SP>(
227
    runtime: &SP,
228
    req: &R,
229
    stream: &mut S,
230
    source: Option<SourceInfo>,
231
) -> Result<DirResponse>
232
where
233
    R: request::Requestable + ?Sized,
234
    S: AsyncRead + AsyncWrite + Send + Unpin,
235
    SP: SleepProvider,
236
{
237
    send_request(runtime, req, stream, source).await
238
}
239

            
240
/// Fetch or upload a Tor directory object using the provided stream.
241
///
242
/// To do this, we send a simple HTTP/1.0 request for the described
243
/// object in `req` over `stream`, and then wait for a response.  In
244
/// log messages, we describe the origin of the data as coming from
245
/// `source`.
246
///
247
/// # Notes
248
///
249
/// It's kind of bogus to have a 'source' field here at all; we may
250
/// eventually want to remove it.
251
///
252
/// This function doesn't close the stream; you may want to do that
253
/// yourself.
254
///
255
/// The only error variant returned is [`Error::RequestFailed`].
256
// TODO: should the error return type change to `RequestFailedError`?
257
// If so, that would simplify some code in_dirmgr::bridgedesc.
258
196
pub async fn send_request<R, S, SP>(
259
196
    runtime: &SP,
260
196
    req: &R,
261
196
    stream: &mut S,
262
196
    source: Option<SourceInfo>,
263
196
) -> Result<DirResponse>
264
196
where
265
196
    R: request::Requestable + ?Sized,
266
196
    S: AsyncRead + AsyncWrite + Send + Unpin,
267
196
    SP: SleepProvider,
268
196
{
269
196
    let wrap_err = |error| {
270
40
        Error::RequestFailed(RequestFailedError {
271
40
            source: source.clone(),
272
40
            error,
273
40
        })
274
40
    };
275

            
276
196
    let partial_ok = req.partial_response_body_ok();
277
196
    let maxlen = req.max_response_len();
278
196
    let anonymized = req.anonymized();
279
196
    let req = req.make_request().map_err(wrap_err)?;
280
196
    let method = req.method().clone();
281
196
    let encoded = util::encode_request(&req);
282

            
283
    // Write the request.
284
196
    stream
285
196
        .write_all(encoded.as_bytes())
286
196
        .await
287
196
        .map_err(RequestError::from)
288
196
        .map_err(wrap_err)?;
289
196
    stream
290
196
        .flush()
291
196
        .await
292
196
        .map_err(RequestError::from)
293
196
        .map_err(wrap_err)?;
294

            
295
196
    let mut buffered = BufReader::new(stream);
296

            
297
    // Handle the response
298
    // TODO: should there be a separate timeout here?
299
196
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
300
158
    if header.status != Some(200) {
301
34
        return Ok(DirResponse::new(
302
34
            method,
303
34
            header.status.unwrap_or(0),
304
34
            header.status_message,
305
34
            None,
306
34
            vec![],
307
34
            source,
308
34
        ));
309
124
    }
310

            
311
124
    let mut decoder =
312
124
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
313

            
314
124
    let mut result = Vec::new();
315
124
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
316

            
317
124
    let ok = match (partial_ok, ok, result.len()) {
318
2
        (true, Err(e), n) if n > 0 => {
319
            // Note that we _don't_ return here: we want the partial response.
320
2
            Err(e)
321
        }
322
2
        (_, Err(e), _) => {
323
2
            return Err(wrap_err(e));
324
        }
325
120
        (_, Ok(()), _) => Ok(()),
326
    };
327

            
328
122
    Ok(DirResponse::new(
329
122
        method,
330
122
        200,
331
122
        None,
332
122
        ok.err(),
333
122
        result,
334
122
        source,
335
122
    ))
336
196
}
337

            
338
/// Maximum length for the HTTP headers in a single request or response.
339
///
340
/// Chosen more or less arbitrarily.
341
const MAX_HEADERS_LEN: usize = 16384;
342

            
343
/// Read and parse HTTP/1 headers from `stream`.
344
204
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
345
204
where
346
204
    S: AsyncBufRead + Unpin,
347
204
{
348
204
    let mut buf = Vec::with_capacity(1024);
349

            
350
    loop {
351
        // TODO: it's inefficient to do this a line at a time; it would
352
        // probably be better to read until the CRLF CRLF ending of the
353
        // response.  But this should be fast enough.
354
402
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
355

            
356
        // TODO(nickm): Better maximum and/or let this expand.
357
370
        let mut headers = [httparse::EMPTY_HEADER; 32];
358
370
        let mut response = httparse::Response::new(&mut headers);
359

            
360
370
        match response.parse(&buf[..])? {
361
            httparse::Status::Partial => {
362
                // We didn't get a whole response; we may need to try again.
363

            
364
206
                if n == 0 {
365
                    // We hit an EOF; no more progress can be made.
366
6
                    return Err(RequestError::TruncatedHeaders);
367
200
                }
368

            
369
200
                if buf.len() >= MAX_HEADERS_LEN {
370
2
                    return Err(RequestError::HeadersTooLong(buf.len()));
371
198
                }
372
            }
373
162
            httparse::Status::Complete(n_parsed) => {
374
162
                if response.code != Some(200) {
375
36
                    return Ok(HeaderStatus {
376
36
                        status: response.code,
377
36
                        status_message: response.reason.map(str::to_owned),
378
36
                        encoding: None,
379
36
                    });
380
126
                }
381
126
                let encoding = if let Some(enc) = response
382
126
                    .headers
383
126
                    .iter()
384
126
                    .find(|h| h.name == "Content-Encoding")
385
                {
386
10
                    Some(String::from_utf8(enc.value.to_vec())?)
387
                } else {
388
116
                    None
389
                };
390
                /*
391
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
392
                    let clen = std::str::from_utf8(clen.value)?;
393
                    length = Some(clen.parse()?);
394
                }
395
                 */
396
126
                assert!(n_parsed == buf.len());
397
126
                return Ok(HeaderStatus {
398
126
                    status: Some(200),
399
126
                    status_message: None,
400
126
                    encoding,
401
126
                });
402
            }
403
        }
404
198
        if n == 0 {
405
            return Err(RequestError::TruncatedHeaders);
406
198
        }
407
    }
408
204
}
409

            
410
/// Return value from read_headers
411
#[derive(Debug, Clone)]
412
struct HeaderStatus {
413
    /// HTTP status code.
414
    status: Option<u16>,
415
    /// HTTP status message associated with the status code.
416
    status_message: Option<String>,
417
    /// The Content-Encoding header, if any.
418
    encoding: Option<String>,
419
}
420

            
421
/// Helper: download directory information from `stream` and
422
/// decompress it into a result buffer.  Assumes that `buf` is empty.
423
///
424
/// If we get more than maxlen bytes after decompression, give an error.
425
///
426
/// Returns the status of our download attempt, stores any data that
427
/// we were able to download into `result`.  Existing contents of
428
/// `result` are overwritten.
429
138
async fn read_and_decompress<S, SP>(
430
138
    runtime: &SP,
431
138
    mut stream: S,
432
138
    maxlen: usize,
433
138
    result: &mut Vec<u8>,
434
138
) -> RequestResult<()>
435
138
where
436
138
    S: AsyncRead + Unpin,
437
138
    SP: SleepProvider,
438
138
{
439
138
    let buffer_window_size = 1024;
440
138
    let mut written_total: usize = 0;
441
    // TODO(nickm): This should be an option, and is maybe too long.
442
    // Though for some users it may be too short?
443
138
    let read_timeout = Duration::from_secs(10);
444
138
    let timer = runtime.sleep(read_timeout).fuse();
445
138
    futures::pin_mut!(timer);
446

            
447
    loop {
448
        // allocate buffer for next read
449
708
        result.resize(written_total + buffer_window_size, 0);
450
708
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
451

            
452
708
        let status = futures::select! {
453
708
            status = stream.read(buf).fuse() => status,
454
            _ = timer => {
455
                result.resize(written_total, 0); // truncate as needed
456
                return Err(RequestError::DirTimeout);
457
            }
458
        };
459
708
        let written_in_this_loop = match status {
460
702
            Ok(n) => n,
461
6
            Err(other) => {
462
6
                result.resize(written_total, 0); // truncate as needed
463
6
                return Err(other.into());
464
            }
465
        };
466

            
467
702
        written_total += written_in_this_loop;
468

            
469
        // exit conditions below
470

            
471
702
        if written_in_this_loop == 0 {
472
            /*
473
            in case we read less than `buffer_window_size` in last `read`
474
            we need to shrink result because otherwise we'll return those
475
            un-read 0s
476
            */
477
130
            if written_total < result.len() {
478
130
                result.resize(written_total, 0);
479
130
            }
480
130
            return Ok(());
481
572
        }
482

            
483
        // TODO: It would be good to detect compression bombs, but
484
        // that would require access to the internal stream, which
485
        // would in turn require some tricky programming.  For now, we
486
        // use the maximum length here to prevent an attacker from
487
        // filling our RAM.
488
572
        if written_total > maxlen {
489
2
            result.resize(maxlen, 0);
490
2
            return Err(RequestError::ResponseTooLong(written_total));
491
570
        }
492
    }
493
138
}
494

            
495
/// Retire a directory circuit because of an error we've encountered on it.
496
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
497
where
498
    R: Runtime,
499
{
500
    info!(
501
        "{}: Retiring circuit because of directory failure: {}",
502
        &id, &error
503
    );
504
    circ_mgr.retire_circ(id);
505
}
506

            
507
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
508
///
509
/// Note that this function might not actually read any byte of value
510
/// `byte`, since EOF might occur, or we might fill the buffer.
511
///
512
/// A return value of 0 indicates an end-of-file.
513
408
async fn read_until_limited<S>(
514
408
    stream: &mut S,
515
408
    byte: u8,
516
408
    max: usize,
517
408
    buf: &mut Vec<u8>,
518
408
) -> std::io::Result<usize>
519
408
where
520
408
    S: AsyncBufRead + Unpin,
521
408
{
522
408
    let mut n_added = 0;
523
    loop {
524
566
        let data = stream.fill_buf().await?;
525
534
        if data.is_empty() {
526
            // End-of-file has been reached.
527
12
            return Ok(n_added);
528
522
        }
529
522
        debug_assert!(n_added < max);
530
522
        let remaining_space = max - n_added;
531
522
        let (available, found_byte) = match memchr(byte, data) {
532
350
            Some(idx) => (idx + 1, true),
533
172
            None => (data.len(), false),
534
        };
535
522
        debug_assert!(available >= 1);
536
522
        let n_to_copy = std::cmp::min(remaining_space, available);
537
522
        buf.extend(&data[..n_to_copy]);
538
522
        stream.consume_unpin(n_to_copy);
539
522
        n_added += n_to_copy;
540
522
        if found_byte || n_added == max {
541
364
            return Ok(n_added);
542
158
        }
543
    }
544
408
}
545

            
546
/// Helper: Return a boxed decoder object that wraps the stream  $s.
547
macro_rules! decoder {
548
    ($dec:ident, $s:expr) => {{
549
        let mut decoder = $dec::new($s);
550
        decoder.multiple_members(true);
551
        Ok(Box::new(decoder))
552
    }};
553
}
554

            
555
/// Wrap `stream` in an appropriate type to undo the content encoding
556
/// as described in `encoding`.
557
140
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
558
140
    stream: S,
559
140
    encoding: Option<&str>,
560
140
    anonymized: AnonymizedRequest,
561
140
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
562
    use AnonymizedRequest::Direct;
563
140
    match (encoding, anonymized) {
564
136
        (None | Some("identity"), _) => Ok(Box::new(stream)),
565
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
566
        // We only admit to supporting these on a direct connection; otherwise,
567
        // a hostile directory could send them back even though we hadn't
568
        // requested them.
569
        #[cfg(feature = "xz")]
570
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
571
        #[cfg(feature = "zstd")]
572
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
573
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
574
    }
575
140
}
576

            
577
#[cfg(test)]
578
mod test {
579
    // @@ begin test lint list maintained by maint/add_warning @@
580
    #![allow(clippy::bool_assert_comparison)]
581
    #![allow(clippy::clone_on_copy)]
582
    #![allow(clippy::dbg_macro)]
583
    #![allow(clippy::mixed_attributes_style)]
584
    #![allow(clippy::print_stderr)]
585
    #![allow(clippy::print_stdout)]
586
    #![allow(clippy::single_char_pattern)]
587
    #![allow(clippy::unwrap_used)]
588
    #![allow(clippy::unchecked_time_subtraction)]
589
    #![allow(clippy::useless_vec)]
590
    #![allow(clippy::needless_pass_by_value)]
591
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
592
    use super::*;
593
    use tor_rtmock::io::stream_pair;
594

            
595
    use tor_rtmock::simple_time::SimpleMockTimeProvider;
596
    use web_time_compat::{SystemTime, SystemTimeExt};
597

            
598
    use futures_await_test::async_test;
599

            
600
    #[async_test]
601
    async fn test_read_until_limited() -> RequestResult<()> {
602
        let mut out = Vec::new();
603
        let bytes = b"This line eventually ends\nthen comes another\n";
604

            
605
        // Case 1: find a whole line.
606
        let mut s = &bytes[..];
607
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
608
        assert_eq!(res?, 26);
609
        assert_eq!(&out[..], b"This line eventually ends\n");
610

            
611
        // Case 2: reach the limit.
612
        let mut s = &bytes[..];
613
        out.clear();
614
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
615
        assert_eq!(res?, 10);
616
        assert_eq!(&out[..], b"This line ");
617

            
618
        // Case 3: reach EOF.
619
        let mut s = &bytes[..];
620
        out.clear();
621
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
622
        assert_eq!(res?, 45);
623
        assert_eq!(&out[..], &bytes[..]);
624

            
625
        Ok(())
626
    }
627

            
628
    // Basic decompression wrapper.
629
    async fn decomp_basic(
630
        encoding: Option<&str>,
631
        data: &[u8],
632
        maxlen: usize,
633
    ) -> (RequestResult<()>, Vec<u8>) {
634
        // We don't need to do anything fancy here, since we aren't simulating
635
        // a timeout.
636
        #[allow(deprecated)] // TODO #1885
637
        let mock_time = SimpleMockTimeProvider::from_wallclock(SystemTime::get());
638

            
639
        let mut output = Vec::new();
640
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
641
            Ok(s) => s,
642
            Err(e) => return (Err(e), output),
643
        };
644

            
645
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
646

            
647
        (r, output)
648
    }
649

            
650
    #[async_test]
651
    async fn decompress_identity() -> RequestResult<()> {
652
        let mut text = Vec::new();
653
        for _ in 0..1000 {
654
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
655
        }
656

            
657
        let limit = 10 << 20;
658
        let (s, r) = decomp_basic(None, &text[..], limit).await;
659
        s?;
660
        assert_eq!(r, text);
661

            
662
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
663
        s?;
664
        assert_eq!(r, text);
665

            
666
        // Try truncated result
667
        let limit = 100;
668
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
669
        assert!(s.is_err());
670
        assert_eq!(r, &text[..100]);
671

            
672
        Ok(())
673
    }
674

            
675
    #[async_test]
676
    async fn decomp_zlib() -> RequestResult<()> {
677
        let compressed =
678
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
679

            
680
        let limit = 10 << 20;
681
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
682
        s?;
683
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
684

            
685
        Ok(())
686
    }
687

            
688
    #[cfg(feature = "zstd")]
689
    #[async_test]
690
    async fn decomp_zstd() -> RequestResult<()> {
691
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
692
        let limit = 10 << 20;
693
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
694
        s?;
695
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
696

            
697
        Ok(())
698
    }
699

            
700
    #[cfg(feature = "xz")]
701
    #[async_test]
702
    async fn decomp_xz2() -> RequestResult<()> {
703
        // Not so good at tiny files...
704
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
705
        let limit = 10 << 20;
706
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
707
        s?;
708
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
709

            
710
        Ok(())
711
    }
712

            
713
    #[async_test]
714
    async fn decomp_unknown() {
715
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
716
        let limit = 10 << 20;
717
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
718

            
719
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
720
    }
721

            
722
    #[async_test]
723
    async fn decomp_bad_data() {
724
        let compressed = b"This is not good zlib data";
725
        let limit = 10 << 20;
726
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
727

            
728
        // This should possibly be a different type in the future.
729
        assert!(matches!(s, Err(RequestError::IoError(_))));
730
    }
731

            
732
    #[async_test]
733
    async fn headers_ok() -> RequestResult<()> {
734
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
735

            
736
        let mut s = &text[..];
737
        let h = read_headers(&mut s).await?;
738

            
739
        assert_eq!(h.status, Some(200));
740
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
741

            
742
        // now try truncated
743
        let mut s = &text[..15];
744
        let h = read_headers(&mut s).await;
745
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
746

            
747
        // now try with no encoding.
748
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
749
        let mut s = &text[..];
750
        let h = read_headers(&mut s).await?;
751

            
752
        assert_eq!(h.status, Some(404));
753
        assert!(h.encoding.is_none());
754

            
755
        Ok(())
756
    }
757

            
758
    #[async_test]
759
    async fn headers_bogus() -> Result<()> {
760
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
761
        let mut s = &text[..];
762
        let h = read_headers(&mut s).await;
763

            
764
        assert!(h.is_err());
765
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
766
        Ok(())
767
    }
768

            
769
    /// Run a trivial download example with a response provided as a binary
770
    /// string.
771
    ///
772
    /// Return the directory response (if any) and the request as encoded (if
773
    /// any.)
774
    fn run_download_test<Req: request::Requestable>(
775
        req: Req,
776
        response: &[u8],
777
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
778
        let (mut s1, s2) = stream_pair();
779
        let (mut s2_r, mut s2_w) = s2.split();
780

            
781
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
782
            let rt2 = rt.clone();
783
            let (v1, v2, v3): (
784
                Result<DirResponse>,
785
                RequestResult<Vec<u8>>,
786
                RequestResult<()>,
787
            ) = futures::join!(
788
                async {
789
                    // Run the download function.
790
                    let r = send_request(&rt, &req, &mut s1, None).await;
791
                    s1.close().await.map_err(|error| {
792
                        Error::RequestFailed(RequestFailedError {
793
                            source: None,
794
                            error: error.into(),
795
                        })
796
                    })?;
797
                    r
798
                },
799
                async {
800
                    // Take the request from the client, and return it in "v2"
801
                    let mut v = Vec::new();
802
                    s2_r.read_to_end(&mut v).await?;
803
                    Ok(v)
804
                },
805
                async {
806
                    // Send back a response.
807
                    s2_w.write_all(response).await?;
808
                    // We wait a moment to give the other side time to notice it
809
                    // has data.
810
                    //
811
                    // (Tentative diagnosis: The `async-compress` crate seems to
812
                    // be behave differently depending on whether the "close"
813
                    // comes right after the incomplete data or whether it comes
814
                    // after a delay.  If there's a delay, it notices the
815
                    // truncated data and tells us about it. But when there's
816
                    // _no_delay, it treats the data as an error and doesn't
817
                    // tell our code.)
818

            
819
                    // TODO: sleeping in tests is not great.
820
                    rt2.sleep(Duration::from_millis(50)).await;
821
                    s2_w.close().await?;
822
                    Ok(())
823
                }
824
            );
825

            
826
            assert!(v3.is_ok());
827

            
828
            (v1, v2)
829
        })
830
    }
831

            
832
    #[test]
833
    fn test_send_request() -> RequestResult<()> {
834
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
835

            
836
        let (response, request) = run_download_test(
837
            req,
838
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
839
        );
840

            
841
        let request = request?;
842
        assert!(request[..].starts_with(
843
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
844
        ));
845

            
846
        let response = response.unwrap();
847
        assert_eq!(response.status_code(), 200);
848
        assert!(!response.is_partial());
849
        assert!(response.error().is_none());
850
        assert!(response.source().is_none());
851
        let out_ref = response.output_unchecked();
852
        assert_eq!(out_ref, b"This is where the descs would go.");
853
        let out = response.into_output_unchecked();
854
        assert_eq!(&out, b"This is where the descs would go.");
855

            
856
        Ok(())
857
    }
858

            
859
    #[test]
860
    fn test_download_truncated() {
861
        // Request only one md, so "partial ok" will not be set.
862
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
863
        let mut response_text: Vec<u8> =
864
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
865
        // "One fish two fish" as above twice, but truncated the second time
866
        response_text.extend(
867
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
868
        );
869
        response_text.extend(
870
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
871
        );
872
        let (response, request) = run_download_test(req, &response_text);
873
        assert!(request.is_ok());
874
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
875

            
876
        // request two microdescs, so "partial_ok" will be set.
877
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
878

            
879
        let (response, request) = run_download_test(req, &response_text);
880
        assert!(request.is_ok());
881

            
882
        let response = response.unwrap();
883
        assert_eq!(response.status_code(), 200);
884
        assert!(response.error().is_some());
885
        assert!(response.is_partial());
886
        assert!(response.output_unchecked().len() < 37 * 2);
887
        assert!(response.output_unchecked().starts_with(b"One fish"));
888
    }
889

            
890
    #[test]
891
    fn test_404() {
892
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
893
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
894
        let (response, _request) = run_download_test(req, response_text);
895

            
896
        assert_eq!(response.unwrap().status_code(), 418);
897
    }
898

            
899
    #[test]
900
    fn test_headers_truncated() {
901
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
902
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
903
        let (response, _request) = run_download_test(req, response_text);
904

            
905
        assert!(matches!(
906
            response,
907
            Err(Error::RequestFailed(RequestFailedError {
908
                error: RequestError::TruncatedHeaders,
909
                ..
910
            }))
911
        ));
912

            
913
        // Try a completely empty response.
914
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
915
        let response_text = b"";
916
        let (response, _request) = run_download_test(req, response_text);
917

            
918
        assert!(matches!(
919
            response,
920
            Err(Error::RequestFailed(RequestFailedError {
921
                error: RequestError::TruncatedHeaders,
922
                ..
923
            }))
924
        ));
925
    }
926

            
927
    #[test]
928
    fn test_headers_too_long() {
929
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
930
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
931
        response_text.resize(16384, b'A');
932
        let (response, _request) = run_download_test(req, &response_text);
933

            
934
        assert!(response.as_ref().unwrap_err().should_retire_circ());
935
        assert!(matches!(
936
            response,
937
            Err(Error::RequestFailed(RequestFailedError {
938
                error: RequestError::HeadersTooLong(_),
939
                ..
940
            }))
941
        ));
942
    }
943

            
944
    // TODO: test with bad utf-8
945
}