1
#![cfg_attr(docsrs, feature(doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
#![deny(clippy::string_slice)] // See arti#2571
48
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
49

            
50
// TODO probably remove this at some point - see tpo/core/arti#1060
51
#![cfg_attr(
52
    not(all(feature = "full", feature = "experimental")),
53
    allow(unused_imports)
54
)]
55

            
56
mod body;
57
mod err;
58
pub mod request;
59
mod response;
60
mod util;
61

            
62
use tor_circmgr::{CircMgr, DirInfo};
63
use tor_error::bad_api_usage;
64
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
65

            
66
// Zlib is required; the others are optional.
67
#[cfg(feature = "xz")]
68
use async_compression::futures::bufread::XzDecoder;
69
use async_compression::futures::bufread::ZlibDecoder;
70
#[cfg(feature = "zstd")]
71
use async_compression::futures::bufread::ZstdDecoder;
72

            
73
use futures::FutureExt;
74
use futures::io::{
75
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
76
};
77
use memchr::memchr;
78
use std::sync::Arc;
79
use std::time::Duration;
80
use tracing::{info, instrument};
81

            
82
pub use err::{Error, RequestError, RequestFailedError};
83
pub use response::{DirResponse, SourceInfo};
84

            
85
/// Type for results returned in this crate.
86
pub type Result<T> = std::result::Result<T, Error>;
87

            
88
/// Type for internal results  containing a RequestError.
89
pub type RequestResult<T> = std::result::Result<T, RequestError>;
90

            
91
/// Flag to declare whether a request is always anonymized or not.
92
///
93
/// This is used by tor-dirclient to control whether *other* deanonymizing metadata
94
/// might be added to the request (eg in request headers):
95
/// Some requests (like those to download onion service descriptors) are always
96
/// anonymized, and should never be sent in a way that leaks information about
97
/// our settings or configuration.
98
///
99
/// It is up to the *caller* of `tor-dirclient` to ensure that
100
///
101
///   - every request whose anonymization status is `AnonymizedRequest::Direct`
102
///     is sent only over non-anonymous connections.
103
///
104
///     (Sending an `AnonymizedRequest::Direct` request over an anonymized connection
105
///     would weaken the connection's anonymity, and can therefore weaken the anonymity
106
///     of user traffic sharing the same circuit.)
107
///
108
///   - every request whose anonymization status is `AnonymizedRequest::Anonymized`
109
///     is sent over only anonymous connections (ie, multi-hop circuits).
110
///
111
///     (Sending an `AnonymizedRequest::Anonymized` request over a direct connection
112
///     would directly reveal user behaviour data to the directory server.)
113
///
114
/// TODO the calling code cannot easily be sure to get this right this because
115
/// the anonymization status is a run-time property and the choice of connection kind
116
/// is statically defined in the calling code.  (Perhaps this could be checked in tests?)
117
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118
#[non_exhaustive]
119
pub enum AnonymizedRequest {
120
    /// This request's content or semantics reveals or is correlated with sensitive information.
121
    ///
122
    /// For example, requests for hidden service descriptors reveal which hidden services
123
    /// the client is connecting to.
124
    ///
125
    /// The request must be sent over an anonymous circuit by the caller
126
    /// and no additional deanonymizing information should be added to it by `tor-dirclient`.
127
    /// (For example, no client-version-specific information should be
128
    /// sent in HTTP headers when the request is made.)
129
    Anonymized,
130

            
131
    /// Making this request does not reveal anything sensitive, nor any user behaviour.
132
    ///
133
    /// The request body is uncorrelated with such things as the websites the user might visit,
134
    /// the onion services the user is visiting or running, etc.
135
    ///
136
    /// For example, requests for all router microdescriptors are made by all clients,
137
    /// so which microdescriptor(s) are requested reveals nothing to any attacker.
138
    ///
139
    /// tor-dirclient is allowed to add include information about our capabilities
140
    /// when sending this request.
141
    /// The request must *not* be sent over an anonymous circuit by the caller
142
    /// (at least, not one used for anything else).
143
    Direct,
144
}
145

            
146
/// Fetch the resource described by `req` over the Tor network.
147
///
148
/// Circuits are built or found using `circ_mgr`, using paths
149
/// constructed using `dirinfo`.
150
///
151
/// For more fine-grained control over the circuit and stream used,
152
/// construct them yourself, and then call [`send_request`] instead.
153
///
154
/// # TODO
155
///
156
/// This is the only function in this crate that knows about CircMgr and
157
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
158
#[instrument(level = "trace", skip_all)]
159
pub async fn get_resource<CR, R, SP>(
160
    req: &CR,
161
    dirinfo: DirInfo<'_>,
162
    runtime: &SP,
163
    circ_mgr: Arc<CircMgr<R>>,
164
) -> Result<DirResponse>
165
where
166
    CR: request::Requestable + ?Sized,
167
    R: Runtime,
168
    SP: SleepProvider,
169
{
170
    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
171

            
172
    if req.anonymized() == AnonymizedRequest::Anonymized {
173
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
174
    }
175

            
176
    // TODO(nickm) This should be an option, and is too long.
177
    let begin_timeout = Duration::from_secs(5);
178
    let source = match SourceInfo::from_tunnel(&tunnel) {
179
        Ok(source) => source,
180
        Err(e) => {
181
            return Err(Error::RequestFailed(RequestFailedError {
182
                source: None,
183
                error: e.into(),
184
            }));
185
        }
186
    };
187

            
188
    let wrap_err = |error| {
189
        Error::RequestFailed(RequestFailedError {
190
            source: source.clone(),
191
            error,
192
        })
193
    };
194

            
195
    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
196

            
197
    // Launch the stream.
198
    let mut stream = runtime
199
        .timeout(begin_timeout, tunnel.begin_dir_stream())
200
        .await
201
        .map_err(RequestError::from)
202
        .map_err(wrap_err)?
203
        .map_err(RequestError::from)
204
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
205

            
206
    // TODO: Perhaps we want separate timeouts for each phase of this.
207
    // For now, we just use higher-level timeouts in `dirmgr`.
208
    let r = send_request(runtime, req, &mut stream, source.clone()).await;
209

            
210
    if should_retire_circ(&r) {
211
        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
212
    }
213

            
214
    r
215
}
216

            
217
/// Return true if `result` holds an error indicating that we should retire the
218
/// circuit used for the corresponding request.
219
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
220
    match result {
221
        Err(e) => e.should_retire_circ(),
222
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
223
    }
224
}
225

            
226
/// Fetch a Tor directory object from a provided stream.
227
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
228
pub async fn download<R, S, SP>(
229
    runtime: &SP,
230
    req: &R,
231
    stream: &mut S,
232
    source: Option<SourceInfo>,
233
) -> Result<DirResponse>
234
where
235
    R: request::Requestable + ?Sized,
236
    S: AsyncRead + AsyncWrite + Send + Unpin,
237
    SP: SleepProvider,
238
{
239
    send_request(runtime, req, stream, source).await
240
}
241

            
242
/// Fetch or upload a Tor directory object using the provided stream.
243
///
244
/// To do this, we send a simple HTTP/1.0 request for the described
245
/// object in `req` over `stream`, and then wait for a response.  In
246
/// log messages, we describe the origin of the data as coming from
247
/// `source`.
248
///
249
/// # Notes
250
///
251
/// It's kind of bogus to have a 'source' field here at all; we may
252
/// eventually want to remove it.
253
///
254
/// This function doesn't close the stream; you may want to do that
255
/// yourself.
256
///
257
/// The only error variant returned is [`Error::RequestFailed`].
258
// TODO: should the error return type change to `RequestFailedError`?
259
// If so, that would simplify some code in_dirmgr::bridgedesc.
260
208
pub async fn send_request<R, S, SP>(
261
208
    runtime: &SP,
262
208
    req: &R,
263
208
    stream: &mut S,
264
208
    source: Option<SourceInfo>,
265
208
) -> Result<DirResponse>
266
208
where
267
208
    R: request::Requestable + ?Sized,
268
208
    S: AsyncRead + AsyncWrite + Send + Unpin,
269
208
    SP: SleepProvider,
270
208
{
271
208
    let wrap_err = |error| {
272
40
        Error::RequestFailed(RequestFailedError {
273
40
            source: source.clone(),
274
40
            error,
275
40
        })
276
40
    };
277

            
278
208
    let partial_ok = req.partial_response_body_ok();
279
208
    let maxlen = req.max_response_len();
280
208
    let anonymized = req.anonymized();
281
208
    let req = req.make_request().map_err(wrap_err)?;
282
208
    let method = req.method().clone();
283
208
    let encoded = util::encode_request(&req);
284

            
285
    // Write the request.
286
384
    for chunk in encoded.iter() {
287
384
        stream
288
384
            .write_all(chunk)
289
384
            .await
290
384
            .map_err(RequestError::from)
291
384
            .map_err(wrap_err)?;
292
    }
293
208
    stream
294
208
        .flush()
295
208
        .await
296
208
        .map_err(RequestError::from)
297
208
        .map_err(wrap_err)?;
298

            
299
208
    let mut buffered = BufReader::new(stream);
300

            
301
    // Handle the response
302
    // TODO: should there be a separate timeout here?
303
208
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
304
170
    if header.status != Some(200) {
305
34
        return Ok(DirResponse::new(
306
34
            method,
307
34
            header.status.unwrap_or(0),
308
34
            header.status_message,
309
34
            None,
310
34
            vec![],
311
34
            source,
312
34
        ));
313
136
    }
314

            
315
136
    let mut decoder =
316
136
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
317

            
318
136
    let mut result = Vec::new();
319
136
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
320

            
321
136
    let ok = match (partial_ok, ok, result.len()) {
322
2
        (true, Err(e), n) if n > 0 => {
323
            // Note that we _don't_ return here: we want the partial response.
324
2
            Err(e)
325
        }
326
2
        (_, Err(e), _) => {
327
2
            return Err(wrap_err(e));
328
        }
329
132
        (_, Ok(()), _) => Ok(()),
330
    };
331

            
332
134
    Ok(DirResponse::new(
333
134
        method,
334
134
        200,
335
134
        None,
336
134
        ok.err(),
337
134
        result,
338
134
        source,
339
134
    ))
340
208
}
341

            
342
/// Maximum length for the HTTP headers in a single request or response.
343
///
344
/// Chosen more or less arbitrarily.
345
const MAX_HEADERS_LEN: usize = 16384;
346

            
347
/// Read and parse HTTP/1 headers from `stream`.
348
216
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
349
216
where
350
216
    S: AsyncBufRead + Unpin,
351
216
{
352
216
    let mut buf = Vec::with_capacity(1024);
353

            
354
    loop {
355
        // TODO: it's inefficient to do this a line at a time; it would
356
        // probably be better to read until the CRLF CRLF ending of the
357
        // response.  But this should be fast enough.
358
426
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
359

            
360
        // TODO(nickm): Better maximum and/or let this expand.
361
394
        let mut headers = [httparse::EMPTY_HEADER; 32];
362
394
        let mut response = httparse::Response::new(&mut headers);
363

            
364
394
        match response.parse(&buf[..])? {
365
            httparse::Status::Partial => {
366
                // We didn't get a whole response; we may need to try again.
367

            
368
218
                if n == 0 {
369
                    // We hit an EOF; no more progress can be made.
370
6
                    return Err(RequestError::TruncatedHeaders);
371
212
                }
372

            
373
212
                if buf.len() >= MAX_HEADERS_LEN {
374
2
                    return Err(RequestError::HeadersTooLong(buf.len()));
375
210
                }
376
            }
377
174
            httparse::Status::Complete(n_parsed) => {
378
174
                if response.code != Some(200) {
379
36
                    return Ok(HeaderStatus {
380
36
                        status: response.code,
381
36
                        status_message: response.reason.map(str::to_owned),
382
36
                        encoding: None,
383
36
                    });
384
138
                }
385
138
                let encoding = if let Some(enc) = response
386
138
                    .headers
387
138
                    .iter()
388
138
                    .find(|h| h.name == "Content-Encoding")
389
                {
390
10
                    Some(String::from_utf8(enc.value.to_vec())?)
391
                } else {
392
128
                    None
393
                };
394
                /*
395
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
396
                    let clen = std::str::from_utf8(clen.value)?;
397
                    length = Some(clen.parse()?);
398
                }
399
                 */
400
138
                assert!(n_parsed == buf.len());
401
138
                return Ok(HeaderStatus {
402
138
                    status: Some(200),
403
138
                    status_message: None,
404
138
                    encoding,
405
138
                });
406
            }
407
        }
408
210
        if n == 0 {
409
            return Err(RequestError::TruncatedHeaders);
410
210
        }
411
    }
412
216
}
413

            
414
/// Return value from read_headers
415
#[derive(Debug, Clone)]
416
struct HeaderStatus {
417
    /// HTTP status code.
418
    status: Option<u16>,
419
    /// HTTP status message associated with the status code.
420
    status_message: Option<String>,
421
    /// The Content-Encoding header, if any.
422
    encoding: Option<String>,
423
}
424

            
425
/// Helper: download directory information from `stream` and
426
/// decompress it into a result buffer.  Assumes that `buf` is empty.
427
///
428
/// If we get more than maxlen bytes after decompression, give an error.
429
///
430
/// Returns the status of our download attempt, stores any data that
431
/// we were able to download into `result`.  Existing contents of
432
/// `result` are overwritten.
433
150
async fn read_and_decompress<S, SP>(
434
150
    runtime: &SP,
435
150
    mut stream: S,
436
150
    maxlen: usize,
437
150
    result: &mut Vec<u8>,
438
150
) -> RequestResult<()>
439
150
where
440
150
    S: AsyncRead + Unpin,
441
150
    SP: SleepProvider,
442
150
{
443
150
    let buffer_window_size = 1024;
444
150
    let mut written_total: usize = 0;
445
    // TODO(nickm): This should be an option, and is maybe too long.
446
    // Though for some users it may be too short?
447
150
    let read_timeout = Duration::from_secs(10);
448
150
    let timer = runtime.sleep(read_timeout).fuse();
449
150
    futures::pin_mut!(timer);
450

            
451
    loop {
452
        // allocate buffer for next read
453
888
        result.resize(written_total + buffer_window_size, 0);
454
888
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
455

            
456
888
        let status = futures::select! {
457
888
            status = stream.read(buf).fuse() => status,
458
            _ = timer => {
459
                result.resize(written_total, 0); // truncate as needed
460
                return Err(RequestError::DirTimeout);
461
            }
462
        };
463
888
        let written_in_this_loop = match status {
464
882
            Ok(n) => n,
465
6
            Err(other) => {
466
6
                result.resize(written_total, 0); // truncate as needed
467
6
                return Err(other.into());
468
            }
469
        };
470

            
471
882
        written_total += written_in_this_loop;
472

            
473
        // exit conditions below
474

            
475
882
        if written_in_this_loop == 0 {
476
            /*
477
            in case we read less than `buffer_window_size` in last `read`
478
            we need to shrink result because otherwise we'll return those
479
            un-read 0s
480
            */
481
142
            if written_total < result.len() {
482
142
                result.resize(written_total, 0);
483
142
            }
484
142
            return Ok(());
485
740
        }
486

            
487
        // TODO: It would be good to detect compression bombs, but
488
        // that would require access to the internal stream, which
489
        // would in turn require some tricky programming.  For now, we
490
        // use the maximum length here to prevent an attacker from
491
        // filling our RAM.
492
740
        if written_total > maxlen {
493
2
            result.resize(maxlen, 0);
494
2
            return Err(RequestError::ResponseTooLong(written_total));
495
738
        }
496
    }
497
150
}
498

            
499
/// Retire a directory circuit because of an error we've encountered on it.
500
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
501
where
502
    R: Runtime,
503
{
504
    info!(
505
        "{}: Retiring circuit because of directory failure: {}",
506
        &id, &error
507
    );
508
    circ_mgr.retire_circ(id);
509
}
510

            
511
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
512
///
513
/// Note that this function might not actually read any byte of value
514
/// `byte`, since EOF might occur, or we might fill the buffer.
515
///
516
/// A return value of 0 indicates an end-of-file.
517
432
async fn read_until_limited<S>(
518
432
    stream: &mut S,
519
432
    byte: u8,
520
432
    max: usize,
521
432
    buf: &mut Vec<u8>,
522
432
) -> std::io::Result<usize>
523
432
where
524
432
    S: AsyncBufRead + Unpin,
525
432
{
526
432
    let mut n_added = 0;
527
    loop {
528
590
        let data = stream.fill_buf().await?;
529
558
        if data.is_empty() {
530
            // End-of-file has been reached.
531
12
            return Ok(n_added);
532
546
        }
533
546
        debug_assert!(n_added < max);
534
546
        let remaining_space = max - n_added;
535
546
        let (available, found_byte) = match memchr(byte, data) {
536
374
            Some(idx) => (idx + 1, true),
537
172
            None => (data.len(), false),
538
        };
539
546
        debug_assert!(available >= 1);
540
546
        let n_to_copy = std::cmp::min(remaining_space, available);
541
546
        buf.extend(&data[..n_to_copy]);
542
546
        stream.consume_unpin(n_to_copy);
543
546
        n_added += n_to_copy;
544
546
        if found_byte || n_added == max {
545
388
            return Ok(n_added);
546
158
        }
547
    }
548
432
}
549

            
550
/// Helper: Return a boxed decoder object that wraps the stream  $s.
551
macro_rules! decoder {
552
    ($dec:ident, $s:expr) => {{
553
        let mut decoder = $dec::new($s);
554
        decoder.multiple_members(true);
555
        Ok(Box::new(decoder))
556
    }};
557
}
558

            
559
/// Wrap `stream` in an appropriate type to undo the content encoding
560
/// as described in `encoding`.
561
152
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
562
152
    stream: S,
563
152
    encoding: Option<&str>,
564
152
    anonymized: AnonymizedRequest,
565
152
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
566
    use AnonymizedRequest::Direct;
567
152
    match (encoding, anonymized) {
568
148
        (None | Some("identity"), _) => Ok(Box::new(stream)),
569
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
570
        // We only admit to supporting these on a direct connection; otherwise,
571
        // a hostile directory could send them back even though we hadn't
572
        // requested them.
573
        #[cfg(feature = "xz")]
574
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
575
        #[cfg(feature = "zstd")]
576
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
577
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
578
    }
579
152
}
580

            
581
#[cfg(test)]
582
mod test {
583
    // @@ begin test lint list maintained by maint/add_warning @@
584
    #![allow(clippy::bool_assert_comparison)]
585
    #![allow(clippy::clone_on_copy)]
586
    #![allow(clippy::dbg_macro)]
587
    #![allow(clippy::mixed_attributes_style)]
588
    #![allow(clippy::print_stderr)]
589
    #![allow(clippy::print_stdout)]
590
    #![allow(clippy::single_char_pattern)]
591
    #![allow(clippy::unwrap_used)]
592
    #![allow(clippy::unchecked_time_subtraction)]
593
    #![allow(clippy::useless_vec)]
594
    #![allow(clippy::needless_pass_by_value)]
595
    #![allow(clippy::string_slice)] // See arti#2571
596
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
597
    use super::*;
598
    use tor_rtmock::io::stream_pair;
599

            
600
    use tor_rtmock::simple_time::SimpleMockTimeProvider;
601
    use web_time_compat::{SystemTime, SystemTimeExt};
602

            
603
    use futures_await_test::async_test;
604

            
605
    #[async_test]
606
    async fn test_read_until_limited() -> RequestResult<()> {
607
        let mut out = Vec::new();
608
        let bytes = b"This line eventually ends\nthen comes another\n";
609

            
610
        // Case 1: find a whole line.
611
        let mut s = &bytes[..];
612
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
613
        assert_eq!(res?, 26);
614
        assert_eq!(&out[..], b"This line eventually ends\n");
615

            
616
        // Case 2: reach the limit.
617
        let mut s = &bytes[..];
618
        out.clear();
619
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
620
        assert_eq!(res?, 10);
621
        assert_eq!(&out[..], b"This line ");
622

            
623
        // Case 3: reach EOF.
624
        let mut s = &bytes[..];
625
        out.clear();
626
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
627
        assert_eq!(res?, 45);
628
        assert_eq!(&out[..], &bytes[..]);
629

            
630
        Ok(())
631
    }
632

            
633
    // Basic decompression wrapper.
634
    async fn decomp_basic(
635
        encoding: Option<&str>,
636
        data: &[u8],
637
        maxlen: usize,
638
    ) -> (RequestResult<()>, Vec<u8>) {
639
        // We don't need to do anything fancy here, since we aren't simulating
640
        // a timeout.
641
        #[allow(deprecated)] // TODO #1885
642
        let mock_time = SimpleMockTimeProvider::from_wallclock(SystemTime::get());
643

            
644
        let mut output = Vec::new();
645
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
646
            Ok(s) => s,
647
            Err(e) => return (Err(e), output),
648
        };
649

            
650
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
651

            
652
        (r, output)
653
    }
654

            
655
    #[async_test]
656
    async fn decompress_identity() -> RequestResult<()> {
657
        let mut text = Vec::new();
658
        for _ in 0..1000 {
659
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
660
        }
661

            
662
        let limit = 10 << 20;
663
        let (s, r) = decomp_basic(None, &text[..], limit).await;
664
        s?;
665
        assert_eq!(r, text);
666

            
667
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
668
        s?;
669
        assert_eq!(r, text);
670

            
671
        // Try truncated result
672
        let limit = 100;
673
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
674
        assert!(s.is_err());
675
        assert_eq!(r, &text[..100]);
676

            
677
        Ok(())
678
    }
679

            
680
    #[async_test]
681
    async fn decomp_zlib() -> RequestResult<()> {
682
        let compressed =
683
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
684

            
685
        let limit = 10 << 20;
686
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
687
        s?;
688
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
689

            
690
        Ok(())
691
    }
692

            
693
    #[cfg(feature = "zstd")]
694
    #[async_test]
695
    async fn decomp_zstd() -> RequestResult<()> {
696
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
697
        let limit = 10 << 20;
698
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
699
        s?;
700
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
701

            
702
        Ok(())
703
    }
704

            
705
    #[cfg(feature = "xz")]
706
    #[async_test]
707
    async fn decomp_xz2() -> RequestResult<()> {
708
        // Not so good at tiny files...
709
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
710
        let limit = 10 << 20;
711
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
712
        s?;
713
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
714

            
715
        Ok(())
716
    }
717

            
718
    #[async_test]
719
    async fn decomp_unknown() {
720
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
721
        let limit = 10 << 20;
722
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
723

            
724
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
725
    }
726

            
727
    #[async_test]
728
    async fn decomp_bad_data() {
729
        let compressed = b"This is not good zlib data";
730
        let limit = 10 << 20;
731
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
732

            
733
        // This should possibly be a different type in the future.
734
        assert!(matches!(s, Err(RequestError::IoError(_))));
735
    }
736

            
737
    #[async_test]
738
    async fn headers_ok() -> RequestResult<()> {
739
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
740

            
741
        let mut s = &text[..];
742
        let h = read_headers(&mut s).await?;
743

            
744
        assert_eq!(h.status, Some(200));
745
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
746

            
747
        // now try truncated
748
        let mut s = &text[..15];
749
        let h = read_headers(&mut s).await;
750
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
751

            
752
        // now try with no encoding.
753
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
754
        let mut s = &text[..];
755
        let h = read_headers(&mut s).await?;
756

            
757
        assert_eq!(h.status, Some(404));
758
        assert!(h.encoding.is_none());
759

            
760
        Ok(())
761
    }
762

            
763
    #[async_test]
764
    async fn headers_bogus() -> Result<()> {
765
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
766
        let mut s = &text[..];
767
        let h = read_headers(&mut s).await;
768

            
769
        assert!(h.is_err());
770
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
771
        Ok(())
772
    }
773

            
774
    /// Run a trivial download example with a response provided as a binary
775
    /// string.
776
    ///
777
    /// Return the directory response (if any) and the request as encoded (if
778
    /// any.)
779
    fn run_download_test<Req: request::Requestable>(
780
        req: Req,
781
        response: &[u8],
782
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
783
        let (mut s1, s2) = stream_pair();
784
        let (mut s2_r, mut s2_w) = s2.split();
785

            
786
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
787
            let rt2 = rt.clone();
788
            let (v1, v2, v3): (
789
                Result<DirResponse>,
790
                RequestResult<Vec<u8>>,
791
                RequestResult<()>,
792
            ) = futures::join!(
793
                async {
794
                    // Run the download function.
795
                    let r = send_request(&rt, &req, &mut s1, None).await;
796
                    s1.close().await.map_err(|error| {
797
                        Error::RequestFailed(RequestFailedError {
798
                            source: None,
799
                            error: error.into(),
800
                        })
801
                    })?;
802
                    r
803
                },
804
                async {
805
                    // Take the request from the client, and return it in "v2"
806
                    let mut v = Vec::new();
807
                    s2_r.read_to_end(&mut v).await?;
808
                    Ok(v)
809
                },
810
                async {
811
                    // Send back a response.
812
                    s2_w.write_all(response).await?;
813
                    // We wait a moment to give the other side time to notice it
814
                    // has data.
815
                    //
816
                    // (Tentative diagnosis: The `async-compress` crate seems to
817
                    // be behave differently depending on whether the "close"
818
                    // comes right after the incomplete data or whether it comes
819
                    // after a delay.  If there's a delay, it notices the
820
                    // truncated data and tells us about it. But when there's
821
                    // _no_delay, it treats the data as an error and doesn't
822
                    // tell our code.)
823

            
824
                    // TODO: sleeping in tests is not great.
825
                    rt2.sleep(Duration::from_millis(50)).await;
826
                    s2_w.close().await?;
827
                    Ok(())
828
                }
829
            );
830

            
831
            assert!(v3.is_ok());
832

            
833
            (v1, v2)
834
        })
835
    }
836

            
837
    #[test]
838
    fn test_send_request() -> RequestResult<()> {
839
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
840

            
841
        let (response, request) = run_download_test(
842
            req,
843
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
844
        );
845

            
846
        let request = request?;
847
        assert!(request[..].starts_with(
848
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
849
        ));
850

            
851
        let response = response.unwrap();
852
        assert_eq!(response.status_code(), 200);
853
        assert!(!response.is_partial());
854
        assert!(response.error().is_none());
855
        assert!(response.source().is_none());
856
        let out_ref = response.output_unchecked();
857
        assert_eq!(out_ref, b"This is where the descs would go.");
858
        let out = response.into_output_unchecked();
859
        assert_eq!(&out, b"This is where the descs would go.");
860

            
861
        Ok(())
862
    }
863

            
864
    #[test]
865
    fn test_download_truncated() {
866
        // Request only one md, so "partial ok" will not be set.
867
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
868
        let mut response_text: Vec<u8> =
869
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
870
        // "One fish two fish" as above twice, but truncated the second time
871
        response_text.extend(
872
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
873
        );
874
        response_text.extend(
875
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
876
        );
877
        let (response, request) = run_download_test(req, &response_text);
878
        assert!(request.is_ok());
879
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
880

            
881
        // request two microdescs, so "partial_ok" will be set.
882
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
883

            
884
        let (response, request) = run_download_test(req, &response_text);
885
        assert!(request.is_ok());
886

            
887
        let response = response.unwrap();
888
        assert_eq!(response.status_code(), 200);
889
        assert!(response.error().is_some());
890
        assert!(response.is_partial());
891
        assert!(response.output_unchecked().len() < 37 * 2);
892
        assert!(response.output_unchecked().starts_with(b"One fish"));
893
    }
894

            
895
    #[test]
896
    fn test_404() {
897
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
898
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
899
        let (response, _request) = run_download_test(req, response_text);
900

            
901
        assert_eq!(response.unwrap().status_code(), 418);
902
    }
903

            
904
    #[test]
905
    fn test_headers_truncated() {
906
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
907
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
908
        let (response, _request) = run_download_test(req, response_text);
909

            
910
        assert!(matches!(
911
            response,
912
            Err(Error::RequestFailed(RequestFailedError {
913
                error: RequestError::TruncatedHeaders,
914
                ..
915
            }))
916
        ));
917

            
918
        // Try a completely empty response.
919
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
920
        let response_text = b"";
921
        let (response, _request) = run_download_test(req, response_text);
922

            
923
        assert!(matches!(
924
            response,
925
            Err(Error::RequestFailed(RequestFailedError {
926
                error: RequestError::TruncatedHeaders,
927
                ..
928
            }))
929
        ));
930
    }
931

            
932
    #[test]
933
    fn test_headers_too_long() {
934
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
935
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
936
        response_text.resize(16384, b'A');
937
        let (response, _request) = run_download_test(req, &response_text);
938

            
939
        assert!(response.as_ref().unwrap_err().should_retire_circ());
940
        assert!(matches!(
941
            response,
942
            Err(Error::RequestFailed(RequestFailedError {
943
                error: RequestError::HeadersTooLong(_),
944
                ..
945
            }))
946
        ));
947
    }
948

            
949
    // TODO: test with bad utf-8
950
}