Skip to content

Commit

Permalink
sqlite_worker: refactor locking (#2875)
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu authored Dec 14, 2023
1 parent 895cdd4 commit 0c0d70e
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 155 deletions.
92 changes: 48 additions & 44 deletions core/bin/sqlite_worker.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};

use anyhow::{anyhow, Result};
use axum::{
extract::{self, Path},
extract,
response::Json,
routing::{get, post},
Extension, Json, Router,
Router,
};
use dust::{
databases::database::DatabaseRow,
databases::database::DatabaseTable,
sqlite_workers::store::DatabasesStore,
sqlite_workers::{sqlite_database::SqliteDatabase, store},
utils::{self, error_response, APIResponse},
};
use dust::{databases::database::DatabaseTable, sqlite_workers::store::DatabasesStore};
use hyper::{Body, Client, Request, StatusCode};
use serde::Deserialize;
use serde_json::json;
use tokio::{
signal::unix::{signal, SignalKind},
sync::Mutex,
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::Mutex;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;

// Duration after which a database is considered inactive and can be removed from the registry.
const DATABASE_TIMEOUT_DURATION: Duration = std::time::Duration::from_secs(5 * 60); // 5 minutes

struct DatabaseEntry {
database: SqliteDatabase,
database: Arc<Mutex<SqliteDatabase>>,
last_accessed: Instant,
}

Expand All @@ -47,7 +46,7 @@ struct WorkerState {
impl WorkerState {
fn new(databases_store: Box<dyn store::DatabasesStore + Sync + Send>) -> Self {
Self {
databases_store: databases_store,
databases_store,

// TODO: store an instant of the last access for each DB.
registry: Arc::new(Mutex::new(HashMap::new())),
Expand Down Expand Up @@ -146,28 +145,31 @@ async fn index() -> &'static str {
// Databases

#[derive(Deserialize)]
struct DbQueryBody {
struct DbQueryPayload {
query: String,
tables: Vec<DatabaseTable>,
}

async fn db_query(
Path(db_id): Path<String>,
Json(payload): Json<DbQueryBody>,
Extension(state): Extension<Arc<WorkerState>>,
extract::Path(db_id): extract::Path<String>,
extract::Json(payload): Json<DbQueryPayload>,
extract::Extension(state): extract::Extension<Arc<WorkerState>>,
) -> (StatusCode, Json<APIResponse>) {
let mut registry = state.registry.lock().await;
let database = {
let mut registry = state.registry.lock().await;
let entry = registry
.entry(db_id.clone())
.or_insert_with(|| DatabaseEntry {
database: Arc::new(Mutex::new(SqliteDatabase::new(db_id))),
last_accessed: Instant::now(),
});
entry.last_accessed = Instant::now();
entry.database.clone()
};

let entry = registry
.entry(db_id.clone())
.or_insert_with(|| DatabaseEntry {
database: SqliteDatabase::new(db_id),
last_accessed: Instant::now(),
});
let mut guard = database.lock().await;

entry.last_accessed = Instant::now();
match entry
.database
match guard
.init(payload.tables, state.databases_store.clone())
.await
{
Expand All @@ -182,7 +184,7 @@ async fn db_query(
}
}

match entry.database.query(payload.query).await {
match guard.query(&payload.query).await {
Ok(results) => (
axum::http::StatusCode::OK,
Json(APIResponse {
Expand All @@ -208,18 +210,20 @@ struct DatabasesRowsUpsertPayload {
async fn databases_rows_upsert(
extract::Path((database_id, table_id)): extract::Path<(String, String)>,
extract::Json(payload): extract::Json<DatabasesRowsUpsertPayload>,
Extension(state): Extension<Arc<WorkerState>>,
extract::Extension(state): extract::Extension<Arc<WorkerState>>,
) -> (StatusCode, Json<APIResponse>) {
// Terminate the running DB thread if it exists.
let mut registry = state.registry.lock().await;
match registry.get(&database_id) {
Some(_) => {
// Removing the DB from the registry will terminate the thread once pending queries are
// finished.
registry.remove(&database_id);
// Terminate (invalidate) the DB if it exists.
{
let mut registry = state.registry.lock().await;
match registry.get(&database_id) {
Some(_) => {
// Removing the DB from the registry will destroy the SQLite connection and hence the
// in-memory DB.
registry.remove(&database_id);
}
None => (),
}
None => (),
}
};

let truncate = match payload.truncate {
Some(v) => v,
Expand Down Expand Up @@ -258,7 +262,7 @@ struct DatabasesRowsListQuery {
async fn databases_rows_list(
extract::Path((database_id, table_id)): extract::Path<(String, String)>,
extract::Query(query): extract::Query<DatabasesRowsListQuery>,
Extension(state): Extension<Arc<WorkerState>>,
extract::Extension(state): extract::Extension<Arc<WorkerState>>,
) -> (StatusCode, Json<APIResponse>) {
let limit_offset = match (query.limit, query.offset) {
(Some(limit), Some(offset)) => Some((limit, offset)),
Expand Down Expand Up @@ -291,7 +295,7 @@ async fn databases_rows_list(
async fn databases_row_retrieve(
extract::Path((database_id, table_id)): extract::Path<(String, String)>,
extract::Path(row_id): extract::Path<String>,
Extension(state): Extension<Arc<WorkerState>>,
extract::Extension(state): extract::Extension<Arc<WorkerState>>,
) -> (StatusCode, Json<APIResponse>) {
match state
.databases_store
Expand Down
Loading

0 comments on commit 0c0d70e

Please sign in to comment.