Skip to content

Commit

Permalink
feat(sockets)!: Downgrade to unidirectional socket
Browse files Browse the repository at this point in the history
This changes beams socket protocol from the bidirectional sockets
interface to a unidirectional streaming interface.
This change was motivated by serveral practical issues with proxies that
did not allow http request upgrades by default.
The new approach uses streaming http bodies which should be supported by
all regular proxies.
  • Loading branch information
Threated committed Sep 4, 2024
1 parent 0d0042a commit 0882f57
Show file tree
Hide file tree
Showing 12 changed files with 345 additions and 389 deletions.
4 changes: 3 additions & 1 deletion beam-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ serde_json = "1"
uuid = { version = "1", features = ["v4", "serde"] }
reqwest = { version = "0.12", features = ["json"], default-features = false, optional = true }
thiserror = { version = "1.0", optional = true }
futures-core = { version = "0.3", optional = true }
bytes = { version = "1.7.1", optional = true }

[features]
strict-ids = []
http-util = ["dep:reqwest", "dep:thiserror"]
sockets = []
sockets = ["dep:futures-core", "dep:bytes"]
43 changes: 25 additions & 18 deletions beam-lib/src/http_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use thiserror::Error;
use crate::{AddressingId, TaskRequest, MsgId, TaskResult, ProxyId};
#[cfg(feature = "sockets")]
use crate::SocketTask;
#[cfg(feature = "sockets")]
use futures_core::{Stream, TryStream};

/// A client used for communicating with the beam network
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -187,33 +189,42 @@ impl BeamClient {
/// Create a socket task for some other application to connect to
/// For this to work both the beam proxy and beam broker need to have the sockets feature enabled.
#[cfg(feature = "sockets")]
pub async fn create_socket(&self, destination: &AddressingId) -> Result<reqwest::Upgraded> {
self.create_socket_with_metadata(destination, serde_json::Value::Null).await
pub async fn create_socket<S>(&self, destination: &AddressingId, data_stream: S) -> Result<reqwest::Response>
where
// Trait bound from Body::wrap_stream
S: TryStream + Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
bytes::Bytes: From<S::Ok>
{
self.create_socket_with_metadata(destination, data_stream, serde_json::Value::Null).await
}

/// Same as `create_socket` but with associated (unencrypted) metadata.
#[cfg(feature = "sockets")]
pub async fn create_socket_with_metadata(&self, destination: &AddressingId, metadata: impl Serialize) -> Result<reqwest::Upgraded> {
pub async fn create_socket_with_metadata<S>(&self, destination: &AddressingId, data_stream: S, metadata: impl Serialize) -> Result<reqwest::Response>
where
// Trait bound from Body::wrap_stream
S: TryStream + Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
bytes::Bytes: From<S::Ok>
{
const METADATA_HEADER: HeaderName = HeaderName::from_static("metadata");
let url = self.beam_proxy_url
.join(&format!("/v1/sockets/{destination}"))
.expect("The proxy url is valid");
let response = self.client
.post(url)
.header(header::UPGRADE, "tcp")
.header(
METADATA_HEADER,
HeaderValue::try_from(serde_json::to_string(&metadata).map_err(BeamError::other)?).map_err(BeamError::other)?
)
.body(reqwest::Body::wrap_stream(data_stream))
.send().await?
.handle_invalid_receivers().await?;
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
Err(BeamError::UnexpectedStatus(response.status()))
if response.status().is_success() {
Ok(response)
} else {
response
.upgrade()
.await
.map_err(Into::into)
Err(BeamError::UnexpectedStatus(response.status()))
}
}

Expand Down Expand Up @@ -245,22 +256,18 @@ impl BeamClient {

/// Connect to a socket by its socket task id
#[cfg(feature = "sockets")]
pub async fn connect_socket(&self, socket_task_id: &MsgId) -> Result<reqwest::Upgraded> {
pub async fn connect_socket(&self, socket_task_id: &MsgId) -> Result<impl Stream<Item = reqwest::Result<bytes::Bytes>>> {
let url = self.beam_proxy_url
.join(&format!("/v1/sockets/{socket_task_id}"))
.expect("The proxy url is valid");
let response = self.client
.get(url)
.header(header::UPGRADE, "tcp")
.send().await?
.handle_invalid_receivers().await?;
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
Err(BeamError::UnexpectedStatus(response.status()))
if response.status().is_success() {
Ok(response.bytes_stream())
} else {
response
.upgrade()
.await
.map_err(Into::into)
Err(BeamError::UnexpectedStatus(response.status()))
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions broker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ once_cell = "1"
# Socket dependencies
bytes = { version = "1", optional = true }
axum-extra = { version = "0.9", features = ["typed-header"] }
hyper = { version = "1", default-features = false, optional = true}
hyper-util = { version = "0.1", default-features = false, features = ["tokio"], optional = true}
futures-util = { version = "0.3", default-features = false, optional = true }

[features]
sockets = ["dep:bytes", "shared/sockets", "dep:hyper", "dep:hyper-util"]
sockets = ["dep:bytes", "shared/sockets", "dep:futures-util"]

[build-dependencies]
build-data = "0"
149 changes: 111 additions & 38 deletions broker/src/serve_sockets.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
use std::{sync::Arc, collections::{HashMap, HashSet}, ops::Deref, time::Duration};
use std::{borrow::Cow, collections::{HashMap, HashSet}, ops::Deref, sync::Arc, time::Duration};

use axum::{extract::{Path, Request, State}, http::{header, request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, RequestExt, Router};
use bytes::BufMut;
use hyper_util::rt::TokioIo;
use axum::{body::{Body, BodyDataStream}, extract::{Path, Request, State}, http::{header, request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, RequestExt, Router};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use dashmap::mapref::entry::Entry;
use futures_core::TryStream;
use futures_util::{stream, StreamExt};
use serde::{Serialize, Serializer, ser::SerializeSeq};
use shared::{config::{CONFIG_CENTRAL, CONFIG_SHARED}, crypto_jwt::Authorized, expire_map::LazyExpireMap, serde_helpers::DerefSerializer, Encrypted, HasWaitId, HowLongToBlock, Msg, MsgEmpty, MsgId, MsgSigned, MsgSocketRequest};
use tokio::sync::{RwLock, broadcast::{Sender, self}, oneshot};
use tracing::{debug, log::error, warn};
use shared::{config::{CONFIG_CENTRAL, CONFIG_SHARED}, crypto_jwt::Authorized, errors::SamplyBeamError, expire_map::LazyExpireMap, serde_helpers::DerefSerializer, Encrypted, HasWaitId, HowLongToBlock, Msg, MsgEmpty, MsgId, MsgSigned, MsgSocketRequest};
use tokio::{sync::{broadcast::{self, Sender}, oneshot, RwLock}, time::Instant};
use tracing::{debug, log::error, warn, Span};

use crate::task_manager::{TaskManager, Task};


#[derive(Clone)]
struct SocketState {
task_manager: Arc<TaskManager<MsgSocketRequest<Encrypted>>>,
waiting_connections: Arc<LazyExpireMap<MsgId, oneshot::Sender<hyper::upgrade::OnUpgrade>>>
waiting_connections: Arc<LazyExpireMap<MsgId, ConnectionState>>
}

enum ConnectionState {
ReadHalfConnected(oneshot::Sender<SocketStream>),
WriteHalfConnected(oneshot::Receiver<SocketStream>),
}

type SocketStream = stream::BoxStream<'static, Result<Bytes, axum::Error>>;

impl SocketState {
const WAITING_CONNECTIONS_TIMEOUT: Duration = Duration::from_secs(60);
const WAITING_CONNECTIONS_CLEANUP_INTERVAL: Duration = Duration::from_secs(5 * 60);
Expand Down Expand Up @@ -78,14 +87,17 @@ async fn post_socket_request(
))
}

// TODO: Instrument with task_id
async fn connect_socket(
state: State<SocketState>,
Path(task_id): Path<MsgId>,
mut parts: Parts,
body: String,
body: Body,
// This Result is just an Either type. An error value does not mean something went wrong
) -> Result<Response, StatusCode> {
let result = shared::crypto_jwt::verify_with_extended_header::<MsgEmpty>(&mut parts, &body).await;
let mut body_stream = body.into_data_stream();
let (is_read, token, remaining) = read_header(&mut body_stream).await?;
let result = shared::crypto_jwt::verify_with_extended_header::<MsgEmpty>(&mut parts, String::from_utf8_lossy(&token).as_ref()).await;
let msg = match result {
Ok(msg) => msg.msg,
Err(e) => return Ok(e.into_response()),
Expand All @@ -97,39 +109,100 @@ async fn connect_socket(
return Err(StatusCode::UNAUTHORIZED);
}
}

let Some(conn) = parts.extensions.remove::<hyper::upgrade::OnUpgrade>() else {
return Err(StatusCode::UPGRADE_REQUIRED);
};

if let Some(req_sender) = state.waiting_connections.remove(&task_id) {
if let Err(_) = req_sender.send(conn) {
warn!("Error sending socket connection to tunnel. Receiver has been dropped");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
if is_read {
let recv = match state.waiting_connections.entry(task_id) {
Entry::Occupied(e) => {
match e.remove().0 {
ConnectionState::ReadHalfConnected(_) => {
warn!("Reader connected twice");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
},
ConnectionState::WriteHalfConnected(recv) => recv,
}
},
Entry::Vacant(empty) => {
let (tx, rx) = oneshot::channel();
empty.insert((ConnectionState::ReadHalfConnected(tx), Instant::now() + SocketState::WAITING_CONNECTIONS_TIMEOUT));
rx
},
};
debug!(%task_id, "Read waiting on write");
let recv_res = recv.await;
_ = state.task_manager.remove(&task_id);
match recv_res {
Ok(s) => Ok(Body::from_stream(s).into_response()),
Err(_) => {
warn!(%task_id, "Socket connection expired");
Err(StatusCode::GONE)
},
}
} else {
let (tx, rx) = tokio::sync::oneshot::channel();
state.waiting_connections.insert_for(SocketState::WAITING_CONNECTIONS_TIMEOUT, task_id, tx);
let Ok(other_con) = rx.await else {
debug!("Socket expired because nobody connected");
let sender = match state.waiting_connections.entry(task_id) {
Entry::Occupied(e) => {
match e.remove().0 {
ConnectionState::ReadHalfConnected(send_read) => send_read,
ConnectionState::WriteHalfConnected(_) => {
warn!("Sender connected twice");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
},
}
},
Entry::Vacant(empty) => {
let (tx, rx) = oneshot::channel();
empty.insert((ConnectionState::WriteHalfConnected(rx), Instant::now() + SocketState::WAITING_CONNECTIONS_TIMEOUT));
tx
},
};
let (tx, rx) = oneshot::channel::<()>();
let mut wrapped = Some(tx);
let send_res = sender.send(stream::once(futures_util::future::ready(Ok(remaining.freeze()))).chain(body_stream).chain(stream::poll_fn(move |_| {
_ = wrapped.take().unwrap().send(());
std::task::Poll::Ready(None)
})).boxed());
let Ok(()) = send_res else {
warn!(%task_id, "Failed to send socket body. Reciever dropped");
return Err(StatusCode::GONE);
};
// We don't care if the task expired by now
_ = state.task_manager.remove(&task_id);
tokio::spawn(async move {
let (socket1, socket2) = match tokio::try_join!(conn, other_con) {
Ok(sockets) => sockets,
Err(e) => {
warn!("Failed to upgrade requests to socket connections: {e}");
return;
},
};
debug!("Write half over stream");
_ = rx.await;
Err(StatusCode::OK)
}
}

let result = tokio::io::copy_bidirectional(&mut TokioIo::new(socket1), &mut TokioIo::new(socket2)).await;
if let Err(e) = result {
debug!("Relaying socket connection ended: {e}");
async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut), StatusCode> {
async fn next(s: &mut BodyDataStream) -> Result<Option<Bytes>, StatusCode> {
s.next().await.transpose().map_err(|e| {
warn!(%e, "Failed to read init for sockets");
StatusCode::BAD_GATEWAY
})
}
let mut buf = BytesMut::new();
#[derive(Debug)]
enum ReadState {
ReadingHeader,
ReadingMessage { is_read: bool, len: usize },
}
let mut state = ReadState::ReadingHeader;
while let Some(mut packet) = next(s).await? {
loop {
match state {
ReadState::ReadingHeader if buf.len() + packet.len() >= 5 => {
buf.put(packet.split_to(packet.len()));
debug_assert!(packet.is_empty());
let is_read = buf.split_to(1)[0] == 1;
let len = u32::from_be_bytes(buf.split_to(4).as_ref().try_into().unwrap());
state = ReadState::ReadingMessage { is_read, len: len as usize };
continue;
},
ReadState::ReadingMessage { is_read, len } if buf.len() + packet.len() >= len => {
buf.put(packet);
return Ok((is_read, buf.split_to(len).freeze(), buf))
},
_ => break,
}
});
}
buf.put(packet);
}
Err(StatusCode::SWITCHING_PROTOCOLS)
dbg!("Not enough data?", state, buf);
Err(StatusCode::BAD_GATEWAY)
}
2 changes: 2 additions & 0 deletions dev/beamdev
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ check_prereqs
if [ "$1" == "--tag" ]; then
TAG="$2"
shift 2
else
TAG="localbuild"
fi

case "$1" in
Expand Down
6 changes: 2 additions & 4 deletions proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ async-sse = "5.1"
async-stream = "0.3"

# Socket dependencies
chacha20poly1305 = { version = "0.10", features = ["stream"], optional = true }
crypto_secretstream = { version = "0.2", optional = true }
dashmap = { version = "6.0", optional = true}
hyper = { version = "1", default-features = false, optional = true }
hyper-util = { version = "0.1", default-features = false, features = ["tokio"], optional = true}

[features]
sockets = ["dep:chacha20poly1305", "dep:dashmap", "tokio-util/codec", "tokio-util/compat", "shared/sockets", "shared/expire_map", "dep:hyper", "dep:hyper-util"]
sockets = ["dep:crypto_secretstream", "dep:dashmap", "shared/sockets", "shared/expire_map"]

[build-dependencies]
build-data = "0"
2 changes: 1 addition & 1 deletion proxy/src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl GetCertsFromBroker {
let req = sign_request(body, parts, &self.config, Some(&self.crypto_conf))
.await
.map_err(|(_, msg)| SamplyBeamError::SignEncryptError(msg.into()))?;
Ok(self.client.execute(req).await?.into())
Ok(self.client.execute(req.try_into().unwrap()).await?.into())
}

async fn query(&self, path: &str) -> Result<String, SamplyBeamError> {
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) {

let req = sign_request(body, parts, &config, None).await.expect("Unable to sign request; this should always work");
// In the future this will poll actual control related tasks
match client.execute(req).await {
match client.execute(req.try_into().unwrap()).await {
Ok(res) => {
match res.status() {
StatusCode::OK => {
Expand Down
Loading

0 comments on commit 0882f57

Please sign in to comment.