From ec47aedb1a0c7e9eecfa7a06032c50551c93d588 Mon Sep 17 00:00:00 2001 From: karthik2804 Date: Mon, 5 Aug 2024 08:44:27 +0200 Subject: [PATCH 1/2] Add outbound MQTT factor Co-authored-by: rylev Signed-off-by: karthik2804 --- Cargo.lock | 17 +++ crates/factor-outbound-mqtt/Cargo.toml | 24 ++++ crates/factor-outbound-mqtt/src/host.rs | 131 ++++++++++++++++++ crates/factor-outbound-mqtt/src/lib.rs | 128 +++++++++++++++++ .../factor-outbound-mqtt/tests/factor_test.rs | 119 ++++++++++++++++ 5 files changed, 419 insertions(+) create mode 100644 crates/factor-outbound-mqtt/Cargo.toml create mode 100644 crates/factor-outbound-mqtt/src/host.rs create mode 100644 crates/factor-outbound-mqtt/src/lib.rs create mode 100644 crates/factor-outbound-mqtt/tests/factor_test.rs diff --git a/Cargo.lock b/Cargo.lock index e66bbf0d08..32d7e70fe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7715,6 +7715,23 @@ dependencies = [ "wasmtime-wasi-http", ] +[[package]] +name = "spin-factor-outbound-mqtt" +version = "2.7.0-pre0" +dependencies = [ + "anyhow", + "rumqttc", + "spin-core", + "spin-factor-outbound-networking", + "spin-factor-variables", + "spin-factors", + "spin-factors-test", + "spin-world", + "table", + "tokio", + "tracing", +] + [[package]] name = "spin-factor-outbound-mysql" version = "2.7.0-pre0" diff --git a/crates/factor-outbound-mqtt/Cargo.toml b/crates/factor-outbound-mqtt/Cargo.toml new file mode 100644 index 0000000000..76c44511e2 --- /dev/null +++ b/crates/factor-outbound-mqtt/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "spin-factor-outbound-mqtt" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } + +[dependencies] +anyhow = "1.0" +rumqttc = { version = "0.24", features = ["url"] } +spin-factor-outbound-networking = { path = "../factor-outbound-networking" } +spin-factors = { path = "../factors" } +spin-core = { path = "../core" } +spin-world = { path = "../world" } +tracing = { workspace = true } +table = { path = "../table" } +tokio = { version = "1.0", features = ["sync"] } + +[dev-dependencies] +spin-factor-variables = { path = "../factor-variables" } +spin-factors-test = { path = "../factors-test" } +tokio = { version = "1", features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/crates/factor-outbound-mqtt/src/host.rs b/crates/factor-outbound-mqtt/src/host.rs new file mode 100644 index 0000000000..a6d0a1b0cb --- /dev/null +++ b/crates/factor-outbound-mqtt/src/host.rs @@ -0,0 +1,131 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use spin_core::{async_trait, wasmtime::component::Resource}; +use spin_factor_outbound_networking::OutboundAllowedHosts; +use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos}; +use tracing::{instrument, Level}; + +pub type CreateClient = Arc< + dyn Fn(String, String, String, Duration) -> Result, Error> + Send + Sync, +>; + +pub struct InstanceState { + pub allowed_hosts: OutboundAllowedHosts, + pub connections: table::Table>, + pub create_client: CreateClient, +} + +impl InstanceState { + pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: CreateClient) -> Self { + Self { + allowed_hosts, + create_client, + connections: table::Table::new(1024), + } + } +} + +#[async_trait] +pub trait MqttClient: Send + Sync { + async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec) -> Result<(), Error>; +} + +impl InstanceState { + async fn is_address_allowed(&self, address: &str) -> Result { + self.allowed_hosts.check_url(address, "mqtt").await + } + + async fn establish_connection( + &mut self, + address: String, + username: String, + password: String, + keep_alive_interval: Duration, + ) -> Result, Error> { + self.connections + .push((self.create_client)( + address, + username, + password, + keep_alive_interval, + )?) + .map(Resource::new_own) + .map_err(|_| Error::TooManyConnections) + } + + async fn get_conn(&self, connection: Resource) -> Result<&dyn MqttClient, Error> { + self.connections + .get(connection.rep()) + .ok_or(Error::Other( + "could not find connection for resource".into(), + )) + .map(|c| c.as_ref()) + } +} + +impl v2::Host for InstanceState { + fn convert_error(&mut self, error: Error) -> Result { + Ok(error) + } +} + +#[async_trait] +impl v2::HostConnection for InstanceState { + #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))] + async fn open( + &mut self, + address: String, + username: String, + password: String, + keep_alive_interval: u64, + ) -> Result, Error> { + if !self + .is_address_allowed(&address) + .await + .map_err(|e| v2::Error::Other(e.to_string()))? + { + return Err(v2::Error::ConnectionFailed(format!( + "address {address} is not permitted" + ))); + } + self.establish_connection( + address, + username, + password, + Duration::from_secs(keep_alive_interval), + ) + .await + } + + /// Publish a message to the MQTT broker. + /// + /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the + /// current trace context into the payload yourself. + /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation. + #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO), + fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish", + messaging.system = "mqtt"))] + async fn publish( + &mut self, + connection: Resource, + topic: String, + payload: Vec, + qos: Qos, + ) -> Result<(), Error> { + let conn = self.get_conn(connection).await.map_err(other_error)?; + + conn.publish_bytes(topic, qos, payload).await?; + + Ok(()) + } + + fn drop(&mut self, connection: Resource) -> anyhow::Result<()> { + self.connections.remove(connection.rep()); + Ok(()) + } +} + +pub fn other_error(e: impl std::fmt::Display) -> Error { + Error::Other(e.to_string()) +} diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs new file mode 100644 index 0000000000..db63318db4 --- /dev/null +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -0,0 +1,128 @@ +mod host; + +use std::time::Duration; + +use host::other_error; +use host::CreateClient; +use host::InstanceState; +use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; +use spin_core::async_trait; +use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factors::{ + ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, + SelfInstanceBuilder, +}; +use spin_world::v2::mqtt::{self as v2, Error, Qos}; +use tokio::sync::Mutex; + +pub use host::MqttClient; + +pub struct OutboundMqttFactor { + create_client: CreateClient, +} + +impl OutboundMqttFactor { + pub fn new(create_client: CreateClient) -> Self { + Self { create_client } + } +} + +impl Factor for OutboundMqttFactor { + type RuntimeConfig = (); + type AppState = (); + type InstanceBuilder = InstanceState; + + fn init( + &mut self, + mut ctx: spin_factors::InitContext, + ) -> anyhow::Result<()> { + ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?; + Ok(()) + } + + fn configure_app( + &self, + _ctx: ConfigureAppContext, + ) -> anyhow::Result { + Ok(()) + } + + fn prepare( + &self, + _ctx: PrepareContext, + builders: &mut InstanceBuilders, + ) -> anyhow::Result { + let allowed_hosts = builders + .get_mut::()? + .allowed_hosts(); + Ok(InstanceState::new( + allowed_hosts, + self.create_client.clone(), + )) + } +} + +impl SelfInstanceBuilder for InstanceState {} + +pub struct NetworkedMqttClient { + inner: rumqttc::AsyncClient, + event_loop: Mutex, +} + +const MQTT_CHANNEL_CAP: usize = 1000; + +impl NetworkedMqttClient { + pub fn create( + address: String, + username: String, + password: String, + keep_alive_interval: Duration, + ) -> Result { + let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| { + tracing::error!("MQTT URL parse error: {e:?}"); + Error::InvalidAddress + })?; + conn_opts.set_credentials(username, password); + conn_opts.set_keep_alive(keep_alive_interval); + let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP); + Ok(Self { + inner: client, + event_loop: Mutex::new(event_loop), + }) + } +} + +#[async_trait] +impl MqttClient for NetworkedMqttClient { + async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec) -> Result<(), Error> { + let qos = match qos { + Qos::AtMostOnce => rumqttc::QoS::AtMostOnce, + Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce, + Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce, + }; + // Message published to EventLoop (not MQTT Broker) + self.inner + .publish_bytes(topic, qos, false, payload.into()) + .await + .map_err(other_error)?; + + // Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error. + // We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool. + let mut lock = self.event_loop.lock().await; + loop { + let event = lock + .poll() + .await + .map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?; + + match (qos, event) { + (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_))) + | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_))) + | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break, + + (_, _) => continue, + } + } + Ok(()) + } +} diff --git a/crates/factor-outbound-mqtt/tests/factor_test.rs b/crates/factor-outbound-mqtt/tests/factor_test.rs new file mode 100644 index 0000000000..178d17a0e5 --- /dev/null +++ b/crates/factor-outbound-mqtt/tests/factor_test.rs @@ -0,0 +1,119 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use spin_core::async_trait; +use spin_factor_outbound_mqtt::{MqttClient, OutboundMqttFactor}; +use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_variables::VariablesFactor; +use spin_factors::{anyhow, RuntimeFactors}; +use spin_factors_test::{toml, TestEnvironment}; +use spin_world::v2::mqtt::{self as v2, Error, HostConnection, Qos}; + +pub struct MockMqttClient {} + +#[async_trait] +impl MqttClient for MockMqttClient { + async fn publish_bytes( + &self, + _topic: String, + _qos: Qos, + _payload: Vec, + ) -> Result<(), Error> { + Ok(()) + } +} + +#[derive(RuntimeFactors)] +struct TestFactors { + variables: VariablesFactor, + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor, +} + +fn factors() -> TestFactors { + TestFactors { + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor::new(Arc::new(|_, _, _, _| Ok(Box::new(MockMqttClient {})))), + } +} + +fn test_env() -> TestEnvironment { + TestEnvironment::new(factors()).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*"] + }) +} + +#[tokio::test] +async fn disallowed_host_fails() -> anyhow::Result<()> { + let env = TestEnvironment::new(factors()).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + }); + let mut state = env.build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await; + let Err(err) = res else { + bail!("expected Err, got Ok"); + }; + assert!(matches!(err, v2::Error::ConnectionFailed(_))); + + Ok(()) +} + +#[tokio::test] +async fn allowed_host_succeeds() -> anyhow::Result<()> { + let mut state = test_env().build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await; + let Ok(_) = res else { + bail!("expected Ok, got Err"); + }; + + Ok(()) +} + +#[tokio::test] +async fn exercise_publish() -> anyhow::Result<()> { + let mut state = test_env().build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + state + .mqtt + .publish( + res, + "message".to_string(), + b"test message".to_vec(), + Qos::ExactlyOnce, + ) + .await?; + + Ok(()) +} From 398cd3f0198a4c6949d5165ded17eb21259938b9 Mon Sep 17 00:00:00 2001 From: karthik2804 Date: Mon, 19 Aug 2024 15:22:40 +0200 Subject: [PATCH 2/2] use custom trait object instead of closure trait Signed-off-by: karthik2804 --- crates/factor-outbound-mqtt/Cargo.toml | 2 +- crates/factor-outbound-mqtt/src/host.rs | 28 ++++++++++--------- crates/factor-outbound-mqtt/src/lib.rs | 9 +++--- .../factor-outbound-mqtt/tests/factor_test.rs | 17 +++++++++-- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/crates/factor-outbound-mqtt/Cargo.toml b/crates/factor-outbound-mqtt/Cargo.toml index 76c44511e2..95d7dce534 100644 --- a/crates/factor-outbound-mqtt/Cargo.toml +++ b/crates/factor-outbound-mqtt/Cargo.toml @@ -11,9 +11,9 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-core = { path = "../core" } spin-world = { path = "../world" } -tracing = { workspace = true } table = { path = "../table" } tokio = { version = "1.0", features = ["sync"] } +tracing = { workspace = true } [dev-dependencies] spin-factor-variables = { path = "../factor-variables" } diff --git a/crates/factor-outbound-mqtt/src/host.rs b/crates/factor-outbound-mqtt/src/host.rs index a6d0a1b0cb..3cd22abbd0 100644 --- a/crates/factor-outbound-mqtt/src/host.rs +++ b/crates/factor-outbound-mqtt/src/host.rs @@ -6,18 +6,25 @@ use spin_factor_outbound_networking::OutboundAllowedHosts; use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos}; use tracing::{instrument, Level}; -pub type CreateClient = Arc< - dyn Fn(String, String, String, Duration) -> Result, Error> + Send + Sync, ->; +#[async_trait] +pub trait ClientCreator: Send + Sync { + fn create( + &self, + address: String, + username: String, + password: String, + keep_alive_interval: Duration, + ) -> Result, Error>; +} pub struct InstanceState { - pub allowed_hosts: OutboundAllowedHosts, - pub connections: table::Table>, - pub create_client: CreateClient, + allowed_hosts: OutboundAllowedHosts, + connections: table::Table>, + create_client: Arc, } impl InstanceState { - pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: CreateClient) -> Self { + pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc) -> Self { Self { allowed_hosts, create_client, @@ -44,12 +51,7 @@ impl InstanceState { keep_alive_interval: Duration, ) -> Result, Error> { self.connections - .push((self.create_client)( - address, - username, - password, - keep_alive_interval, - )?) + .push((self.create_client).create(address, username, password, keep_alive_interval)?) .map(Resource::new_own) .map_err(|_| Error::TooManyConnections) } diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs index db63318db4..4816e12bec 100644 --- a/crates/factor-outbound-mqtt/src/lib.rs +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -1,9 +1,9 @@ mod host; +use std::sync::Arc; use std::time::Duration; use host::other_error; -use host::CreateClient; use host::InstanceState; use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; use spin_core::async_trait; @@ -15,14 +15,14 @@ use spin_factors::{ use spin_world::v2::mqtt::{self as v2, Error, Qos}; use tokio::sync::Mutex; -pub use host::MqttClient; +pub use host::{ClientCreator, MqttClient}; pub struct OutboundMqttFactor { - create_client: CreateClient, + create_client: Arc, } impl OutboundMqttFactor { - pub fn new(create_client: CreateClient) -> Self { + pub fn new(create_client: Arc) -> Self { Self { create_client } } } @@ -64,6 +64,7 @@ impl Factor for OutboundMqttFactor { impl SelfInstanceBuilder for InstanceState {} +// This is a concrete implementation of the MQTT client using rumqttc. pub struct NetworkedMqttClient { inner: rumqttc::AsyncClient, event_loop: Mutex, diff --git a/crates/factor-outbound-mqtt/tests/factor_test.rs b/crates/factor-outbound-mqtt/tests/factor_test.rs index 178d17a0e5..1d88b3e574 100644 --- a/crates/factor-outbound-mqtt/tests/factor_test.rs +++ b/crates/factor-outbound-mqtt/tests/factor_test.rs @@ -1,8 +1,9 @@ use std::sync::Arc; +use std::time::Duration; use anyhow::{bail, Result}; use spin_core::async_trait; -use spin_factor_outbound_mqtt::{MqttClient, OutboundMqttFactor}; +use spin_factor_outbound_mqtt::{ClientCreator, MqttClient, OutboundMqttFactor}; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; @@ -23,6 +24,18 @@ impl MqttClient for MockMqttClient { } } +impl ClientCreator for MockMqttClient { + fn create( + &self, + _address: String, + _username: String, + _password: String, + _keep_alive_interval: Duration, + ) -> Result, Error> { + Ok(Arc::new(MockMqttClient {})) + } +} + #[derive(RuntimeFactors)] struct TestFactors { variables: VariablesFactor, @@ -34,7 +47,7 @@ fn factors() -> TestFactors { TestFactors { variables: VariablesFactor::default(), networking: OutboundNetworkingFactor, - mqtt: OutboundMqttFactor::new(Arc::new(|_, _, _, _| Ok(Box::new(MockMqttClient {})))), + mqtt: OutboundMqttFactor::new(Arc::new(MockMqttClient {})), } }