Skip to content

Commit

Permalink
Merge pull request #2709 from fermyon/new-connection-reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
rylev authored Aug 12, 2024
2 parents 250263b + 1ce0b7c commit 23e6a23
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 79 deletions.
26 changes: 15 additions & 11 deletions crates/factor-sqlite/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use spin_factors::{anyhow, SelfInstanceBuilder};
use spin_world::v1::sqlite as v1;
use spin_world::v2::sqlite as v2;

use crate::{Connection, ConnectionPool};
use crate::{Connection, ConnectionCreator};

pub struct InstanceState {
allowed_databases: Arc<HashSet<String>>,
connections: table::Table<Arc<dyn Connection>>,
get_pool: ConnectionPoolGetter,
connections: table::Table<Box<dyn Connection>>,
get_connection_creator: ConnectionCreatorGetter,
}

impl InstanceState {
Expand All @@ -22,25 +22,29 @@ impl InstanceState {
}
}

/// A function that takes a database label and returns a connection pool, if one exists.
pub type ConnectionPoolGetter = Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionPool>> + Send + Sync>;
/// A function that takes a database label and returns a connection creator, if one exists.
pub type ConnectionCreatorGetter =
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;

impl InstanceState {
/// Create a new `InstanceState`
///
/// Takes the list of allowed databases, and a function for getting a connection pool given a database label.
pub fn new(allowed_databases: Arc<HashSet<String>>, get_pool: ConnectionPoolGetter) -> Self {
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
pub fn new(
allowed_databases: Arc<HashSet<String>>,
get_connection_creator: ConnectionCreatorGetter,
) -> Self {
Self {
allowed_databases,
connections: table::Table::new(256),
get_pool,
get_connection_creator,
}
}

fn get_connection(
&self,
connection: Resource<v2::Connection>,
) -> Result<&Arc<dyn Connection>, v2::Error> {
) -> Result<&Box<dyn Connection>, v2::Error> {
self.connections
.get(connection.rep())
.ok_or(v2::Error::InvalidConnection)
Expand All @@ -61,9 +65,9 @@ impl v2::HostConnection for InstanceState {
if !self.allowed_databases.contains(&database) {
return Err(v2::Error::AccessDenied);
}
(self.get_pool)(&database)
(self.get_connection_creator)(&database)
.ok_or(v2::Error::NoSuchDatabase)?
.get_connection()
.create_connection()
.await
.and_then(|conn| {
self.connections
Expand Down
62 changes: 28 additions & 34 deletions crates/factor-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ impl Factor for SqliteFactor {
&self,
mut ctx: spin_factors::ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
let connection_pools = ctx
let connection_creators = ctx
.take_runtime_config()
.map(|r| r.pools)
.map(|r| r.connection_creators)
.unwrap_or_default();

let allowed_databases = ctx
Expand All @@ -68,20 +68,20 @@ impl Factor for SqliteFactor {
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;
let resolver = self.default_label_resolver.clone();
let get_connection_pool: host::ConnectionPoolGetter = Arc::new(move |label| {
connection_pools
let get_connection_creator: host::ConnectionCreatorGetter = Arc::new(move |label| {
connection_creators
.get(label)
.cloned()
.or_else(|| resolver.default(label))
});

ensure_allowed_databases_are_configured(&allowed_databases, |label| {
get_connection_pool(label).is_some()
get_connection_creator(label).is_some()
})?;

Ok(AppState {
allowed_databases,
get_connection_pool,
get_connection_creator,
})
}

Expand All @@ -96,8 +96,11 @@ impl Factor for SqliteFactor {
.get(ctx.app_component().id())
.cloned()
.unwrap_or_default();
let get_connection_pool = ctx.app_state().get_connection_pool.clone();
Ok(InstanceState::new(allowed_databases, get_connection_pool))
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
Ok(InstanceState::new(
allowed_databases,
get_connection_creator,
))
}
}

Expand Down Expand Up @@ -136,46 +139,37 @@ fn ensure_allowed_databases_are_configured(

pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");

/// Resolves a label to a default connection pool.
/// Resolves a label to a default connection creator.
pub trait DefaultLabelResolver: Send + Sync {
/// If there is no runtime configuration for a given database label, return a default connection pool.
/// If there is no runtime configuration for a given database label, return a default connection creator.
///
/// If `Option::None` is returned, the database is not allowed.
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionPool>>;
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>>;
}

pub struct AppState {
/// A map from component id to a set of allowed database labels.
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
/// A function for mapping from database name to a connection pool
get_connection_pool: host::ConnectionPoolGetter,
/// A function for mapping from database name to a connection creator.
get_connection_creator: host::ConnectionCreatorGetter,
}

/// A pool of connections for a particular SQLite database
/// A creator of a connections for a particular SQLite database.
#[async_trait]
pub trait ConnectionPool: Send + Sync {
/// Get a `Connection` from the pool
async fn get_connection(&self) -> Result<Arc<dyn Connection + 'static>, v2::Error>;
}

/// A simple [`ConnectionPool`] that always creates a new connection.
pub struct SimpleConnectionPool(
Box<dyn Fn() -> anyhow::Result<Arc<dyn Connection + 'static>> + Send + Sync>,
);

impl SimpleConnectionPool {
/// Create a new `SimpleConnectionPool` with the given connection factory.
pub fn new(
factory: impl Fn() -> anyhow::Result<Arc<dyn Connection + 'static>> + Send + Sync + 'static,
) -> Self {
Self(Box::new(factory))
}
pub trait ConnectionCreator: Send + Sync {
/// Get a *new* [`Connection`]
///
/// The connection should be a new connection, not a reused one.
async fn create_connection(&self) -> Result<Box<dyn Connection + 'static>, v2::Error>;
}

#[async_trait::async_trait]
impl ConnectionPool for SimpleConnectionPool {
async fn get_connection(&self) -> Result<Arc<dyn Connection + 'static>, v2::Error> {
(self.0)().map_err(|_| v2::Error::InvalidConnection)
impl<F> ConnectionCreator for F
where
F: Fn() -> anyhow::Result<Box<dyn Connection + 'static>> + Send + Sync + 'static,
{
async fn create_connection(&self) -> Result<Box<dyn Connection + 'static>, v2::Error> {
(self)().map_err(|_| v2::Error::InvalidConnection)
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/factor-sqlite/src/runtime_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ pub mod spin;

use std::{collections::HashMap, sync::Arc};

use crate::ConnectionPool;
use crate::ConnectionCreator;

/// A runtime configuration for SQLite databases.
///
/// Maps database labels to connection pools.
/// Maps database labels to connection creators.
pub struct RuntimeConfig {
pub pools: HashMap<String, Arc<dyn ConnectionPool>>,
pub connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}
51 changes: 28 additions & 23 deletions crates/factor-sqlite/src/runtime_config/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use spin_factors::{
use spin_world::v2::sqlite as v2;
use tokio::sync::OnceCell;

use crate::{Connection, ConnectionPool, DefaultLabelResolver, SimpleConnectionPool};
use crate::{Connection, ConnectionCreator, DefaultLabelResolver};

/// Spin's default handling of the runtime configuration for SQLite databases.
///
Expand Down Expand Up @@ -66,28 +66,34 @@ impl SpinSqliteRuntimeConfig {
return Ok(None);
};
let config: std::collections::HashMap<String, RuntimeConfig> = table.clone().try_into()?;
let pools = config
let connection_creators = config
.into_iter()
.map(|(k, v)| Ok((k, self.get_pool(v)?)))
.map(|(k, v)| Ok((k, self.get_connection_creator(v)?)))
.collect::<anyhow::Result<_>>()?;
Ok(Some(super::RuntimeConfig { pools }))
Ok(Some(super::RuntimeConfig {
connection_creators,
}))
}

/// Get a connection pool for a given runtime configuration.
pub fn get_pool(&self, config: RuntimeConfig) -> anyhow::Result<Arc<dyn ConnectionPool>> {
/// Get a connection creator for a given runtime configuration.
pub fn get_connection_creator(
&self,
config: RuntimeConfig,
) -> anyhow::Result<Arc<dyn ConnectionCreator>> {
let database_kind = config.type_.as_str();
let pool = match database_kind {
match database_kind {
"spin" => {
let config: LocalDatabase = config.config.try_into()?;
config.pool(&self.local_database_dir)?
Ok(Arc::new(
config.connection_creator(&self.local_database_dir)?,
))
}
"libsql" => {
let config: LibSqlDatabase = config.config.try_into()?;
config.pool()?
Ok(Arc::new(config.connection_creator()?))
}
_ => anyhow::bail!("Unknown database kind: {database_kind}"),
};
Ok(Arc::new(pool))
}
}
}

Expand All @@ -100,7 +106,7 @@ pub struct RuntimeConfig {
}

impl DefaultLabelResolver for SpinSqliteRuntimeConfig {
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionPool>> {
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>> {
// Only default the database labeled "default".
if label != "default" {
return None;
Expand All @@ -110,10 +116,9 @@ impl DefaultLabelResolver for SpinSqliteRuntimeConfig {
let factory = move || {
let location = spin_sqlite_inproc::InProcDatabaseLocation::Path(path.clone());
let connection = spin_sqlite_inproc::InProcConnection::new(location)?;
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
let pool = SimpleConnectionPool::new(factory);
Some(Arc::new(pool))
Some(Arc::new(factory))
}
}

Expand Down Expand Up @@ -196,10 +201,10 @@ pub struct LocalDatabase {
}

impl LocalDatabase {
/// Create a new connection pool for a local database.
/// Get a new connection creator for a local database.
///
/// `base_dir` is the base directory path from which `path` is resolved if it is a relative path.
fn pool(self, base_dir: &Path) -> anyhow::Result<SimpleConnectionPool> {
fn connection_creator(self, base_dir: &Path) -> anyhow::Result<impl ConnectionCreator> {
let location = match self.path {
Some(path) => {
let path = resolve_relative_path(&path, base_dir);
Expand All @@ -213,9 +218,9 @@ impl LocalDatabase {
};
let factory = move || {
let connection = spin_sqlite_inproc::InProcConnection::new(location.clone())?;
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
Ok(SimpleConnectionPool::new(factory))
Ok(factory)
}
}

Expand All @@ -238,8 +243,8 @@ pub struct LibSqlDatabase {
}

impl LibSqlDatabase {
/// Create a new connection pool for a libSQL database.
fn pool(self) -> anyhow::Result<SimpleConnectionPool> {
/// Get a new connection creator for a libSQL database.
fn connection_creator(self) -> anyhow::Result<impl ConnectionCreator> {
let url = check_url(&self.url)
.with_context(|| {
format!(
Expand All @@ -250,9 +255,9 @@ impl LibSqlDatabase {
.to_owned();
let factory = move || {
let connection = LibSqlConnection::new(url.clone(), self.token.clone());
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
Ok(SimpleConnectionPool::new(factory))
Ok(factory)
}
}

Expand Down
16 changes: 8 additions & 8 deletions crates/factor-sqlite/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl TryFrom<TomlRuntimeSource<'_>> for TestFactorsRuntimeConfig {
}
}

/// Will return an `InvalidConnectionPool` for the supplied default database.
/// Will return an `InvalidConnectionCreator` for the supplied default database.
struct DefaultLabelResolver {
default: Option<String>,
}
Expand All @@ -130,22 +130,22 @@ impl DefaultLabelResolver {
}

impl factor_sqlite::DefaultLabelResolver for DefaultLabelResolver {
fn default(&self, label: &str) -> Option<Arc<dyn factor_sqlite::ConnectionPool>> {
fn default(&self, label: &str) -> Option<Arc<dyn factor_sqlite::ConnectionCreator>> {
let Some(default) = &self.default else {
return None;
};
(default == label).then_some(Arc::new(InvalidConnectionPool))
(default == label).then_some(Arc::new(InvalidConnectionCreator))
}
}

/// A connection pool that always returns an error.
struct InvalidConnectionPool;
/// A connection creator that always returns an error.
struct InvalidConnectionCreator;

#[async_trait::async_trait]
impl factor_sqlite::ConnectionPool for InvalidConnectionPool {
async fn get_connection(
impl factor_sqlite::ConnectionCreator for InvalidConnectionCreator {
async fn create_connection(
&self,
) -> Result<Arc<dyn factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error> {
) -> Result<Box<dyn factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error> {
Err(spin_world::v2::sqlite::Error::InvalidConnection)
}
}

0 comments on commit 23e6a23

Please sign in to comment.