diff --git a/CHANGELOG.md b/CHANGELOG.md index d0a78752..737f4e3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - 18 new icons available (see https://github.com/tabler/tabler-icons/releases/tag/v2.40.0) - Support multiple statements in [`on_connect.sql`](./configuration.md) in MySQL. - Randomize postgres prepared statement names to avoid name collisions. This should fix a bug where SQLPage would report errors like `prepared statement "sqlx_s_3" already exists` when using a connection pooler in front of a PostgreSQL database. + - Delegate statement preparation to sqlx. The logic of preparing statements and caching them for later reuse is now entirely delegated to the sql driver library (sqlx). This simplifies the code and logic inside sqlpage itself. More importantly, statements are now prepared in a streaming fashion when a file is first loaded, instead of all at once, which allows referencing a temporary table created at the start of a file in a later statement in the same file. ## 0.15.2 (2023-11-12) diff --git a/mssql/setup.sql b/mssql/setup.sql index 44dcf5aa..7fca01f3 100644 --- a/mssql/setup.sql +++ b/mssql/setup.sql @@ -9,4 +9,8 @@ GO CREATE LOGIN root WITH PASSWORD = 'Password123!'; CREATE USER root FOR LOGIN root; +GO + +GRANT CREATE TABLE TO root; +GRANT ALTER, DELETE, INSERT, SELECT, UPDATE ON SCHEMA::dbo TO root; GO \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index cff981c6..aa7a9295 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ impl AppState { let file_system = FileSystem::init(&config.web_root, &db).await; sql_file_cache.add_static( PathBuf::from("index.sql"), - ParsedSqlFile::new(&db, include_str!("../index.sql")).await, + ParsedSqlFile::new(&db, include_str!("../index.sql")), ); Ok(AppState { db, diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index d0e505be..83dc83ae 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -5,18 +5,17 @@ use serde_json::Value; use std::borrow::Cow; use std::collections::HashMap; -use super::sql::{ParsedSQLStatement, ParsedSqlFile}; +use super::sql::{ParsedSqlFile, ParsedStatement, StmtWithParams}; use crate::webserver::database::sql_pseudofunctions::extract_req_param; use crate::webserver::http::{RequestInfo, SingleOrVec}; use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use sqlx::pool::PoolConnection; -use sqlx::query::Query; -use sqlx::{AnyConnection, Arguments, Either, Executor, Row, Statement}; +use sqlx::{Any, AnyConnection, Arguments, Either, Executor, Row, Statement}; use super::sql_pseudofunctions::StmtParam; use super::sql_to_json::sql_to_json; -use super::{highlight_sql_error, Database, DbItem, PreparedStatement}; +use super::{highlight_sql_error, Database, DbItem}; impl Database { pub(crate) async fn prepare_with( @@ -41,22 +40,22 @@ pub fn stream_query_results<'a>( let mut connection_opt = None; for res in &sql_file.statements { match res { - ParsedSQLStatement::Statement(stmt) => { + ParsedStatement::StmtWithParams(stmt) => { let query = bind_parameters(stmt, request).await?; let connection = take_connection(db, &mut connection_opt).await?; - let mut stream = query.fetch_many(connection); + let mut stream = connection.fetch_many(query); while let Some(elem) = stream.next().await { let is_err = elem.is_err(); - yield parse_single_sql_result(stmt, elem); + yield parse_single_sql_result(&stmt.query, elem); if is_err { break; } } }, - ParsedSQLStatement::SetVariable { variable, value} => { + ParsedStatement::SetVariable { variable, value} => { let query = bind_parameters(value, request).await?; let connection = take_connection(db, &mut connection_opt).await?; - let row = query.fetch_optional(connection).await?; + let row = connection.fetch_optional(query).await?; let (vars, name) = vars_and_name(request, variable)?; if let Some(row) = row { vars.insert(name.clone(), row_to_varvalue(&row)); @@ -64,10 +63,10 @@ pub fn stream_query_results<'a>( vars.remove(&name); } }, - ParsedSQLStatement::StaticSimpleSelect(value) => { + ParsedStatement::StaticSimpleSelect(value) => { yield DbItem::Row(value.clone().into()) } - ParsedSQLStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)), + ParsedStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)), } } } @@ -132,10 +131,7 @@ async fn take_connection<'a, 'b>( } #[inline] -fn parse_single_sql_result( - stmt: &PreparedStatement, - res: sqlx::Result>, -) -> DbItem { +fn parse_single_sql_result(sql: &str, res: sqlx::Result>) -> DbItem { match res { Ok(Either::Right(r)) => DbItem::Row(super::sql_to_json::row_to_json(&r)), Ok(Either::Left(res)) => { @@ -144,7 +140,7 @@ fn parse_single_sql_result( } Err(err) => DbItem::Error(highlight_sql_error( "Failed to execute SQL statement", - stmt.statement.sql(), + sql, err, )), } @@ -159,18 +155,43 @@ fn clone_anyhow_err(err: &anyhow::Error) -> anyhow::Error { } async fn bind_parameters<'a>( - stmt: &'a PreparedStatement, + stmt: &'a StmtWithParams, request: &'a RequestInfo, -) -> anyhow::Result>> { +) -> anyhow::Result> { + let sql = stmt.query.as_str(); let mut arguments = AnyArguments::default(); - for param in &stmt.parameters { + for param in &stmt.params { let argument = extract_req_param(param, request).await?; - log::debug!("Binding value {:?} in statement {}", &argument, stmt); + log::debug!("Binding value {:?} in statement {}", &argument, stmt.query); match argument { None => arguments.add(None::), Some(Cow::Owned(s)) => arguments.add(s), Some(Cow::Borrowed(v)) => arguments.add(v), } } - Ok(stmt.statement.query_with(arguments)) + Ok(StatementWithParams { sql, arguments }) +} + +pub struct StatementWithParams<'a> { + sql: &'a str, + arguments: AnyArguments<'a>, +} + +impl<'q> sqlx::Execute<'q, Any> for StatementWithParams<'q> { + fn sql(&self) -> &'q str { + self.sql + } + + fn statement(&self) -> Option<&>::Statement> { + None + } + + fn take_arguments(&mut self) -> Option<>::Arguments> { + Some(std::mem::take(&mut self.arguments)) + } + + fn persistent(&self) -> bool { + // Let sqlx create a prepared statement the first time it is executed, and then reuse it. + true + } } diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 722a677f..24e1d77e 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -18,18 +18,6 @@ pub enum DbItem { Error(anyhow::Error), } -struct PreparedStatement { - statement: sqlx::any::AnyStatement<'static>, - parameters: Vec, -} - -impl std::fmt::Display for PreparedStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use sqlx::Statement; - write!(f, "{}", self.statement.sql()) - } -} - #[must_use] pub fn highlight_sql_error( context: &str, diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 6fa0eec7..d22a24e5 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -1,5 +1,4 @@ use super::sql_pseudofunctions::{func_call_to_param, StmtParam}; -use super::PreparedStatement; use crate::file_cache::AsyncFromStrWithState; use crate::utils::add_value_to_map; use crate::{AppState, Database}; @@ -13,101 +12,51 @@ use sqlparser::dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, use sqlparser::parser::{Parser, ParserError}; use sqlparser::tokenizer::Token::{SemiColon, EOF}; use sqlparser::tokenizer::Tokenizer; -use sqlx::any::{AnyKind, AnyTypeInfo}; -use sqlx::Postgres; +use sqlx::any::AnyKind; use std::fmt::Write; use std::ops::ControlFlow; #[derive(Default)] pub struct ParsedSqlFile { - pub(super) statements: Vec, -} - -pub(super) enum ParsedSQLStatement { - Statement(PreparedStatement), - StaticSimpleSelect(serde_json::Map), - Error(anyhow::Error), - SetVariable { - variable: StmtParam, - value: PreparedStatement, - }, + pub(super) statements: Vec, } impl ParsedSqlFile { - pub async fn new(db: &Database, sql: &str) -> ParsedSqlFile { + #[must_use] + pub fn new(db: &Database, sql: &str) -> ParsedSqlFile { let dialect = dialect_for_db(db.connection.any_kind()); let parsed_statements = match parse_sql(dialect.as_ref(), sql) { Ok(parsed) => parsed, Err(err) => return Self::from_err(err), }; - let mut statements = Vec::with_capacity(8); - for parsed in parsed_statements { - statements.push(match parsed { - ParsedStatement::StaticSimpleSelect(s) => ParsedSQLStatement::StaticSimpleSelect(s), - ParsedStatement::Error(e) => ParsedSQLStatement::Error(e), - ParsedStatement::StmtWithParams(stmt_with_params) => { - prepare_query_with_params(db, stmt_with_params).await - } - ParsedStatement::SetVariable { variable, value } => { - match prepare_query_with_params(db, value).await { - ParsedSQLStatement::Statement(value) => { - ParsedSQLStatement::SetVariable { variable, value } - } - err => err, - } - } - }); - } - statements.shrink_to_fit(); + let statements = parsed_statements.collect(); ParsedSqlFile { statements } } fn from_err(e: impl Into) -> Self { Self { - statements: vec![ParsedSQLStatement::Error( + statements: vec![ParsedStatement::Error( e.into().context("SQLPage could not parse the SQL file"), )], } } } -async fn prepare_query_with_params( - db: &Database, - StmtWithParams { query, params }: StmtWithParams, -) -> ParsedSQLStatement { - let param_types = get_param_types(¶ms); - match db.prepare_with(&query, ¶m_types).await { - Ok(statement) => { - log::debug!("Successfully prepared SQL statement '{query}'"); - ParsedSQLStatement::Statement(PreparedStatement { - statement, - parameters: params, - }) - } - Err(err) => { - log::warn!("Failed to prepare {query:?}: {err:#}"); - ParsedSQLStatement::Error(err.context(format!( - "The database returned an error when preparing this SQL statement: {query}" - ))) - } - } -} - #[async_trait(? Send)] impl AsyncFromStrWithState for ParsedSqlFile { async fn from_str_with_state(app_state: &AppState, source: &str) -> anyhow::Result { - Ok(ParsedSqlFile::new(&app_state.db, source).await) + Ok(ParsedSqlFile::new(&app_state.db, source)) } } #[derive(Debug, PartialEq)] -struct StmtWithParams { - query: String, - params: Vec, +pub(super) struct StmtWithParams { + pub query: String, + pub params: Vec, } #[derive(Debug)] -enum ParsedStatement { +pub(super) enum ParsedStatement { StmtWithParams(StmtWithParams), StaticSimpleSelect(serde_json::Map), SetVariable { @@ -201,13 +150,6 @@ fn kind_of_dialect(dialect: &dyn Dialect) -> AnyKind { } } -fn get_param_types(parameters: &[StmtParam]) -> Vec { - parameters - .iter() - .map(|_p| >::type_info().into()) - .collect() -} - fn map_param(mut name: String) -> StmtParam { if name.is_empty() { return StmtParam::GetOrPost(name); diff --git a/tests/sql_test_files/it_works_create_table.sql b/tests/sql_test_files/it_works_create_table.sql new file mode 100644 index 00000000..e3ec8f39 --- /dev/null +++ b/tests/sql_test_files/it_works_create_table.sql @@ -0,0 +1,7 @@ +drop table if exists my_tmp_store; +create table my_tmp_store(x varchar(100)); + +insert into my_tmp_store(x) values ('It works !'); + +select 'card' as component; +select x as description from my_tmp_store; \ No newline at end of file