diff --git a/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs b/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs index 4c09aa1ed3a..0b1f472120b 100644 --- a/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs +++ b/implementations/rust/ockam/ockam_transport_core/src/hostname_port.rs @@ -30,7 +30,7 @@ impl From for HostnamePort { } /// Hostname and port -#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode, CborLen)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode, CborLen)] #[rustfmt::skip] pub struct HostnamePort { #[n(0)] hostname: String, @@ -158,6 +158,10 @@ impl HostnamePort { Ok(HostnamePort::new(hostname, port)) } + + pub fn is_localhost(&self) -> bool { + self.hostname == "localhost" || self.hostname == "127.0.0.1" + } } impl From for HostnamePort { diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs b/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs index 12ea8729c50..b16c0cb4e5d 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs @@ -1,4 +1,4 @@ -use crate::transport::connect; +use crate::transport::{connect, CachedConnectionsQueue}; use crate::workers::{Addresses, TcpRecvProcessor, TcpSendWorker}; use crate::{TcpConnectionMode, TcpConnectionOptions, TcpTransport}; use core::fmt; @@ -10,6 +10,9 @@ use ockam_core::{Address, Result}; use ockam_node::Context; use ockam_transport_core::HostnamePort; use std::net::SocketAddr; +use std::sync::{Arc, Mutex as SyncMutex, Weak}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::time::Instant; use tracing::debug; /// Result of [`TcpTransport::connect`] call. @@ -84,6 +87,55 @@ impl TcpConnection { } } +pub(crate) struct ConnectionCollector { + hostname_port: HostnamePort, + connections_queue: Weak, + read_half: SyncMutex>, + write_half: SyncMutex>, +} + +impl ConnectionCollector { + fn new(hostname_port: HostnamePort, connections_queue: &Arc) -> Self { + Self { + hostname_port, + connections_queue: Arc::downgrade(connections_queue), + read_half: SyncMutex::new(None), + write_half: SyncMutex::new(None), + } + } + + pub(crate) fn collect_read_half(&self, last_known_reply: Instant, read_half: OwnedReadHalf) { + debug!("Collecting read half for {}", self.hostname_port); + self.read_half + .lock() + .unwrap() + .replace((last_known_reply, read_half)); + self.check_and_push_connection(); + } + + pub(crate) fn collect_write_half(&self, write_half: OwnedWriteHalf) { + debug!("Collecting write half for {}", self.hostname_port); + self.write_half.lock().unwrap().replace(write_half); + self.check_and_push_connection(); + } + + fn check_and_push_connection(&self) { + let mut read_half = self.read_half.lock().unwrap(); + let mut write_half = self.write_half.lock().unwrap(); + + if read_half.is_some() && write_half.is_some() { + if let Some(connections_queue) = self.connections_queue.upgrade() { + let (last_known_reply, read_half) = read_half.take().unwrap(); + let write_half = write_half.take().unwrap(); + + let mut guard = connections_queue.lock().unwrap(); + let connections = guard.entry(self.hostname_port.clone()).or_default(); + connections.push_back((last_known_reply, read_half, write_half)); + } + } + } +} + impl TcpTransport { /// Establish an outgoing TCP connection. /// @@ -103,9 +155,41 @@ impl TcpTransport { options: TcpConnectionOptions, ) -> Result { let peer = HostnamePort::from_str(&peer.into())?; - debug!("Connecting to {}", peer.clone()); - let (read_half, write_half) = connect(&peer).await?; + let (last_known_reply, skip_initialization, read_half, write_half) = { + let connection = { + let mut guard = self.connections.lock().unwrap(); + if let Some(connections) = guard.get_mut(&peer) { + loop { + if let Some((last_known_reply, read_half, write_half)) = + connections.pop_front() + { + let elapsed = last_known_reply.elapsed(); + if elapsed.as_secs() < 2 { + debug!( + "Reusing existing connection to {}, {}ms old", + peer.clone(), + elapsed.as_millis() + ); + break Some((last_known_reply, true, read_half, write_half)); + } + } else { + break None; + } + } + } else { + None + } + }; + + if let Some(read_write_half) = connection { + read_write_half + } else { + let (read_half, write_half) = connect(&peer).await?; + (Instant::now(), false, read_half, write_half) + } + }; + let socket = read_half .peer_addr() .map_err(|e| ockam_core::Error::new(Origin::Transport, Kind::Internal, e))?; @@ -118,14 +202,24 @@ impl TcpTransport { let receiver_outgoing_access_control = options.create_receiver_outgoing_access_control(self.ctx.flow_controls()); + let connection_collector = { + if peer.is_localhost() { + None + } else { + Some(Arc::new(ConnectionCollector::new(peer, &self.connections))) + } + }; + TcpSendWorker::start( &self.ctx, self.registry.clone(), write_half, + skip_initialization, &addresses, socket, mode, &flow_control_id, + connection_collector.clone(), ) .await?; @@ -133,11 +227,14 @@ impl TcpTransport { &self.ctx, self.registry.clone(), read_half, + skip_initialization, + last_known_reply, &addresses, socket, mode, &flow_control_id, receiver_outgoing_access_control, + connection_collector, ) .await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs b/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs index 9db988c2393..b91bfba2827 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/transport/mod.rs @@ -4,17 +4,23 @@ mod lifecycle; mod listener; mod portals; -pub(crate) use common::*; - pub use crate::portal::options::*; +use crate::TcpRegistry; +pub(crate) use common::*; pub use connection::*; pub use listener::*; -pub use portals::*; - -use crate::TcpRegistry; use ockam_core::compat::sync::Arc; use ockam_core::{async_trait, Result}; use ockam_node::{Context, HasContext}; +use ockam_transport_core::HostnamePort; +pub use portals::*; +use std::collections::{HashMap, VecDeque}; +use std::sync::Mutex as SyncMutex; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::time::Instant; + +type CachedConnectionsQueue = + SyncMutex>>; /// High level management interface for TCP transports /// @@ -58,6 +64,7 @@ use ockam_node::{Context, HasContext}; pub struct TcpTransport { ctx: Arc, registry: TcpRegistry, + connections: Arc, #[cfg(privileged_portals_support)] pub(crate) ebpf_support: Arc, @@ -71,6 +78,7 @@ impl TcpTransport { registry: TcpRegistry::default(), #[cfg(privileged_portals_support)] ebpf_support: Default::default(), + connections: Default::default(), } } } diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs index 2b590bc79d2..f5da733f4be 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs @@ -5,6 +5,7 @@ use ockam_core::{Address, Processor, Result}; use ockam_node::Context; use ockam_transport_core::TransportError; use tokio::net::TcpListener; +use tokio::time::Instant; use tracing::{debug, instrument}; /// A TCP Listen processor @@ -102,10 +103,12 @@ impl Processor for TcpListenProcessor { ctx, self.registry.clone(), write_half, + false, &addresses, peer, mode, &receiver_flow_control_id, + None, ) .await?; @@ -114,11 +117,14 @@ impl Processor for TcpListenProcessor { ctx, self.registry.clone(), read_half, + false, + Instant::now(), &addresses, peer, mode, &receiver_flow_control_id, receiver_outgoing_access_control, + None, ) .await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs index 13b8e969319..19df43ccdc4 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/receiver.rs @@ -1,3 +1,4 @@ +use crate::transport::ConnectionCollector; use crate::transport_message::TcpTransportMessage; use crate::workers::Addresses; use crate::{ @@ -15,8 +16,9 @@ use ockam_core::{ use ockam_core::{Processor, Result}; use ockam_node::{Context, ProcessorBuilder}; use ockam_transport_core::TransportError; +use tokio::time::Instant; use tokio::{io::AsyncReadExt, net::tcp::OwnedReadHalf}; -use tracing::{info, instrument, trace}; +use tracing::{debug, info, instrument, trace}; /// A TCP receiving message processor /// @@ -29,31 +31,41 @@ use tracing::{info, instrument, trace}; pub(crate) struct TcpRecvProcessor { registry: TcpRegistry, incoming_buffer: Vec, - read_half: OwnedReadHalf, + read_half: Option, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, flow_control_id: FlowControlId, + connection_collector: Option>, + last_known_reply: Instant, + initialized: bool, } impl TcpRecvProcessor { /// Create a new `TcpRecvProcessor` + #[allow(clippy::too_many_arguments)] fn new( registry: TcpRegistry, read_half: OwnedReadHalf, + last_known_reply: Instant, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, flow_control_id: FlowControlId, + connection_collector: Option>, + initialized: bool, ) -> Self { Self { registry, incoming_buffer: Vec::new(), - read_half, + read_half: Some(read_half), socket_address, addresses, mode, flow_control_id, + connection_collector, + last_known_reply, + initialized, } } @@ -63,19 +75,25 @@ impl TcpRecvProcessor { ctx: &Context, registry: TcpRegistry, read_half: OwnedReadHalf, + skip_initialization: bool, + last_known_reply: Instant, addresses: &Addresses, socket_address: SocketAddr, mode: TcpConnectionMode, flow_control_id: &FlowControlId, receiver_outgoing_access_control: Arc, + connection_collector: Option>, ) -> Result<()> { let receiver = TcpRecvProcessor::new( registry, read_half, + last_known_reply, socket_address, addresses.clone(), mode, flow_control_id.clone(), + connection_collector, + skip_initialization, ); let mailbox = Mailbox::new( @@ -98,7 +116,12 @@ impl TcpRecvProcessor { Ok(()) } - async fn notify_sender_stream_dropped(&self, ctx: &Context, msg: impl Display) -> Result<()> { + async fn notify_sender_stream_dropped( + &mut self, + ctx: &Context, + msg: impl Display, + ) -> Result<()> { + self.read_half.take(); info!( "Connection to peer '{}' was closed; dropping stream. {}", self.socket_address, msg @@ -129,12 +152,23 @@ impl Processor for TcpRecvProcessor { self.flow_control_id.clone(), )); - let protocol_version = match self.read_half.read_u8().await { - Ok(p) => p, - Err(e) => { - self.notify_sender_stream_dropped(ctx, e).await?; - return Err(TransportError::GenericIo)?; + if self.initialized { + return Ok(()); + } + + let protocol_version = if let Some(read_half) = self.read_half.as_mut() { + match read_half.read_u8().await { + Ok(p) => { + self.last_known_reply = Instant::now(); + p + } + Err(e) => { + self.notify_sender_stream_dropped(ctx, e).await?; + return Err(TransportError::GenericIo)?; + } } + } else { + return Err(TransportError::ConnectionDrop)?; }; let _protocol_version = match TcpProtocolVersion::try_from(protocol_version) { @@ -153,13 +187,25 @@ impl Processor for TcpRecvProcessor { } }; + self.last_known_reply = Instant::now(); + self.initialized = true; + Ok(()) } #[instrument(skip_all, name = "TcpRecvProcessor::shutdown")] async fn shutdown(&mut self, ctx: &mut Self::Context) -> Result<()> { - self.registry.remove_receiver_processor(&ctx.address()); + if self.initialized { + if let Some(connection_collector) = self.connection_collector.as_ref() { + if let Some(read_half) = self.read_half.take() { + connection_collector.collect_read_half(self.last_known_reply, read_half); + } else { + debug!("Connection closed, no read half to collect"); + } + } + } + self.registry.remove_receiver_processor(&ctx.address()); Ok(()) } @@ -177,7 +223,13 @@ impl Processor for TcpRecvProcessor { #[instrument(skip_all, name = "TcpRecvProcessor::process", fields(worker = %ctx.address()))] async fn process(&mut self, ctx: &mut Context) -> Result { // Read the message length - let len = match self.read_half.read_u32().await { + let read_half = if let Some(read_half) = self.read_half.as_mut() { + read_half + } else { + return Ok(false); + }; + + let len = match read_half.read_u32().await { Ok(l) => l, Err(e) => { self.notify_sender_stream_dropped(ctx, e).await?; @@ -217,7 +269,7 @@ impl Processor for TcpRecvProcessor { self.incoming_buffer.resize(len_usize, 0); // Then read into the buffer - match self.read_half.read_exact(&mut self.incoming_buffer).await { + match read_half.read_exact(&mut self.incoming_buffer).await { Ok(_) => {} Err(e) => { self.notify_sender_stream_dropped(ctx, e).await?; diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs index 9f5cbb0e9a1..6f07cfeb6a2 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs @@ -9,12 +9,13 @@ use ockam_core::{ use ockam_core::{Any, Decodable, Mailbox, Mailboxes, Message, Result, Routed, Worker}; use ockam_node::{Context, WorkerBuilder}; +use crate::transport::ConnectionCollector; use crate::transport_message::TcpTransportMessage; use ockam_transport_core::TransportError; use serde::{Deserialize, Serialize}; use tokio::io::AsyncWriteExt; use tokio::net::tcp::OwnedWriteHalf; -use tracing::{info, instrument, trace, warn}; +use tracing::{debug, info, instrument, trace, warn}; #[derive(Serialize, Deserialize, Message, Clone)] pub(crate) enum TcpSendWorkerMsg { @@ -32,33 +33,40 @@ pub(crate) enum TcpSendWorkerMsg { pub(crate) struct TcpSendWorker { buffer: Vec, registry: TcpRegistry, - write_half: OwnedWriteHalf, + write_half: Option, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, receiver_flow_control_id: FlowControlId, rx_should_be_stopped: bool, + connection_collector: Option>, + initialized: bool, } impl TcpSendWorker { /// Create a new `TcpSendWorker` + #[allow(clippy::too_many_arguments)] fn new( registry: TcpRegistry, write_half: OwnedWriteHalf, + initialized: bool, socket_address: SocketAddr, addresses: Addresses, mode: TcpConnectionMode, receiver_flow_control_id: FlowControlId, + connection_collector: Option>, ) -> Self { Self { buffer: vec![], registry, - write_half, + write_half: Some(write_half), socket_address, addresses, receiver_flow_control_id, mode, + connection_collector, rx_should_be_stopped: true, + initialized, } } } @@ -72,19 +80,23 @@ impl TcpSendWorker { ctx: &Context, registry: TcpRegistry, write_half: OwnedWriteHalf, + skip_initialization: bool, addresses: &Addresses, socket_address: SocketAddr, mode: TcpConnectionMode, receiver_flow_control_id: &FlowControlId, + connection_collector: Option>, ) -> Result<()> { trace!("Creating new TCP worker pair"); let sender_worker = Self::new( registry, write_half, + skip_initialization, socket_address, addresses.clone(), mode, receiver_flow_control_id.clone(), + connection_collector, ); let main_mailbox = Mailbox::new( @@ -154,6 +166,7 @@ impl TcpSendWorker { // Replace zeros with actual length self.buffer[..LENGTH_VALUE_SIZE].copy_from_slice(&payload_len_u32.to_be_bytes()); + trace!("Sending {payload_len_u32} bytes"); Ok(()) } @@ -176,27 +189,48 @@ impl Worker for TcpSendWorker { self.receiver_flow_control_id.clone(), )); - // First thing send our protocol version - if self - .write_half - .write_u8(TcpProtocolVersion::V1.into()) - .await - .is_err() - { - warn!( - "Failed to send protocol version to peer {}", - self.socket_address - ); - self.stop(ctx).await?; - + // First thing sends our protocol version + if self.initialized { return Ok(()); } + if let Some(write_half) = self.write_half.as_mut() { + if write_half + .write_u8(TcpProtocolVersion::V1.into()) + .await + .is_err() + { + warn!( + "Failed to send protocol version to peer {}", + self.socket_address + ); + self.write_half.take(); + self.stop(ctx).await?; + + return Ok(()); + } + } else { + self.stop(ctx).await?; + return Err(TransportError::ConnectionDrop)?; + } + + self.initialized = true; + Ok(()) } #[instrument(skip_all, name = "TcpSendWorker::shutdown")] async fn shutdown(&mut self, ctx: &mut Self::Context) -> Result<()> { + if self.initialized { + if let Some(connection_collector) = self.connection_collector.as_ref() { + if let Some(write_half) = self.write_half.take() { + connection_collector.collect_write_half(write_half); + } else { + debug!("Connection closed, no read write to collect"); + } + } + } + self.registry .remove_sender_worker(self.addresses.sender_address()); @@ -230,6 +264,7 @@ impl Worker for TcpSendWorker { // No need to stop Receiver as it notified us about connection drop and will // stop itself self.rx_should_be_stopped = false; + self.write_half.take(); self.stop(ctx).await?; return Ok(()); @@ -243,16 +278,23 @@ impl Worker for TcpSendWorker { if let Err(err) = self.serialize_message(local_message) { // Close the stream + self.write_half.take(); self.stop(ctx).await?; return Err(err); }; - if self.write_half.write_all(&self.buffer).await.is_err() { - warn!("Failed to send message to peer {}", self.socket_address); - self.stop(ctx).await?; + if let Some(write_half) = self.write_half.as_mut() { + if write_half.write_all(&self.buffer).await.is_err() { + warn!("Failed to send message to peer {}", self.socket_address); + self.write_half.take(); + self.stop(ctx).await?; - return Ok(()); + return Ok(()); + } + } else { + self.stop(ctx).await?; + return Err(TransportError::ConnectionDrop)?; } }