1
//! Implements a simple mock network for testing purposes.
2

            
3
// Note: There are lots of opportunities here for making the network
4
// more and more realistic, but please remember that this module only
5
// exists for writing unit tests.  Let's resist the temptation to add
6
// things we don't need.
7

            
8
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
9

            
10
use super::MockNetRuntime;
11
use super::io::{LocalStream, stream_pair};
12
use crate::util::mpsc_channel;
13
use core::fmt;
14
use std::borrow::Cow;
15
use tor_rtcompat::tls::{TlsAcceptorSettings, TlsConnector};
16
use tor_rtcompat::{
17
    CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
18
};
19
use tor_rtcompat::{UdpProvider, UdpSocket};
20

            
21
use async_trait::async_trait;
22
use futures::FutureExt;
23
use futures::channel::mpsc;
24
use futures::io::{AsyncRead, AsyncWrite};
25
use futures::lock::Mutex as AsyncMutex;
26
use futures::sink::SinkExt;
27
use futures::stream::{Stream, StreamExt};
28
use std::collections::HashMap;
29
use std::fmt::Formatter;
30
use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
31
use std::net::{IpAddr, SocketAddr};
32
use std::pin::Pin;
33
use std::sync::atomic::{AtomicU16, Ordering};
34
use std::sync::{Arc, Mutex};
35
use std::task::{Context, Poll};
36
use thiserror::Error;
37
use void::Void;
38

            
39
/// A channel sender that we use to send incoming connections to
40
/// listeners.
41
type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
42
/// A channel receiver that listeners use to receive incoming connections.
43
type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
44

            
45
/// A simulated Internet, for testing.
46
///
47
/// We simulate TCP streams only, and skip all the details. Connection
48
/// are implemented using [`LocalStream`]. The MockNetwork object is
49
/// shared by a large set of MockNetworkProviders, each of which has
50
/// its own view of its address(es) on the network.
51
#[derive(Default)]
52
pub struct MockNetwork {
53
    /// A map from address to the entries about listeners there.
54
    listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
55
}
56

            
57
/// The `MockNetwork`'s view of a listener.
58
#[derive(Clone)]
59
struct ListenerEntry {
60
    /// A sender that need to be informed about connection attempts
61
    /// there.
62
    send: ConnSender,
63

            
64
    /// A notional TLS certificate for this listener.  If absent, the
65
    /// listener isn't a TLS listener.
66
    tls_cert: Option<Vec<u8>>,
67
}
68

            
69
/// A possible non-error behavior from an address
70
#[derive(Clone)]
71
enum AddrBehavior {
72
    /// There's a listener at this address, which would like to reply.
73
    Listener(ListenerEntry),
74
    /// All connections sent to this address will time out.
75
    Timeout,
76
}
77

            
78
/// A view of a single host's access to a MockNetwork.
79
///
80
/// Each simulated host has its own addresses that it's allowed to listen on,
81
/// and a reference to the network.
82
///
83
/// This type implements [`NetStreamProvider`] for [`SocketAddr`]
84
/// so that it can be used as a
85
/// drop-in replacement for testing code that uses the network.
86
///
87
/// # Limitations
88
///
89
/// There's no randomness here, so we can't simulate the weirdness of
90
/// real networks.
91
///
92
/// So far, there's no support for DNS or UDP.
93
///
94
/// We don't handle localhost specially, and we don't simulate providers
95
/// that can connect to some addresses but not all.
96
///
97
/// We don't do the right thing (block) if there is a listener that
98
/// never calls accept.
99
///
100
/// UDP is completely broken:
101
/// datagrams appear to be transmitted, but will never be received.
102
/// And local address assignment is not implemented
103
/// so [`.local_addr()`](UdpSocket::local_addr) can return `NONE`
104
// TODO MOCK UDP: Documentation does describe the brokennesses
105
///
106
/// We use a simple `u16` counter to decide what arbitrary port
107
/// numbers to use: Once that counter is exhausted, we will fail with
108
/// an assertion.  We don't do anything to prevent those arbitrary
109
/// ports from colliding with specified ports, other than declare that
110
/// you can't have two listeners on the same addr:port at the same
111
/// time.
112
///
113
/// We pretend to provide TLS, but there's no actual encryption or
114
/// authentication.
115
#[derive(Clone)]
116
pub struct MockNetProvider {
117
    /// Actual implementation of this host's view of the network.
118
    ///
119
    /// We have to use a separate type here and reference count it,
120
    /// since the `next_port` counter needs to be shared.
121
    inner: Arc<MockNetProviderInner>,
122
}
123

            
124
impl fmt::Debug for MockNetProvider {
125
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126
        f.debug_struct("MockNetProvider").finish_non_exhaustive()
127
    }
128
}
129

            
130
/// Shared part of a MockNetworkProvider.
131
///
132
/// This is separate because providers need to implement Clone, but
133
/// `next_port` can't be cloned.
134
struct MockNetProviderInner {
135
    /// List of public addresses
136
    addrs: Vec<IpAddr>,
137
    /// Shared reference to the network.
138
    net: Arc<MockNetwork>,
139
    /// Next port number to hand out when we're asked to listen on
140
    /// port 0.
141
    ///
142
    /// See discussion of limitations on `listen()` implementation.
143
    next_port: AtomicU16,
144
}
145

            
146
/// A [`NetStreamListener`] implementation returned by a [`MockNetProvider`].
147
///
148
/// Represents listening on a public address for incoming TCP connections.
149
pub struct MockNetListener {
150
    /// The address that we're listening on.
151
    addr: SocketAddr,
152
    /// The incoming channel that tells us about new connections.
153
    // TODO: I'm not thrilled to have to use an AsyncMutex and a
154
    // std Mutex in the same module.
155
    receiver: AsyncMutex<ConnReceiver>,
156
}
157

            
158
/// A builder object used to configure a [`MockNetProvider`]
159
///
160
/// Returned by [`MockNetwork::builder()`].
161
pub struct ProviderBuilder {
162
    /// List of public addresses.
163
    addrs: Vec<IpAddr>,
164
    /// Shared reference to the network.
165
    net: Arc<MockNetwork>,
166
}
167

            
168
impl Default for MockNetProvider {
169
70740
    fn default() -> Self {
170
70740
        Arc::new(MockNetwork::default()).builder().provider()
171
70740
    }
172
}
173

            
174
impl MockNetwork {
175
    /// Make a new MockNetwork with no active listeners.
176
140
    pub fn new() -> Arc<Self> {
177
140
        Default::default()
178
140
    }
179

            
180
    /// Return a [`ProviderBuilder`] for creating a [`MockNetProvider`]
181
    ///
182
    /// # Examples
183
    ///
184
    /// ```
185
    /// # use tor_rtmock::net::*;
186
    /// # let mock_network = MockNetwork::new();
187
    /// let client_net = mock_network.builder()
188
    ///       .add_address("198.51.100.6".parse().unwrap())
189
    ///       .add_address("2001:db8::7".parse().unwrap())
190
    ///       .provider();
191
    /// ```
192
71018
    pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
193
71018
        ProviderBuilder {
194
71018
            addrs: vec![],
195
71018
            net: Arc::clone(self),
196
71018
        }
197
71018
    }
198

            
199
    /// Add a "black hole" at the given address, where all traffic will time out.
200
56
    pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
201
56
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
202
56
        if listener_map.contains_key(&address) {
203
            return Err(err(ErrorKind::AddrInUse));
204
56
        }
205
56
        listener_map.insert(address, AddrBehavior::Timeout);
206
56
        Ok(())
207
56
    }
208

            
209
    /// Tell the listener at `target_addr` (if any) about an incoming
210
    /// connection from `source_addr` at `peer_stream`.
211
    ///
212
    /// If the listener is a TLS listener, returns its certificate.
213
    /// **Note:** Callers should check whether the presence or absence of a certificate
214
    /// matches their expectations.
215
    ///
216
    /// Returns an error if there isn't any such listener.
217
1192
    async fn send_connection(
218
1192
        &self,
219
1192
        source_addr: SocketAddr,
220
1192
        target_addr: SocketAddr,
221
1192
        peer_stream: LocalStream,
222
1248
    ) -> IoResult<Option<Vec<u8>>> {
223
1192
        let entry = {
224
1192
            let listener_map = self.listening.lock().expect("Poisoned lock for listener");
225
1192
            listener_map.get(&target_addr).cloned()
226
        };
227
900
        match entry {
228
620
            Some(AddrBehavior::Listener(mut entry)) => {
229
620
                if entry.send.send((peer_stream, source_addr)).await.is_ok() {
230
620
                    return Ok(entry.tls_cert);
231
                }
232
                Err(err(ErrorKind::ConnectionRefused))
233
            }
234
280
            Some(AddrBehavior::Timeout) => futures::future::pending().await,
235
292
            None => Err(err(ErrorKind::ConnectionRefused)),
236
        }
237
912
    }
238

            
239
    /// Register a listener at `addr` and return the ConnReceiver
240
    /// that it should use for connections.
241
    ///
242
    /// If tls_cert is provided, then the listener is a TLS listener
243
    /// and any only TLS connection attempts should succeed.
244
    ///
245
    /// Returns an error if the address is already in use.
246
204
    fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
247
204
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
248
204
        if listener_map.contains_key(&addr) {
249
            // TODO: Maybe this should ignore dangling Weak references?
250
            return Err(err(ErrorKind::AddrInUse));
251
204
        }
252

            
253
204
        let (send, recv) = mpsc_channel(16);
254

            
255
204
        let entry = ListenerEntry { send, tls_cert };
256

            
257
204
        listener_map.insert(addr, AddrBehavior::Listener(entry));
258

            
259
204
        Ok(recv)
260
204
    }
261
}
262

            
263
impl ProviderBuilder {
264
    /// Add `addr` as a new address for the provider we're building.
265
336
    pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
266
336
        self.addrs.push(addr);
267
336
        self
268
336
    }
269
    /// Use this builder to return a new [`MockNetRuntime`] wrapping
270
    /// an existing `runtime`.
271
8
    pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
272
8
        MockNetRuntime::new(runtime, self.provider())
273
8
    }
274
    /// Use this builder to return a new [`MockNetProvider`]
275
71018
    pub fn provider(&self) -> MockNetProvider {
276
71018
        let inner = MockNetProviderInner {
277
71018
            addrs: self.addrs.clone(),
278
71018
            net: Arc::clone(&self.net),
279
71018
            next_port: AtomicU16::new(1),
280
71018
        };
281
71018
        MockNetProvider {
282
71018
            inner: Arc::new(inner),
283
71018
        }
284
71018
    }
285
}
286

            
287
impl NetStreamListener for MockNetListener {
288
    type Stream = LocalStream;
289

            
290
    type Incoming = Self;
291

            
292
36
    fn local_addr(&self) -> IoResult<SocketAddr> {
293
36
        Ok(self.addr)
294
36
    }
295

            
296
92
    fn incoming(self) -> Self {
297
92
        self
298
92
    }
299
}
300

            
301
impl Stream for MockNetListener {
302
    type Item = IoResult<(LocalStream, SocketAddr)>;
303
116
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
304
116
        let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
305
116
        match recv.poll_next_unpin(cx) {
306
            Poll::Pending => Poll::Pending,
307
            Poll::Ready(None) => Poll::Ready(None),
308
116
            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
309
        }
310
116
    }
311
}
312

            
313
/// A very poor imitation of a UDP socket
314
#[derive(Debug)]
315
#[non_exhaustive]
316
pub struct MockUdpSocket {
317
    /// This is uninhabited.
318
    ///
319
    /// To implement UDP support, implement `.bind()`, and abolish this field,
320
    /// replacing it with the actual implementation.
321
    void: Void,
322
}
323

            
324
#[async_trait]
325
impl UdpProvider for MockNetProvider {
326
    type UdpSocket = MockUdpSocket;
327

            
328
    async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
329
        let _ = addr; // MockNetProvider UDP is not implemented
330
        Err(io::ErrorKind::Unsupported.into())
331
    }
332
}
333

            
334
#[allow(clippy::diverging_sub_expression)] // void::unimplemented + async_trait
335
#[async_trait]
336
impl UdpSocket for MockUdpSocket {
337
    async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
338
        // This tuple idiom avoids unused variable warnings.
339
        // An alternative would be to write _buf, but then when this is implemented,
340
        // and the void::unreachable call removed, we actually *want* those warnings.
341
        void::unreachable((self.void, buf).0)
342
    }
343
    async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
344
        void::unreachable((self.void, buf, target).0)
345
    }
346
    fn local_addr(&self) -> IoResult<SocketAddr> {
347
        void::unreachable(self.void)
348
    }
349
}
350

            
351
impl MockNetProvider {
352
    /// If we have a local addresses that is in the same family as `other`,
353
    /// return it.
354
1234
    fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
355
1234
        self.inner
356
1234
            .addrs
357
1234
            .iter()
358
1313
            .find(|a| a.is_ipv4() == other.is_ipv4())
359
1234
            .copied()
360
1234
    }
361

            
362
    /// Return an arbitrary port number that we haven't returned from
363
    /// this function before.
364
    ///
365
    /// # Panics
366
    ///
367
    /// Panics if there are no remaining ports that this function hasn't
368
    /// returned before.
369
1208
    fn arbitrary_port(&self) -> u16 {
370
1208
        let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
371
1208
        assert!(next != 0);
372
1208
        next
373
1208
    }
374

            
375
    /// Helper for connecting: Picks the socketaddr to use
376
    /// when told to connect to `addr`.
377
    ///
378
    /// The IP is one of our own IPs with the same family as `addr`.
379
    /// The port is a port that we haven't used as an arbitrary port
380
    /// before.
381
1192
    fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
382
1192
        let my_addr = self
383
1192
            .get_addr_in_family(&addr.ip())
384
1192
            .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
385
1192
        Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
386
1192
    }
387

            
388
    /// Helper for binding a listener: Picks the socketaddr to use
389
    /// when told to bind to `addr`.
390
    ///
391
    /// If addr is `0.0.0.0` or `[::]`, then we pick one of our own
392
    /// addresses with the same family. Otherwise we fail unless `addr` is
393
    /// one of our own addresses.
394
    ///
395
    /// If port is 0, we pick a new arbitrary port we haven't used as
396
    /// an arbitrary port before.
397
220
    fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
398
216
        let ipaddr = {
399
220
            let ip = spec.ip();
400
220
            if ip.is_unspecified() {
401
42
                self.get_addr_in_family(&ip)
402
42
                    .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
403
248
            } else if self.inner.addrs.iter().any(|a| a == &ip) {
404
174
                ip
405
            } else {
406
4
                return Err(err(ErrorKind::AddrNotAvailable));
407
            }
408
        };
409
216
        let port = {
410
216
            if spec.port() == 0 {
411
16
                self.arbitrary_port()
412
            } else {
413
200
                spec.port()
414
            }
415
        };
416

            
417
216
        Ok(SocketAddr::new(ipaddr, port))
418
220
    }
419

            
420
    /// Create a mock TLS listener with provided certificate.
421
    ///
422
    /// Note that no encryption or authentication is actually
423
    /// performed!  Other parties are simply told that their connections
424
    /// succeeded and were authenticated against the given certificate.
425
68
    pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
426
68
        let addr = self.get_listener_addr(addr)?;
427

            
428
68
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
429

            
430
68
        Ok(MockNetListener { addr, receiver })
431
68
    }
432
}
433

            
434
#[async_trait]
435
impl NetStreamProvider for MockNetProvider {
436
    type Stream = LocalStream;
437
    type Listener = MockNetListener;
438
    type ConnectOptions = tor_rtcompat::TcpConnectOptions;
439
    type ListenOptions = tor_rtcompat::TcpListenOptions;
440

            
441
    async fn connect(
442
        &self,
443
        addr: &SocketAddr,
444
        _options: &Self::ConnectOptions,
445
1192
    ) -> IoResult<LocalStream> {
446
        let my_addr = self.get_origin_addr_for(addr)?;
447
        let (mut mine, theirs) = stream_pair();
448

            
449
        let cert = self
450
            .inner
451
            .net
452
            .send_connection(my_addr, *addr, theirs)
453
            .await?;
454

            
455
        mine.tls_cert = cert;
456

            
457
        Ok(mine)
458
1192
    }
459

            
460
    async fn listen(
461
        &self,
462
        addr: &SocketAddr,
463
        _options: &Self::ListenOptions,
464
136
    ) -> IoResult<Self::Listener> {
465
        let addr = self.get_listener_addr(addr)?;
466

            
467
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
468

            
469
        Ok(MockNetListener { addr, receiver })
470
136
    }
471
}
472

            
473
#[async_trait]
474
impl TlsProvider<LocalStream> for MockNetProvider {
475
    type Connector = MockTlsConnector;
476
    type TlsStream = MockTlsStream;
477
    type Acceptor = MockTlsAcceptor;
478
    type TlsServerStream = MockTlsStream;
479

            
480
964
    fn tls_connector(&self) -> MockTlsConnector {
481
964
        MockTlsConnector {}
482
964
    }
483
    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<MockTlsAcceptor> {
484
        Ok(MockTlsAcceptor {
485
            own_cert: settings.cert_der().to_vec(),
486
        })
487
    }
488

            
489
    fn supports_keying_material_export(&self) -> bool {
490
        false
491
    }
492
}
493

            
494
/// Mock TLS connector for use with MockNetProvider.
495
///
496
/// Note that no TLS is actually performed here: connections are simply
497
/// told that they succeeded with a given certificate.
498
#[derive(Clone)]
499
#[non_exhaustive]
500
pub struct MockTlsConnector;
501

            
502
/// Mock TLS acceptor for use with MockNetProvider.
503
///
504
/// Note that no TLS is actually performed here: connections are simply
505
/// told that they succeeded.
506
#[derive(Clone)]
507
#[non_exhaustive]
508
pub struct MockTlsAcceptor {
509
    /// The certificate that we are pretending to send.
510
    own_cert: Vec<u8>,
511
}
512

            
513
/// Mock TLS connector for use with MockNetProvider.
514
///
515
/// Note that no TLS is actually performed here: connections are simply
516
/// told that they succeeded with a given certificate.
517
///
518
/// Note also that we only use this type for client-side connections
519
/// right now: Arti doesn't support being a real TLS Listener yet,
520
/// since we only handle Tor client operations.
521
pub struct MockTlsStream {
522
    /// The peer certificate that we are pretending our peer has.
523
    peer_cert: Option<Vec<u8>>,
524
    /// The certificate that we are pretending that we sent.
525
    own_cert: Option<Vec<u8>>,
526
    /// The underlying stream.
527
    stream: LocalStream,
528
}
529

            
530
#[async_trait]
531
impl TlsConnector<LocalStream> for MockTlsConnector {
532
    type Conn = MockTlsStream;
533

            
534
    async fn negotiate_unvalidated(
535
        &self,
536
        mut stream: LocalStream,
537
        _sni_hostname: &str,
538
68
    ) -> IoResult<MockTlsStream> {
539
        let peer_cert = stream.tls_cert.take();
540

            
541
        if peer_cert.is_none() {
542
            return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
543
        }
544

            
545
        Ok(MockTlsStream {
546
            peer_cert,
547
            own_cert: None,
548
            stream,
549
        })
550
68
    }
551
}
552

            
553
#[async_trait]
554
impl TlsConnector<LocalStream> for MockTlsAcceptor {
555
    type Conn = MockTlsStream;
556

            
557
    async fn negotiate_unvalidated(
558
        &self,
559
        stream: LocalStream,
560
        _sni_hostname: &str,
561
    ) -> IoResult<MockTlsStream> {
562
        Ok(MockTlsStream {
563
            peer_cert: None,
564
            own_cert: Some(self.own_cert.clone()),
565
            stream,
566
        })
567
    }
568
}
569

            
570
impl CertifiedConn for MockTlsStream {
571
68
    fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
572
68
        Ok(self.peer_cert.clone().map(Cow::from))
573
68
    }
574

            
575
    fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
576
        Ok(self.own_cert.clone().map(Cow::from))
577
    }
578
    fn export_keying_material(
579
        &self,
580
        _len: usize,
581
        _label: &[u8],
582
        _context: Option<&[u8]>,
583
    ) -> IoResult<Vec<u8>> {
584
        Ok(Vec::new())
585
    }
586
}
587

            
588
impl AsyncRead for MockTlsStream {
589
876
    fn poll_read(
590
876
        mut self: Pin<&mut Self>,
591
876
        cx: &mut Context<'_>,
592
876
        buf: &mut [u8],
593
876
    ) -> Poll<IoResult<usize>> {
594
876
        Pin::new(&mut self.stream).poll_read(cx, buf)
595
876
    }
596
}
597
impl AsyncWrite for MockTlsStream {
598
236
    fn poll_write(
599
236
        mut self: Pin<&mut Self>,
600
236
        cx: &mut Context<'_>,
601
236
        buf: &[u8],
602
236
    ) -> Poll<IoResult<usize>> {
603
236
        Pin::new(&mut self.stream).poll_write(cx, buf)
604
236
    }
605
112
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
606
112
        Pin::new(&mut self.stream).poll_flush(cx)
607
112
    }
608
12
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
609
12
        Pin::new(&mut self.stream).poll_close(cx)
610
12
    }
611
}
612

            
613
impl StreamOps for MockTlsStream {
614
    fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
615
        Err(std::io::Error::new(
616
            std::io::ErrorKind::Unsupported,
617
            "not supported on non-StreamOps stream!",
618
        ))
619
    }
620

            
621
56
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
622
56
        Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
623
56
    }
624
}
625

            
626
/// Inner error type returned when a `MockNetwork` operation fails.
627
#[derive(Clone, Error, Debug)]
628
#[non_exhaustive]
629
pub enum MockNetError {
630
    /// General-purpose error.  The real information is in `ErrorKind`.
631
    #[error("Invalid operation on mock network")]
632
    BadOp,
633
}
634

            
635
/// Wrap `k` in a new [`std::io::Error`].
636
296
fn err(k: ErrorKind) -> IoError {
637
296
    IoError::new(k, MockNetError::BadOp)
638
296
}
639

            
640
#[cfg(all(test, not(miri)))] // miri cannot simulate the networking
641
mod test {
642
    // @@ begin test lint list maintained by maint/add_warning @@
643
    #![allow(clippy::bool_assert_comparison)]
644
    #![allow(clippy::clone_on_copy)]
645
    #![allow(clippy::dbg_macro)]
646
    #![allow(clippy::mixed_attributes_style)]
647
    #![allow(clippy::print_stderr)]
648
    #![allow(clippy::print_stdout)]
649
    #![allow(clippy::single_char_pattern)]
650
    #![allow(clippy::unwrap_used)]
651
    #![allow(clippy::unchecked_time_subtraction)]
652
    #![allow(clippy::useless_vec)]
653
    #![allow(clippy::needless_pass_by_value)]
654
    #![allow(clippy::string_slice)] // See arti#2571
655
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
656
    use super::*;
657
    use futures::io::{AsyncReadExt, AsyncWriteExt};
658
    use tor_rtcompat::test_with_all_runtimes;
659

            
660
    fn client_pair() -> (MockNetProvider, MockNetProvider) {
661
        let net = MockNetwork::new();
662
        let client1 = net
663
            .builder()
664
            .add_address("192.0.2.55".parse().unwrap())
665
            .provider();
666
        let client2 = net
667
            .builder()
668
            .add_address("198.51.100.7".parse().unwrap())
669
            .provider();
670

            
671
        (client1, client2)
672
    }
673

            
674
    #[test]
675
    fn end_to_end() {
676
        test_with_all_runtimes!(|_rt| async {
677
            let (client1, client2) = client_pair();
678
            let listen_options = Default::default();
679
            let lis = client2
680
                .listen(&"0.0.0.0:99".parse().unwrap(), &listen_options)
681
                .await?;
682
            let address = lis.local_addr()?;
683

            
684
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
685
                async {
686
                    let connect_options = Default::default();
687
                    let mut conn = client1.connect(&address, &connect_options).await?;
688
                    conn.write_all(b"This is totally a network.").await?;
689
                    conn.close().await?;
690

            
691
                    // Nobody listening here...
692
                    let a2 = "192.0.2.200:99".parse().unwrap();
693
                    let cant_connect = client1.connect(&a2, &connect_options).await;
694
                    assert!(cant_connect.is_err());
695
                    Ok(())
696
                },
697
                async {
698
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
699
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
700
                    let mut inp = Vec::new();
701
                    conn.read_to_end(&mut inp).await?;
702
                    assert_eq!(&inp[..], &b"This is totally a network."[..]);
703
                    Ok(())
704
                }
705
            );
706
            r1?;
707
            r2?;
708
            IoResult::Ok(())
709
        });
710
    }
711

            
712
    #[test]
713
    fn pick_listener_addr() -> IoResult<()> {
714
        let net = MockNetwork::new();
715
        let ip4 = "192.0.2.55".parse().unwrap();
716
        let ip6 = "2001:db8::7".parse().unwrap();
717
        let client = net.builder().add_address(ip4).add_address(ip6).provider();
718

            
719
        // Successful cases
720
        let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
721
        assert_eq!(a1.ip(), ip4);
722
        assert_eq!(a1.port(), 99);
723
        let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
724
        assert_eq!(a2.ip(), ip4);
725
        assert_eq!(a2.port(), 100);
726
        let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
727
        assert_eq!(a3.ip(), ip4);
728
        assert!(a3.port() != 0);
729
        let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
730
        assert_eq!(a4.ip(), ip4);
731
        assert!(a4.port() != 0);
732
        assert!(a4.port() != a3.port());
733
        let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
734
        assert_eq!(a5.ip(), ip6);
735
        assert_eq!(a5.port(), 99);
736
        let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
737
        assert_eq!(a6.ip(), ip6);
738
        assert_eq!(a6.port(), 100);
739

            
740
        // Failing cases
741
        let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
742
        let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
743
        assert!(e1.is_err());
744
        assert!(e2.is_err());
745

            
746
        IoResult::Ok(())
747
    }
748

            
749
    #[test]
750
    fn listener_stream() {
751
        test_with_all_runtimes!(|_rt| async {
752
            let (client1, client2) = client_pair();
753

            
754
            let listen_options = Default::default();
755
            let lis = client2
756
                .listen(&"0.0.0.0:99".parse().unwrap(), &listen_options)
757
                .await?;
758
            let address = lis.local_addr()?;
759
            let mut incoming = lis.incoming();
760

            
761
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
762
                async {
763
                    for _ in 0..3_u8 {
764
                        let connect_options = Default::default();
765
                        let mut c = client1.connect(&address, &connect_options).await?;
766
                        c.close().await?;
767
                    }
768
                    Ok(())
769
                },
770
                async {
771
                    for _ in 0..3_u8 {
772
                        let (mut c, a) = incoming.next().await.unwrap()?;
773
                        let mut v = Vec::new();
774
                        let _ = c.read_to_end(&mut v).await?;
775
                        assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
776
                    }
777
                    Ok(())
778
                }
779
            );
780
            r1?;
781
            r2?;
782
            IoResult::Ok(())
783
        });
784
    }
785

            
786
    #[test]
787
    fn tls_basics() {
788
        let (client1, client2) = client_pair();
789
        let cert = b"I am certified for something I assure you.";
790

            
791
        test_with_all_runtimes!(|_rt| async {
792
            let lis = client2
793
                .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
794
                .unwrap();
795
            let address = lis.local_addr().unwrap();
796

            
797
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
798
                async {
799
                    let connector = client1.tls_connector();
800
                    let connect_options = Default::default();
801
                    let conn = client1.connect(&address, &connect_options).await?;
802
                    let mut conn = connector
803
                        .negotiate_unvalidated(conn, "zombo.example.com")
804
                        .await?;
805
                    assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
806
                    conn.write_all(b"This is totally encrypted.").await?;
807
                    let mut v = Vec::new();
808
                    conn.read_to_end(&mut v).await?;
809
                    conn.close().await?;
810
                    assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
811
                    Ok(())
812
                },
813
                async {
814
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
815
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
816
                    let mut inp = [0_u8; 26];
817
                    conn.read_exact(&mut inp[..]).await?;
818
                    assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
819
                    conn.write_all(b"Yup, your secrets is safe").await?;
820
                    Ok(())
821
                }
822
            );
823
            r1?;
824
            r2?;
825
            IoResult::Ok(())
826
        });
827
    }
828
}