1
#![cfg_attr(docsrs, feature(doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48

            
49
// TODO probably remove this at some point - see tpo/core/arti#1060
50
#![cfg_attr(
51
    not(all(feature = "full", feature = "experimental")),
52
    allow(unused_imports)
53
)]
54

            
55
mod err;
56
pub mod request;
57
mod response;
58
mod util;
59

            
60
use tor_circmgr::{CircMgr, DirInfo};
61
use tor_error::bad_api_usage;
62
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
63

            
64
// Zlib is required; the others are optional.
65
#[cfg(feature = "xz")]
66
use async_compression::futures::bufread::XzDecoder;
67
use async_compression::futures::bufread::ZlibDecoder;
68
#[cfg(feature = "zstd")]
69
use async_compression::futures::bufread::ZstdDecoder;
70

            
71
use futures::FutureExt;
72
use futures::io::{
73
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
74
};
75
use memchr::memchr;
76
use std::sync::Arc;
77
use std::time::Duration;
78
use tracing::{info, instrument};
79

            
80
pub use err::{Error, RequestError, RequestFailedError};
81
pub use response::{DirResponse, SourceInfo};
82

            
83
/// Type for results returned in this crate.
84
pub type Result<T> = std::result::Result<T, Error>;
85

            
86
/// Type for internal results  containing a RequestError.
87
pub type RequestResult<T> = std::result::Result<T, RequestError>;
88

            
89
/// Flag to declare whether a request is anonymized or not.
90
///
91
/// Some requests (like those to download onion service descriptors) are always
92
/// anonymized, and should never be sent in a way that leaks information about
93
/// our settings or configuration.
94
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
95
#[non_exhaustive]
96
pub enum AnonymizedRequest {
97
    /// This request should not leak any information about our configuration.
98
    Anonymized,
99
    /// This request is allowed to include information about our capabilities.
100
    Direct,
101
}
102

            
103
/// Fetch the resource described by `req` over the Tor network.
104
///
105
/// Circuits are built or found using `circ_mgr`, using paths
106
/// constructed using `dirinfo`.
107
///
108
/// For more fine-grained control over the circuit and stream used,
109
/// construct them yourself, and then call [`send_request`] instead.
110
///
111
/// # TODO
112
///
113
/// This is the only function in this crate that knows about CircMgr and
114
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
115
#[instrument(level = "trace", skip_all)]
116
pub async fn get_resource<CR, R, SP>(
117
    req: &CR,
118
    dirinfo: DirInfo<'_>,
119
    runtime: &SP,
120
    circ_mgr: Arc<CircMgr<R>>,
121
) -> Result<DirResponse>
122
where
123
    CR: request::Requestable + ?Sized,
124
    R: Runtime,
125
    SP: SleepProvider,
126
{
127
    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
128

            
129
    if req.anonymized() == AnonymizedRequest::Anonymized {
130
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
131
    }
132

            
133
    // TODO(nickm) This should be an option, and is too long.
134
    let begin_timeout = Duration::from_secs(5);
135
    let source = match SourceInfo::from_tunnel(&tunnel) {
136
        Ok(source) => source,
137
        Err(e) => {
138
            return Err(Error::RequestFailed(RequestFailedError {
139
                source: None,
140
                error: e.into(),
141
            }));
142
        }
143
    };
144

            
145
    let wrap_err = |error| {
146
        Error::RequestFailed(RequestFailedError {
147
            source: source.clone(),
148
            error,
149
        })
150
    };
151

            
152
    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
153

            
154
    // Launch the stream.
155
    let mut stream = runtime
156
        .timeout(begin_timeout, tunnel.begin_dir_stream())
157
        .await
158
        .map_err(RequestError::from)
159
        .map_err(wrap_err)?
160
        .map_err(RequestError::from)
161
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
162

            
163
    // TODO: Perhaps we want separate timeouts for each phase of this.
164
    // For now, we just use higher-level timeouts in `dirmgr`.
165
    let r = send_request(runtime, req, &mut stream, source.clone()).await;
166

            
167
    if should_retire_circ(&r) {
168
        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
169
    }
170

            
171
    r
172
}
173

            
174
/// Return true if `result` holds an error indicating that we should retire the
175
/// circuit used for the corresponding request.
176
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
177
    match result {
178
        Err(e) => e.should_retire_circ(),
179
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
180
    }
181
}
182

            
183
/// Fetch a Tor directory object from a provided stream.
184
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
185
pub async fn download<R, S, SP>(
186
    runtime: &SP,
187
    req: &R,
188
    stream: &mut S,
189
    source: Option<SourceInfo>,
190
) -> Result<DirResponse>
191
where
192
    R: request::Requestable + ?Sized,
193
    S: AsyncRead + AsyncWrite + Send + Unpin,
194
    SP: SleepProvider,
195
{
196
    send_request(runtime, req, stream, source).await
197
}
198

            
199
/// Fetch or upload a Tor directory object using the provided stream.
200
///
201
/// To do this, we send a simple HTTP/1.0 request for the described
202
/// object in `req` over `stream`, and then wait for a response.  In
203
/// log messages, we describe the origin of the data as coming from
204
/// `source`.
205
///
206
/// # Notes
207
///
208
/// It's kind of bogus to have a 'source' field here at all; we may
209
/// eventually want to remove it.
210
///
211
/// This function doesn't close the stream; you may want to do that
212
/// yourself.
213
///
214
/// The only error variant returned is [`Error::RequestFailed`].
215
// TODO: should the error return type change to `RequestFailedError`?
216
// If so, that would simplify some code in_dirmgr::bridgedesc.
217
228
pub async fn send_request<R, S, SP>(
218
228
    runtime: &SP,
219
228
    req: &R,
220
228
    stream: &mut S,
221
228
    source: Option<SourceInfo>,
222
228
) -> Result<DirResponse>
223
228
where
224
228
    R: request::Requestable + ?Sized,
225
228
    S: AsyncRead + AsyncWrite + Send + Unpin,
226
228
    SP: SleepProvider,
227
228
{
228
228
    let wrap_err = |error| {
229
56
        Error::RequestFailed(RequestFailedError {
230
56
            source: source.clone(),
231
56
            error,
232
56
        })
233
56
    };
234

            
235
228
    let partial_ok = req.partial_response_body_ok();
236
228
    let maxlen = req.max_response_len();
237
228
    let anonymized = req.anonymized();
238
228
    let req = req.make_request().map_err(wrap_err)?;
239
228
    let method = req.method().clone();
240
228
    let encoded = util::encode_request(&req);
241

            
242
    // Write the request.
243
228
    stream
244
228
        .write_all(encoded.as_bytes())
245
228
        .await
246
228
        .map_err(RequestError::from)
247
228
        .map_err(wrap_err)?;
248
228
    stream
249
228
        .flush()
250
228
        .await
251
228
        .map_err(RequestError::from)
252
228
        .map_err(wrap_err)?;
253

            
254
228
    let mut buffered = BufReader::new(stream);
255

            
256
    // Handle the response
257
    // TODO: should there be a separate timeout here?
258
228
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
259
174
    if header.status != Some(200) {
260
48
        return Ok(DirResponse::new(
261
48
            method,
262
48
            header.status.unwrap_or(0),
263
48
            header.status_message,
264
48
            None,
265
48
            vec![],
266
48
            source,
267
48
        ));
268
126
    }
269

            
270
126
    let mut decoder =
271
126
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
272

            
273
126
    let mut result = Vec::new();
274
126
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
275

            
276
126
    let ok = match (partial_ok, ok, result.len()) {
277
2
        (true, Err(e), n) if n > 0 => {
278
            // Note that we _don't_ return here: we want the partial response.
279
2
            Err(e)
280
        }
281
2
        (_, Err(e), _) => {
282
2
            return Err(wrap_err(e));
283
        }
284
122
        (_, Ok(()), _) => Ok(()),
285
    };
286

            
287
124
    Ok(DirResponse::new(
288
124
        method,
289
124
        200,
290
124
        None,
291
124
        ok.err(),
292
124
        result,
293
124
        source,
294
124
    ))
295
228
}
296

            
297
/// Maximum length for the HTTP headers in a single request or response.
298
///
299
/// Chosen more or less arbitrarily.
300
const MAX_HEADERS_LEN: usize = 16384;
301

            
302
/// Read and parse HTTP/1 headers from `stream`.
303
236
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
304
236
where
305
236
    S: AsyncBufRead + Unpin,
306
236
{
307
236
    let mut buf = Vec::with_capacity(1024);
308

            
309
    loop {
310
        // TODO: it's inefficient to do this a line at a time; it would
311
        // probably be better to read until the CRLF CRLF ending of the
312
        // response.  But this should be fast enough.
313
448
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
314

            
315
        // TODO(nickm): Better maximum and/or let this expand.
316
400
        let mut headers = [httparse::EMPTY_HEADER; 32];
317
400
        let mut response = httparse::Response::new(&mut headers);
318

            
319
400
        match response.parse(&buf[..])? {
320
            httparse::Status::Partial => {
321
                // We didn't get a whole response; we may need to try again.
322

            
323
220
                if n == 0 {
324
                    // We hit an EOF; no more progress can be made.
325
6
                    return Err(RequestError::TruncatedHeaders);
326
214
                }
327

            
328
214
                if buf.len() >= MAX_HEADERS_LEN {
329
2
                    return Err(RequestError::HeadersTooLong(buf.len()));
330
212
                }
331
            }
332
178
            httparse::Status::Complete(n_parsed) => {
333
178
                if response.code != Some(200) {
334
50
                    return Ok(HeaderStatus {
335
50
                        status: response.code,
336
50
                        status_message: response.reason.map(str::to_owned),
337
50
                        encoding: None,
338
50
                    });
339
128
                }
340
128
                let encoding = if let Some(enc) = response
341
128
                    .headers
342
128
                    .iter()
343
128
                    .find(|h| h.name == "Content-Encoding")
344
                {
345
6
                    Some(String::from_utf8(enc.value.to_vec())?)
346
                } else {
347
122
                    None
348
                };
349
                /*
350
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
351
                    let clen = std::str::from_utf8(clen.value)?;
352
                    length = Some(clen.parse()?);
353
                }
354
                 */
355
128
                assert!(n_parsed == buf.len());
356
128
                return Ok(HeaderStatus {
357
128
                    status: Some(200),
358
128
                    status_message: None,
359
128
                    encoding,
360
128
                });
361
            }
362
        }
363
212
        if n == 0 {
364
            return Err(RequestError::TruncatedHeaders);
365
212
        }
366
    }
367
236
}
368

            
369
/// Return value from read_headers
370
#[derive(Debug, Clone)]
371
struct HeaderStatus {
372
    /// HTTP status code.
373
    status: Option<u16>,
374
    /// HTTP status message associated with the status code.
375
    status_message: Option<String>,
376
    /// The Content-Encoding header, if any.
377
    encoding: Option<String>,
378
}
379

            
380
/// Helper: download directory information from `stream` and
381
/// decompress it into a result buffer.  Assumes that `buf` is empty.
382
///
383
/// If we get more than maxlen bytes after decompression, give an error.
384
///
385
/// Returns the status of our download attempt, stores any data that
386
/// we were able to download into `result`.  Existing contents of
387
/// `result` are overwritten.
388
140
async fn read_and_decompress<S, SP>(
389
140
    runtime: &SP,
390
140
    mut stream: S,
391
140
    maxlen: usize,
392
140
    result: &mut Vec<u8>,
393
140
) -> RequestResult<()>
394
140
where
395
140
    S: AsyncRead + Unpin,
396
140
    SP: SleepProvider,
397
140
{
398
140
    let buffer_window_size = 1024;
399
140
    let mut written_total: usize = 0;
400
    // TODO(nickm): This should be an option, and is maybe too long.
401
    // Though for some users it may be too short?
402
140
    let read_timeout = Duration::from_secs(10);
403
140
    let timer = runtime.sleep(read_timeout).fuse();
404
140
    futures::pin_mut!(timer);
405

            
406
    loop {
407
        // allocate buffer for next read
408
626
        result.resize(written_total + buffer_window_size, 0);
409
626
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
410

            
411
626
        let status = futures::select! {
412
626
            status = stream.read(buf).fuse() => status,
413
            _ = timer => {
414
                result.resize(written_total, 0); // truncate as needed
415
                return Err(RequestError::DirTimeout);
416
            }
417
        };
418
626
        let written_in_this_loop = match status {
419
620
            Ok(n) => n,
420
6
            Err(other) => {
421
6
                result.resize(written_total, 0); // truncate as needed
422
6
                return Err(other.into());
423
            }
424
        };
425

            
426
620
        written_total += written_in_this_loop;
427

            
428
        // exit conditions below
429

            
430
620
        if written_in_this_loop == 0 {
431
            /*
432
            in case we read less than `buffer_window_size` in last `read`
433
            we need to shrink result because otherwise we'll return those
434
            un-read 0s
435
            */
436
132
            if written_total < result.len() {
437
132
                result.resize(written_total, 0);
438
132
            }
439
132
            return Ok(());
440
488
        }
441

            
442
        // TODO: It would be good to detect compression bombs, but
443
        // that would require access to the internal stream, which
444
        // would in turn require some tricky programming.  For now, we
445
        // use the maximum length here to prevent an attacker from
446
        // filling our RAM.
447
488
        if written_total > maxlen {
448
2
            result.resize(maxlen, 0);
449
2
            return Err(RequestError::ResponseTooLong(written_total));
450
486
        }
451
    }
452
140
}
453

            
454
/// Retire a directory circuit because of an error we've encountered on it.
455
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
456
where
457
    R: Runtime,
458
{
459
    info!(
460
        "{}: Retiring circuit because of directory failure: {}",
461
        &id, &error
462
    );
463
    circ_mgr.retire_circ(id);
464
}
465

            
466
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
467
///
468
/// Note that this function might not actually read any byte of value
469
/// `byte`, since EOF might occur, or we might fill the buffer.
470
///
471
/// A return value of 0 indicates an end-of-file.
472
454
async fn read_until_limited<S>(
473
454
    stream: &mut S,
474
454
    byte: u8,
475
454
    max: usize,
476
454
    buf: &mut Vec<u8>,
477
454
) -> std::io::Result<usize>
478
454
where
479
454
    S: AsyncBufRead + Unpin,
480
454
{
481
454
    let mut n_added = 0;
482
    loop {
483
612
        let data = stream.fill_buf().await?;
484
564
        if data.is_empty() {
485
            // End-of-file has been reached.
486
12
            return Ok(n_added);
487
552
        }
488
552
        debug_assert!(n_added < max);
489
552
        let remaining_space = max - n_added;
490
552
        let (available, found_byte) = match memchr(byte, data) {
491
380
            Some(idx) => (idx + 1, true),
492
172
            None => (data.len(), false),
493
        };
494
552
        debug_assert!(available >= 1);
495
552
        let n_to_copy = std::cmp::min(remaining_space, available);
496
552
        buf.extend(&data[..n_to_copy]);
497
552
        stream.consume_unpin(n_to_copy);
498
552
        n_added += n_to_copy;
499
552
        if found_byte || n_added == max {
500
394
            return Ok(n_added);
501
158
        }
502
    }
503
454
}
504

            
505
/// Helper: Return a boxed decoder object that wraps the stream  $s.
506
macro_rules! decoder {
507
    ($dec:ident, $s:expr) => {{
508
        let mut decoder = $dec::new($s);
509
        decoder.multiple_members(true);
510
        Ok(Box::new(decoder))
511
    }};
512
}
513

            
514
/// Wrap `stream` in an appropriate type to undo the content encoding
515
/// as described in `encoding`.
516
142
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
517
142
    stream: S,
518
142
    encoding: Option<&str>,
519
142
    anonymized: AnonymizedRequest,
520
142
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
521
    use AnonymizedRequest::Direct;
522
142
    match (encoding, anonymized) {
523
138
        (None | Some("identity"), _) => Ok(Box::new(stream)),
524
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
525
        // We only admit to supporting these on a direct connection; otherwise,
526
        // a hostile directory could send them back even though we hadn't
527
        // requested them.
528
        #[cfg(feature = "xz")]
529
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
530
        #[cfg(feature = "zstd")]
531
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
532
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
533
    }
534
142
}
535

            
536
#[cfg(test)]
537
mod test {
538
    // @@ begin test lint list maintained by maint/add_warning @@
539
    #![allow(clippy::bool_assert_comparison)]
540
    #![allow(clippy::clone_on_copy)]
541
    #![allow(clippy::dbg_macro)]
542
    #![allow(clippy::mixed_attributes_style)]
543
    #![allow(clippy::print_stderr)]
544
    #![allow(clippy::print_stdout)]
545
    #![allow(clippy::single_char_pattern)]
546
    #![allow(clippy::unwrap_used)]
547
    #![allow(clippy::unchecked_time_subtraction)]
548
    #![allow(clippy::useless_vec)]
549
    #![allow(clippy::needless_pass_by_value)]
550
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
551
    use super::*;
552
    use tor_rtmock::io::stream_pair;
553

            
554
    use tor_rtmock::simple_time::SimpleMockTimeProvider;
555

            
556
    use futures_await_test::async_test;
557

            
558
    #[async_test]
559
    async fn test_read_until_limited() -> RequestResult<()> {
560
        let mut out = Vec::new();
561
        let bytes = b"This line eventually ends\nthen comes another\n";
562

            
563
        // Case 1: find a whole line.
564
        let mut s = &bytes[..];
565
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
566
        assert_eq!(res?, 26);
567
        assert_eq!(&out[..], b"This line eventually ends\n");
568

            
569
        // Case 2: reach the limit.
570
        let mut s = &bytes[..];
571
        out.clear();
572
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
573
        assert_eq!(res?, 10);
574
        assert_eq!(&out[..], b"This line ");
575

            
576
        // Case 3: reach EOF.
577
        let mut s = &bytes[..];
578
        out.clear();
579
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
580
        assert_eq!(res?, 45);
581
        assert_eq!(&out[..], &bytes[..]);
582

            
583
        Ok(())
584
    }
585

            
586
    // Basic decompression wrapper.
587
    async fn decomp_basic(
588
        encoding: Option<&str>,
589
        data: &[u8],
590
        maxlen: usize,
591
    ) -> (RequestResult<()>, Vec<u8>) {
592
        // We don't need to do anything fancy here, since we aren't simulating
593
        // a timeout.
594
        #[allow(deprecated)] // TODO #1885
595
        let mock_time = SimpleMockTimeProvider::from_wallclock(std::time::SystemTime::now());
596

            
597
        let mut output = Vec::new();
598
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
599
            Ok(s) => s,
600
            Err(e) => return (Err(e), output),
601
        };
602

            
603
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
604

            
605
        (r, output)
606
    }
607

            
608
    #[async_test]
609
    async fn decompress_identity() -> RequestResult<()> {
610
        let mut text = Vec::new();
611
        for _ in 0..1000 {
612
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
613
        }
614

            
615
        let limit = 10 << 20;
616
        let (s, r) = decomp_basic(None, &text[..], limit).await;
617
        s?;
618
        assert_eq!(r, text);
619

            
620
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
621
        s?;
622
        assert_eq!(r, text);
623

            
624
        // Try truncated result
625
        let limit = 100;
626
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
627
        assert!(s.is_err());
628
        assert_eq!(r, &text[..100]);
629

            
630
        Ok(())
631
    }
632

            
633
    #[async_test]
634
    async fn decomp_zlib() -> RequestResult<()> {
635
        let compressed =
636
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
637

            
638
        let limit = 10 << 20;
639
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
640
        s?;
641
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
642

            
643
        Ok(())
644
    }
645

            
646
    #[cfg(feature = "zstd")]
647
    #[async_test]
648
    async fn decomp_zstd() -> RequestResult<()> {
649
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
650
        let limit = 10 << 20;
651
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
652
        s?;
653
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
654

            
655
        Ok(())
656
    }
657

            
658
    #[cfg(feature = "xz")]
659
    #[async_test]
660
    async fn decomp_xz2() -> RequestResult<()> {
661
        // Not so good at tiny files...
662
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
663
        let limit = 10 << 20;
664
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
665
        s?;
666
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
667

            
668
        Ok(())
669
    }
670

            
671
    #[async_test]
672
    async fn decomp_unknown() {
673
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
674
        let limit = 10 << 20;
675
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
676

            
677
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
678
    }
679

            
680
    #[async_test]
681
    async fn decomp_bad_data() {
682
        let compressed = b"This is not good zlib data";
683
        let limit = 10 << 20;
684
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
685

            
686
        // This should possibly be a different type in the future.
687
        assert!(matches!(s, Err(RequestError::IoError(_))));
688
    }
689

            
690
    #[async_test]
691
    async fn headers_ok() -> RequestResult<()> {
692
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
693

            
694
        let mut s = &text[..];
695
        let h = read_headers(&mut s).await?;
696

            
697
        assert_eq!(h.status, Some(200));
698
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
699

            
700
        // now try truncated
701
        let mut s = &text[..15];
702
        let h = read_headers(&mut s).await;
703
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
704

            
705
        // now try with no encoding.
706
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
707
        let mut s = &text[..];
708
        let h = read_headers(&mut s).await?;
709

            
710
        assert_eq!(h.status, Some(404));
711
        assert!(h.encoding.is_none());
712

            
713
        Ok(())
714
    }
715

            
716
    #[async_test]
717
    async fn headers_bogus() -> Result<()> {
718
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
719
        let mut s = &text[..];
720
        let h = read_headers(&mut s).await;
721

            
722
        assert!(h.is_err());
723
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
724
        Ok(())
725
    }
726

            
727
    /// Run a trivial download example with a response provided as a binary
728
    /// string.
729
    ///
730
    /// Return the directory response (if any) and the request as encoded (if
731
    /// any.)
732
    fn run_download_test<Req: request::Requestable>(
733
        req: Req,
734
        response: &[u8],
735
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
736
        let (mut s1, s2) = stream_pair();
737
        let (mut s2_r, mut s2_w) = s2.split();
738

            
739
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
740
            let rt2 = rt.clone();
741
            let (v1, v2, v3): (
742
                Result<DirResponse>,
743
                RequestResult<Vec<u8>>,
744
                RequestResult<()>,
745
            ) = futures::join!(
746
                async {
747
                    // Run the download function.
748
                    let r = send_request(&rt, &req, &mut s1, None).await;
749
                    s1.close().await.map_err(|error| {
750
                        Error::RequestFailed(RequestFailedError {
751
                            source: None,
752
                            error: error.into(),
753
                        })
754
                    })?;
755
                    r
756
                },
757
                async {
758
                    // Take the request from the client, and return it in "v2"
759
                    let mut v = Vec::new();
760
                    s2_r.read_to_end(&mut v).await?;
761
                    Ok(v)
762
                },
763
                async {
764
                    // Send back a response.
765
                    s2_w.write_all(response).await?;
766
                    // We wait a moment to give the other side time to notice it
767
                    // has data.
768
                    //
769
                    // (Tentative diagnosis: The `async-compress` crate seems to
770
                    // be behave differently depending on whether the "close"
771
                    // comes right after the incomplete data or whether it comes
772
                    // after a delay.  If there's a delay, it notices the
773
                    // truncated data and tells us about it. But when there's
774
                    // _no_delay, it treats the data as an error and doesn't
775
                    // tell our code.)
776

            
777
                    // TODO: sleeping in tests is not great.
778
                    rt2.sleep(Duration::from_millis(50)).await;
779
                    s2_w.close().await?;
780
                    Ok(())
781
                }
782
            );
783

            
784
            assert!(v3.is_ok());
785

            
786
            (v1, v2)
787
        })
788
    }
789

            
790
    #[test]
791
    fn test_send_request() -> RequestResult<()> {
792
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
793

            
794
        let (response, request) = run_download_test(
795
            req,
796
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
797
        );
798

            
799
        let request = request?;
800
        assert!(request[..].starts_with(
801
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
802
        ));
803

            
804
        let response = response.unwrap();
805
        assert_eq!(response.status_code(), 200);
806
        assert!(!response.is_partial());
807
        assert!(response.error().is_none());
808
        assert!(response.source().is_none());
809
        let out_ref = response.output_unchecked();
810
        assert_eq!(out_ref, b"This is where the descs would go.");
811
        let out = response.into_output_unchecked();
812
        assert_eq!(&out, b"This is where the descs would go.");
813

            
814
        Ok(())
815
    }
816

            
817
    #[test]
818
    fn test_download_truncated() {
819
        // Request only one md, so "partial ok" will not be set.
820
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
821
        let mut response_text: Vec<u8> =
822
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
823
        // "One fish two fish" as above twice, but truncated the second time
824
        response_text.extend(
825
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
826
        );
827
        response_text.extend(
828
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
829
        );
830
        let (response, request) = run_download_test(req, &response_text);
831
        assert!(request.is_ok());
832
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
833

            
834
        // request two microdescs, so "partial_ok" will be set.
835
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
836

            
837
        let (response, request) = run_download_test(req, &response_text);
838
        assert!(request.is_ok());
839

            
840
        let response = response.unwrap();
841
        assert_eq!(response.status_code(), 200);
842
        assert!(response.error().is_some());
843
        assert!(response.is_partial());
844
        assert!(response.output_unchecked().len() < 37 * 2);
845
        assert!(response.output_unchecked().starts_with(b"One fish"));
846
    }
847

            
848
    #[test]
849
    fn test_404() {
850
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
851
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
852
        let (response, _request) = run_download_test(req, response_text);
853

            
854
        assert_eq!(response.unwrap().status_code(), 418);
855
    }
856

            
857
    #[test]
858
    fn test_headers_truncated() {
859
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
860
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
861
        let (response, _request) = run_download_test(req, response_text);
862

            
863
        assert!(matches!(
864
            response,
865
            Err(Error::RequestFailed(RequestFailedError {
866
                error: RequestError::TruncatedHeaders,
867
                ..
868
            }))
869
        ));
870

            
871
        // Try a completely empty response.
872
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
873
        let response_text = b"";
874
        let (response, _request) = run_download_test(req, response_text);
875

            
876
        assert!(matches!(
877
            response,
878
            Err(Error::RequestFailed(RequestFailedError {
879
                error: RequestError::TruncatedHeaders,
880
                ..
881
            }))
882
        ));
883
    }
884

            
885
    #[test]
886
    fn test_headers_too_long() {
887
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
888
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
889
        response_text.resize(16384, b'A');
890
        let (response, _request) = run_download_test(req, &response_text);
891

            
892
        assert!(response.as_ref().unwrap_err().should_retire_circ());
893
        assert!(matches!(
894
            response,
895
            Err(Error::RequestFailed(RequestFailedError {
896
                error: RequestError::HeadersTooLong(_),
897
                ..
898
            }))
899
        ));
900
    }
901

            
902
    // TODO: test with bad utf-8
903
}