1
#![cfg_attr(docsrs, feature(doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_time_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
#![allow(clippy::collapsible_if)] // See arti#2342
46
#![deny(clippy::unused_async)]
47
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48

            
49
// TODO #1645 (either remove this, or decide to have it everywhere)
50
#![cfg_attr(not(all(feature = "full")), allow(unused))]
51

            
52
#[cfg(all(
53
    any(feature = "native-tls", feature = "rustls"),
54
    any(feature = "async-std", feature = "tokio", feature = "smol")
55
))]
56
pub(crate) mod impls;
57
pub mod task;
58

            
59
mod coarse_time;
60
mod compound;
61
mod dyn_time;
62
pub mod general;
63
mod opaque;
64
pub mod scheduler;
65
mod timer;
66
mod traits;
67
pub mod unimpl;
68
pub mod unix;
69

            
70
#[cfg(any(feature = "async-std", feature = "tokio", feature = "smol"))]
71
use std::io;
72
pub use traits::{
73
    Blocking, CertifiedConn, CoarseTimeProvider, NetStreamListener, NetStreamProvider,
74
    NoOpStreamOpsHandle, Runtime, SleepProvider, SpawnExt, StreamOps, TlsProvider, ToplevelBlockOn,
75
    ToplevelRuntime, UdpProvider, UdpSocket, UnsupportedStreamOp,
76
};
77

            
78
pub use coarse_time::{CoarseDuration, CoarseInstant, RealCoarseTimeProvider};
79
pub use dyn_time::DynTimeProvider;
80
pub use timer::{SleepProviderExt, Timeout, TimeoutError};
81

            
82
/// Traits used to describe TLS connections and objects that can
83
/// create them.
84
pub mod tls {
85
    #[cfg(all(
86
        any(feature = "native-tls", feature = "rustls"),
87
        any(feature = "async-std", feature = "tokio", feature = "smol")
88
    ))]
89
    pub use crate::impls::unimpl_tls::UnimplementedTls;
90
    pub use crate::traits::{
91
        CertifiedConn, TlsAcceptorSettings, TlsConnector, TlsServerUnsupported,
92
    };
93

            
94
    #[cfg(all(
95
        feature = "native-tls",
96
        any(feature = "tokio", feature = "async-std", feature = "smol")
97
    ))]
98
    pub use crate::impls::native_tls::NativeTlsProvider;
99
    #[cfg(all(
100
        feature = "rustls",
101
        any(feature = "tokio", feature = "async-std", feature = "smol")
102
    ))]
103
    pub use crate::impls::rustls::RustlsProvider;
104
    #[cfg(all(
105
        feature = "rustls",
106
        feature = "tls-server",
107
        any(feature = "tokio", feature = "async-std", feature = "smol")
108
    ))]
109
    pub use crate::impls::rustls::rustls_server::{RustlsAcceptor, RustlsServerStream};
110
}
111

            
112
#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))]
113
pub mod tokio;
114

            
115
#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "async-std"))]
116
pub mod async_std;
117

            
118
#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "smol"))]
119
pub mod smol;
120

            
121
pub use compound::{CompoundRuntime, RuntimeSubstExt};
122

            
123
#[cfg(all(
124
    any(feature = "native-tls", feature = "rustls"),
125
    feature = "async-std",
126
    not(feature = "tokio")
127
))]
128
use async_std as preferred_backend_mod;
129
#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))]
130
use tokio as preferred_backend_mod;
131

            
132
/// The runtime that we prefer to use, out of all the runtimes compiled into the
133
/// tor-rtcompat crate.
134
///
135
/// If `tokio` and `async-std` are both available, we prefer `tokio` for its
136
/// performance.
137
/// If `native_tls` and `rustls` are both available, we prefer `native_tls` since
138
/// it has been used in Arti for longer.
139
///
140
/// The process [**may not fork**](crate#do-not-fork)
141
/// (except, very carefully, before exec)
142
/// after creating this or any other `Runtime`.
143
#[cfg(all(
144
    any(feature = "native-tls", feature = "rustls"),
145
    any(feature = "async-std", feature = "tokio")
146
))]
147
#[derive(Clone)]
148
pub struct PreferredRuntime {
149
    /// The underlying runtime object.
150
    inner: preferred_backend_mod::PreferredRuntime,
151
}
152

            
153
#[cfg(all(
154
    any(feature = "native-tls", feature = "rustls"),
155
    any(feature = "async-std", feature = "tokio")
156
))]
157
crate::opaque::implement_opaque_runtime! {
158
    PreferredRuntime { inner : preferred_backend_mod::PreferredRuntime }
159
}
160

            
161
#[cfg(all(
162
    any(feature = "native-tls", feature = "rustls"),
163
    any(feature = "async-std", feature = "tokio")
164
))]
165
impl PreferredRuntime {
166
    /// Obtain a [`PreferredRuntime`] from the currently running asynchronous runtime.
167
    /// Generally, this is what you want.
168
    ///
169
    /// This tries to get a handle to a currently running asynchronous runtime, and
170
    /// wraps it; the returned [`PreferredRuntime`] isn't the same thing as the
171
    /// asynchronous runtime object itself (e.g. `tokio::runtime::Runtime`).
172
    ///
173
    /// # Panics
174
    ///
175
    /// When `tor-rtcompat` is compiled with the `tokio` feature enabled
176
    /// (regardless of whether the `async-std` feature is also enabled),
177
    /// panics if called outside of Tokio runtime context.
178
    /// See `tokio::runtime::Handle::current`.
179
    ///
180
    /// # Usage notes
181
    ///
182
    /// Once you have a runtime returned by this function, you should
183
    /// just create more handles to it via [`Clone`].
184
    ///
185
    /// # Limitations
186
    ///
187
    /// If the `tor-rtcompat` crate was compiled with `tokio` support,
188
    /// this function will never return a runtime based on `async_std`.
189
    ///
190
    /// The process [**may not fork**](crate#do-not-fork)
191
    /// (except, very carefully, before exec)
192
    /// after creating this or any other `Runtime`.
193
    //
194
    // ## Note to Arti developers
195
    //
196
    // We should never call this from inside other Arti crates, or from
197
    // library crates that want to support multiple runtimes!  This
198
    // function is for Arti _users_ who want to wrap some existing Tokio
199
    // or Async_std runtime as a [`Runtime`].  It is not for library
200
    // crates that want to work with multiple runtimes.
201
204
    pub fn current() -> io::Result<Self> {
202
204
        let rt = preferred_backend_mod::PreferredRuntime::current()?;
203

            
204
204
        Ok(Self { inner: rt })
205
204
    }
206

            
207
    /// Create and return a new instance of the default [`Runtime`].
208
    ///
209
    /// Generally you should call this function at most once, and then use
210
    /// [`Clone::clone()`] to create additional references to that runtime.
211
    ///
212
    /// Tokio users may want to avoid this function and instead obtain a runtime using
213
    /// [`PreferredRuntime::current`]: this function always _builds_ a runtime,
214
    /// and if you already have a runtime, that isn't what you want with Tokio.
215
    ///
216
    /// If you need more fine-grained control over a runtime, you can create it
217
    /// using an appropriate builder type or function.
218
    ///
219
    /// The process [**may not fork**](crate#do-not-fork)
220
    /// (except, very carefully, before exec)
221
    /// after creating this or any other `Runtime`.
222
    //
223
    // ## Note to Arti developers
224
    //
225
    // We should never call this from inside other Arti crates, or from
226
    // library crates that want to support multiple runtimes!  This
227
    // function is for Arti _users_ who want to wrap some existing Tokio
228
    // or Async_std runtime as a [`Runtime`].  It is not for library
229
    // crates that want to work with multiple runtimes.
230
2723
    pub fn create() -> io::Result<Self> {
231
2723
        let rt = preferred_backend_mod::PreferredRuntime::create()?;
232

            
233
2723
        Ok(Self { inner: rt })
234
2723
    }
235

            
236
    /// Helper to run a single test function in a freshly created runtime.
237
    ///
238
    /// # Panics
239
    ///
240
    /// Panics if we can't create this runtime.
241
    ///
242
    /// # Warning
243
    ///
244
    /// This API is **NOT** for consumption outside Arti. Semver guarantees are not provided.
245
    #[doc(hidden)]
246
96
    pub fn run_test<P, F, O>(func: P) -> O
247
96
    where
248
96
        P: FnOnce(Self) -> F,
249
96
        F: futures::Future<Output = O>,
250
    {
251
96
        let runtime = Self::create().expect("Failed to create runtime");
252
96
        runtime.clone().block_on(func(runtime))
253
96
    }
254
}
255

            
256
/// Helpers for test_with_all_runtimes
257
///
258
/// # Warning
259
///
260
/// This API is **NOT** for consumption outside Arti. Semver guarantees are not provided.
261
#[doc(hidden)]
262
pub mod testing__ {
263
    /// A trait for an object that might represent a test failure, or which
264
    /// might just be `()`.
265
    pub trait TestOutcome {
266
        /// Abort if the test has failed.
267
        fn check_ok(&self);
268
    }
269
    impl TestOutcome for () {
270
        fn check_ok(&self) {}
271
    }
272
    impl<E: std::fmt::Debug> TestOutcome for Result<(), E> {
273
        fn check_ok(&self) {
274
            self.as_ref().expect("Test failure");
275
        }
276
    }
277
}
278

            
279
/// Helper: define a macro that expands a token tree iff a pair of features are
280
/// both present.
281
macro_rules! declare_conditional_macro {
282
    ( $(#[$meta:meta])* macro $name:ident = ($f1:expr, $f2:expr) ) => {
283
        $( #[$meta] )*
284
        #[cfg(all(feature=$f1, feature=$f2))]
285
        #[macro_export]
286
        macro_rules! $name {
287
            ($tt:tt) => {
288
                $tt
289
            };
290
        }
291

            
292
        $( #[$meta] )*
293
        #[cfg(not(all(feature=$f1, feature=$f2)))]
294
        #[macro_export]
295
        macro_rules! $name {
296
            ($tt:tt) => {};
297
        }
298

            
299
        // Needed so that we can access this macro at this path, both within the
300
        // crate and without.
301
        pub use $name;
302
    };
303
}
304

            
305
/// Defines macros that will expand when certain runtimes are available.
306
#[doc(hidden)]
307
pub mod cond {
308
    declare_conditional_macro! {
309
        /// Expand a token tree if the TokioNativeTlsRuntime is available.
310
        #[doc(hidden)]
311
        macro if_tokio_native_tls_present = ("tokio", "native-tls")
312
    }
313
    declare_conditional_macro! {
314
        /// Expand a token tree if the TokioRustlsRuntime is available.
315
        #[doc(hidden)]
316
        macro if_tokio_rustls_present = ("tokio", "rustls")
317
    }
318
    declare_conditional_macro! {
319
        /// Expand a token tree if the TokioNativeTlsRuntime is available.
320
        #[doc(hidden)]
321
        macro if_async_std_native_tls_present = ("async-std", "native-tls")
322
    }
323
    declare_conditional_macro! {
324
        /// Expand a token tree if the TokioNativeTlsRuntime is available.
325
        #[doc(hidden)]
326
        macro if_async_std_rustls_present = ("async-std", "rustls")
327
    }
328
    declare_conditional_macro! {
329
        /// Expand a token tree if the SmolNativeTlsRuntime is available.
330
        #[doc(hidden)]
331
        macro if_smol_native_tls_present = ("smol", "native-tls")
332
    }
333
    declare_conditional_macro! {
334
        /// Expand a token tree if the SmolRustlsRuntime is available.
335
        #[doc(hidden)]
336
        macro if_smol_rustls_present = ("smol", "rustls")
337
    }
338
}
339

            
340
/// Run a test closure, passing as argument every supported runtime.
341
///
342
/// Usually, prefer `tor_rtmock::MockRuntime::test_with_various` to this.
343
/// Use this macro only when you need to interact with things
344
/// that `MockRuntime` can't handle,
345
///
346
/// If everything in your test case is supported by `MockRuntime`,
347
/// you should use that instead:
348
/// that will give superior test coverage *and* a (more) deterministic test.
349
///
350
/// (This is a macro so that it can repeat the closure as multiple separate
351
/// expressions, so it can take on two different types, if needed.)
352
//
353
// NOTE(eta): changing this #[cfg] can affect tests inside this crate that use
354
//            this macro, like in scheduler.rs
355
#[macro_export]
356
#[cfg(all(
357
    any(feature = "native-tls", feature = "rustls"),
358
    any(feature = "tokio", feature = "async-std", feature = "smol"),
359
))]
360
macro_rules! test_with_all_runtimes {
361
    ( $fn:expr ) => {{
362
        use $crate::cond::*;
363
        use $crate::testing__::TestOutcome;
364
        // We have to do this outcome-checking business rather than just using
365
        // the ? operator or calling expect() because some of the closures that
366
        // we use this macro with return (), and some return Result.
367

            
368
        if_tokio_native_tls_present! {{
369
           $crate::tokio::TokioNativeTlsRuntime::run_test($fn).check_ok();
370
        }}
371
        if_tokio_rustls_present! {{
372
            $crate::tokio::TokioRustlsRuntime::run_test($fn).check_ok();
373
        }}
374
        if_async_std_native_tls_present! {{
375
            $crate::async_std::AsyncStdNativeTlsRuntime::run_test($fn).check_ok();
376
        }}
377
        if_async_std_rustls_present! {{
378
            $crate::async_std::AsyncStdRustlsRuntime::run_test($fn).check_ok();
379
        }}
380
        if_smol_native_tls_present! {{
381
            $crate::smol::SmolNativeTlsRuntime::run_test($fn).check_ok();
382
        }}
383
        if_smol_rustls_present! {{
384
            $crate::smol::SmolRustlsRuntime::run_test($fn).check_ok();
385
        }}
386
    }};
387
}
388

            
389
/// Run a test closure, passing as argument one supported runtime.
390
///
391
/// Usually, prefer `tor_rtmock::MockRuntime::test_with_various` to this.
392
/// Use this macro only when you need to interact with things
393
/// that `MockRuntime` can't handle.
394
///
395
/// If everything in your test case is supported by `MockRuntime`,
396
/// you should use that instead:
397
/// that will give superior test coverage *and* a (more) deterministic test.
398
///
399
/// (Always prefers tokio if present.)
400
#[macro_export]
401
#[cfg(all(
402
    any(feature = "native-tls", feature = "rustls"),
403
    any(feature = "tokio", feature = "async-std"),
404
))]
405
macro_rules! test_with_one_runtime {
406
    ( $fn:expr ) => {{ $crate::PreferredRuntime::run_test($fn) }};
407
}
408

            
409
#[cfg(all(
410
    test,
411
    any(feature = "native-tls", feature = "rustls"),
412
    any(feature = "async-std", feature = "tokio", feature = "smol"),
413
    not(miri), // Many of these tests use real sockets or SystemTime.
414
))]
415
mod test {
416
    // @@ begin test lint list maintained by maint/add_warning @@
417
    #![allow(clippy::bool_assert_comparison)]
418
    #![allow(clippy::clone_on_copy)]
419
    #![allow(clippy::dbg_macro)]
420
    #![allow(clippy::mixed_attributes_style)]
421
    #![allow(clippy::print_stderr)]
422
    #![allow(clippy::print_stdout)]
423
    #![allow(clippy::single_char_pattern)]
424
    #![allow(clippy::unwrap_used)]
425
    #![allow(clippy::unchecked_time_subtraction)]
426
    #![allow(clippy::useless_vec)]
427
    #![allow(clippy::needless_pass_by_value)]
428
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
429
    #![allow(clippy::unnecessary_wraps)]
430
    use crate::SleepProviderExt;
431
    use crate::ToplevelRuntime;
432

            
433
    use crate::traits::*;
434

            
435
    use futures::io::{AsyncReadExt, AsyncWriteExt};
436
    use futures::stream::StreamExt;
437
    use native_tls_crate as native_tls;
438
    use std::io::Result as IoResult;
439
    use std::net::SocketAddr;
440
    use std::net::{Ipv4Addr, SocketAddrV4};
441
    use std::time::{Duration, Instant};
442

            
443
    // Test "sleep" with a tiny delay, and make sure that at least that
444
    // much delay happens.
445
    fn small_delay<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
446
        let rt = runtime.clone();
447
        runtime.block_on(async {
448
            let i1 = Instant::now();
449
            let one_msec = Duration::from_millis(1);
450
            rt.sleep(one_msec).await;
451
            let i2 = Instant::now();
452
            assert!(i2 >= i1 + one_msec);
453
        });
454
        Ok(())
455
    }
456

            
457
    // Try a timeout operation that will succeed.
458
    fn small_timeout_ok<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
459
        let rt = runtime.clone();
460
        runtime.block_on(async {
461
            let one_day = Duration::from_secs(86400);
462
            let outcome = rt.timeout(one_day, async { 413_u32 }).await;
463
            assert_eq!(outcome, Ok(413));
464
        });
465
        Ok(())
466
    }
467

            
468
    // Try a timeout operation that will time out.
469
    fn small_timeout_expire<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
470
        use futures::future::pending;
471

            
472
        let rt = runtime.clone();
473
        runtime.block_on(async {
474
            let one_micros = Duration::from_micros(1);
475
            let outcome = rt.timeout(one_micros, pending::<()>()).await;
476
            assert_eq!(outcome, Err(crate::TimeoutError));
477
            assert_eq!(
478
                outcome.err().unwrap().to_string(),
479
                "Timeout expired".to_string()
480
            );
481
        });
482
        Ok(())
483
    }
484
    // Try a little wallclock delay.
485
    //
486
    // NOTE: This test will fail if the clock jumps a lot while it's
487
    // running.  We should use simulated time instead.
488
    fn tiny_wallclock<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
489
        let rt = runtime.clone();
490
        runtime.block_on(async {
491
            let i1 = Instant::now();
492
            let now = runtime.wallclock();
493
            let one_millis = Duration::from_millis(1);
494
            let one_millis_later = now + one_millis;
495

            
496
            rt.sleep_until_wallclock(one_millis_later).await;
497

            
498
            let i2 = Instant::now();
499
            let newtime = runtime.wallclock();
500
            assert!(newtime >= one_millis_later);
501
            assert!(i2 - i1 >= one_millis);
502
        });
503
        Ok(())
504
    }
505

            
506
    // Try connecting to ourself and sending a little data.
507
    //
508
    // NOTE: requires Ipv4 localhost.
509
    fn self_connect_tcp<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
510
        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
511
        let rt1 = runtime.clone();
512

            
513
        let listener = runtime.block_on(rt1.listen(&(SocketAddr::from(localhost))))?;
514
        let addr = listener.local_addr()?;
515

            
516
        runtime.block_on(async {
517
            let task1 = async {
518
                let mut buf = vec![0_u8; 11];
519
                let (mut con, _addr) = listener.incoming().next().await.expect("closed?")?;
520
                con.read_exact(&mut buf[..]).await?;
521
                IoResult::Ok(buf)
522
            };
523
            let task2 = async {
524
                let mut con = rt1.connect(&addr).await?;
525
                con.write_all(b"Hello world").await?;
526
                con.flush().await?;
527
                IoResult::Ok(())
528
            };
529

            
530
            let (data, send_r) = futures::join!(task1, task2);
531
            send_r?;
532

            
533
            assert_eq!(&data?[..], b"Hello world");
534

            
535
            Ok(())
536
        })
537
    }
538

            
539
    // Try connecting to ourself and sending a little data.
540
    //
541
    // NOTE: requires Ipv4 localhost.
542
    fn self_connect_udp<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
543
        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
544
        let rt1 = runtime.clone();
545

            
546
        let socket1 = runtime.block_on(rt1.bind(&(localhost.into())))?;
547
        let addr1 = socket1.local_addr()?;
548

            
549
        let socket2 = runtime.block_on(rt1.bind(&(localhost.into())))?;
550
        let addr2 = socket2.local_addr()?;
551

            
552
        runtime.block_on(async {
553
            let task1 = async {
554
                let mut buf = [0_u8; 16];
555
                let (len, addr) = socket1.recv(&mut buf[..]).await?;
556
                IoResult::Ok((buf[..len].to_vec(), addr))
557
            };
558
            let task2 = async {
559
                socket2.send(b"Hello world", &addr1).await?;
560
                IoResult::Ok(())
561
            };
562

            
563
            let (recv_r, send_r) = futures::join!(task1, task2);
564
            send_r?;
565
            let (buff, addr) = recv_r?;
566
            assert_eq!(addr2, addr);
567
            assert_eq!(&buff, b"Hello world");
568

            
569
            Ok(())
570
        })
571
    }
572

            
573
    // Try out our incoming connection stream code.
574
    //
575
    // We launch a few connections and make sure that we can read data on
576
    // them.
577
    fn listener_stream<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
578
        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
579
        let rt1 = runtime.clone();
580

            
581
        let listener = runtime
582
            .block_on(rt1.listen(&SocketAddr::from(localhost)))
583
            .unwrap();
584
        let addr = listener.local_addr().unwrap();
585
        let mut stream = listener.incoming();
586

            
587
        runtime.block_on(async {
588
            let task1 = async {
589
                let mut n = 0_u32;
590
                loop {
591
                    let (mut con, _addr) = stream.next().await.unwrap()?;
592
                    let mut buf = [0_u8; 11];
593
                    con.read_exact(&mut buf[..]).await?;
594
                    n += 1;
595
                    if &buf[..] == b"world done!" {
596
                        break IoResult::Ok(n);
597
                    }
598
                }
599
            };
600
            let task2 = async {
601
                for _ in 0_u8..5 {
602
                    let mut con = rt1.connect(&addr).await?;
603
                    con.write_all(b"Hello world").await?;
604
                    con.flush().await?;
605
                }
606
                let mut con = rt1.connect(&addr).await?;
607
                con.write_all(b"world done!").await?;
608
                con.flush().await?;
609
                con.close().await?;
610
                IoResult::Ok(())
611
            };
612

            
613
            let (n, send_r) = futures::join!(task1, task2);
614
            send_r?;
615

            
616
            assert_eq!(n?, 6);
617

            
618
            Ok(())
619
        })
620
    }
621

            
622
    // Try listening on an address and connecting there, except using TLS.
623
    //
624
    // Note that since we didn't have TLS server support when this test was first written,
625
    // we're going to use a thread.
626
    fn simple_tls<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
627
        /*
628
         A simple expired self-signed rsa-2048 certificate.
629

            
630
         Generated by running the make-cert.c program in tor-rtcompat/test-data-helper,
631
         and then making a PFX file using
632

            
633
         openssl pkcs12 -export -certpbe PBE-SHA1-3DES -out test.pfx -inkey test.key -in test.crt
634

            
635
         The password is "abc".
636
        */
637
        static PFX_ID: &[u8] = include_bytes!("test.pfx");
638
        // Note that we need to set a password on the pkcs12 file, since apparently
639
        // OSX doesn't support pkcs12 with empty passwords. (That was arti#111).
640
        static PFX_PASSWORD: &str = "abc";
641

            
642
        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
643
        let listener = std::net::TcpListener::bind(localhost)?;
644
        let addr = listener.local_addr()?;
645

            
646
        let identity = native_tls::Identity::from_pkcs12(PFX_ID, PFX_PASSWORD).unwrap();
647

            
648
        // See note on function for why we're using a thread here.
649
        let th = std::thread::spawn(move || {
650
            // Accept a single TLS connection and run an echo server
651
            use std::io::{Read, Write};
652
            let acceptor = native_tls::TlsAcceptor::new(identity).unwrap();
653
            let (con, _addr) = listener.accept()?;
654
            let mut con = acceptor.accept(con).unwrap();
655
            let mut buf = [0_u8; 16];
656
            loop {
657
                let n = con.read(&mut buf)?;
658
                if n == 0 {
659
                    break;
660
                }
661
                con.write_all(&buf[..n])?;
662
            }
663
            IoResult::Ok(())
664
        });
665

            
666
        let connector = runtime.tls_connector();
667

            
668
        runtime.block_on(async {
669
            let text = b"I Suddenly Dont Understand Anything";
670
            let mut buf = vec![0_u8; text.len()];
671
            let conn = runtime.connect(&addr).await?;
672
            let mut conn = connector.negotiate_unvalidated(conn, "Kan.Aya").await?;
673
            assert!(conn.peer_certificate()?.is_some());
674
            conn.write_all(text).await?;
675
            conn.flush().await?;
676
            conn.read_exact(&mut buf[..]).await?;
677
            assert_eq!(&buf[..], text);
678
            conn.close().await?;
679
            IoResult::Ok(())
680
        })?;
681

            
682
        th.join().unwrap()?;
683
        IoResult::Ok(())
684
    }
685

            
686
    fn simple_tls_server<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
687
        let mut rng = tor_basic_utils::test_rng::testing_rng();
688
        let tls_cert = tor_cert_x509::TlsKeyAndCert::create(
689
            &mut rng,
690
            std::time::SystemTime::now(),
691
            "prospit.example.org",
692
            "derse.example.org",
693
        )
694
        .unwrap();
695
        let cert = tls_cert.certificates_der()[0].to_vec();
696
        let settings = TlsAcceptorSettings::new(tls_cert).unwrap();
697

            
698
        let Ok(tls_acceptor) = runtime.tls_acceptor(settings) else {
699
            println!("Skipping tls-server test for runtime {:?}", &runtime);
700
            return IoResult::Ok(());
701
        };
702
        println!("Running tls-server test for runtime {:?}", &runtime);
703

            
704
        let tls_connector = runtime.tls_connector();
705

            
706
        let localhost: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0).into();
707
        let rt1 = runtime.clone();
708

            
709
        let msg = b"Derse Reviles Him And Outlaws Frogs Wherever They Can";
710
        runtime.block_on(async move {
711
            let listener = runtime.listen(&localhost).await.unwrap();
712
            let address = listener.local_addr().unwrap();
713

            
714
            let h1 = runtime
715
                .spawn_with_handle(async move {
716
                    let conn = listener.incoming().next().await.unwrap().unwrap().0;
717
                    let mut conn = tls_acceptor.negotiate_unvalidated(conn, "").await.unwrap();
718

            
719
                    let mut buf = vec![];
720
                    conn.read_to_end(&mut buf).await.unwrap();
721
                    (buf, conn.own_certificate().unwrap().unwrap().into_owned())
722
                })
723
                .unwrap();
724

            
725
            let h2 = runtime
726
                .spawn_with_handle(async move {
727
                    let conn = rt1.connect(&address).await.unwrap();
728
                    let mut conn = tls_connector
729
                        .negotiate_unvalidated(conn, "prospit.example.org")
730
                        .await
731
                        .unwrap();
732
                    conn.write_all(msg).await.unwrap();
733
                    conn.close().await.unwrap();
734
                    conn.peer_certificate().unwrap().unwrap().into_owned()
735
                })
736
                .unwrap();
737

            
738
            let (received, server_own_cert) = h1.await;
739
            let client_peer_cert = h2.await;
740
            assert_eq!(received, msg);
741
            assert_eq!(&server_own_cert, &cert);
742
            assert_eq!(&client_peer_cert, &cert);
743
        });
744
        IoResult::Ok(())
745
    }
746

            
747
    macro_rules! tests_with_runtime {
748
        { $runtime:expr  => $($id:ident),* $(,)? } => {
749
            $(
750
                #[test]
751
                fn $id() -> std::io::Result<()> {
752
                    super::$id($runtime)
753
                }
754
            )*
755
        }
756
    }
757

            
758
    macro_rules! runtime_tests {
759
        { $($id:ident),* $(,)? } =>
760
        {
761
           #[cfg(feature="tokio")]
762
            mod tokio_runtime_tests {
763
                tests_with_runtime! { &crate::tokio::PreferredRuntime::create()? => $($id),* }
764
            }
765
            #[cfg(feature="async-std")]
766
            mod async_std_runtime_tests {
767
                tests_with_runtime! { &crate::async_std::PreferredRuntime::create()? => $($id),* }
768
            }
769
            #[cfg(feature="smol")]
770
            mod smol_runtime_tests {
771
                tests_with_runtime! { &crate::smol::PreferredRuntime::create()? => $($id),* }
772
            }
773
            mod default_runtime_tests {
774
                tests_with_runtime! { &crate::PreferredRuntime::create()? => $($id),* }
775
            }
776
        }
777
    }
778

            
779
    macro_rules! tls_runtime_tests {
780
        { $($id:ident),* $(,)? } =>
781
        {
782
            #[cfg(all(feature="tokio", feature = "native-tls"))]
783
            mod tokio_native_tls_tests {
784
                tests_with_runtime! { &crate::tokio::TokioNativeTlsRuntime::create()? => $($id),* }
785
            }
786
            #[cfg(all(feature="async-std", feature = "native-tls"))]
787
            mod async_std_native_tls_tests {
788
                tests_with_runtime! { &crate::async_std::AsyncStdNativeTlsRuntime::create()? => $($id),* }
789
            }
790
            #[cfg(all(feature="smol", feature = "native-tls"))]
791
            mod smol_native_tls_tests {
792
                tests_with_runtime! { &crate::smol::SmolNativeTlsRuntime::create()? => $($id),* }
793
            }
794
            #[cfg(all(feature="tokio", feature="rustls"))]
795
            mod tokio_rustls_tests {
796
                tests_with_runtime! {  &crate::tokio::TokioRustlsRuntime::create()? => $($id),* }
797
            }
798
            #[cfg(all(feature="async-std", feature="rustls"))]
799
            mod async_std_rustls_tests {
800
                tests_with_runtime! {  &crate::async_std::AsyncStdRustlsRuntime::create()? => $($id),* }
801
            }
802
            #[cfg(all(feature="smol", feature="rustls"))]
803
            mod smol_rustls_tests {
804
                tests_with_runtime! {  &crate::smol::SmolRustlsRuntime::create()? => $($id),* }
805
            }
806
            mod default_runtime_tls_tests {
807
                tests_with_runtime! { &crate::PreferredRuntime::create()? => $($id),* }
808
            }
809
        }
810
    }
811

            
812
    runtime_tests! {
813
        small_delay,
814
        small_timeout_ok,
815
        small_timeout_expire,
816
        tiny_wallclock,
817
        self_connect_tcp,
818
        self_connect_udp,
819
        listener_stream,
820
    }
821

            
822
    tls_runtime_tests! {
823
        simple_tls,
824
        simple_tls_server,
825
    }
826
}