1
//! Module for helping with dirserver's HTTP interface.
2
//!
3
//! This module is unfortunately necessary as a middleware due to some obscure
4
//! things in Tor, most notably the ".z" extensions.
5

            
6
use cache::StoreCache;
7
use r2d2::Pool;
8
use r2d2_sqlite::SqliteConnectionManager;
9
use tor_error::internal;
10

            
11
use std::{
12
    collections::VecDeque,
13
    convert::Infallible,
14
    panic::{catch_unwind, AssertUnwindSafe},
15
    str::FromStr,
16
    sync::Arc,
17
    task::{Context, Poll},
18
    time::Duration,
19
};
20

            
21
use bytes::Bytes;
22
use futures::{Stream, StreamExt};
23
use http::{header, Method, Request, Response, StatusCode};
24
use http_body::{Body, Frame};
25
use hyper::{
26
    body::Incoming,
27
    server::conn::http1::{self},
28
    service::service_fn,
29
};
30
use hyper_util::rt::TokioIo;
31
use rusqlite::{params, Transaction};
32
use tokio::{
33
    io::{AsyncRead, AsyncWrite},
34
    task::JoinSet,
35
    time,
36
};
37
use tracing::warn;
38

            
39
use crate::database::{self, sql, ContentEncoding, DocumentId};
40

            
41
mod cache;
42

            
43
/// A type alias for the functions implementing endpoint logic.
44
///
45
/// An endpoint function is a function of the following form:
46
/// ```rust,ignore
47
/// fn get_consensus(
48
///     tx: &Transaction<'_>,
49
///     requ: &Request<Incoming>
50
/// ) -> Result<Response<Vec<DocumentId>>, Box<dyn std::error::Error + Send>>;
51
/// ```
52
///
53
/// The arguments give the endpoint function access to fixed state of the
54
/// database ([`Transaction`]) and the incoming [`Request`].  The return type is
55
/// a [`Result`] with an arbitrary error that implements [`Send`] and gets logged
56
/// but not returned to the client, which will just receive an `Internal Server Error`.
57
/// The [`Ok`] type of the [`Result`] is a [`Vec`] consisting of [`DocumentId`]
58
/// hashsums identifying (uncompressed) objects in the `store` table.
59
///
60
/// Changes to the database within the [`Transaction`] will (for now) get rolled
61
/// back, thereby giving the endpoint functions just read-only access to the
62
/// database.
63
///
64
/// TODO DIRMIRROR: Document the responsibilities here.
65
///
66
/// TODO DIRMIRROR: The error handling of endpoint functions may need further
67
/// discussions.  Maybe take a look at what other frameworks do?
68
type EndpointFn = fn(
69
    &Transaction,
70
    &Request<Incoming>,
71
) -> Result<Response<Vec<DocumentId>>, Box<dyn std::error::Error + Send>>;
72

            
73
/// A type that implements [`Body`] for a list of [`Arc<[u8]>`] data.
74
///
75
/// This is required because we use the reference counts as first-level return
76
/// types in order to avoid duplicate entires of the same data in memory.
77
/// See the documentation of [`StoreCache`] for more information on that.
78
struct DocumentBody(VecDeque<Arc<[u8]>>);
79

            
80
/// Representation of an endpoint, uniquely identified by a [`Method`] and path
81
/// pair followed by an appropriate [`EndpointFn`].
82
///
83
/// The path itself is a special string that refers to the endpoint at which this
84
/// resource should be available.  It supports a pattern-matching like syntax
85
/// through the use of the asterisk `*` character.
86
///
87
/// For example:
88
/// `/tor/status-vote/current/consensus` will match the URL exactly, whereas
89
/// `/tor/status-vote/current/*` will match every string that is in the
90
/// fourth component; such as `/tor/status-vote/current/consensus` or
91
/// `/tor/status-vote/current/consensus-microdesc`; it will however not
92
/// match in a prefix-like syntax, such as
93
/// `/tor/status-vote/current/consensus-microdesc/diff`.
94
///
95
/// In the case of non-unique matches, the first match wins.  Also, because
96
/// of wildcards, matching takes place in a `O(n)` fashion, so be sure to
97
/// to keep the `n` at a reasonable size.  This should not be much of a
98
/// problem for Tor applications though, because the list of endpoints is
99
/// reasonable (less than 30).
100
///
101
/// TODO: The entire asterisk matching is not so super nice, primarily because
102
/// it removes compile-time semantic checks; however, I cannot really think
103
/// of a much cleaner way that would not involve lots of boilerplate.
104
/// The most minimal "clean" way could be to do `path: &Option<&'static str>`
105
/// but I am not sure if this overhead is worth it, i.e.:
106
/// * `/tor/status-vote/current/*/diff/*/*`
107
/// * `[Some(""), Some("tor"), Some("status-vote"), Some("current"), None, ...]`
108
///   Maybe a macro could help here though ...
109
type Endpoint = (Method, &'static str, EndpointFn);
110

            
111
/// Representation of the core HTTP server.
112
#[derive(Debug)]
113
pub(crate) struct HttpServer {
114
    /// List of [`Endpoint`] entries.
115
    endpoints: Vec<Endpoint>,
116
    /// Access to the database pool.
117
    pool: Pool<SqliteConnectionManager>,
118
}
119

            
120
impl Body for DocumentBody {
121
    type Data = Bytes;
122
    type Error = Infallible;
123

            
124
8
    fn poll_frame(
125
8
        mut self: std::pin::Pin<&mut Self>,
126
8
        _cx: &mut Context<'_>,
127
8
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
128
        Poll::Ready(
129
8
            self.0
130
8
                .pop_front()
131
10
                .map(|bytes| Ok(Frame::data(Bytes::from_owner(bytes)))),
132
        )
133
8
    }
134
}
135

            
136
impl HttpServer {
137
    /// Creates a new [`HttpServer`] with a given [`Vec`] of [`Endpoint`] entries
138
    /// alongside access to the database [`Pool`].
139
2
    pub(crate) fn new(endpoints: Vec<Endpoint>, pool: Pool<SqliteConnectionManager>) -> Self {
140
2
        Self { endpoints, pool }
141
2
    }
142

            
143
    /// Runs the server endlessly in the current task.
144
    ///
145
    /// This function does not fail, because all errors that could potentially
146
    /// occur, occur in further sub-tasks spawned by it and handled appropriately,
147
    /// that is ususally logging the error and continuing the exeuction.
148
    #[allow(clippy::cognitive_complexity)]
149
2
    pub(crate) async fn serve<I, S, E>(self, mut listener: I) -> Result<(), tor_error::Bug>
150
2
    where
151
2
        I: Stream<Item = Result<S, E>> + Unpin,
152
2
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
153
2
        E: std::error::Error,
154
2
    {
155
2
        let cache = Arc::new(StoreCache::new());
156
2
        let endpoints: Arc<[Endpoint]> = self.endpoints.into();
157
2
        let pool = self.pool;
158

            
159
        // We operate exclusively in JoinSets so that everything gets aborted
160
        // nicely in order without causing any sort of leaks.
161
2
        let mut hyper_tasks: JoinSet<Result<(), hyper::Error>> = JoinSet::new();
162
2
        let mut misc_tasks: JoinSet<()> = JoinSet::new();
163

            
164
        // Spawn a simple garbage collection task that periodically removes
165
        // dead references, just in case, from the StoreCache.
166
2
        misc_tasks.spawn({
167
2
            let cache = cache.clone();
168
2
            async move {
169
                loop {
170
2
                    cache.gc();
171
2
                    time::sleep(Duration::from_secs(60)).await;
172
                }
173
            }
174
        });
175

            
176
        loop {
177
4
            tokio::select! {
178
4
                res = listener.next() => match res {
179
                    // Connection successfully accepted.
180
2
                    Some(Ok(s)) => Self::dispatch_stream(&cache, &endpoints, &pool, &mut hyper_tasks, s),
181

            
182
                    // There has been an error in accepting the connection.
183
                    Some(Err(e)) => {
184
                        warn!("listener accept failure: {e}");
185
                        continue;
186
                    }
187

            
188
                    // This should not happen due to ownership.
189
                    None => return Err(internal!("listener was closed externally?")),
190
                },
191

            
192
                // A hyper task we monitored in our tasks has exiteed.
193
                //
194
                // We distinguish between graceful and ungraceful errors, with
195
                // the latter one being errors related to a failure in tokio's
196
                // joining itself, such as if the underlying task panic'ed;
197
                // whereas graceful errors are logical application level errors.
198
4
                Some(res) = hyper_tasks.join_next() => match res {
199
                    Ok(Ok(())) => {},
200
                    Ok(Err(e)) => warn!("client task encountered an error: {e}"),
201
                    Err(e) => warn!("client task exited ungracefully: {e}"),
202
                },
203

            
204
            }
205
        }
206
    }
207

            
208
    /// Dispatches a new [`Stream`] into an existing [`JoinSet`].
209
2
    fn dispatch_stream<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
210
2
        cache: &Arc<StoreCache>,
211
2
        endpoints: &Arc<[Endpoint]>,
212
2
        pool: &Pool<SqliteConnectionManager>,
213
2
        tasks: &mut JoinSet<Result<(), hyper::Error>>,
214
2
        stream: S,
215
2
    ) {
216
2
        let stream = TokioIo::new(stream);
217

            
218
        // Create the `service_fn` to pass to `hyper`.
219
        //
220
        // Unfortunately, we have to clone the reference counter of all shared
221
        // objects two times here.  The first clone is required to not move
222
        // it into the `service_fn`, the second one is required to
223
        // circumvent a hyper limitation, namely that a service function
224
        // requires a `Fn`, not an `FnMut`, which would allow capturing values
225
        // from the environment natively.
226
2
        let cache = cache.clone();
227
2
        let endpoints = endpoints.clone();
228
2
        let pool = pool.clone();
229
4
        let service = service_fn(move |requ| {
230
4
            let cache = cache.clone();
231
4
            let endpoints = endpoints.clone();
232
4
            let pool = pool.clone();
233
4
            async move { Self::handler(cache, endpoints, pool, requ).await }
234
4
        });
235

            
236
2
        tasks.spawn(http1::Builder::new().serve_connection(stream, service));
237
2
    }
238

            
239
    /// A small wrapper function that creates a read-only or read-write
240
    /// [`Transaction`] based upon the [`Method`] and continues execution in
241
    /// [`Self::handler_tx`].
242
    #[allow(clippy::unused_async)] // TODO
243
4
    async fn handler(
244
4
        cache: Arc<StoreCache>,
245
4
        endpoints: Arc<[Endpoint]>,
246
4
        pool: Pool<SqliteConnectionManager>,
247
4
        requ: Request<Incoming>,
248
6
    ) -> Result<Response<DocumentBody>, Infallible> {
249
        // TODO: This would be the place to either use read_tx or rw_tx depending
250
        // on the method, but given that this is all GET at the moment, just go
251
        // with read_tx.
252
        Ok(
253
4
            database::read_tx(&pool, |tx| Self::handler_tx(&cache, &endpoints, tx, &requ))
254
4
                .unwrap_or_else(|e| {
255
                    warn!("database error: {e}");
256
                    Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR)
257
                }),
258
        )
259
4
    }
260

            
261
    /// A big monolithic function that handles incoming request with a consist
262
    /// view upon the database.
263
    ///
264
    /// The function works in eight steps which are documented with more detail
265
    /// within the code:
266
    /// 1. Determine the compression algorithm
267
    /// 2. Select an [`EndpointFn`] by matching the path component
268
    /// 3. Call the [`EndpointFn`] to obtain various [`DocumentId`]s
269
    /// 4. Map the [`DocumentId`]s to their compressed counterpart
270
    /// 5. Query the [`StoreCache`] with the [`DocumentId`] and [`Transaction`]
271
    ///    handle to store the document ref
272
    /// 6. Compose the [`Response`]
273
    ///
274
    /// TODO DIRMIRROR: Implement [`Method::HEAD`].
275
    #[allow(clippy::cognitive_complexity)]
276
4
    fn handler_tx(
277
4
        cache: &Arc<StoreCache>,
278
4
        endpoints: &[Endpoint],
279
4
        tx: &Transaction,
280
4
        requ: &Request<Incoming>,
281
4
    ) -> Response<DocumentBody> {
282
        // (1) Determine the compression algorithm
283
        //
284
        // This step determines the compression algorithm, according to:
285
        // https://spec.torproject.org/dir-spec/standards-compliance.html#http-headers.
286
4
        let (encoding, advertise_encoding) = Self::determine_encoding(requ);
287

            
288
        // (2) Select an `EndpointFn` by matching the path component
289
4
        let endpoint_fn = match Self::match_endpoint(endpoints, requ) {
290
4
            Some((_, _, endpoint_fn)) => endpoint_fn,
291
            None => return Self::empty_response(StatusCode::NOT_FOUND),
292
        };
293

            
294
        // (3) Call the `EndpointFn` to obtain various `DocumentId`s
295
6
        let endpoint_fn_resp = match catch_unwind(AssertUnwindSafe(|| endpoint_fn(tx, requ))) {
296
            // Everything went successful.
297
4
            Ok(Ok(r)) => r,
298

            
299
            // The endpoint function gracefully failed with an error.
300
            Ok(Err(e)) => {
301
                warn!(
302
                    "{} {}: endpoint function failed: {e}",
303
                    requ.method(),
304
                    requ.uri()
305
                );
306
                return Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR);
307
            }
308

            
309
            // The endpoint function unexpectedly crashed.
310
            Err(_) => {
311
                warn!(
312
                    "{} {}: endpoint function crashed",
313
                    requ.method(),
314
                    requ.uri()
315
                );
316
                return Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR);
317
            }
318
        };
319
4
        let (endpoint_fn_parts, docids) = endpoint_fn_resp.into_parts();
320

            
321
        // (4) Map the docids to their compressed counterpart
322
4
        let docids = docids
323
4
            .into_iter()
324
6
            .map(|docid| Self::map_encoding(tx, docid, encoding))
325
4
            .collect::<Result<Vec<_>, _>>();
326
4
        let docids = match docids {
327
4
            Ok(s) => s,
328
            Err(e) => {
329
                warn!(
330
                    "{} {}: unable to find compressed document: {e}",
331
                    requ.method(),
332
                    requ.uri()
333
                );
334
                return Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR);
335
            }
336
        };
337

            
338
        // (5) Query the [`StoreCache`] with the [`DocumentId`] and
339
        //     [`Transaction`] handle to store the document ref.
340
4
        let mut documents = VecDeque::new();
341
8
        for docid in docids {
342
4
            let document = match cache.get(tx, docid) {
343
4
                Ok(document) => document,
344
                Err(e) => {
345
                    warn!(
346
                        "{} {}: unable to access the cache: {e}",
347
                        requ.method(),
348
                        requ.uri()
349
                    );
350
                    return Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR);
351
                }
352
            };
353

            
354
4
            documents.push_back(document);
355
        }
356

            
357
        // (6) Compose the `Response`.
358
        //
359
        // The composing primarily consists of building a response from the parts
360
        // of the intermediate response plus optionally adding a Content-Encoding
361
        // header.
362
4
        let mut resp = Response::from_parts(endpoint_fn_parts, DocumentBody(documents));
363
4
        if advertise_encoding {
364
            // Add the Content-Encoding header, if necessary.
365
            resp.headers_mut().insert(
366
                header::CONTENT_ENCODING,
367
                encoding
368
                    .to_string()
369
                    .try_into()
370
                    .expect("strum serialized a non-valid header?!?"),
371
            );
372
4
        }
373

            
374
4
        resp
375
4
    }
376

            
377
    /// Determines the [`ContentEncoding`] based on the path and the value of [`header::ACCEPT_ENCODING`].
378
    ///
379
    /// This function returns a tuple containing the determined [`ContentEncoding`]
380
    /// alongside a boolean that indicates whether [`header::CONTENT_ENCODING`]
381
    /// should be set or not with the value of the just determined
382
    /// [`ContentEncoding`].
383
16
    fn determine_encoding<B: Body>(requ: &Request<B>) -> (ContentEncoding, bool) {
384
16
        let z_suffix = requ.uri().path().ends_with(".z");
385

            
386
        // TODO: Refactor this in a flat fashion once we get stable If-Let-Chains
387
        // by upgrading MSVC to 1.88.
388
        //
389
        // This works by branching the parameters into the following four branches:
390
        // 1. Accept-Encoding && ".z" URL
391
        // 2. Accept-Encoding && No ".z" URL
392
        // 3. No Accept-Encoding && ".z" URL
393
        // 4. No Accept-Encoding && No "z" URL
394

            
395
        // Technically we could use an else-if here, but given the branching
396
        // I explained above, I would like to keep it in the nested fashion
397
        // once we got stable If-Let.
398
        #[allow(clippy::collapsible_else_if)]
399
16
        if let Some(accept_encoding) = requ.headers().get(header::ACCEPT_ENCODING) {
400
            // Parse the accept_encoding value by splitting it at "," and then
401
            // parse each trimmed component as a ContentEncoding.  Unsupported
402
            // ContentEncodings are ignored.
403
8
            let encodings = accept_encoding
404
8
                .to_str()
405
8
                .unwrap_or("")
406
8
                .split(",")
407
14
                .filter_map(|encoding| ContentEncoding::from_str(encoding.trim()).ok())
408
8
                .collect::<Vec<_>>();
409

            
410
8
            if z_suffix {
411
                // (1) Accept-Encoding && ".z" URL
412
                //
413
                // From the specification:
414
                // > If the client does send an Accept-Encoding header along with
415
                // > a .z URL, the server SHOULD treat the request the same way
416
                // > as for the URL without the .z.  If deflate is included in the
417
                // > Accept-Encoding, the response MUST be encoded, once, with
418
                // > an encoding advertised by the client, and be accompanied by
419
                // > an appropriate Content-Encoding.
420

            
421
                // We do not check whether Accept-Encoding contains deflate,
422
                // because the specification gives us the assurance.
423
                // TODO: Maybe we should?
424
2
                (ContentEncoding::Deflate, true)
425
            } else {
426
                // (2) Accept-Encoding && No ".z" URL
427
6
                if let Some(encoding) = encodings.first() {
428
                    // Pick the first found encoding and include it in the header,
429
                    // if it is not the identity encoding.
430
4
                    let include_in_header = *encoding != ContentEncoding::Identity;
431
4
                    (*encoding, include_in_header)
432
                } else {
433
                    // No supported encodings were found, fallback to identity
434
                    // and do not provide a Content-Encoding header.
435
                    // This is effectively equivalent to (4).
436
2
                    (ContentEncoding::Identity, false)
437
                }
438
            }
439
        } else {
440
8
            if z_suffix {
441
                // (3) No Accept-Encoding && ".z" URL
442
                //
443
                // From the specification:
444
                // > If the client does not send an Accept-Encoding header along
445
                // > with a .z URL, the server MUST send the response compressed
446
                // > with deflate and SHOULD NOT send a Content-Encoding header.
447
4
                (ContentEncoding::Deflate, false)
448
            } else {
449
                // (4) No Accept-Encoding && No ".z" URL
450
4
                (ContentEncoding::Identity, false)
451
            }
452
        }
453
16
    }
454

            
455
    /// Matches an incoming request to an existing endpoint.
456
    ///
457
    /// The matching works in a first-match wins fashion.
458
    /// An endpoint is said to be matched when the following two properties for
459
    /// the incoming request hold true:
460
    /// * Both [`Method`] values are the same.
461
    /// * Each component of the URL path is equal at the respective position or,
462
    ///   in the case of the endpoint path, is a wildcard.
463
40
    fn match_endpoint<'a, B: Body>(
464
40
        endpoints: &'a [Endpoint],
465
40
        requ: &Request<B>,
466
40
    ) -> Option<&'a Endpoint> {
467
40
        let requ_path = requ.uri().path();
468
40
        let requ_path = requ_path.strip_suffix(".z").unwrap_or(requ_path);
469
40
        let requ_path = requ_path.split('/').collect::<Vec<_>>();
470
40
        let mut res = None;
471
116
        for tuple in endpoints.iter() {
472
116
            let (method, path, _endpoint_fn) = tuple;
473
116
            let path = path.split('/').collect::<Vec<_>>();
474

            
475
            // Filter the method out first.
476
116
            if requ.method() != method {
477
                continue;
478
116
            }
479

            
480
            // Now that the method is filtered out, perform the path matching
481
            // algorithm.
482
            //
483
            // The path algorithm works as follows:
484
            // 1. Check whether `path.len() == requ_path.len()`, for a match,
485
            //    two paths must have the same number of path components.
486
            // 2. Initialize `is_match = true`.
487
            // 3. Walk over the path components in pairs (i.e. compare first
488
            //    component of `path` with the first component of `requ_path`, ...)
489
            //    and check for each component tuple, whether they are equal or
490
            //    whether the component at the current position in path is a
491
            //    wildcard component, that is, a component that equals `*`.
492
            //
493
            //    Stop immediately the moment
494
            //    `path[i] == requ_path[i] || path[i] == "*"` yields `false`;
495
            //     set `is_match = false`.
496
            // 4. Check the result of `is_match`.
497

            
498
            // Paths must have the same number of components in order to match.
499
            // An inequality here means instant disqualification.
500
116
            if path.len() != requ_path.len() {
501
64
                continue;
502
52
            }
503

            
504
            // Iterate over the path component for component until we disqualify
505
            // for a match.
506
52
            let mut is_match = true;
507
164
            for (this, incoming) in path.iter().zip(&requ_path) {
508
164
                if this == incoming || *this == "*" {
509
138
                    continue;
510
                } else {
511
26
                    is_match = false;
512
26
                    break;
513
                }
514
            }
515

            
516
            // Stop on the first match, propagate the match to the outside.
517
52
            if is_match {
518
26
                res = Some(tuple);
519
26
                break;
520
26
            }
521
        }
522

            
523
40
        res
524
40
    }
525

            
526
    /// Looks up the corresponding [`DocumentId`] for a given [`DocumentId`] and
527
    /// a [`ContentEncoding`].
528
14
    fn map_encoding(
529
14
        tx: &Transaction,
530
14
        docid: DocumentId,
531
14
        encoding: ContentEncoding,
532
14
    ) -> Result<DocumentId, rusqlite::Error> {
533
        // If the encoding is the identity, do not bother about it any further.
534
14
        if encoding == ContentEncoding::Identity {
535
4
            return Ok(docid);
536
10
        }
537

            
538
10
        let mut stmt = tx.prepare_cached(sql!(
539
10
            "
540
10
            SELECT compressed_docid
541
10
            FROM compressed_document
542
10
              WHERE identity_docid = ?1
543
10
                AND algorithm = ?2
544
10
            "
545
10
        ))?;
546
10
        let compressed_docid =
547
15
            stmt.query_one(params![docid, encoding.to_string()], |row| row.get(0))?;
548

            
549
10
        Ok(compressed_docid)
550
14
    }
551

            
552
    /// Generates an empty response with a given [`StatusCode`].
553
    fn empty_response(status: StatusCode) -> Response<DocumentBody> {
554
        // TODO DIRMIRROR: Statically assert that.
555
        Response::builder()
556
            .status(status)
557
            .body(DocumentBody(VecDeque::new()))
558
            .expect("response builder for empty response failed?!?")
559
    }
560
}
561

            
562
#[cfg(test)]
563
pub(in crate::http) mod test {
564
    // @@ begin test lint list maintained by maint/add_warning @@
565
    #![allow(clippy::bool_assert_comparison)]
566
    #![allow(clippy::clone_on_copy)]
567
    #![allow(clippy::dbg_macro)]
568
    #![allow(clippy::mixed_attributes_style)]
569
    #![allow(clippy::print_stderr)]
570
    #![allow(clippy::print_stdout)]
571
    #![allow(clippy::single_char_pattern)]
572
    #![allow(clippy::unwrap_used)]
573
    #![allow(clippy::unchecked_time_subtraction)]
574
    #![allow(clippy::useless_vec)]
575
    #![allow(clippy::needless_pass_by_value)]
576
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
577
    use crate::database;
578

            
579
    use super::*;
580

            
581
    use std::{
582
        io::{Cursor, Write},
583
        str::FromStr,
584
    };
585

            
586
    use flate2::{
587
        write::{DeflateDecoder, DeflateEncoder, GzEncoder},
588
        Compression,
589
    };
590
    use http::Version;
591
    use http_body_util::{BodyExt, Empty};
592
    use lazy_static::lazy_static;
593
    use tokio::{
594
        net::{TcpListener, TcpStream},
595
        task,
596
    };
597
    use tokio_stream::wrappers::TcpListenerStream;
598

            
599
    pub(in crate::http) const IDENTITY: &str = "Lorem ipsum dolor sit amet.";
600

            
601
    lazy_static! {
602
        pub(in crate::http) static ref IDENTITY_DOCID: DocumentId =
603
            hex_to_docid("DD14CBBF0E74909AAC7F248A85D190AFD8DA98265CEF95FC90DFDDABEA7C2E66");
604
        pub(in crate::http) static ref DEFLATE_DOCID: DocumentId =
605
            hex_to_docid("07564DD13A7F4A6AD98B997F2938B1CEE11F8C7F358C444374521BA54D50D05E");
606
        pub(in crate::http) static ref GZIP_DOCID: DocumentId =
607
            hex_to_docid("1518107D3EF1EC6EAC3F3249DF26B2F845BC8226C326309F4822CAEF2E664104");
608
        pub(in crate::http) static ref XZ_STD_DOCID: DocumentId =
609
            hex_to_docid("17416948501F8E627CC9A8F7EFE7A2F32788D53CB84A5F67AC8FD4C1B59184CF");
610
        pub(in crate::http) static ref X_TOR_LZMA_DOCID: DocumentId =
611
            hex_to_docid("B5549F79A69113BDAF3EF0AD1D7D339D0083BC31400ECEE1B673F331CF26E239");
612
    }
613

            
614
6
    pub(in crate::http) fn create_test_db_pool() -> Pool<SqliteConnectionManager> {
615
6
        let pool = database::open("").unwrap();
616
6
        database::rw_tx(&pool, init_test_db).unwrap();
617
6
        pool
618
6
    }
619

            
620
10
    fn hex_to_docid(s: &str) -> DocumentId {
621
10
        let data: [u8; 32] = hex::decode(s).unwrap().try_into().unwrap();
622
10
        data.into()
623
10
    }
624

            
625
6
    fn init_test_db(tx: &Transaction) {
626
6
        assert_eq!(DocumentId::digest(IDENTITY.as_bytes()), *IDENTITY_DOCID);
627

            
628
6
        let deflate = {
629
6
            let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
630
6
            encoder.write_all(IDENTITY.as_bytes()).unwrap();
631
6
            encoder.finish().unwrap()
632
        };
633
6
        assert_eq!(DocumentId::digest(&deflate), *DEFLATE_DOCID);
634

            
635
6
        let gzip = {
636
6
            let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
637
6
            encoder.write_all(IDENTITY.as_bytes()).unwrap();
638
6
            encoder.finish().unwrap()
639
        };
640
6
        assert_eq!(DocumentId::digest(&gzip), *GZIP_DOCID);
641

            
642
6
        let xz_std = zstd::encode_all(IDENTITY.as_bytes(), 3).unwrap();
643
6
        assert_eq!(DocumentId::digest(&xz_std), *XZ_STD_DOCID);
644

            
645
6
        let mut x_tor_lzma = Vec::new();
646
6
        lzma_rs::lzma_compress(&mut Cursor::new(IDENTITY), &mut x_tor_lzma).unwrap();
647
6
        assert_eq!(DocumentId::digest(&x_tor_lzma), *X_TOR_LZMA_DOCID);
648

            
649
6
        tx.execute(
650
6
            sql!(
651
6
                "
652
6
                INSERT INTO store(docid, content) VALUES
653
6
                (?1, ?2), -- identity
654
6
                (?3, ?4), -- deflate
655
6
                (?5, ?6), -- gzip
656
6
                (?7, ?8), -- xzstd
657
6
                (?9, ?10) -- lzma
658
6
                "
659
6
            ),
660
6
            params![
661
6
                *IDENTITY_DOCID,
662
6
                IDENTITY.as_bytes().to_vec(),
663
6
                *DEFLATE_DOCID,
664
6
                deflate,
665
6
                *GZIP_DOCID,
666
6
                gzip,
667
6
                *XZ_STD_DOCID,
668
6
                xz_std,
669
6
                *X_TOR_LZMA_DOCID,
670
6
                x_tor_lzma
671
6
            ],
672
6
        )
673
6
        .unwrap();
674

            
675
6
        tx.execute(
676
6
            sql!(
677
6
                "
678
6
                INSERT INTO compressed_document(algorithm, identity_docid, compressed_docid) VALUES
679
6
                ('deflate', ?1, ?2),
680
6
                ('gzip', ?1, ?3),
681
6
                ('x-zstd', ?1, ?4),
682
6
                ('x-tor-lzma', ?1, ?5)
683
6
                "
684
6
            ),
685
6
            params![
686
6
                *IDENTITY_DOCID,
687
6
                *DEFLATE_DOCID,
688
6
                *GZIP_DOCID,
689
6
                *XZ_STD_DOCID,
690
6
                *X_TOR_LZMA_DOCID
691
6
            ],
692
6
        )
693
6
        .unwrap();
694
6
    }
695

            
696
    #[test]
697
2
    fn content_encoding() {
698
2
        assert_eq!(ContentEncoding::Identity.to_string(), "identity");
699
2
        assert_eq!(
700
2
            ContentEncoding::from_str("identity").unwrap(),
701
            ContentEncoding::Identity
702
        );
703

            
704
2
        assert_eq!(ContentEncoding::Deflate.to_string(), "deflate");
705
2
        assert_eq!(
706
2
            ContentEncoding::from_str("DeFlaTe").unwrap(),
707
            ContentEncoding::Deflate
708
        );
709

            
710
2
        assert_eq!(ContentEncoding::Gzip.to_string(), "gzip");
711
2
        assert_eq!(
712
2
            ContentEncoding::from_str("GzIP").unwrap(),
713
            ContentEncoding::Gzip
714
        );
715
2
        assert_eq!(ContentEncoding::XZstd.to_string(), "x-zstd");
716
2
        assert_eq!(
717
2
            ContentEncoding::from_str("x-zStD").unwrap(),
718
            ContentEncoding::XZstd
719
        );
720

            
721
2
        assert_eq!(ContentEncoding::XTorLzma.to_string(), "x-tor-lzma");
722
2
        assert_eq!(
723
2
            ContentEncoding::from_str("x-tOr-lzMa").unwrap(),
724
            ContentEncoding::XTorLzma
725
        );
726
2
    }
727

            
728
    #[test]
729
2
    fn determine_encoding() {
730
        // 1. Accept-Encoding && ".z" URL.
731
2
        let requ = Request::builder()
732
2
            .header("Accept-Encoding", "deflate,identity  ,  gzip")
733
2
            .uri("/foo.z")
734
2
            .body(String::new())
735
2
            .unwrap();
736
2
        assert_eq!(
737
2
            HttpServer::determine_encoding(&requ),
738
            (ContentEncoding::Deflate, true)
739
        );
740

            
741
        // 2a. Valid Accept-Encoding && No ".z" URL.
742
2
        let requ = Request::builder()
743
2
            .header("Accept-Encoding", "  gzip   ")
744
2
            .uri("/foo")
745
2
            .body(String::new())
746
2
            .unwrap();
747
2
        assert_eq!(
748
2
            HttpServer::determine_encoding(&requ),
749
            (ContentEncoding::Gzip, true)
750
        );
751

            
752
        // 2b. Identity Accept-Encoding && No ".z" URL.
753
2
        let requ = Request::builder()
754
2
            .header("Accept-Encoding", "identity")
755
2
            .uri("/foo")
756
2
            .body(String::new())
757
2
            .unwrap();
758
2
        assert_eq!(
759
2
            HttpServer::determine_encoding(&requ),
760
            (ContentEncoding::Identity, false)
761
        );
762

            
763
        // 2c. Invalid Accept-Encoding && No ".z" URL.
764
2
        let requ = Request::builder()
765
2
            .header("Accept-Encoding", "  unSuppOrtEd_EncODing_SCHEMA , yeah   ")
766
2
            .uri("/foo")
767
2
            .body(String::new())
768
2
            .unwrap();
769
2
        assert_eq!(
770
2
            HttpServer::determine_encoding(&requ),
771
            (ContentEncoding::Identity, false)
772
        );
773

            
774
        // 3. No Accept-Encoding && ".z" URL
775
2
        let requ = Request::builder()
776
2
            .uri("/foo.z")
777
2
            .body(String::new())
778
2
            .unwrap();
779
2
        assert_eq!(
780
2
            HttpServer::determine_encoding(&requ),
781
            (ContentEncoding::Deflate, false)
782
        );
783

            
784
        // 4. No Accept-Encoding && No ".z" URL
785
2
        let requ = Request::builder().uri("/foo").body(String::new()).unwrap();
786
2
        assert_eq!(
787
2
            HttpServer::determine_encoding(&requ),
788
            (ContentEncoding::Identity, false)
789
        );
790
2
    }
791

            
792
    #[test]
793
2
    fn match_endpoint() {
794
        /// Dummy call back that does nothing and is not even called.
795
        fn dummy(
796
            _: &Transaction,
797
            _: &Request<Incoming>,
798
        ) -> Result<Response<Vec<DocumentId>>, Box<dyn std::error::Error + Send>> {
799
            todo!()
800
        }
801

            
802
2
        let endpoints: Vec<Endpoint> = vec![
803
2
            (Method::GET, "/foo/bar/baz", dummy),
804
2
            (Method::GET, "/foo/*/baz", dummy),
805
2
            (Method::GET, "/bar/*", dummy),
806
2
            (Method::GET, "/", dummy),
807
        ];
808

            
809
        /// Basically a domain specific [`assert_eq`] that works by comparing
810
        /// pointers instead of a deep comparison.
811
        macro_rules! check_match {
812
            ($uri:literal, $endpoint:literal) => {
813
                let requ = Request::builder().uri($uri).body(String::new()).unwrap();
814
                let left: *const Endpoint = HttpServer::match_endpoint(&endpoints, &requ).unwrap();
815
                let right: *const Endpoint = &endpoints[$endpoint];
816
                assert_eq!(left, right);
817
            };
818
        }
819

            
820
        macro_rules! check_no_match {
821
            ($uri:literal) => {
822
                let requ = Request::builder().uri($uri).body(String::new()).unwrap();
823
                assert!(HttpServer::match_endpoint(&endpoints, &requ).is_none());
824
            };
825
        }
826

            
827
2
        check_match!("/foo/bar/baz", 0);
828
2
        check_match!("/foo/bar/baz.z", 0);
829
2
        check_no_match!("/foo/bar/baz1");
830
2
        check_no_match!("/foo/bar/baz/");
831

            
832
2
        check_match!("/foo/I_DONT_CARE/baz", 1);
833
2
        check_match!("/foo/I_DONT_CARE/baz.z", 1);
834
2
        check_match!("/foo//baz", 1);
835
2
        check_no_match!("/foo/");
836
2
        check_no_match!("/foo/foo");
837
2
        check_no_match!("/foo/foo/foo");
838

            
839
2
        check_match!("/bar/", 2);
840
2
        check_match!("/bar/.z", 2);
841
2
        check_match!("/bar/foo", 2);
842
2
        check_match!("/bar/foo.z", 2);
843
2
        check_no_match!("/bar/foo/");
844
2
        check_no_match!("/bar/foo/foo");
845

            
846
2
        check_match!("/", 3);
847
2
        check_match!("/.z", 3);
848
2
    }
849

            
850
    #[test]
851
2
    fn map_encoding() {
852
2
        let pool = create_test_db_pool();
853

            
854
2
        let data = [
855
2
            (ContentEncoding::Identity, *IDENTITY_DOCID),
856
2
            (ContentEncoding::Deflate, *DEFLATE_DOCID),
857
2
            (ContentEncoding::Gzip, *GZIP_DOCID),
858
2
            (ContentEncoding::XZstd, *XZ_STD_DOCID),
859
2
            (ContentEncoding::XTorLzma, *X_TOR_LZMA_DOCID),
860
2
        ];
861

            
862
3
        database::read_tx(&pool, |tx| {
863
12
            for (encoding, compressed_docid) in data {
864
10
                println!("{encoding}");
865
10
                assert_eq!(
866
10
                    HttpServer::map_encoding(tx, *IDENTITY_DOCID, encoding).unwrap(),
867
                    compressed_docid
868
                );
869
            }
870
2
        })
871
2
        .unwrap();
872
2
    }
873

            
874
    #[tokio::test]
875
3
    async fn basic_http_server() {
876
        // This is a stupid clippy false positive.
877
        #[allow(clippy::unnecessary_wraps)]
878
4
        fn identity(
879
4
            _tx: &Transaction<'_>,
880
4
            _requ: &Request<Incoming>,
881
4
        ) -> Result<Response<Vec<DocumentId>>, Box<dyn std::error::Error + Send>> {
882
4
            Ok(Response::new(vec![*IDENTITY_DOCID]))
883
4
        }
884

            
885
2
        let pool = create_test_db_pool();
886
2
        let server = HttpServer::new(
887
2
            vec![(Method::GET, "/tor/status-vote/current/consensus", identity)],
888
2
            pool,
889
        );
890

            
891
2
        let listener = TcpListener::bind("[::]:0").await.unwrap();
892
2
        let local_addr = listener.local_addr().unwrap();
893
2
        let listener = TcpListenerStream::new(listener);
894

            
895
2
        task::spawn(async move {
896
2
            server.serve(listener).await.unwrap();
897
        });
898

            
899
2
        let stream = TcpStream::connect(local_addr).await.unwrap();
900
2
        let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream))
901
2
            .await
902
2
            .unwrap();
903

            
904
2
        task::spawn(async move {
905
2
            if let Err(e) = conn.await {
906
                println!("Connection failed: {e:?}");
907
            }
908
        });
909

            
910
        // Perform a simple request.
911
        // TODO: Put this into one function for making requests or use reqwest.
912
2
        let requ = Request::builder()
913
2
            .version(Version::HTTP_11)
914
2
            .uri("/tor/status-vote/current/consensus")
915
2
            .body(Empty::<Bytes>::new())
916
2
            .unwrap();
917
2
        let mut resp = sender.send_request(requ).await.unwrap();
918
2
        let mut resp_body: Vec<u8> = Vec::new();
919
4
        while let Some(next) = resp.frame().await {
920
2
            resp_body.append(&mut next.unwrap().data_ref().unwrap().as_ref().to_vec());
921
2
        }
922
2
        assert_eq!(IDENTITY, String::from_utf8_lossy(&resp_body));
923

            
924
        // Perform a ".z" request.
925
2
        let requ = Request::builder()
926
2
            .version(Version::HTTP_11)
927
2
            .uri("/tor/status-vote/current/consensus.z")
928
2
            .body(Empty::<Bytes>::new())
929
2
            .unwrap();
930
2
        let mut resp = sender.send_request(requ).await.unwrap();
931
2
        let mut resp_body: Vec<u8> = Vec::new();
932
4
        while let Some(next) = resp.frame().await {
933
2
            resp_body.append(&mut next.unwrap().data_ref().unwrap().as_ref().to_vec());
934
2
        }
935
2
        let mut decoder = DeflateDecoder::new(Vec::new());
936
2
        decoder.write_all(&resp_body).unwrap();
937
2
        let decoded_resp = decoder.finish().unwrap();
938
3
        assert_eq!(IDENTITY, String::from_utf8_lossy(&decoded_resp));
939
2
    }
940
}