1
//! Implementations for the client channel handshake
2

            
3
use digest::Digest;
4
use futures::SinkExt;
5
use futures::io::{AsyncRead, AsyncWrite};
6
use std::sync::Arc;
7
use std::time::SystemTime;
8
use tracing::{debug, instrument, trace};
9

            
10
use safelog::MaybeSensitive;
11
use tor_cell::chancell::msg;
12
use tor_linkspec::{ChannelMethod, OwnedChanTarget};
13
use tor_rtcompat::{CoarseTimeProvider, SleepProvider, StreamOps};
14

            
15
use crate::ClockSkew;
16
use crate::Result;
17
use crate::channel::handshake::{
18
    AuthLogAction, ChannelBaseHandshake, ChannelInitiatorHandshake, UnverifiedChannel,
19
    UnverifiedInitiatorChannel, VerifiedChannel, unauthenticated_clock_skew,
20
};
21
use crate::channel::{Channel, ChannelFrame, ChannelType, Reactor, UniqId, new_frame};
22
use crate::memquota::ChannelAccount;
23
use crate::peer::{PeerAddr, PeerInfo};
24

            
25
/// A raw client channel on which nothing has been done.
26
pub struct ClientInitiatorHandshake<
27
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
28
    S: CoarseTimeProvider + SleepProvider,
29
> {
30
    /// Runtime handle (insofar as we need it)
31
    sleep_prov: S,
32

            
33
    /// Memory quota account
34
    memquota: ChannelAccount,
35

            
36
    /// Cell encoder/decoder wrapping the underlying TLS stream
37
    ///
38
    /// (We don't enforce that this is actually TLS, but if it isn't, the
39
    /// connection won't be secure.)
40
    framed_tls: ChannelFrame<T>,
41

            
42
    /// Declared target method for this channel, if any.
43
    target_method: Option<ChannelMethod>,
44

            
45
    /// Logging identifier for this stream.  (Used for logging only.)
46
    unique_id: UniqId,
47
}
48

            
49
/// Implement the base channel handshake trait.
50
impl<T, S> ChannelBaseHandshake<T> for ClientInitiatorHandshake<T, S>
51
where
52
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
53
    S: CoarseTimeProvider + SleepProvider,
54
{
55
120
    fn framed_tls(&mut self) -> &mut ChannelFrame<T> {
56
120
        &mut self.framed_tls
57
120
    }
58
134
    fn unique_id(&self) -> &UniqId {
59
134
        &self.unique_id
60
134
    }
61
}
62

            
63
/// Implement the initiator channel handshake trait.
64
impl<T, S> ChannelInitiatorHandshake<T> for ClientInitiatorHandshake<T, S>
65
where
66
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
67
    S: CoarseTimeProvider + SleepProvider,
68
{
69
}
70

            
71
impl<
72
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
73
    S: CoarseTimeProvider + SleepProvider,
74
> ClientInitiatorHandshake<T, S>
75
{
76
    /// Construct a new ClientInitiatorHandshake.
77
24
    pub(crate) fn new(
78
24
        tls: T,
79
24
        target_method: Option<ChannelMethod>,
80
24
        sleep_prov: S,
81
24
        memquota: ChannelAccount,
82
24
    ) -> Self {
83
24
        Self {
84
24
            framed_tls: new_frame(tls, ChannelType::ClientInitiator),
85
24
            target_method,
86
24
            unique_id: UniqId::new(),
87
24
            sleep_prov,
88
24
            memquota,
89
24
        }
90
24
    }
91

            
92
    /// Negotiate a link protocol version with the relay, and read
93
    /// the relay's handshake information.
94
    ///
95
    /// Takes a function that reports the current time.  In theory, this can just be
96
    /// `SystemTime::get()`.
97
    #[instrument(skip_all, level = "trace")]
98
24
    pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedClientChannel<T, S>>
99
24
    where
100
24
        F: FnOnce() -> SystemTime,
101
24
    {
102
        match &self.target_method {
103
            Some(method) => debug!(
104
                stream_id = %self.unique_id,
105
                "starting Tor handshake with {:?}",
106
                method
107
            ),
108
            None => debug!(stream_id = %self.unique_id, "starting Tor handshake"),
109
        }
110
        // Send versions cell.
111
        let (versions_flushed_at, versions_flushed_wallclock) =
112
            self.send_versions_cell(now_fn).await?;
113

            
114
        // Receive versions cell.
115
        let link_protocol = self.recv_versions_cell().await?;
116

            
117
        // VERSIONS cell have been exchanged, set the link protocol into our channel frame.
118
        self.set_link_protocol(link_protocol)?;
119

            
120
        // Receive the relay responder cells. Ignore the AUTH_CHALLENGE cell and SLOG; we don't need
121
        // them as we are not authenticating with our responder because we are a client.
122
        let (_auth_chal_cell, certs_cell, (netinfo_cell, netinfo_rcvd_at), _slog) =
123
            self.recv_cells_from_responder(AuthLogAction::Leave).await?;
124

            
125
        // Get the clock skew.
126
        let clock_skew = unauthenticated_clock_skew(
127
            &netinfo_cell,
128
            netinfo_rcvd_at,
129
            versions_flushed_at,
130
            versions_flushed_wallclock,
131
        );
132

            
133
        trace!(stream_id = %self.unique_id, "received handshake, ready to verify.");
134

            
135
        Ok(UnverifiedClientChannel {
136
            inner: UnverifiedInitiatorChannel {
137
                inner: UnverifiedChannel {
138
                    link_protocol,
139
                    framed_tls: self.framed_tls,
140
                    clock_skew,
141
                    target_method: self.target_method.take(),
142
                    unique_id: self.unique_id,
143
                    sleep_prov: self.sleep_prov.clone(),
144
                    memquota: self.memquota.clone(),
145
                },
146
                certs_cell,
147
            },
148
            netinfo_cell,
149
        })
150
24
    }
151
}
152

            
153
/// A client channel on which versions have been negotiated and the relay's handshake has been
154
/// read, but where the certs have not been checked.
155
pub struct UnverifiedClientChannel<
156
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
157
    S: CoarseTimeProvider + SleepProvider,
158
> {
159
    /// Inner generic unverified initiator channel.
160
    inner: UnverifiedInitiatorChannel<T, S>,
161
    /// Received [`msg::Netinfo`] cell during the handshake.
162
    netinfo_cell: msg::Netinfo,
163
}
164

            
165
impl<
166
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
167
    S: CoarseTimeProvider + SleepProvider,
168
> UnverifiedClientChannel<T, S>
169
{
170
    /// Validate the certificates and keys in the relay's handshake. As a client, we always verify
171
    /// but we don't authenticate.
172
    ///
173
    /// 'peer_target' is the peer that we want to make sure we're connecting to.
174
    ///
175
    /// 'peer_tls_cert' is the x.509 certificate that the peer presented during
176
    /// its TLS handshake (ServerHello).
177
    ///
178
    /// 'now' is the time at which to check that certificates are
179
    /// valid.  `None` means to use the current time. It can be used
180
    /// for testing to override the current view of the time.
181
    ///
182
    /// This is a separate function because it's likely to be somewhat
183
    /// CPU-intensive.
184
    #[instrument(skip_all, level = "trace")]
185
2
    pub fn verify(
186
2
        self,
187
2
        peer_target: &OwnedChanTarget,
188
2
        peer_tls_cert: &[u8],
189
2
        now: Option<std::time::SystemTime>,
190
2
    ) -> Result<VerifiedClientChannel<T, S>> {
191
2
        let peer_cert_digest = tor_llcrypto::d::Sha256::digest(peer_tls_cert).into();
192
2
        let inner = self.inner.verify(peer_target, peer_cert_digest, now)?;
193

            
194
2
        Ok(VerifiedClientChannel {
195
2
            inner,
196
2
            netinfo_cell: self.netinfo_cell,
197
2
        })
198
2
    }
199

            
200
    /// Return the clock skew of this channel.
201
8
    pub fn clock_skew(&self) -> ClockSkew {
202
8
        self.inner.inner.clock_skew
203
8
    }
204

            
205
    /// Return the link protocol version of this channel.
206
    #[cfg(test)]
207
2
    pub(crate) fn link_protocol(&self) -> u16 {
208
2
        self.inner.inner.link_protocol
209
2
    }
210
}
211

            
212
/// A client channel on which versions have been negotiated, relay's handshake has been read, but
213
/// the client has not yet finished the handshake.
214
///
215
/// This type is separate from UnverifiedClientChannel, since finishing the handshake requires a
216
/// bunch of CPU, and you might want to do it as a separate task or after a yield.
217
pub struct VerifiedClientChannel<
218
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
219
    S: CoarseTimeProvider + SleepProvider,
220
> {
221
    /// Inner generic verified channel.
222
    inner: VerifiedChannel<T, S>,
223
    /// Received [`msg::Netinfo`] cell during the handshake.
224
    netinfo_cell: msg::Netinfo,
225
}
226

            
227
impl<
228
    T: AsyncRead + AsyncWrite + StreamOps + Send + Unpin + 'static,
229
    S: CoarseTimeProvider + SleepProvider,
230
> VerifiedClientChannel<T, S>
231
{
232
    /// Send a NETINFO message to the relay to finish the handshake, and create an open channel and
233
    /// reactor.
234
    ///
235
    /// The `peer_addr` is sensitive because it can be a secret bridge or guard.
236
    ///
237
    /// The channel is used to send cells, and to create outgoing circuits. The reactor is used to
238
    /// route incoming messages to their appropriate circuit.
239
    #[instrument(skip_all, level = "trace")]
240
2
    pub async fn finish(
241
2
        mut self,
242
2
        peer_addr: MaybeSensitive<PeerAddr>,
243
2
    ) -> Result<(Arc<Channel>, Reactor<S>)> {
244
        // Send the NETINFO message.
245
        let netinfo = msg::Netinfo::from_client(peer_addr.netinfo_addr());
246
        trace!(stream_id = %self.inner.unique_id, "Sending netinfo cell.");
247
        self.inner.framed_tls.send(netinfo.into()).await?;
248

            
249
        // This could be a client Guard so it is sensitive.
250
        let peer_info = MaybeSensitive::sensitive(PeerInfo::new(
251
            peer_addr.inner(),
252
            self.inner.relay_ids().clone(),
253
        ));
254

            
255
        // Finish the channel to get a reactor.
256
        self.inner.finish(&self.netinfo_cell, &[], peer_info).await
257
2
    }
258
}