1
#![cfg_attr(docsrs, feature(doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48

            
49
// TODO probably remove this at some point - see tpo/core/arti#1060
50
#![cfg_attr(
51
    not(all(feature = "full", feature = "experimental")),
52
    allow(unused_imports)
53
)]
54

            
55
mod body;
56
mod err;
57
pub mod request;
58
mod response;
59
mod util;
60

            
61
use tor_circmgr::{CircMgr, DirInfo};
62
use tor_error::bad_api_usage;
63
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
64

            
65
// Zlib is required; the others are optional.
66
#[cfg(feature = "xz")]
67
use async_compression::futures::bufread::XzDecoder;
68
use async_compression::futures::bufread::ZlibDecoder;
69
#[cfg(feature = "zstd")]
70
use async_compression::futures::bufread::ZstdDecoder;
71

            
72
use futures::FutureExt;
73
use futures::io::{
74
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
75
};
76
use memchr::memchr;
77
use std::sync::Arc;
78
use std::time::Duration;
79
use tracing::{info, instrument};
80

            
81
pub use err::{Error, RequestError, RequestFailedError};
82
pub use response::{DirResponse, SourceInfo};
83

            
84
/// Type for results returned in this crate.
85
pub type Result<T> = std::result::Result<T, Error>;
86

            
87
/// Type for internal results  containing a RequestError.
88
pub type RequestResult<T> = std::result::Result<T, RequestError>;
89

            
90
/// Flag to declare whether a request is always anonymized or not.
91
///
92
/// This is used by tor-dirclient to control whether *other* deanonymizing metadata
93
/// might be added to the request (eg in request headers):
94
/// Some requests (like those to download onion service descriptors) are always
95
/// anonymized, and should never be sent in a way that leaks information about
96
/// our settings or configuration.
97
///
98
/// It is up to the *caller* of `tor-dirclient` to ensure that
99
///
100
///   - every request whose anonymization status is `AnonymizedRequest::Direct`
101
///     is sent only over non-anonymous connections.
102
///
103
///     (Sending an `AnonymizedRequest::Direct` request over an anonymized connection
104
///     would weaken the connection's anonymity, and can therefore weaken the anonymity
105
///     of user traffic sharing the same circuit.)
106
///
107
///   - every request whose anonymization status is `AnonymizedRequest::Anonymized`
108
///     is sent over only anonymous connections (ie, multi-hop circuits).
109
///
110
///     (Sending an `AnonymizedRequest::Anonymized` request over a direct connection
111
///     would directly reveal user behaviour data to the directory server.)
112
///
113
/// TODO the calling code cannot easily be sure to get this right this because
114
/// the anonymization status is a run-time property and the choice of connection kind
115
/// is statically defined in the calling code.  (Perhaps this could be checked in tests?)
116
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
117
#[non_exhaustive]
118
pub enum AnonymizedRequest {
119
    /// This request's content or semantics reveals or is correlated with sensitive information.
120
    ///
121
    /// For example, requests for hidden service descriptors reveal which hidden services
122
    /// the client is connecting to.
123
    ///
124
    /// The request must be sent over an anonymous circuit by the caller
125
    /// and no additional deanonymizing information should be added to it by `tor-dirclient`.
126
    /// (For example, no client-version-specific information should be
127
    /// sent in HTTP headers when the request is made.)
128
    Anonymized,
129

            
130
    /// Making this request does not reveal anything sensitive, nor any user behaviour.
131
    ///
132
    /// The request body is uncorrelated with such things as the websites the user might visit,
133
    /// the onion services the user is visiting or running, etc.
134
    ///
135
    /// For example, requests for all router microdescriptors are made by all clients,
136
    /// so which microdescriptor(s) are requested reveals nothing to any attacker.
137
    ///
138
    /// tor-dirclient is allowed to add include information about our capabilities
139
    /// when sending this request.
140
    /// The request must *not* be sent over an anonymous circuit by the caller
141
    /// (at least, not one used for anything else).
142
    Direct,
143
}
144

            
145
/// Fetch the resource described by `req` over the Tor network.
146
///
147
/// Circuits are built or found using `circ_mgr`, using paths
148
/// constructed using `dirinfo`.
149
///
150
/// For more fine-grained control over the circuit and stream used,
151
/// construct them yourself, and then call [`send_request`] instead.
152
///
153
/// # TODO
154
///
155
/// This is the only function in this crate that knows about CircMgr and
156
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
157
#[instrument(level = "trace", skip_all)]
158
pub async fn get_resource<CR, R, SP>(
159
    req: &CR,
160
    dirinfo: DirInfo<'_>,
161
    runtime: &SP,
162
    circ_mgr: Arc<CircMgr<R>>,
163
) -> Result<DirResponse>
164
where
165
    CR: request::Requestable + ?Sized,
166
    R: Runtime,
167
    SP: SleepProvider,
168
{
169
    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
170

            
171
    if req.anonymized() == AnonymizedRequest::Anonymized {
172
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
173
    }
174

            
175
    // TODO(nickm) This should be an option, and is too long.
176
    let begin_timeout = Duration::from_secs(5);
177
    let source = match SourceInfo::from_tunnel(&tunnel) {
178
        Ok(source) => source,
179
        Err(e) => {
180
            return Err(Error::RequestFailed(RequestFailedError {
181
                source: None,
182
                error: e.into(),
183
            }));
184
        }
185
    };
186

            
187
    let wrap_err = |error| {
188
        Error::RequestFailed(RequestFailedError {
189
            source: source.clone(),
190
            error,
191
        })
192
    };
193

            
194
    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
195

            
196
    // Launch the stream.
197
    let mut stream = runtime
198
        .timeout(begin_timeout, tunnel.begin_dir_stream())
199
        .await
200
        .map_err(RequestError::from)
201
        .map_err(wrap_err)?
202
        .map_err(RequestError::from)
203
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
204

            
205
    // TODO: Perhaps we want separate timeouts for each phase of this.
206
    // For now, we just use higher-level timeouts in `dirmgr`.
207
    let r = send_request(runtime, req, &mut stream, source.clone()).await;
208

            
209
    if should_retire_circ(&r) {
210
        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
211
    }
212

            
213
    r
214
}
215

            
216
/// Return true if `result` holds an error indicating that we should retire the
217
/// circuit used for the corresponding request.
218
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
219
    match result {
220
        Err(e) => e.should_retire_circ(),
221
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
222
    }
223
}
224

            
225
/// Fetch a Tor directory object from a provided stream.
226
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
227
pub async fn download<R, S, SP>(
228
    runtime: &SP,
229
    req: &R,
230
    stream: &mut S,
231
    source: Option<SourceInfo>,
232
) -> Result<DirResponse>
233
where
234
    R: request::Requestable + ?Sized,
235
    S: AsyncRead + AsyncWrite + Send + Unpin,
236
    SP: SleepProvider,
237
{
238
    send_request(runtime, req, stream, source).await
239
}
240

            
241
/// Fetch or upload a Tor directory object using the provided stream.
242
///
243
/// To do this, we send a simple HTTP/1.0 request for the described
244
/// object in `req` over `stream`, and then wait for a response.  In
245
/// log messages, we describe the origin of the data as coming from
246
/// `source`.
247
///
248
/// # Notes
249
///
250
/// It's kind of bogus to have a 'source' field here at all; we may
251
/// eventually want to remove it.
252
///
253
/// This function doesn't close the stream; you may want to do that
254
/// yourself.
255
///
256
/// The only error variant returned is [`Error::RequestFailed`].
257
// TODO: should the error return type change to `RequestFailedError`?
258
// If so, that would simplify some code in_dirmgr::bridgedesc.
259
192
pub async fn send_request<R, S, SP>(
260
192
    runtime: &SP,
261
192
    req: &R,
262
192
    stream: &mut S,
263
192
    source: Option<SourceInfo>,
264
192
) -> Result<DirResponse>
265
192
where
266
192
    R: request::Requestable + ?Sized,
267
192
    S: AsyncRead + AsyncWrite + Send + Unpin,
268
192
    SP: SleepProvider,
269
192
{
270
192
    let wrap_err = |error| {
271
40
        Error::RequestFailed(RequestFailedError {
272
40
            source: source.clone(),
273
40
            error,
274
40
        })
275
40
    };
276

            
277
192
    let partial_ok = req.partial_response_body_ok();
278
192
    let maxlen = req.max_response_len();
279
192
    let anonymized = req.anonymized();
280
192
    let req = req.make_request().map_err(wrap_err)?;
281
192
    let method = req.method().clone();
282
192
    let encoded = util::encode_request(&req);
283

            
284
    // Write the request.
285
352
    for chunk in encoded.iter() {
286
352
        stream
287
352
            .write_all(chunk)
288
352
            .await
289
352
            .map_err(RequestError::from)
290
352
            .map_err(wrap_err)?;
291
    }
292
192
    stream
293
192
        .flush()
294
192
        .await
295
192
        .map_err(RequestError::from)
296
192
        .map_err(wrap_err)?;
297

            
298
192
    let mut buffered = BufReader::new(stream);
299

            
300
    // Handle the response
301
    // TODO: should there be a separate timeout here?
302
192
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
303
154
    if header.status != Some(200) {
304
34
        return Ok(DirResponse::new(
305
34
            method,
306
34
            header.status.unwrap_or(0),
307
34
            header.status_message,
308
34
            None,
309
34
            vec![],
310
34
            source,
311
34
        ));
312
120
    }
313

            
314
120
    let mut decoder =
315
120
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
316

            
317
120
    let mut result = Vec::new();
318
120
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
319

            
320
120
    let ok = match (partial_ok, ok, result.len()) {
321
2
        (true, Err(e), n) if n > 0 => {
322
            // Note that we _don't_ return here: we want the partial response.
323
2
            Err(e)
324
        }
325
2
        (_, Err(e), _) => {
326
2
            return Err(wrap_err(e));
327
        }
328
116
        (_, Ok(()), _) => Ok(()),
329
    };
330

            
331
118
    Ok(DirResponse::new(
332
118
        method,
333
118
        200,
334
118
        None,
335
118
        ok.err(),
336
118
        result,
337
118
        source,
338
118
    ))
339
192
}
340

            
341
/// Maximum length for the HTTP headers in a single request or response.
342
///
343
/// Chosen more or less arbitrarily.
344
const MAX_HEADERS_LEN: usize = 16384;
345

            
346
/// Read and parse HTTP/1 headers from `stream`.
347
200
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
348
200
where
349
200
    S: AsyncBufRead + Unpin,
350
200
{
351
200
    let mut buf = Vec::with_capacity(1024);
352

            
353
    loop {
354
        // TODO: it's inefficient to do this a line at a time; it would
355
        // probably be better to read until the CRLF CRLF ending of the
356
        // response.  But this should be fast enough.
357
394
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
358

            
359
        // TODO(nickm): Better maximum and/or let this expand.
360
362
        let mut headers = [httparse::EMPTY_HEADER; 32];
361
362
        let mut response = httparse::Response::new(&mut headers);
362

            
363
362
        match response.parse(&buf[..])? {
364
            httparse::Status::Partial => {
365
                // We didn't get a whole response; we may need to try again.
366

            
367
202
                if n == 0 {
368
                    // We hit an EOF; no more progress can be made.
369
6
                    return Err(RequestError::TruncatedHeaders);
370
196
                }
371

            
372
196
                if buf.len() >= MAX_HEADERS_LEN {
373
2
                    return Err(RequestError::HeadersTooLong(buf.len()));
374
194
                }
375
            }
376
158
            httparse::Status::Complete(n_parsed) => {
377
158
                if response.code != Some(200) {
378
36
                    return Ok(HeaderStatus {
379
36
                        status: response.code,
380
36
                        status_message: response.reason.map(str::to_owned),
381
36
                        encoding: None,
382
36
                    });
383
122
                }
384
122
                let encoding = if let Some(enc) = response
385
122
                    .headers
386
122
                    .iter()
387
122
                    .find(|h| h.name == "Content-Encoding")
388
                {
389
10
                    Some(String::from_utf8(enc.value.to_vec())?)
390
                } else {
391
112
                    None
392
                };
393
                /*
394
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
395
                    let clen = std::str::from_utf8(clen.value)?;
396
                    length = Some(clen.parse()?);
397
                }
398
                 */
399
122
                assert!(n_parsed == buf.len());
400
122
                return Ok(HeaderStatus {
401
122
                    status: Some(200),
402
122
                    status_message: None,
403
122
                    encoding,
404
122
                });
405
            }
406
        }
407
194
        if n == 0 {
408
            return Err(RequestError::TruncatedHeaders);
409
194
        }
410
    }
411
200
}
412

            
413
/// Return value from read_headers
414
#[derive(Debug, Clone)]
415
struct HeaderStatus {
416
    /// HTTP status code.
417
    status: Option<u16>,
418
    /// HTTP status message associated with the status code.
419
    status_message: Option<String>,
420
    /// The Content-Encoding header, if any.
421
    encoding: Option<String>,
422
}
423

            
424
/// Helper: download directory information from `stream` and
425
/// decompress it into a result buffer.  Assumes that `buf` is empty.
426
///
427
/// If we get more than maxlen bytes after decompression, give an error.
428
///
429
/// Returns the status of our download attempt, stores any data that
430
/// we were able to download into `result`.  Existing contents of
431
/// `result` are overwritten.
432
134
async fn read_and_decompress<S, SP>(
433
134
    runtime: &SP,
434
134
    mut stream: S,
435
134
    maxlen: usize,
436
134
    result: &mut Vec<u8>,
437
134
) -> RequestResult<()>
438
134
where
439
134
    S: AsyncRead + Unpin,
440
134
    SP: SleepProvider,
441
134
{
442
134
    let buffer_window_size = 1024;
443
134
    let mut written_total: usize = 0;
444
    // TODO(nickm): This should be an option, and is maybe too long.
445
    // Though for some users it may be too short?
446
134
    let read_timeout = Duration::from_secs(10);
447
134
    let timer = runtime.sleep(read_timeout).fuse();
448
134
    futures::pin_mut!(timer);
449

            
450
    loop {
451
        // allocate buffer for next read
452
872
        result.resize(written_total + buffer_window_size, 0);
453
872
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
454

            
455
872
        let status = futures::select! {
456
872
            status = stream.read(buf).fuse() => status,
457
            _ = timer => {
458
                result.resize(written_total, 0); // truncate as needed
459
                return Err(RequestError::DirTimeout);
460
            }
461
        };
462
872
        let written_in_this_loop = match status {
463
866
            Ok(n) => n,
464
6
            Err(other) => {
465
6
                result.resize(written_total, 0); // truncate as needed
466
6
                return Err(other.into());
467
            }
468
        };
469

            
470
866
        written_total += written_in_this_loop;
471

            
472
        // exit conditions below
473

            
474
866
        if written_in_this_loop == 0 {
475
            /*
476
            in case we read less than `buffer_window_size` in last `read`
477
            we need to shrink result because otherwise we'll return those
478
            un-read 0s
479
            */
480
126
            if written_total < result.len() {
481
126
                result.resize(written_total, 0);
482
126
            }
483
126
            return Ok(());
484
740
        }
485

            
486
        // TODO: It would be good to detect compression bombs, but
487
        // that would require access to the internal stream, which
488
        // would in turn require some tricky programming.  For now, we
489
        // use the maximum length here to prevent an attacker from
490
        // filling our RAM.
491
740
        if written_total > maxlen {
492
2
            result.resize(maxlen, 0);
493
2
            return Err(RequestError::ResponseTooLong(written_total));
494
738
        }
495
    }
496
134
}
497

            
498
/// Retire a directory circuit because of an error we've encountered on it.
499
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
500
where
501
    R: Runtime,
502
{
503
    info!(
504
        "{}: Retiring circuit because of directory failure: {}",
505
        &id, &error
506
    );
507
    circ_mgr.retire_circ(id);
508
}
509

            
510
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
511
///
512
/// Note that this function might not actually read any byte of value
513
/// `byte`, since EOF might occur, or we might fill the buffer.
514
///
515
/// A return value of 0 indicates an end-of-file.
516
400
async fn read_until_limited<S>(
517
400
    stream: &mut S,
518
400
    byte: u8,
519
400
    max: usize,
520
400
    buf: &mut Vec<u8>,
521
400
) -> std::io::Result<usize>
522
400
where
523
400
    S: AsyncBufRead + Unpin,
524
400
{
525
400
    let mut n_added = 0;
526
    loop {
527
558
        let data = stream.fill_buf().await?;
528
526
        if data.is_empty() {
529
            // End-of-file has been reached.
530
12
            return Ok(n_added);
531
514
        }
532
514
        debug_assert!(n_added < max);
533
514
        let remaining_space = max - n_added;
534
514
        let (available, found_byte) = match memchr(byte, data) {
535
342
            Some(idx) => (idx + 1, true),
536
172
            None => (data.len(), false),
537
        };
538
514
        debug_assert!(available >= 1);
539
514
        let n_to_copy = std::cmp::min(remaining_space, available);
540
514
        buf.extend(&data[..n_to_copy]);
541
514
        stream.consume_unpin(n_to_copy);
542
514
        n_added += n_to_copy;
543
514
        if found_byte || n_added == max {
544
356
            return Ok(n_added);
545
158
        }
546
    }
547
400
}
548

            
549
/// Helper: Return a boxed decoder object that wraps the stream  $s.
550
macro_rules! decoder {
551
    ($dec:ident, $s:expr) => {{
552
        let mut decoder = $dec::new($s);
553
        decoder.multiple_members(true);
554
        Ok(Box::new(decoder))
555
    }};
556
}
557

            
558
/// Wrap `stream` in an appropriate type to undo the content encoding
559
/// as described in `encoding`.
560
136
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
561
136
    stream: S,
562
136
    encoding: Option<&str>,
563
136
    anonymized: AnonymizedRequest,
564
136
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
565
    use AnonymizedRequest::Direct;
566
136
    match (encoding, anonymized) {
567
132
        (None | Some("identity"), _) => Ok(Box::new(stream)),
568
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
569
        // We only admit to supporting these on a direct connection; otherwise,
570
        // a hostile directory could send them back even though we hadn't
571
        // requested them.
572
        #[cfg(feature = "xz")]
573
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
574
        #[cfg(feature = "zstd")]
575
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
576
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
577
    }
578
136
}
579

            
580
#[cfg(test)]
581
mod test {
582
    // @@ begin test lint list maintained by maint/add_warning @@
583
    #![allow(clippy::bool_assert_comparison)]
584
    #![allow(clippy::clone_on_copy)]
585
    #![allow(clippy::dbg_macro)]
586
    #![allow(clippy::mixed_attributes_style)]
587
    #![allow(clippy::print_stderr)]
588
    #![allow(clippy::print_stdout)]
589
    #![allow(clippy::single_char_pattern)]
590
    #![allow(clippy::unwrap_used)]
591
    #![allow(clippy::unchecked_time_subtraction)]
592
    #![allow(clippy::useless_vec)]
593
    #![allow(clippy::needless_pass_by_value)]
594
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
595
    use super::*;
596
    use tor_rtmock::io::stream_pair;
597

            
598
    use tor_rtmock::simple_time::SimpleMockTimeProvider;
599
    use web_time_compat::{SystemTime, SystemTimeExt};
600

            
601
    use futures_await_test::async_test;
602

            
603
    #[async_test]
604
    async fn test_read_until_limited() -> RequestResult<()> {
605
        let mut out = Vec::new();
606
        let bytes = b"This line eventually ends\nthen comes another\n";
607

            
608
        // Case 1: find a whole line.
609
        let mut s = &bytes[..];
610
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
611
        assert_eq!(res?, 26);
612
        assert_eq!(&out[..], b"This line eventually ends\n");
613

            
614
        // Case 2: reach the limit.
615
        let mut s = &bytes[..];
616
        out.clear();
617
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
618
        assert_eq!(res?, 10);
619
        assert_eq!(&out[..], b"This line ");
620

            
621
        // Case 3: reach EOF.
622
        let mut s = &bytes[..];
623
        out.clear();
624
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
625
        assert_eq!(res?, 45);
626
        assert_eq!(&out[..], &bytes[..]);
627

            
628
        Ok(())
629
    }
630

            
631
    // Basic decompression wrapper.
632
    async fn decomp_basic(
633
        encoding: Option<&str>,
634
        data: &[u8],
635
        maxlen: usize,
636
    ) -> (RequestResult<()>, Vec<u8>) {
637
        // We don't need to do anything fancy here, since we aren't simulating
638
        // a timeout.
639
        #[allow(deprecated)] // TODO #1885
640
        let mock_time = SimpleMockTimeProvider::from_wallclock(SystemTime::get());
641

            
642
        let mut output = Vec::new();
643
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
644
            Ok(s) => s,
645
            Err(e) => return (Err(e), output),
646
        };
647

            
648
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
649

            
650
        (r, output)
651
    }
652

            
653
    #[async_test]
654
    async fn decompress_identity() -> RequestResult<()> {
655
        let mut text = Vec::new();
656
        for _ in 0..1000 {
657
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
658
        }
659

            
660
        let limit = 10 << 20;
661
        let (s, r) = decomp_basic(None, &text[..], limit).await;
662
        s?;
663
        assert_eq!(r, text);
664

            
665
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
666
        s?;
667
        assert_eq!(r, text);
668

            
669
        // Try truncated result
670
        let limit = 100;
671
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
672
        assert!(s.is_err());
673
        assert_eq!(r, &text[..100]);
674

            
675
        Ok(())
676
    }
677

            
678
    #[async_test]
679
    async fn decomp_zlib() -> RequestResult<()> {
680
        let compressed =
681
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
682

            
683
        let limit = 10 << 20;
684
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
685
        s?;
686
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
687

            
688
        Ok(())
689
    }
690

            
691
    #[cfg(feature = "zstd")]
692
    #[async_test]
693
    async fn decomp_zstd() -> RequestResult<()> {
694
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
695
        let limit = 10 << 20;
696
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
697
        s?;
698
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
699

            
700
        Ok(())
701
    }
702

            
703
    #[cfg(feature = "xz")]
704
    #[async_test]
705
    async fn decomp_xz2() -> RequestResult<()> {
706
        // Not so good at tiny files...
707
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
708
        let limit = 10 << 20;
709
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
710
        s?;
711
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
712

            
713
        Ok(())
714
    }
715

            
716
    #[async_test]
717
    async fn decomp_unknown() {
718
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
719
        let limit = 10 << 20;
720
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
721

            
722
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
723
    }
724

            
725
    #[async_test]
726
    async fn decomp_bad_data() {
727
        let compressed = b"This is not good zlib data";
728
        let limit = 10 << 20;
729
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
730

            
731
        // This should possibly be a different type in the future.
732
        assert!(matches!(s, Err(RequestError::IoError(_))));
733
    }
734

            
735
    #[async_test]
736
    async fn headers_ok() -> RequestResult<()> {
737
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
738

            
739
        let mut s = &text[..];
740
        let h = read_headers(&mut s).await?;
741

            
742
        assert_eq!(h.status, Some(200));
743
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
744

            
745
        // now try truncated
746
        let mut s = &text[..15];
747
        let h = read_headers(&mut s).await;
748
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
749

            
750
        // now try with no encoding.
751
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
752
        let mut s = &text[..];
753
        let h = read_headers(&mut s).await?;
754

            
755
        assert_eq!(h.status, Some(404));
756
        assert!(h.encoding.is_none());
757

            
758
        Ok(())
759
    }
760

            
761
    #[async_test]
762
    async fn headers_bogus() -> Result<()> {
763
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
764
        let mut s = &text[..];
765
        let h = read_headers(&mut s).await;
766

            
767
        assert!(h.is_err());
768
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
769
        Ok(())
770
    }
771

            
772
    /// Run a trivial download example with a response provided as a binary
773
    /// string.
774
    ///
775
    /// Return the directory response (if any) and the request as encoded (if
776
    /// any.)
777
    fn run_download_test<Req: request::Requestable>(
778
        req: Req,
779
        response: &[u8],
780
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
781
        let (mut s1, s2) = stream_pair();
782
        let (mut s2_r, mut s2_w) = s2.split();
783

            
784
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
785
            let rt2 = rt.clone();
786
            let (v1, v2, v3): (
787
                Result<DirResponse>,
788
                RequestResult<Vec<u8>>,
789
                RequestResult<()>,
790
            ) = futures::join!(
791
                async {
792
                    // Run the download function.
793
                    let r = send_request(&rt, &req, &mut s1, None).await;
794
                    s1.close().await.map_err(|error| {
795
                        Error::RequestFailed(RequestFailedError {
796
                            source: None,
797
                            error: error.into(),
798
                        })
799
                    })?;
800
                    r
801
                },
802
                async {
803
                    // Take the request from the client, and return it in "v2"
804
                    let mut v = Vec::new();
805
                    s2_r.read_to_end(&mut v).await?;
806
                    Ok(v)
807
                },
808
                async {
809
                    // Send back a response.
810
                    s2_w.write_all(response).await?;
811
                    // We wait a moment to give the other side time to notice it
812
                    // has data.
813
                    //
814
                    // (Tentative diagnosis: The `async-compress` crate seems to
815
                    // be behave differently depending on whether the "close"
816
                    // comes right after the incomplete data or whether it comes
817
                    // after a delay.  If there's a delay, it notices the
818
                    // truncated data and tells us about it. But when there's
819
                    // _no_delay, it treats the data as an error and doesn't
820
                    // tell our code.)
821

            
822
                    // TODO: sleeping in tests is not great.
823
                    rt2.sleep(Duration::from_millis(50)).await;
824
                    s2_w.close().await?;
825
                    Ok(())
826
                }
827
            );
828

            
829
            assert!(v3.is_ok());
830

            
831
            (v1, v2)
832
        })
833
    }
834

            
835
    #[test]
836
    fn test_send_request() -> RequestResult<()> {
837
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
838

            
839
        let (response, request) = run_download_test(
840
            req,
841
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
842
        );
843

            
844
        let request = request?;
845
        assert!(request[..].starts_with(
846
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
847
        ));
848

            
849
        let response = response.unwrap();
850
        assert_eq!(response.status_code(), 200);
851
        assert!(!response.is_partial());
852
        assert!(response.error().is_none());
853
        assert!(response.source().is_none());
854
        let out_ref = response.output_unchecked();
855
        assert_eq!(out_ref, b"This is where the descs would go.");
856
        let out = response.into_output_unchecked();
857
        assert_eq!(&out, b"This is where the descs would go.");
858

            
859
        Ok(())
860
    }
861

            
862
    #[test]
863
    fn test_download_truncated() {
864
        // Request only one md, so "partial ok" will not be set.
865
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
866
        let mut response_text: Vec<u8> =
867
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
868
        // "One fish two fish" as above twice, but truncated the second time
869
        response_text.extend(
870
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
871
        );
872
        response_text.extend(
873
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
874
        );
875
        let (response, request) = run_download_test(req, &response_text);
876
        assert!(request.is_ok());
877
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
878

            
879
        // request two microdescs, so "partial_ok" will be set.
880
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
881

            
882
        let (response, request) = run_download_test(req, &response_text);
883
        assert!(request.is_ok());
884

            
885
        let response = response.unwrap();
886
        assert_eq!(response.status_code(), 200);
887
        assert!(response.error().is_some());
888
        assert!(response.is_partial());
889
        assert!(response.output_unchecked().len() < 37 * 2);
890
        assert!(response.output_unchecked().starts_with(b"One fish"));
891
    }
892

            
893
    #[test]
894
    fn test_404() {
895
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
896
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
897
        let (response, _request) = run_download_test(req, response_text);
898

            
899
        assert_eq!(response.unwrap().status_code(), 418);
900
    }
901

            
902
    #[test]
903
    fn test_headers_truncated() {
904
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
905
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
906
        let (response, _request) = run_download_test(req, response_text);
907

            
908
        assert!(matches!(
909
            response,
910
            Err(Error::RequestFailed(RequestFailedError {
911
                error: RequestError::TruncatedHeaders,
912
                ..
913
            }))
914
        ));
915

            
916
        // Try a completely empty response.
917
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
918
        let response_text = b"";
919
        let (response, _request) = run_download_test(req, response_text);
920

            
921
        assert!(matches!(
922
            response,
923
            Err(Error::RequestFailed(RequestFailedError {
924
                error: RequestError::TruncatedHeaders,
925
                ..
926
            }))
927
        ));
928
    }
929

            
930
    #[test]
931
    fn test_headers_too_long() {
932
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
933
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
934
        response_text.resize(16384, b'A');
935
        let (response, _request) = run_download_test(req, &response_text);
936

            
937
        assert!(response.as_ref().unwrap_err().should_retire_circ());
938
        assert!(matches!(
939
            response,
940
            Err(Error::RequestFailed(RequestFailedError {
941
                error: RequestError::HeadersTooLong(_),
942
                ..
943
            }))
944
        ));
945
    }
946

            
947
    // TODO: test with bad utf-8
948
}