Skip to content

Commit

Permalink
simplify server shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Dec 19, 2023
1 parent bc040d8 commit 80c7ac1
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 147 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.5.1

- Make server shutdown more robust.

## 0.5.0

- UDP connections are now modeled as streams.
Expand Down
81 changes: 24 additions & 57 deletions mitmproxy-rs/src/server/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,31 @@ use crate::task::PyInteropTask;
use anyhow::Result;

use mitmproxy::packet_sources::{PacketSourceConf, PacketSourceTask};
use mitmproxy::shutdown::ShutdownTask;
use mitmproxy::shutdown::shutdown_task;
use pyo3::prelude::*;
#[cfg(target_os = "macos")]
use std::path::Path;

use tokio::task::JoinSet;
use tokio::{sync::broadcast, sync::mpsc};

#[derive(Debug)]
pub struct Server {
/// channel for notifying subtasks of requested server shutdown
sd_trigger: broadcast::Sender<()>,
/// channel for getting notified of successful server shutdown
sd_barrier: broadcast::Sender<()>,
/// flag to indicate whether server shutdown is in progress
closing: bool,
shutdown_done: broadcast::Receiver<()>,
start_shutdown: Option<broadcast::Sender<()>>,
}

impl Server {
pub fn close(&mut self) {
if !self.closing {
self.closing = true;
// XXX: Does not really belong here.
#[cfg(target_os = "macos")]
{
if Path::new("/Applications/MitmproxyAppleTunnel.app").exists() {
std::fs::remove_dir_all("/Applications/MitmproxyAppleTunnel.app").expect(
"Failed to remove MitmproxyAppleTunnel.app from Applications folder",
);
}
}
if let Some(trigger) = self.start_shutdown.take() {
log::info!("Shutting down.");
// notify tasks to shut down
let _ = self.sd_trigger.send(());
trigger.send(()).ok();
}
}

pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
let mut barrier = self.sd_barrier.subscribe();
let mut receiver = self.shutdown_done.resubscribe();
pyo3_asyncio::tokio::future_into_py(py, async move {
barrier.recv().await.map_err(|_| {
pyo3::exceptions::PyRuntimeError::new_err("Failed to wait for server shutdown.")
})
receiver.recv().await.ok();
Ok(())
})
}
}
Expand All @@ -62,61 +45,45 @@ impl Server {
let typ = packet_source_conf.name();
log::debug!("Initializing {} ...", typ);

// initialize channels between the virtual network device and the python interop task
// - only used to notify of incoming connections and datagrams
// Channel used to notify Python land of incoming connections.
let (transport_events_tx, transport_events_rx) = mpsc::channel(256);
// - used to send data and to ask for packets
// This channel needs to be unbounded because write() is not async.
// Channel used to send data and ask for packets.
// This needs to be unbounded because write() is not async.
let (transport_commands_tx, transport_commands_rx) = mpsc::unbounded_channel();

// initialize barriers for handling graceful shutdown
let shutdown = broadcast::channel(1).0;
let shutdown_done = broadcast::channel(1).0;
// Channel used to trigger graceful shutdown
let (shutdown_start_tx, shutdown_start_rx) = broadcast::channel(1);

let (packet_source_task, data) = packet_source_conf
.build(
transport_events_tx,
transport_commands_rx,
shutdown.subscribe(),
shutdown_start_rx.resubscribe(),
)
.await?;

// initialize Python interop task
// Note: The current asyncio event loop needs to be determined here on the main thread.
let py_loop: PyObject = Python::with_gil(|py| {
let py_loop = pyo3_asyncio::tokio::get_current_loop(py)?.into_py(py);
Ok::<PyObject, PyErr>(py_loop)
})?;

let py_task = PyInteropTask::new(
py_loop,
transport_commands_tx,
transport_events_rx,
py_tcp_handler,
py_udp_handler,
shutdown.subscribe(),
);
shutdown_start_rx,
)?;

// spawn tasks
let wg_handle = tokio::spawn(async move { packet_source_task.run().await });
let py_handle = tokio::spawn(async move { py_task.run().await });
let mut tasks = JoinSet::new();
tasks.spawn(async move { packet_source_task.run().await });
tasks.spawn(async move { py_task.run().await });

// initialize and run shutdown handler
let sd_task = ShutdownTask::new(
py_handle,
wg_handle,
shutdown.clone(),
shutdown_done.clone(),
);
tokio::spawn(async move { sd_task.run().await });
let (shutdown_done_tx, shutdown_done_rx) = broadcast::channel(1);
tokio::spawn(shutdown_task(tasks, shutdown_done_tx));

log::debug!("{} successfully initialized.", typ);

Ok((
Server {
sd_trigger: shutdown,
sd_barrier: shutdown_done,
closing: false,
shutdown_done: shutdown_done_rx,
start_shutdown: Some(shutdown_start_tx),
},
data,
))
Expand Down
30 changes: 16 additions & 14 deletions mitmproxy-rs/src/task.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::Result;
use anyhow::{Context, Result};
use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals;
use tokio::sync::{broadcast, mpsc, Mutex};
Expand All @@ -12,7 +12,7 @@ use crate::stream::Stream;
use crate::stream::StreamState;

pub struct PyInteropTask {
py_loop: PyObject,
locals: TaskLocals,
transport_commands: mpsc::UnboundedSender<TransportCommand>,
transport_events: mpsc::Receiver<TransportEvent>,
py_tcp_handler: PyObject,
Expand All @@ -23,30 +23,32 @@ pub struct PyInteropTask {
impl PyInteropTask {
#[allow(clippy::too_many_arguments)]
pub fn new(
py_loop: PyObject,
transport_commands: mpsc::UnboundedSender<TransportCommand>,
transport_events: mpsc::Receiver<TransportEvent>,
py_tcp_handler: PyObject,
py_udp_handler: PyObject,
sd_watcher: broadcast::Receiver<()>,
) -> Self {
PyInteropTask {
py_loop,
shutdown: broadcast::Receiver<()>,
) -> Result<Self> {
// Note: The current asyncio event loop needs to be determined here on the main thread.
let locals = Python::with_gil(|py| -> Result<TaskLocals, PyErr> {
let py_loop = pyo3_asyncio::tokio::get_current_loop(py)?.into_py(py);
TaskLocals::new(py_loop.as_ref(py)).copy_context(py)
})
.context("failed to get python task locals")?;

Ok(PyInteropTask {
locals,
transport_commands,
transport_events,
py_tcp_handler,
py_udp_handler,
shutdown: sd_watcher,
}
shutdown,
})
}

pub async fn run(mut self) -> Result<()> {
let active_streams = Arc::new(Mutex::new(HashMap::new()));

let locals = Python::with_gil(|py| {
TaskLocals::new(self.py_loop.as_ref(py)).copy_context(self.py_loop.as_ref(py).py())
})?;

loop {
tokio::select! {
// wait for graceful shutdown
Expand Down Expand Up @@ -90,7 +92,7 @@ impl PyInteropTask {
};

// convert Python awaitable into Rust Future
let future = pyo3_asyncio::into_future_with_locals(&locals, coro.as_ref(py))?;
let future = pyo3_asyncio::into_future_with_locals(&self.locals, coro.as_ref(py))?;

// run Future on a new Tokio task
let handle = {
Expand Down
93 changes: 17 additions & 76 deletions src/shutdown.rs
Original file line number Diff line number Diff line change
@@ -1,84 +1,25 @@
use std::sync::Arc;

use anyhow::Result;
use tokio::{
sync::{broadcast::Sender as BroadcastSender, RwLock},
task::JoinHandle,
};

pub struct ShutdownTask {
py_handle: JoinHandle<Result<()>>,
wg_handle: JoinHandle<Result<()>>,
sd_trigger: BroadcastSender<()>,
sd_barrier: BroadcastSender<()>,
}

impl ShutdownTask {
pub fn new(
py_handle: JoinHandle<Result<()>>,
wg_handle: JoinHandle<Result<()>>,
sd_trigger: BroadcastSender<()>,
sd_barrier: BroadcastSender<()>,
) -> Self {
ShutdownTask {
py_handle,
wg_handle,
sd_trigger,
sd_barrier,
}
}

pub async fn run(self) {
let mut sd_watcher = self.sd_trigger.subscribe();
let shutting_down = Arc::new(RwLock::new(false));
use tokio::sync::broadcast;
use tokio::task::JoinSet;

// wait for Python interop task to return
let py_sd_trigger = self.sd_trigger.clone();
let py_shutting_down = shutting_down.clone();
let py_task_handle = tokio::spawn(async move {
match self.py_handle.await {
Ok(Ok(())) => (),
Ok(Err(error)) => log::error!("Python interop task failed: {}", error),
Err(error) => log::error!("Python interop task panicked: {}", error),
pub async fn shutdown_task(mut tasks: JoinSet<Result<()>>, shutdown_done: broadcast::Sender<()>) {
while let Some(task) = tasks.join_next().await {
match task {
Ok(Ok(())) => (),
Ok(Err(error)) => {
log::error!("Task failed: {}\n{}", error, error.backtrace().to_string());
tasks.shutdown().await;
}

if !*py_shutting_down.read().await {
log::error!("Python interop task shut down early, exiting.");
let _ = py_sd_trigger.send(());
Err(error) => {
if error.is_cancelled() {
log::error!("Task cancelled: {}", error);
} else {
log::error!("Task panicked: {}", error);
}
tasks.shutdown().await;
}
});

// wait for WireGuard server task to return
let wg_sd_trigger = self.sd_trigger.clone();
let wg_shutting_down = shutting_down.clone();
let wg_task_handle = tokio::spawn(async move {
match self.wg_handle.await {
Ok(Ok(())) => (),
Ok(Err(error)) => log::error!("Proxy server task failed: {}", error),
Err(error) => log::error!("Proxy server task panicked: {}", error),
}

if !*wg_shutting_down.read().await {
log::error!("Proxy server task shut down early, exiting.");
let _ = wg_sd_trigger.send(());
}
});

// wait for shutdown trigger:
// - either `Server.stop` was called, or
// - one of the subtasks failed early
let _ = sd_watcher.recv().await;
*shutting_down.write().await = true;

// wait for all tasks to terminate and log any errors
if let Err(error) = py_task_handle.await {
log::error!("Shutdown of Python interop task failed: {}", error);
}
if let Err(error) = wg_task_handle.await {
log::error!("Shutdown of WireGuard server task failed: {}", error);
}

// make `Server.wait_closed` method yield
self.sd_barrier.send(()).ok();
}
shutdown_done.send(()).ok();
}

0 comments on commit 80c7ac1

Please sign in to comment.