Skip to content

Commit

Permalink
delegate statement preparation to sqlx (#135)
Browse files Browse the repository at this point in the history
* delegate statement preparation to sqlx

The logic of preparing statements and caching them for later reuse is now entirely delegated to the sql driver library (sqlx).
This simplifies the code and logic inside sqlpage itself.

 More importantly, statements are now prepared in a streaming fashion when a file is first loaded, instead of all at once, which allows referencing a temporary table created at the start of a file in a later statement in the same file.

 fixes #100

* remove temporary table usage

mssql does not have "create temp table"

* mssql permissions fix
  • Loading branch information
lovasoa authored Nov 19, 2023
1 parent da167bb commit edd94df
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- 18 new icons available (see https://github.com/tabler/tabler-icons/releases/tag/v2.40.0)
- Support multiple statements in [`on_connect.sql`](./configuration.md) in MySQL.
- Randomize postgres prepared statement names to avoid name collisions. This should fix a bug where SQLPage would report errors like `prepared statement "sqlx_s_3" already exists` when using a connection pooler in front of a PostgreSQL database.
- Delegate statement preparation to sqlx. The logic of preparing statements and caching them for later reuse is now entirely delegated to the sql driver library (sqlx). This simplifies the code and logic inside sqlpage itself. More importantly, statements are now prepared in a streaming fashion when a file is first loaded, instead of all at once, which allows referencing a temporary table created at the start of a file in a later statement in the same file.

## 0.15.2 (2023-11-12)

Expand Down
4 changes: 4 additions & 0 deletions mssql/setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ GO

CREATE LOGIN root WITH PASSWORD = 'Password123!';
CREATE USER root FOR LOGIN root;
GO

GRANT CREATE TABLE TO root;
GRANT ALTER, DELETE, INSERT, SELECT, UPDATE ON SCHEMA::dbo TO root;
GO
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl AppState {
let file_system = FileSystem::init(&config.web_root, &db).await;
sql_file_cache.add_static(
PathBuf::from("index.sql"),
ParsedSqlFile::new(&db, include_str!("../index.sql")).await,
ParsedSqlFile::new(&db, include_str!("../index.sql")),
);
Ok(AppState {
db,
Expand Down
63 changes: 42 additions & 21 deletions src/webserver/database/execute_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashMap;

use super::sql::{ParsedSQLStatement, ParsedSqlFile};
use super::sql::{ParsedSqlFile, ParsedStatement, StmtWithParams};
use crate::webserver::database::sql_pseudofunctions::extract_req_param;
use crate::webserver::http::{RequestInfo, SingleOrVec};

use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo};
use sqlx::pool::PoolConnection;
use sqlx::query::Query;
use sqlx::{AnyConnection, Arguments, Either, Executor, Row, Statement};
use sqlx::{Any, AnyConnection, Arguments, Either, Executor, Row, Statement};

use super::sql_pseudofunctions::StmtParam;
use super::sql_to_json::sql_to_json;
use super::{highlight_sql_error, Database, DbItem, PreparedStatement};
use super::{highlight_sql_error, Database, DbItem};

impl Database {
pub(crate) async fn prepare_with(
Expand All @@ -41,33 +40,33 @@ pub fn stream_query_results<'a>(
let mut connection_opt = None;
for res in &sql_file.statements {
match res {
ParsedSQLStatement::Statement(stmt) => {
ParsedStatement::StmtWithParams(stmt) => {
let query = bind_parameters(stmt, request).await?;
let connection = take_connection(db, &mut connection_opt).await?;
let mut stream = query.fetch_many(connection);
let mut stream = connection.fetch_many(query);
while let Some(elem) = stream.next().await {
let is_err = elem.is_err();
yield parse_single_sql_result(stmt, elem);
yield parse_single_sql_result(&stmt.query, elem);
if is_err {
break;
}
}
},
ParsedSQLStatement::SetVariable { variable, value} => {
ParsedStatement::SetVariable { variable, value} => {
let query = bind_parameters(value, request).await?;
let connection = take_connection(db, &mut connection_opt).await?;
let row = query.fetch_optional(connection).await?;
let row = connection.fetch_optional(query).await?;
let (vars, name) = vars_and_name(request, variable)?;
if let Some(row) = row {
vars.insert(name.clone(), row_to_varvalue(&row));
} else {
vars.remove(&name);
}
},
ParsedSQLStatement::StaticSimpleSelect(value) => {
ParsedStatement::StaticSimpleSelect(value) => {
yield DbItem::Row(value.clone().into())
}
ParsedSQLStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)),
ParsedStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)),
}
}
}
Expand Down Expand Up @@ -132,10 +131,7 @@ async fn take_connection<'a, 'b>(
}

#[inline]
fn parse_single_sql_result(
stmt: &PreparedStatement,
res: sqlx::Result<Either<AnyQueryResult, AnyRow>>,
) -> DbItem {
fn parse_single_sql_result(sql: &str, res: sqlx::Result<Either<AnyQueryResult, AnyRow>>) -> DbItem {
match res {
Ok(Either::Right(r)) => DbItem::Row(super::sql_to_json::row_to_json(&r)),
Ok(Either::Left(res)) => {
Expand All @@ -144,7 +140,7 @@ fn parse_single_sql_result(
}
Err(err) => DbItem::Error(highlight_sql_error(
"Failed to execute SQL statement",
stmt.statement.sql(),
sql,
err,
)),
}
Expand All @@ -159,18 +155,43 @@ fn clone_anyhow_err(err: &anyhow::Error) -> anyhow::Error {
}

async fn bind_parameters<'a>(
stmt: &'a PreparedStatement,
stmt: &'a StmtWithParams,
request: &'a RequestInfo,
) -> anyhow::Result<Query<'a, sqlx::Any, AnyArguments<'a>>> {
) -> anyhow::Result<StatementWithParams<'a>> {
let sql = stmt.query.as_str();
let mut arguments = AnyArguments::default();
for param in &stmt.parameters {
for param in &stmt.params {
let argument = extract_req_param(param, request).await?;
log::debug!("Binding value {:?} in statement {}", &argument, stmt);
log::debug!("Binding value {:?} in statement {}", &argument, stmt.query);
match argument {
None => arguments.add(None::<String>),
Some(Cow::Owned(s)) => arguments.add(s),
Some(Cow::Borrowed(v)) => arguments.add(v),
}
}
Ok(stmt.statement.query_with(arguments))
Ok(StatementWithParams { sql, arguments })
}

pub struct StatementWithParams<'a> {
sql: &'a str,
arguments: AnyArguments<'a>,
}

impl<'q> sqlx::Execute<'q, Any> for StatementWithParams<'q> {
fn sql(&self) -> &'q str {
self.sql
}

fn statement(&self) -> Option<&<Any as sqlx::database::HasStatement<'q>>::Statement> {
None
}

fn take_arguments(&mut self) -> Option<<Any as sqlx::database::HasArguments<'q>>::Arguments> {
Some(std::mem::take(&mut self.arguments))
}

fn persistent(&self) -> bool {
// Let sqlx create a prepared statement the first time it is executed, and then reuse it.
true
}
}
12 changes: 0 additions & 12 deletions src/webserver/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ pub enum DbItem {
Error(anyhow::Error),
}

struct PreparedStatement {
statement: sqlx::any::AnyStatement<'static>,
parameters: Vec<sql_pseudofunctions::StmtParam>,
}

impl std::fmt::Display for PreparedStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use sqlx::Statement;
write!(f, "{}", self.statement.sql())
}
}

#[must_use]
pub fn highlight_sql_error(
context: &str,
Expand Down
80 changes: 11 additions & 69 deletions src/webserver/database/sql.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::sql_pseudofunctions::{func_call_to_param, StmtParam};
use super::PreparedStatement;
use crate::file_cache::AsyncFromStrWithState;
use crate::utils::add_value_to_map;
use crate::{AppState, Database};
Expand All @@ -13,101 +12,51 @@ use sqlparser::dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect,
use sqlparser::parser::{Parser, ParserError};
use sqlparser::tokenizer::Token::{SemiColon, EOF};
use sqlparser::tokenizer::Tokenizer;
use sqlx::any::{AnyKind, AnyTypeInfo};
use sqlx::Postgres;
use sqlx::any::AnyKind;
use std::fmt::Write;
use std::ops::ControlFlow;

#[derive(Default)]
pub struct ParsedSqlFile {
pub(super) statements: Vec<ParsedSQLStatement>,
}

pub(super) enum ParsedSQLStatement {
Statement(PreparedStatement),
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
Error(anyhow::Error),
SetVariable {
variable: StmtParam,
value: PreparedStatement,
},
pub(super) statements: Vec<ParsedStatement>,
}

impl ParsedSqlFile {
pub async fn new(db: &Database, sql: &str) -> ParsedSqlFile {
#[must_use]
pub fn new(db: &Database, sql: &str) -> ParsedSqlFile {
let dialect = dialect_for_db(db.connection.any_kind());
let parsed_statements = match parse_sql(dialect.as_ref(), sql) {
Ok(parsed) => parsed,
Err(err) => return Self::from_err(err),
};
let mut statements = Vec::with_capacity(8);
for parsed in parsed_statements {
statements.push(match parsed {
ParsedStatement::StaticSimpleSelect(s) => ParsedSQLStatement::StaticSimpleSelect(s),
ParsedStatement::Error(e) => ParsedSQLStatement::Error(e),
ParsedStatement::StmtWithParams(stmt_with_params) => {
prepare_query_with_params(db, stmt_with_params).await
}
ParsedStatement::SetVariable { variable, value } => {
match prepare_query_with_params(db, value).await {
ParsedSQLStatement::Statement(value) => {
ParsedSQLStatement::SetVariable { variable, value }
}
err => err,
}
}
});
}
statements.shrink_to_fit();
let statements = parsed_statements.collect();
ParsedSqlFile { statements }
}

fn from_err(e: impl Into<anyhow::Error>) -> Self {
Self {
statements: vec![ParsedSQLStatement::Error(
statements: vec![ParsedStatement::Error(
e.into().context("SQLPage could not parse the SQL file"),
)],
}
}
}

async fn prepare_query_with_params(
db: &Database,
StmtWithParams { query, params }: StmtWithParams,
) -> ParsedSQLStatement {
let param_types = get_param_types(&params);
match db.prepare_with(&query, &param_types).await {
Ok(statement) => {
log::debug!("Successfully prepared SQL statement '{query}'");
ParsedSQLStatement::Statement(PreparedStatement {
statement,
parameters: params,
})
}
Err(err) => {
log::warn!("Failed to prepare {query:?}: {err:#}");
ParsedSQLStatement::Error(err.context(format!(
"The database returned an error when preparing this SQL statement: {query}"
)))
}
}
}

#[async_trait(? Send)]
impl AsyncFromStrWithState for ParsedSqlFile {
async fn from_str_with_state(app_state: &AppState, source: &str) -> anyhow::Result<Self> {
Ok(ParsedSqlFile::new(&app_state.db, source).await)
Ok(ParsedSqlFile::new(&app_state.db, source))
}
}

#[derive(Debug, PartialEq)]
struct StmtWithParams {
query: String,
params: Vec<StmtParam>,
pub(super) struct StmtWithParams {
pub query: String,
pub params: Vec<StmtParam>,
}

#[derive(Debug)]
enum ParsedStatement {
pub(super) enum ParsedStatement {
StmtWithParams(StmtWithParams),
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
SetVariable {
Expand Down Expand Up @@ -201,13 +150,6 @@ fn kind_of_dialect(dialect: &dyn Dialect) -> AnyKind {
}
}

fn get_param_types(parameters: &[StmtParam]) -> Vec<AnyTypeInfo> {
parameters
.iter()
.map(|_p| <str as sqlx::Type<Postgres>>::type_info().into())
.collect()
}

fn map_param(mut name: String) -> StmtParam {
if name.is_empty() {
return StmtParam::GetOrPost(name);
Expand Down
7 changes: 7 additions & 0 deletions tests/sql_test_files/it_works_create_table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
drop table if exists my_tmp_store;
create table my_tmp_store(x varchar(100));

insert into my_tmp_store(x) values ('It works !');

select 'card' as component;
select x as description from my_tmp_store;

0 comments on commit edd94df

Please sign in to comment.