Skip to content

Commit

Permalink
small code reorganization
Browse files Browse the repository at this point in the history
make functions available in a separate module
  • Loading branch information
lovasoa committed Jul 16, 2023
1 parent 62105c4 commit 0061d13
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 136 deletions.
2 changes: 1 addition & 1 deletion src/filesystem.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::webserver::database::ErrorWithStatus;
use crate::webserver::ErrorWithStatus;
use crate::webserver::{make_placeholder, Database};
use crate::AppState;
use anyhow::Context;
Expand Down
102 changes: 4 additions & 98 deletions src/webserver/database/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod sql;
mod sql_pseudofunctions;

use actix_web::http::StatusCode;
use actix_web_httpauth::headers::authorization::Basic;
use anyhow::{anyhow, Context};
use futures_util::stream::{self, BoxStream, Stream};
use futures_util::StreamExt;
Expand All @@ -15,7 +14,8 @@ use std::time::Duration;
use crate::app_config::AppConfig;
pub use crate::file_cache::FileCache;
use crate::utils::add_value_to_map;
use crate::webserver::http::{RequestInfo, SingleOrVec};
use crate::webserver::database::sql_pseudofunctions::extract_req_param;
use crate::webserver::http::RequestInfo;
use crate::MIGRATIONS_DIR;
pub use sql::make_placeholder;
pub use sql::ParsedSqlFile;
Expand Down Expand Up @@ -162,86 +162,6 @@ fn bind_parameters<'a>(
Ok(stmt.statement.query_with(arguments))
}

fn extract_req_param<'a>(
param: &StmtParam,
request: &'a RequestInfo,
) -> anyhow::Result<Option<Cow<'a, str>>> {
Ok(match param {
StmtParam::Get(x) => request.get_variables.get(x).map(SingleOrVec::as_json_str),
StmtParam::Post(x) => request.post_variables.get(x).map(SingleOrVec::as_json_str),
StmtParam::GetOrPost(x) => request
.post_variables
.get(x)
.or_else(|| request.get_variables.get(x))
.map(SingleOrVec::as_json_str),
StmtParam::Cookie(x) => request.cookies.get(x).map(SingleOrVec::as_json_str),
StmtParam::Header(x) => request.headers.get(x).map(SingleOrVec::as_json_str),
StmtParam::Error(x) => anyhow::bail!("{}", x),
StmtParam::BasicAuthPassword => extract_basic_auth_password(request)
.map(Cow::Borrowed)
.map(Some)?,
StmtParam::BasicAuthUsername => extract_basic_auth_username(request)
.map(Cow::Borrowed)
.map(Some)?,
StmtParam::HashPassword(inner) => extract_req_param(inner, request)?
.map_or(Ok(None), |x| hash_password(&x).map(Cow::Owned).map(Some))?,
StmtParam::RandomString(len) => Some(Cow::Owned(random_string(*len))),
})
}

fn random_string(len: usize) -> String {
use rand::{distributions::Alphanumeric, Rng};
password_hash::rand_core::OsRng
.sample_iter(&Alphanumeric)
.take(len)
.map(char::from)
.collect()
}

fn hash_password(password: &str) -> anyhow::Result<String> {
let phf = argon2::Argon2::default();
let salt = password_hash::SaltString::generate(&mut password_hash::rand_core::OsRng);
let password_hash = &password_hash::PasswordHash::generate(phf, password, &salt)
.map_err(|e| anyhow!("Unable to hash password: {}", e))?;
Ok(password_hash.to_string())
}

#[derive(Debug)]
pub struct ErrorWithStatus {
pub status: StatusCode,
}
impl std::fmt::Display for ErrorWithStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.status)
}
}
impl std::error::Error for ErrorWithStatus {}

fn extract_basic_auth(request: &RequestInfo) -> anyhow::Result<&Basic> {
request
.basic_auth
.as_ref()
.ok_or_else(|| {
anyhow::Error::new(ErrorWithStatus {
status: StatusCode::UNAUTHORIZED,
})
})
.with_context(|| "Expected the user to be authenticated with HTTP basic auth")
}

fn extract_basic_auth_username(request: &RequestInfo) -> anyhow::Result<&str> {
Ok(extract_basic_auth(request)?.user_id())
}

fn extract_basic_auth_password(request: &RequestInfo) -> anyhow::Result<&str> {
let password = extract_basic_auth(request)?.password().ok_or_else(|| {
anyhow::Error::new(ErrorWithStatus {
status: StatusCode::UNAUTHORIZED,
})
})?;
Ok(password)
}

#[derive(Debug)]
pub enum DbItem {
Row(Value),
Expand Down Expand Up @@ -371,7 +291,7 @@ fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfi
}
struct PreparedStatement {
statement: AnyStatement<'static>,
parameters: Vec<StmtParam>,
parameters: Vec<sql_pseudofunctions::StmtParam>,
}

impl Display for PreparedStatement {
Expand All @@ -380,20 +300,6 @@ impl Display for PreparedStatement {
}
}

#[derive(Debug, PartialEq, Eq)]
enum StmtParam {
Get(String),
Post(String),
GetOrPost(String),
Cookie(String),
Header(String),
Error(String),
BasicAuthPassword,
BasicAuthUsername,
HashPassword(Box<StmtParam>),
RandomString(usize),
}

#[actix_web::test]
async fn test_row_to_json() -> anyhow::Result<()> {
use sqlx::Connection;
Expand Down
52 changes: 16 additions & 36 deletions src/webserver/database/sql.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::sql_pseudofunctions::{func_call_to_param, StmtParam};
use super::PreparedStatement;
use crate::file_cache::AsyncFromStrWithState;
use crate::webserver::database::StmtParam;
use crate::{AppState, Database};
use async_trait::async_trait;
use sqlparser::ast::{
Expand Down Expand Up @@ -155,27 +155,22 @@ impl ParameterExtractor {
}
}

fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam {
match func_name {
"cookie" => extract_single_quoted_string("cookie", arguments)
.map_or_else(StmtParam::Error, StmtParam::Cookie),
"header" => extract_single_quoted_string("header", arguments)
.map_or_else(StmtParam::Error, StmtParam::Header),
"basic_auth_username" => StmtParam::BasicAuthUsername,
"basic_auth_password" => StmtParam::BasicAuthPassword,
"hash_password" => extract_variable_argument("hash_password", arguments)
.map(Box::new)
.map_or_else(StmtParam::Error, StmtParam::HashPassword),
"random_string" => extract_integer("random_string", arguments)
.map_or_else(StmtParam::Error, StmtParam::RandomString),
unknown_name => StmtParam::Error(format!(
"Unknown function {unknown_name}({})",
FormatArguments(arguments)
)),
/** This is a helper struct to format a list of arguments for an error message. */
pub(super) struct FormatArguments<'a>(pub &'a [FunctionArg]);
impl std::fmt::Display for FormatArguments<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut args = self.0.iter();
if let Some(arg) = args.next() {
write!(f, "{arg}")?;
}
for arg in args {
write!(f, ", {arg}")?;
}
Ok(())
}
}

fn extract_single_quoted_string(
pub(super) fn extract_single_quoted_string(
func_name: &'static str,
arguments: &mut [FunctionArg],
) -> Result<String, String> {
Expand All @@ -190,7 +185,7 @@ fn extract_single_quoted_string(
}
}

fn extract_integer(
pub(super) fn extract_integer(
func_name: &'static str,
arguments: &mut [FunctionArg],
) -> Result<usize, String> {
Expand All @@ -205,22 +200,7 @@ fn extract_integer(
}
}

/** This is a helper struct to format a list of arguments for an error message. */
struct FormatArguments<'a>(&'a [FunctionArg]);
impl std::fmt::Display for FormatArguments<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut args = self.0.iter();
if let Some(arg) = args.next() {
write!(f, "{arg}")?;
}
for arg in args {
write!(f, ", {arg}")?;
}
Ok(())
}
}

fn extract_variable_argument(
pub(super) fn extract_variable_argument(
func_name: &'static str,
arguments: &mut [FunctionArg],
) -> Result<StmtParam, String> {
Expand Down
118 changes: 118 additions & 0 deletions src/webserver/database/sql_pseudofunctions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::borrow::Cow;

use actix_web::http::StatusCode;
use actix_web_httpauth::headers::authorization::Basic;
use sqlparser::ast::FunctionArg;

use crate::webserver::{
http::{RequestInfo, SingleOrVec},
ErrorWithStatus,
};

use super::sql::{
extract_integer, extract_single_quoted_string, extract_variable_argument, FormatArguments,
};
use anyhow::{anyhow, Context};

#[derive(Debug, PartialEq, Eq)]
pub(super) enum StmtParam {
Get(String),
Post(String),
GetOrPost(String),
Cookie(String),
Header(String),
Error(String),
BasicAuthPassword,
BasicAuthUsername,
HashPassword(Box<StmtParam>),
RandomString(usize),
}

pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam {
match func_name {
"cookie" => extract_single_quoted_string("cookie", arguments)
.map_or_else(StmtParam::Error, StmtParam::Cookie),
"header" => extract_single_quoted_string("header", arguments)
.map_or_else(StmtParam::Error, StmtParam::Header),
"basic_auth_username" => StmtParam::BasicAuthUsername,
"basic_auth_password" => StmtParam::BasicAuthPassword,
"hash_password" => extract_variable_argument("hash_password", arguments)
.map(Box::new)
.map_or_else(StmtParam::Error, StmtParam::HashPassword),
"random_string" => extract_integer("random_string", arguments)
.map_or_else(StmtParam::Error, StmtParam::RandomString),
unknown_name => StmtParam::Error(format!(
"Unknown function {unknown_name}({})",
FormatArguments(arguments)
)),
}
}

pub(super) fn extract_req_param<'a>(
param: &StmtParam,
request: &'a RequestInfo,
) -> anyhow::Result<Option<Cow<'a, str>>> {
Ok(match param {
StmtParam::Get(x) => request.get_variables.get(x).map(SingleOrVec::as_json_str),
StmtParam::Post(x) => request.post_variables.get(x).map(SingleOrVec::as_json_str),
StmtParam::GetOrPost(x) => request
.post_variables
.get(x)
.or_else(|| request.get_variables.get(x))
.map(SingleOrVec::as_json_str),
StmtParam::Cookie(x) => request.cookies.get(x).map(SingleOrVec::as_json_str),
StmtParam::Header(x) => request.headers.get(x).map(SingleOrVec::as_json_str),
StmtParam::Error(x) => anyhow::bail!("{}", x),
StmtParam::BasicAuthPassword => extract_basic_auth_password(request)
.map(Cow::Borrowed)
.map(Some)?,
StmtParam::BasicAuthUsername => extract_basic_auth_username(request)
.map(Cow::Borrowed)
.map(Some)?,
StmtParam::HashPassword(inner) => extract_req_param(inner, request)?
.map_or(Ok(None), |x| hash_password(&x).map(Cow::Owned).map(Some))?,
StmtParam::RandomString(len) => Some(Cow::Owned(random_string(*len))),
})
}

fn random_string(len: usize) -> String {
use rand::{distributions::Alphanumeric, Rng};
password_hash::rand_core::OsRng
.sample_iter(&Alphanumeric)
.take(len)
.map(char::from)
.collect()
}

fn hash_password(password: &str) -> anyhow::Result<String> {
let phf = argon2::Argon2::default();
let salt = password_hash::SaltString::generate(&mut password_hash::rand_core::OsRng);
let password_hash = &password_hash::PasswordHash::generate(phf, password, &salt)
.map_err(|e| anyhow!("Unable to hash password: {}", e))?;
Ok(password_hash.to_string())
}

fn extract_basic_auth_username(request: &RequestInfo) -> anyhow::Result<&str> {
Ok(extract_basic_auth(request)?.user_id())
}

fn extract_basic_auth_password(request: &RequestInfo) -> anyhow::Result<&str> {
let password = extract_basic_auth(request)?.password().ok_or_else(|| {
anyhow::Error::new(ErrorWithStatus {
status: StatusCode::UNAUTHORIZED,
})
})?;
Ok(password)
}

fn extract_basic_auth(request: &RequestInfo) -> anyhow::Result<&Basic> {
request
.basic_auth
.as_ref()
.ok_or_else(|| {
anyhow::Error::new(ErrorWithStatus {
status: StatusCode::UNAUTHORIZED,
})
})
.with_context(|| "Expected the user to be authenticated with HTTP basic auth")
}
12 changes: 12 additions & 0 deletions src/webserver/error_with_status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use actix_web::http::StatusCode;

#[derive(Debug)]
pub struct ErrorWithStatus {
pub status: StatusCode,
}
impl std::fmt::Display for ErrorWithStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.status)
}
}
impl std::error::Error for ErrorWithStatus {}
3 changes: 2 additions & 1 deletion src/webserver/http.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::render::{HeaderContext, PageContext, RenderContext};
use crate::webserver::database::{stream_query_results, DbItem, ErrorWithStatus};
use crate::webserver::database::{stream_query_results, DbItem};
use crate::webserver::ErrorWithStatus;
use crate::{AppState, Config, ParsedSqlFile};
use actix_web::dev::{fn_service, ServiceFactory, ServiceRequest};
use actix_web::error::{ErrorInternalServerError, ErrorNotFound};
Expand Down
2 changes: 2 additions & 0 deletions src/webserver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub mod database;
pub mod error_with_status;
pub mod http;

pub use database::Database;
pub use error_with_status::ErrorWithStatus;

pub use database::apply_migrations;
pub use database::make_placeholder;
Expand Down

0 comments on commit 0061d13

Please sign in to comment.