1
//! Implementations for the relay channel handshake
2

            
3
use futures::SinkExt;
4
use futures::io::{AsyncRead, AsyncWrite};
5
use futures::stream::{Stream, StreamExt};
6
use rand::Rng;
7
use safelog::Sensitive;
8
use std::net::IpAddr;
9
use std::{sync::Arc, time::SystemTime};
10
use tracing::trace;
11

            
12
use tor_cell::chancell::msg::AnyChanMsg;
13
use tor_cell::chancell::{AnyChanCell, ChanMsg, msg};
14
use tor_cell::restrict::{RestrictedMsg, restricted_msg};
15
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
16
use tor_rtcompat::{CertifiedConn, CoarseTimeProvider, SleepProvider, StreamOps};
17

            
18
use crate::channel::handshake::{
19
    ChannelBaseHandshake, ChannelInitiatorHandshake, UnverifiedChannel, unauthenticated_clock_skew,
20
};
21
use crate::channel::{ChannelFrame, ChannelType, UniqId, new_frame};
22
use crate::memquota::ChannelAccount;
23
use crate::peer::PeerAddr;
24
use crate::relay::channel::initiator::UnverifiedInitiatorRelayChannel;
25
use crate::relay::channel::responder::{
26
    MaybeVerifiableRelayResponderChannel, NonVerifiableResponderRelayChannel,
27
    UnverifiedResponderRelayChannel,
28
};
29
use crate::relay::channel::{RelayIdentities, build_certs_cell, build_netinfo_cell};
30
use crate::{Error, Result};
31

            
32
/// The "Ed25519-SHA256-RFC5705" link authentication which is value "00 03".
33
pub(super) static AUTHTYPE_ED25519_SHA256_RFC5705: u16 = 3;
34

            
35
/// A relay channel handshake as the initiator.
36
pub struct RelayInitiatorHandshake<
37
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
38
    S: CoarseTimeProvider + SleepProvider,
39
> {
40
    /// Runtime handle (insofar as we need it)
41
    sleep_prov: S,
42
    /// Memory quota account
43
    memquota: ChannelAccount,
44
    /// Underlying TLS stream in a channel frame.
45
    ///
46
    /// (We don't enforce that this is actually TLS, but if it isn't, the
47
    /// connection won't be secure.)
48
    framed_tls: ChannelFrame<T>,
49
    /// Logging identifier for this stream.  (Used for logging only.)
50
    unique_id: UniqId,
51
    /// Our identity keys needed for authentication.
52
    identities: Arc<RelayIdentities>,
53
    /// The peer we are attempting to connect to.
54
    target_method: ChannelMethod,
55
    /// Our advertised addresses. Needed for the NETINFO.
56
    my_addrs: Vec<IpAddr>,
57
}
58

            
59
/// Implement the base channel handshake trait.
60
impl<T, S> ChannelBaseHandshake<T> for RelayInitiatorHandshake<T, S>
61
where
62
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
63
    S: CoarseTimeProvider + SleepProvider,
64
{
65
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
66
        &mut self.framed_tls
67
    }
68
    fn unique_id(&self) -> &UniqId {
69
        &self.unique_id
70
    }
71
}
72

            
73
/// Implement the initiator channel handshake trait.
74
impl<T, S> ChannelInitiatorHandshake<T> for RelayInitiatorHandshake<T, S>
75
where
76
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
77
    S: CoarseTimeProvider + SleepProvider,
78
{
79
}
80

            
81
impl<
82
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
83
    S: CoarseTimeProvider + SleepProvider,
84
> RelayInitiatorHandshake<T, S>
85
{
86
    /// Constructor.
87
    pub(crate) fn new(
88
        tls: T,
89
        sleep_prov: S,
90
        identities: Arc<RelayIdentities>,
91
        my_addrs: Vec<IpAddr>,
92
        peer: &OwnedChanTarget,
93
        memquota: ChannelAccount,
94
    ) -> Self {
95
        Self {
96
            framed_tls: new_frame(tls, ChannelType::RelayInitiator),
97
            unique_id: UniqId::new(),
98
            sleep_prov,
99
            identities,
100
            memquota,
101
            my_addrs,
102
            target_method: peer.chan_method(),
103
        }
104
    }
105

            
106
    /// Connect to another relay as the relay Initiator.
107
    ///
108
    /// Takes a function that reports the current time.  In theory, this can just be
109
    /// `SystemTime::now()`.
110
    pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedInitiatorRelayChannel<T, S>>
111
    where
112
        F: FnOnce() -> SystemTime,
113
    {
114
        // Send the VERSIONS.
115
        let (versions_flushed_at, versions_flushed_wallclock) =
116
            self.send_versions_cell(now_fn).await?;
117

            
118
        // Receive the VERSIONS.
119
        let link_protocol = self.recv_versions_cell().await?;
120

            
121
        // Read until we have all the remaining cells from the responder.
122
        let (auth_challenge_cell, certs_cell, (netinfo_cell, netinfo_rcvd_at)) =
123
            self.recv_cells_from_responder().await?;
124

            
125
        trace!(stream_id = %self.unique_id,
126
            "received handshake, ready to verify.",
127
        );
128

            
129
        // Calculate our clock skew from the timings we just got/calculated.
130
        let clock_skew = unauthenticated_clock_skew(
131
            &netinfo_cell,
132
            netinfo_rcvd_at,
133
            versions_flushed_at,
134
            versions_flushed_wallclock,
135
        );
136

            
137
        Ok(UnverifiedInitiatorRelayChannel {
138
            inner: UnverifiedChannel {
139
                link_protocol,
140
                framed_tls: self.framed_tls,
141
                clock_skew,
142
                memquota: self.memquota,
143
                target_method: Some(self.target_method),
144
                unique_id: self.unique_id,
145
                sleep_prov: self.sleep_prov.clone(),
146
                certs_cell: Some(certs_cell),
147
            },
148
            auth_challenge_cell,
149
            netinfo_cell,
150
            identities: self.identities,
151
            my_addrs: self.my_addrs,
152
        })
153
    }
154
}
155

            
156
/// A relay channel handshake as the responder.
157
pub struct RelayResponderHandshake<
158
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
159
    S: CoarseTimeProvider + SleepProvider,
160
> {
161
    /// Runtime handle (insofar as we need it)
162
    sleep_prov: S,
163
    /// Memory quota account
164
    memquota: ChannelAccount,
165
    /// Underlying TLS stream in a channel frame.
166
    ///
167
    /// (We don't enforce that this is actually TLS, but if it isn't, the
168
    /// connection won't be secure.)
169
    framed_tls: ChannelFrame<T>,
170
    /// The peer IP address as in the address the initiator is connecting from. This can be a
171
    /// client so keep it sensitive.
172
    peer_addr: Sensitive<PeerAddr>,
173
    /// Our advertised addresses. Needed for the NETINFO.
174
    my_addrs: Vec<IpAddr>,
175
    /// Logging identifier for this stream.  (Used for logging only.)
176
    unique_id: UniqId,
177
    /// Our identity keys needed for authentication.
178
    identities: Arc<RelayIdentities>,
179
}
180

            
181
/// Implement the base channel handshake trait.
182
impl<T, S> ChannelBaseHandshake<T> for RelayResponderHandshake<T, S>
183
where
184
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
185
    S: CoarseTimeProvider + SleepProvider,
186
{
187
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
188
        &mut self.framed_tls
189
    }
190
    fn unique_id(&self) -> &UniqId {
191
        &self.unique_id
192
    }
193
}
194

            
195
impl<
196
    T: AsyncRead + AsyncWrite + CertifiedConn + StreamOps + Send + Unpin + 'static,
197
    S: CoarseTimeProvider + SleepProvider,
198
> RelayResponderHandshake<T, S>
199
{
200
    /// Constructor.
201
    pub(crate) fn new(
202
        peer_addr: Sensitive<PeerAddr>,
203
        my_addrs: Vec<IpAddr>,
204
        tls: T,
205
        sleep_prov: S,
206
        identities: Arc<RelayIdentities>,
207
        memquota: ChannelAccount,
208
    ) -> Self {
209
        Self {
210
            peer_addr,
211
            my_addrs,
212
            framed_tls: new_frame(
213
                tls,
214
                ChannelType::RelayResponder {
215
                    authenticated: false,
216
                },
217
            ),
218
            unique_id: UniqId::new(),
219
            sleep_prov,
220
            identities,
221
            memquota,
222
        }
223
    }
224

            
225
    /// Begin the handshake process.
226
    ///
227
    /// Takes a function that reports the current time.  In theory, this can just be
228
    /// `SystemTime::now()`.
229
    pub async fn handshake<F>(
230
        mut self,
231
        now_fn: F,
232
    ) -> Result<MaybeVerifiableRelayResponderChannel<T, S>>
233
    where
234
        F: FnOnce() -> SystemTime,
235
    {
236
        // Receive initiator VERSIONS.
237
        let link_protocol = self.recv_versions_cell().await?;
238

            
239
        // Send VERSION, CERTS, AUTH_CHALLENGE and NETINFO
240
        let (versions_flushed_at, versions_flushed_wallclock) =
241
            self.send_cells_to_initiator(now_fn).await?;
242

            
243
        // Receive NETINFO and possibly [CERTS, AUTHENTICATE]. The connection could be from a
244
        // client/bridge and thus no authentication meaning no CERTS/AUTHENTICATE cells.
245
        let (cells, (netinfo_cell, netinfo_rcvd_at)) = self.recv_cells_from_initiator().await?;
246
        let (certs_cell, auth_cell) = cells.unzip();
247

            
248
        // Calculate our clock skew from the timings we just got/calculated.
249
        let clock_skew = unauthenticated_clock_skew(
250
            &netinfo_cell,
251
            netinfo_rcvd_at,
252
            versions_flushed_at,
253
            versions_flushed_wallclock,
254
        );
255

            
256
        let inner = UnverifiedChannel {
257
            link_protocol,
258
            framed_tls: self.framed_tls,
259
            clock_skew,
260
            memquota: self.memquota,
261
            target_method: None,
262
            unique_id: self.unique_id,
263
            sleep_prov: self.sleep_prov,
264
            certs_cell,
265
        };
266

            
267
        // With an AUTHENTICATE cell, we can verify (relay). Else (client/bridge), we can't.
268
        Ok(match auth_cell {
269
            Some(auth_cell) => {
270
                MaybeVerifiableRelayResponderChannel::Verifiable(UnverifiedResponderRelayChannel {
271
                    inner,
272
                    auth_cell,
273
                    netinfo_cell,
274
                    identities: self.identities,
275
                    my_addrs: self.my_addrs,
276
                    peer_addr: self.peer_addr.into_inner(), // Relay address.
277
                })
278
            }
279
            None => MaybeVerifiableRelayResponderChannel::NonVerifiable(
280
                NonVerifiableResponderRelayChannel {
281
                    inner,
282
                    netinfo_cell,
283
                    my_addrs: self.my_addrs,
284
                    peer_addr: self.peer_addr,
285
                },
286
            ),
287
        })
288
    }
289

            
290
    /// Receive all the cells expected from the initiator of the connection. Keep in mind that it
291
    /// can be either a relay or client or bridge.
292
    async fn recv_cells_from_initiator(
293
        &mut self,
294
    ) -> Result<(
295
        Option<(msg::Certs, msg::Authenticate)>,
296
        (msg::Netinfo, coarsetime::Instant),
297
    )> {
298
        // IMPORTANT: Protocol wise, we MUST only allow one single cell of each type for a valid
299
        // handshake. Any duplicates lead to a failure.
300
        // They must arrive in a specific order in order for the CLOG calculation to be valid.
301

            
302
        /// Read a message from the stream.
303
        ///
304
        /// The `expecting` parameter is used for logging purposes, not filtering.
305
        async fn read_msg<T>(
306
            stream_id: UniqId,
307
            mut stream: impl Stream<Item = Result<AnyChanCell>> + Unpin,
308
        ) -> Result<T>
309
        where
310
            T: RestrictedMsg + TryFrom<AnyChanMsg, Error = AnyChanMsg>,
311
        {
312
            let Some(cell) = stream.next().await.transpose()? else {
313
                // The entire channel has ended, so nothing else to be done.
314
                return Err(Error::HandshakeProto("Stream ended unexpectedly".into()));
315
            };
316

            
317
            let (id, m) = cell.into_circid_and_msg();
318
            trace!(%stream_id, "received a {} cell", m.cmd());
319

            
320
            // TODO: Maybe also check this in the channel handshake codec?
321
            if let Some(id) = id {
322
                return Err(Error::HandshakeProto(format!(
323
                    "Expected no circ ID for {} cell, but received circ ID of {id} instead",
324
                    m.cmd(),
325
                )));
326
            }
327

            
328
            let m = m.try_into().map_err(|m: AnyChanMsg| {
329
                Error::HandshakeProto(format!(
330
                    "Expected [{}] cell, but received {} cell instead",
331
                    tor_basic_utils::iter_join(", ", T::cmds_for_logging().iter()),
332
                    m.cmd(),
333
                ))
334
            })?;
335

            
336
            Ok(m)
337
        }
338

            
339
        // Note that the `ChannelFrame` already restricts the messages due to its handshake cell
340
        // handler.
341

            
342
        // This is kind of ugly, but I don't see a nicer way to write the authentication branch
343
        // without a bunch of boilerplate for a state machine.
344
        let (certs_and_auth, netinfo, netinfo_rcvd_at) = 'outer: {
345
            // CERTS or NETINFO cell.
346
            let certs = loop {
347
                restricted_msg! {
348
                    enum CertsNetinfoMsg : ChanMsg {
349
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
350
                        Vpadding,
351
                        Netinfo,
352
                        Certs,
353
                   }
354
                }
355

            
356
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
357
                    CertsNetinfoMsg::Vpadding(_) => continue,
358
                    // If a NETINFO cell, the initiator did not authenticate and we can stop early.
359
                    CertsNetinfoMsg::Netinfo(msg) => {
360
                        break 'outer (None, msg, coarsetime::Instant::now());
361
                    }
362
                    // If a CERTS cell, the initiator is authenticating.
363
                    CertsNetinfoMsg::Certs(msg) => msg,
364
                };
365
            };
366

            
367
            // AUTHENTICATE cell.
368
            let auth = loop {
369
                restricted_msg! {
370
                    enum AuthenticateMsg : ChanMsg {
371
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
372
                        Vpadding,
373
                        Authenticate,
374
                   }
375
                }
376

            
377
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
378
                    AuthenticateMsg::Vpadding(_) => continue,
379
                    AuthenticateMsg::Authenticate(msg) => msg,
380
                };
381
            };
382

            
383
            // NETINFO cell (if we didn't receive it earlier).
384
            let (netinfo, netinfo_rcvd_at) = loop {
385
                restricted_msg! {
386
                    enum NetinfoMsg : ChanMsg {
387
                        // VPADDING cells (but not PADDING) can be sent during handshaking.
388
                        Vpadding,
389
                        Netinfo,
390
                   }
391
                }
392

            
393
                break match read_msg(*self.unique_id(), self.framed_tls()).await? {
394
                    NetinfoMsg::Vpadding(_) => continue,
395
                    NetinfoMsg::Netinfo(msg) => (msg, coarsetime::Instant::now()),
396
                };
397
            };
398

            
399
            (Some((certs, auth)), netinfo, netinfo_rcvd_at)
400
        };
401

            
402
        Ok((certs_and_auth, (netinfo, netinfo_rcvd_at)))
403
    }
404

            
405
    /// Send all expected cells to the initiator of the channel as the responder.
406
    ///
407
    /// Return the sending times of the [`msg::Versions`] so it can be used for clock skew
408
    /// validation.
409
    async fn send_cells_to_initiator<F>(
410
        &mut self,
411
        now_fn: F,
412
    ) -> Result<(coarsetime::Instant, SystemTime)>
413
    where
414
        F: FnOnce() -> SystemTime,
415
    {
416
        // Send the VERSIONS message.
417
        let (versions_flushed_at, versions_flushed_wallclock) =
418
            self.send_versions_cell(now_fn).await?;
419

            
420
        // Send the CERTS message.
421
        let certs = build_certs_cell(&self.identities, /* is_responder */ true);
422
        trace!(channel_id = %self.unique_id, "Sending CERTS as responder cell.");
423
        self.framed_tls.send(certs.into()).await?;
424

            
425
        // Send the AUTH_CHALLENGE.
426
        let challenge: [u8; 32] = rand::rng().random();
427
        let auth_challenge = msg::AuthChallenge::new(challenge, [AUTHTYPE_ED25519_SHA256_RFC5705]);
428
        trace!(channel_id = %self.unique_id, "Sending AUTH_CHALLENGE as responder cell.");
429
        self.framed_tls.send(auth_challenge.into()).await?;
430

            
431
        // Send the NETINFO message.
432
        let peer_ip = self.peer_addr.netinfo_addr();
433
        let netinfo = build_netinfo_cell(peer_ip, self.my_addrs.clone(), &self.sleep_prov)?;
434
        trace!(channel_id = %self.unique_id, "Sending NETINFO as responder cell.");
435
        self.framed_tls.send(netinfo.into()).await?;
436

            
437
        Ok((versions_flushed_at, versions_flushed_wallclock))
438
    }
439
}