Skip to content

Commit

Permalink
Reduce scopes of unsafe in sqlite aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
z33ky committed Sep 11, 2020
1 parent 09c1085 commit d694a55
Showing 1 changed file with 46 additions and 45 deletions.
91 changes: 46 additions & 45 deletions diesel/src/sqlite/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
unsafe {
let aggregate_context = unsafe {
// This block of unsafe code makes the following assumptions:
//
// * sqlite3_aggregate_context allocates sizeof::<OptionalAggregator<A>>
Expand Down Expand Up @@ -332,26 +332,26 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
null_aggregate_context_error(ctx);
return;
}
};
}
};

let args = slice::from_raw_parts(value_ptr, num_args as _);
let args = build_sql_function_args::<ArgsSqlType, Args>(args);
let mut aggregator = std::panic::AssertUnwindSafe(aggregator);
let result = args
.map(|args| std::panic::catch_unwind(move || Ok(aggregator.step(args))))
.unwrap_or_else(|e| Ok(Err(e)));
match result {
Ok(Ok(())) => (),
Ok(Err(e)) => {
let msg = e.to_string();
context_error_str(ctx, &msg);
}
Err(_) => {
let msg = format!("{}::step() panicked", std::any::type_name::<A>());
context_error_str(ctx, &msg);
}
};
}
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
let args = build_sql_function_args::<ArgsSqlType, Args>(args);
let mut aggregator = std::panic::AssertUnwindSafe(aggregator);
let result = args
.map(|args| std::panic::catch_unwind(move || Ok(aggregator.step(args))))
.unwrap_or_else(|e| Ok(Err(e)));
match result {
Ok(Ok(())) => (),
Ok(Err(e)) => {
let msg = e.to_string();
unsafe { context_error_str(ctx, &msg) };
}
Err(_) => {
let msg = format!("{}::step() panicked", std::any::type_name::<A>());
unsafe { context_error_str(ctx, &msg) };
}
};
}

extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
Expand All @@ -362,40 +362,41 @@ extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret,
Ret: ToSql<RetSqlType, Sqlite>,
Sqlite: HasSqlType<RetSqlType>,
{
unsafe {
let aggregate_context = unsafe {
// Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
// allocations occur, a null pointer is returned in this case
// See: https://www.sqlite.org/c3ref/aggregate_context.html
//
// For the reasoning about the safety of the OptionalAggregator handling
// see the comment in run_aggregator_step_function.
let aggregate_context = ffi::sqlite3_aggregate_context(ctx, 0);
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
let aggregator = match aggregate_context {
Some(ref mut a) => match std::mem::replace(a.as_mut(), OptionalAggregator::None) {
OptionalAggregator::Some(agg) => Some(agg),
OptionalAggregator::None => {
eprintln!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer");
std::process::abort();
}
},
None => None,
};

let result = std::panic::catch_unwind(|| A::finalize(aggregator))
.map(process_sql_function_result::<RetSqlType, Ret>);
ffi::sqlite3_aggregate_context(ctx, 0)
};

match result {
Ok(Ok(value)) => value.result_of(ctx),
Ok(Err(e)) => {
let msg = e.to_string();
context_error_str(ctx, &msg);
}
Err(_) => {
let msg = format!("{}::finalize() panicked", std::any::type_name::<A>());
context_error_str(ctx, &msg);
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
let aggregator = aggregate_context.as_mut().map(|a| {
let a = unsafe { a.as_mut() };
match std::mem::replace(a, OptionalAggregator::None) {
OptionalAggregator::Some(agg) => agg,
OptionalAggregator::None => {
eprintln!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer");
std::process::abort();
}
}
});

let result = std::panic::catch_unwind(|| A::finalize(aggregator))
.map(process_sql_function_result::<RetSqlType, Ret>);

match result {
Ok(Ok(value)) => value.result_of(ctx),
Ok(Err(e)) => {
let msg = e.to_string();
unsafe { context_error_str(ctx, &msg) };
}
Err(_) => {
let msg = format!("{}::finalize() panicked", std::any::type_name::<A>());
unsafe { context_error_str(ctx, &msg) };
}
}
}

Expand Down

0 comments on commit d694a55

Please sign in to comment.