1
//! Implementations for the relay channel handshake
2

            
3
use futures::SinkExt;
4
use futures::io::{AsyncRead, AsyncWrite};
5
use rand::Rng;
6
use safelog::Sensitive;
7
use std::net::IpAddr;
8
use std::{sync::Arc, time::SystemTime};
9
use tracing::trace;
10

            
11
use tor_cell::chancell::msg;
12
use tor_cell::restrict::restricted_msg;
13
use tor_error::internal;
14
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
15
use tor_rtcompat::{CertifiedConn, CoarseTimeProvider, SleepProvider, StreamOps};
16

            
17
use crate::Result;
18
use crate::channel::handshake::{
19
    AuthLogAction, ChannelBaseHandshake, ChannelInitiatorHandshake, UnverifiedChannel,
20
    UnverifiedInitiatorChannel, read_msg, unauthenticated_clock_skew,
21
};
22
use crate::channel::{AuthLogDigest, ChannelFrame, ChannelType, UniqId, new_frame};
23
use crate::memquota::ChannelAccount;
24
use crate::peer::PeerAddr;
25
use crate::relay::channel::initiator::UnverifiedInitiatorRelayChannel;
26
use crate::relay::channel::responder::{
27
    MaybeVerifiableRelayResponderChannel, NonVerifiableResponderRelayChannel,
28
    UnverifiedResponderRelayChannel,
29
};
30
use crate::relay::channel::{RelayChannelAuthMaterial, build_certs_cell, build_netinfo_cell};
31

            
32
/// The "Ed25519-SHA256-RFC5705" link authentication which is value "00 03".
33
pub(super) static AUTHTYPE_ED25519_SHA256_RFC5705: u16 = 3;
34

            
35
/// A relay channel handshake as the initiator.
36
pub struct RelayInitiatorHandshake<
37
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
38
    S: CoarseTimeProvider + SleepProvider,
39
> {
40
    /// Runtime handle (insofar as we need it)
41
    sleep_prov: S,
42
    /// Memory quota account
43
    memquota: ChannelAccount,
44
    /// Underlying TLS stream in a channel frame.
45
    ///
46
    /// (We don't enforce that this is actually TLS, but if it isn't, the
47
    /// connection won't be secure.)
48
    framed_tls: ChannelFrame<T>,
49
    /// Logging identifier for this stream.  (Used for logging only.)
50
    unique_id: UniqId,
51
    /// Our identity keys needed for authentication.
52
    auth_material: Arc<RelayChannelAuthMaterial>,
53
    /// The peer we are attempting to connect to.
54
    target_method: ChannelMethod,
55
    /// Our advertised addresses. Needed for the NETINFO.
56
    my_addrs: Vec<IpAddr>,
57
}
58

            
59
/// Implement the base channel handshake trait.
60
impl<T, S> ChannelBaseHandshake<T> for RelayInitiatorHandshake<T, S>
61
where
62
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
63
    S: CoarseTimeProvider + SleepProvider,
64
{
65
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
66
        &mut self.framed_tls
67
    }
68
    fn unique_id(&self) -> &UniqId {
69
        &self.unique_id
70
    }
71
}
72

            
73
/// Implement the initiator channel handshake trait.
74
impl<T, S> ChannelInitiatorHandshake<T> for RelayInitiatorHandshake<T, S>
75
where
76
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
77
    S: CoarseTimeProvider + SleepProvider,
78
{
79
}
80

            
81
impl<
82
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
83
    S: CoarseTimeProvider + SleepProvider,
84
> RelayInitiatorHandshake<T, S>
85
{
86
    /// Constructor.
87
    pub(crate) fn new(
88
        tls: T,
89
        sleep_prov: S,
90
        auth_material: Arc<RelayChannelAuthMaterial>,
91
        my_addrs: Vec<IpAddr>,
92
        peer_target: &OwnedChanTarget,
93
        memquota: ChannelAccount,
94
    ) -> Self {
95
        Self {
96
            framed_tls: new_frame(tls, ChannelType::RelayInitiator),
97
            unique_id: UniqId::new(),
98
            sleep_prov,
99
            auth_material,
100
            memquota,
101
            my_addrs,
102
            target_method: peer_target.chan_method(),
103
        }
104
    }
105

            
106
    /// Connect to another relay as the relay Initiator.
107
    ///
108
    /// Takes a function that reports the current time.  In theory, this can just be
109
    /// `SystemTime::get()`.
110
    pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedInitiatorRelayChannel<T, S>>
111
    where
112
        F: FnOnce() -> SystemTime,
113
    {
114
        // Send the VERSIONS.
115
        let (versions_flushed_at, versions_flushed_wallclock) =
116
            self.send_versions_cell(now_fn).await?;
117

            
118
        // Receive the VERSIONS.
119
        let link_protocol = self.recv_versions_cell().await?;
120

            
121
        // VERSIONS cell have been exchanged, set the link protocol into our channel frame.
122
        self.set_link_protocol(link_protocol)?;
123

            
124
        // Read until we have all the remaining cells from the responder.
125
        let (auth_challenge_cell, certs_cell, (netinfo_cell, netinfo_rcvd_at), slog_digest) =
126
            self.recv_cells_from_responder(AuthLogAction::Take).await?;
127

            
128
        // TODO: It would be nice to come up with a better design for getting the SLOG.
129
        let slog_digest = slog_digest.ok_or(internal!("Asked for SLOG, but `None` returned?"))?;
130

            
131
        trace!(stream_id = %self.unique_id,
132
            "received handshake, ready to verify.",
133
        );
134

            
135
        // Calculate our clock skew from the timings we just got/calculated.
136
        let clock_skew = unauthenticated_clock_skew(
137
            &netinfo_cell,
138
            netinfo_rcvd_at,
139
            versions_flushed_at,
140
            versions_flushed_wallclock,
141
        );
142

            
143
        Ok(UnverifiedInitiatorRelayChannel {
144
            inner: UnverifiedInitiatorChannel {
145
                inner: UnverifiedChannel {
146
                    link_protocol,
147
                    framed_tls: self.framed_tls,
148
                    clock_skew,
149
                    memquota: self.memquota,
150
                    target_method: Some(self.target_method),
151
                    unique_id: self.unique_id,
152
                    sleep_prov: self.sleep_prov.clone(),
153
                },
154
                certs_cell,
155
            },
156
            auth_challenge_cell,
157
            slog_digest,
158
            netinfo_cell,
159
            auth_material: self.auth_material,
160
            my_addrs: self.my_addrs,
161
        })
162
    }
163
}
164

            
165
/// A relay channel handshake as the responder.
166
pub struct RelayResponderHandshake<
167
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
168
    S: CoarseTimeProvider + SleepProvider,
169
> {
170
    /// Runtime handle (insofar as we need it)
171
    sleep_prov: S,
172
    /// Memory quota account
173
    memquota: ChannelAccount,
174
    /// Underlying TLS stream in a channel frame.
175
    ///
176
    /// (We don't enforce that this is actually TLS, but if it isn't, the
177
    /// connection won't be secure.)
178
    framed_tls: ChannelFrame<T>,
179
    /// The peer IP address as in the address the initiator is connecting from. This can be a
180
    /// client so keep it sensitive.
181
    peer_addr: Sensitive<PeerAddr>,
182
    /// Our advertised addresses. Needed for the NETINFO.
183
    my_addrs: Vec<IpAddr>,
184
    /// Logging identifier for this stream.  (Used for logging only.)
185
    unique_id: UniqId,
186
    /// Our identity keys needed for authentication.
187
    auth_material: Arc<RelayChannelAuthMaterial>,
188
}
189

            
190
/// Implement the base channel handshake trait.
191
impl<T, S> ChannelBaseHandshake<T> for RelayResponderHandshake<T, S>
192
where
193
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
194
    S: CoarseTimeProvider + SleepProvider,
195
{
196
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
197
        &mut self.framed_tls
198
    }
199
    fn unique_id(&self) -> &UniqId {
200
        &self.unique_id
201
    }
202
}
203

            
204
impl<
205
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
206
    S: CoarseTimeProvider + SleepProvider,
207
> RelayResponderHandshake<T, S>
208
{
209
    /// Constructor.
210
    pub(crate) fn new(
211
        peer_addr: Sensitive<PeerAddr>,
212
        my_addrs: Vec<IpAddr>,
213
        tls: T,
214
        sleep_prov: S,
215
        auth_material: Arc<RelayChannelAuthMaterial>,
216
        memquota: ChannelAccount,
217
    ) -> Self {
218
        Self {
219
            peer_addr,
220
            my_addrs,
221
            framed_tls: new_frame(
222
                tls,
223
                ChannelType::RelayResponder {
224
                    authenticated: false,
225
                },
226
            ),
227
            unique_id: UniqId::new(),
228
            sleep_prov,
229
            auth_material,
230
            memquota,
231
        }
232
    }
233

            
234
    /// Begin the handshake process.
235
    ///
236
    /// Takes a function that reports the current time.  In theory, this can just be
237
    /// `SystemTime::get()`.
238
    pub async fn handshake<F>(
239
        mut self,
240
        now_fn: F,
241
    ) -> Result<MaybeVerifiableRelayResponderChannel<T, S>>
242
    where
243
        F: FnOnce() -> SystemTime,
244
    {
245
        // Receive initiator VERSIONS.
246
        let link_protocol = self.recv_versions_cell().await?;
247

            
248
        // Send the VERSIONS message.
249
        let (versions_flushed_at, versions_flushed_wallclock) =
250
            self.send_versions_cell(now_fn).await?;
251

            
252
        // VERSIONS cell have been exchanged, set the link protocol into our channel frame.
253
        self.set_link_protocol(link_protocol)?;
254

            
255
        // Send CERTS, AUTH_CHALLENGE and NETINFO
256
        let slog_digest = self.send_cells_to_initiator().await?;
257

            
258
        // Receive NETINFO and possibly [CERTS, AUTHENTICATE]. The connection could be from a
259
        // client/bridge and thus no authentication meaning no CERTS/AUTHENTICATE cells.
260
        let (certs_and_auth_and_clog, (netinfo_cell, netinfo_rcvd_at)) =
261
            self.recv_cells_from_initiator().await?;
262

            
263
        // Try to unpack these into something we can use later.
264
        let (certs_cell, auth_and_clog) = match certs_and_auth_and_clog {
265
            Some((certs, auth, clog)) => (Some(certs), Some((auth, clog))),
266
            None => (None, None),
267
        };
268

            
269
        // Calculate our clock skew from the timings we just got/calculated.
270
        let clock_skew = unauthenticated_clock_skew(
271
            &netinfo_cell,
272
            netinfo_rcvd_at,
273
            versions_flushed_at,
274
            versions_flushed_wallclock,
275
        );
276

            
277
        let inner = UnverifiedChannel {
278
            link_protocol,
279
            framed_tls: self.framed_tls,
280
            clock_skew,
281
            memquota: self.memquota,
282
            target_method: None,
283
            unique_id: self.unique_id,
284
            sleep_prov: self.sleep_prov,
285
        };
286

            
287
        // With an AUTHENTICATE cell, we can verify (relay). Else (client/bridge), we can't.
288
        Ok(match auth_and_clog {
289
            Some((auth_cell, clog_digest)) => {
290
                MaybeVerifiableRelayResponderChannel::Verifiable(UnverifiedResponderRelayChannel {
291
                    inner,
292
                    auth_cell,
293
                    netinfo_cell,
294
                    // TODO(relay): Should probably put that in the match {} and not assume.
295
                    certs_cell: certs_cell.expect("AUTHENTICATE cell without CERTS cell"),
296
                    auth_material: self.auth_material,
297
                    my_addrs: self.my_addrs,
298
                    peer_addr: self.peer_addr.into_inner(), // Relay address.
299
                    clog_digest,
300
                    slog_digest,
301
                })
302
            }
303
            None => MaybeVerifiableRelayResponderChannel::NonVerifiable(
304
                NonVerifiableResponderRelayChannel {
305
                    inner,
306
                    netinfo_cell,
307
                    my_addrs: self.my_addrs,
308
                    peer_addr: self.peer_addr,
309
                },
310
            ),
311
        })
312
    }
313

            
314
    /// Receive all the cells expected from the initiator of the connection. Keep in mind that it
315
    /// can be either a relay or client or bridge.
316
    async fn recv_cells_from_initiator(
317
        &mut self,
318
    ) -> Result<(
319
        Option<(msg::Certs, msg::Authenticate, AuthLogDigest /* CLOG */)>,
320
        (msg::Netinfo, coarsetime::Instant),
321
    )> {
322
        // IMPORTANT: Protocol wise, we MUST only allow one single cell of each type for a valid
323
        // handshake. Any duplicates lead to a failure.
324
        // They must arrive in a specific order in order for the CLOG calculation to be valid.
325

            
326
        // Note that the `ChannelFrame` already restricts the messages due to its handshake cell
327
        // handler.
328

            
329
        // This is kind of ugly, but I don't see a nicer way to write the authentication branch
330
        // without a bunch of boilerplate for a state machine.
331
        let (certs_and_auth_and_clog, netinfo, netinfo_rcvd_at) = 'outer: {
332
            // CERTS or NETINFO cell.
333
            let certs = loop {
334
                restricted_msg! {
335
                    enum CertsNetinfoMsg : ChanMsg {
336
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
337
                        Vpadding,
338
                        Netinfo,
339
                        Certs,
340
                   }
341
                }
342

            
343
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
344
                    CertsNetinfoMsg::Vpadding(_) => continue,
345
                    // If a NETINFO cell, the initiator did not authenticate and we can stop early.
346
                    CertsNetinfoMsg::Netinfo(msg) => {
347
                        break 'outer (None, msg, coarsetime::Instant::now());
348
                    }
349
                    // If a CERTS cell, the initiator is authenticating.
350
                    CertsNetinfoMsg::Certs(msg) => msg,
351
                };
352
            };
353

            
354
            // We're the responder, which means that the recv log is the CLOG.
355
            let clog_digest = self.framed_tls().codec_mut().take_recv_log_digest()?;
356

            
357
            // AUTHENTICATE cell.
358
            let auth = loop {
359
                restricted_msg! {
360
                    enum AuthenticateMsg : ChanMsg {
361
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
362
                        Vpadding,
363
                        Authenticate,
364
                   }
365
                }
366

            
367
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
368
                    AuthenticateMsg::Vpadding(_) => continue,
369
                    AuthenticateMsg::Authenticate(msg) => msg,
370
                };
371
            };
372

            
373
            // NETINFO cell (if we didn't receive it earlier).
374
            let (netinfo, netinfo_rcvd_at) = loop {
375
                restricted_msg! {
376
                    enum NetinfoMsg : ChanMsg {
377
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
378
                        Vpadding,
379
                        Netinfo,
380
                   }
381
                }
382

            
383
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
384
                    NetinfoMsg::Vpadding(_) => continue,
385
                    NetinfoMsg::Netinfo(msg) => (msg, coarsetime::Instant::now()),
386
                };
387
            };
388

            
389
            (Some((certs, auth, clog_digest)), netinfo, netinfo_rcvd_at)
390
        };
391

            
392
        Ok((certs_and_auth_and_clog, (netinfo, netinfo_rcvd_at)))
393
    }
394

            
395
    /// Send all expected cells to the initiator of the channel as the responder.
396
    ///
397
    /// Return the SLOG (send log) digest to be later used when verifying the initiator's
398
    /// AUTHENTICATE cell.
399
    async fn send_cells_to_initiator(&mut self) -> Result<AuthLogDigest> {
400
        // Send the CERTS message.
401
        let certs = build_certs_cell(&self.auth_material, /* is_responder */ true);
402
        trace!(channel_id = %self.unique_id, "Sending CERTS as responder cell.");
403
        self.framed_tls.send(certs.into()).await?;
404

            
405
        // Send the AUTH_CHALLENGE.
406
        let challenge: [u8; 32] = rand::rng().random();
407
        let auth_challenge = msg::AuthChallenge::new(challenge, [AUTHTYPE_ED25519_SHA256_RFC5705]);
408
        trace!(channel_id = %self.unique_id, "Sending AUTH_CHALLENGE as responder cell.");
409
        self.framed_tls.send(auth_challenge.into()).await?;
410

            
411
        // We're the responder, which means that the send log is the SLOG.
412
        let slog_digest = self.framed_tls.codec_mut().take_send_log_digest()?;
413

            
414
        // Send the NETINFO message.
415
        let peer_ip = self.peer_addr.netinfo_addr();
416
        let netinfo = build_netinfo_cell(peer_ip, self.my_addrs.clone(), &self.sleep_prov)?;
417
        trace!(channel_id = %self.unique_id, "Sending NETINFO as responder cell.");
418
        self.framed_tls.send(netinfo.into()).await?;
419

            
420
        Ok(slog_digest)
421
    }
422
}