Skip to content

Commit

Permalink
Prevent panics from escaping SQLite custom functions
Browse files Browse the repository at this point in the history
This is essentially the same treatment that custom aggregate functions
got in ee2f792.
  • Loading branch information
z33ky committed Sep 14, 2020
1 parent b68e252 commit d47bb60
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion diesel/src/sqlite/connection/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
mut f: F,
) -> QueryResult<()>
where
F: FnMut(&RawConnection, Args) -> Ret + Send + 'static,
F: FnMut(&RawConnection, Args) -> Ret + Send + 'static + std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
Expand Down
4 changes: 2 additions & 2 deletions diesel/src/sqlite/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ impl SqliteConnection {
mut f: F,
) -> QueryResult<()>
where
F: FnMut(Args) -> Ret + Send + 'static,
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
F: FnMut(Args) -> Ret + Send + 'static + std::panic::RefUnwindSafe,
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
Expand Down
27 changes: 20 additions & 7 deletions diesel/src/sqlite/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ impl RawConnection {
where
F: FnMut(&Self, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
+ Send
+ 'static,
+ 'static
+ std::panic::RefUnwindSafe,
{
let fn_name = Self::get_fn_name(fn_name)?;
let flags = Self::get_flags(deterministic);
Expand Down Expand Up @@ -250,7 +251,8 @@ extern "C" fn run_custom_function<F>(
) where
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
+ Send
+ 'static,
+ 'static
+ std::panic::RefUnwindSafe,
{
static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
Expand All @@ -276,15 +278,26 @@ extern "C" fn run_custom_function<F>(
return;
}
};
match f(&conn, args) {
Ok(value) => value.result_of(ctx),
Err(e) => {

let mut f = std::panic::AssertUnwindSafe(f);
let result = std::panic::catch_unwind(move || {
use std::ops::DerefMut as _;
let result = f.deref_mut()(&conn, args);
mem::forget(conn);
result
});

match result {
Ok(Ok(value)) => value.result_of(ctx),
Ok(Err(e)) => {
let msg = e.to_string();
context_error_str(ctx, &msg);
},
Err(_) => {
let msg = format!("{} panicked", std::any::type_name::<F>());
unsafe { context_error_str(ctx, &msg) };
}
}

mem::forget(conn);
}
}

Expand Down
10 changes: 6 additions & 4 deletions diesel_derives/src/sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,10 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
f: F,
) -> QueryResult<()>
where
F: Fn(#(#arg_name,)*) -> Ret + Send + 'static,
F: Fn(#(#arg_name,)*) -> Ret + Send + 'static + ::std::panic::RefUnwindSafe,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
::std::panic::UnwindSafe,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
Expand All @@ -289,9 +290,10 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
mut f: F,
) -> QueryResult<()>
where
F: FnMut(#(#arg_name,)*) -> Ret + Send + 'static,
F: FnMut(#(#arg_name,)*) -> Ret + Send + 'static + ::std::panic::RefUnwindSafe,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
::std::panic::UnwindSafe,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
Expand Down

0 comments on commit d47bb60

Please sign in to comment.