-
Notifications
You must be signed in to change notification settings - Fork 255
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2722 from karthik2804/mqtt_factors
Add outbound MQTT factor
- Loading branch information
Showing
5 changed files
with
435 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
[package] | ||
name = "spin-factor-outbound-mqtt" | ||
version = { workspace = true } | ||
authors = { workspace = true } | ||
edition = { workspace = true } | ||
|
||
[dependencies] | ||
anyhow = "1.0" | ||
rumqttc = { version = "0.24", features = ["url"] } | ||
spin-factor-outbound-networking = { path = "../factor-outbound-networking" } | ||
spin-factors = { path = "../factors" } | ||
spin-core = { path = "../core" } | ||
spin-world = { path = "../world" } | ||
table = { path = "../table" } | ||
tokio = { version = "1.0", features = ["sync"] } | ||
tracing = { workspace = true } | ||
|
||
[dev-dependencies] | ||
spin-factor-variables = { path = "../factor-variables" } | ||
spin-factors-test = { path = "../factors-test" } | ||
tokio = { version = "1", features = ["macros", "rt"] } | ||
|
||
[lints] | ||
workspace = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
use std::{sync::Arc, time::Duration}; | ||
|
||
use anyhow::Result; | ||
use spin_core::{async_trait, wasmtime::component::Resource}; | ||
use spin_factor_outbound_networking::OutboundAllowedHosts; | ||
use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos}; | ||
use tracing::{instrument, Level}; | ||
|
||
#[async_trait] | ||
pub trait ClientCreator: Send + Sync { | ||
fn create( | ||
&self, | ||
address: String, | ||
username: String, | ||
password: String, | ||
keep_alive_interval: Duration, | ||
) -> Result<Arc<dyn MqttClient>, Error>; | ||
} | ||
|
||
pub struct InstanceState { | ||
allowed_hosts: OutboundAllowedHosts, | ||
connections: table::Table<Arc<dyn MqttClient>>, | ||
create_client: Arc<dyn ClientCreator>, | ||
} | ||
|
||
impl InstanceState { | ||
pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self { | ||
Self { | ||
allowed_hosts, | ||
create_client, | ||
connections: table::Table::new(1024), | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
pub trait MqttClient: Send + Sync { | ||
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>; | ||
} | ||
|
||
impl InstanceState { | ||
async fn is_address_allowed(&self, address: &str) -> Result<bool> { | ||
self.allowed_hosts.check_url(address, "mqtt").await | ||
} | ||
|
||
async fn establish_connection( | ||
&mut self, | ||
address: String, | ||
username: String, | ||
password: String, | ||
keep_alive_interval: Duration, | ||
) -> Result<Resource<Connection>, Error> { | ||
self.connections | ||
.push((self.create_client).create(address, username, password, keep_alive_interval)?) | ||
.map(Resource::new_own) | ||
.map_err(|_| Error::TooManyConnections) | ||
} | ||
|
||
async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> { | ||
self.connections | ||
.get(connection.rep()) | ||
.ok_or(Error::Other( | ||
"could not find connection for resource".into(), | ||
)) | ||
.map(|c| c.as_ref()) | ||
} | ||
} | ||
|
||
impl v2::Host for InstanceState { | ||
fn convert_error(&mut self, error: Error) -> Result<Error> { | ||
Ok(error) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl v2::HostConnection for InstanceState { | ||
#[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))] | ||
async fn open( | ||
&mut self, | ||
address: String, | ||
username: String, | ||
password: String, | ||
keep_alive_interval: u64, | ||
) -> Result<Resource<Connection>, Error> { | ||
if !self | ||
.is_address_allowed(&address) | ||
.await | ||
.map_err(|e| v2::Error::Other(e.to_string()))? | ||
{ | ||
return Err(v2::Error::ConnectionFailed(format!( | ||
"address {address} is not permitted" | ||
))); | ||
} | ||
self.establish_connection( | ||
address, | ||
username, | ||
password, | ||
Duration::from_secs(keep_alive_interval), | ||
) | ||
.await | ||
} | ||
|
||
/// Publish a message to the MQTT broker. | ||
/// | ||
/// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the | ||
/// current trace context into the payload yourself. | ||
/// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation. | ||
#[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO), | ||
fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish", | ||
messaging.system = "mqtt"))] | ||
async fn publish( | ||
&mut self, | ||
connection: Resource<Connection>, | ||
topic: String, | ||
payload: Vec<u8>, | ||
qos: Qos, | ||
) -> Result<(), Error> { | ||
let conn = self.get_conn(connection).await.map_err(other_error)?; | ||
|
||
conn.publish_bytes(topic, qos, payload).await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> { | ||
self.connections.remove(connection.rep()); | ||
Ok(()) | ||
} | ||
} | ||
|
||
pub fn other_error(e: impl std::fmt::Display) -> Error { | ||
Error::Other(e.to_string()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
mod host; | ||
|
||
use std::sync::Arc; | ||
use std::time::Duration; | ||
|
||
use host::other_error; | ||
use host::InstanceState; | ||
use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; | ||
use spin_core::async_trait; | ||
use spin_factor_outbound_networking::OutboundNetworkingFactor; | ||
use spin_factors::{ | ||
ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, | ||
SelfInstanceBuilder, | ||
}; | ||
use spin_world::v2::mqtt::{self as v2, Error, Qos}; | ||
use tokio::sync::Mutex; | ||
|
||
pub use host::{ClientCreator, MqttClient}; | ||
|
||
pub struct OutboundMqttFactor { | ||
create_client: Arc<dyn ClientCreator>, | ||
} | ||
|
||
impl OutboundMqttFactor { | ||
pub fn new(create_client: Arc<dyn ClientCreator>) -> Self { | ||
Self { create_client } | ||
} | ||
} | ||
|
||
impl Factor for OutboundMqttFactor { | ||
type RuntimeConfig = (); | ||
type AppState = (); | ||
type InstanceBuilder = InstanceState; | ||
|
||
fn init<T: Send + 'static>( | ||
&mut self, | ||
mut ctx: spin_factors::InitContext<T, Self>, | ||
) -> anyhow::Result<()> { | ||
ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?; | ||
Ok(()) | ||
} | ||
|
||
fn configure_app<T: RuntimeFactors>( | ||
&self, | ||
_ctx: ConfigureAppContext<T, Self>, | ||
) -> anyhow::Result<Self::AppState> { | ||
Ok(()) | ||
} | ||
|
||
fn prepare<T: RuntimeFactors>( | ||
&self, | ||
_ctx: PrepareContext<Self>, | ||
builders: &mut InstanceBuilders<T>, | ||
) -> anyhow::Result<Self::InstanceBuilder> { | ||
let allowed_hosts = builders | ||
.get_mut::<OutboundNetworkingFactor>()? | ||
.allowed_hosts(); | ||
Ok(InstanceState::new( | ||
allowed_hosts, | ||
self.create_client.clone(), | ||
)) | ||
} | ||
} | ||
|
||
impl SelfInstanceBuilder for InstanceState {} | ||
|
||
// This is a concrete implementation of the MQTT client using rumqttc. | ||
pub struct NetworkedMqttClient { | ||
inner: rumqttc::AsyncClient, | ||
event_loop: Mutex<rumqttc::EventLoop>, | ||
} | ||
|
||
const MQTT_CHANNEL_CAP: usize = 1000; | ||
|
||
impl NetworkedMqttClient { | ||
pub fn create( | ||
address: String, | ||
username: String, | ||
password: String, | ||
keep_alive_interval: Duration, | ||
) -> Result<Self, Error> { | ||
let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| { | ||
tracing::error!("MQTT URL parse error: {e:?}"); | ||
Error::InvalidAddress | ||
})?; | ||
conn_opts.set_credentials(username, password); | ||
conn_opts.set_keep_alive(keep_alive_interval); | ||
let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP); | ||
Ok(Self { | ||
inner: client, | ||
event_loop: Mutex::new(event_loop), | ||
}) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl MqttClient for NetworkedMqttClient { | ||
async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error> { | ||
let qos = match qos { | ||
Qos::AtMostOnce => rumqttc::QoS::AtMostOnce, | ||
Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce, | ||
Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce, | ||
}; | ||
// Message published to EventLoop (not MQTT Broker) | ||
self.inner | ||
.publish_bytes(topic, qos, false, payload.into()) | ||
.await | ||
.map_err(other_error)?; | ||
|
||
// Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error. | ||
// We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool. | ||
let mut lock = self.event_loop.lock().await; | ||
loop { | ||
let event = lock | ||
.poll() | ||
.await | ||
.map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?; | ||
|
||
match (qos, event) { | ||
(QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_))) | ||
| (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_))) | ||
| (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break, | ||
|
||
(_, _) => continue, | ||
} | ||
} | ||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.