Skip to content

Commit

Permalink
feat: close tunnel server if remote borken
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma233 committed Jul 16, 2024
1 parent 2d83a38 commit 83e5103
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub enum TunnelError {
/// use elsewhere.
AddrInUse = 8,
/// Tunnel already listened.
TunnelInUse = 9,
AlreadyListened = 9,
/// Unknown [std::io::ErrorKind] error.
Unknown = u8::MAX,
}
Expand Down
49 changes: 46 additions & 3 deletions src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ pub struct TunnelServer {
peer_id: PeerId,
next_tunnel_id: Arc<AtomicUsize>,
pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>,
listener_cancel_token: Option<CancellationToken>,
listener: Option<tokio::task::JoinHandle<()>>,
}

pub struct TunnelServerListener {
peer_id: PeerId,
next_tunnel_id: Arc<AtomicUsize>,
pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>,
tunnels: HashMap<TunnelId, Tunnel>,
cancel_token: CancellationToken,
}

pub struct Tunnel {
Expand All @@ -54,6 +57,23 @@ pub struct TunnelListener {
cancel_token: CancellationToken,
}

impl Drop for TunnelServer {
fn drop(&mut self) {
if let Some(cancel_token) = self.listener_cancel_token.take() {
cancel_token.cancel();
}

if let Some(listener) = self.listener.take() {
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
listener.abort();
});
}

tracing::info!("TunnelServer {} dropped", self.peer_id);
}
}

impl Drop for Tunnel {
fn drop(&mut self) {
if let Some(cancel_token) = self.listener_cancel_token.take() {
Expand Down Expand Up @@ -81,10 +101,16 @@ impl TunnelServer {
peer_id,
next_tunnel_id,
pproxy_command_tx,
listener: None,
listener_cancel_token: None,
}
}

pub async fn listen(&mut self, address: SocketAddr) -> Result<SocketAddr, TunnelError> {
if self.listener.is_some() {
return Err(TunnelError::AlreadyListened);
}

let tcp_listener = TcpListener::bind(address).await?;
let local_addr = tcp_listener.local_addr()?;

Expand All @@ -93,7 +119,12 @@ impl TunnelServer {
self.next_tunnel_id.clone(),
self.pproxy_command_tx.clone(),
);
tokio::spawn(Box::pin(async move { listener.listen(tcp_listener).await }));
let listener_cancel_token = listener.cancel_token();
let listener_handler =
tokio::spawn(Box::pin(async move { listener.listen(tcp_listener).await }));

self.listener = Some(listener_handler);
self.listener_cancel_token = Some(listener_cancel_token);

Ok(local_addr)
}
Expand All @@ -110,18 +141,28 @@ impl TunnelServerListener {
next_tunnel_id,
pproxy_command_tx,
tunnels: HashMap::new(),
cancel_token: CancellationToken::new(),
}
}

fn next_tunnel_id(&mut self) -> TunnelId {
TunnelId::from(self.next_tunnel_id.fetch_add(1usize, Ordering::Relaxed))
}

fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}

async fn listen(&mut self, listener: TcpListener) {
loop {
let Ok((stream, _)) = listener.accept().await else {
if self.cancel_token.is_cancelled() {
break;
}

let Ok((stream, address)) = listener.accept().await else {
continue;
};
tracing::debug!("Received new connection from: {address}");

let tunnel_id = self.next_tunnel_id();
let mut tunnel = Tunnel::new(self.peer_id, tunnel_id, self.pproxy_command_tx.clone());
Expand Down Expand Up @@ -187,7 +228,7 @@ impl Tunnel {
remote_stream_rx: mpsc::Receiver<Vec<u8>>,
) -> Result<(), TunnelError> {
if self.listener.is_some() {
return Err(TunnelError::TunnelInUse);
return Err(TunnelError::AlreadyListened);
}

let mut listener = TunnelListener::new(
Expand Down Expand Up @@ -248,6 +289,7 @@ impl TunnelListener {
break TunnelError::ConnectionClosed;
}
Ok(n) => {
tracing::debug!("Received {} bytes from local stream", n);
let (tx, rx) = oneshot::channel();
let data = buf[..n].to_vec();
let command = PProxyCommand::SendOutboundPackageCommand {
Expand Down Expand Up @@ -283,6 +325,7 @@ impl TunnelListener {
}

if let Some(body) = self.remote_stream_rx.recv().await {
tracing::debug!("Received {} bytes from local stream", body.len());
if let Err(e) = local_write.write_all(&body).await {
tracing::error!("Write to local stream failed: {e:?}");
break e.kind().into();
Expand Down

0 comments on commit 83e5103

Please sign in to comment.