Skip to content

Commit

Permalink
refactor: seperate read/writing streams into get/post requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Oct 22, 2024
1 parent 165fc18 commit a09a0e8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
32 changes: 17 additions & 15 deletions broker/src/serve_sockets.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{sync::Arc, time::Duration};

use axum::{body::{Body, BodyDataStream}, extract::{Path, State}, http::{header, request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, Router};
use axum::{body::{Body, BodyDataStream}, extract::{Path, State}, http::{header, request::Parts, Method, StatusCode}, response::{IntoResponse, Response}, routing::get, Router};
use bytes::{BufMut, Bytes, BytesMut};
use dashmap::mapref::entry::Entry;
use futures_util::{stream, StreamExt};
use shared::{expire_map::LazyExpireMap, serde_helpers::DerefSerializer, Encrypted, HasWaitId, HowLongToBlock, Msg, MsgEmpty, MsgId, MsgSigned, MsgSocketRequest};
use tokio::{sync::oneshot, time::Instant};
use tracing::{debug, warn};
use tracing::{debug, warn, Span};

use crate::task_manager::TaskManager;

Expand Down Expand Up @@ -49,7 +49,7 @@ impl Default for SocketState {
pub(crate) fn router() -> Router {
Router::new()
.route("/v1/sockets", get(get_socket_requests).post(post_socket_request))
.route("/v1/sockets/:id", get(connect_socket))
.route("/v1/sockets/:id", get(connect_socket).post(connect_socket))
.with_state(SocketState::default())
}

Expand Down Expand Up @@ -85,16 +85,19 @@ async fn post_socket_request(
))
}

// TODO: Instrument with task_id
#[tracing::instrument(skip(method, state, parts, body), fields(is_read))]
async fn connect_socket(
method: Method,
state: State<SocketState>,
Path(task_id): Path<MsgId>,
mut parts: Parts,
body: Body,
// This Result is just an Either type. An error value does not mean something went wrong
) -> Result<Response, StatusCode> {
let is_read = method == Method::GET;
Span::current().record("is_read", is_read);
let mut body_stream = body.into_data_stream();
let (is_read, jwt, remaining) = read_header(&mut body_stream).await?;
let (jwt, remaining) = read_header(&mut body_stream).await?;
let result = shared::crypto_jwt::verify_with_extended_header::<MsgEmpty>(&mut parts, String::from_utf8_lossy(&jwt).as_ref()).await;
let msg = match result {
Ok(msg) => msg.msg,
Expand Down Expand Up @@ -124,13 +127,13 @@ async fn connect_socket(
rx
},
};
debug!(%task_id, "Read waiting on write");
debug!("Read waiting on write");
let recv_res = recv.await;
_ = state.task_manager.remove(&task_id);
match recv_res {
Ok(s) => Ok(Body::from_stream(s).into_response()),
Err(_) => {
warn!(%task_id, "Socket connection expired");
warn!("Socket connection expired");
Err(StatusCode::GONE)
},
}
Expand Down Expand Up @@ -158,7 +161,7 @@ async fn connect_socket(
std::task::Poll::Ready(None)
})).boxed());
let Ok(()) = send_res else {
warn!(%task_id, "Failed to send socket body. Reciever dropped");
warn!("Failed to send socket body. Reciever dropped");
return Err(StatusCode::GONE);
};
debug!("Write half send the stream to the read half");
Expand All @@ -168,7 +171,7 @@ async fn connect_socket(
}
}

async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut), StatusCode> {
async fn read_header(s: &mut BodyDataStream) -> Result<(Bytes, BytesMut), StatusCode> {
async fn next(s: &mut BodyDataStream) -> Result<Option<Bytes>, StatusCode> {
s.next().await.transpose().map_err(|e| {
warn!(%e, "Failed to read init for sockets");
Expand All @@ -179,23 +182,22 @@ async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut),
#[derive(Debug)]
enum ReadState {
ReadingHeader,
ReadingMessage { is_read: bool, len: usize },
ReadingMessage { len: usize },
}
let mut state = ReadState::ReadingHeader;
while let Some(mut packet) = next(s).await? {
loop {
match state {
ReadState::ReadingHeader if buf.len() + packet.len() >= 5 => {
ReadState::ReadingHeader if buf.len() + packet.len() >= 4 => {
buf.put(packet.split_to(packet.len()));
debug_assert!(packet.is_empty());
let is_read = buf.split_to(1)[0] == 1;
let len = u32::from_be_bytes(buf.split_to(4).as_ref().try_into().unwrap());
state = ReadState::ReadingMessage { is_read, len: len as usize };
state = ReadState::ReadingMessage { len: len as usize };
continue;
},
ReadState::ReadingMessage { is_read, len } if buf.len() + packet.len() >= len => {
ReadState::ReadingMessage { len } if buf.len() + packet.len() >= len => {
buf.put(packet);
return Ok((is_read, buf.split_to(len).freeze(), buf))
return Ok((buf.split_to(len).freeze(), buf))
},
_ => break,
}
Expand Down
26 changes: 14 additions & 12 deletions proxy/src/serve_sockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
};

use axum::{
body::Body, extract::{Path, Request, State}, http::{self, header, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router
body::Body, extract::{Path, Request, State}, http::{self, header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crypto_secretstream::{Header, Key, PullStream, PushStream};
Expand Down Expand Up @@ -136,15 +136,13 @@ async fn create_socket_con(
);
return (res.status(), "Failed to post MsgSocketRequest to broker").into_response();
}
let req = match prepare_socket_request(sender, task_id, &state).await {
let req = match prepare_socket_request(false, sender, task_id, &state).await {
Ok(req) => req,
Err(e) => return e,
};
let req = req.map(|b| {
let n = b.as_bytes().len();
let mut body = Vec::with_capacity(n + 5);
// This 0 signals write interest
body.push(0);
let mut body = Vec::with_capacity(n + 4);
body.extend(u32::to_be_bytes(n as _));
body.extend(b.as_bytes());
Bytes::from(body)
Expand All @@ -155,7 +153,6 @@ async fn create_socket_con(
};
let (mut parts, body) = req.into_parts();
parts.headers.append(header::TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
parts.headers.append(header::CONNECTION, HeaderValue::from_static("keep-alive"));
let stream = stream::once(ready(Ok(body))).chain(Encrypter::new(key).encrypt(og_req.into_body().into_data_stream()));
let req = Request::from_parts(parts, reqwest::Body::wrap_stream(stream));
match state.client.execute(req.try_into().expect("Conversion to reqwest::Request should always work")).await {
Expand All @@ -173,15 +170,13 @@ async fn connect_read(
Extension(task_secret_map): Extension<MsgSecretMap>,
Path(task_id): Path<MsgId>,
) -> Response {
let req = match prepare_socket_request(sender, task_id, &state).await {
let req = match prepare_socket_request(true, sender, task_id, &state).await {
Ok(value) => value,
Err(e) => return e,
};
let req = req.map(|b| {
let n = b.as_bytes().len();
let mut body = Vec::with_capacity(n + 4 + 1);
// This 1 signals read interest
body.push(1);
let mut body = Vec::with_capacity(n + 4);
body.extend(u32::to_be_bytes(n as _));
body.extend(b.as_bytes());
Bytes::from(body)
Expand All @@ -200,15 +195,22 @@ async fn connect_read(
Response::new(Body::from_stream(Decrypter::new(key).decrypt(res.bytes_stream())))
}

async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, state: &State<TasksState>) -> Result<http::Request<String>, http::Response<Body>> {
async fn prepare_socket_request(is_read: bool, sender: beam_lib::AppId, task_id: MsgId, state: &State<TasksState>) -> Result<http::Request<String>, http::Response<Body>> {
let msg_empty = MsgEmpty {
from: AppOrProxyId::App(sender.clone()),
};
let Ok(body) = serde_json::to_vec(&msg_empty) else {
warn!("Failed to serialize MsgEmpty");
return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response());
};
let new_req = Request::get(format!("/v1/sockets/{task_id}")).body(axum::body::Body::from(body));
let new_req = Request::builder()
.method(if is_read {
Method::GET
} else {
Method::POST
})
.uri(format!("/v1/sockets/{task_id}"))
.body(axum::body::Body::from(body));
let get_socket_con_req = match new_req {
Ok(req) => req,
Err(e) => {
Expand Down

0 comments on commit a09a0e8

Please sign in to comment.