1
//! Implements a simple mock network for testing purposes.
2

            
3
// Note: There are lots of opportunities here for making the network
4
// more and more realistic, but please remember that this module only
5
// exists for writing unit tests.  Let's resist the temptation to add
6
// things we don't need.
7

            
8
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
9

            
10
use super::MockNetRuntime;
11
use super::io::{LocalStream, stream_pair};
12
use crate::util::mpsc_channel;
13
use core::fmt;
14
use std::borrow::Cow;
15
use tor_rtcompat::tls::{TlsAcceptorSettings, TlsConnector};
16
use tor_rtcompat::{
17
    CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
18
};
19
use tor_rtcompat::{UdpProvider, UdpSocket};
20

            
21
use async_trait::async_trait;
22
use futures::FutureExt;
23
use futures::channel::mpsc;
24
use futures::io::{AsyncRead, AsyncWrite};
25
use futures::lock::Mutex as AsyncMutex;
26
use futures::sink::SinkExt;
27
use futures::stream::{Stream, StreamExt};
28
use std::collections::HashMap;
29
use std::fmt::Formatter;
30
use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
31
use std::net::{IpAddr, SocketAddr};
32
use std::pin::Pin;
33
use std::sync::atomic::{AtomicU16, Ordering};
34
use std::sync::{Arc, Mutex};
35
use std::task::{Context, Poll};
36
use thiserror::Error;
37
use void::Void;
38

            
39
/// A channel sender that we use to send incoming connections to
40
/// listeners.
41
type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
42
/// A channel receiver that listeners use to receive incoming connections.
43
type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
44

            
45
/// A simulated Internet, for testing.
46
///
47
/// We simulate TCP streams only, and skip all the details. Connection
48
/// are implemented using [`LocalStream`]. The MockNetwork object is
49
/// shared by a large set of MockNetworkProviders, each of which has
50
/// its own view of its address(es) on the network.
51
#[derive(Default)]
52
pub struct MockNetwork {
53
    /// A map from address to the entries about listeners there.
54
    listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
55
}
56

            
57
/// The `MockNetwork`'s view of a listener.
58
#[derive(Clone)]
59
struct ListenerEntry {
60
    /// A sender that need to be informed about connection attempts
61
    /// there.
62
    send: ConnSender,
63

            
64
    /// A notional TLS certificate for this listener.  If absent, the
65
    /// listener isn't a TLS listener.
66
    tls_cert: Option<Vec<u8>>,
67
}
68

            
69
/// A possible non-error behavior from an address
70
#[derive(Clone)]
71
enum AddrBehavior {
72
    /// There's a listener at this address, which would like to reply.
73
    Listener(ListenerEntry),
74
    /// All connections sent to this address will time out.
75
    Timeout,
76
}
77

            
78
/// A view of a single host's access to a MockNetwork.
79
///
80
/// Each simulated host has its own addresses that it's allowed to listen on,
81
/// and a reference to the network.
82
///
83
/// This type implements [`NetStreamProvider`] for [`SocketAddr`]
84
/// so that it can be used as a
85
/// drop-in replacement for testing code that uses the network.
86
///
87
/// # Limitations
88
///
89
/// There's no randomness here, so we can't simulate the weirdness of
90
/// real networks.
91
///
92
/// So far, there's no support for DNS or UDP.
93
///
94
/// We don't handle localhost specially, and we don't simulate providers
95
/// that can connect to some addresses but not all.
96
///
97
/// We don't do the right thing (block) if there is a listener that
98
/// never calls accept.
99
///
100
/// UDP is completely broken:
101
/// datagrams appear to be transmitted, but will never be received.
102
/// And local address assignment is not implemented
103
/// so [`.local_addr()`](UdpSocket::local_addr) can return `NONE`
104
// TODO MOCK UDP: Documentation does describe the brokennesses
105
///
106
/// We use a simple `u16` counter to decide what arbitrary port
107
/// numbers to use: Once that counter is exhausted, we will fail with
108
/// an assertion.  We don't do anything to prevent those arbitrary
109
/// ports from colliding with specified ports, other than declare that
110
/// you can't have two listeners on the same addr:port at the same
111
/// time.
112
///
113
/// We pretend to provide TLS, but there's no actual encryption or
114
/// authentication.
115
#[derive(Clone)]
116
pub struct MockNetProvider {
117
    /// Actual implementation of this host's view of the network.
118
    ///
119
    /// We have to use a separate type here and reference count it,
120
    /// since the `next_port` counter needs to be shared.
121
    inner: Arc<MockNetProviderInner>,
122
}
123

            
124
impl fmt::Debug for MockNetProvider {
125
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126
        f.debug_struct("MockNetProvider").finish_non_exhaustive()
127
    }
128
}
129

            
130
/// Shared part of a MockNetworkProvider.
131
///
132
/// This is separate because providers need to implement Clone, but
133
/// `next_port` can't be cloned.
134
struct MockNetProviderInner {
135
    /// List of public addresses
136
    addrs: Vec<IpAddr>,
137
    /// Shared reference to the network.
138
    net: Arc<MockNetwork>,
139
    /// Next port number to hand out when we're asked to listen on
140
    /// port 0.
141
    ///
142
    /// See discussion of limitations on `listen()` implementation.
143
    next_port: AtomicU16,
144
}
145

            
146
/// A [`NetStreamListener`] implementation returned by a [`MockNetProvider`].
147
///
148
/// Represents listening on a public address for incoming TCP connections.
149
pub struct MockNetListener {
150
    /// The address that we're listening on.
151
    addr: SocketAddr,
152
    /// The incoming channel that tells us about new connections.
153
    // TODO: I'm not thrilled to have to use an AsyncMutex and a
154
    // std Mutex in the same module.
155
    receiver: AsyncMutex<ConnReceiver>,
156
}
157

            
158
/// A builder object used to configure a [`MockNetProvider`]
159
///
160
/// Returned by [`MockNetwork::builder()`].
161
pub struct ProviderBuilder {
162
    /// List of public addresses.
163
    addrs: Vec<IpAddr>,
164
    /// Shared reference to the network.
165
    net: Arc<MockNetwork>,
166
}
167

            
168
impl Default for MockNetProvider {
169
63029
    fn default() -> Self {
170
63029
        Arc::new(MockNetwork::default()).builder().provider()
171
63029
    }
172
}
173

            
174
impl MockNetwork {
175
    /// Make a new MockNetwork with no active listeners.
176
134
    pub fn new() -> Arc<Self> {
177
134
        Default::default()
178
134
    }
179

            
180
    /// Return a [`ProviderBuilder`] for creating a [`MockNetProvider`]
181
    ///
182
    /// # Examples
183
    ///
184
    /// ```
185
    /// # use tor_rtmock::net::*;
186
    /// # let mock_network = MockNetwork::new();
187
    /// let client_net = mock_network.builder()
188
    ///       .add_address("198.51.100.6".parse().unwrap())
189
    ///       .add_address("2001:db8::7".parse().unwrap())
190
    ///       .provider();
191
    /// ```
192
63295
    pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
193
63295
        ProviderBuilder {
194
63295
            addrs: vec![],
195
63295
            net: Arc::clone(self),
196
63295
        }
197
63295
    }
198

            
199
    /// Add a "black hole" at the given address, where all traffic will time out.
200
53
    pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
201
53
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
202
53
        if listener_map.contains_key(&address) {
203
            return Err(err(ErrorKind::AddrInUse));
204
53
        }
205
53
        listener_map.insert(address, AddrBehavior::Timeout);
206
53
        Ok(())
207
53
    }
208

            
209
    /// Tell the listener at `target_addr` (if any) about an incoming
210
    /// connection from `source_addr` at `peer_stream`.
211
    ///
212
    /// If the listener is a TLS listener, returns its certificate.
213
    /// **Note:** Callers should check whether the presence or absence of a certificate
214
    /// matches their expectations.
215
    ///
216
    /// Returns an error if there isn't any such listener.
217
1132
    async fn send_connection(
218
1132
        &self,
219
1132
        source_addr: SocketAddr,
220
1132
        target_addr: SocketAddr,
221
1132
        peer_stream: LocalStream,
222
1188
    ) -> IoResult<Option<Vec<u8>>> {
223
1132
        let entry = {
224
1132
            let listener_map = self.listening.lock().expect("Poisoned lock for listener");
225
1132
            listener_map.get(&target_addr).cloned()
226
        };
227
855
        match entry {
228
590
            Some(AddrBehavior::Listener(mut entry)) => {
229
590
                if entry.send.send((peer_stream, source_addr)).await.is_ok() {
230
590
                    return Ok(entry.tls_cert);
231
                }
232
                Err(err(ErrorKind::ConnectionRefused))
233
            }
234
265
            Some(AddrBehavior::Timeout) => futures::future::pending().await,
235
277
            None => Err(err(ErrorKind::ConnectionRefused)),
236
        }
237
867
    }
238

            
239
    /// Register a listener at `addr` and return the ConnReceiver
240
    /// that it should use for connections.
241
    ///
242
    /// If tls_cert is provided, then the listener is a TLS listener
243
    /// and any only TLS connection attempts should succeed.
244
    ///
245
    /// Returns an error if the address is already in use.
246
195
    fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
247
195
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
248
195
        if listener_map.contains_key(&addr) {
249
            // TODO: Maybe this should ignore dangling Weak references?
250
            return Err(err(ErrorKind::AddrInUse));
251
195
        }
252

            
253
195
        let (send, recv) = mpsc_channel(16);
254

            
255
195
        let entry = ListenerEntry { send, tls_cert };
256

            
257
195
        listener_map.insert(addr, AddrBehavior::Listener(entry));
258

            
259
195
        Ok(recv)
260
195
    }
261
}
262

            
263
impl ProviderBuilder {
264
    /// Add `addr` as a new address for the provider we're building.
265
321
    pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
266
321
        self.addrs.push(addr);
267
321
        self
268
321
    }
269
    /// Use this builder to return a new [`MockNetRuntime`] wrapping
270
    /// an existing `runtime`.
271
8
    pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
272
8
        MockNetRuntime::new(runtime, self.provider())
273
8
    }
274
    /// Use this builder to return a new [`MockNetProvider`]
275
63295
    pub fn provider(&self) -> MockNetProvider {
276
63295
        let inner = MockNetProviderInner {
277
63295
            addrs: self.addrs.clone(),
278
63295
            net: Arc::clone(&self.net),
279
63295
            next_port: AtomicU16::new(1),
280
63295
        };
281
63295
        MockNetProvider {
282
63295
            inner: Arc::new(inner),
283
63295
        }
284
63295
    }
285
}
286

            
287
impl NetStreamListener for MockNetListener {
288
    type Stream = LocalStream;
289

            
290
    type Incoming = Self;
291

            
292
36
    fn local_addr(&self) -> IoResult<SocketAddr> {
293
36
        Ok(self.addr)
294
36
    }
295

            
296
89
    fn incoming(self) -> Self {
297
89
        self
298
89
    }
299
}
300

            
301
impl Stream for MockNetListener {
302
    type Item = IoResult<(LocalStream, SocketAddr)>;
303
113
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
304
113
        let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
305
113
        match recv.poll_next_unpin(cx) {
306
            Poll::Pending => Poll::Pending,
307
            Poll::Ready(None) => Poll::Ready(None),
308
113
            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
309
        }
310
113
    }
311
}
312

            
313
/// A very poor imitation of a UDP socket
314
#[derive(Debug)]
315
#[non_exhaustive]
316
pub struct MockUdpSocket {
317
    /// This is uninhabited.
318
    ///
319
    /// To implement UDP support, implement `.bind()`, and abolish this field,
320
    /// replacing it with the actual implementation.
321
    void: Void,
322
}
323

            
324
#[async_trait]
325
impl UdpProvider for MockNetProvider {
326
    type UdpSocket = MockUdpSocket;
327

            
328
    async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
329
        let _ = addr; // MockNetProvider UDP is not implemented
330
        Err(io::ErrorKind::Unsupported.into())
331
    }
332
}
333

            
334
#[allow(clippy::diverging_sub_expression)] // void::unimplemented + async_trait
335
#[async_trait]
336
impl UdpSocket for MockUdpSocket {
337
    async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
338
        // This tuple idiom avoids unused variable warnings.
339
        // An alternative would be to write _buf, but then when this is implemented,
340
        // and the void::unreachable call removed, we actually *want* those warnings.
341
        void::unreachable((self.void, buf).0)
342
    }
343
    async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
344
        void::unreachable((self.void, buf, target).0)
345
    }
346
    fn local_addr(&self) -> IoResult<SocketAddr> {
347
        void::unreachable(self.void)
348
    }
349
}
350

            
351
impl MockNetProvider {
352
    /// If we have a local addresses that is in the same family as `other`,
353
    /// return it.
354
1174
    fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
355
1174
        self.inner
356
1174
            .addrs
357
1174
            .iter()
358
1253
            .find(|a| a.is_ipv4() == other.is_ipv4())
359
1174
            .copied()
360
1174
    }
361

            
362
    /// Return an arbitrary port number that we haven't returned from
363
    /// this function before.
364
    ///
365
    /// # Panics
366
    ///
367
    /// Panics if there are no remaining ports that this function hasn't
368
    /// returned before.
369
1148
    fn arbitrary_port(&self) -> u16 {
370
1148
        let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
371
1148
        assert!(next != 0);
372
1148
        next
373
1148
    }
374

            
375
    /// Helper for connecting: Picks the socketaddr to use
376
    /// when told to connect to `addr`.
377
    ///
378
    /// The IP is one of our own IPs with the same family as `addr`.
379
    /// The port is a port that we haven't used as an arbitrary port
380
    /// before.
381
1132
    fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
382
1132
        let my_addr = self
383
1132
            .get_addr_in_family(&addr.ip())
384
1132
            .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
385
1132
        Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
386
1132
    }
387

            
388
    /// Helper for binding a listener: Picks the socketaddr to use
389
    /// when told to bind to `addr`.
390
    ///
391
    /// If addr is `0.0.0.0` or `[::]`, then we pick one of our own
392
    /// addresses with the same family. Otherwise we fail unless `addr` is
393
    /// one of our own addresses.
394
    ///
395
    /// If port is 0, we pick a new arbitrary port we haven't used as
396
    /// an arbitrary port before.
397
211
    fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
398
207
        let ipaddr = {
399
211
            let ip = spec.ip();
400
211
            if ip.is_unspecified() {
401
42
                self.get_addr_in_family(&ip)
402
42
                    .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
403
236
            } else if self.inner.addrs.iter().any(|a| a == &ip) {
404
165
                ip
405
            } else {
406
4
                return Err(err(ErrorKind::AddrNotAvailable));
407
            }
408
        };
409
207
        let port = {
410
207
            if spec.port() == 0 {
411
16
                self.arbitrary_port()
412
            } else {
413
191
                spec.port()
414
            }
415
        };
416

            
417
207
        Ok(SocketAddr::new(ipaddr, port))
418
211
    }
419

            
420
    /// Create a mock TLS listener with provided certificate.
421
    ///
422
    /// Note that no encryption or authentication is actually
423
    /// performed!  Other parties are simply told that their connections
424
    /// succeeded and were authenticated against the given certificate.
425
65
    pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
426
65
        let addr = self.get_listener_addr(addr)?;
427

            
428
65
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
429

            
430
65
        Ok(MockNetListener { addr, receiver })
431
65
    }
432
}
433

            
434
#[async_trait]
435
impl NetStreamProvider for MockNetProvider {
436
    type Stream = LocalStream;
437
    type Listener = MockNetListener;
438

            
439
1132
    async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
440
        let my_addr = self.get_origin_addr_for(addr)?;
441
        let (mut mine, theirs) = stream_pair();
442

            
443
        let cert = self
444
            .inner
445
            .net
446
            .send_connection(my_addr, *addr, theirs)
447
            .await?;
448

            
449
        mine.tls_cert = cert;
450

            
451
        Ok(mine)
452
1132
    }
453

            
454
130
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
455
        let addr = self.get_listener_addr(addr)?;
456

            
457
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
458

            
459
        Ok(MockNetListener { addr, receiver })
460
130
    }
461
}
462

            
463
#[async_trait]
464
impl TlsProvider<LocalStream> for MockNetProvider {
465
    type Connector = MockTlsConnector;
466
    type TlsStream = MockTlsStream;
467
    type Acceptor = MockTlsAcceptor;
468
    type TlsServerStream = MockTlsStream;
469

            
470
1125
    fn tls_connector(&self) -> MockTlsConnector {
471
1125
        MockTlsConnector {}
472
1125
    }
473
    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<MockTlsAcceptor> {
474
        Ok(MockTlsAcceptor {
475
            own_cert: settings.cert_der().to_vec(),
476
        })
477
    }
478

            
479
    fn supports_keying_material_export(&self) -> bool {
480
        false
481
    }
482
}
483

            
484
/// Mock TLS connector for use with MockNetProvider.
485
///
486
/// Note that no TLS is actually performed here: connections are simply
487
/// told that they succeeded with a given certificate.
488
#[derive(Clone)]
489
#[non_exhaustive]
490
pub struct MockTlsConnector;
491

            
492
/// Mock TLS acceptor for use with MockNetProvider.
493
///
494
/// Note that no TLS is actually performed here: connections are simply
495
/// told that they succeeded.
496
#[derive(Clone)]
497
#[non_exhaustive]
498
pub struct MockTlsAcceptor {
499
    /// The certificate that we are pretending to send.
500
    own_cert: Vec<u8>,
501
}
502

            
503
/// Mock TLS connector for use with MockNetProvider.
504
///
505
/// Note that no TLS is actually performed here: connections are simply
506
/// told that they succeeded with a given certificate.
507
///
508
/// Note also that we only use this type for client-side connections
509
/// right now: Arti doesn't support being a real TLS Listener yet,
510
/// since we only handle Tor client operations.
511
pub struct MockTlsStream {
512
    /// The peer certificate that we are pretending our peer has.
513
    peer_cert: Option<Vec<u8>>,
514
    /// The certificate that we are pretending that we sent.
515
    own_cert: Option<Vec<u8>>,
516
    /// The underlying stream.
517
    stream: LocalStream,
518
}
519

            
520
#[async_trait]
521
impl TlsConnector<LocalStream> for MockTlsConnector {
522
    type Conn = MockTlsStream;
523

            
524
    async fn negotiate_unvalidated(
525
        &self,
526
        mut stream: LocalStream,
527
        _sni_hostname: &str,
528
65
    ) -> IoResult<MockTlsStream> {
529
        let peer_cert = stream.tls_cert.take();
530

            
531
        if peer_cert.is_none() {
532
            return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
533
        }
534

            
535
        Ok(MockTlsStream {
536
            peer_cert,
537
            own_cert: None,
538
            stream,
539
        })
540
65
    }
541
}
542

            
543
#[async_trait]
544
impl TlsConnector<LocalStream> for MockTlsAcceptor {
545
    type Conn = MockTlsStream;
546

            
547
    async fn negotiate_unvalidated(
548
        &self,
549
        stream: LocalStream,
550
        _sni_hostname: &str,
551
    ) -> IoResult<MockTlsStream> {
552
        Ok(MockTlsStream {
553
            peer_cert: None,
554
            own_cert: Some(self.own_cert.clone()),
555
            stream,
556
        })
557
    }
558
}
559

            
560
impl CertifiedConn for MockTlsStream {
561
65
    fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
562
65
        Ok(self.peer_cert.clone().map(Cow::from))
563
65
    }
564

            
565
    fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
566
        Ok(self.own_cert.clone().map(Cow::from))
567
    }
568
    fn export_keying_material(
569
        &self,
570
        _len: usize,
571
        _label: &[u8],
572
        _context: Option<&[u8]>,
573
    ) -> IoResult<Vec<u8>> {
574
        Ok(Vec::new())
575
    }
576
}
577

            
578
impl AsyncRead for MockTlsStream {
579
831
    fn poll_read(
580
831
        mut self: Pin<&mut Self>,
581
831
        cx: &mut Context<'_>,
582
831
        buf: &mut [u8],
583
831
    ) -> Poll<IoResult<usize>> {
584
831
        Pin::new(&mut self.stream).poll_read(cx, buf)
585
831
    }
586
}
587
impl AsyncWrite for MockTlsStream {
588
224
    fn poll_write(
589
224
        mut self: Pin<&mut Self>,
590
224
        cx: &mut Context<'_>,
591
224
        buf: &[u8],
592
224
    ) -> Poll<IoResult<usize>> {
593
224
        Pin::new(&mut self.stream).poll_write(cx, buf)
594
224
    }
595
106
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
596
106
        Pin::new(&mut self.stream).poll_flush(cx)
597
106
    }
598
12
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
599
12
        Pin::new(&mut self.stream).poll_close(cx)
600
12
    }
601
}
602

            
603
impl StreamOps for MockTlsStream {
604
    fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
605
        Err(std::io::Error::new(
606
            std::io::ErrorKind::Unsupported,
607
            "not supported on non-StreamOps stream!",
608
        ))
609
    }
610

            
611
53
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
612
53
        Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
613
53
    }
614
}
615

            
616
/// Inner error type returned when a `MockNetwork` operation fails.
617
#[derive(Clone, Error, Debug)]
618
#[non_exhaustive]
619
pub enum MockNetError {
620
    /// General-purpose error.  The real information is in `ErrorKind`.
621
    #[error("Invalid operation on mock network")]
622
    BadOp,
623
}
624

            
625
/// Wrap `k` in a new [`std::io::Error`].
626
281
fn err(k: ErrorKind) -> IoError {
627
281
    IoError::new(k, MockNetError::BadOp)
628
281
}
629

            
630
#[cfg(all(test, not(miri)))] // miri cannot simulate the networking
631
mod test {
632
    // @@ begin test lint list maintained by maint/add_warning @@
633
    #![allow(clippy::bool_assert_comparison)]
634
    #![allow(clippy::clone_on_copy)]
635
    #![allow(clippy::dbg_macro)]
636
    #![allow(clippy::mixed_attributes_style)]
637
    #![allow(clippy::print_stderr)]
638
    #![allow(clippy::print_stdout)]
639
    #![allow(clippy::single_char_pattern)]
640
    #![allow(clippy::unwrap_used)]
641
    #![allow(clippy::unchecked_time_subtraction)]
642
    #![allow(clippy::useless_vec)]
643
    #![allow(clippy::needless_pass_by_value)]
644
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
645
    use super::*;
646
    use futures::io::{AsyncReadExt, AsyncWriteExt};
647
    use tor_rtcompat::test_with_all_runtimes;
648

            
649
    fn client_pair() -> (MockNetProvider, MockNetProvider) {
650
        let net = MockNetwork::new();
651
        let client1 = net
652
            .builder()
653
            .add_address("192.0.2.55".parse().unwrap())
654
            .provider();
655
        let client2 = net
656
            .builder()
657
            .add_address("198.51.100.7".parse().unwrap())
658
            .provider();
659

            
660
        (client1, client2)
661
    }
662

            
663
    #[test]
664
    fn end_to_end() {
665
        test_with_all_runtimes!(|_rt| async {
666
            let (client1, client2) = client_pair();
667
            let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
668
            let address = lis.local_addr()?;
669

            
670
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
671
                async {
672
                    let mut conn = client1.connect(&address).await?;
673
                    conn.write_all(b"This is totally a network.").await?;
674
                    conn.close().await?;
675

            
676
                    // Nobody listening here...
677
                    let a2 = "192.0.2.200:99".parse().unwrap();
678
                    let cant_connect = client1.connect(&a2).await;
679
                    assert!(cant_connect.is_err());
680
                    Ok(())
681
                },
682
                async {
683
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
684
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
685
                    let mut inp = Vec::new();
686
                    conn.read_to_end(&mut inp).await?;
687
                    assert_eq!(&inp[..], &b"This is totally a network."[..]);
688
                    Ok(())
689
                }
690
            );
691
            r1?;
692
            r2?;
693
            IoResult::Ok(())
694
        });
695
    }
696

            
697
    #[test]
698
    fn pick_listener_addr() -> IoResult<()> {
699
        let net = MockNetwork::new();
700
        let ip4 = "192.0.2.55".parse().unwrap();
701
        let ip6 = "2001:db8::7".parse().unwrap();
702
        let client = net.builder().add_address(ip4).add_address(ip6).provider();
703

            
704
        // Successful cases
705
        let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
706
        assert_eq!(a1.ip(), ip4);
707
        assert_eq!(a1.port(), 99);
708
        let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
709
        assert_eq!(a2.ip(), ip4);
710
        assert_eq!(a2.port(), 100);
711
        let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
712
        assert_eq!(a3.ip(), ip4);
713
        assert!(a3.port() != 0);
714
        let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
715
        assert_eq!(a4.ip(), ip4);
716
        assert!(a4.port() != 0);
717
        assert!(a4.port() != a3.port());
718
        let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
719
        assert_eq!(a5.ip(), ip6);
720
        assert_eq!(a5.port(), 99);
721
        let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
722
        assert_eq!(a6.ip(), ip6);
723
        assert_eq!(a6.port(), 100);
724

            
725
        // Failing cases
726
        let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
727
        let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
728
        assert!(e1.is_err());
729
        assert!(e2.is_err());
730

            
731
        IoResult::Ok(())
732
    }
733

            
734
    #[test]
735
    fn listener_stream() {
736
        test_with_all_runtimes!(|_rt| async {
737
            let (client1, client2) = client_pair();
738

            
739
            let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
740
            let address = lis.local_addr()?;
741
            let mut incoming = lis.incoming();
742

            
743
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
744
                async {
745
                    for _ in 0..3_u8 {
746
                        let mut c = client1.connect(&address).await?;
747
                        c.close().await?;
748
                    }
749
                    Ok(())
750
                },
751
                async {
752
                    for _ in 0..3_u8 {
753
                        let (mut c, a) = incoming.next().await.unwrap()?;
754
                        let mut v = Vec::new();
755
                        let _ = c.read_to_end(&mut v).await?;
756
                        assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
757
                    }
758
                    Ok(())
759
                }
760
            );
761
            r1?;
762
            r2?;
763
            IoResult::Ok(())
764
        });
765
    }
766

            
767
    #[test]
768
    fn tls_basics() {
769
        let (client1, client2) = client_pair();
770
        let cert = b"I am certified for something I assure you.";
771

            
772
        test_with_all_runtimes!(|_rt| async {
773
            let lis = client2
774
                .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
775
                .unwrap();
776
            let address = lis.local_addr().unwrap();
777

            
778
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
779
                async {
780
                    let connector = client1.tls_connector();
781
                    let conn = client1.connect(&address).await?;
782
                    let mut conn = connector
783
                        .negotiate_unvalidated(conn, "zombo.example.com")
784
                        .await?;
785
                    assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
786
                    conn.write_all(b"This is totally encrypted.").await?;
787
                    let mut v = Vec::new();
788
                    conn.read_to_end(&mut v).await?;
789
                    conn.close().await?;
790
                    assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
791
                    Ok(())
792
                },
793
                async {
794
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
795
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
796
                    let mut inp = [0_u8; 26];
797
                    conn.read_exact(&mut inp[..]).await?;
798
                    assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
799
                    conn.write_all(b"Yup, your secrets is safe").await?;
800
                    Ok(())
801
                }
802
            );
803
            r1?;
804
            r2?;
805
            IoResult::Ok(())
806
        });
807
    }
808
}