1
//! Declare a type for streams that do hostname lookups
2

            
3
use crate::client::stream::StreamReceiver;
4
use crate::memquota::StreamAccount;
5
use crate::stream::cmdcheck::{AnyCmdChecker, CmdChecker, StreamStatus};
6
use crate::{Error, Result};
7

            
8
use futures::StreamExt;
9
use tor_cell::relaycell::RelayCmd;
10
use tor_cell::relaycell::msg::Resolved;
11
use tor_cell::restricted_msg;
12

            
13
/// A ResolveStream represents a pending DNS request made with a RESOLVE
14
/// cell.
15
pub struct ResolveStream {
16
    /// The underlying RawCellStream.
17
    s: StreamReceiver,
18

            
19
    /// The memory quota account that should be used for this "stream"'s data
20
    ///
21
    /// Exists to keep the account alive
22
    _memquota: StreamAccount,
23
}
24

            
25
restricted_msg! {
26
    /// An allowable reply for a RESOLVE message.
27
    enum ResolveResponseMsg : RelayMsg {
28
        End,
29
        Resolved,
30
    }
31
}
32

            
33
impl ResolveStream {
34
    /// Wrap a RawCellStream into a ResolveStream.
35
    ///
36
    /// Call only after sending a RESOLVE cell.
37
    pub(crate) fn new(s: StreamReceiver, memquota: StreamAccount) -> Self {
38
        ResolveStream {
39
            s,
40
            _memquota: memquota,
41
        }
42
    }
43

            
44
    /// Read a message from this stream telling us the answer to our
45
    /// name lookup request.
46
    pub async fn read_msg(&mut self) -> Result<Resolved> {
47
        use ResolveResponseMsg::*;
48
        let cell = match self.s.next().await {
49
            Some(cell) => cell?,
50
            None => return Err(Error::NotConnected),
51
        };
52
        let msg = match cell.decode::<ResolveResponseMsg>() {
53
            Ok(cell) => cell.into_msg(),
54
            Err(e) => {
55
                self.s.protocol_error();
56
                return Err(Error::from_bytes_err(e, "response on a resolve stream"));
57
            }
58
        };
59
        match msg {
60
            End(e) => Err(Error::EndReceived(e.reason())),
61
            Resolved(r) => Ok(r),
62
        }
63
    }
64
}
65

            
66
/// A `CmdChecker` that enforces correctness for incoming commands on an
67
/// outbound resolve stream.
68
///
69
/// NOTE(prop349): this implements the "Resolve Stream Handler".
70
/// This is set via [crate::ClientTunnel::begin_stream_impl],
71
/// which installs the checker on the last hop in the circuit.
72
///
73
/// This is called via `CircHop::deliver_msg_to_stream`.
74
/// Errors are propagated all the way up to
75
/// [`Circuit::handle_cell`](crate::client::reactor::circuit::Circuit),
76
/// and eventually end up being returned from the reactor's `run_once`
77
/// function, causing it to shut down.
78
///
79
/// [`StreamStatus::Closed`] is handled in the `CircHop`'s
80
/// stream map (by marking the stream as closed, or returning
81
/// a CircProto error, as appropriate).
82
#[derive(Debug, Default)]
83
pub(crate) struct ResolveCmdChecker {}
84

            
85
impl CmdChecker for ResolveCmdChecker {
86
    fn check_msg(&mut self, msg: &tor_cell::relaycell::UnparsedRelayMsg) -> Result<StreamStatus> {
87
        use StreamStatus::Closed;
88
        match msg.cmd() {
89
            RelayCmd::RESOLVED => Ok(Closed),
90
            RelayCmd::END => Ok(Closed),
91
            _ => Err(Error::StreamProto(format!(
92
                "Unexpected {} on resolve stream",
93
                msg.cmd()
94
            ))),
95
        }
96
    }
97

            
98
    fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
99
        let _ = msg
100
            .decode::<ResolveResponseMsg>()
101
            .map_err(|err| Error::from_bytes_err(err, "message on resolve stream."))?;
102
        Ok(())
103
    }
104
}
105

            
106
impl ResolveCmdChecker {
107
    /// Return a new boxed `DataCmdChecker` in a state suitable for a newly
108
    /// constructed connection.
109
    pub(crate) fn new_any() -> AnyCmdChecker {
110
        Box::<Self>::default()
111
    }
112
}