Skip to content

Commit

Permalink
Add UI state persistence to DB
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasPickering committed Dec 6, 2023
1 parent 7f27875 commit 0abb0b3
Show file tree
Hide file tree
Showing 20 changed files with 564 additions and 170 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Mostly for development purposes
- Add collection ID/path to help modal ([#59](https://github.com/LucasPickering/slumber/issues/59))
- Also add collection ID to terminal title
- Persist UI state between sessions ([#39](https://github.com/LucasPickering/slumber/issues/39))

### Changed

Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# latest watchexec we can get rid of this.
# https://github.com/watchexec/cargo-watch/issues/269

RUST_LOG=slumber=trace watchexec --restart \
RUST_LOG=${RUST_LOG:-slumber=trace} watchexec --restart \
--watch Cargo.toml --watch Cargo.lock --watch src/ \
-- cargo run \
-- $@
14 changes: 14 additions & 0 deletions src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ pub struct Profile {
)]
pub struct ProfileId(String);

/// Needed for persistence loading
impl PartialEq<Profile> for ProfileId {
fn eq(&self, other: &Profile) -> bool {
self == &other.id
}
}

/// The value type of a profile's data mapping
#[derive(Clone, Debug, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
Expand Down Expand Up @@ -136,6 +143,13 @@ pub struct RequestRecipe {
)]
pub struct RequestRecipeId(String);

/// Needed for persistence loading
impl PartialEq<RequestRecipe> for RequestRecipeId {
fn eq(&self, other: &RequestRecipe) -> bool {
self == &other.id
}
}

/// A chain is a means to data from one response in another request. The chain
/// is the middleman: it defines where and how to pull the value, then recipes
/// can use it in a template via `{{chains.<chain_id>}}`.
Expand Down
177 changes: 112 additions & 65 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::{
collection::{CollectionId, RequestRecipeId},
http::{Request, RequestId, RequestRecord, Response},
http::{RequestId, RequestRecord},
util::{Directory, ResultExt},
};
use anyhow::Context;
Expand All @@ -12,29 +12,30 @@ use rusqlite::{
Connection, OptionalExtension, Row, ToSql,
};
use rusqlite_migration::{Migrations, M};
use std::{ops::Deref, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
use serde::{de::DeserializeOwned, Serialize};
use std::{
fmt::{Debug, Display},
ops::Deref,
path::PathBuf,
sync::{Arc, Mutex},
};
use tracing::debug;
use uuid::Uuid;

/// A SQLite database for persisting data. Generally speaking, any error that
/// occurs *after* opening the DB connection should be an internal bug, but
/// should be shown to the user whenever possible. All operations are async
/// to enable concurrent accesses to yield. Do not call block on thisfrom the
/// draw phase; instead, cache the results in UI state for as long as they're
/// needed.
/// should be shown to the user whenever possible. All operations are blocking,
/// to enable calling from the view code. Do not call on every frame though,
/// cache results in UI state for as long as they're needed.
///
/// This uses an `Arc` internally, so it's safe and cheap to clone.
///
/// Note: Despite all the operations being async, the actual database isn't
/// async. Each operation will asynchronously wait for the connection mutex,
/// then block while performing the operation. This is just a shortcut, if it
/// becomes a bottleneck we can change that.
#[derive(Clone, Debug)]
pub struct Database {
/// Data is stored in a sqlite DB. Mutex is needed for multi-threaded
/// access. This is a bottleneck but the access rate should be so low that
/// it doesn't matter.
/// it doesn't matter. If it does become a bottleneck, we could spawn
/// one connection per thread, but the code would be a bit more
/// complicated.
connection: Arc<Mutex<Connection>>,
}

Expand All @@ -61,36 +62,48 @@ impl Database {

/// Apply database migrations
fn migrate(connection: &mut Connection) -> anyhow::Result<()> {
let migrations = Migrations::new(vec![M::up(
// The request state kind is a bit hard to map to tabular data.
// Everything that we need to query on (HTTP status code,
// end_time, etc.) is in its own column. The
// request/repsonse and response will be serialized into
// msgpack bytes
"CREATE TABLE requests (
id UUID PRIMARY KEY,
recipe_id TEXT,
start_time TEXT,
end_time TEXT,
request BLOB,
response BLOB,
status_code INTEGER
let migrations = Migrations::new(vec![
M::up(
// The request state kind is a bit hard to map to tabular data.
// Everything that we need to query on (HTTP status code,
// end_time, etc.) is in its own column. Therequest/response
// will be serialized into msgpack bytes
"CREATE TABLE requests (
id UUID PRIMARY KEY NOT NULL,
recipe_id TEXT NOT NULL,
start_time TEXT NOT NULL,
end_time TEXT NOT NULL,
request BLOB NOT NULL,
response BLOB NOT NULL,
status_code INTEGER NOT NULL
)",
)
.down("DROP TABLE requests")]);
)
.down("DROP TABLE requests"),
M::up(
// Values will be serialized as msgpack
"CREATE TABLE ui_state (
key TEXT PRIMARY KEY NOT NULL,
value BLOB NOT NULL
)",
)
.down("DROP TABLE ui_state"),
]);
migrations.to_latest(connection)?;
Ok(())
}

/// Get a reference to the DB connection. Panics if the lock is poisoned
fn connection(&self) -> impl '_ + Deref<Target = Connection> {
self.connection.lock().expect("Connection lock poisoned")
}

/// Get the most recent request+response for a recipe, or `None` if there
/// has never been one received.
pub async fn get_last_request(
pub fn get_last_request(
&self,
recipe_id: &RequestRecipeId,
) -> anyhow::Result<Option<RequestRecord>> {
self.connection
.lock()
.await
self.connection()
.query_row(
"SELECT * FROM requests WHERE recipe_id = ?1
ORDER BY start_time DESC LIMIT 1",
Expand All @@ -107,18 +120,13 @@ impl Database {
/// response should be stored. In-flight requests, invalid requests, and
/// requests that failed to complete (e.g. because of a network error)
/// should not (and cannot) be stored.
pub async fn insert_request(
&self,
record: &RequestRecord,
) -> anyhow::Result<()> {
pub fn insert_request(&self, record: &RequestRecord) -> anyhow::Result<()> {
debug!(
id = %record.id(),
url = %record.request.url,
"Adding request record to database",
);
self.connection
.lock()
.await
self.connection()
.execute(
"INSERT INTO
requests (
Expand All @@ -136,15 +144,57 @@ impl Database {
&record.request.recipe_id,
&record.start_time,
&record.end_time,
&record.request,
&record.response,
&Bytes(&record.request),
&Bytes(&record.response),
record.response.status.as_u16(),
),
)
.context("Error saving request to database")
.traced()?;
Ok(())
}

/// Get the value of a UI state field
pub fn get_ui<K, V>(&self, key: K) -> anyhow::Result<Option<V>>
where
K: Display,
V: Debug + DeserializeOwned,
{
let value = self
.connection()
.query_row(
"SELECT value FROM ui_state WHERE key = ?1",
(key.to_string(),),
|row| {
let value: Bytes<V> = row.get(0)?;
Ok(value.0)
},
)
.optional()
.context("Error fetching UI state from database")
.traced()?;
debug!(%key, ?value, "Fetched UI state");
Ok(value)
}

/// Set the value of a UI state field
pub fn set_ui<K, V>(&self, key: K, value: V) -> anyhow::Result<()>
where
K: Display,
V: Debug + Serialize,
{
debug!(%key, ?value, "Setting UI state");
self.connection()
.execute(
// Upsert!
"INSERT INTO ui_state VALUES (?1, ?2)
ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key.to_string(), Bytes(value)),
)
.context("Error saving UI state to database")
.traced()?;
Ok(())
}
}

/// Test-only helpers
Expand Down Expand Up @@ -184,30 +234,26 @@ impl FromSql for RequestRecipeId {
}
}

/// Macro to convert a serializable type to/from SQL via MessagePack
macro_rules! serial_sql {
($t:ty) => {
impl ToSql for $t {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let bytes = rmp_serde::to_vec(self).map_err(|err| {
rusqlite::Error::ToSqlConversionFailure(Box::new(err))
})?;
Ok(ToSqlOutput::Owned(bytes.into()))
}
}
/// A wrapper to serialize/deserialize a value as msgpack for DB storage
struct Bytes<T>(T);

impl FromSql for $t {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
rmp_serde::from_slice(bytes)
.map_err(|err| FromSqlError::Other(Box::new(err)))
}
}
};
impl<T: Serialize> ToSql for Bytes<T> {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let bytes = rmp_serde::to_vec(&self.0).map_err(|err| {
rusqlite::Error::ToSqlConversionFailure(Box::new(err))
})?;
Ok(ToSqlOutput::Owned(bytes.into()))
}
}

serial_sql!(Request);
serial_sql!(Response);
impl<T: DeserializeOwned> FromSql for Bytes<T> {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let value: T = rmp_serde::from_slice(bytes)
.map_err(|err| FromSqlError::Other(Box::new(err)))?;
Ok(Self(value))
}
}

/// Convert from `SELECT * FROM requests` to `RequestRecord`
impl<'a, 'b> TryFrom<&'a Row<'b>> for RequestRecord {
Expand All @@ -218,8 +264,9 @@ impl<'a, 'b> TryFrom<&'a Row<'b>> for RequestRecord {
id: row.get("id")?,
start_time: row.get("start_time")?,
end_time: row.get("end_time")?,
request: row.get("request")?,
response: row.get("response")?,
// Deserialize from bytes
request: row.get::<_, Bytes<_>>("request")?.0,
response: row.get::<_, Bytes<_>>("response")?.0,
})
}
}
2 changes: 1 addition & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl HttpEngine {
};

// Error here should *not* kill the request
let _ = self.database.insert_request(&record).await;
let _ = self.database.insert_request(&record);
Ok(record)
}
Err(error) => Err(RequestError {
Expand Down
2 changes: 0 additions & 2 deletions src/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ mod tests {
.insert_request(
&create!(RequestRecord, request: request, response: response),
)
.await
.unwrap();
let selector = selector.map(|s| s.parse().unwrap());
let chains = indexmap! {"chain1".into() => create!(
Expand Down Expand Up @@ -348,7 +347,6 @@ mod tests {
database
.insert_request(&create!(
RequestRecord, request: request, response: response))
.await
.unwrap();
}
let chains = indexmap! {chain_id.into() => chain};
Expand Down
1 change: 0 additions & 1 deletion src/template/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ impl<'a> ChainTemplateSource<'a> {
let record = context
.database
.get_last_request(recipe_id)
.await
.map_err(ChainError::Database)?
.ok_or(ChainError::NoResponse)?;

Expand Down
Loading

0 comments on commit 0abb0b3

Please sign in to comment.