Skip to content

Commit

Permalink
refactor: use &'static Config instead of Arc
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Oct 30, 2024
1 parent e0cf95f commit 987ef21
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 22 deletions.
8 changes: 4 additions & 4 deletions src/logic_ask.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{sync::Arc, str::FromStr};
use std::str::FromStr;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
Expand All @@ -18,7 +18,7 @@ use crate::{config::Config, structs::MyStatusCode, msg::{HttpRequest, HttpRespon
/// This function knows from its map which app to direct the message to
pub(crate) async fn handler_http(
mut req: Request<Incoming>,
config: Arc<Config>,
config: &Config,
https_authority: Option<Authority>,
) -> Result<Response, MyStatusCode> {

Expand Down Expand Up @@ -96,12 +96,12 @@ pub(crate) async fn handler_http(
info!("{method} {} via {target}", req.uri());
let span = info_span!("request", %method, via = %target, url = %req.uri());
#[cfg(feature = "sockets")]
return crate::sockets::handle_via_sockets(req, &config, target, auth).instrument(span).await;
return crate::sockets::handle_via_sockets(req, config, target, auth).instrument(span).await;
#[cfg(not(feature = "sockets"))]
return handle_via_tasks(req, &config, target, auth).instrument(span).await;
}

async fn handle_via_tasks(req: Request<Incoming>, config: &Arc<Config>, target: &AppId, auth: HeaderValue) -> Result<Response, MyStatusCode> {
async fn handle_via_tasks(req: Request<Incoming>, config: &Config, target: &AppId, auth: HeaderValue) -> Result<Response, MyStatusCode> {
let msg = http_req_to_struct(req, &config.my_app_id, &target, config.expire).await?;

// Send to Proxy
Expand Down
19 changes: 7 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{convert::Infallible, time::Duration};

use config::Config;
use http_body_util::combinators::BoxBody;
Expand Down Expand Up @@ -28,7 +28,7 @@ async fn main() -> anyhow::Result<()> {
tracing::subscriber::set_global_default(tracing_subscriber::fmt().with_env_filter(EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy()).finish())?;
banner::print_banner();
let config = Config::load().await?;
let config2 = config.clone();
let config: &'static _ = Box::leak(Box::new(config));
let client = config.client.clone();
let client2 = client.clone();
banner::print_startup_app_config(&config);
Expand All @@ -42,7 +42,7 @@ async fn main() -> anyhow::Result<()> {
let mut timer= std::pin::pin!(tokio::time::sleep(Duration::from_secs(60)));
loop {
debug!("Waiting for next request ...");
if let Err(e) = logic_reply::process_requests(config2.clone(), client2.clone()).await {
if let Err(e) = logic_reply::process_requests(config, client2.clone()).await {
match e {
BeamConnectError::ProxyTimeoutError => {
debug!("{e}");
Expand Down Expand Up @@ -71,11 +71,9 @@ async fn main() -> anyhow::Result<()> {
#[allow(unused_mut)]
let mut executers = vec![http_executor];
#[cfg(feature = "sockets")]
executers.push(sockets::spawn_socket_task_poller(config.clone()));
executers.push(sockets::spawn_socket_task_poller(config));

let config = Arc::new(config.clone());

if let Err(e) = server(&config).await {
if let Err(e) = server(config).await {
error!("Server error: {}", e);
}
info!("Shutting down...");
Expand All @@ -84,7 +82,7 @@ async fn main() -> anyhow::Result<()> {
}

// See https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs
async fn server(config: &Arc<Config>) -> anyhow::Result<()> {
async fn server(config: &'static Config) -> anyhow::Result<()> {
let listener = TcpListener::bind(config.bind_addr.clone()).await?;

let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
Expand All @@ -106,9 +104,7 @@ async fn server(config: &Arc<Config>) -> anyhow::Result<()> {

let stream = hyper_util::rt::TokioIo::new(stream);

let config = config.clone();
let conn = server.serve_connection_with_upgrades(stream, service_fn(move |req| {
let config = config.clone();
handler_http_wrapper(req, config)
}));

Expand Down Expand Up @@ -146,7 +142,7 @@ pub type Response<T = BoxBody<Bytes, anyhow::Error>> = hyper::Response<T>;

pub(crate) async fn handler_http_wrapper(
req: Request<Incoming>,
config: Arc<Config>,
config: &'static Config,
) -> Result<Response, Infallible> {
// On https connections we want to emulate that we successfully connected to get the actual http request
if req.method() == Method::CONNECT {
Expand All @@ -162,7 +158,6 @@ pub(crate) async fn handler_http_wrapper(
Ok(s) => s,
};
server::conn::auto::Builder::new(TokioExecutor::new()).serve_connection_with_upgrades(TokioIo::new(tls_connection), service_fn(|req| {
let config = config.clone();
let authority = authority.clone();
async move {
match handler_http(req, config, authority).await {
Expand Down
11 changes: 5 additions & 6 deletions src/sockets.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{time::Duration, collections::HashSet, sync::Arc, convert::Infallible};
use std::{time::Duration, collections::HashSet, convert::Infallible};

use futures_util::TryStreamExt;
use http_body_util::{combinators::BoxBody, BodyExt, BodyStream, StreamBody};
Expand All @@ -12,7 +12,7 @@ use reqwest::Response;
use crate::{config::Config, errors::BeamConnectError, structs::MyStatusCode};


pub(crate) fn spawn_socket_task_poller(config: Config) -> JoinHandle<()> {
pub(crate) fn spawn_socket_task_poller(config: &'static Config) -> JoinHandle<()> {
tokio::spawn(async move {
use BeamConnectError::*;
let mut seen: HashSet<MsgId> = HashSet::new();
Expand Down Expand Up @@ -41,10 +41,9 @@ pub(crate) fn spawn_socket_task_poller(config: Config) -> JoinHandle<()> {
warn!("Invalid app id skipping");
continue;
};
let config_clone = config.clone();
tokio::spawn(async move {
match connect_proxy(&task.id, &config_clone).await {
Ok(resp) => tunnel(resp, client, &config_clone).await,
match connect_proxy(&task.id, config).await {
Ok(resp) => tunnel(resp, client, config).await,
Err(e) => {
warn!("{e}");
},
Expand Down Expand Up @@ -200,7 +199,7 @@ fn tunnel_upgrade(client: Option<OnUpgrade>, server: Option<OnUpgrade>) {
}
}

pub(crate) async fn handle_via_sockets(mut req: Request<Incoming>, config: &Arc<Config>, target: &AppId, auth: HeaderValue) -> Result<crate::Response, MyStatusCode> {
pub(crate) async fn handle_via_sockets(mut req: Request<Incoming>, config: &Config, target: &AppId, auth: HeaderValue) -> Result<crate::Response, MyStatusCode> {
let resp = config.client
.post(format!("{}v1/sockets/{target}", config.proxy_url))
.header(header::AUTHORIZATION, auth)
Expand Down

0 comments on commit 987ef21

Please sign in to comment.