1
//! Implementations for the client channel handshake
2

            
3
use futures::SinkExt;
4
use futures::io::{AsyncRead, AsyncWrite};
5
use std::sync::Arc;
6
use std::time::SystemTime;
7
use tracing::{debug, instrument, trace};
8

            
9
use safelog::{MaybeSensitive, Sensitive};
10
use tor_cell::chancell::msg;
11
use tor_linkspec::{ChannelMethod, OwnedChanTarget};
12
use tor_rtcompat::{CoarseTimeProvider, SleepProvider, StreamOps};
13

            
14
use crate::ClockSkew;
15
use crate::Result;
16
use crate::channel::handshake::{
17
    ChannelBaseHandshake, ChannelInitiatorHandshake, UnverifiedChannel, VerifiedChannel,
18
    unauthenticated_clock_skew,
19
};
20
use crate::channel::{Channel, ChannelFrame, ChannelType, Reactor, UniqId, new_frame};
21
use crate::memquota::ChannelAccount;
22
use crate::peer::{PeerAddr, PeerInfo};
23

            
24
/// A raw client channel on which nothing has been done.
25
pub struct ClientInitiatorHandshake<
26
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
27
    S: CoarseTimeProvider + SleepProvider,
28
> {
29
    /// Runtime handle (insofar as we need it)
30
    sleep_prov: S,
31

            
32
    /// Memory quota account
33
    memquota: ChannelAccount,
34

            
35
    /// Cell encoder/decoder wrapping the underlying TLS stream
36
    ///
37
    /// (We don't enforce that this is actually TLS, but if it isn't, the
38
    /// connection won't be secure.)
39
    framed_tls: ChannelFrame<T>,
40

            
41
    /// Declared target method for this channel, if any.
42
    target_method: Option<ChannelMethod>,
43

            
44
    /// Logging identifier for this stream.  (Used for logging only.)
45
    unique_id: UniqId,
46
}
47

            
48
/// Implement the base channel handshake trait.
49
impl<T, S> ChannelBaseHandshake<T> for ClientInitiatorHandshake<T, S>
50
where
51
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
52
    S: CoarseTimeProvider + SleepProvider,
53
{
54
120
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
55
120
        &mut self.framed_tls
56
120
    }
57
134
    fn unique_id(&self) -> &UniqId {
58
134
        &self.unique_id
59
134
    }
60
}
61

            
62
/// Implement the initiator channel handshake trait.
63
impl<T, S> ChannelInitiatorHandshake<T> for ClientInitiatorHandshake<T, S>
64
where
65
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
66
    S: CoarseTimeProvider + SleepProvider,
67
{
68
}
69

            
70
impl<
71
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
72
    S: CoarseTimeProvider + SleepProvider,
73
> ClientInitiatorHandshake<T, S>
74
{
75
    /// Construct a new ClientInitiatorHandshake.
76
24
    pub(crate) fn new(
77
24
        tls: T,
78
24
        target_method: Option<ChannelMethod>,
79
24
        sleep_prov: S,
80
24
        memquota: ChannelAccount,
81
24
    ) -> Self {
82
24
        Self {
83
24
            framed_tls: new_frame(tls, ChannelType::ClientInitiator),
84
24
            target_method,
85
24
            unique_id: UniqId::new(),
86
24
            sleep_prov,
87
24
            memquota,
88
24
        }
89
24
    }
90

            
91
    /// Negotiate a link protocol version with the relay, and read
92
    /// the relay's handshake information.
93
    ///
94
    /// Takes a function that reports the current time.  In theory, this can just be
95
    /// `SystemTime::now()`.
96
    #[instrument(skip_all, level = "trace")]
97
24
    pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedClientChannel<T, S>>
98
24
    where
99
24
        F: FnOnce() -> SystemTime,
100
24
    {
101
        match &self.target_method {
102
            Some(method) => debug!(
103
                stream_id = %self.unique_id,
104
                "starting Tor handshake with {:?}",
105
                method
106
            ),
107
            None => debug!(stream_id = %self.unique_id, "starting Tor handshake"),
108
        }
109
        // Send versions cell.
110
        let (versions_flushed_at, versions_flushed_wallclock) =
111
            self.send_versions_cell(now_fn).await?;
112

            
113
        // Receive versions cell.
114
        let link_protocol = self.recv_versions_cell().await?;
115

            
116
        // Receive the relay responder cells. Ignore the AUTH_CHALLENGE cell, we don't need it as
117
        // we are not authenticating with our responder because we are a client.
118
        let (_, certs_cell, (netinfo_cell, netinfo_rcvd_at)) =
119
            self.recv_cells_from_responder().await?;
120

            
121
        // Get the clock skew.
122
        let clock_skew = unauthenticated_clock_skew(
123
            &netinfo_cell,
124
            netinfo_rcvd_at,
125
            versions_flushed_at,
126
            versions_flushed_wallclock,
127
        );
128

            
129
        trace!(stream_id = %self.unique_id, "received handshake, ready to verify.");
130

            
131
        Ok(UnverifiedClientChannel {
132
            inner: UnverifiedChannel {
133
                link_protocol,
134
                framed_tls: self.framed_tls,
135
                certs_cell: Some(certs_cell),
136
                clock_skew,
137
                target_method: self.target_method.take(),
138
                unique_id: self.unique_id,
139
                sleep_prov: self.sleep_prov.clone(),
140
                memquota: self.memquota.clone(),
141
            },
142
            netinfo_cell,
143
        })
144
24
    }
145
}
146

            
147
/// A client channel on which versions have been negotiated and the relay's handshake has been
148
/// read, but where the certs have not been checked.
149
pub struct UnverifiedClientChannel<
150
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
151
    S: CoarseTimeProvider + SleepProvider,
152
> {
153
    /// Inner generic unverified channel.
154
    inner: UnverifiedChannel<T, S>,
155
    /// Received [`msg::Netinfo`] cell during the handshake.
156
    netinfo_cell: msg::Netinfo,
157
}
158

            
159
impl<
160
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
161
    S: CoarseTimeProvider + SleepProvider,
162
> UnverifiedClientChannel<T, S>
163
{
164
    /// Validate the certificates and keys in the relay's handshake. As a client, we always verify
165
    /// but we don't authenticate.
166
    ///
167
    /// 'peer' is the peer that we want to make sure we're connecting to.
168
    ///
169
    /// 'peer_cert' is the x.509 certificate that the peer presented during
170
    /// its TLS handshake (ServerHello).
171
    ///
172
    /// 'now' is the time at which to check that certificates are
173
    /// valid.  `None` means to use the current time. It can be used
174
    /// for testing to override the current view of the time.
175
    ///
176
    /// This is a separate function because it's likely to be somewhat
177
    /// CPU-intensive.
178
    #[instrument(skip_all, level = "trace")]
179
2
    pub fn verify(
180
2
        self,
181
2
        peer: &OwnedChanTarget,
182
2
        peer_cert: &[u8],
183
2
        now: Option<std::time::SystemTime>,
184
2
    ) -> Result<VerifiedClientChannel<T, S>> {
185
2
        let inner = self.inner.check(peer, peer_cert, now)?;
186
2
        Ok(VerifiedClientChannel {
187
2
            inner,
188
2
            netinfo_cell: self.netinfo_cell,
189
2
        })
190
2
    }
191

            
192
    /// Return the clock skew of this channel.
193
8
    pub fn clock_skew(&self) -> ClockSkew {
194
8
        self.inner.clock_skew
195
8
    }
196

            
197
    /// Return the link protocol version of this channel.
198
    #[cfg(test)]
199
2
    pub(crate) fn link_protocol(&self) -> u16 {
200
2
        self.inner.link_protocol
201
2
    }
202
}
203

            
204
/// A client channel on which versions have been negotiated, relay's handshake has been read, but
205
/// the client has not yet finished the handshake.
206
///
207
/// This type is separate from UnverifiedClientChannel, since finishing the handshake requires a
208
/// bunch of CPU, and you might want to do it as a separate task or after a yield.
209
pub struct VerifiedClientChannel<
210
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
211
    S: CoarseTimeProvider + SleepProvider,
212
> {
213
    /// Inner generic verified channel.
214
    inner: VerifiedChannel<T, S>,
215
    /// Received [`msg::Netinfo`] cell during the handshake.
216
    netinfo_cell: msg::Netinfo,
217
}
218

            
219
impl<
220
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
221
    S: CoarseTimeProvider + SleepProvider,
222
> VerifiedClientChannel<T, S>
223
{
224
    /// Send a NETINFO message to the relay to finish the handshake, and create an open channel and
225
    /// reactor.
226
    ///
227
    /// The `peer_addr` is sensitive because it can be a secret bridge or guard.
228
    ///
229
    /// The channel is used to send cells, and to create outgoing circuits. The reactor is used to
230
    /// route incoming messages to their appropriate circuit.
231
    #[instrument(skip_all, level = "trace")]
232
2
    pub async fn finish(
233
2
        mut self,
234
2
        peer_addr: Sensitive<PeerAddr>,
235
2
    ) -> Result<(Arc<Channel>, Reactor<S>)> {
236
        // Send the NETINFO message.
237
        let netinfo = msg::Netinfo::from_client(peer_addr.netinfo_addr());
238
        trace!(stream_id = %self.inner.unique_id, "Sending netinfo cell.");
239
        self.inner.framed_tls.send(netinfo.into()).await?;
240

            
241
        // This could be a client Guard so it is sensitive.
242
        let peer_info = MaybeSensitive::sensitive(PeerInfo::new(
243
            peer_addr.into_inner(),
244
            self.inner.relay_ids()?,
245
        ));
246

            
247
        // Finish the channel to get a reactor.
248
        self.inner.finish(&self.netinfo_cell, &[], peer_info).await
249
2
    }
250
}