Skip to content

Commit

Permalink
Add --cors flag
Browse files Browse the repository at this point in the history
  • Loading branch information
erik committed Oct 8, 2023
1 parent b4cd989 commit aeafb38
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 60 deletions.
36 changes: 28 additions & 8 deletions projects/hotpot/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,22 @@ enum Commands {
#[arg(short, long, default_value = "8080")]
port: u16,

/// Allow uploading of activities via `/upload`
// TODO: specify token somehow
/// Allow uploading of activities via `/upload`.
///
/// Remember to set `HOTPOT_UPLOAD_TOKEN` environment variable.
#[arg(long, default_value = "false")]
allow_upload: bool,
upload: bool,

/// Enable Strava activity webhook (requires authenticating via `strava-auth`)
/// Enable Strava activity webhook
///
/// Use `strava-auth` subcommand to grab OAuth tokens.
#[arg(long, default_value = "false")]
strava_webhook: bool,

/// Allow cross origin requests (CORS headers)
#[arg(long, default_value = "false")]
cors: bool,
},

/// Authenticate with Strava to fetch OAuth tokens for webhook.
Expand Down Expand Up @@ -174,19 +183,25 @@ fn run() -> Result<()> {
Commands::Serve {
host,
port,
allow_upload,
upload,
strava_webhook,
cors,
} => {
let db = Database::new(&opts.global.db_path)?;
let addr = format!("{}:{}", host, port).parse()?;
let routes = web::RouteConfig {
strava_webhook,
allow_upload,
upload,
tiles: true,
strava_auth: false,
};

web::run_blocking(addr, db, routes)?;
let config = web::Config {
cors,
upload_token: std::env::var("HOTPOT_UPLOAD_TOKEN").ok(),
};

web::run_blocking(addr, db, config, routes)?;
}

Commands::StravaAuth { host, port } => {
Expand All @@ -196,7 +211,12 @@ fn run() -> Result<()> {
strava_auth: true,
tiles: false,
strava_webhook: false,
allow_upload: false,
upload: false,
};

let config = web::Config {
cors: false,
upload_token: None,
};

println!(
Expand All @@ -205,7 +225,7 @@ fn run() -> Result<()> {
\n==============================",
addr
);
web::run_blocking(addr, db, routes)?;
web::run_blocking(addr, db, config, routes)?;
}
};

Expand Down
118 changes: 66 additions & 52 deletions projects/hotpot/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,54 +19,35 @@ use tokio::runtime::Runtime;
use tower_http::trace::TraceLayer;
use tracing::log::info;

use crate::db::ActivityFilter;
use crate::db::Database;
use crate::db::{ActivityFilter, Database};
use crate::raster::DEFAULT_GRADIENT;
use crate::tile::Tile;
use crate::web::strava::StravaAuth;
use crate::{activity, raster};

mod strava;

struct RequestData {
method: Method,
uri: Uri,
}

async fn store_request_data<B>(req: Request<B>, next: Next<B>) -> Response {
let data = RequestData {
method: req.method().clone(),
uri: req.uri().clone(),
};

let mut res = next.run(req).await;
res.extensions_mut().insert(data);

res
}

fn trace_response(res: &Response, latency: Duration, _span: &tracing::Span) {
let data = res.extensions().get::<RequestData>().unwrap();

tracing::info!(
status = %res.status().as_u16(),
method = %data.method,
uri = %data.uri,
latency = ?latency,
size = res.size_hint().exact(),
"response"
);
#[derive(Clone)]
pub struct Config {
pub cors: bool,
pub upload_token: Option<String>,
}

#[derive(Clone)]
pub struct AppState {
db: Arc<Database>,
strava: StravaAuth,
config: Config,
}

pub fn run_blocking(addr: SocketAddr, db: Database, routes: RouteConfig) -> Result<()> {
pub fn run_blocking(
addr: SocketAddr,
db: Database,
config: Config,
routes: RouteConfig,
) -> Result<()> {
let rt = Runtime::new()?;
let fut = run_async(addr, db, routes);
let fut = run_async(addr, db, config, routes);
rt.block_on(fut)?;
Ok(())
}
Expand All @@ -75,11 +56,11 @@ pub struct RouteConfig {
pub tiles: bool,
pub strava_webhook: bool,
pub strava_auth: bool,
pub allow_upload: bool,
pub upload: bool,
}

impl RouteConfig {
fn build<S>(&self, db: Database) -> Result<Router<S>> {
fn build<S>(&self, db: Database, config: Config) -> Result<Router<S>> {
// TODO: MVT endpoint?
let mut router = Router::new()
.layer(axum::middleware::from_fn(store_request_data))
Expand All @@ -100,7 +81,7 @@ impl RouteConfig {
router = router.nest("/strava", strava::auth_routes());
}

if self.allow_upload {
if self.upload {
router = router.route("/upload", post(upload_activity));
}

Expand All @@ -112,21 +93,27 @@ impl RouteConfig {
};

Ok(router.with_state(AppState {
config,
strava,
db: Arc::new(db),
}))
}
}

async fn run_async(addr: SocketAddr, db: Database, routes: RouteConfig) -> Result<()> {
async fn run_async(
addr: SocketAddr,
db: Database,
config: Config,
routes: RouteConfig,
) -> Result<()> {
tracing_subscriber::fmt()
.compact()
.with_max_level(tracing::Level::INFO)
.init();

info!("Listening on http://{}", addr);

let router = routes.build(db)?;
let router = routes.build(db, config)?;
Server::bind(&addr)
.serve(router.into_make_service())
.await?;
Expand All @@ -150,7 +137,7 @@ struct RenderQueryParams {
}

async fn render_tile(
State(AppState { db, .. }): State<AppState>,
State(AppState { db, config, .. }): State<AppState>,
Path((z, x, y)): Path<(u8, u32, u32)>,
Query(params): Query<RenderQueryParams>,
) -> impl IntoResponse {
Expand Down Expand Up @@ -182,17 +169,15 @@ async fn render_tile(
))
.unwrap();

// TODO: seems hacky
(
axum::response::AppendHeaders([
(header::CONTENT_TYPE, "image/png"),
(header::CACHE_CONTROL, "max-age=3600"),
// TODO: should be configurable.
(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"),
]),
bytes,
)
.into_response()
let mut res = axum::response::Response::builder()
.header(header::CONTENT_TYPE, "image/png")
.header(header::CACHE_CONTROL, "max-age=86400");

if config.cors {
res = res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
}

res.body(bytes).unwrap().into_parts().into_response()
}
Ok(None) => StatusCode::NO_CONTENT.into_response(),
Err(e) => {
Expand All @@ -203,12 +188,11 @@ async fn render_tile(
}

async fn upload_activity(
State(AppState { db, .. }): State<AppState>,
State(AppState { db, config, .. }): State<AppState>,
TypedHeader(auth): TypedHeader<axum::headers::Authorization<Bearer>>,
mut multipart: Multipart,
) -> impl IntoResponse {
// TODO: real auth
if auth.token() != "magic" {
if Some(auth.token()) != config.upload_token.as_deref() {
return (StatusCode::UNAUTHORIZED, "bad token");
}

Expand Down Expand Up @@ -246,3 +230,33 @@ async fn upload_activity(

(StatusCode::OK, "added!")
}

struct RequestData {
method: Method,
uri: Uri,
}

async fn store_request_data<B>(req: Request<B>, next: Next<B>) -> Response {
let data = RequestData {
method: req.method().clone(),
uri: req.uri().clone(),
};

let mut res = next.run(req).await;
res.extensions_mut().insert(data);

res
}

fn trace_response(res: &Response, latency: Duration, _span: &tracing::Span) {
let data = res.extensions().get::<RequestData>().unwrap();

tracing::info!(
status = %res.status().as_u16(),
method = %data.method,
uri = %data.uri,
latency = ?latency,
size = res.size_hint().exact(),
"response"
);
}

0 comments on commit aeafb38

Please sign in to comment.