1
//! Infrastructure required to support managed PTs.
2

            
3
use crate::config::{ManagedTransportOptions, TransportOptions};
4
use crate::err;
5
use crate::err::PtError;
6
use crate::ipc::{
7
    PluggableClientTransport, PluggableTransport, PtClientParameters, PtCommonParameters,
8
    sealed::PluggableTransportPrivate,
9
};
10
use crate::{PtClientMethod, PtSharedState};
11
use futures::channel::mpsc::UnboundedReceiver;
12
use futures::stream::FuturesUnordered;
13
use futures::{FutureExt, StreamExt, select};
14
use oneshot_fused_workaround as oneshot;
15
use std::collections::{HashMap, HashSet};
16
use std::future::Future;
17
use std::path::{Path, PathBuf};
18
use std::pin::Pin;
19
use std::sync::{Arc, RwLock};
20
use tor_config_path::CfgPathResolver;
21
use tor_error::internal;
22
use tor_linkspec::PtTransportName;
23
use tor_rtcompat::Runtime;
24
use tracing::{debug, warn};
25

            
26
/// A message to the `PtReactor`.
27
pub(crate) enum PtReactorMessage {
28
    /// Notify the reactor that the currently configured set of PTs has changed.
29
    Reconfigured,
30
    /// Ask the reactor to spawn a pluggable transport binary.
31
    Spawn {
32
        /// Spawn a binary to provide this PT.
33
        pt: PtTransportName,
34
        /// Notify the result via this channel.
35
        result: oneshot::Sender<err::Result<PtClientMethod>>,
36
    },
37
}
38

            
39
/// The result of a spawn attempt: the list of transports the spawned binary covers, and the result.
40
type SpawnResult = (Vec<PtTransportName>, err::Result<PluggableClientTransport>);
41

            
42
/// Background reactor to handle managing pluggable transport binaries.
43
pub(crate) struct PtReactor<R> {
44
    /// Runtime.
45
    rt: R,
46
    /// Currently running pluggable transport binaries.
47
    running: Vec<PluggableClientTransport>,
48
    /// A map of asked-for transports.
49
    ///
50
    /// If a transport name has an entry, we will append any additional requests for that entry.
51
    /// If no entry is present, we will start a request.
52
    requests: HashMap<PtTransportName, Vec<oneshot::Sender<err::Result<PtClientMethod>>>>,
53
    /// FuturesUnordered that spawned tasks get pushed on to.
54
    ///
55
    /// WARNING: This MUST always contain one "will never resolve" future!
56
    spawning: FuturesUnordered<Pin<Box<dyn Future<Output = SpawnResult> + Send>>>,
57
    /// State for the corresponding PtMgr.
58
    state: Arc<RwLock<PtSharedState>>,
59
    /// PtMgr channel.
60
    /// (Unbounded so that we can reconfigure without blocking: we're unlikely to have the reactor
61
    /// get behind.)
62
    rx: UnboundedReceiver<PtReactorMessage>,
63
    /// State directory.
64
    state_dir: PathBuf,
65
    /// Path resolver for configuration files.
66
    path_resolver: Arc<CfgPathResolver>,
67
}
68

            
69
impl<R: Runtime> PtReactor<R> {
70
    /// Make a new reactor.
71
22
    pub(crate) fn new(
72
22
        rt: R,
73
22
        state: Arc<RwLock<PtSharedState>>,
74
22
        rx: UnboundedReceiver<PtReactorMessage>,
75
22
        state_dir: PathBuf,
76
22
        path_resolver: Arc<CfgPathResolver>,
77
22
    ) -> Self {
78
22
        let spawning = FuturesUnordered::new();
79
22
        spawning.push(Box::pin(futures::future::pending::<SpawnResult>())
80
22
            as Pin<Box<dyn Future<Output = _> + Send>>);
81
22
        Self {
82
22
            rt,
83
22
            running: vec![],
84
22
            requests: Default::default(),
85
22
            spawning,
86
22
            state,
87
22
            rx,
88
22
            state_dir,
89
22
            path_resolver,
90
22
        }
91
22
    }
92

            
93
    /// Called when a spawn request completes.
94
    #[allow(clippy::needless_pass_by_value)]
95
    fn handle_spawned(
96
        &mut self,
97
        covers: Vec<PtTransportName>,
98
        result: err::Result<PluggableClientTransport>,
99
    ) {
100
        match result {
101
            Err(e) => {
102
                warn!("Spawning PT for {:?} failed: {}", covers, e);
103
                // Go and tell all the transports about the bad news.
104
                let senders = covers
105
                    .iter()
106
                    .flat_map(|x| self.requests.remove(x))
107
                    .flatten();
108
                for sender in senders {
109
                    // We don't really care if the sender went away.
110
                    let _ = sender.send(Err(e.clone()));
111
                }
112
            }
113
            Ok(pt) => {
114
                let mut state = self.state.write().expect("ptmgr state poisoned");
115
                for (transport, method) in pt.transport_methods() {
116
                    state
117
                        .managed_cmethods
118
                        .insert(transport.clone(), method.clone());
119
                    for sender in self.requests.remove(transport).into_iter().flatten() {
120
                        let _ = sender.send(Ok(method.clone()));
121
                    }
122
                }
123

            
124
                let requested: HashSet<_> = covers.iter().collect();
125
                let found: HashSet<_> = pt.transport_methods().keys().collect();
126
                if requested != found {
127
                    warn!(
128
                        "Bug: PT {} succeeded, but did not give the same transports we asked for. ({:?} vs {:?})",
129
                        pt.identifier(),
130
                        found,
131
                        requested
132
                    );
133
                }
134
                self.running.push(pt);
135
            }
136
        }
137
    }
138

            
139
    /// Called to remove a pluggable transport from the shared state.
140
    fn remove_pt(&self, pt: PluggableClientTransport) {
141
        let mut state = self.state.write().expect("ptmgr state poisoned");
142
        for transport in pt.transport_methods().keys() {
143
            state.managed_cmethods.remove(transport);
144
        }
145
        // to satisfy clippy, and make it clear that this is a desired side-effect: doing this
146
        // shuts down the PT (asynchronously).
147
        drop(pt);
148
    }
149

            
150
    /// Run one step of the reactor. Returns true if the reactor should terminate.
151
26
    pub(crate) async fn run_one_step(&mut self) -> err::Result<bool> {
152
        use futures::future::Either;
153

            
154
        // FIXME(eta): This allocates a lot, which is technically unnecessary but requires careful
155
        //             engineering to get right. It's not really in the hot path, at least.
156
26
        let mut all_next_messages = self
157
26
            .running
158
26
            .iter_mut()
159
            // We could avoid the Box, but that'd require using unsafe to replicate what tokio::pin!
160
            // does under the hood.
161
26
            .map(|pt| Box::pin(pt.next_message()))
162
26
            .collect::<Vec<_>>();
163

            
164
        // We can't construct a select_all if all_next_messages is empty.
165
26
        let mut next_message = if all_next_messages.is_empty() {
166
26
            Either::Left(futures::future::pending())
167
        } else {
168
            Either::Right(futures::future::select_all(all_next_messages.iter_mut()).fuse())
169
        };
170

            
171
26
        select! {
172
            (result, idx, _) = next_message => {
173
                drop(all_next_messages); // no idea why NLL doesn't just infer this but sure
174

            
175
                match result {
176
                    Ok(m) => {
177
                        // FIXME(eta): We should forward the Status messages onto API consumers.
178
                        debug!("PT {} message: {:?}", self.running[idx].identifier(), m);
179
                    },
180
                    Err(e) => {
181
                        warn!("PT {} quit: {:?}", self.running[idx].identifier(), e);
182
                        let pt = self.running.remove(idx);
183
                        self.remove_pt(pt);
184
                    }
185
                }
186
            },
187
26
            spawn_result = self.spawning.next() => {
188
                drop(all_next_messages);
189
                // See the Warning in this field's documentation.
190
                let (covers, result) = spawn_result.expect("self.spawning should never dry up");
191
                self.handle_spawned(covers, result);
192
            }
193
26
            internal = self.rx.next() => {
194
4
                drop(all_next_messages);
195

            
196
4
                match internal {
197
4
                    Some(PtReactorMessage::Reconfigured) => {},
198
                    Some(PtReactorMessage::Spawn { pt, result }) => {
199
                        // Make sure we don't already have a running request.
200
                        if let Some(requests) = self.requests.get_mut(&pt) {
201
                            requests.push(result);
202
                            return Ok(false);
203
                        }
204
                        // Make sure we don't already have a binary for this PT.
205
                        for rpt in self.running.iter() {
206
                            if let Some(cmethod) = rpt.transport_methods().get(&pt) {
207
                                let _ = result.send(Ok(cmethod.clone()));
208
                                return Ok(false);
209
                            }
210
                        }
211
                        // We don't, so time to spawn one.
212
                        let config = {
213
                            let state = self.state.read().expect("ptmgr state poisoned");
214
                            state.configured.get(&pt).cloned()
215
                        };
216

            
217
                        let Some(config) = config else {
218
                            let _ = result.send(Err(PtError::UnconfiguredTransportDueToConcurrentReconfiguration));
219
                            return Ok(false);
220
                        };
221

            
222
                        let TransportOptions::Managed(config) = config else {
223
                            let _ = result.send(Err(internal!("Tried to spawn an unmanaged transport").into()));
224
                            return Ok(false);
225
                        };
226

            
227
                        // Keep track of the request, and also fill holes in other protocols so
228
                        // we don't try and run another spawn request for those.
229
                        self.requests.entry(pt).or_default().push(result);
230
                        for proto in config.protocols.iter() {
231
                            self.requests.entry(proto.clone()).or_default();
232
                        }
233

            
234
                        // Add the spawn future to our pile of them.
235
                        let spawn_fut = Box::pin(
236
                            spawn_from_config(
237
                                self.rt.clone(),
238
                                self.state_dir.clone(),
239
                                config.clone(),
240
                                Arc::clone(&self.path_resolver)
241
                            )
242
                            .map(|result| (config.protocols, result))
243
                        );
244
                        self.spawning.push(spawn_fut);
245
                    },
246
                    None => return Ok(true)
247
                }
248
            }
249
        }
250
4
        Ok(false)
251
4
    }
252
}
253

            
254
/// Spawn a managed `PluggableTransport` using a `ManagedTransportOptions`.
255
async fn spawn_from_config<R: Runtime>(
256
    rt: R,
257
    state_dir: PathBuf,
258
    cfg: ManagedTransportOptions,
259
    path_resolver: Arc<CfgPathResolver>,
260
) -> Result<PluggableClientTransport, PtError> {
261
    // FIXME(eta): I really think this expansion should happen at builder validation time...
262

            
263
    let cfg_path = cfg.path;
264

            
265
    let binary_path = cfg_path
266
        .path(&path_resolver)
267
        .map_err(|e| PtError::PathExpansionFailed {
268
            path: cfg_path.clone(),
269
            error: e,
270
        })?;
271

            
272
    let filename = pt_identifier_as_path(&binary_path)?;
273

            
274
    // HACK(eta): Currently the state directory is named after the PT binary name. Maybe we should
275
    //            invent a better way of doing this?
276
    let new_state_dir = state_dir.join(filename);
277
    std::fs::create_dir_all(&new_state_dir).map_err(|e| PtError::StatedirCreateFailed {
278
        path: new_state_dir.clone(),
279
        error: Arc::new(e),
280
    })?;
281

            
282
    // FIXME(eta): make the rest of these parameters configurable
283
    let pt_common_params = PtCommonParameters::builder()
284
        .state_location(new_state_dir)
285
        .build()
286
        .expect("PtCommonParameters constructed incorrectly");
287

            
288
    let pt_client_params = PtClientParameters::builder()
289
        .transports(cfg.protocols)
290
        .build()
291
        .expect("PtClientParameters constructed incorrectly");
292

            
293
    let mut pt = PluggableClientTransport::new(
294
        binary_path,
295
        cfg.arguments,
296
        pt_common_params,
297
        pt_client_params,
298
    );
299
    pt.launch(rt).await?;
300
    Ok(pt)
301
}
302

            
303
/// Given a path to a binary for a pluggable transport, return an identifier for
304
/// that binary in a format that can be used as a path component.
305
fn pt_identifier_as_path(binary_path: impl AsRef<Path>) -> Result<PathBuf, PtError> {
306
    // Extract the final component.
307
    let mut filename =
308
        PathBuf::from(
309
            binary_path
310
                .as_ref()
311
                .file_name()
312
                .ok_or_else(|| PtError::NotAFile {
313
                    path: binary_path.as_ref().to_path_buf(),
314
                })?,
315
        );
316

            
317
    // Strip an "exe" off the end, if appropriate.
318
    if let Some(ext) = filename.extension() {
319
        if ext.eq_ignore_ascii_case(std::env::consts::EXE_EXTENSION) {
320
            filename.set_extension("");
321
        }
322
    }
323

            
324
    Ok(filename)
325
}
326

            
327
/// Given a path to a binary for a pluggable transport, return an identifier for
328
/// that binary in human-readable form.
329
pub(crate) fn pt_identifier(binary_path: impl AsRef<Path>) -> Result<String, PtError> {
330
    Ok(pt_identifier_as_path(binary_path)?
331
        .to_string_lossy()
332
        .to_string())
333
}