1
//! Define a [`CompoundRuntime`] part that can be built from several component
2
//! pieces.
3

            
4
use std::{net, sync::Arc, time::Duration};
5

            
6
use crate::traits::*;
7
use crate::{CoarseInstant, CoarseTimeProvider};
8
use async_trait::async_trait;
9
use educe::Educe;
10
use futures::{future::FutureObj, task::Spawn};
11
use std::future::Future;
12
use std::io::Result as IoResult;
13
use std::time::{Instant, SystemTime};
14
use tor_general_addr::unix;
15
use tracing::instrument;
16

            
17
/// A runtime made of several parts, each of which implements one trait-group.
18
///
19
/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20
/// the `SleepR` component should implement [`SleepProvider`];
21
/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22
/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23
/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24
/// and
25
/// the `TlsR` component should implement [`TlsProvider`].
26
///
27
/// You can use this structure to create new runtimes in two ways: either by
28
/// overriding a single part of an existing runtime, or by building an entirely
29
/// new runtime from pieces.
30
#[derive(Educe)]
31
#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32
pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33
    /// The actual collection of Runtime objects.
34
    ///
35
    /// We wrap this in an Arc rather than requiring that each item implement
36
    /// Clone, though we could change our minds later on.
37
    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38
}
39

            
40
/// A collection of objects implementing that traits that make up a [`Runtime`]
41
struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42
    /// A `Spawn` and `BlockOn` implementation.
43
    spawn: TaskR,
44
    /// A `SleepProvider` implementation.
45
    sleep: SleepR,
46
    /// A `CoarseTimeProvider`` implementation.
47
    coarse_time: CoarseTimeR,
48
    /// A `NetStreamProvider<net::SocketAddr>` implementation
49
    tcp: TcpR,
50
    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51
    unix: UnixR,
52
    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53
    tls: TlsR,
54
    /// A `UdpProvider` implementation
55
    udp: UdpR,
56
}
57

            
58
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59
    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60
{
61
    /// Construct a new CompoundRuntime from its components.
62
30886
    pub fn new(
63
30886
        spawn: TaskR,
64
30886
        sleep: SleepR,
65
30886
        coarse_time: CoarseTimeR,
66
30886
        tcp: TcpR,
67
30886
        unix: UnixR,
68
30886
        tls: TlsR,
69
30886
        udp: UdpR,
70
30886
    ) -> Self {
71
        #[allow(clippy::arc_with_non_send_sync)]
72
30886
        CompoundRuntime {
73
30886
            inner: Arc::new(Inner {
74
30886
                spawn,
75
30886
                sleep,
76
30886
                coarse_time,
77
30886
                tcp,
78
30886
                unix,
79
30886
                tls,
80
30886
                udp,
81
30886
            }),
82
30886
        }
83
30886
    }
84
}
85

            
86
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88
where
89
    TaskR: Spawn,
90
{
91
    #[inline]
92
    #[track_caller]
93
1264
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94
1264
        self.inner.spawn.spawn_obj(future)
95
1264
    }
96
}
97

            
98
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100
where
101
    TaskR: Blocking,
102
    SleepR: Clone + Send + Sync + 'static,
103
    CoarseTimeR: Clone + Send + Sync + 'static,
104
    TcpR: Clone + Send + Sync + 'static,
105
    UnixR: Clone + Send + Sync + 'static,
106
    TlsR: Clone + Send + Sync + 'static,
107
    UdpR: Clone + Send + Sync + 'static,
108
{
109
    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110

            
111
    #[inline]
112
    #[track_caller]
113
    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114
    where
115
        F: FnOnce() -> T + Send + 'static,
116
        T: Send + 'static,
117
    {
118
        self.inner.spawn.spawn_blocking(f)
119
    }
120

            
121
    #[inline]
122
    #[track_caller]
123
    fn reenter_block_on<F>(&self, future: F) -> F::Output
124
    where
125
        F: Future,
126
        F::Output: Send + 'static,
127
    {
128
        self.inner.spawn.reenter_block_on(future)
129
    }
130

            
131
    #[track_caller]
132
    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133
    where
134
        F: FnOnce() -> T + Send + 'static,
135
        T: Send + 'static,
136
    {
137
        self.inner.spawn.blocking_io(f)
138
    }
139
}
140

            
141
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143
where
144
    TaskR: ToplevelBlockOn,
145
    SleepR: Clone + Send + Sync + 'static,
146
    CoarseTimeR: Clone + Send + Sync + 'static,
147
    TcpR: Clone + Send + Sync + 'static,
148
    UnixR: Clone + Send + Sync + 'static,
149
    TlsR: Clone + Send + Sync + 'static,
150
    UdpR: Clone + Send + Sync + 'static,
151
{
152
    #[inline]
153
    #[track_caller]
154
1080
    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155
1080
        self.inner.spawn.block_on(future)
156
1080
    }
157
}
158

            
159
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161
where
162
    SleepR: SleepProvider,
163
    TaskR: Clone + Send + Sync + 'static,
164
    CoarseTimeR: Clone + Send + Sync + 'static,
165
    TcpR: Clone + Send + Sync + 'static,
166
    UnixR: Clone + Send + Sync + 'static,
167
    TlsR: Clone + Send + Sync + 'static,
168
    UdpR: Clone + Send + Sync + 'static,
169
{
170
    type SleepFuture = SleepR::SleepFuture;
171

            
172
    #[inline]
173
11147
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174
11147
        self.inner.sleep.sleep(duration)
175
11147
    }
176

            
177
    #[inline]
178
    fn now(&self) -> Instant {
179
        self.inner.sleep.now()
180
    }
181

            
182
    #[inline]
183
36
    fn wallclock(&self) -> SystemTime {
184
36
        self.inner.sleep.wallclock()
185
36
    }
186
}
187

            
188
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190
where
191
    CoarseTimeR: CoarseTimeProvider,
192
    SleepR: Clone + Send + Sync + 'static,
193
    TaskR: Clone + Send + Sync + 'static,
194
    CoarseTimeR: Clone + Send + Sync + 'static,
195
    TcpR: Clone + Send + Sync + 'static,
196
    UnixR: Clone + Send + Sync + 'static,
197
    TlsR: Clone + Send + Sync + 'static,
198
    UdpR: Clone + Send + Sync + 'static,
199
{
200
    #[inline]
201
14939
    fn now_coarse(&self) -> CoarseInstant {
202
14939
        self.inner.coarse_time.now_coarse()
203
14939
    }
204
}
205

            
206
#[async_trait]
207
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209
where
210
    TcpR: NetStreamProvider<net::SocketAddr>,
211
    TaskR: Send + Sync + 'static,
212
    SleepR: Send + Sync + 'static,
213
    CoarseTimeR: Send + Sync + 'static,
214
    TcpR: Send + Sync + 'static,
215
    UnixR: Clone + Send + Sync + 'static,
216
    TlsR: Send + Sync + 'static,
217
    UdpR: Send + Sync + 'static,
218
{
219
    type Stream = TcpR::Stream;
220

            
221
    type Listener = TcpR::Listener;
222

            
223
    #[inline]
224
    #[instrument(skip_all, level = "trace")]
225
    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
226
        self.inner.tcp.connect(addr).await
227
    }
228

            
229
    #[inline]
230
22
    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
231
        self.inner.tcp.listen(addr).await
232
22
    }
233
}
234

            
235
#[async_trait]
236
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
237
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
238
where
239
    UnixR: NetStreamProvider<unix::SocketAddr>,
240
    TaskR: Send + Sync + 'static,
241
    SleepR: Send + Sync + 'static,
242
    CoarseTimeR: Send + Sync + 'static,
243
    TcpR: Send + Sync + 'static,
244
    UnixR: Clone + Send + Sync + 'static,
245
    TlsR: Send + Sync + 'static,
246
    UdpR: Send + Sync + 'static,
247
{
248
    type Stream = UnixR::Stream;
249

            
250
    type Listener = UnixR::Listener;
251

            
252
    #[inline]
253
    #[instrument(skip_all, level = "trace")]
254
    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
255
        self.inner.unix.connect(addr).await
256
    }
257

            
258
    #[inline]
259
    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
260
        self.inner.unix.listen(addr).await
261
    }
262
}
263

            
264
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
265
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
266
where
267
    TcpR: NetStreamProvider,
268
    TlsR: TlsProvider<S>,
269
    UnixR: Clone + Send + Sync + 'static,
270
    SleepR: Clone + Send + Sync + 'static,
271
    CoarseTimeR: Clone + Send + Sync + 'static,
272
    TaskR: Clone + Send + Sync + 'static,
273
    UdpR: Clone + Send + Sync + 'static,
274
    S: StreamOps,
275
{
276
    type Connector = TlsR::Connector;
277
    type TlsStream = TlsR::TlsStream;
278
    type Acceptor = TlsR::Acceptor;
279
    type TlsServerStream = TlsR::TlsServerStream;
280

            
281
    #[inline]
282
38
    fn tls_connector(&self) -> Self::Connector {
283
38
        self.inner.tls.tls_connector()
284
38
    }
285

            
286
    #[inline]
287
14
    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
288
14
        self.inner.tls.tls_acceptor(settings)
289
14
    }
290

            
291
    #[inline]
292
    fn supports_keying_material_export(&self) -> bool {
293
        self.inner.tls.supports_keying_material_export()
294
    }
295
}
296

            
297
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
298
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
299
{
300
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301
2
        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
302
2
    }
303
}
304

            
305
#[async_trait]
306
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
307
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
308
where
309
    UdpR: UdpProvider,
310
    TaskR: Send + Sync + 'static,
311
    SleepR: Send + Sync + 'static,
312
    CoarseTimeR: Send + Sync + 'static,
313
    TcpR: Send + Sync + 'static,
314
    UnixR: Clone + Send + Sync + 'static,
315
    TlsR: Send + Sync + 'static,
316
    UdpR: Send + Sync + 'static,
317
{
318
    type UdpSocket = UdpR::UdpSocket;
319

            
320
    #[inline]
321
16
    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
322
        self.inner.udp.bind(addr).await
323
16
    }
324
}
325

            
326
/// Module to seal RuntimeSubstExt
327
mod sealed {
328
    /// Helper for sealing RuntimeSubstExt
329
    #[allow(unreachable_pub)]
330
    pub trait Sealed {}
331
}
332
/// Extension trait on Runtime:
333
/// Construct new Runtimes that replace part of an original runtime.
334
///
335
/// (If you need to do more complicated versions of this, you should likely construct
336
/// CompoundRuntime directly.)
337
pub trait RuntimeSubstExt: sealed::Sealed + Sized {
338
    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
339
    fn with_tcp_provider<T>(
340
        &self,
341
        new_tcp: T,
342
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
343
    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
344
    fn with_sleep_provider<T>(
345
        &self,
346
        new_sleep: T,
347
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
348
    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
349
    fn with_coarse_time_provider<T>(
350
        &self,
351
        new_coarse_time: T,
352
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
353
}
354
impl<R: Runtime> sealed::Sealed for R {}
355
impl<R: Runtime + Sized> RuntimeSubstExt for R {
356
    fn with_tcp_provider<T>(
357
        &self,
358
        new_tcp: T,
359
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
360
        CompoundRuntime::new(
361
            self.clone(),
362
            self.clone(),
363
            self.clone(),
364
            new_tcp,
365
            self.clone(),
366
            self.clone(),
367
            self.clone(),
368
        )
369
    }
370

            
371
12
    fn with_sleep_provider<T>(
372
12
        &self,
373
12
        new_sleep: T,
374
12
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
375
12
        CompoundRuntime::new(
376
12
            self.clone(),
377
12
            new_sleep,
378
12
            self.clone(),
379
12
            self.clone(),
380
12
            self.clone(),
381
12
            self.clone(),
382
12
            self.clone(),
383
        )
384
12
    }
385

            
386
12
    fn with_coarse_time_provider<T>(
387
12
        &self,
388
12
        new_coarse_time: T,
389
12
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
390
12
        CompoundRuntime::new(
391
12
            self.clone(),
392
12
            self.clone(),
393
12
            new_coarse_time,
394
12
            self.clone(),
395
12
            self.clone(),
396
12
            self.clone(),
397
12
            self.clone(),
398
        )
399
12
    }
400
}