1
//! Access to the database schema.
2
//!
3
//! This module is not intended to provide a high-level ORM, instead it serves
4
//! the purpose of initializing and upgrading the database, if necessary.
5
//!
6
//! # Synchronous or Asynchronous?
7
//!
8
//! The question on whether the database and access to it shall be synchronous
9
//! or asynchronous has been fairly long debate that eventually got settled
10
//! after realizing that an asynchronous approach does not work.  This comment
11
//! should serve as a reminder for future devs, wondering why we use certain
12
//! synchronous primitives in an otherwise asynchronous codebase.
13
//!
14
//! Early on, it was clear that we would need some sort of connection pool,
15
//! primarily for two reasons:
16
//! 1. Performing frequent open and close calls in every task would be costly.
17
//! 2. Sharing a single connection object with a Mutex would be a waste
18
//!
19
//! Because the application itself is primarily asynchronous, we decided to go
20
//! with an asynchronous connection pool as well, leading to the choose of
21
//! `deadpool` initially.
22
//!
23
//! However, soon thereafter, problems with `deadpool` became evident.  Those
24
//! problems mostly stemmed from the synchronous nature of SQLite itself.  In our
25
//! case, this problem was initially triggered by figuring out a way to solve
26
//! `SQLITE_BUSY` handling.  In the end, we decided to settle upon the following
27
//! approach: Set `PRAGMA busy_timeout` to a certain value and create write
28
//! transactions with `BEGIN EXCLUSIVE`.  This way, SQLite would try to obtain
29
//! a write transaction for `busy_timeout` milliseconds by blocking the current
30
//! thread.  Due to this blocking, async no longer made any sense and was in
31
//! fact quite counter-productive because those potential sleep could screw a
32
//! lot of things up, which became very evident while trying to test this.
33
//!
34
//! Besides, throughout refactoring the code base, we realized that, even while
35
//! still using `deadpool`, the actual "asynchronous" calls interfacing with the
36
//! database became smaller and smaller.  In the end, the asynchronous code just
37
//! involved parts of obtaining a connection and creating a transaction,
38
//! eventually resulting in a calling a synchronous function taking the
39
//! transaction handle to perform the lion's share of the operation.
40

            
41
// TODO DIRMIRROR: This could benefit from methods by wrapping the pool into a
42
// custom type.
43

            
44
use std::{
45
    fmt::Display,
46
    io::{Cursor, Write},
47
    num::NonZero,
48
    ops::{Add, Sub},
49
    path::Path,
50
    time::{Duration, SystemTime},
51
};
52

            
53
use flate2::write::{DeflateEncoder, GzEncoder};
54
use r2d2::Pool;
55
use r2d2_sqlite::SqliteConnectionManager;
56
use rusqlite::{
57
    named_params, params,
58
    types::{FromSql, FromSqlError, FromSqlResult, ToSqlOutput, ValueRef},
59
    ToSql, Transaction, TransactionBehavior,
60
};
61
use saturating_time::SaturatingTime;
62
use sha2::Digest;
63
use tor_error::into_internal;
64

            
65
use crate::err::DatabaseError;
66

            
67
/// The identifier for documents in the content-addressable cache.
68
///
69
/// Right now, this is a Sha256 hash, but this may change in future.
70
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
71
pub(crate) struct DocumentId([u8; 32]);
72

            
73
impl DocumentId {
74
    /// Computes the [`DocumentId`] from arbitrary data.
75
74
    pub(crate) fn digest(data: &[u8]) -> Self {
76
74
        Self(sha2::Sha256::digest(data).into())
77
74
    }
78
}
79

            
80
impl Display for DocumentId {
81
    /// Formats the [`DocumentId`] in uppercase hexadecimal.
82
200
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83
200
        write!(f, "{}", hex::encode_upper(self.0))
84
200
    }
85
}
86

            
87
impl FromSql for DocumentId {
88
10
    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
89
        // We read the document id as a hexadecimal string from the database.
90
        // Afterwards, we convert it to binary data, which should succeed due
91
        // to database check constraints.  Finally, we verify the length to see
92
        // whether it actually constitutes a valid SHA256 checksum.
93
10
        let data: [u8; 32] = value
94
10
            .as_str()
95
10
            .map(hex::decode)?
96
10
            .map_err(|e| {
97
                FromSqlError::Other(Box::new(tor_error::internal!(
98
                    "non hex data in database? {e}"
99
                )))
100
            })?
101
10
            .try_into()
102
10
            .map_err(|_| {
103
                FromSqlError::Other(Box::new(tor_error::internal!(
104
                    "document id with invalid length in database?"
105
                )))
106
            })?;
107

            
108
10
        Ok(Self(data))
109
10
    }
110
}
111

            
112
impl ToSql for DocumentId {
113
198
    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
114
        // Because Self is only constructed with FromSql and digest data, it is
115
        // safe to assume to be valid.  Even if not, database constraints will
116
        // catch us from inserting invalid data.
117
198
        Ok(ToSqlOutput::from(self.to_string()))
118
198
    }
119
}
120

            
121
impl PartialEq<&str> for DocumentId {
122
2
    fn eq(&self, other: &&str) -> bool {
123
2
        self.to_string() == other.to_uppercase()
124
2
    }
125
}
126

            
127
#[cfg(test)]
128
impl From<[u8; 32]> for DocumentId {
129
10
    fn from(value: [u8; 32]) -> Self {
130
10
        Self(value)
131
10
    }
132
}
133

            
134
/// The supported content encodings.
135
#[derive(Debug, Clone, Copy, PartialEq, strum::EnumString, strum::Display, strum::EnumIter)]
136
#[strum(serialize_all = "kebab-case", ascii_case_insensitive)]
137
pub(crate) enum ContentEncoding {
138
    /// RFC2616 section 3.5.
139
    Identity,
140
    /// RFC2616 section 3.5.
141
    Deflate,
142
    /// RFC2616 section 3.5.
143
    Gzip,
144
    /// The zstandard compression algorithm (www.zstd.net).
145
    XZstd,
146
    /// The lzma compression algorithm with a "present" value no higher than 6.
147
    XTorLzma,
148
}
149

            
150
/// A wrapper around [`SystemTime`] with convenient features.
151
///
152
/// Please use this type throughout the crate internally, instead of
153
/// [`SystemTime`].
154
///
155
/// # Conversion
156
///
157
/// This type can be safely converted from and into a [`SystemTime`], because
158
/// it is just a wrapper type.
159
///
160
/// # Saturating Artihmetic
161
///
162
/// This type implements [`Add`] and [`Sub`] for [`Duration`] and [`Timestamp`]
163
/// ([`Sub`] only) using saturating artihmetic from the [`saturating_time`]
164
/// crate.  It means that addition and subtraction can be safely performed
165
/// without the potential risk of an unexpected panic, instead wrapping to
166
/// a local maximum/minimum or [`Duration::ZERO`] depending on the type.
167
///
168
/// Note that we don't provide a saturating version of [`Duration`], so addition
169
/// or substraction of two [`Duration`]s still needs care to avoid panics.
170
///
171
/// # SQLite Interaction
172
///
173
/// This type implements [`FromSql`] and [`ToSql`], making it convenient to
174
/// integrate into SQL statements, as the database schema represents timestamps
175
/// internally using a non-negative [`i64`] storing the seconds since the epoch.
176
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
177
pub(crate) struct Timestamp(SystemTime);
178

            
179
impl From<SystemTime> for Timestamp {
180
20020
    fn from(value: SystemTime) -> Self {
181
20020
        Self(value)
182
20020
    }
183
}
184

            
185
impl From<Timestamp> for SystemTime {
186
2
    fn from(value: Timestamp) -> Self {
187
2
        value.0
188
2
    }
189
}
190

            
191
impl Add<Duration> for Timestamp {
192
    type Output = Self;
193

            
194
    /// Performs a saturating addition wrapping to [`SystemTime::max_value()`].
195
20014
    fn add(self, rhs: Duration) -> Self::Output {
196
20014
        Self(self.0.saturating_add(rhs))
197
20014
    }
198
}
199

            
200
impl Sub<Duration> for Timestamp {
201
    type Output = Self;
202

            
203
    /// Performs a saturating subtraction wrapping to [`SystemTime::min_value()`].
204
4
    fn sub(self, rhs: Duration) -> Self::Output {
205
4
        Self(self.0.saturating_sub(rhs))
206
4
    }
207
}
208

            
209
impl Sub<Timestamp> for Timestamp {
210
    type Output = Duration;
211

            
212
    /// Performs a saturating duration_since wrapping to [`Duration::ZERO`].
213
80000
    fn sub(self, rhs: Timestamp) -> Self::Output {
214
80000
        self.0.saturating_duration_since(rhs.0)
215
80000
    }
216
}
217

            
218
impl FromSql for Timestamp {
219
34
    fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
220
34
        let mut res = SystemTime::UNIX_EPOCH;
221
34
        res = res.saturating_add(Duration::from_secs(value.as_i64()?.try_into().unwrap_or(0)));
222
34
        Ok(Self(res))
223
34
    }
224
}
225

            
226
impl ToSql for Timestamp {
227
42
    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
228
42
        Ok(ToSqlOutput::from(
229
42
            self.0
230
42
                .saturating_duration_since(SystemTime::UNIX_EPOCH)
231
42
                .as_secs()
232
42
                .try_into()
233
42
                .unwrap_or(i64::MAX),
234
42
        ))
235
42
    }
236
}
237

            
238
/// A no-op macro just returning the supplied.
239
///
240
/// The purpose of this macro is to semantically mark [`str`] literals to be
241
/// SQL statement.
242
///
243
/// Keep in mind that the compiler will not notice if you forget this macro.
244
/// Unfortunately, you have to ensure it yourself.
245
macro_rules! sql {
246
    ($s:literal) => {
247
        $s
248
    };
249
}
250

            
251
pub(crate) use sql;
252

            
253
/// Version 1 of the database schema.
254
const V1_SCHEMA: &str = sql!(
255
    "
256
-- Meta table to store the current schema version.
257
CREATE TABLE arti_dirserver_schema_version(
258
    version TEXT NOT NULL -- currently, always `1`
259
) STRICT;
260

            
261
-- Stores consensuses.
262
--
263
-- http://<hostname>/tor/status-vote/current/consensus-<FLAVOR>
264
-- http://<hostname>/tor/status-vote/current/consensus-<FLAVOR>/<F1>+<F2>+<F3>
265
-- http://<hostname>/tor/status-vote/current/consensus-<FLAVOR>/diff/<HASH>/<FPRLIST>
266
CREATE TABLE consensus(
267
    rowid               INTEGER PRIMARY KEY AUTOINCREMENT,
268
    docid               TEXT NOT NULL UNIQUE,
269
    -- Required for consensus diffs.
270
    -- https://spec.torproject.org/dir-spec/directory-cache-operation.html#diff-format
271
    unsigned_sha3_256   TEXT NOT NULL UNIQUE,
272
    flavor              TEXT NOT NULL,
273
    valid_after         INTEGER NOT NULL,
274
    fresh_until         INTEGER NOT NULL,
275
    valid_until         INTEGER NOT NULL,
276
    FOREIGN KEY(docid) REFERENCES store(docid),
277
    CHECK(GLOB('*[^0-9A-F]*', unsigned_sha3_256) == 0),
278
    CHECK(LENGTH(unsigned_sha3_256) == 64),
279
    CHECK(flavor IN ('ns', 'md')),
280
    CHECK(valid_after >= 0),
281
    CHECK(fresh_until >= 0),
282
    CHECK(valid_until >= 0),
283
    CHECK(valid_after < fresh_until),
284
    CHECK(fresh_until < valid_until)
285
) STRICT;
286

            
287
-- Stores consensus diffs.
288
--
289
-- http://<hostname>/tor/status-vote/current/consensus-<FLAVOR>/diff/<HASH>/<FPRLIST>
290
CREATE TABLE consensus_diff(
291
    rowid                   INTEGER PRIMARY KEY AUTOINCREMENT,
292
    docid                   TEXT NOT NULL UNIQUE,
293
    old_consensus_rowid     INTEGER NOT NULL,
294
    new_consensus_rowid     INTEGER NOT NULL,
295
    FOREIGN KEY(docid) REFERENCES store(docid),
296
    FOREIGN KEY(old_consensus_rowid) REFERENCES consensus(rowid),
297
    FOREIGN KEY(new_consensus_rowid) REFERENCES consensus(rowid)
298
) STRICT;
299

            
300
-- Stores the router descriptors.
301
--
302
-- http://<hostname>/tor/server/fp/<F>
303
-- http://<hostname>/tor/server/d/<D>
304
-- http://<hostname>/tor/server/authority
305
-- http://<hostname>/tor/server/all
306
CREATE TABLE router_descriptor(
307
    rowid                   INTEGER PRIMARY KEY AUTOINCREMENT,
308
    docid                   TEXT NOT NULL UNIQUE,
309
    sha1                    TEXT NOT NULL UNIQUE,
310
    sha2                    TEXT NOT NULL UNIQUE,
311
    kp_relay_id_rsa_sha1    TEXT NOT NULL,
312
    flavor                  TEXT NOT NULL,
313
    router_extra_info_rowid  INTEGER,
314
    FOREIGN KEY(docid) REFERENCES store(docid),
315
    FOREIGN KEY(router_extra_info_rowid) REFERENCES router_extra_info(rowid),
316
    CHECK(GLOB('*[^0-9A-F]*', sha1) == 0),
317
    CHECK(GLOB('*[^0-9A-F]*', kp_relay_id_rsa_sha1) == 0),
318
    CHECK(LENGTH(sha1) == 40),
319
    CHECK(docid == sha2),
320
    CHECK(LENGTH(kp_relay_id_rsa_sha1) == 40),
321
    CHECK(flavor IN ('ns', 'md'))
322
) STRICT;
323

            
324
-- Stores extra-info documents.
325
--
326
-- http://<hostname>/tor/extra/d/<D>
327
-- http://<hostname>/tor/extra/fp/<FP>
328
-- http://<hostname>/tor/extra/all
329
-- http://<hostname>/tor/extra/authority
330
CREATE TABLE router_extra_info(
331
    rowid                   INTEGER PRIMARY KEY AUTOINCREMENT,
332
    docid                   TEXT NOT NULL UNIQUE,
333
    sha1                    TEXT NOT NULL UNIQUE,
334
    kp_relay_id_rsa_sha1    TEXT NOT NULL,
335
    FOREIGN KEY(docid) REFERENCES store(docid),
336
    CHECK(GLOB('*[^0-9A-F]*', sha1) == 0),
337
    CHECK(GLOB('*[^0-9A-F]*', kp_relay_id_rsa_sha1) == 0),
338
    CHECK(LENGTH(sha1) == 40),
339
    CHECK(LENGTH(kp_relay_id_rsa_sha1) == 40)
340
) STRICT;
341

            
342
-- Directory authority key certificates.
343
--
344
-- This information is derived from the consensus documents.
345
--
346
-- http://<hostname>/tor/keys/all
347
-- http://<hostname>/tor/keys/authority
348
-- http://<hostname>/tor/keys/fp/<F>
349
-- http://<hostname>/tor/keys/sk/<F>-<S>
350
CREATE TABLE authority_key_certificate(
351
    rowid                   INTEGER PRIMARY KEY AUTOINCREMENT,
352
    docid                   TEXT NOT NULL UNIQUE,
353
    kp_auth_id_rsa_sha1     TEXT NOT NULL,
354
    kp_auth_sign_rsa_sha1   TEXT NOT NULL,
355
    dir_key_published       INTEGER NOT NULL,
356
    dir_key_expires         INTEGER NOT NULL,
357
    FOREIGN KEY(docid) REFERENCES store(docid),
358
    CHECK(GLOB('*[^0-9A-F]*', kp_auth_id_rsa_sha1) == 0),
359
    CHECK(GLOB('*[^0-9A-F]*', kp_auth_sign_rsa_sha1) == 0),
360
    CHECK(LENGTH(kp_auth_id_rsa_sha1) == 40),
361
    CHECK(LENGTH(kp_auth_sign_rsa_sha1) == 40),
362
    CHECK(dir_key_published >= 0),
363
    CHECK(dir_key_expires >= 0),
364
    CHECK(dir_key_published < dir_key_expires)
365

            
366
) STRICT;
367

            
368
-- Content addressable storage, storing all contents.
369
CREATE TABLE store(
370
    rowid   INTEGER PRIMARY KEY AUTOINCREMENT, -- hex uppercase
371
    docid   TEXT NOT NULL UNIQUE,
372
    content BLOB NOT NULL,
373
    CHECK(GLOB('*[^0-9A-F]*', docid) == 0),
374
    CHECK(LENGTH(docid) == 64)
375
) STRICT;
376

            
377
-- Stores compressed network documents.
378
CREATE TABLE compressed_document(
379
    rowid               INTEGER PRIMARY KEY AUTOINCREMENT,
380
    algorithm           TEXT NOT NULL,
381
    identity_docid      TEXT NOT NULL,
382
    compressed_docid   TEXT NOT NULL,
383
    FOREIGN KEY(identity_docid) REFERENCES store(docid),
384
    FOREIGN KEY(compressed_docid) REFERENCES store(docid),
385
    UNIQUE(algorithm, identity_docid)
386
) STRICT;
387

            
388
-- Stores the N:M cardinality of which router descriptors are contained in which
389
-- consensuses.
390
CREATE TABLE consensus_router_descriptor_member(
391
    consensus_rowid         INTEGER,
392
    router_descriptor_rowid INTEGER,
393
    PRIMARY KEY(consensus_rowid, router_descriptor_rowid),
394
    FOREIGN KEY(consensus_rowid) REFERENCES consensus(rowid),
395
    FOREIGN KEY(router_descriptor_rowid) REFERENCES router_descriptor(rowid)
396
) STRICT;
397

            
398
-- Stores which authority key signed which consensuses.
399
--
400
-- Required to implement the consensus retrieval by authority fingerprints as
401
-- well as the garbage collection of authority key certificates.
402
--
403
-- http://<hostname>/tor/status-vote/current/consensus-<FLAVOR>/<F1>+<F2>+<F3>
404
CREATE TABLE consensus_authority_voter(
405
    consensus_rowid INTEGER,
406
    authority_rowid INTEGER,
407
    PRIMARY KEY(consensus_rowid, authority_rowid),
408
    FOREIGN KEY(consensus_rowid) REFERENCES consensus(rowid),
409
    FOREIGN KEY(authority_rowid) REFERENCES authority_key_certificate(rowid)
410
) STRICT;
411

            
412
INSERT INTO arti_dirserver_schema_version VALUES ('1');
413
"
414
);
415

            
416
/// Global options set in every connection.
417
const GLOBAL_OPTIONS: &str = sql!(
418
    "
419
PRAGMA journal_mode=WAL;
420
PRAGMA foreign_keys=ON;
421
PRAGMA busy_timeout=1000;
422
"
423
);
424

            
425
/// Opens a database from disk, creating a [`Pool`] for it.
426
///
427
/// This function should be the entry point for all things requiring a database
428
/// handle, as this function prepares all necessary steps required for operating
429
/// on the database correctly, such as:
430
/// * Schema initialization.
431
/// * Schema upgrade.
432
/// * Setting connection specific settings.
433
///
434
/// # `SQLITE_BUSY` Caveat
435
///
436
/// There is a problem with the handling of `SQLITE_BUSY` when opening an
437
/// SQLite database.  In WAL, opening a database might acquire an exclusive lock
438
/// for a very short amount of time, in order to perform clean-up from previous
439
/// connections alongside other tasks for maintaining database integrity?  This
440
/// means, that opening multiple SQLite databases simultanously will result in
441
/// a busy error regardless of a busy handler, as setting a busy handler will
442
/// require an existing connection, something we are unable to obtain in the
443
/// first place.
444
///
445
/// In order to mitigate this issue, the recommended way in the SQLite community
446
/// is to simply ensure that database connections are opened sequentially,
447
/// by urging calling applications to just use a single [`Pool`] instance.
448
///
449
/// Testing this is hard unfortunately.
450
26
pub(crate) fn open<P: AsRef<Path>>(
451
26
    path: P,
452
26
) -> Result<Pool<SqliteConnectionManager>, DatabaseError> {
453
26
    let num_cores = std::thread::available_parallelism()
454
26
        .unwrap_or(NonZero::new(8).expect("8 == 0?"))
455
26
        .get() as u32;
456

            
457
26
    let manager = r2d2_sqlite::SqliteConnectionManager::file(&path);
458
26
    let pool = Pool::builder().max_size(num_cores).build(manager)?;
459

            
460
26
    rw_tx(&pool, |tx| {
461
        // Prepare the database, doing the following steps:
462
        // 1. Checking the database schema.
463
        // 2. Upgrading (in future) or initializing the database schema (if empty).
464

            
465
26
        let has_arti_dirserver_schema_version = match tx.query_one(
466
            sql!(
467
26
                "
468
26
                SELECT name
469
26
                FROM sqlite_master
470
26
                  WHERE type = 'table'
471
26
                    AND name = 'arti_dirserver_schema_version'
472
26
                "
473
            ),
474
26
            params![],
475
2
            |_| Ok(()),
476
        ) {
477
2
            Ok(()) => true,
478
24
            Err(rusqlite::Error::QueryReturnedNoRows) => false,
479
            Err(e) => return Err(DatabaseError::LowLevel(e)),
480
        };
481

            
482
26
        if has_arti_dirserver_schema_version {
483
2
            let version = tx.query_one(
484
2
                sql!("SELECT version FROM arti_dirserver_schema_version WHERE rowid = 1"),
485
2
                params![],
486
2
                |row| row.get::<_, String>(0),
487
            )?;
488

            
489
2
            match version.as_ref() {
490
2
                "1" => {}
491
2
                unknown => {
492
2
                    return Err(DatabaseError::IncompatibleSchema {
493
2
                        version: unknown.into(),
494
2
                    })
495
                }
496
            }
497
        } else {
498
24
            tx.execute_batch(V1_SCHEMA)?;
499
        }
500

            
501
24
        Ok::<_, DatabaseError>(())
502
26
    })??;
503

            
504
24
    Ok(pool)
505
26
}
506

            
507
/// Executes a closure `op` with a given read-only [`Transaction`].
508
///
509
/// The [`Transaction`] always gets rolled back the moment `op` returns.
510
///
511
/// The [`Transaction`] gets initialized with the global pragma options set.
512
///
513
/// **The closure shall not perform write operations!**
514
/// Not only do they get rolled back anyways, but upgrading the [`Transaction`]
515
/// from a read to a write transaction will lead to other simultanous write upgrades
516
/// to fail.  Unfortunately, there is no real programatic way to ensure this.
517
26
pub(crate) fn read_tx<U, F>(pool: &Pool<SqliteConnectionManager>, op: F) -> Result<U, DatabaseError>
518
26
where
519
26
    F: FnOnce(&Transaction<'_>) -> U,
520
{
521
26
    let mut conn = pool.get()?;
522
26
    conn.execute_batch(GLOBAL_OPTIONS)?;
523
26
    let tx = conn.transaction_with_behavior(TransactionBehavior::Deferred)?;
524
26
    let res = op(&tx);
525
26
    tx.rollback()?;
526
26
    Ok(res)
527
26
}
528

            
529
/// Executes a closure `op` with a given read-write [`Transaction`].
530
///
531
/// The [`Transaction`] always gets committed the moment `op` returns.
532
///
533
/// The [`Transaction`] gets initialized with the global pragma options set.
534
///
535
/// The [`Transaction`] gets created with [`TransactionBehavior::Immediate`],
536
/// meaning it will immediately exist as a write connection, retrying in the
537
/// case of a [`rusqlite::ErrorCode::DatabaseBusy`] until it failed after 1s.
538
48
pub(crate) fn rw_tx<U, F>(pool: &Pool<SqliteConnectionManager>, op: F) -> Result<U, DatabaseError>
539
48
where
540
48
    F: FnOnce(&Transaction<'_>) -> U,
541
{
542
48
    let mut conn = pool.get()?;
543
48
    conn.execute_batch(GLOBAL_OPTIONS)?;
544
48
    let tx = conn.transaction_with_behavior(TransactionBehavior::Exclusive)?;
545
46
    let res = op(&tx);
546
46
    tx.commit()?;
547
46
    Ok(res)
548
48
}
549

            
550
/// Inserts `data` into store while also compressing it with given encodings.
551
///
552
/// Returns the [`DocumentId`] of `data`.
553
///
554
/// This function inserts `data` into store and also compresses it into all
555
/// given compression formats.
556
///
557
/// Duplicates get re-encoded and replaced in the database, including
558
/// [`ContentEncoding::Identity`].
559
8
pub(crate) fn store_insert<I: Iterator<Item = ContentEncoding>>(
560
8
    tx: &Transaction,
561
8
    data: &[u8],
562
8
    encodings: I,
563
8
) -> Result<DocumentId, DatabaseError> {
564
    // The statement to insert some data into the store.
565
    //
566
    // Parameters:
567
    // :docid - The docid.
568
    // :content - The binary data.
569
8
    let mut store_stmt = tx.prepare_cached(sql!(
570
8
        "
571
8
        INSERT OR REPLACE INTO store (docid, content)
572
8
        VALUES
573
8
        (:docid, :content)
574
8
        "
575
8
    ))?;
576

            
577
    // The statement to insert a compressed document into the metatable.
578
    //
579
    // Parameters:
580
    // :algorithm - The name of the encoding algorithm.
581
    // :identity_docid - The docid of the plain-text document in the store.
582
    // :compressed_docid - The docid of the encoded document in the store.
583
8
    let mut compressed_stmt = tx.prepare_cached(sql!(
584
8
        "
585
8
        INSERT OR REPLACE INTO compressed_document (algorithm, identity_docid, compressed_docid)
586
8
        VALUES
587
8
        (:algorithm, :identity_docid, :compressed_docid)
588
8
        "
589
8
    ))?;
590

            
591
    // Insert the plain document into the store.
592
8
    let identity_docid = DocumentId::digest(data);
593
8
    store_stmt.execute(named_params! {
594
8
        ":docid": identity_docid,
595
8
        ":content": data
596
8
    })?;
597

            
598
    // Compress it into all formats and insert it into store and compressed.
599
48
    for encoding in encodings {
600
40
        if encoding == ContentEncoding::Identity {
601
            // Ignore identity because we inserted that above.
602
8
            continue;
603
32
        }
604

            
605
        // We map a compression error to a bug because there is no good reason
606
        // on why it should fail, given that we compress from memory data to
607
        // memory data.  Probably because it uses the std::io::Writer interface
608
        // which itself demands use of std::io::Result.
609
32
        let compressed = compress(data, encoding).map_err(into_internal!("{encoding} failed?"))?;
610
32
        let compressed_docid = DocumentId::digest(&compressed);
611
32
        store_stmt.execute(named_params! {
612
32
            ":docid": compressed_docid,
613
32
            ":content": compressed,
614
32
        })?;
615
32
        compressed_stmt.execute(named_params! {
616
32
            ":algorithm": encoding.to_string(),
617
32
            ":identity_docid": identity_docid,
618
32
            ":compressed_docid": compressed_docid,
619
32
        })?;
620
    }
621

            
622
8
    Ok(identity_docid)
623
8
}
624

            
625
/// Compresses `data` into a specified [`ContentEncoding`].
626
///
627
/// Returns a [`Vec`] containing the encoded data.
628
42
fn compress(data: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>, std::io::Error> {
629
42
    match encoding {
630
2
        ContentEncoding::Identity => Ok(data.to_vec()),
631
        ContentEncoding::Deflate => {
632
10
            let mut w = DeflateEncoder::new(Vec::new(), Default::default());
633
10
            w.write_all(data)?;
634
10
            w.finish()
635
        }
636
        ContentEncoding::Gzip => {
637
10
            let mut w = GzEncoder::new(Vec::new(), Default::default());
638
10
            w.write_all(data)?;
639
10
            w.finish()
640
        }
641
10
        ContentEncoding::XZstd => zstd::encode_all(data, Default::default()),
642
        ContentEncoding::XTorLzma => {
643
10
            let mut res = Vec::new();
644
10
            lzma_rs::lzma_compress(&mut Cursor::new(data), &mut res)?;
645
10
            Ok(res)
646
        }
647
    }
648
42
}
649

            
650
#[cfg(test)]
651
mod test {
652
    // @@ begin test lint list maintained by maint/add_warning @@
653
    #![allow(clippy::bool_assert_comparison)]
654
    #![allow(clippy::clone_on_copy)]
655
    #![allow(clippy::dbg_macro)]
656
    #![allow(clippy::mixed_attributes_style)]
657
    #![allow(clippy::print_stderr)]
658
    #![allow(clippy::print_stdout)]
659
    #![allow(clippy::single_char_pattern)]
660
    #![allow(clippy::unwrap_used)]
661
    #![allow(clippy::unchecked_time_subtraction)]
662
    #![allow(clippy::useless_vec)]
663
    #![allow(clippy::needless_pass_by_value)]
664
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
665
    use std::{
666
        collections::HashSet,
667
        io::Read,
668
        sync::{Arc, Once},
669
    };
670

            
671
    use flate2::read::{DeflateDecoder, GzDecoder};
672
    use rusqlite::Connection;
673
    use strum::IntoEnumIterator;
674
    use tempfile::tempdir;
675

            
676
    use super::*;
677

            
678
    #[test]
679
    fn open() {
680
        let db_dir = tempdir().unwrap();
681
        let db_path = db_dir.path().join("db");
682

            
683
        super::open(&db_path).unwrap();
684
        let conn = Connection::open(&db_path).unwrap();
685

            
686
        // Check if the version was initialized properly.
687
        let version = conn
688
            .query_one(
689
                "SELECT version FROM arti_dirserver_schema_version WHERE rowid = 1",
690
                params![],
691
                |row| row.get::<_, String>(0),
692
            )
693
            .unwrap();
694
        assert_eq!(version, "1");
695

            
696
        // Set the version to something unknown.
697
        conn.execute(
698
            "UPDATE arti_dirserver_schema_version SET version = 42",
699
            params![],
700
        )
701
        .unwrap();
702
        drop(conn);
703

            
704
        assert_eq!(
705
            super::open(&db_path).unwrap_err().to_string(),
706
            "incompatible schema version: 42"
707
        );
708
    }
709

            
710
    #[test]
711
    fn read_tx() {
712
        let db_dir = tempdir().unwrap();
713
        let db_path = db_dir.path().join("db");
714

            
715
        let pool = super::open(&db_path).unwrap();
716

            
717
        // Do a write transaction despite forbidden.
718
        super::read_tx(&pool, |tx| {
719
            tx.execute_batch("DELETE FROM arti_dirserver_schema_version")
720
                .unwrap();
721
            let e = tx
722
                .query_one(
723
                    sql!("SELECT version FROM arti_dirserver_schema_version"),
724
                    params![],
725
                    |row| row.get::<_, String>(0),
726
                )
727
                .unwrap_err();
728
            assert_eq!(e, rusqlite::Error::QueryReturnedNoRows);
729
        })
730
        .unwrap();
731

            
732
        // Normal check.
733
        let version: String = super::read_tx(&pool, |tx| {
734
            tx.query_one(
735
                sql!("SELECT version FROM arti_dirserver_schema_version"),
736
                params![],
737
                |row| row.get(0),
738
            )
739
            .unwrap()
740
        })
741
        .unwrap();
742
        assert_eq!(version, "1");
743
    }
744

            
745
    #[test]
746
    fn rw_tx() {
747
        let db_dir = tempdir().unwrap();
748
        let db_path = db_dir.path().join("db");
749

            
750
        let pool = super::open(&db_path).unwrap();
751

            
752
        // Do a write transaction.
753
        super::rw_tx(&pool, |tx| {
754
            tx.execute_batch("DELETE FROM arti_dirserver_schema_version")
755
                .unwrap();
756
        })
757
        .unwrap();
758

            
759
        // Check that it was deleted.
760
        super::read_tx(&pool, |tx| {
761
            let e = tx
762
                .query_one(
763
                    sql!("SELECT version FROM arti_dirserver_schema_version"),
764
                    params![],
765
                    |row| row.get::<_, String>(0),
766
                )
767
                .unwrap_err();
768
            assert_eq!(e, rusqlite::Error::QueryReturnedNoRows);
769
        })
770
        .unwrap();
771
    }
772

            
773
    /// Tests whether our SQLite busy error handling works in normal situations.
774
    ///
775
    /// A normal situations means a situation where a lock is never held for
776
    /// more than 1000ms.  In our case, we will work with two threads.
777
    /// t1 will acquire an exclusive lock and inform t2 about it.  t2 waits
778
    /// until t1 has acquired this lock and then immediately informs t1, that
779
    /// it will now wait for a lock too.  Now, t1 will immediately terminate,
780
    /// thereby releasing the lock and leading t2 to eventually acquire it.
781
    #[test]
782
    fn rw_tx_busy_timeout_working() {
783
        let db_dir = tempdir().unwrap();
784
        let db_path = db_dir.path().join("db");
785
        let pool = super::open(db_path).unwrap();
786

            
787
        // t2 will wait on this before it starts doing stuff.
788
        let t1_acquired_lock = Arc::new(Once::new());
789
        // t1 will wait on this in order to terminate properly.
790
        let t2_is_waiting = Arc::new(Once::new());
791

            
792
        let t1 = std::thread::spawn({
793
            let pool = pool.clone();
794
            let t1_acquired_lock = t1_acquired_lock.clone();
795
            let t2_is_waiting = t2_is_waiting.clone();
796
            move || {
797
                super::rw_tx(&pool, move |_tx| {
798
                    // Inform t2 we have write lock.
799
                    t1_acquired_lock.call_once(|| ());
800
                    println!("t1 acquired write lock");
801

            
802
                    // Wait for t2 to start waiting.
803
                    t2_is_waiting.wait();
804
                })
805
                .unwrap();
806
                println!("t2 released write lock");
807
            }
808
        });
809

            
810
        println!("t2 waits for t1 to acquire write lock");
811
        t1_acquired_lock.wait();
812
        t2_is_waiting.call_once(|| ());
813
        super::rw_tx(&pool, |_| ()).unwrap();
814
        println!("t2 acquired and released write lock");
815
        t1.join().unwrap();
816
    }
817

            
818
    /// Tests whether our SQLite busy error handlings fails as expected.
819
    ///
820
    /// We configure SQLite to fail after 1000ms.  This test works with two
821
    /// threads.  t1 will acquire an exclusive lock on the database and will
822
    /// inform t2 about it, which itself will wait until t1 has acquired the
823
    /// lock.  t2 will then immediately try to also obtain an exclusive lock,
824
    /// which should fail after about 1000ms.  After the failure, t2 informs
825
    /// t1 that it has failed, causing t1 to terminate.
826
    #[test]
827
    fn rw_tx_busy_timeout_busy() {
828
        let db_dir = tempdir().unwrap();
829
        let db_path = db_dir.path().join("db");
830
        let pool = super::open(db_path).unwrap();
831

            
832
        // t2 will wait on this before it starts doing stuff.
833
        let t1_acquired_lock = Arc::new(Once::new());
834
        // t1 will wait on this in order to terminate properly.
835
        let t2_gave_up = Arc::new(Once::new());
836

            
837
        let t1 = std::thread::spawn({
838
            let pool = pool.clone();
839
            let t1_acquired_lock = t1_acquired_lock.clone();
840
            let t2_gave_up = t2_gave_up.clone();
841

            
842
            move || {
843
                super::rw_tx(&pool, move |_tx| {
844
                    // Inform t2 we have the write lock.
845
                    t1_acquired_lock.call_once(|| ());
846
                    println!("t1 acquired write lock");
847
                    // Wait for t2 to give up before we release (how mean from us).
848
                    t2_gave_up.wait();
849
                })
850
                .unwrap();
851
                println!("t1 released write lock");
852
            }
853
        });
854

            
855
        println!("t2 waits for t1 to acquire write lock");
856
        t1_acquired_lock.wait();
857
        let e = super::rw_tx(&pool, |_| ()).unwrap_err();
858
        assert_eq!(
859
            e.to_string(),
860
            "low-level rusqlite error: database is locked"
861
        );
862
        println!("t2 gave up on acquiring write lock");
863
        t2_gave_up.call_once(|| ());
864
        t1.join().unwrap();
865
    }
866

            
867
    #[test]
868
    fn store_insert() {
869
        let db_dir = tempdir().unwrap();
870
        let db_path = db_dir.path().join("db");
871

            
872
        super::open(&db_path).unwrap();
873
        let mut conn = Connection::open(&db_path).unwrap();
874
        let tx = conn.transaction().unwrap();
875

            
876
        let docid = super::store_insert(&tx, "foobar".as_bytes(), ContentEncoding::iter()).unwrap();
877
        assert_eq!(
878
            docid,
879
            "C3AB8FF13720E8AD9047DD39466B3C8974E592C2FA383D4A3960714CAEF0C4F2"
880
        );
881

            
882
        let res = tx
883
            .query_one(
884
                sql!(
885
                    "
886
                    SELECT content
887
                    FROM store
888
                    WHERE docid = 'C3AB8FF13720E8AD9047DD39466B3C8974E592C2FA383D4A3960714CAEF0C4F2'
889
                    "
890
                ),
891
                params![],
892
                |row| row.get::<_, Vec<u8>>(0),
893
            )
894
            .unwrap();
895
        assert_eq!(res, "foobar".as_bytes());
896

            
897
        let mut stmt = tx.prepare_cached(sql!(
898
            "
899
            SELECT algorithm
900
            FROM compressed_document
901
            WHERE identity_docid = 'C3AB8FF13720E8AD9047DD39466B3C8974E592C2FA383D4A3960714CAEF0C4F2'
902
            "
903
        )).unwrap();
904

            
905
        let algorithms = stmt
906
            .query_map(params![], |row| row.get::<_, String>(0))
907
            .unwrap();
908

            
909
        let algorithms = algorithms.map(|x| x.unwrap()).collect::<HashSet<_>>();
910
        assert_eq!(
911
            algorithms,
912
            HashSet::from([
913
                "deflate".to_string(),
914
                "gzip".to_string(),
915
                "x-zstd".to_string(),
916
                "x-tor-lzma".to_string()
917
            ])
918
        );
919

            
920
        // Now insert the same thing a second time again and see whether the
921
        // ON CONFLICT magic works.
922
        let docid_second =
923
            super::store_insert(&tx, "foobar".as_bytes(), ContentEncoding::iter()).unwrap();
924
        assert_eq!(docid, docid_second);
925

            
926
        // Remove a few compressed entries and get them again.
927
        let n = tx
928
            .execute(
929
                sql!(
930
                    "
931
                    DELETE FROM
932
                    compressed_document
933
                    WHERE algorithm IN ('deflate', 'x-zstd')
934
                    "
935
                ),
936
                params![],
937
            )
938
            .unwrap();
939
        assert_eq!(n, 2);
940

            
941
        let docid_third =
942
            super::store_insert(&tx, "foobar".as_bytes(), ContentEncoding::iter()).unwrap();
943
        assert_eq!(docid, docid_third);
944
        let algorithms = stmt
945
            .query_map(params![], |row| row.get::<_, String>(0))
946
            .unwrap();
947
        let algorithms = algorithms.map(|x| x.unwrap()).collect::<HashSet<_>>();
948
        assert_eq!(
949
            algorithms,
950
            HashSet::from([
951
                "deflate".to_string(),
952
                "gzip".to_string(),
953
                "x-zstd".to_string(),
954
                "x-tor-lzma".to_string()
955
            ])
956
        );
957
    }
958

            
959
    #[test]
960
    fn compress() {
961
        /// Asserts that `res` contains `encoding`.
962
        fn contains(encoding: ContentEncoding, res: &[(ContentEncoding, Vec<u8>)]) {
963
            assert!(res.iter().any(|x| x.0 == encoding));
964
        }
965

            
966
        const INPUT: &[u8] = "foobar".as_bytes();
967

            
968
        // Check whether everything was encoded.
969
        let res = ContentEncoding::iter()
970
            .map(|encoding| (encoding, super::compress(INPUT, encoding).unwrap()))
971
            .collect::<Vec<_>>();
972
        assert_eq!(res.len(), 5);
973
        contains(ContentEncoding::Identity, &res);
974
        contains(ContentEncoding::Deflate, &res);
975
        contains(ContentEncoding::Gzip, &res);
976
        contains(ContentEncoding::XTorLzma, &res);
977
        contains(ContentEncoding::XZstd, &res);
978

            
979
        // Check if we can decode it.
980
        for (encoding, compressed) in res {
981
            let mut decompressed = Vec::new();
982

            
983
            match encoding {
984
                ContentEncoding::Identity => decompressed = compressed,
985
                ContentEncoding::Deflate => {
986
                    DeflateDecoder::new(Cursor::new(compressed))
987
                        .read_to_end(&mut decompressed)
988
                        .unwrap();
989
                }
990
                ContentEncoding::Gzip => {
991
                    GzDecoder::new(Cursor::new(compressed))
992
                        .read_to_end(&mut decompressed)
993
                        .unwrap();
994
                }
995
                ContentEncoding::XTorLzma => {
996
                    lzma_rs::lzma_decompress(&mut Cursor::new(compressed), &mut decompressed)
997
                        .unwrap();
998
                }
999
                ContentEncoding::XZstd => {
                    decompressed = zstd::decode_all(Cursor::new(compressed)).unwrap();
                }
            }
            assert_eq!(decompressed, INPUT);
        }
    }
}