diff --git a/src/error.rs b/src/error.rs index 8bc16aa..b18b291 100644 --- a/src/error.rs +++ b/src/error.rs @@ -56,7 +56,7 @@ pub enum TunnelError { /// use elsewhere. AddrInUse = 8, /// Tunnel already listened. - TunnelInUse = 9, + AlreadyListened = 9, /// Unknown [std::io::ErrorKind] error. Unknown = u8::MAX, } diff --git a/src/tunnel.rs b/src/tunnel.rs index 85b1309..44f676c 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -28,6 +28,8 @@ pub struct TunnelServer { peer_id: PeerId, next_tunnel_id: Arc, pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + listener_cancel_token: Option, + listener: Option>, } pub struct TunnelServerListener { @@ -35,6 +37,7 @@ pub struct TunnelServerListener { next_tunnel_id: Arc, pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, tunnels: HashMap, + cancel_token: CancellationToken, } pub struct Tunnel { @@ -54,6 +57,23 @@ pub struct TunnelListener { cancel_token: CancellationToken, } +impl Drop for TunnelServer { + fn drop(&mut self) { + if let Some(cancel_token) = self.listener_cancel_token.take() { + cancel_token.cancel(); + } + + if let Some(listener) = self.listener.take() { + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + listener.abort(); + }); + } + + tracing::info!("TunnelServer {} dropped", self.peer_id); + } +} + impl Drop for Tunnel { fn drop(&mut self) { if let Some(cancel_token) = self.listener_cancel_token.take() { @@ -81,10 +101,16 @@ impl TunnelServer { peer_id, next_tunnel_id, pproxy_command_tx, + listener: None, + listener_cancel_token: None, } } pub async fn listen(&mut self, address: SocketAddr) -> Result { + if self.listener.is_some() { + return Err(TunnelError::AlreadyListened); + } + let tcp_listener = TcpListener::bind(address).await?; let local_addr = tcp_listener.local_addr()?; @@ -93,7 +119,12 @@ impl TunnelServer { self.next_tunnel_id.clone(), self.pproxy_command_tx.clone(), ); - tokio::spawn(Box::pin(async move { listener.listen(tcp_listener).await })); + let listener_cancel_token = listener.cancel_token(); + let listener_handler = + tokio::spawn(Box::pin(async move { listener.listen(tcp_listener).await })); + + self.listener = Some(listener_handler); + self.listener_cancel_token = Some(listener_cancel_token); Ok(local_addr) } @@ -110,6 +141,7 @@ impl TunnelServerListener { next_tunnel_id, pproxy_command_tx, tunnels: HashMap::new(), + cancel_token: CancellationToken::new(), } } @@ -117,11 +149,20 @@ impl TunnelServerListener { TunnelId::from(self.next_tunnel_id.fetch_add(1usize, Ordering::Relaxed)) } + fn cancel_token(&self) -> CancellationToken { + self.cancel_token.clone() + } + async fn listen(&mut self, listener: TcpListener) { loop { - let Ok((stream, _)) = listener.accept().await else { + if self.cancel_token.is_cancelled() { + break; + } + + let Ok((stream, address)) = listener.accept().await else { continue; }; + tracing::debug!("Received new connection from: {address}"); let tunnel_id = self.next_tunnel_id(); let mut tunnel = Tunnel::new(self.peer_id, tunnel_id, self.pproxy_command_tx.clone()); @@ -187,7 +228,7 @@ impl Tunnel { remote_stream_rx: mpsc::Receiver>, ) -> Result<(), TunnelError> { if self.listener.is_some() { - return Err(TunnelError::TunnelInUse); + return Err(TunnelError::AlreadyListened); } let mut listener = TunnelListener::new( @@ -248,6 +289,7 @@ impl TunnelListener { break TunnelError::ConnectionClosed; } Ok(n) => { + tracing::debug!("Received {} bytes from local stream", n); let (tx, rx) = oneshot::channel(); let data = buf[..n].to_vec(); let command = PProxyCommand::SendOutboundPackageCommand { @@ -283,6 +325,7 @@ impl TunnelListener { } if let Some(body) = self.remote_stream_rx.recv().await { + tracing::debug!("Received {} bytes from local stream", body.len()); if let Err(e) = local_write.write_all(&body).await { tracing::error!("Write to local stream failed: {e:?}"); break e.kind().into();