1
//! Define a [`CompoundRuntime`] part that can be built from several component
2
//! pieces.
3

            
4
use std::{net, sync::Arc, time::Duration};
5

            
6
use crate::traits::*;
7
use crate::{CoarseInstant, CoarseTimeProvider};
8
use async_trait::async_trait;
9
use educe::Educe;
10
use futures::{future::FutureObj, task::Spawn};
11
use std::future::Future;
12
use std::io::Result as IoResult;
13
use tor_general_addr::unix;
14
use tracing::instrument;
15
use web_time_compat::{Instant, SystemTime};
16

            
17
/// A runtime made of several parts, each of which implements one trait-group.
18
///
19
/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20
/// the `SleepR` component should implement [`SleepProvider`];
21
/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22
/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23
/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24
/// and
25
/// the `TlsR` component should implement [`TlsProvider`].
26
///
27
/// You can use this structure to create new runtimes in two ways: either by
28
/// overriding a single part of an existing runtime, or by building an entirely
29
/// new runtime from pieces.
30
#[derive(Educe)]
31
#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32
pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33
    /// The actual collection of Runtime objects.
34
    ///
35
    /// We wrap this in an Arc rather than requiring that each item implement
36
    /// Clone, though we could change our minds later on.
37
    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38
}
39

            
40
/// A collection of objects implementing that traits that make up a [`Runtime`]
41
struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42
    /// A `Spawn` and `BlockOn` implementation.
43
    spawn: TaskR,
44
    /// A `SleepProvider` implementation.
45
    sleep: SleepR,
46
    /// A `CoarseTimeProvider`` implementation.
47
    coarse_time: CoarseTimeR,
48
    /// A `NetStreamProvider<net::SocketAddr>` implementation
49
    tcp: TcpR,
50
    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51
    unix: UnixR,
52
    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53
    tls: TlsR,
54
    /// A `UdpProvider` implementation
55
    udp: UdpR,
56
}
57

            
58
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59
    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60
{
61
    /// Construct a new CompoundRuntime from its components.
62
32248
    pub fn new(
63
32248
        spawn: TaskR,
64
32248
        sleep: SleepR,
65
32248
        coarse_time: CoarseTimeR,
66
32248
        tcp: TcpR,
67
32248
        unix: UnixR,
68
32248
        tls: TlsR,
69
32248
        udp: UdpR,
70
32248
    ) -> Self {
71
        #[allow(clippy::arc_with_non_send_sync)]
72
32248
        CompoundRuntime {
73
32248
            inner: Arc::new(Inner {
74
32248
                spawn,
75
32248
                sleep,
76
32248
                coarse_time,
77
32248
                tcp,
78
32248
                unix,
79
32248
                tls,
80
32248
                udp,
81
32248
            }),
82
32248
        }
83
32248
    }
84
}
85

            
86
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88
where
89
    TaskR: Spawn,
90
{
91
    #[inline]
92
    #[track_caller]
93
1026
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94
1026
        self.inner.spawn.spawn_obj(future)
95
1026
    }
96
}
97

            
98
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100
where
101
    TaskR: Blocking,
102
    SleepR: Clone + Send + Sync + 'static,
103
    CoarseTimeR: Clone + Send + Sync + 'static,
104
    TcpR: Clone + Send + Sync + 'static,
105
    UnixR: Clone + Send + Sync + 'static,
106
    TlsR: Clone + Send + Sync + 'static,
107
    UdpR: Clone + Send + Sync + 'static,
108
{
109
    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110

            
111
    #[inline]
112
    #[track_caller]
113
    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114
    where
115
        F: FnOnce() -> T + Send + 'static,
116
        T: Send + 'static,
117
    {
118
        self.inner.spawn.spawn_blocking(f)
119
    }
120

            
121
    #[inline]
122
    #[track_caller]
123
    fn reenter_block_on<F>(&self, future: F) -> F::Output
124
    where
125
        F: Future,
126
        F::Output: Send + 'static,
127
    {
128
        self.inner.spawn.reenter_block_on(future)
129
    }
130

            
131
    #[track_caller]
132
    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133
    where
134
        F: FnOnce() -> T + Send + 'static,
135
        T: Send + 'static,
136
    {
137
        self.inner.spawn.blocking_io(f)
138
    }
139
}
140

            
141
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143
where
144
    TaskR: ToplevelBlockOn,
145
    SleepR: Clone + Send + Sync + 'static,
146
    CoarseTimeR: Clone + Send + Sync + 'static,
147
    TcpR: Clone + Send + Sync + 'static,
148
    UnixR: Clone + Send + Sync + 'static,
149
    TlsR: Clone + Send + Sync + 'static,
150
    UdpR: Clone + Send + Sync + 'static,
151
{
152
    #[inline]
153
    #[track_caller]
154
1084
    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155
1084
        self.inner.spawn.block_on(future)
156
1084
    }
157
}
158

            
159
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161
where
162
    SleepR: SleepProvider,
163
    TaskR: Clone + Send + Sync + 'static,
164
    CoarseTimeR: Clone + Send + Sync + 'static,
165
    TcpR: Clone + Send + Sync + 'static,
166
    UnixR: Clone + Send + Sync + 'static,
167
    TlsR: Clone + Send + Sync + 'static,
168
    UdpR: Clone + Send + Sync + 'static,
169
{
170
    type SleepFuture = SleepR::SleepFuture;
171

            
172
    #[inline]
173
9907
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174
9907
        self.inner.sleep.sleep(duration)
175
9907
    }
176

            
177
    #[inline]
178
1012
    fn now(&self) -> Instant {
179
1012
        self.inner.sleep.now()
180
1012
    }
181

            
182
    #[inline]
183
512
    fn wallclock(&self) -> SystemTime {
184
512
        self.inner.sleep.wallclock()
185
512
    }
186
}
187

            
188
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190
where
191
    CoarseTimeR: CoarseTimeProvider,
192
    SleepR: Clone + Send + Sync + 'static,
193
    TaskR: Clone + Send + Sync + 'static,
194
    CoarseTimeR: Clone + Send + Sync + 'static,
195
    TcpR: Clone + Send + Sync + 'static,
196
    UnixR: Clone + Send + Sync + 'static,
197
    TlsR: Clone + Send + Sync + 'static,
198
    UdpR: Clone + Send + Sync + 'static,
199
{
200
    #[inline]
201
15351
    fn now_coarse(&self) -> CoarseInstant {
202
15351
        self.inner.coarse_time.now_coarse()
203
15351
    }
204
}
205

            
206
#[async_trait]
207
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209
where
210
    TcpR: NetStreamProvider<net::SocketAddr>,
211
    TaskR: Send + Sync + 'static,
212
    SleepR: Send + Sync + 'static,
213
    CoarseTimeR: Send + Sync + 'static,
214
    TcpR: Send + Sync + 'static,
215
    UnixR: Clone + Send + Sync + 'static,
216
    TlsR: Send + Sync + 'static,
217
    UdpR: Send + Sync + 'static,
218
{
219
    type Stream = TcpR::Stream;
220

            
221
    type Listener = TcpR::Listener;
222

            
223
    type ConnectOptions = TcpR::ConnectOptions;
224
    type ListenOptions = TcpR::ListenOptions;
225

            
226
    #[inline]
227
    #[instrument(skip_all, level = "trace")]
228
    async fn connect(
229
        &self,
230
        addr: &net::SocketAddr,
231
        options: &Self::ConnectOptions,
232
    ) -> IoResult<Self::Stream> {
233
        self.inner.tcp.connect(addr, options).await
234
    }
235

            
236
    #[inline]
237
    async fn listen(
238
        &self,
239
        addr: &net::SocketAddr,
240
        options: &Self::ListenOptions,
241
22
    ) -> IoResult<Self::Listener> {
242
        self.inner.tcp.listen(addr, options).await
243
22
    }
244
}
245

            
246
#[async_trait]
247
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
248
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
249
where
250
    UnixR: NetStreamProvider<unix::SocketAddr>,
251
    TaskR: Send + Sync + 'static,
252
    SleepR: Send + Sync + 'static,
253
    CoarseTimeR: Send + Sync + 'static,
254
    TcpR: Send + Sync + 'static,
255
    UnixR: Clone + Send + Sync + 'static,
256
    TlsR: Send + Sync + 'static,
257
    UdpR: Send + Sync + 'static,
258
{
259
    type Stream = UnixR::Stream;
260

            
261
    type Listener = UnixR::Listener;
262

            
263
    type ConnectOptions = UnixR::ConnectOptions;
264
    type ListenOptions = UnixR::ListenOptions;
265

            
266
    #[inline]
267
    #[instrument(skip_all, level = "trace")]
268
    async fn connect(
269
        &self,
270
        addr: &unix::SocketAddr,
271
        options: &Self::ConnectOptions,
272
    ) -> IoResult<Self::Stream> {
273
        self.inner.unix.connect(addr, options).await
274
    }
275

            
276
    #[inline]
277
    async fn listen(
278
        &self,
279
        addr: &unix::SocketAddr,
280
        options: &Self::ListenOptions,
281
    ) -> IoResult<Self::Listener> {
282
        self.inner.unix.listen(addr, options).await
283
    }
284
}
285

            
286
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
287
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
288
where
289
    TcpR: NetStreamProvider,
290
    TlsR: TlsProvider<S>,
291
    UnixR: Clone + Send + Sync + 'static,
292
    SleepR: Clone + Send + Sync + 'static,
293
    CoarseTimeR: Clone + Send + Sync + 'static,
294
    TaskR: Clone + Send + Sync + 'static,
295
    UdpR: Clone + Send + Sync + 'static,
296
    S: StreamOps,
297
{
298
    type Connector = TlsR::Connector;
299
    type TlsStream = TlsR::TlsStream;
300
    type Acceptor = TlsR::Acceptor;
301
    type TlsServerStream = TlsR::TlsServerStream;
302

            
303
    #[inline]
304
24
    fn tls_connector(&self) -> Self::Connector {
305
24
        self.inner.tls.tls_connector()
306
24
    }
307

            
308
    #[inline]
309
14
    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
310
14
        self.inner.tls.tls_acceptor(settings)
311
14
    }
312

            
313
    #[inline]
314
    fn supports_keying_material_export(&self) -> bool {
315
        self.inner.tls.supports_keying_material_export()
316
    }
317
}
318

            
319
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
320
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
321
{
322
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323
2
        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
324
2
    }
325
}
326

            
327
#[async_trait]
328
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
329
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
330
where
331
    UdpR: UdpProvider,
332
    TaskR: Send + Sync + 'static,
333
    SleepR: Send + Sync + 'static,
334
    CoarseTimeR: Send + Sync + 'static,
335
    TcpR: Send + Sync + 'static,
336
    UnixR: Clone + Send + Sync + 'static,
337
    TlsR: Send + Sync + 'static,
338
    UdpR: Send + Sync + 'static,
339
{
340
    type UdpSocket = UdpR::UdpSocket;
341

            
342
    #[inline]
343
16
    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
344
        self.inner.udp.bind(addr).await
345
16
    }
346
}
347

            
348
/// Module to seal RuntimeSubstExt
349
mod sealed {
350
    /// Helper for sealing RuntimeSubstExt
351
    #[allow(unreachable_pub)]
352
    pub trait Sealed {}
353
}
354
/// Extension trait on Runtime:
355
/// Construct new Runtimes that replace part of an original runtime.
356
///
357
/// (If you need to do more complicated versions of this, you should likely construct
358
/// CompoundRuntime directly.)
359
pub trait RuntimeSubstExt: sealed::Sealed + Sized {
360
    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
361
    fn with_tcp_provider<T>(
362
        &self,
363
        new_tcp: T,
364
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
365
    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
366
    fn with_sleep_provider<T>(
367
        &self,
368
        new_sleep: T,
369
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
370
    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
371
    fn with_coarse_time_provider<T>(
372
        &self,
373
        new_coarse_time: T,
374
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
375
}
376
impl<R: Runtime> sealed::Sealed for R {}
377
impl<R: Runtime + Sized> RuntimeSubstExt for R {
378
    fn with_tcp_provider<T>(
379
        &self,
380
        new_tcp: T,
381
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
382
        CompoundRuntime::new(
383
            self.clone(),
384
            self.clone(),
385
            self.clone(),
386
            new_tcp,
387
            self.clone(),
388
            self.clone(),
389
            self.clone(),
390
        )
391
    }
392

            
393
12
    fn with_sleep_provider<T>(
394
12
        &self,
395
12
        new_sleep: T,
396
12
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
397
12
        CompoundRuntime::new(
398
12
            self.clone(),
399
12
            new_sleep,
400
12
            self.clone(),
401
12
            self.clone(),
402
12
            self.clone(),
403
12
            self.clone(),
404
12
            self.clone(),
405
        )
406
12
    }
407

            
408
12
    fn with_coarse_time_provider<T>(
409
12
        &self,
410
12
        new_coarse_time: T,
411
12
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
412
12
        CompoundRuntime::new(
413
12
            self.clone(),
414
12
            self.clone(),
415
12
            new_coarse_time,
416
12
            self.clone(),
417
12
            self.clone(),
418
12
            self.clone(),
419
12
            self.clone(),
420
        )
421
12
    }
422
}