Skip to content

Commit

Permalink
Merge pull request #2722 from karthik2804/mqtt_factors
Browse files Browse the repository at this point in the history
Add outbound MQTT factor
  • Loading branch information
lann authored Aug 19, 2024
2 parents 69c3cd8 + 398cd3f commit 871918b
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 0 deletions.
17 changes: 17 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions crates/factor-outbound-mqtt/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "spin-factor-outbound-mqtt"
version = { workspace = true }
authors = { workspace = true }
edition = { workspace = true }

[dependencies]
anyhow = "1.0"
rumqttc = { version = "0.24", features = ["url"] }
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-core = { path = "../core" }
spin-world = { path = "../world" }
table = { path = "../table" }
tokio = { version = "1.0", features = ["sync"] }
tracing = { workspace = true }

[dev-dependencies]
spin-factor-variables = { path = "../factor-variables" }
spin-factors-test = { path = "../factors-test" }
tokio = { version = "1", features = ["macros", "rt"] }

[lints]
workspace = true
133 changes: 133 additions & 0 deletions crates/factor-outbound-mqtt/src/host.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::{sync::Arc, time::Duration};

use anyhow::Result;
use spin_core::{async_trait, wasmtime::component::Resource};
use spin_factor_outbound_networking::OutboundAllowedHosts;
use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
use tracing::{instrument, Level};

#[async_trait]
pub trait ClientCreator: Send + Sync {
fn create(
&self,
address: String,
username: String,
password: String,
keep_alive_interval: Duration,
) -> Result<Arc<dyn MqttClient>, Error>;
}

pub struct InstanceState {
allowed_hosts: OutboundAllowedHosts,
connections: table::Table<Arc<dyn MqttClient>>,
create_client: Arc<dyn ClientCreator>,
}

impl InstanceState {
pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
Self {
allowed_hosts,
create_client,
connections: table::Table::new(1024),
}
}
}

#[async_trait]
pub trait MqttClient: Send + Sync {
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
}

impl InstanceState {
async fn is_address_allowed(&self, address: &str) -> Result<bool> {
self.allowed_hosts.check_url(address, "mqtt").await
}

async fn establish_connection(
&mut self,
address: String,
username: String,
password: String,
keep_alive_interval: Duration,
) -> Result<Resource<Connection>, Error> {
self.connections
.push((self.create_client).create(address, username, password, keep_alive_interval)?)
.map(Resource::new_own)
.map_err(|_| Error::TooManyConnections)
}

async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
self.connections
.get(connection.rep())
.ok_or(Error::Other(
"could not find connection for resource".into(),
))
.map(|c| c.as_ref())
}
}

impl v2::Host for InstanceState {
fn convert_error(&mut self, error: Error) -> Result<Error> {
Ok(error)
}
}

#[async_trait]
impl v2::HostConnection for InstanceState {
#[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
async fn open(
&mut self,
address: String,
username: String,
password: String,
keep_alive_interval: u64,
) -> Result<Resource<Connection>, Error> {
if !self
.is_address_allowed(&address)
.await
.map_err(|e| v2::Error::Other(e.to_string()))?
{
return Err(v2::Error::ConnectionFailed(format!(
"address {address} is not permitted"
)));
}
self.establish_connection(
address,
username,
password,
Duration::from_secs(keep_alive_interval),
)
.await
}

/// Publish a message to the MQTT broker.
///
/// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
/// current trace context into the payload yourself.
/// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
#[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
messaging.system = "mqtt"))]
async fn publish(
&mut self,
connection: Resource<Connection>,
topic: String,
payload: Vec<u8>,
qos: Qos,
) -> Result<(), Error> {
let conn = self.get_conn(connection).await.map_err(other_error)?;

conn.publish_bytes(topic, qos, payload).await?;

Ok(())
}

fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

pub fn other_error(e: impl std::fmt::Display) -> Error {
Error::Other(e.to_string())
}
129 changes: 129 additions & 0 deletions crates/factor-outbound-mqtt/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
mod host;

use std::sync::Arc;
use std::time::Duration;

use host::other_error;
use host::InstanceState;
use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
use spin_core::async_trait;
use spin_factor_outbound_networking::OutboundNetworkingFactor;
use spin_factors::{
ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors,
SelfInstanceBuilder,
};
use spin_world::v2::mqtt::{self as v2, Error, Qos};
use tokio::sync::Mutex;

pub use host::{ClientCreator, MqttClient};

pub struct OutboundMqttFactor {
create_client: Arc<dyn ClientCreator>,
}

impl OutboundMqttFactor {
pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
Self { create_client }
}
}

impl Factor for OutboundMqttFactor {
type RuntimeConfig = ();
type AppState = ();
type InstanceBuilder = InstanceState;

fn init<T: Send + 'static>(
&mut self,
mut ctx: spin_factors::InitContext<T, Self>,
) -> anyhow::Result<()> {
ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?;
Ok(())
}

fn configure_app<T: RuntimeFactors>(
&self,
_ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(())
}

fn prepare<T: RuntimeFactors>(
&self,
_ctx: PrepareContext<Self>,
builders: &mut InstanceBuilders<T>,
) -> anyhow::Result<Self::InstanceBuilder> {
let allowed_hosts = builders
.get_mut::<OutboundNetworkingFactor>()?
.allowed_hosts();
Ok(InstanceState::new(
allowed_hosts,
self.create_client.clone(),
))
}
}

impl SelfInstanceBuilder for InstanceState {}

// This is a concrete implementation of the MQTT client using rumqttc.
pub struct NetworkedMqttClient {
inner: rumqttc::AsyncClient,
event_loop: Mutex<rumqttc::EventLoop>,
}

const MQTT_CHANNEL_CAP: usize = 1000;

impl NetworkedMqttClient {
pub fn create(
address: String,
username: String,
password: String,
keep_alive_interval: Duration,
) -> Result<Self, Error> {
let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
tracing::error!("MQTT URL parse error: {e:?}");
Error::InvalidAddress
})?;
conn_opts.set_credentials(username, password);
conn_opts.set_keep_alive(keep_alive_interval);
let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
Ok(Self {
inner: client,
event_loop: Mutex::new(event_loop),
})
}
}

#[async_trait]
impl MqttClient for NetworkedMqttClient {
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error> {
let qos = match qos {
Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
};
// Message published to EventLoop (not MQTT Broker)
self.inner
.publish_bytes(topic, qos, false, payload.into())
.await
.map_err(other_error)?;

// Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
// We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
let mut lock = self.event_loop.lock().await;
loop {
let event = lock
.poll()
.await
.map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;

match (qos, event) {
(QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
| (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
| (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,

(_, _) => continue,
}
}
Ok(())
}
}
Loading

0 comments on commit 871918b

Please sign in to comment.