From ae3e50fe9501cc76d56bba9bcc4f6f878a4a28da Mon Sep 17 00:00:00 2001 From: Jukka Taimisto Date: Wed, 20 Dec 2023 15:21:07 +0200 Subject: [PATCH] Terminate the packet writer early if stop signal is received If SIGTERM or SIGINT is received, we should not drain the channel before stopping. --- src/channel.rs | 11 +++++++++-- src/main.rs | 18 ++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 84a3f9d..769ec60 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -2,6 +2,7 @@ use std::{ fmt::Display, sync::{ + atomic::AtomicBool, mpsc::{self, Receiver, SendError, Sender}, Arc, Condvar, Mutex, }, @@ -45,6 +46,7 @@ pub struct Rx { recv: Receiver, ctx: Arc<(Mutex, Condvar)>, watermark_lo: u64, + stop: Arc, } /// Iterator for reading packets. @@ -57,6 +59,9 @@ impl Iterator for IntoRxIter { type Item = Packet; fn next(&mut self) -> Option { + if self.rx.stop.load(std::sync::atomic::Ordering::Relaxed) { + return None; + } let (mux, cvar) = &*self.rx.ctx; let packet = self.rx.recv.recv().ok(); if packet.is_some() { @@ -126,12 +131,13 @@ impl Tx { } /// Creates a channel, returning [Tx] and [Rx] for a channel that allows -/// `hi` number of packets to be queued. +/// `hi` number of packets to be queued. `stop` can be used to signal that +/// [Rx] should terminate immediately instead of draining the buffer. /// /// When hi number of packets are queued, the [Tx::write_packet()] will /// block until packets are consumed from channel and only `lo` number of /// packets are left. -pub fn create(hi: u64, lo: u64) -> (Tx, Rx) { +pub fn create(hi: u64, lo: u64, stop: Arc) -> (Tx, Rx) { let (sender, recv) = mpsc::channel(); let ctx = Arc::new(( Mutex::new(ChannelContext { @@ -151,6 +157,7 @@ pub fn create(hi: u64, lo: u64) -> (Tx, Rx) { recv, ctx: ctx2, watermark_lo: lo, + stop, }, ) } diff --git a/src/main.rs b/src/main.rs index e3a85d4..0da4a70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -73,18 +73,19 @@ fn input_task( terminate: Arc, limit: Option, ) -> i32 { + let stop = terminate.clone(); let rd_handle: thread::JoinHandle> = thread::Builder::new() .name("pcap-reader".to_string()) .spawn(move || { loop { let inp = method.to_pcap_input()?; let it = match limit { - Some(n) => Box::new(inp.packets(&terminate)?.take(n)) + Some(n) => Box::new(inp.packets(&stop)?.take(n)) as Box>, - None => Box::new(inp.packets(&terminate)?), + None => Box::new(inp.packets(&stop)?), }; pipe::read_packets_to(it, &tx)?; - if !loop_file || terminate.load(std::sync::atomic::Ordering::Relaxed) { + if !loop_file || stop.load(std::sync::atomic::Ordering::Relaxed) { break; } tracing::info!("pcap file iteration complete"); @@ -94,8 +95,13 @@ fn input_task( .unwrap(); let mut ret = 0; if let Err(err) = rd_handle.join().unwrap() { - tracing::error!("Error while reading packets: {}", err); - ret = -1; + // if we have received signal indicating we should stop, discard + // reader errors as the packet writer might have terminated + // already and reader just complains about closed channel. + if !terminate.load(std::sync::atomic::Ordering::Relaxed) { + tracing::error!("Error while reading packets: {}", err); + ret = -1; + } } tracing::trace!("Reader terminated"); match pipe.wait() { @@ -238,7 +244,7 @@ fn main() { rate = Rate::Full; } - let (tx, rx) = channel::create(ch_hi, ch_low); + let (tx, rx) = channel::create(ch_hi, ch_low, terminate.clone()); let stat_period = params.stats.map(Duration::from_secs); let (stats, stat_printer) = if let Some(period) = stat_period { let (s, r) = pipe::Stats::periodic(period);