From d67ceb72731309f1c727b39248eca9704a724347 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:20:36 +0100 Subject: [PATCH 01/23] Refactor to desired structure (#93) * Refactor to desired structure * fix metric test * Update branch.yaml * cargo fmt * fix feature flags * simplify feature flags for now * add env_dependency --- .github/workflows/branch.yaml | 8 +- dsh_sdk/CHANGELOG.md | 26 + dsh_sdk/Cargo.toml | 28 +- dsh_sdk/README.md | 2 +- dsh_sdk/examples/mqtt_token_fetcher.rs | 3 +- .../mqtt_token_fetcher_specific_claims.rs | 3 +- dsh_sdk/examples/rest_api_token_fetcher.rs | 2 +- .../src/{dsh => certificates}/bootstrap.rs | 147 ++- dsh_sdk/src/certificates/mod.rs | 392 ++++++++ .../{dsh => certificates}/pki_config_dir.rs | 93 +- dsh_sdk/src/{dsh => }/datastream.rs | 6 +- dsh_sdk/src/dlq.rs | 542 +---------- dsh_sdk/src/dsh.rs | 701 ++++++++++++++ dsh_sdk/src/{dsh => dsh_old}/certificates.rs | 130 +-- dsh_sdk/src/dsh_old/datastream.rs | 649 +++++++++++++ dsh_sdk/src/{dsh => dsh_old}/mod.rs | 31 +- dsh_sdk/src/{dsh => dsh_old}/properties.rs | 354 ++++--- dsh_sdk/src/error.rs | 20 +- dsh_sdk/src/graceful_shutdown.rs | 2 +- dsh_sdk/src/lib.rs | 76 +- dsh_sdk/src/management_api/mod.rs | 1 + dsh_sdk/src/management_api/token_fetcher.rs | 597 ++++++++++++ dsh_sdk/src/metrics.rs | 17 +- dsh_sdk/src/mqtt_token_fetcher.rs | 856 +---------------- .../protocol_adapters/http_protocol/mod.rs | 3 + .../kafka_protocol}/config.rs | 141 ++- .../protocol_adapters/kafka_protocol/mod.rs | 12 + dsh_sdk/src/protocol_adapters/mod.rs | 11 + .../protocol_adapters/mqtt_protocol/mod.rs | 3 + .../protocol_adapters/token_fetcher/mod.rs | 860 ++++++++++++++++++ dsh_sdk/src/rest_api_token_fetcher.rs | 576 +----------- dsh_sdk/src/utils/dlq.rs | 562 ++++++++++++ dsh_sdk/src/utils/graceful_shutdown.rs | 203 +++++ dsh_sdk/src/utils/metrics.rs | 334 +++++++ dsh_sdk/src/{utils.rs => utils/mod.rs} | 59 +- example_dsh_service/src/main.rs | 2 +- 36 files changed, 5027 insertions(+), 2425 deletions(-) rename dsh_sdk/src/{dsh => certificates}/bootstrap.rs (67%) create mode 100644 dsh_sdk/src/certificates/mod.rs rename dsh_sdk/src/{dsh => certificates}/pki_config_dir.rs (73%) rename dsh_sdk/src/{dsh => }/datastream.rs (99%) create mode 100644 dsh_sdk/src/dsh.rs rename dsh_sdk/src/{dsh => dsh_old}/certificates.rs (73%) create mode 100644 dsh_sdk/src/dsh_old/datastream.rs rename dsh_sdk/src/{dsh => dsh_old}/mod.rs (56%) rename dsh_sdk/src/{dsh => dsh_old}/properties.rs (96%) create mode 100644 dsh_sdk/src/management_api/mod.rs create mode 100644 dsh_sdk/src/management_api/token_fetcher.rs create mode 100644 dsh_sdk/src/protocol_adapters/http_protocol/mod.rs rename dsh_sdk/src/{dsh => protocol_adapters/kafka_protocol}/config.rs (76%) create mode 100644 dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/mqtt_protocol/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs create mode 100644 dsh_sdk/src/utils/dlq.rs create mode 100644 dsh_sdk/src/utils/graceful_shutdown.rs create mode 100644 dsh_sdk/src/utils/metrics.rs rename dsh_sdk/src/{utils.rs => utils/mod.rs} (82%) diff --git a/.github/workflows/branch.yaml b/.github/workflows/branch.yaml index eef40da..9ca519e 100644 --- a/.github/workflows/branch.yaml +++ b/.github/workflows/branch.yaml @@ -2,9 +2,11 @@ name: Branch on: push: - branches: [ "*", "!main" ] + branches-ignore: + - main pull_request: - branches: [ "*", "!main" ] + branches-ignore: + - main env: CARGO_TERM_COLOR: always @@ -56,4 +58,4 @@ jobs: if: matrix.version == 'stable' - name: cargo check all features run: cargo hack check --feature-powerset --no-dev-deps - if: matrix.version == 'stable' \ No newline at end of file + if: matrix.version == 'stable' diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index e8cf0cf..c0eac99 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - unreleased +### Added +- New public functions `dsh_sdk::certificates::Cert` + - Bootstrap to DSH + - Read certificates from PKI_CONFIG_DIR + - Add support reading private key in DER format when reading from PKI_CONFIG_DIR + +### Changed +- Moved `dsh_sdk::dsh::properties` to `dsh_sdk::propeties` +- Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` + - **NOTE** Cargo.toml feature flag falls now under `management_api` (`rest-token-fetcher` will be removed in v0.6.0) +- Moved `dsh_sdk::dsh::datastreams` to `dsh_sdk::datastreams` +- Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` + - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module +- Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` where it is renamed to `ProtocolTokenFetcher` + - **NOTE** Cargo.toml feature flag falls now under `mqtt-protocol` (`mqtt_token_fetcher` will be removed in v0.6.0) +- Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` +- Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` +- Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` + +### Removed +- Removed `Default` trait for `Dsh` (original `Properties`) struct as this should be public + +### Fixed + + ## [0.4.10] -2024-09-30 ### Added - Add new with client methods to REST and MQTTT token fetcher diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 2acd177..0dfa5b3 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -17,13 +17,12 @@ all-features = true [dependencies] base64 = {version = "0.22", optional = true } bytes = { version = "1.6", optional = true } -dashmap = {version = "6.0", optional = true} http-body-util = { version = "0.1", optional = true } hyper = { version = "1.3", features = ["server", "http1"], optional = true } hyper-util = { version = "0.1", features = ["tokio"], optional = true } lazy_static = { version = "1.5", optional = true } log = "0.4" -pem = "3" +pem = {version = "3", optional = true } prometheus = { version = "0.13", features = ["process"], optional = true } rcgen = { version = "0.13", optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "blocking"], optional = true } @@ -36,22 +35,29 @@ tokio = { version = "^1.35", features = ["signal", "sync", "time", "macros"], op tokio-util = { version = "0.7", default-features = false, optional = true } [features] -default = ["bootstrap", "graceful_shutdown", "metrics", "rdkafka-ssl"] +default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl"] -bootstrap = ["rcgen", "serde_json", "reqwest", "tokio/rt-multi-thread"] +bootstrap = ["certificate", "serde_json", "tokio/rt-multi-thread"] +certificate = ["rcgen", "reqwest", "pem"] metrics = ["prometheus", "hyper", "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] -dlq = ["tokio", "bootstrap", "rdkafka-ssl", "graceful_shutdown"] -graceful_shutdown = ["tokio", "tokio-util"] -rdkafka-ssl-vendored = ["rdkafka", "rdkafka/ssl-vendored"] -rdkafka-ssl = ["rdkafka", "rdkafka/ssl"] -rest-token-fetcher = ["reqwest"] -mqtt-token-fetcher = ["base64","dashmap","reqwest","serde_json","sha2","tokio/sync"] +dlq = ["tokio", "bootstrap", "rdkafka-ssl", "graceful-shutdown"] +graceful-shutdown = ["tokio", "tokio-util"] +management-api = ["reqwest"] +protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] +# http-protocol-adapter = ["protocol-token-fetcher"] +# mqtt-protocol-adapter = ["protocol-token-fetcher"] + +# TODO: Remove the following features at v0.6.0 +rdkafka-ssl-vendored = ["rdkafka", "rdkafka/ssl-vendored", "rdkafka/cmake-build"] +rdkafka-ssl = ["rdkafka", "rdkafka/ssl", "rdkafka/cmake-build"] +#mqtt-token-fetcher = ["protocol-token-fetcher"] +#rest-token-fetcher = ["management-api"] [dev-dependencies] mockito = "1.1.1" openssl = "0.10" tokio = { version = "^1.35", features = ["full"] } -hyper = { version = "1.2.0", features = ["full"]} +hyper = { version = "1.3", features = ["full"]} serial_test = "3.1.0" dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.2.0"} dsh_sdk = { features = ["dlq"], path = "." } \ No newline at end of file diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 11500ed..4d7639a 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -64,7 +64,7 @@ The following features are available in this library and can be enabled/disabled | `rest-token-fetcher` | ✗ | Fetch tokens to use DSH Rest API | | `mqtt-token-fetcher` | ✗ | Fetch tokens to use DSH MQTT | | `metrics` | ✓ | Enable (custom) metrics for your service | -| `graceful_shutdown` | ✓ | Create a signal handler for implementing a graceful shutdown | +| `graceful-shutdown` | ✓ | Create a signal handler for implementing a graceful shutdown | | `dlq` | ✗ | Dead Letter Queue implementation (experimental) | | `rdkafka-ssl` | ✓ | Dynamically link to librdkafka to a locally installed OpenSSL | | `rdkafka-ssl-vendored` | ✗ | Build OpenSSL during compile and statically link librdkafka
(No initial install required in environment, slower compile time) | diff --git a/dsh_sdk/examples/mqtt_token_fetcher.rs b/dsh_sdk/examples/mqtt_token_fetcher.rs index 827fb3b..9b48c95 100644 --- a/dsh_sdk/examples/mqtt_token_fetcher.rs +++ b/dsh_sdk/examples/mqtt_token_fetcher.rs @@ -6,8 +6,7 @@ use dsh_sdk::mqtt_token_fetcher::{MqttToken, MqttTokenFetcher}; async fn main() { let tenant_name = env::var("TENANT").unwrap().to_string(); let api_key = env::var("API_KEY").unwrap().to_string(); - let mqtt_token_fetcher: MqttTokenFetcher = - MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); + let mqtt_token_fetcher = MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); let token: MqttToken = mqtt_token_fetcher .get_token("Client-id", None) //Claims = None fetches all possible claims .await diff --git a/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs b/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs index 1f820a3..3439128 100644 --- a/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs +++ b/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs @@ -16,8 +16,7 @@ async fn main() { let claims_vector = vec![claims_sub, claims_pub]; - let mqtt_token_fetcher: MqttTokenFetcher = - MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); + let mqtt_token_fetcher = MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); let token: MqttToken = mqtt_token_fetcher .get_token("Client-id", Some(claims_vector)) diff --git a/dsh_sdk/examples/rest_api_token_fetcher.rs b/dsh_sdk/examples/rest_api_token_fetcher.rs index 53a9532..dd1df21 100644 --- a/dsh_sdk/examples/rest_api_token_fetcher.rs +++ b/dsh_sdk/examples/rest_api_token_fetcher.rs @@ -8,6 +8,7 @@ use dsh_rest_api_client::Client; use dsh_sdk::{Platform, RestTokenFetcherBuilder}; use std::env; + #[tokio::main] async fn main() { let platform = Platform::NpLz; @@ -15,7 +16,6 @@ async fn main() { env::var("CLIENT_SECRET").expect("CLIENT_SECRET must be set as environment variable"); let tenant = env::var("TENANT").expect("TENANT must be set as environment variable"); let client = Client::new(platform.endpoint_rest_api()); - let tf = RestTokenFetcherBuilder::new(platform) .tenant_name(tenant.clone()) .client_secret(client_secret) diff --git a/dsh_sdk/src/dsh/bootstrap.rs b/dsh_sdk/src/certificates/bootstrap.rs similarity index 67% rename from dsh_sdk/src/dsh/bootstrap.rs rename to dsh_sdk/src/certificates/bootstrap.rs index d7a33e2..8e7acb2 100644 --- a/dsh_sdk/src/dsh/bootstrap.rs +++ b/dsh_sdk/src/certificates/bootstrap.rs @@ -1,18 +1,19 @@ -//! Module for bootstrapping the DSH client. +//! Module for bootstrapping to DSH. //! //! This module contains the logic to connect to DSH and retrieve the certificates and datastreams.json //! to create the properties struct. It follows the certificate signing request pattern as normally //! used in the get_signed_certificates_json.sh script. //! //! ## Note -//! This module is not intended to be used directly, but through the `Properties` struct. It will -//! always be used when getting a `Properties` struct via dsh::Properties::get(). +//! This module is NOT intended to be used directly, but through [Cert] or indirectly via [Properties](crate::Properties). use log::{debug, info}; use reqwest::blocking::Client; +use rcgen::{CertificateParams, CertificateSigningRequest, DnType, KeyPair}; + use crate::error::DshError; -use super::certificates::Cert; +use super::Cert; use crate::utils; use crate::{VAR_DSH_CA_CERTIFICATE, VAR_DSH_SECRET_TOKEN, VAR_DSH_SECRET_TOKEN_PATH}; @@ -22,11 +23,11 @@ pub(crate) fn bootstrap( tenant_name: &str, task_id: &str, ) -> Result { - let dsh_config = DshConfig::new(config_host, tenant_name, task_id)?; + let dsh_config = DshBootstrapConfig::new(config_host, tenant_name, task_id)?; let client = reqwest_ca_client(dsh_config.dsh_ca_certificate.as_bytes())?; let dn = DshBootstapCall::Dn(&dsh_config).perform_call(&client)?; let dn = Dn::parse_string(&dn)?; - let certificates = Cert::get_signed_client_cert(dn, &dsh_config, &client)?; + let certificates = get_signed_client_cert(dn, &dsh_config, &client)?; info!("Successfully connected to DSH"); Ok(certificates) } @@ -40,16 +41,51 @@ fn reqwest_ca_client(dsh_ca_certificate: &[u8]) -> Result Result { + let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384)?; + let csr = generate_csr(&key_pair, dn)?; + let client_certificate = DshBootstapCall::CertificateSignRequest { + config: dsh_config, + csr: &csr.pem()?, + } + .perform_call(client)?; + let ca_cert = pem::parse_many(&dsh_config.dsh_ca_certificate)?; + let client_cert = pem::parse_many(client_certificate)?; + Ok(Cert::new( + pem::encode_many(&ca_cert), + pem::encode_many(&client_cert), + key_pair, + )) +} + +/// Generate the certificate signing request. +fn generate_csr(key_pair: &KeyPair, dn: Dn) -> Result { + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, dn.cn); + params + .distinguished_name + .push(DnType::OrganizationalUnitName, dn.ou); + params + .distinguished_name + .push(DnType::OrganizationName, dn.o); + Ok(params.serialize_request(key_pair)?) +} + /// Helper struct to store the config needed for bootstrapping to DSH #[derive(Debug)] -pub(crate) struct DshConfig<'a> { +struct DshBootstrapConfig<'a> { config_host: &'a str, tenant_name: &'a str, task_id: &'a str, dsh_secret_token: String, dsh_ca_certificate: String, } -impl<'a> DshConfig<'a> { +impl<'a> DshBootstrapConfig<'a> { fn new(config_host: &'a str, tenant_name: &'a str, task_id: &'a str) -> Result { let dsh_secret_token = match utils::get_env_var(VAR_DSH_SECRET_TOKEN) { Ok(token) => token, @@ -62,7 +98,7 @@ impl<'a> DshConfig<'a> { } }; let dsh_ca_certificate = utils::get_env_var(VAR_DSH_CA_CERTIFICATE)?; - Ok(DshConfig { + Ok(DshBootstrapConfig { config_host, task_id, tenant_name, @@ -70,24 +106,20 @@ impl<'a> DshConfig<'a> { dsh_ca_certificate, }) } - - pub(crate) fn dsh_ca_certificate(&self) -> &str { - &self.dsh_ca_certificate - } } -pub(crate) enum DshBootstapCall<'a> { +enum DshBootstapCall<'a> { /// Call to retreive distinguished name. - Dn(&'a DshConfig<'a>), + Dn(&'a DshBootstrapConfig<'a>), /// Call to post the certificate signing request. CertificateSignRequest { - config: &'a DshConfig<'a>, + config: &'a DshBootstrapConfig<'a>, csr: &'a str, }, } impl DshBootstapCall<'_> { - fn url_for_call(&self) -> String { + fn url(&self) -> String { match self { DshBootstapCall::Dn(config) => { format!( @@ -105,7 +137,7 @@ impl DshBootstapCall<'_> { } fn request_builder(&self, client: &Client) -> reqwest::blocking::RequestBuilder { - let url = self.url_for_call(); + let url = self.url(); match self { DshBootstapCall::Dn(..) => client.get(url), DshBootstapCall::CertificateSignRequest { config, csr, .. } => client @@ -115,11 +147,11 @@ impl DshBootstapCall<'_> { } } - pub(crate) fn perform_call(&self, client: &Client) -> Result { + fn perform_call(&self, client: &Client) -> Result { let response = self.request_builder(client).send()?; if !response.status().is_success() { return Err(DshError::DshCallError { - url: self.url_for_call(), + url: self.url(), status_code: response.status(), error_body: response.text().unwrap_or_default(), }); @@ -131,7 +163,7 @@ impl DshBootstapCall<'_> { /// Struct to parse DN string into separate fields. /// Needed for Picky solution. #[derive(Debug)] -pub(crate) struct Dn { +struct Dn { cn: String, ou: String, o: String, @@ -139,7 +171,7 @@ pub(crate) struct Dn { impl Dn { /// Parse the DN string into Dn struct. - pub(crate) fn parse_string(dn_string: &str) -> Result { + fn parse_string(dn_string: &str) -> Result { let mut cn = None; let mut ou = None; let mut o = None; @@ -168,17 +200,6 @@ impl Dn { ))?, }) } - pub(crate) fn cn(&self) -> &str { - &self.cn - } - - pub(crate) fn ou(&self) -> &str { - &self.ou - } - - pub(crate) fn o(&self) -> &str { - &self.o - } } #[cfg(test)] @@ -188,9 +209,24 @@ mod tests { use std::env; use std::str::from_utf8; + use rcgen::{generate_simple_self_signed, CertifiedKey}; + use std::sync::OnceLock; + + use openssl::pkey::PKey; + use openssl::x509::X509Req; + + static TEST_CERTIFICATES: OnceLock = OnceLock::new(); + + fn set_test_cert() -> Cert { + let subject_alt_names = vec!["hello.world.example".to_string(), "localhost".to_string()]; + let CertifiedKey { cert, key_pair } = + generate_simple_self_signed(subject_alt_names).unwrap(); + Cert::new(cert.pem(), cert.pem(), key_pair) + } + #[test] fn test_dsh_call_request_builder() { - let dsh_config = DshConfig { + let dsh_config = DshBootstrapConfig { config_host: "https://test_host", tenant_name: "test_tenant_name", task_id: "test_task_id", @@ -234,8 +270,8 @@ mod tests { .create(); // simple reqwest client let client = Client::new(); - // create a DshConfig struct - let dsh_config = DshConfig { + // create a DshBootstrapConfig struct + let dsh_config = DshBootstrapConfig { config_host: &dsh.url(), tenant_name: "tenant", task_id: "test_task_id", @@ -258,6 +294,37 @@ mod tests { assert_eq!(dn.o, "test_o"); } + #[test] + fn test_dsh_certificate_sign_request() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); + let csr = generate_csr(&cert.key_pair, dn).unwrap(); + let req = csr.pem().unwrap(); + assert!(req.starts_with("-----BEGIN CERTIFICATE REQUEST-----")); + assert!(req.trim().ends_with("-----END CERTIFICATE REQUEST-----")); + } + + #[test] + fn test_verify_csr() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); + let csr = generate_csr(&cert.key_pair, dn).unwrap(); + let csr_pem = csr.pem().unwrap(); + let key = cert.private_key_pkcs8(); + let pkey = PKey::private_key_from_der(&key).unwrap(); + + let req = X509Req::from_pem(csr_pem.as_bytes()).unwrap(); + req.verify(&pkey).unwrap(); + let subject = req + .subject_name() + .entries() + .into_iter() + .map(|e| e.data().as_utf8().unwrap().to_string()) + .collect::>() + .join(","); + assert_eq!(subject, "Test CN,Test OU,Test Org"); + } + #[test] #[serial(env_dependency)] fn test_dsh_config_new() { @@ -267,7 +334,7 @@ mod tests { let config_host = "https://test_host"; let tenant_name = "test_tenant"; let task_id = "test_task_id"; - let dsh_config = DshConfig::new(config_host, tenant_name, task_id).unwrap(); + let dsh_config = DshBootstrapConfig::new(config_host, tenant_name, task_id).unwrap(); assert_eq!(dsh_config.config_host, "https://test_host"); assert_eq!(dsh_config.task_id, "test_task_id"); assert_eq!(dsh_config.tenant_name, "test_tenant"); @@ -280,14 +347,14 @@ mod tests { let test_token_dir = format!("{}/test_token", test_token_dir); let _ = std::fs::remove_file(&test_token_dir); env::set_var(VAR_DSH_SECRET_TOKEN_PATH, &test_token_dir); - let result = DshConfig::new(config_host, tenant_name, task_id); + let result = DshBootstrapConfig::new(config_host, tenant_name, task_id); assert!(result.is_err()); std::fs::write(test_token_dir.as_str(), "test_token_from_file").unwrap(); - let dsh_config = DshConfig::new(config_host, tenant_name, task_id).unwrap(); + let dsh_config = DshBootstrapConfig::new(config_host, tenant_name, task_id).unwrap(); assert_eq!(dsh_config.dsh_secret_token, "test_token_from_file"); // fail if DSH_CA_CERTIFICATE is not set env::remove_var(VAR_DSH_CA_CERTIFICATE); - let result = DshConfig::new(config_host, tenant_name, task_id); + let result = DshBootstrapConfig::new(config_host, tenant_name, task_id); assert!(result.is_err()); env::remove_var(VAR_DSH_SECRET_TOKEN); env::remove_var(VAR_DSH_CA_CERTIFICATE); diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs new file mode 100644 index 0000000..c917399 --- /dev/null +++ b/dsh_sdk/src/certificates/mod.rs @@ -0,0 +1,392 @@ +//! This module holds the certificate struct and its methods. +//! +//! The certificate struct holds the DSH CA certificate, the DSH Kafka certificate and +//! the private key. It also has methods to create a reqwest client with the DSH Kafka +//! certificate included and to retrieve the certificates and keys as PEM strings. Also +//! it is possible to create the ca.crt, client.pem, and client.key files in a desired +//! directory. +//! +//! ## Create files +//! +//! To create the ca.crt, client.pem, and client.key files in a desired directory, use the +//! `to_files` method. +//! ```no_run +//! use dsh_sdk::certificates::Cert; +//! use std::path::PathBuf; +//! +//! # fn main() -> Result<(), Box> { +//! let certificates = Cert::from_env()?; +//! let directory = PathBuf::from("path/to/dir"); +//! certificates.to_files(&directory)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Reqwest Client +//! With this request client we can retrieve datastreams.json and connect to Schema Registry. +use std::path::PathBuf; +use std::sync::Arc; + +use log::{info, warn}; +use rcgen::KeyPair; +use reqwest::blocking::{Client, ClientBuilder}; + +use crate::error::DshError; +use crate::utils; +use crate::{DEFAULT_CONFIG_HOST, VAR_KAFKA_CONFIG_HOST, VAR_PKI_CONFIG_DIR, VAR_TASK_ID}; + +#[cfg(feature = "bootstrap")] +mod bootstrap; +#[cfg(feature = "bootstrap")] +mod pki_config_dir; + +/// Hold all relevant certificates and keys to connect to DSH Kafka Cluster and Schema Store. +#[derive(Debug, Clone)] +pub struct Cert { + dsh_ca_certificate_pem: String, + dsh_client_certificate_pem: String, + key_pair: Arc, +} + +impl Cert { + /// Create new `Cert` struct + fn new( + dsh_ca_certificate_pem: String, + dsh_client_certificate_pem: String, + key_pair: KeyPair, + ) -> Cert { + Self { + dsh_ca_certificate_pem, + dsh_client_certificate_pem, + key_pair: Arc::new(key_pair), + } + } + + /// Bootstrap to DSH and sign the certificates. + /// + /// This method will get DSH CA certificate, sign the Kafka certificate and generate a private key. + /// + /// ## Recommended + /// Use [Cert::from_env] to get the certificates and keys. As this method will check based on the injected environment variables by DSH. + /// This method also allows you to easily switch between Kafka Proxy or VPN connection, based on `PKI_CONFIG_DIR` environment variable. + /// + /// ## Arguments + /// * `config_host` - The DSH config host where the CSR can be send to. (default: `"https://pikachu.dsh.marathon.mesos:4443"`) + /// * `tenant_name` - The tenant name. + /// * `task_id` - The task id of running container. + #[cfg(feature = "bootstrap")] + pub fn from_bootstrap( + config_host: &str, + tenant_name: &str, + task_id: &str, + ) -> Result { + bootstrap::bootstrap(config_host, tenant_name, task_id) + } + + /// Bootstrap to DSH and sign the certificates based on the injected environment variables by DSH. + /// + /// This method will first check if `PKI_CONFIG_DIR` environment variable is set. If set, it will use the certificates from the directory. + /// This is usefull when you want to use Kafka Proxy, VPN or when a different process that already created the certificates. More info at [CONNECT_PROXY_VPN_LOCAL.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/CONNECT_PROXY_VPN_LOCAL.md). + /// + /// Else it will check `KAFKA_CONFIG_HOST`, `MESOS_TASK_ID` and `MARATHON_APP_ID` environment variables to bootstrap to DSH and sign the certificates. + /// These environment variables are injected by DSH. + #[cfg(feature = "bootstrap")] + pub fn from_env() -> Result { + if let Ok(path) = utils::get_env_var(VAR_PKI_CONFIG_DIR) { + Self::from_pki_config_dir(Some(path)) + } else { + let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST) + .map(|host| ensure_https_prefix(host)) + .unwrap_or_else(|_| { + warn!( + "{} is not set, using default value {}", + VAR_KAFKA_CONFIG_HOST, DEFAULT_CONFIG_HOST + ); + DEFAULT_CONFIG_HOST.to_string() + }); + let task_id = utils::get_env_var(VAR_TASK_ID)?; + let tenant_name = utils::tenant_name()?; + Self::from_bootstrap(&config_host, &tenant_name, &task_id) + } + } + + /// Get the certificates from a directory. + /// + /// This method is usefull if you already have the certificates in a directory. + /// For example if you are using Kafka Proxy, VPN or when a different process already + /// created the certificates. + /// + /// ## Arguments + /// * `path` - Path to the directory where the certificates are stored (Optional). + /// + /// path can be overruled by setting the environment variable `PKI_CONFIG_DIR`. + /// + /// ## Note + /// Only certificates in PEM format are supported. + /// Key files should be in PKCS8 format and can be DER or PEM files. + #[cfg(feature = "bootstrap")] + pub fn from_pki_config_dir

(path: Option

) -> Result + where + P: AsRef, + { + pki_config_dir::get_pki_certificates(path) + } + + /// Build an async reqwest client with the DSH Kafka certificate included. + /// With this client we can retrieve datastreams.json and conenct to Schema Registry. + pub fn reqwest_client_config(&self) -> Result { + let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( + self.dsh_kafka_certificate_pem(), + &self.private_key_pem(), + self.dsh_ca_certificate_pem(), + )?; + Ok(reqwest::Client::builder() + .add_root_certificate(reqwest_cert) + .identity(pem_identity) + .use_rustls_tls()) + } + + /// Build a reqwest client with the DSH Kafka certificate included. + /// With this client we can retrieve datastreams.json and conenct to Schema Registry. + pub fn reqwest_blocking_client_config(&self) -> Result { + let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( + self.dsh_kafka_certificate_pem(), + &self.private_key_pem(), + self.dsh_ca_certificate_pem(), + )?; + Ok(Client::builder() + .add_root_certificate(reqwest_cert) + .identity(pem_identity) + .use_rustls_tls()) + } + + /// Get the root certificate as PEM string. Equivalent to ca.crt. + pub fn dsh_ca_certificate_pem(&self) -> &str { + self.dsh_ca_certificate_pem.as_str() + } + + /// Get the kafka certificate as PEM string. Equivalent to client.pem. + pub fn dsh_kafka_certificate_pem(&self) -> &str { + self.dsh_client_certificate_pem.as_str() + } + + /// Get the private key as PKCS8 and return bytes based on asn1 DER format. + pub fn private_key_pkcs8(&self) -> Vec { + self.key_pair.serialize_der() + } + + /// Get the private key as PEM string. Equivalent to client.key. + pub fn private_key_pem(&self) -> String { + self.key_pair.serialize_pem() + } + + /// Get the public key as PEM string. + pub fn public_key_pem(&self) -> String { + self.key_pair.public_key_pem() + } + + /// Get the public key as DER bytes. + pub fn public_key_der(&self) -> Vec { + self.key_pair.public_key_der() + } + + /// Create the ca.crt, client.pem, and client.key files in a desired directory. + /// + /// This method will create the directory if it does not exist. + /// + /// # Example + /// + /// ```no_run + /// use dsh_sdk::Properties; + /// use std::path::PathBuf; + /// + /// # fn main() -> Result<(), Box> { + /// let dsh_properties = Properties::get(); + /// let directory = PathBuf::from("dir"); + /// dsh_properties.certificates()?.to_files(&directory)?; + /// # Ok(()) + /// # } + /// ``` + pub fn to_files(&self, dir: &PathBuf) -> Result<(), DshError> { + std::fs::create_dir_all(dir)?; + Self::create_file(dir.join("ca.crt"), self.dsh_ca_certificate_pem())?; + Self::create_file(dir.join("client.pem"), self.dsh_kafka_certificate_pem())?; + Self::create_file(dir.join("client.key"), self.private_key_pem())?; + Ok(()) + } + + fn create_file>(path: PathBuf, contents: C) -> Result<(), DshError> { + std::fs::write(&path, contents)?; + info!("File created ({})", path.display()); + Ok(()) + } + + fn create_identity( + cert: &[u8], + private_key: &[u8], + ) -> Result { + let mut ident = private_key.to_vec(); + ident.extend_from_slice(b"\n"); + ident.extend_from_slice(cert); + reqwest::Identity::from_pem(&ident) + } + + fn prepare_reqwest_client( + kafka_certificate: &str, + private_key: &str, + ca_certificate: &str, + ) -> Result<(reqwest::Identity, reqwest::tls::Certificate), DshError> { + let pem_identity = + Cert::create_identity(kafka_certificate.as_bytes(), private_key.as_bytes())?; + let reqwest_cert = reqwest::tls::Certificate::from_pem(ca_certificate.as_bytes())?; + Ok((pem_identity, reqwest_cert)) + } +} + +/// Helper function to ensure that the host starts with `https://` (or `http://`) +fn ensure_https_prefix(host: String) -> String { + if host.starts_with("https://") || host.starts_with("http://") { + host + } else { + format!("https://{}", host) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{generate_simple_self_signed, CertifiedKey}; + use std::sync::OnceLock; + + use openssl::pkey::PKey; + + static TEST_CERTIFICATES: OnceLock = OnceLock::new(); + + fn set_test_cert() -> Cert { + let subject_alt_names = vec!["hello.world.example".to_string(), "localhost".to_string()]; + let CertifiedKey { cert, key_pair } = + generate_simple_self_signed(subject_alt_names).unwrap(); + Cert::new(cert.pem(), cert.pem(), key_pair) + } + + #[test] + fn test_private_key_pem() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let der = cert.key_pair.serialize_der(); + let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey_pem_bytes = pkey.private_key_to_pem_pkcs8().unwrap(); + + let key_pem = cert.private_key_pem(); + let pkey_pem = String::from_utf8_lossy(pkey_pem_bytes.as_slice()); + assert_eq!(key_pem, pkey_pem); + } + + #[test] + fn test_public_key_pem() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let der = cert.key_pair.serialize_der(); + let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey_pub_pem_bytes = pkey.public_key_to_pem().unwrap(); + + let pub_pem = cert.public_key_pem(); + let pkey_pub_pem = String::from_utf8_lossy(pkey_pub_pem_bytes.as_slice()); + assert_eq!(pub_pem, pkey_pub_pem); + } + + #[test] + fn test_public_key_der() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let der = cert.key_pair.serialize_der(); + let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey_pub_der = pkey.public_key_to_der().unwrap(); + + let pub_der = cert.public_key_der(); + assert_eq!(pub_der, pkey_pub_der); + } + + #[test] + fn test_private_key_pkcs8() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let der = cert.key_pair.serialize_der(); + let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey = pkey.private_key_to_pkcs8().unwrap(); + + let key = cert.private_key_pkcs8(); + assert_eq!(key, pkey); + } + + #[test] + fn test_dsh_ca_certificate_pem() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let pem = cert.dsh_ca_certificate_pem(); + assert_eq!(pem, cert.dsh_ca_certificate_pem); + } + + #[test] + fn test_dsh_kafka_certificate_pem() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let pem = cert.dsh_kafka_certificate_pem(); + assert_eq!(pem, cert.dsh_client_certificate_pem); + } + + #[test] + fn test_write_files() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let dir = PathBuf::from("test_files"); + cert.to_files(&dir).unwrap(); + let dir = "test_files"; + assert!(std::path::Path::new(&format!("{}/ca.crt", dir)).exists()); + assert!(std::path::Path::new(&format!("{}/client.pem", dir)).exists()); + assert!(std::path::Path::new(&format!("{}/client.key", dir)).exists()); + } + + #[test] + fn test_create_identity() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let identity = Cert::create_identity( + cert.dsh_kafka_certificate_pem().as_bytes(), + cert.private_key_pem().as_bytes(), + ); + assert!(identity.is_ok()); + } + + #[test] + fn test_prepare_reqwest_client() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let result = Cert::prepare_reqwest_client( + cert.dsh_kafka_certificate_pem(), + &cert.private_key_pem(), + cert.dsh_ca_certificate_pem(), + ); + assert!(result.is_ok()); + } + + #[test] + fn test_reqwest_client_config() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let client = cert.reqwest_client_config(); + assert!(client.is_ok()); + } + + #[test] + fn test_reqwest_blocking_client_config() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let client = cert.reqwest_blocking_client_config(); + assert!(client.is_ok()); + } + + #[test] + fn test_ensure_https_prefix() { + let host = "http://example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "http://example.com"); + + let host = "https://example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + + let host = "example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + } +} diff --git a/dsh_sdk/src/dsh/pki_config_dir.rs b/dsh_sdk/src/certificates/pki_config_dir.rs similarity index 73% rename from dsh_sdk/src/dsh/pki_config_dir.rs rename to dsh_sdk/src/certificates/pki_config_dir.rs index d409702..a2665f6 100644 --- a/dsh_sdk/src/dsh/pki_config_dir.rs +++ b/dsh_sdk/src/certificates/pki_config_dir.rs @@ -5,7 +5,7 @@ //! //! This also makes it possible to use the DSH SDK with Kafka Proxy //! or VPN outside of the DSH environment. -use super::certificates::Cert; +use super::Cert; use crate::error::DshError; use crate::{utils, VAR_PKI_CONFIG_DIR}; @@ -14,8 +14,13 @@ use pem::{self, Pem}; use rcgen::KeyPair; use std::path::{Path, PathBuf}; -pub(crate) fn get_pki_cert() -> Result { - let config_dir = PathBuf::from(utils::get_env_var(VAR_PKI_CONFIG_DIR)?); +pub(crate) fn get_pki_certificates

(pki_config_dir: Option

) -> Result +where + P: AsRef, +{ + let config_dir = pki_config_dir + .map(|dir| dir.as_ref().to_path_buf()) + .unwrap_or(PathBuf::from(utils::get_env_var(VAR_PKI_CONFIG_DIR)?)); let ca_cert_paths = get_file_path_bufs("ca", PkiFileType::Cert, &config_dir)?; let dsh_ca_certificate_pem = get_certificate(ca_cert_paths)?; let client_cert_paths = get_file_path_bufs("client", PkiFileType::Cert, &config_dir)?; @@ -33,8 +38,8 @@ pub(crate) fn get_pki_cert() -> Result { /// Get certificate from the PKI config directory /// /// Looks for all files containing client*.pem and client.crt in the PKI config directory. -fn get_certificate(mut cert_paths: Vec) -> Result, DshError> { - 'file: while let Some(file) = cert_paths.pop() { +fn get_certificate(cert_paths: Vec) -> Result, DshError> { + 'file: for file in cert_paths { info!("{} - Reading certificate file", file.display()); if let Ok(ca_cert) = std::fs::read(&file) { let pem_result = pem::parse_many(&ca_cert); @@ -61,29 +66,33 @@ fn get_certificate(mut cert_paths: Vec) -> Result, DshError> { Err(DshError::NoCertificates) } -/// Get certificate from the PKI config directory +/// Get key pair from a file in the PKI config directory /// -/// Looks for all files containing client*.pem and client.crt in the PKI config directory. -fn get_key_pair(mut key_paths: Vec) -> Result { - while let Some(file) = key_paths.pop() { +/// Returns first succesfull converted key pair found in the given list of paths +fn get_key_pair(key_paths: Vec) -> Result { + for file in key_paths { info!("{} - Reading key file", file.display()); if let Ok(bytes) = std::fs::read(&file) { - if let Ok(string) = std::string::String::from_utf8(bytes) { + if let Ok(string) = std::string::String::from_utf8(bytes.clone()) { debug!("{} - Key parsed as string", file.display()); - match rcgen::KeyPair::from_pem(&string) { - Ok(key_pair) => { - debug!("{} - Key parsed as KeyPair from string", file.display()); - return Ok(key_pair); - } - Err(e) => warn!("{} - Error parsing key: {:?}", file.display(), e), + if let Ok(key_pair) = rcgen::KeyPair::from_pem(&string) { + debug!("{} - Key parsed as KeyPair from string", file.display()); + return Ok(key_pair); + } else { + warn!("{} - Error parsing key from string", file.display()); } } + if let Ok(key_pair) = rcgen::KeyPair::try_from(bytes) { + debug!("{} - Key parsed as KeyPair from bytes", file.display()); + return Ok(key_pair); + } else { + warn!("{} - Error parsing key from bytes", file.display()); + } } } info!("No (valid) key found in the PKI config directory"); Err(DshError::NoCertificates) } - /// Get the path to the PKI config direc fn get_file_path_bufs

( prefix: &str, @@ -136,22 +145,26 @@ mod tests { use serial_test::serial; const PKI_CONFIG_DIR: &str = "test_files/pki_config_dir"; - const PKI_KEY_FILE_NAME: &str = "client.key"; + const PKI_KEY_FILE_PEM_NAME: &str = "client.key"; + const PKI_KEY_FILE_DER_NAME: &str = "client-der.key"; const PKI_CERT_FILE_NAME: &str = "client.pem"; const PKI_CA_FILE_NAME: &str = "ca.crt"; fn create_test_pki_config_dir() { let path = PathBuf::from(PKI_CONFIG_DIR); - let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_key_pem = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_PEM_NAME); + let path_key_der = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_DER_NAME); let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); - if path_key.exists() && path_cert.exists() && path_ca.exists() { + if path_key_pem.exists() && path_cert.exists() && path_ca.exists() && path_key_der.exists() + { return; } let _ = std::fs::create_dir(path); let priv_key = openssl::rsa::Rsa::generate(2048).unwrap(); let pkey = PKey::from_rsa(priv_key).unwrap(); - let key = pkey.private_key_to_pem_pkcs8().unwrap(); + let key_pem = pkey.private_key_to_pem_pkcs8().unwrap(); + let key_der = pkey.private_key_to_pkcs8().unwrap(); let mut x509_name = openssl::x509::X509NameBuilder::new().unwrap(); x509_name.append_entry_by_text("CN", "test_ca").unwrap(); let x509_name = x509_name.build(); @@ -168,7 +181,8 @@ mod tests { let x509 = x509.build(); let ca_cert = x509.to_pem().unwrap(); let cert = x509.to_pem().unwrap(); - std::fs::write(path_key, key).unwrap(); + std::fs::write(path_key_pem, key_pem).unwrap(); + std::fs::write(path_key_der, key_der).unwrap(); std::fs::write(path_ca, ca_cert).unwrap(); std::fs::write(path_cert, cert).unwrap(); } @@ -181,7 +195,7 @@ mod tests { let result_cert = get_file_path_bufs("client", PkiFileType::Cert, &path).unwrap(); assert_eq!(result_cert.len(), 1); let result_key = get_file_path_bufs("client", PkiFileType::Key, &path).unwrap(); - assert_eq!(result_key.len(), 1); + assert_eq!(result_key.len(), 2); assert_ne!(result_cert, result_key); let result_ca = get_file_path_bufs("ca", PkiFileType::Cert, &path).unwrap(); assert_eq!(result_ca.len(), 1); @@ -194,7 +208,7 @@ mod tests { #[serial(pki)] fn test_get_certificate() { create_test_pki_config_dir(); - let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_PEM_NAME); let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); let path_ne = PathBuf::from(PKI_CONFIG_DIR).join("not_existing.crt"); @@ -223,9 +237,30 @@ mod tests { #[test] #[serial(pki)] - fn test_get_key_pair() { + fn test_get_key_pair_pem() { + create_test_pki_config_dir(); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_PEM_NAME); + let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); + let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); + let path_ne = PathBuf::from(PKI_CONFIG_DIR).join("not_existing.key"); + let result = get_key_pair(vec![path_key.clone()]); + assert!(result.is_ok()); + let result = get_key_pair(vec![path_ne.clone(), path_key.clone()]); + assert!(result.is_ok()); + let result = + get_key_pair(vec![path_ne.clone(), path_cert.clone(), path_ca.clone()]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_key_pair(vec![]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_key_pair(vec![path_ne]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + } + + #[test] + #[serial(pki)] + fn test_get_key_pair_der() { create_test_pki_config_dir(); - let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_DER_NAME); let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); let path_ne = PathBuf::from(PKI_CONFIG_DIR).join("not_existing.key"); @@ -246,10 +281,10 @@ mod tests { #[serial(pki, env_dependency)] fn test_get_pki_cert() { create_test_pki_config_dir(); - let result = get_pki_cert().unwrap_err(); - assert!(matches!(result, DshError::EnvVarError(_))); + let result = get_pki_certificates::(None).unwrap_err(); + assert!(matches!(result, DshError::EnvVarError(_, _))); std::env::set_var(VAR_PKI_CONFIG_DIR, PKI_CONFIG_DIR); - let result = get_pki_cert(); + let result = get_pki_certificates::(None); assert!(result.is_ok()); std::env::remove_var(VAR_PKI_CONFIG_DIR); } diff --git a/dsh_sdk/src/dsh/datastream.rs b/dsh_sdk/src/datastream.rs similarity index 99% rename from dsh_sdk/src/dsh/datastream.rs rename to dsh_sdk/src/datastream.rs index 65c28d2..3687268 100644 --- a/dsh_sdk/src/dsh/datastream.rs +++ b/dsh_sdk/src/datastream.rs @@ -20,7 +20,7 @@ use std::env; use std::fs::File; use std::io::Read; -use log::{debug, error, info, warn}; +use log::{debug, error, info}; use serde::{Deserialize, Serialize}; use crate::error::DshError; @@ -151,7 +151,7 @@ impl Datastream { /// /// # Example /// ```no_run - /// # use dsh_sdk::dsh::datastream::Datastream; + /// # use dsh_sdk::datastream::Datastream; /// # let datastream = Datastream::default(); /// let path = std::path::PathBuf::from("/path/to/directory"); /// datastream.to_file(&path).unwrap(); @@ -402,7 +402,7 @@ impl GroupType { GroupType::Shared(0) } Err(_) => { - warn!("KAFKA_CONSUMER_GROUP_TYPE is not set, defaulting to shared group type."); + debug!("KAFKA_CONSUMER_GROUP_TYPE is not set, defaulting to shared group type."); GroupType::Shared(0) } } diff --git a/dsh_sdk/src/dlq.rs b/dsh_sdk/src/dlq.rs index 37963a8..b74f928 100644 --- a/dsh_sdk/src/dlq.rs +++ b/dsh_sdk/src/dlq.rs @@ -23,540 +23,8 @@ //! ### Example: //! See the examples folder on github for a working example. -use std::collections::HashMap; -use std::env; -use std::str::from_utf8; - -use log::{debug, error, info, warn}; - -use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; -use rdkafka::producer::{FutureProducer, FutureRecord}; - -use tokio::sync::mpsc; - -use crate::graceful_shutdown::Shutdown; -use crate::Properties; - -/// Trait to convert an error to a dlq message -/// This trait is implemented for all errors that can and should be converted to a dlq message -/// -/// Example: -///``` -/// use dsh_sdk::dlq; -/// use std::backtrace::Backtrace; -/// use thiserror::Error; -/// -/// #[derive(Error, Debug)] -/// enum ConsumerError { -/// #[error("Deserialization error: {0}")] -/// DeserializeError(String), -/// } -/// -/// impl dlq::ErrorToDlq for ConsumerError { -/// fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { -/// dlq::SendToDlq::new(kafka_message, self.retryable(), self.to_string(), None) -/// } -/// fn retryable(&self) -> dlq::Retryable { -/// match self { -/// ConsumerError::DeserializeError(e) => dlq::Retryable::NonRetryable, -/// } -/// } -/// } -/// ``` -pub trait ErrorToDlq { - /// Convert error message to a dlq message - fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq; - /// Match error if the orignal message is able to be retried or not - fn retryable(&self) -> Retryable; -} - -/// Struct with required details to send a channel message to the dlq -/// Error needs to be send as string, as it is not possible to send a struct that implements Error trait -pub struct SendToDlq { - kafka_message: OwnedMessage, - retryable: Retryable, - error: String, - stack_trace: Option, -} - -impl SendToDlq { - /// Create new SendToDlq message - pub fn new( - kafka_message: OwnedMessage, - retryable: Retryable, - error: String, - stack_trace: Option, - ) -> Self { - Self { - kafka_message, - retryable, - error, - stack_trace, - } - } - /// Send message to dlq channel - pub async fn send(self, dlq_tx: &mut mpsc::Sender) { - match dlq_tx.send(self).await { - Ok(_) => debug!("Message sent to DLQ channel"), - Err(e) => error!("Error sending message to DLQ: {}", e), - } - } - - fn get_original_msg(&self) -> OwnedMessage { - self.kafka_message.clone() - } -} - -/// Helper enum to decide to which topic the message should be sent to. -#[derive(Debug, Clone, Copy)] -pub enum Retryable { - Retryable, - NonRetryable, - Other, -} - -impl std::fmt::Display for Retryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Retryable::Retryable => write!(f, "Retryable"), - Retryable::NonRetryable => write!(f, "NonRetryable"), - Retryable::Other => write!(f, "Other"), - } - } -} - -/// Struct with implementation to send messages to the dlq -pub struct Dlq { - dlq_producer: FutureProducer, - dlq_rx: mpsc::Receiver, - dlq_tx: mpsc::Sender, - dlq_dead_topic: String, - dlq_retry_topic: String, - shutdown: Shutdown, -} - -impl Dlq { - /// Create new Dlq struct - pub fn new( - dsh_prop: &Properties, - shutdown: Shutdown, - ) -> Result> { - use crate::dsh::datastream::ReadWriteAccess; - let (dlq_tx, dlq_rx) = mpsc::channel(200); - let dlq_producer = Self::build_producer(dsh_prop)?; - let dlq_dead_topic = env::var("DLQ_DEAD_TOPIC")?; - let dlq_retry_topic = env::var("DLQ_RETRY_TOPIC")?; - dsh_prop.datastream().verify_list_of_topics( - &vec![&dlq_dead_topic, &dlq_retry_topic], - ReadWriteAccess::Write, - )?; - Ok(Self { - dlq_producer, - dlq_rx, - dlq_tx, - dlq_dead_topic, - dlq_retry_topic, - shutdown, - }) - } - - /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics - /// This function will run until the shutdown channel is closed - pub async fn run(&mut self) { - info!("DLQ started"); - loop { - tokio::select! { - _ = self.shutdown.recv() => { - warn!("DLQ shutdown"); - return; - }, - Some(mut dlq_message) = self.dlq_rx.recv() => { - match self.send(&mut dlq_message).await { - Ok(_) => {}, - Err(e) => error!("Error sending message to DLQ: {}", e), - }; - } - } - } - } - - /// Get the dlq channel sender. To be used in your service to send messages to the dlq in case of errors. - /// - /// This channel can be used to send messages to the dlq from different threads. - pub fn dlq_records_tx(&self) -> mpsc::Sender { - self.dlq_tx.clone() - } - - /// Create and send message towards the dlq - async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), rdkafka::error::KafkaError> { - let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); - let headers = orignal_kafka_msg - .generate_dlq_headers(dlq_message) - .to_owned_headers(); - let topic = self.dlq_topic(dlq_message.retryable); - let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); - let payload = orignal_kafka_msg.payload().unwrap_or_default(); - debug!("Sending message to DLQ topic: {}", topic); - let record = FutureRecord::to(topic) - .payload(payload) - .key(key) - .headers(headers); - let s = self.dlq_producer.send(record, None).await; - match s { - Ok((p, o)) => warn!( - "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", - from_utf8(key), - topic, - p, - o - ), - Err((e, _)) => return Err(e), - }; - Ok(()) - } - - fn dlq_topic(&self, retryable: Retryable) -> &str { - match retryable { - Retryable::Retryable => &self.dlq_retry_topic, - Retryable::NonRetryable => &self.dlq_dead_topic, - Retryable::Other => &self.dlq_dead_topic, - } - } - - fn build_producer(dsh_prop: &Properties) -> Result { - dsh_prop.producer_rdkafka_config().create() - } -} - -trait DlqHeaders { - fn generate_dlq_headers<'a>( - &'a self, - dlq_message: &'a mut SendToDlq, - ) -> HashMap<&'a str, Option>>; -} - -impl DlqHeaders for OwnedMessage { - fn generate_dlq_headers<'a>( - &'a self, - dlq_message: &'a mut SendToDlq, - ) -> HashMap<&'a str, Option>> { - let mut hashmap_headers: HashMap<&str, Option>> = HashMap::new(); - // Get original headers and add to hashmap - if let Some(headers) = self.headers() { - for header in headers.iter() { - hashmap_headers.insert(header.key, header.value.map(|v| v.to_vec())); - } - } - - // Add dlq headers if not exist (we don't want to overwrite original dlq headers if message already failed earlier) - let partition = self.partition().to_string().as_bytes().to_vec(); - let offset = self.offset().to_string().as_bytes().to_vec(); - let timestamp = self - .timestamp() - .to_millis() - .unwrap_or(-1) - .to_string() - .as_bytes() - .to_vec(); - hashmap_headers - .entry("dlq_topic_origin") - .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); - hashmap_headers - .entry("dlq_partition_origin") - .or_insert_with(move || Some(partition)); - hashmap_headers - .entry("dlq_partition_offset_origin") - .or_insert_with(move || Some(offset)); - hashmap_headers - .entry("dlq_topic_origin") - .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); - hashmap_headers - .entry("dlq_timestamp_origin") - .or_insert_with(move || Some(timestamp)); - // Overwrite if exist - hashmap_headers.insert( - "dlq_retryable", - Some(dlq_message.retryable.to_string().as_bytes().to_vec()), - ); - hashmap_headers.insert( - "dlq_error", - Some(dlq_message.error.to_string().as_bytes().to_vec()), - ); - if let Some(stack_trace) = &dlq_message.stack_trace { - hashmap_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); - } - // update dlq_retries with +1 if exists, else add dlq_retries wiith 1 - let retries = hashmap_headers - .get("dlq_retries") - .map(|v| { - let mut retries = [0; 4]; - retries.copy_from_slice(v.as_ref().unwrap()); - i32::from_be_bytes(retries) - }) - .unwrap_or(0); - hashmap_headers.insert("dlq_retries", Some((retries + 1).to_be_bytes().to_vec())); - - hashmap_headers - } -} - -trait HashMapToKafkaHeaders { - fn to_owned_headers(&self) -> OwnedHeaders; -} - -impl HashMapToKafkaHeaders for HashMap<&str, Option>> { - fn to_owned_headers(&self) -> OwnedHeaders { - // Convert to OwnedHeaders - let mut owned_headers = OwnedHeaders::new_with_capacity(self.len()); - for header in self { - let value = header.1.as_ref().map(|value| value.as_slice()); - owned_headers = owned_headers.insert(Header { - key: header.0, - value, - }); - } - owned_headers - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rdkafka::config::ClientConfig; - use rdkafka::mocking::MockCluster; - - #[derive(Debug)] - enum MockError { - MockErrorRetryable(String), - MockErrorDead(String), - } - impl MockError { - fn to_string(&self) -> String { - match self { - MockError::MockErrorRetryable(e) => e.to_string(), - MockError::MockErrorDead(e) => e.to_string(), - } - } - } - - impl std::fmt::Display for MockError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MockError::MockErrorRetryable(e) => write!(f, "{}", e), - MockError::MockErrorDead(e) => write!(f, "{}", e), - } - } - } - - impl ErrorToDlq for MockError { - fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq { - let backtrace = "some_backtrace"; - SendToDlq::new( - kafka_message, - self.retryable(), - self.to_string(), - Some(backtrace.to_string()), - ) - } - - fn retryable(&self) -> Retryable { - match self { - MockError::MockErrorRetryable(_) => Retryable::Retryable, - MockError::MockErrorDead(_) => Retryable::NonRetryable, - } - } - } - - #[test] - fn test_dlq_get_original_msg() { - let topic = "original_topic"; - let partition = 0; - let offset = 123; - let timestamp = 456; - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "some_key", - value: Some("some_value".as_bytes()), - }); - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - topic.to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - let dlq_message = - MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); - let result = dlq_message.get_original_msg(); - assert_eq!( - result.payload(), - dlq_message.kafka_message.payload(), - "payoad does not match" - ); - assert_eq!( - result.key(), - dlq_message.kafka_message.key(), - "key does not match" - ); - assert_eq!( - result.topic(), - dlq_message.kafka_message.topic(), - "topic does not match" - ); - assert_eq!( - result.partition(), - dlq_message.kafka_message.partition(), - "partition does not match" - ); - assert_eq!( - result.offset(), - dlq_message.kafka_message.offset(), - "offset does not match" - ); - assert_eq!( - result.timestamp(), - dlq_message.kafka_message.timestamp(), - "timestamp does not match" - ); - } - - #[test] - fn test_dlq_hashmap_to_owned_headers() { - let mut hashmap: HashMap<&str, Option>> = HashMap::new(); - hashmap.insert("some_key", Some(b"key_value".to_vec())); - hashmap.insert("some_other_key", None); - let result: Vec<(&str, Option<&[u8]>)> = - vec![("some_key", Some(b"key_value")), ("some_other_key", None)]; - - let owned_headers = hashmap.to_owned_headers(); - for header in owned_headers.iter() { - assert!(result.contains(&(header.key, header.value))); - } - } - - #[test] - fn test_dlq_topic() { - let mock_cluster = MockCluster::new(1).unwrap(); - let mut producer = ClientConfig::new(); - producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); - let producer = producer.create().unwrap(); - let dlq = Dlq { - dlq_producer: producer, - dlq_rx: mpsc::channel(200).1, - dlq_tx: mpsc::channel(200).0, - dlq_dead_topic: "dead_topic".to_string(), - dlq_retry_topic: "retry_topic".to_string(), - shutdown: Shutdown::new(), - }; - let error = MockError::MockErrorRetryable("some_error".to_string()); - let topic = dlq.dlq_topic(error.retryable()); - assert_eq!(topic, "retry_topic"); - let error = MockError::MockErrorDead("some_error".to_string()); - let topic = dlq.dlq_topic(error.retryable()); - assert_eq!(topic, "dead_topic"); - } - - #[test] - fn test_dlq_generate_dlq_headers() { - let topic = "original_topic"; - let partition = 0; - let offset = 123; - let timestamp = 456; - let error = Box::new(MockError::MockErrorRetryable("some_error".to_string())); - - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "some_key", - value: Some("some_value".as_bytes()), - }); - - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - topic.to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - - let mut dlq_message = error.to_dlq(owned_message.clone()); - - let mut expected_headers: HashMap<&str, Option>> = HashMap::new(); - expected_headers.insert("some_key", Some(b"some_value".to_vec())); - expected_headers.insert("dlq_topic_origin", Some(topic.as_bytes().to_vec())); - expected_headers.insert( - "dlq_partition_origin", - Some(partition.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_partition_offset_origin", - Some(offset.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_timestamp_origin", - Some(timestamp.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_retryable", - Some(Retryable::Retryable.to_string().as_bytes().to_vec()), - ); - expected_headers.insert("dlq_retries", Some(1_i32.to_be_bytes().to_vec())); - expected_headers.insert("dlq_error", Some(error.to_string().as_bytes().to_vec())); - if let Some(stack_trace) = &dlq_message.stack_trace { - expected_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); - } - - let result = owned_message.generate_dlq_headers(&mut dlq_message); - for header in result.iter() { - assert_eq!( - header.1, - expected_headers.get(header.0).unwrap_or(&None), - "Header {} does not match", - header.0 - ); - } - - // Test if dlq headers are correctly overwritten when to be retried message was already retried before - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "dlq_error", - value: Some( - "to_be_overwritten_error_as_this_was_the_original_error_from_1st_retry".as_bytes(), - ), - }); - original_headers = original_headers.insert(Header { - key: "dlq_topic_origin", - value: Some(topic.as_bytes()), - }); - original_headers = original_headers.insert(Header { - key: "dlq_retries", - value: Some(&1_i32.to_be_bytes().to_vec()), - }); - - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - "retry_topic".to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - let result = owned_message.generate_dlq_headers(&mut dlq_message); - assert_eq!( - result.get("dlq_error").unwrap(), - &Some(error.to_string().as_bytes().to_vec()) - ); - assert_eq!( - result.get("dlq_topic_origin").unwrap(), - &Some(topic.as_bytes().to_vec()) - ); - assert_eq!( - result.get("dlq_retries").unwrap(), - &Some(2_i32.to_be_bytes().to_vec()) - ); - } -} +#[deprecated( + since = "0.5.0", + note = "The DLQ is moved to [crate::utils::dlq](crate::utils::dlq)" +)] +pub use crate::utils::dlq::*; diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs new file mode 100644 index 0000000..dc3f3e9 --- /dev/null +++ b/dsh_sdk/src/dsh.rs @@ -0,0 +1,701 @@ +//! # Dsh +//! +//! This module contains the High-level struct for all related +//! +//! From `Dsh` there are level functions to get the correct config to connect to Kafka and schema store. +//! For more low level functions, see +//! - [datastream](datastream/index.html) module. +//! - [certificates](certificates/index.html) module. +//! +//! ## Environment variables +//! See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for +//! more information configuring the consmer or producer via environment variables. +//! +//! # Example +//! ``` +//! use dsh_sdk::Dsh; +//! use rdkafka::consumer::{Consumer, StreamConsumer}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let dsh_properties = Dsh::get(); +//! let consumer_config = dsh_properties.consumer_rdkafka_config(); +//! let consumer: StreamConsumer = consumer_config.create()?; +//! +//! # Ok(()) +//! # } +//! ``` +use log::{error, warn}; +use std::env; +use std::sync::OnceLock; + +use crate::certificates::Cert; +use crate::datastream::Datastream; +use crate::error::DshError; +use crate::protocol_adapters::kafka_protocol::config; +use crate::utils; +use crate::*; + +// TODP: Remove at v0.6.0 +pub use crate::dsh_old::*; + +/// DSH properties struct. Create new to initialize all related components to connect to the DSH kafka clusters +/// - Contains info from datastreams.json +/// - Metadata of running container/task +/// - Certificates for Kafka and DSH Schema Registry +/// +/// ## Environment variables +/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for +/// more information configuring the consmer or producer via environment variables. +/// +/// # Example +/// ``` +/// use dsh_sdk::Dsh; +/// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let dsh_properties = Dsh::get(); +/// +/// let consumer_config = dsh_properties.consumer_rdkafka_config(); +/// let consumer: StreamConsumer = consumer_config.create()?; +/// +/// Ok(()) +/// } +/// ``` + +#[derive(Debug, Clone)] +pub struct Dsh { + config_host: String, + task_id: String, + tenant_name: String, + datastream: Datastream, + certificates: Option, +} + +impl Dsh { + /// New `Dsh` struct + pub(crate) fn new( + config_host: String, + task_id: String, + tenant_name: String, + datastream: Datastream, + certificates: Option, + ) -> Self { + Self { + config_host, + task_id, + tenant_name, + datastream, + certificates, + } + } + /// Get the DSH Dsh on a lazy way. If not already initialized, it will initialize the properties + /// and bootstrap to DSH. + /// + /// This struct contains all configuration and certificates needed to connect to Kafka and DSH. + /// + /// - Contains a struct equal to datastreams.json + /// - Metadata of running container/task + /// - Certificates for Kafka and DSH + /// + /// # Panics + /// This method can panic when running on local machine and tries to load incorrect [local_datastream.json](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/local_datastreams.json). + /// When no file is available in root or path on env variable `LOCAL_DATASTREAMS_JSON` is not set, it will + /// return a default datastream struct and NOT panic. + /// + /// # Example + /// ``` + /// use dsh_sdk::Dsh; + /// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let dsh_properties = Dsh::get(); + /// let consumer: StreamConsumer = dsh_properties.consumer_rdkafka_config().create()?; + /// # Ok(()) + /// # } + /// ``` + pub fn get() -> &'static Self { + static PROPERTIES: OnceLock = OnceLock::new(); + PROPERTIES.get_or_init(|| tokio::task::block_in_place(Self::init)) + } + + /// Initialize the properties and bootstrap to DSH + fn init() -> Self { + let tenant_name = match utils::tenant_name() { + Ok(tenant_name) => tenant_name, + Err(_) => { + error!("{} and {} are not set, this may cause unexpected behaviour when connecting to DSH Kafka cluster!. Please set one of these environment variables.", VAR_APP_ID, VAR_DSH_TENANT_NAME); + "local_tenant".to_string() + } + }; + let task_id = utils::get_env_var(VAR_TASK_ID).unwrap_or("local_task_id".to_string()); + let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST) + .map(|host| format!("https://{}", host)) + .unwrap_or_else(|_| { + warn!( + "{} is not set, using default value {}", + VAR_KAFKA_CONFIG_HOST, DEFAULT_CONFIG_HOST + ); + DEFAULT_CONFIG_HOST.to_string() + }); + let certificates = if let Ok(cert) = Cert::from_pki_config_dir::(None) { + Some(cert) + } else { + Cert::from_bootstrap(&config_host, &tenant_name, &task_id) + .inspect_err(|e| { + warn!("Could not bootstrap to DSH, due to: {}", e); + }) + .ok() + }; + let fetched_datastreams = certificates.as_ref().and_then(|cert| { + cert.reqwest_blocking_client_config() + .ok() + .and_then(|cb| cb.build().ok()) + .and_then(|client| { + Datastream::fetch_blocking(&client, &config_host, &tenant_name, &task_id).ok() + }) + }); + let datastream = if let Some(datastream) = fetched_datastreams { + datastream + } else { + warn!("Could not fetch datastreams.json, using local or default datastreams"); + Datastream::load_local_datastreams().unwrap_or_default() + }; + Self::new(config_host, task_id, tenant_name, datastream, certificates) + } + + /// Get reqwest async client config to connect to DSH Schema Registry. + /// If certificates are present, it will use SSL to connect to Schema Registry. + /// + /// Use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) to connect to Schema Registry. + /// + /// # Example + /// ``` + /// # use dsh_sdk::Dsh; + /// # use reqwest::Client; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let dsh_properties = Dsh::get(); + /// let client = dsh_properties.reqwest_client_config()?.build()?; + /// # Ok(()) + /// # } + /// ``` + pub fn reqwest_client_config(&self) -> Result { + let mut client_builder = reqwest::Client::builder(); + if let Ok(certificates) = &self.certificates() { + client_builder = certificates.reqwest_client_config()?; + } + Ok(client_builder) + } + + /// Get reqwest blocking client config to connect to DSH Schema Registry. + /// If certificates are present, it will use SSL to connect to Schema Registry. + /// + /// Use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) to connect to Schema Registry. + /// + /// # Example + /// ``` + /// # use dsh_sdk::Dsh; + /// # use reqwest::blocking::Client; + /// # use dsh_sdk::error::DshError; + /// # fn main() -> Result<(), DshError> { + /// let dsh_properties = Dsh::get(); + /// let client = dsh_properties.reqwest_blocking_client_config()?.build()?; + /// # Ok(()) + /// # } + pub fn reqwest_blocking_client_config( + &self, + ) -> Result { + let mut client_builder: reqwest::blocking::ClientBuilder = + reqwest::blocking::Client::builder(); + if let Ok(certificates) = &self.certificates() { + client_builder = certificates.reqwest_blocking_client_config()?; + } + Ok(client_builder) + } + + /// Get the certificates and private key. Returns an error when running on local machine. + /// + /// # Example + /// ```no_run + /// # use dsh_sdk::Dsh; + /// # use dsh_sdk::error::DshError; + /// # fn main() -> Result<(), DshError> { + /// let dsh_properties = Dsh::get(); + /// let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); + /// # Ok(()) + /// # } + pub fn certificates(&self) -> Result<&Cert, DshError> { + if let Some(cert) = &self.certificates { + Ok(cert) + } else { + Err(DshError::NoCertificates) + } + } + + /// Get the client id based on the task id. + pub fn client_id(&self) -> &str { + &self.task_id + } + + /// Get the tenant name of running container. + pub fn tenant_name(&self) -> &str { + &self.tenant_name + } + + /// Get the task id of running container. + pub fn task_id(&self) -> &str { + &self.task_id + } + + /// Get the kafka properties provided by DSH (datastreams.json) + /// + /// This datastream is fetched at initialization of the properties, and can not be updated during runtime. + pub fn datastream(&self) -> &Datastream { + &self.datastream + } + + /// High level method to fetch the kafka properties provided by DSH (datastreams.json) + /// This will fetch the datastream from DSH. This can be used to update the datastream during runtime. + /// + /// This method keeps the reqwest client in memory to prevent creating a new client for every request. + /// + /// # Panics + /// This method panics when it can't initialize a reqwest client. + /// + /// Use [Datastream::fetch] as a lowlevel method where you can provide your own client. + pub async fn fetch_datastream(&self) -> Result { + static ASYNC_CLIENT: OnceLock = OnceLock::new(); + + let client = ASYNC_CLIENT.get_or_init(|| { + self.reqwest_client_config() + .expect("Failed loading certificates into reqwest client config") + .build() + .expect("Could not build reqwest client for fetching datastream") + }); + Datastream::fetch(client, &self.config_host, &self.tenant_name, &self.task_id).await + } + + /// High level method to fetch the kafka properties provided by DSH (datastreams.json) in a blocking way. + /// This will fetch the datastream from DSH. This can be used to update the datastream during runtime. + /// + /// This method keeps the reqwest client in memory to prevent creating a new client for every request. + /// + /// # Panics + /// This method panics when it can't initialize a reqwest client. + /// + /// Use [Datastream::fetch_blocking] as a lowlevel method where you can provide your own client. + pub fn fetch_datastream_blocking(&self) -> Result { + static BLOCKING_CLIENT: OnceLock = OnceLock::new(); + + let client = BLOCKING_CLIENT.get_or_init(|| { + self.reqwest_blocking_client_config() + .expect("Failed loading certificates into reqwest client config") + .build() + .expect("Could not build reqwest client for fetching datastream") + }); + Datastream::fetch_blocking(client, &self.config_host, &self.tenant_name, &self.task_id) + } + + /// Get schema host of DSH. + pub fn schema_registry_host(&self) -> &str { + self.datastream().schema_store() + } + + /// Get the Kafka brokers. + /// + /// ## Environment variables + /// To manipulate the hastnames of the brokers, you can set the following environment variables. + /// + /// ### `KAFKA_BOOTSTRAP_SERVERS` + /// - Usage: Overwrite hostnames of brokers + /// - Default: Brokers based on datastreams + /// - Required: `false` + pub fn kafka_brokers(&self) -> String { + self.datastream().get_brokers_string() + } + + /// Get the kafka_group_id based. + /// + /// ## Environment variables + /// To manipulate the group id, you can set the following environment variables. + /// + /// ### `KAFKA_CONSUMER_GROUP_TYPE` + /// - Usage: Picks group_id based on type from datastreams + /// - Default: Shared + /// - Options: private, shared + /// - Required: `false` + /// + /// ### `KAFKA_GROUP_ID` + /// - Usage: Custom group id + /// - Default: NA + /// - Required: `false` + /// - Remark: Overrules `KAFKA_CONSUMER_GROUP_TYPE`. Mandatory to start with tenant name. (will prefix tenant name automatically if not set) + pub fn kafka_group_id(&self) -> String { + if let Ok(group_id) = env::var(VAR_KAFKA_GROUP_ID) { + if !group_id.starts_with(self.tenant_name()) { + format!("{}_{}", self.tenant_name(), group_id) + } else { + group_id + } + } else { + self.datastream() + .get_group_id(crate::datastream::GroupType::from_env()) + .unwrap_or(&format!("{}_CONSUMER", self.tenant_name())) + .to_string() + } + } + + /// Get the confifured kafka auto commit setinngs. + /// + /// ## Environment variables + /// To manipulate the auto commit settings, you can set the following environment variables. + /// + /// ### `KAFKA_ENABLE_AUTO_COMMIT` + /// - Usage: Enable/Disable auto commit + /// - Default: `false` + /// - Required: `false` + /// - Options: `true`, `false` + pub fn kafka_auto_commit(&self) -> bool { + config::KafkaConfig::get().enable_auto_commit() + } + + /// Get the kafka auto offset reset settings. + /// + /// ## Environment variables + /// To manipulate the auto offset reset settings, you can set the following environment variables. + /// + /// ### `KAFKA_AUTO_OFFSET_RESET` + /// - Usage: Set the offset reset settings to start consuming from set option. + /// - Default: earliest + /// - Required: `false` + /// - Options: smallest, earliest, beginning, largest, latest, end + pub fn kafka_auto_offset_reset(&self) -> String { + config::KafkaConfig::get().auto_offset_reset() + } + #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let consumer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("group.id", self.kafka_group_id()) + .set("client.id", self.client_id()) + .set("enable.auto.commit", self.kafka_auto_commit().to_string()) + .set("auto.offset.reset", self.kafka_auto_offset_reset()); + if let Some(session_timeout) = consumer_config.session_timeout() { + config.set("session.timeout.ms", session_timeout.to_string()); + } + if let Some(queued_buffering_max_messages_kbytes) = + consumer_config.queued_buffering_max_messages_kbytes() + { + config.set( + "queued.max.messages.kbytes", + queued_buffering_max_messages_kbytes.to_string(), + ); + } + log::debug!("Consumer config: {:#?}", config); + // Set SSL if certificates are present + if let Ok(certificates) = &self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config + } + + #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let producer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("client.id", self.client_id()); + if let Some(batch_num_messages) = producer_config.batch_num_messages() { + config.set("batch.num.messages", batch_num_messages.to_string()); + } + if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { + config.set( + "queue.buffering.max.messages", + queue_buffering_max_messages.to_string(), + ); + } + if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { + config.set( + "queue.buffering.max.kbytes", + queue_buffering_max_kbytes.to_string(), + ); + } + if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { + config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); + } + log::debug!("Producer config: {:#?}", config); + + // Set SSL if certificates are present + if let Ok(certificates) = self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_KAFKA_CONSUMER_GROUP_TYPE}; + use serial_test::serial; + use std::io::Read; + + impl Default for Dsh { + fn default() -> Self { + let datastream = Datastream::load_local_datastreams().unwrap_or_default(); + Self { + task_id: "local_task_id".to_string(), + tenant_name: "local_tenant".to_string(), + config_host: "http://localhost/".to_string(), + datastream, + certificates: None, + } + } + } + + // maybe replace with local_datastreams.json? + fn datastreams_json() -> String { + std::fs::File::open("test_resources/valid_datastreams.json") + .map(|mut file| { + let mut contents = String::new(); + file.read_to_string(&mut contents).unwrap(); + contents + }) + .unwrap() + } + + // Define a reusable Dsh instance + fn datastream() -> Datastream { + serde_json::from_str(datastreams_json().as_str()).unwrap() + } + + #[test] + #[serial(env_dependency)] + fn test_get_or_init() { + let properties = Dsh::get(); + assert_eq!(properties.client_id(), "local_task_id"); + assert_eq!(properties.task_id, "local_task_id"); + assert_eq!(properties.tenant_name, "local_tenant"); + assert_eq!( + properties.config_host, + "https://pikachu.dsh.marathon.mesos:4443" + ); + assert!(properties.certificates.is_none()); + } + + #[test] + #[serial(env_dependency)] + fn test_reqwest_client_config() { + let properties = Dsh::default(); + let config = properties.reqwest_client_config(); + assert!(config.is_ok()); + } + + #[test] + #[serial(env_dependency)] + fn test_client_id() { + let properties = Dsh::default(); + assert_eq!(properties.client_id(), "local_task_id"); + } + + #[test] + #[serial(env_dependency)] + fn test_tenant_name() { + let properties = Dsh::default(); + assert_eq!(properties.tenant_name(), "local_tenant"); + } + + #[test] + #[serial(env_dependency)] + fn test_task_id() { + let properties = Dsh::default(); + assert_eq!(properties.task_id(), "local_task_id"); + } + + #[test] + #[serial(env_dependency)] + fn test_schema_registry_host() { + let properties = Dsh::default(); + assert_eq!( + properties.schema_registry_host(), + "http://localhost:8081/apis/ccompat/v7" + ); + } + + #[test] + #[serial(env_dependency)] + fn test_kafka_brokers() { + let properties = Dsh::default(); + assert_eq!( + properties.kafka_brokers(), + properties.datastream().get_brokers_string() + ); + env::set_var(VAR_KAFKA_BOOTSTRAP_SERVERS, "test:9092"); + let properties = Dsh::default(); + assert_eq!(properties.kafka_brokers(), "test:9092"); + env::remove_var(VAR_KAFKA_BOOTSTRAP_SERVERS); + } + + #[test] + #[serial(env_dependency)] + fn test_kafka_group_id() { + let properties = Dsh::default(); + assert_eq!( + properties.kafka_group_id(), + properties + .datastream() + .get_group_id(crate::datastream::GroupType::Shared(0)) + .unwrap() + ); + env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "private"); + assert_eq!( + properties.kafka_group_id(), + properties + .datastream() + .get_group_id(crate::datastream::GroupType::Private(0)) + .unwrap() + ); + env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "shared"); + assert_eq!( + properties.kafka_group_id(), + properties + .datastream() + .get_group_id(crate::datastream::GroupType::Shared(0)) + .unwrap() + ); + env::set_var(VAR_KAFKA_GROUP_ID, "test_group"); + assert_eq!( + properties.kafka_group_id(), + format!("{}_test_group", properties.tenant_name()) + ); + env::set_var( + VAR_KAFKA_GROUP_ID, + format!("{}_test_group", properties.tenant_name()), + ); + assert_eq!( + properties.kafka_group_id(), + format!("{}_test_group", properties.tenant_name()) + ); + env::remove_var(VAR_KAFKA_CONSUMER_GROUP_TYPE); + assert_eq!( + properties.kafka_group_id(), + format!("{}_test_group", properties.tenant_name()) + ); + env::remove_var(VAR_KAFKA_GROUP_ID); + } + + #[test] + #[serial(env_dependency)] + fn test_kafka_auto_commit() { + let properties = Dsh::default(); + assert!(!properties.kafka_auto_commit()); + } + + #[test] + #[serial(env_dependency)] + fn test_kafka_auto_offset_reset() { + let properties = Dsh::default(); + assert_eq!(properties.kafka_auto_offset_reset(), "earliest"); + } + + #[tokio::test] + async fn test_fetch_datastream() { + let mut server = mockito::Server::new_async().await; + let tenant = "test-tenant"; + let task_id = "test-task-id"; + let host = server.url(); + let prop = Dsh::new( + host, + task_id.to_string(), + tenant.to_string(), + Datastream::default(), + None, + ); + server + .mock("GET", "/kafka/config/test-tenant/test-task-id") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(datastreams_json()) + .create(); + let fetched_datastream = prop.fetch_datastream().await.unwrap(); + assert_eq!(fetched_datastream, datastream()); + } + + #[test] + fn test_fetch_blocking_datastream() { + let mut dsh = mockito::Server::new(); + let tenant = "test-tenant"; + let task_id = "test-task-id"; + let host = dsh.url(); + let prop = Dsh::new( + host, + task_id.to_string(), + tenant.to_string(), + Datastream::default(), + None, + ); + dsh.mock("GET", "/kafka/config/test-tenant/test-task-id") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(datastreams_json()) + .create(); + let fetched_datastream = prop.fetch_datastream_blocking().unwrap(); + assert_eq!(fetched_datastream, datastream()); + } + + #[test] + #[serial(env_dependency)] + fn test_consumer_rdkafka_config() { + let dsh = Dsh::default(); + let config = dsh.consumer_rdkafka_config(); + assert_eq!( + config.get("bootstrap.servers").unwrap(), + dsh.datastream().get_brokers_string() + ); + assert_eq!( + config.get("group.id").unwrap(), + dsh.datastream() + .get_group_id(crate::datastream::GroupType::from_env()) + .unwrap() + ); + assert_eq!(config.get("client.id").unwrap(), dsh.client_id()); + assert_eq!(config.get("enable.auto.commit").unwrap(), "false"); + assert_eq!(config.get("auto.offset.reset").unwrap(), "earliest"); + } + + #[test] + #[serial(env_dependency)] + fn test_producer_rdkafka_config() { + let dsh = Dsh::default(); + let config = dsh.producer_rdkafka_config(); + assert_eq!( + config.get("bootstrap.servers").unwrap(), + dsh.datastream().get_brokers_string() + ); + assert_eq!(config.get("client.id").unwrap(), dsh.client_id()); + } +} diff --git a/dsh_sdk/src/dsh/certificates.rs b/dsh_sdk/src/dsh_old/certificates.rs similarity index 73% rename from dsh_sdk/src/dsh/certificates.rs rename to dsh_sdk/src/dsh_old/certificates.rs index 0d6b506..04daaf8 100644 --- a/dsh_sdk/src/dsh/certificates.rs +++ b/dsh_sdk/src/dsh_old/certificates.rs @@ -11,33 +11,29 @@ //! To create the ca.crt, client.pem, and client.key files in a desired directory, use the //! `to_files` method. //! ```no_run -//! use dsh_sdk::Properties; +//! use dsh_sdk::certificates::Cert; //! use std::path::PathBuf; //! //! # fn main() -> Result<(), Box> { -//! let dsh_properties = Properties::get(); -//! let directory = PathBuf::from("dir"); -//! dsh_properties.certificates()?.to_files(&directory)?; +//! let certificates = Cert::from_env()?; +//! let directory = PathBuf::from("path/to/dir"); +//! certificates.to_files(&directory)?; //! # Ok(()) //! # } //! ``` //! //! ## Reqwest Client //! With this request client we can retrieve datastreams.json and connect to Schema Registry. +use std::path::PathBuf; use std::sync::Arc; use log::info; +use rcgen::KeyPair; use reqwest::blocking::{Client, ClientBuilder}; use reqwest::Identity; -use std::path::PathBuf; - -use super::bootstrap::{Dn, DshBootstapCall, DshConfig}; use crate::error::DshError; -use pem; -use rcgen::{CertificateParams, CertificateSigningRequest, DnType, KeyPair}; - /// Hold all relevant certificates and keys to connect to DSH Kafka Cluster and Schema Store. #[derive(Debug, Clone)] pub struct Cert { @@ -47,40 +43,6 @@ pub struct Cert { } impl Cert { - /// Create new `Cert` struct - pub(crate) fn new( - dsh_ca_certificate_pem: String, - dsh_client_certificate_pem: String, - key_pair: KeyPair, - ) -> Cert { - Self { - dsh_ca_certificate_pem, - dsh_client_certificate_pem, - key_pair: Arc::new(key_pair), - } - } - /// Generate private key and call for a signed certificate to DSH. - pub(crate) fn get_signed_client_cert( - dn: Dn, - dsh_config: &DshConfig, - client: &Client, - ) -> Result { - let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384)?; - let csr = Self::generate_csr(&key_pair, dn)?; - let client_certificate = DshBootstapCall::CertificateSignRequest { - config: dsh_config, - csr: &csr.pem()?, - } - .perform_call(client)?; - let ca_cert = pem::parse_many(dsh_config.dsh_ca_certificate())?; - let client_cert = pem::parse_many(client_certificate)?; - Ok(Self::new( - pem::encode_many(&ca_cert), - pem::encode_many(&client_cert), - key_pair, - )) - } - /// Build an async reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. pub fn reqwest_client_config(&self) -> Result { @@ -164,19 +126,6 @@ impl Cert { Ok(()) } - /// Generate the certificate signing request. - fn generate_csr(key_pair: &KeyPair, dn: Dn) -> Result { - let mut params = CertificateParams::default(); - params.distinguished_name.push(DnType::CommonName, dn.cn()); - params - .distinguished_name - .push(DnType::OrganizationalUnitName, dn.ou()); - params - .distinguished_name - .push(DnType::OrganizationName, dn.o()); - Ok(params.serialize_request(key_pair)?) - } - fn create_file>(path: PathBuf, contents: C) -> Result<(), DshError> { std::fs::write(&path, contents)?; info!("File created ({})", path.display()); @@ -202,6 +151,15 @@ impl Cert { } } +/// Helper function to ensure that the host starts with `https://` (or `http://`) +fn ensure_https_prefix(host: String) -> String { + if host.starts_with("https://") || host.starts_with("http://") { + host + } else { + format!("https://{}", host) + } +} + #[cfg(test)] mod tests { use super::*; @@ -209,7 +167,6 @@ mod tests { use std::sync::OnceLock; use openssl::pkey::PKey; - use openssl::x509::X509Req; static TEST_CERTIFICATES: OnceLock = OnceLock::new(); @@ -217,12 +174,11 @@ mod tests { let subject_alt_names = vec!["hello.world.example".to_string(), "localhost".to_string()]; let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names).unwrap(); - Cert::new(cert.pem(), cert.pem(), key_pair) - //Cert { - // dsh_ca_certificate_pem: CA_CERT.to_string(), - // dsh_client_certificate_pem: KAFKA_CERT.to_string(), - // key_pair: Arc::new(KeyPair::generate().unwrap()), - //} + Cert { + dsh_ca_certificate_pem: cert.pem(), + dsh_client_certificate_pem: cert.pem(), + key_pair: Arc::new(key_pair), + } } #[test] @@ -296,37 +252,6 @@ mod tests { assert!(std::path::Path::new(&format!("{}/client.key", dir)).exists()); } - #[test] - fn test_dsh_certificate_sign_request() { - let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); - let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); - let csr = Cert::generate_csr(&cert.key_pair, dn).unwrap(); - let req = csr.pem().unwrap(); - assert!(req.starts_with("-----BEGIN CERTIFICATE REQUEST-----")); - assert!(req.trim().ends_with("-----END CERTIFICATE REQUEST-----")); - } - - #[test] - fn test_verify_csr() { - let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); - let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); - let csr = Cert::generate_csr(&cert.key_pair, dn).unwrap(); - let csr_pem = csr.pem().unwrap(); - let key = cert.private_key_pkcs8(); - let pkey = PKey::private_key_from_der(&key).unwrap(); - - let req = X509Req::from_pem(csr_pem.as_bytes()).unwrap(); - req.verify(&pkey).unwrap(); - let subject = req - .subject_name() - .entries() - .into_iter() - .map(|e| e.data().as_utf8().unwrap().to_string()) - .collect::>() - .join(","); - assert_eq!(subject, "Test CN,Test OU,Test Org"); - } - #[test] fn test_create_identity() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); @@ -361,4 +286,19 @@ mod tests { let client = cert.reqwest_blocking_client_config(); assert!(client.is_ok()); } + + #[test] + fn test_ensure_https_prefix() { + let host = "http://example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "http://example.com"); + + let host = "https://example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + + let host = "example.com".to_string(); + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + } } diff --git a/dsh_sdk/src/dsh_old/datastream.rs b/dsh_sdk/src/dsh_old/datastream.rs new file mode 100644 index 0000000..4d86a67 --- /dev/null +++ b/dsh_sdk/src/dsh_old/datastream.rs @@ -0,0 +1,649 @@ +//! Module to handle the datastreams.json file. +//! +//! The datastreams.json can be parsed into a Datastream struct using serde_json. +//! This struct contains all the information from the datastreams.json file. +//! +//! You can get the Datastream struct via the 'Properties' struct. +//! +//! # Example +//! ``` +//! use dsh_sdk::Properties; +//! +//! let properties = Properties::get(); +//! let datastream = properties.datastream(); +//! +//! let brokers = datastream.get_brokers(); +//! let schema_store = datastream.schema_store(); +//! ``` +use std::collections::HashMap; + +use log::info; +use serde::{Deserialize, Serialize}; + +use crate::error::DshError; +use crate::{utils, VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_SCHEMA_REGISTRY_HOST}; + +/// This struct is equivalent to the datastreams.json +/// +/// # Example +/// ``` +/// use dsh_sdk::Properties; +/// +/// let properties = Properties::get(); +/// let datastream = properties.datastream(); +/// +/// let brokers = datastream.get_brokers(); +/// let streams = datastream.streams(); +/// let schema_store = datastream.schema_store(); +/// ``` +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Datastream { + brokers: Vec, + streams: HashMap, + private_consumer_groups: Vec, + shared_consumer_groups: Vec, + non_enveloped_streams: Vec, + schema_store: String, +} + +impl Datastream { + /// Get the kafka brokers from the datastreams as a vector of strings + pub fn get_brokers(&self) -> Vec<&str> { + self.brokers.iter().map(|s| s.as_str()).collect() + } + + /// Get the kafka brokers as comma seperated string from the datastreams + pub fn get_brokers_string(&self) -> String { + self.brokers.join(", ") + } + + /// Get the group id from the datastreams based on GroupType + /// + /// # Error + /// If the index is greater then amount of groups in the datastreams + /// (index out of bounds) + pub fn get_group_id(&self, group_type: GroupType) -> Result<&str, DshError> { + let group_id = match group_type { + GroupType::Private(i) => self.private_consumer_groups.get(i), + GroupType::Shared(i) => self.shared_consumer_groups.get(i), + }; + match group_id { + Some(id) => Ok(id), + None => Err(DshError::IndexGroupIdError(group_type)), + } + } + + /// Get all available datastreams (scratch topics, internal topics and stream topics) + pub fn streams(&self) -> &HashMap { + &self.streams + } + + /// Get a specific datastream based on the topic name + /// If the topic is not found, it will return None + pub fn get_stream(&self, topic: &str) -> Option<&Stream> { + // if topic name contains 2 dots, get the first 2 parts of the topic name + // this is needed because the topic name in datastreams.json is only the first 2 parts + let topic_name = topic.split('.').take(2).collect::>().join("."); + self.streams().get(&topic_name) + } + + /// Check if a list of topics is present in the read topics of datastreams + pub fn verify_list_of_topics( + &self, + topics: &Vec, + access: crate::datastream::ReadWriteAccess, + ) -> Result<(), DshError> { + let read_topics = self + .streams() + .values() + .map(|datastream| match access { + crate::datastream::ReadWriteAccess::Read => datastream + .read + .split('.') + .take(2) + .collect::>() + .join(".") + .replace('\\', ""), + crate::datastream::ReadWriteAccess::Write => datastream + .write + .split('.') + .take(2) + .collect::>() + .join(".") + .replace('\\', ""), + }) + .collect::>(); + for topic in topics { + let topic_name = topic + .to_string() + .split('.') + .take(2) + .collect::>() + .join("."); + if !read_topics.contains(&topic_name) { + return Err(DshError::NotFoundTopicError(topic.to_string())); + } + } + Ok(()) + } + + /// Get schema store url from datastreams. + /// + /// ## How to connect to schema registry + /// Use the Reqwest client from `Cert` to connect to the schema registry. + /// As this client is already configured with the correct certificates. + /// + /// You can use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) + /// to fetch the schema and decode your payload. + pub fn schema_store(&self) -> &str { + &self.schema_store + } + + /// Write datastreams.json in a directory + /// + /// # Example + /// ```no_run + /// # use dsh_sdk::datastream::Datastream; + /// # let datastream = Datastream::default(); + /// let path = std::path::PathBuf::from("/path/to/directory"); + /// datastream.to_file(&path).unwrap(); + /// ``` + pub fn to_file(&self, path: &std::path::Path) -> Result<(), DshError> { + let json_string = serde_json::to_string_pretty(self)?; + std::fs::write(path.join("datastreams.json"), json_string)?; + info!("File created ({})", path.display()); + Ok(()) + } + + /// Fetch datastreams from the dsh server (async) + /// + /// Make sure you use a Reqwest client from `Cert` to connect to the dsh server. + /// As this client is already configured with the correct certificates. + pub async fn fetch( + client: &reqwest::Client, + host: &str, + tenant: &str, + task_id: &str, + ) -> Result { + let url = Self::datastreams_endpoint(host, tenant, task_id); + let response = client.get(&url).send().await?; + if !response.status().is_success() { + return Err(DshError::DshCallError { + url, + status_code: response.status(), + error_body: response.text().await.unwrap_or_default(), + }); + } + Ok(response.json().await?) + } + + /// Fetch datastreams from the dsh server (blocking) + /// + /// Make sure you use a Reqwest client from `Cert` to connect to the dsh server. + /// As this client is already configured with the correct certificates. + pub fn fetch_blocking( + client: &reqwest::blocking::Client, + host: &str, + tenant: &str, + task_id: &str, + ) -> Result { + let url = Self::datastreams_endpoint(host, tenant, task_id); + let response = client.get(&url).send()?; + if !response.status().is_success() { + return Err(DshError::DshCallError { + url, + status_code: response.status(), + error_body: response.text().unwrap_or_default(), + }); + } + Ok(response.json()?) + } + + pub(crate) fn datastreams_endpoint(host: &str, tenant: &str, task_id: &str) -> String { + format!("{}/kafka/config/{}/{}", host, tenant, task_id) + } +} + +impl Default for Datastream { + fn default() -> Self { + let group_id = format!( + "{}_default_group", + utils::tenant_name().unwrap_or("local".to_string()) + ); + let brokers = if let Ok(brokers) = utils::get_env_var(VAR_KAFKA_BOOTSTRAP_SERVERS) { + brokers.split(',').map(|s| s.to_string()).collect() + } else { + vec!["localhost:9092".to_string()] + }; + let schema_store = utils::get_env_var(VAR_SCHEMA_REGISTRY_HOST) + .unwrap_or("http://localhost:8081/apis/ccompat/v7".to_string()); + Datastream { + brokers, + streams: HashMap::new(), + private_consumer_groups: vec![group_id.clone()], + shared_consumer_groups: vec![group_id], + non_enveloped_streams: Vec::new(), + schema_store, + } + } +} + +/// Struct containing all topic information which also is provided in datastreams.json +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct Stream { + name: String, + cluster: String, + read: String, + write: String, + partitions: i32, + replication: i32, + partitioner: String, + partitioning_depth: i32, + can_retain: bool, +} + +impl Stream { + /// Get the Stream's name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the Stream's cluster + pub fn cluster(&self) -> &str { + &self.cluster + } + + /// Get the read pattern as stated in datastreams. + /// + /// Use `read_pattern` method to validate if read access is allowed. + pub fn read(&self) -> &str { + &self.read + } + + /// Get the write pattern + /// + /// Use `write_pattern` method to validate if write access is allowed. + pub fn write(&self) -> &str { + &self.write + } + + /// Get the Stream's number of partitions + pub fn partitions(&self) -> i32 { + self.partitions + } + + /// Get the Stream's replication factor + pub fn replication(&self) -> i32 { + self.replication + } + + /// Get the Stream's partitioner + pub fn partitioner(&self) -> &str { + &self.partitioner + } + + /// Get the Stream's partitioning depth + pub fn partitioning_depth(&self) -> i32 { + self.partitioning_depth + } + + /// Get the Stream's can retain value + pub fn can_retain(&self) -> bool { + self.can_retain + } + + /// Check read access on topic based on datastream + pub fn read_access(&self) -> bool { + !self.read.is_empty() + } + + /// Check write access on topic based on datastream + pub fn write_access(&self) -> bool { + !self.write.is_empty() + } + + /// Get the Stream's Read whitelist pattern + /// + /// ## Error + /// If the topic does not have read access it returns a `TopicPermissionsError` + pub fn read_pattern(&self) -> Result<&str, DshError> { + if self.read_access() { + Ok(&self.read) + } else { + Err(DshError::TopicPermissionsError( + self.name.clone(), + crate::datastream::ReadWriteAccess::Read, + )) + } + } + + /// Get the Stream's Write pattern + /// + /// ## Error + /// If the topic does not have write access it returns a `TopicPermissionsError` + pub fn write_pattern(&self) -> Result<&str, DshError> { + if self.write_access() { + Ok(&self.write) + } else { + Err(DshError::TopicPermissionsError( + self.name.clone(), + crate::datastream::ReadWriteAccess::Write, + )) + } + } +} + +pub use crate::datastream::GroupType; +pub use crate::datastream::ReadWriteAccess; + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::io::Read; + + use crate::VAR_KAFKA_CONSUMER_GROUP_TYPE; + + // Define a reusable Properties instance + fn datastream() -> Datastream { + serde_json::from_str(datastreams_json().as_str()).unwrap() + } + + // maybe replace with local_datastreams.json? + fn datastreams_json() -> String { + std::fs::File::open("test_resources/valid_datastreams.json") + .map(|mut file| { + let mut contents = String::new(); + file.read_to_string(&mut contents).unwrap(); + contents + }) + .unwrap() + } + + #[test] + fn test_name() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.name(), "scratch.test"); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.name(), "stream.test"); + } + + #[test] + fn test_read() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.read(), "scratch.test.test-tenant"); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.read(), "stream\\.test\\.[^.]*"); + } + + #[test] + fn test_write() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.write(), "scratch.test.test-tenant"); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.write(), ""); + } + + #[test] + fn test_cluster() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.cluster(), "/tt"); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.cluster(), "/tt"); + } + + #[test] + fn test_partitions() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.partitions(), 3); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.partitions(), 1); + } + + #[test] + fn test_replication() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.replication(), 1); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.replication(), 1); + } + + #[test] + fn test_partitioner() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.partitioner(), "default-partitioner"); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.partitioner(), "default-partitioner"); + } + + #[test] + fn test_partitioning_depth() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.partitioning_depth(), 0); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.partitioning_depth(), 0); + } + + #[test] + fn test_can_retain() { + let datastream = datastream(); + let stream = datastream.streams().get("scratch.test").unwrap(); + assert_eq!(stream.can_retain(), false); + let stream = datastream.streams().get("stream.test").unwrap(); + assert_eq!(stream.can_retain(), true); + } + + #[test] + fn test_datastream_get_brokers() { + assert_eq!( + datastream().get_brokers(), + vec![ + "broker-0.tt.kafka.mesos:9091", + "broker-1.tt.kafka.mesos:9091", + "broker-2.tt.kafka.mesos:9091" + ] + ); + } + + #[test] + fn test_datastream_get_brokers_string() { + assert_eq!( + datastream().get_brokers_string(), + "broker-0.tt.kafka.mesos:9091, broker-1.tt.kafka.mesos:9091, broker-2.tt.kafka.mesos:9091" + ); + } + + #[test] + fn test_datastream_verify_list_of_topics() { + let topics = vec![ + "scratch.test.test-tenant".to_string(), + "stream.test.test-tenant".to_string(), + ]; + datastream() + .verify_list_of_topics(&topics, ReadWriteAccess::Read) + .unwrap() + } + + #[test] + fn test_datastream_get_schema_store() { + assert_eq!( + datastream().schema_store(), + "http://schema-registry.tt.kafka.mesos:8081" + ); + } + + #[test] + #[serial(env_dependency)] + fn test_datastream_get_group_type_from_env() { + // Set the KAFKA_CONSUMER_GROUP_TYPE environment variable to "private" + std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "private"); + assert_eq!(GroupType::from_env(), GroupType::Private(0),); + std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "shared"); + assert_eq!(GroupType::from_env(), GroupType::Shared(0),); + std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "invalid-type"); + assert_eq!(GroupType::from_env(), GroupType::Shared(0),); + std::env::remove_var(VAR_KAFKA_CONSUMER_GROUP_TYPE); + assert_eq!(GroupType::from_env(), GroupType::Shared(0),); + } + + #[test] + fn test_datastream_get_group_id() { + assert_eq!( + datastream().get_group_id(GroupType::Private(0)).unwrap(), + "test-app.7e93a513-6556-11eb-841e-f6ab8576620c_1", + "KAFKA_CONSUMER_GROUP_TYPE is set to private, but did not return test-app.7e93a513-6556-11eb-841e-f6ab8576620c_1" + ); + assert_eq!( + datastream().get_group_id(GroupType::Shared(0)).unwrap(), + "test-app_1", + "KAFKA_CONSUMER_GROUP_TYPE is set to shared, but did not return test-app_1" + ); + assert_eq!( + datastream().get_group_id(GroupType::Shared(3)).unwrap(), + "test-app_4", + "KAFKA_CONSUMER_GROUP_TYPE is set to shared, but did not return test-app_1" + ); + assert!(datastream().get_group_id(GroupType::Private(1000)).is_err(),); + } + + #[test] + fn test_datastream_check_access_read_topic() { + assert_eq!( + datastream() + .get_stream("scratch.test.test-tenant") + .unwrap() + .read_access(), + true + ); + assert_eq!( + datastream() + .get_stream("stream.test.test-tenant") + .unwrap() + .read_access(), + true + ); + } + + #[test] + fn test_datastream_check_access_write_topic() { + assert_eq!( + datastream() + .get_stream("scratch.test.test-tenant") + .unwrap() + .write_access(), + true + ); + assert_eq!( + datastream() + .get_stream("stream.test.test-tenant") + .unwrap() + .write_access(), + false + ); + } + + #[test] + fn test_datastream_check_read_topic() { + assert_eq!( + datastream() + .get_stream("scratch.test.test-tenant") + .unwrap() + .read_pattern() + .unwrap(), + "scratch.test.test-tenant" + ); + assert_eq!( + datastream() + .get_stream("stream.test.test-tenant") + .unwrap() + .read_pattern() + .unwrap(), + "stream\\.test\\.[^.]*" + ); + } + + #[test] + fn test_datastream_check_write_topic() { + assert_eq!( + datastream() + .get_stream("scratch.test.test-tenant") + .unwrap() + .write_pattern() + .unwrap(), + "scratch.test.test-tenant" + ); + let e = datastream() + .get_stream("stream.test.test-tenant") + .unwrap() + .write_pattern() + .unwrap_err(); + + assert!(matches!( + e, + DshError::TopicPermissionsError(_, ReadWriteAccess::Write) + )); + } + + #[test] + fn test_to_file() { + let test_path = std::path::PathBuf::from("test_files"); + let result = datastream().to_file(&test_path); + assert!(result.is_ok()) + } + + #[test] + fn test_datastream_endpoint() { + let host = "http://localhost:8080"; + let tenant = "test-tenant"; + let task_id = "test-task-id"; + let endpoint = Datastream::datastreams_endpoint(host, tenant, task_id); + assert_eq!( + endpoint, + "http://localhost:8080/kafka/config/test-tenant/test-task-id" + ); + } + + #[tokio::test] + async fn test_fetch() { + let mut dsh = mockito::Server::new_async().await; + let tenant = "test-tenant"; + let task_id = "test-task-id"; + let host = dsh.url(); + dsh.mock("GET", "/kafka/config/test-tenant/test-task-id") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(datastreams_json()) + .create(); + let client = reqwest::Client::new(); + let fetched_datastream = Datastream::fetch(&client, &host, tenant, task_id) + .await + .unwrap(); + assert_eq!(fetched_datastream, datastream()); + } + + #[test] + fn test_fetch_blocking() { + let mut dsh = mockito::Server::new(); + let tenant = "test-tenant"; + let task_id = "test-task-id"; + let host = dsh.url(); + dsh.mock("GET", "/kafka/config/test-tenant/test-task-id") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(datastreams_json()) + .create(); + let client = reqwest::blocking::Client::new(); + let fetched_datastream = + Datastream::fetch_blocking(&client, &host, tenant, task_id).unwrap(); + assert_eq!(fetched_datastream, datastream()); + } +} diff --git a/dsh_sdk/src/dsh/mod.rs b/dsh_sdk/src/dsh_old/mod.rs similarity index 56% rename from dsh_sdk/src/dsh/mod.rs rename to dsh_sdk/src/dsh_old/mod.rs index 96d6019..c0024aa 100644 --- a/dsh_sdk/src/dsh/mod.rs +++ b/dsh_sdk/src/dsh_old/mod.rs @@ -21,13 +21,36 @@ //! # Ok(()) //! # } //! ``` -mod bootstrap; + +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::dsh::certificates` is moved to `dsh_sdk::certificates`" +)] pub mod certificates; -mod config; +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::dsh::datastream` is moved to `dsh_sdk::datastream`" +)] pub mod datastream; -mod pki_config_dir; +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::dsh::properties` is moved to `dsh_sdk::dsh`" +)] pub mod properties; // Re-export the properties struct to avoid braking changes -pub use super::utils::get_configured_topics; + +#[deprecated( + since = "0.5.0", + note = "get_configured_topics is moved to `dsh_sdk::utils::get_configured_topics`" +)] +pub fn get_configured_topics() -> Result, crate::error::DshError> { + let kafka_topic_string = crate::utils::get_env_var("TOPICS")?; + Ok(kafka_topic_string + .split(',') + .map(str::trim) + .map(String::from) + .collect()) +} + pub use properties::Properties; diff --git a/dsh_sdk/src/dsh/properties.rs b/dsh_sdk/src/dsh_old/properties.rs similarity index 96% rename from dsh_sdk/src/dsh/properties.rs rename to dsh_sdk/src/dsh_old/properties.rs index ce9427a..32e347d 100644 --- a/dsh_sdk/src/dsh/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -25,18 +25,16 @@ //! # Ok(()) //! # } //! ``` -use log::{debug, error, warn}; +use log::{error, warn}; use std::env; use std::sync::OnceLock; -use super::bootstrap::bootstrap; -use super::{certificates, config, datastream, pki_config_dir}; +use crate::certificates::Cert; +use crate::datastream; use crate::error::DshError; +use crate::protocol_adapters::kafka_protocol::config; use crate::utils; use crate::*; -static PROPERTIES: OnceLock = OnceLock::new(); -static CONSUMER_CONFIG: OnceLock = OnceLock::new(); -static PRODUCER_CONFIG: OnceLock = OnceLock::new(); /// DSH properties struct. Create new to initialize all related components to connect to the DSH kafka clusters /// - Contains info from datastreams.json @@ -63,6 +61,7 @@ static PRODUCER_CONFIG: OnceLock = OnceLock::new(); /// } /// ``` +#[deprecated(since = "0.5.0", note = "`Properties` is renamed to `dsh_sdk::Dsh`")] #[derive(Debug, Clone)] pub struct Properties { config_host: String, @@ -116,6 +115,7 @@ impl Properties { /// # } /// ``` pub fn get() -> &'static Self { + static PROPERTIES: OnceLock = OnceLock::new(); PROPERTIES.get_or_init(|| tokio::task::block_in_place(Self::init)) } @@ -138,10 +138,10 @@ impl Properties { ); DEFAULT_CONFIG_HOST.to_string() }); - let certificates = if let Ok(cert) = pki_config_dir::get_pki_cert() { + let certificates = if let Ok(cert) = Cert::from_pki_config_dir::(None) { Some(cert) } else { - bootstrap(&config_host, &tenant_name, &task_id) + Cert::from_bootstrap(&config_host, &tenant_name, &task_id) .inspect_err(|e| { warn!("Could not bootstrap to DSH, due to: {}", e); }) @@ -170,172 +170,6 @@ impl Properties { Self::new(config_host, task_id, tenant_name, datastream, certificates) } - /// Get default RDKafka Consumer config to connect to Kafka on DSH. - /// - /// Note: This config is set to auto commit to false. You need to manually commit offsets. - /// You can overwrite this config by setting the enable.auto.commit and enable.auto.offset.store property to `true`. - /// - /// # Group ID - /// There are 2 types of group id's in DSH: private and shared. Private will have a unique group id per running instance. - /// Shared will have the same group id for all running instances. With this you can horizontally scale your service. - /// The group type can be manipulated by environment variable KAFKA_CONSUMER_GROUP_TYPE. - /// If not set, it will default to shared. - /// - /// # Example - /// ``` - /// use dsh_sdk::Properties; - /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; - /// use dsh_sdk::rdkafka::consumer::stream_consumer::StreamConsumer; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box> { - /// let dsh_properties = Properties::get(); - /// let mut consumer_config = dsh_properties.consumer_rdkafka_config(); - /// let consumer: StreamConsumer = consumer_config.create()?; - /// Ok(()) - /// } - /// ``` - /// - /// # Default configs - /// See full list of configs properties in case you want to add/overwrite the config: - /// - /// - /// Some configurations are overwitable by environment variables. - /// - /// | **config** | **Default value** | **Remark** | - /// |---------------------------|----------------------------------|------------------------------------------------------------------------| - /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | - /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| - /// | `client.id` | Task_id of service | | - /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | - /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | - /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | - /// | `ssl.key.pem` | private key | Generated when bootstrap is initiated | - /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | - /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | - /// - /// ## Environment variables - /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information - /// configuring the consmer via environment variables. - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] - pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); - let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("group.id", self.kafka_group_id()) - .set("client.id", self.client_id()) - .set("enable.auto.commit", self.kafka_auto_commit().to_string()) - .set("auto.offset.reset", self.kafka_auto_offset_reset()); - if let Some(session_timeout) = consumer_config.session_timeout() { - config.set("session.timeout.ms", session_timeout.to_string()); - } - if let Some(queued_buffering_max_messages_kbytes) = - consumer_config.queued_buffering_max_messages_kbytes() - { - config.set( - "queued.max.messages.kbytes", - queued_buffering_max_messages_kbytes.to_string(), - ); - } - debug!("Consumer config: {:#?}", config); - // Set SSL if certificates are present - if let Ok(certificates) = &self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } - config - } - - /// Get default RDKafka Producer config to connect to Kafka on DSH. - /// If certificates are present, it will use SSL to connect to Kafka. - /// If not, it will use plaintext so it can connect to local as well. - /// - /// Note: The default config is set to auto commit to false. You need to manually commit offsets. - /// - /// # Example - /// ``` - /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; - /// use dsh_sdk::rdkafka::producer::FutureProducer; - /// use dsh_sdk::Properties; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box>{ - /// let dsh_properties = Properties::get(); - /// let mut producer_config = dsh_properties.producer_rdkafka_config(); - /// let producer: FutureProducer = producer_config.create().expect("Producer creation failed"); - /// Ok(()) - /// } - /// ``` - /// - /// # Default configs - /// See full list of configs properties in case you want to manually add/overwrite the config: - /// - /// - /// | **config** | **Default value** | **Remark** | - /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| - /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | - /// | client.id | task_id of service | Based on task_id of running service | - /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | - /// | ssl.key.pem | private key | Generated when bootstrap is initiated | - /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | - /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | - /// | log_level | Info | Log level of rdkafka | - /// - /// ## Environment variables - /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information - /// configuring the producer via environment variables. - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] - pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let producer_config = PRODUCER_CONFIG.get_or_init(config::ProducerConfig::new); - let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("client.id", self.client_id()); - if let Some(batch_num_messages) = producer_config.batch_num_messages() { - config.set("batch.num.messages", batch_num_messages.to_string()); - } - if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { - config.set( - "queue.buffering.max.messages", - queue_buffering_max_messages.to_string(), - ); - } - if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { - config.set( - "queue.buffering.max.kbytes", - queue_buffering_max_kbytes.to_string(), - ); - } - if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { - config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); - } - debug!("Producer config: {:#?}", config); - - // Set SSL if certificates are present - if let Ok(certificates) = self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } - config - } - /// Get reqwest async client config to connect to DSH Schema Registry. /// If certificates are present, it will use SSL to connect to Schema Registry. /// @@ -535,8 +369,7 @@ impl Properties { /// - Required: `false` /// - Options: `true`, `false` pub fn kafka_auto_commit(&self) -> bool { - let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); - consumer_config.enable_auto_commit() + config::KafkaConfig::get().enable_auto_commit() } /// Get the kafka auto offset reset settings. @@ -550,8 +383,173 @@ impl Properties { /// - Required: `false` /// - Options: smallest, earliest, beginning, largest, latest, end pub fn kafka_auto_offset_reset(&self) -> String { - let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); - consumer_config.auto_offset_reset() + config::KafkaConfig::get().auto_offset_reset() + } + + /// Get default RDKafka Consumer config to connect to Kafka on DSH. + /// + /// Note: This config is set to auto commit to false. You need to manually commit offsets. + /// You can overwrite this config by setting the enable.auto.commit and enable.auto.offset.store property to `true`. + /// + /// # Group ID + /// There are 2 types of group id's in DSH: private and shared. Private will have a unique group id per running instance. + /// Shared will have the same group id for all running instances. With this you can horizontally scale your service. + /// The group type can be manipulated by environment variable KAFKA_CONSUMER_GROUP_TYPE. + /// If not set, it will default to shared. + /// + /// # Example + /// ``` + /// use dsh_sdk::Properties; + /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; + /// use dsh_sdk::rdkafka::consumer::stream_consumer::StreamConsumer; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let dsh_properties = Properties::get(); + /// let mut consumer_config = dsh_properties.consumer_rdkafka_config(); + /// let consumer: StreamConsumer = consumer_config.create()?; + /// Ok(()) + /// } + /// ``` + /// + /// # Default configs + /// See full list of configs properties in case you want to add/overwrite the config: + /// + /// + /// Some configurations are overwitable by environment variables. + /// + /// | **config** | **Default value** | **Remark** | + /// |---------------------------|----------------------------------|------------------------------------------------------------------------| + /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | + /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| + /// | `client.id` | Task_id of service | | + /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | + /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | + /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | + /// | `ssl.key.pem` | private key | Generated when bootstrap is initiated | + /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | + /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | + /// + /// ## Environment variables + /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information + /// configuring the consmer via environment variables. + #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let consumer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("group.id", self.kafka_group_id()) + .set("client.id", self.client_id()) + .set("enable.auto.commit", self.kafka_auto_commit().to_string()) + .set("auto.offset.reset", self.kafka_auto_offset_reset()); + if let Some(session_timeout) = consumer_config.session_timeout() { + config.set("session.timeout.ms", session_timeout.to_string()); + } + if let Some(queued_buffering_max_messages_kbytes) = + consumer_config.queued_buffering_max_messages_kbytes() + { + config.set( + "queued.max.messages.kbytes", + queued_buffering_max_messages_kbytes.to_string(), + ); + } + log::debug!("Consumer config: {:#?}", config); + // Set SSL if certificates are present + if let Ok(certificates) = &self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config + } + + /// Get default RDKafka Producer config to connect to Kafka on DSH. + /// If certificates are present, it will use SSL to connect to Kafka. + /// If not, it will use plaintext so it can connect to local as well. + /// + /// Note: The default config is set to auto commit to false. You need to manually commit offsets. + /// + /// # Example + /// ``` + /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; + /// use dsh_sdk::rdkafka::producer::FutureProducer; + /// use dsh_sdk::Properties; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box>{ + /// let dsh_properties = Properties::get(); + /// let mut producer_config = dsh_properties.producer_rdkafka_config(); + /// let producer: FutureProducer = producer_config.create().expect("Producer creation failed"); + /// Ok(()) + /// } + /// ``` + /// + /// # Default configs + /// See full list of configs properties in case you want to manually add/overwrite the config: + /// + /// + /// | **config** | **Default value** | **Remark** | + /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| + /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | + /// | client.id | task_id of service | Based on task_id of running service | + /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | + /// | ssl.key.pem | private key | Generated when bootstrap is initiated | + /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | + /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | + /// | log_level | Info | Log level of rdkafka | + /// + /// ## Environment variables + /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information + /// configuring the producer via environment variables. + #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let producer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("client.id", self.client_id()); + if let Some(batch_num_messages) = producer_config.batch_num_messages() { + config.set("batch.num.messages", batch_num_messages.to_string()); + } + if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { + config.set( + "queue.buffering.max.messages", + queue_buffering_max_messages.to_string(), + ); + } + if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { + config.set( + "queue.buffering.max.kbytes", + queue_buffering_max_kbytes.to_string(), + ); + } + if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { + config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); + } + log::debug!("Producer config: {:#?}", config); + + // Set SSL if certificates are present + if let Ok(certificates) = self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config } } diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index 1a47e45..1996b1d 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -7,11 +7,11 @@ use thiserror::Error; pub enum DshError { #[error("IO Error: {0}")] IoError(#[from] std::io::Error), - #[error("Env var error: {0}")] - EnvVarError(#[from] std::env::VarError), + #[error("Env variable {0} error: {1}")] + EnvVarError(String, std::env::VarError), #[error("Convert bytes to utf8 error: {0}")] Utf8(#[from] std::string::FromUtf8Error), - #[cfg(any(feature = "bootstrap", feature = "mqtt-token-fetcher"))] + #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] DshCallError { url: String, @@ -21,24 +21,24 @@ pub enum DshError { #[cfg(feature = "bootstrap")] #[error("Certificates are not set")] NoCertificates, - #[cfg(feature = "bootstrap")] + #[cfg(any(feature = "bootstrap", feature = "pki-config-dir"))] #[error("Invalid PEM certificate: {0}")] PemError(#[from] pem::PemError), - #[cfg(any(feature = "bootstrap", feature = "mqtt-token-fetcher"))] + #[cfg(any(feature = "certificate", feature = "protocol-token-fetcher"))] #[error("Reqwest: {0}")] ReqwestError(#[from] reqwest::Error), - #[cfg(any(feature = "bootstrap", feature = "mqtt-token-fetcher"))] + #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] #[error("Serde_json error: {0}")] JsonError(#[from] serde_json::Error), #[cfg(feature = "bootstrap")] #[error("Rcgen error: {0}")] PrivateKeyError(#[from] rcgen::Error), - #[cfg(any(feature = "bootstrap", feature = "mqtt-token-fetcher"))] + #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] #[error("Error parsing: {0}")] ParseDnError(String), #[cfg(feature = "bootstrap")] #[error("Error getting group id, index out of bounds for {0}")] - IndexGroupIdError(crate::dsh::datastream::GroupType), + IndexGroupIdError(crate::datastream::GroupType), #[error("No tenant name found")] NoTenantName, #[cfg(feature = "bootstrap")] @@ -46,7 +46,7 @@ pub enum DshError { NotFoundTopicError(String), #[cfg(feature = "bootstrap")] #[error("Error in topic permissions: {0} does not have {1:?} permissions.")] - TopicPermissionsError(String, crate::dsh::datastream::ReadWriteAccess), + TopicPermissionsError(String, crate::datastream::ReadWriteAccess), #[cfg(feature = "metrics")] #[error("Prometheus error: {0}")] Prometheus(#[from] prometheus::Error), @@ -55,7 +55,7 @@ pub enum DshError { HyperError(#[from] hyper::http::Error), } -#[cfg(feature = "rest-token-fetcher")] +#[cfg(feature = "management-api")] #[derive(Error, Debug)] #[non_exhaustive] pub enum DshRestTokenError { diff --git a/dsh_sdk/src/graceful_shutdown.rs b/dsh_sdk/src/graceful_shutdown.rs index 942a0d8..0354bf1 100644 --- a/dsh_sdk/src/graceful_shutdown.rs +++ b/dsh_sdk/src/graceful_shutdown.rs @@ -1,4 +1,4 @@ -//! Graceful shutdown for tokio tasks. +//! Graceful shutdown //! //! This module provides a shutdown handle for graceful shutdown of (tokio tasks within) your service. //! It listens for SIGTERM requests and sends out shutdown requests to all shutdown handles. diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 0276d35..47e58b3 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -72,28 +72,82 @@ //! The DLQ is implemented by running the `Dlq` struct to push messages towards the DLQ topics. //! The `ErrorToDlq` trait can be implemented on your defined errors, to be able to send messages towards the DLQ Struct. -#[cfg(feature = "dlq")] -pub mod dlq; +#![allow(deprecated)] + +// to be kept in v0.6.0 +#[cfg(feature = "certificate")] +pub mod certificates; +#[cfg(feature = "bootstrap")] +pub mod datastream; #[cfg(feature = "bootstrap")] pub mod dsh; pub mod error; -#[cfg(feature = "graceful_shutdown")] +#[cfg(feature = "management-api")] +pub mod management_api; +pub mod protocol_adapters; +pub mod utils; + +#[cfg(feature = "bootstrap")] +#[doc(inline)] +pub use dsh::Dsh; + +#[cfg(feature = "management-api")] +pub use management_api::token_fetcher::{ + ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder, +}; + +#[doc(inline)] +pub use utils::Platform; + +// TODO: to be removed in v0.6.0 +#[cfg(feature = "dlq")] +#[deprecated(since = "0.5.0", note = "The DLQ is moved to `dsh_sdk::utils::dlq`")] +pub mod dlq; + +#[cfg(feature = "bootstrap")] +#[deprecated( + since = "0.5.0", + note = "The `dsh` as module is phased out. Use + `dsh_sdk::Dsh` for all info about your running container; + `dsh_sdk::certificates` for all certificate related info; + `dsh_sdk::datastream` for all datastream related info; + " +)] +pub mod dsh_old; + +#[cfg(feature = "graceful-shutdown")] +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::graceful_shutdown` is moved to `dsh_sdk::utils::graceful_shutdown`" +)] pub mod graceful_shutdown; + #[cfg(feature = "metrics")] +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::metrics` is moved to `dsh_sdk::utils::metrics`" +)] pub mod metrics; + #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] pub use rdkafka; -#[cfg(feature = "mqtt-token-fetcher")] +#[cfg(feature = "protocol-token-fetcher")] +#[deprecated( + since = "0.5.0", + note = "`dsh_sdk::mqtt_token_fetcher` is moved to `dsh_sdk::protocol_adapters::token_fetcher`" +)] pub mod mqtt_token_fetcher; -#[cfg(feature = "rest-token-fetcher")] -mod rest_api_token_fetcher; -mod utils; - #[cfg(feature = "bootstrap")] -pub use dsh::Properties; -#[cfg(feature = "rest-token-fetcher")] +pub use dsh_old::Properties; + +#[cfg(feature = "management-api")] +#[deprecated( + since = "0.5.0", + note = "`RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to `ManagementApiTokenFetcher` and `ManagementApiTokenFetcherBuilder`" +)] +mod rest_api_token_fetcher; +#[cfg(feature = "management-api")] pub use rest_api_token_fetcher::{RestTokenFetcher, RestTokenFetcherBuilder}; -pub use utils::Platform; // Environment variables const VAR_APP_ID: &str = "MARATHON_APP_ID"; diff --git a/dsh_sdk/src/management_api/mod.rs b/dsh_sdk/src/management_api/mod.rs new file mode 100644 index 0000000..a3e273a --- /dev/null +++ b/dsh_sdk/src/management_api/mod.rs @@ -0,0 +1 @@ +pub mod token_fetcher; diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs new file mode 100644 index 0000000..eda43be --- /dev/null +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -0,0 +1,597 @@ +//! Module for fetching and storing access tokens for the DSH Management Rest API client +//! +//! This module is meant to be used together with the [dsh_rest_api_client]. +//! +//! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. +//! +//! ## Example +//! Recommended usage is to use the [RestTokenFetcherBuilder] to create a new instance of the token fetcher. +//! However, you can also create a new instance of the token fetcher directly. +//! ```no_run +//! use dsh_sdk::{RestTokenFetcherBuilder, Platform}; +//! use dsh_rest_api_client::Client; +//! +//! const CLIENT_SECRET: &str = ""; +//! const TENANT: &str = "tenant-name"; +//! +//! #[tokio::main] +//! async fn main() { +//! let platform = Platform::NpLz; +//! let client = Client::new(platform.endpoint_rest_api()); +//! +//! let tf = RestTokenFetcherBuilder::new(platform) +//! .tenant_name(TENANT.to_string()) +//! .client_secret(CLIENT_SECRET.to_string()) +//! .build() +//! .unwrap(); +//! +//! let response = client +//! .topic_get_by_tenant_topic(TENANT, &tf.get_token().await.unwrap()) +//! .await; +//! println!("Available topics: {:#?}", response); +//! } +//! ``` + +use std::fmt::Debug; +use std::ops::Add; +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +use log::debug; +use serde::Deserialize; + +use crate::error::DshRestTokenError; +use crate::utils::Platform; + +/// Access token of the authentication serveice of DSH. +/// +/// This is the response whem requesting for a new access token. +/// +/// ## Recommended usage +/// Use the [RestTokenFetcher::get_token] to get the bearer token, the `TokenFetcher` will automatically fetch a new token if the current token is not valid. +#[derive(Debug, Clone, Deserialize)] +pub struct AccessToken { + access_token: String, + expires_in: u64, + refresh_expires_in: u32, + token_type: String, + #[serde(rename(deserialize = "not-before-policy"))] + not_before_policy: u32, + scope: String, +} + +impl AccessToken { + /// Get the formatted token + pub fn formatted_token(&self) -> String { + format!("{} {}", self.token_type, self.access_token) + } + + /// Get the access token + pub fn access_token(&self) -> &str { + &self.access_token + } + + /// Get the expires in of the access token + pub fn expires_in(&self) -> u64 { + self.expires_in + } + + /// Get the refresh expires in of the access token + pub fn refresh_expires_in(&self) -> u32 { + self.refresh_expires_in + } + + /// Get the token type of the access token + pub fn token_type(&self) -> &str { + &self.token_type + } + + /// Get the not before policy of the access token + pub fn not_before_policy(&self) -> u32 { + self.not_before_policy + } + + /// Get the scope of the access token + pub fn scope(&self) -> &str { + &self.scope + } +} + +impl Default for AccessToken { + fn default() -> Self { + Self { + access_token: "".to_string(), + expires_in: 0, + refresh_expires_in: 0, + token_type: "".to_string(), + not_before_policy: 0, + scope: "".to_string(), + } + } +} + +/// Fetch and store access tokens to be used in the DSH Rest API client +/// +/// This struct will fetch and store access tokens to be used in the DSH Rest API client. +/// It will automatically fetch a new token if the current token is not valid. +pub struct ManagementApiTokenFetcher { + access_token: Mutex, + fetched_at: Mutex, + client_id: String, + client_secret: String, + client: reqwest::Client, + auth_url: String, +} + +impl ManagementApiTokenFetcher { + /// Create a new instance of the token fetcher + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::{ManagementApiTokenFetcher, Platform}; + /// use dsh_rest_api_client::Client; + /// + /// #[tokio::main] + /// async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret".to_string(); + /// let token_fetcher = ManagementApiTokenFetcher::new(client_id, client_secret, platform.endpoint_rest_access_token().to_string()); + /// let token = token_fetcher.get_token().await.unwrap(); + /// } + /// ``` + pub fn new(client_id: String, client_secret: String, auth_url: String) -> Self { + Self::new_with_client( + client_id, + client_secret, + auth_url, + reqwest::Client::default(), + ) + } + + /// Create a new instance of the token fetcher with custom reqwest client + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::{ManagementApiTokenFetcher, Platform}; + /// use dsh_rest_api_client::Client; + /// + /// #[tokio::main] + /// async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret".to_string(); + /// let client = reqwest::Client::new(); + /// let token_fetcher = ManagementApiTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_rest_access_token().to_string(), client); + /// let token = token_fetcher.get_token().await.unwrap(); + /// } + /// ``` + pub fn new_with_client( + client_id: String, + client_secret: String, + auth_url: String, + client: reqwest::Client, + ) -> Self { + Self { + access_token: Mutex::new(AccessToken::default()), + fetched_at: Mutex::new(Instant::now()), + client_id, + client_secret, + client, + auth_url, + } + } + + /// Get token from the token fetcher + /// + /// If the cached token is not valid, it will fetch a new token from the server. + /// It will return the token as a string, formatted as "{token_type} {token}" + /// If the request fails for a new token, it will return a [DshRestTokenError::FailureTokenFetch] error. + /// This will contain the underlying reqwest error. + pub async fn get_token(&self) -> Result { + match self.is_valid() { + true => Ok(self.access_token.lock().unwrap().formatted_token()), + false => { + debug!("Token is expired, fetching new token"); + let access_token = self.fetch_access_token_from_server().await?; + let mut token = self.access_token.lock().unwrap(); + let mut fetched_at = self.fetched_at.lock().unwrap(); + *token = access_token; + *fetched_at = Instant::now(); + Ok(token.formatted_token()) + } + } + } + + /// Check if the current access token is still valid + /// + /// If the token has expired, it will return false. + pub fn is_valid(&self) -> bool { + let access_token = self.access_token.lock().unwrap_or_else(|mut e| { + **e.get_mut() = AccessToken::default(); + self.access_token.clear_poison(); + e.into_inner() + }); + let fetched_at = self.fetched_at.lock().unwrap_or_else(|e| { + self.fetched_at.clear_poison(); + e.into_inner() + }); + // Check if expires in has elapsed (+ safety margin of 5 seconds) + fetched_at.elapsed().add(Duration::from_secs(5)) + < Duration::from_secs(access_token.expires_in) + } + + /// Fetch a new access token from the server + /// + /// This will fetch a new access token from the server and return it. + /// If the request fails, it will return a [DshRestTokenError::FailureTokenFetch] error. + /// If the status code is not successful, it will return a [DshRestTokenError::StatusCode] error. + /// If the request is successful, it will return the [AccessToken]. + pub async fn fetch_access_token_from_server(&self) -> Result { + let response = self + .client + .post(&self.auth_url) + .form(&[ + ("client_id", self.client_id.as_ref()), + ("client_secret", self.client_secret.as_ref()), + ("grant_type", "client_credentials"), + ]) + .send() + .await + .map_err(DshRestTokenError::FailureTokenFetch)?; + if !response.status().is_success() { + Err(DshRestTokenError::StatusCode { + status_code: response.status(), + error_body: response, + }) + } else { + response + .json::() + .await + .map_err(DshRestTokenError::FailureTokenFetch) + } + } +} + +impl Debug for ManagementApiTokenFetcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ManagementApiTokenFetcher") + .field("access_token", &self.access_token) + .field("fetched_at", &self.fetched_at) + .field("client_id", &self.client_id) + .field("client_secret", &"xxxxxx") + .field("auth_url", &self.auth_url) + .finish() + } +} + +/// Builder for the token fetcher +pub struct ManagementApiTokenFetcherBuilder { + client: Option, + client_id: Option, + client_secret: Option, + platform: Platform, + tenant_name: Option, +} + +impl ManagementApiTokenFetcherBuilder { + /// Get a new instance of the ClientBuilder + /// + /// # Arguments + /// * `platform` - The target platform to use for the token fetcher + pub fn new(platform: Platform) -> Self { + Self { + client: None, + client_id: None, + client_secret: None, + platform, + tenant_name: None, + } + } + + /// Set the client_id for the client + /// + /// Alternatively, set `tenant_name` to generate the client_id. + /// `Client_id` does have precedence over `tenant_name`. + pub fn client_id(mut self, client_id: String) -> Self { + self.client_id = Some(client_id); + self + } + + /// Set the client_secret for the client + pub fn client_secret(mut self, client_secret: String) -> Self { + self.client_secret = Some(client_secret); + self + } + + /// Set the tenant_name for the client, this will generate the client_id + /// + /// Alternatively, set `client_id` directly. + /// `Tenant_name` does have precedence over `client_id`. + pub fn tenant_name(mut self, tenant_name: String) -> Self { + self.tenant_name = Some(tenant_name); + self + } + + /// Provide a custom configured Reqwest client for the token + /// + /// This is optional, if not provided, a default client will be used. + pub fn client(mut self, client: reqwest::Client) -> Self { + self.client = Some(client); + self + } + + /// Build the client and token fetcher + /// + /// This will build the client and token fetcher based on the given parameters. + /// It will return a tuple with the client and token fetcher. + /// + /// ## Example + /// ``` + /// # use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; + /// let platform = Platform::NpLz; + /// let client_id = "robot:dev-lz-dsh:my-tenant".to_string(); + /// let client_secret = "secret".to_string(); + /// let tf = ManagementApiTokenFetcherBuilder::new(platform) + /// .client_id(client_id) + /// .client_secret(client_secret) + /// .build() + /// .unwrap(); + /// ``` + pub fn build(self) -> Result { + let client_secret = self + .client_secret + .ok_or(DshRestTokenError::UnknownClientSecret)?; + let client_id = self + .client_id + .or_else(|| { + self.tenant_name + .as_ref() + .map(|tenant_name| self.platform.rest_client_id(tenant_name)) + }) + .ok_or(DshRestTokenError::UnknownClientId)?; + let client = self.client.unwrap_or_default(); + let token_fetcher = ManagementApiTokenFetcher::new_with_client( + client_id, + client_secret, + self.platform.endpoint_rest_access_token().to_string(), + client, + ); + Ok(token_fetcher) + } +} + +impl Debug for ManagementApiTokenFetcherBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let client_secret = self + .client_secret + .as_ref() + .map(|_| "Some(\"client_secret\")"); + f.debug_struct("ManagementApiTokenFetcherBuilder") + .field("client_id", &self.client_id) + .field("client_secret", &client_secret) + .field("platform", &self.platform) + .field("tenant_name", &self.tenant_name) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn create_mock_tf() -> ManagementApiTokenFetcher { + ManagementApiTokenFetcher { + access_token: Mutex::new(AccessToken::default()), + fetched_at: Mutex::new(Instant::now()), + client_id: "client_id".to_string(), + client_secret: "client_secret".to_string(), + client: reqwest::Client::new(), + auth_url: "http://localhost".to_string(), + } + } + + #[test] + fn test_access_token() { + let token_str = r#"{ + "access_token": "secret_access_token", + "expires_in": 600, + "refresh_expires_in": 0, + "token_type": "Bearer", + "not-before-policy": 0, + "scope": "email" + }"#; + let token: AccessToken = serde_json::from_str(token_str).unwrap(); + assert_eq!(token.access_token(), "secret_access_token"); + assert_eq!(token.expires_in(), 600); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), "email"); + assert_eq!(token.formatted_token(), "Bearer secret_access_token"); + } + + #[test] + fn test_access_token_default() { + let token = AccessToken::default(); + assert_eq!(token.access_token(), ""); + assert_eq!(token.expires_in(), 0); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), ""); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), ""); + assert_eq!(token.formatted_token(), " "); + } + + #[test] + fn test_rest_token_fetcher_is_valid_default_token() { + // Test is_valid when validating default token (should expire in 0 seconds) + let tf = create_mock_tf(); + assert!(!tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_valid_token() { + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + assert!(tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_expired_token() { + // Test is_valid when validating an expired token + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + *tf.fetched_at.lock().unwrap() = Instant::now() - Duration::from_secs(600); + assert!(!tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_poisoned_token() { + // Test is_valid when token is poisoned + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + let tf_arc = std::sync::Arc::new(tf); + let tf_clone = tf_arc.clone(); + assert!(tf_arc.is_valid(), "Token should be valid"); + let h = std::thread::spawn(move || { + let _unused = tf_clone.access_token.lock().unwrap(); + panic!("Poison token") + }); + let _ = h.join(); + assert!(!tf_arc.is_valid(), "Token should be invalid"); + } + + #[tokio::test] + async fn test_fetch_access_token_from_server() { + let mut auth_server = mockito::Server::new_async().await; + auth_server + .mock("POST", "/") + .with_status(200) + .with_body( + r#"{ + "access_token": "secret_access_token", + "expires_in": 600, + "refresh_expires_in": 0, + "token_type": "Bearer", + "not-before-policy": 0, + "scope": "email" + }"#, + ) + .create(); + let mut tf = create_mock_tf(); + tf.auth_url = auth_server.url(); + let token = tf.fetch_access_token_from_server().await.unwrap(); + assert_eq!(token.access_token(), "secret_access_token"); + assert_eq!(token.expires_in(), 600); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), "email"); + } + + #[tokio::test] + async fn test_fetch_access_token_from_server_error() { + let mut auth_server = mockito::Server::new_async().await; + auth_server + .mock("POST", "/") + .with_status(400) + .with_body("Bad request") + .create(); + let mut tf = create_mock_tf(); + tf.auth_url = auth_server.url(); + let err = tf.fetch_access_token_from_server().await.unwrap_err(); + match err { + DshRestTokenError::StatusCode { + status_code, + error_body, + } => { + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body.text().await.unwrap(), "Bad request"); + } + _ => panic!("Unexpected error: {:?}", err), + } + } + + #[test] + fn test_token_fetcher_builder_client_id() { + let platform = Platform::NpLz; + let client_id = "robot:dev-lz-dsh:my-tenant"; + let client_secret = "secret"; + let tf = ManagementApiTokenFetcherBuilder::new(platform) + .client_id(client_id.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_tenant_name() { + let platform = Platform::NpLz; + let tenant_name = "my-tenant"; + let client_secret = "secret"; + let tf = ManagementApiTokenFetcherBuilder::new(platform) + .tenant_name(tenant_name.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!( + tf.client_id, + format!("robot:{}:{}", Platform::NpLz.realm(), tenant_name) + ); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_custom_client() { + let platform = Platform::NpLz; + let client_id = "robot:dev-lz-dsh:my-tenant"; + let client_secret = "secret"; + let custom_client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); + let tf = ManagementApiTokenFetcherBuilder::new(platform) + .client_id(client_id.to_string()) + .client_secret(client_secret.to_string()) + .client(custom_client.clone()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_client_id_precedence() { + let platform = Platform::NpLz; + let tenant = "my-tenant"; + let client_id_override = "override"; + let client_secret = "secret"; + let tf = ManagementApiTokenFetcherBuilder::new(platform) + .tenant_name(tenant.to_string()) + .client_id(client_id_override.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id_override); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_build_error() { + let err = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) + .client_secret("client_secret".to_string()) + .build() + .unwrap_err(); + assert!(matches!(err, DshRestTokenError::UnknownClientId)); + + let err = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) + .tenant_name("tenant_name".to_string()) + .build() + .unwrap_err(); + assert!(matches!(err, DshRestTokenError::UnknownClientSecret)); + } +} diff --git a/dsh_sdk/src/metrics.rs b/dsh_sdk/src/metrics.rs index 11d38e6..bf7bde0 100644 --- a/dsh_sdk/src/metrics.rs +++ b/dsh_sdk/src/metrics.rs @@ -183,19 +183,19 @@ fn full>(chunk: T) -> BoxBody { #[cfg(test)] mod tests { + use super::*; use http_body_util::Empty; use hyper::body::Body; use hyper::client::conn; use hyper::client::conn::http1::{Connection, SendRequest}; use hyper::http::HeaderValue; use hyper::Uri; + use serial_test::serial; use tokio::net::TcpStream; - use super::*; - lazy_static! { - pub static ref HIGH_FIVE_COUNTER: IntCounter = - register_int_counter!("highfives", "Number of high fives recieved").unwrap(); + pub static ref HIGH_FIVE_COUNTER_OLD: IntCounter = + register_int_counter!("highfives_old", "Number of high fives recieved").unwrap(); } async fn create_client( @@ -224,9 +224,10 @@ mod tests { } #[tokio::test] + #[serial(port_usage)] async fn test_http_metric_response() { // Increment the counter - HIGH_FIVE_COUNTER.inc(); + HIGH_FIVE_COUNTER_OLD.inc(); // Call the function let res = get_metrics(); @@ -247,12 +248,13 @@ mod tests { } #[tokio::test] + #[serial(port_usage)] async fn test_start_http_server() { // Start HTTP server let server = start_http_server(8080); // increment the counter - HIGH_FIVE_COUNTER.inc(); + HIGH_FIVE_COUNTER_OLD.inc(); // Give the server a moment to start tokio::time::sleep(std::time::Duration::from_secs(1)).await; @@ -288,6 +290,7 @@ mod tests { } #[tokio::test] + #[serial(port_usage)] async fn test_unknown_path() { // Start HTTP server let server = start_http_server(9900); @@ -323,7 +326,7 @@ mod tests { #[test] fn test_metrics_to_string() { - HIGH_FIVE_COUNTER.inc(); + HIGH_FIVE_COUNTER_OLD.inc(); let res = metrics_to_string().unwrap(); assert!(res.contains("highfives")); } diff --git a/dsh_sdk/src/mqtt_token_fetcher.rs b/dsh_sdk/src/mqtt_token_fetcher.rs index 43da3de..0fcdf55 100644 --- a/dsh_sdk/src/mqtt_token_fetcher.rs +++ b/dsh_sdk/src/mqtt_token_fetcher.rs @@ -1,858 +1,24 @@ -//! # MQTT Token Fetcher -//! -//! `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. -use std::fmt::{Display, Formatter}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use tokio::sync::Mutex; +pub use crate::protocol_adapters::token_fetcher::*; -use dashmap::DashMap; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use sha2::{Digest, Sha256}; +use crate::Platform; -use crate::{error::DshError, Platform}; +#[deprecated( + since = "0.5.0", + note = "`MqttTokenFetcher` is renamed to `ProtocolTokenFetcher`" +)] +pub struct MqttTokenFetcher; -/// `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. -/// -/// It ensures that the tokens are valid, and if not, it refreshes them automatically. The struct -/// is thread-safe and can be shared across multiple threads. - -pub struct MqttTokenFetcher { - tenant_name: String, - rest_api_key: String, - rest_token: Mutex, - rest_auth_url: String, - mqtt_token: DashMap, // Mapping from Client ID to MqttToken - mqtt_auth_url: String, - client: reqwest::Client, - //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration - // port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too -} - -/// Constructs a new `MqttTokenFetcher`. -/// -/// # Arguments -/// -/// * `tenant_name` - The tenant name in DSH. -/// * `rest_api_key` - The REST API key used for authentication. -/// * `platform` - The DSH platform environment -/// -/// # Returns -/// -/// Returns a `Result` containing a `MqttTokenFetcher` instance or a `DshError`. impl MqttTokenFetcher { - /// Constructs a new `MqttTokenFetcher`. - /// - /// # Arguments - /// - /// * `tenant_name` - The tenant name of DSH. - /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. - /// * `platform` - The target DSH platform environment. - /// - /// # Example - /// - /// ```no_run - /// use dsh_sdk::mqtt_token_fetcher::MqttTokenFetcher; - /// use dsh_sdk::Platform; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let tenant_name = "test_tenant".to_string(); - /// let api_key = "aAbB123".to_string(); - /// let platform = Platform::NpLz; - /// - /// let fetcher = MqttTokenFetcher::new(tenant_name, api_key, platform); - /// let token = fetcher.get_token("test_client", None).await.unwrap(); - /// # } - /// ``` - pub fn new(tenant_name: String, api_key: String, platform: Platform) -> MqttTokenFetcher { - const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); - - let reqwest_client = reqwest::Client::builder() - .timeout(DEFAULT_TIMEOUT) - .http1_only() - .build() - .expect("Failed to build reqwest client"); - Self::new_with_client(tenant_name, api_key, platform, reqwest_client) + pub fn new(tenant_name: String, api_key: String, platform: Platform) -> ProtocolTokenFetcher { + ProtocolTokenFetcher::new(tenant_name, api_key, platform) } - /// Constructs a new `MqttTokenFetcher` with a custom reqwest client. - /// On this Reqwest client, you can set custom timeouts, headers, Rustls etc. - /// - /// # Arguments - /// - /// * `tenant_name` - The tenant name of DSH. - /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. - /// * `platform` - The target DSH platform environment. - /// * `client` - User configured reqwest client to be used for fetching tokens - /// - /// # Example - /// - /// ```no_run - /// use dsh_sdk::mqtt_token_fetcher::MqttTokenFetcher; - /// use dsh_sdk::Platform; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let tenant_name = "test_tenant".to_string(); - /// let api_key = "aAbB123".to_string(); - /// let platform = Platform::NpLz; - /// let client = reqwest::Client::new(); - /// let fetcher = MqttTokenFetcher::new_with_client(tenant_name, api_key, platform, client); - /// let token = fetcher.get_token("test_client", None).await.unwrap(); - /// # } - /// ``` pub fn new_with_client( tenant_name: String, api_key: String, platform: Platform, client: reqwest::Client, - ) -> MqttTokenFetcher { - let rest_token = RestToken::default(); - Self { - tenant_name, - rest_api_key: api_key, - rest_token: Mutex::new(rest_token), - rest_auth_url: platform.endpoint_rest_token().to_string(), - mqtt_token: DashMap::new(), - mqtt_auth_url: platform.endpoint_mqtt_token().to_string(), - client, - } - } - /// Retrieves an MQTT token for the specified client ID. - /// - /// If the token is expired or does not exist, it fetches a new token. - /// - /// # Arguments - /// - /// * `client_id` - The identifier for the MQTT client. - /// * `claims` - Optional claims for the MQTT token. - /// - /// # Returns - /// - /// Returns a `Result` containing the `MqttToken` or a `DshError`. - pub async fn get_token( - &self, - client_id: &str, - claims: Option>, - ) -> Result { - match self.mqtt_token.entry(client_id.to_string()) { - dashmap::Entry::Occupied(mut entry) => { - let mqtt_token = entry.get_mut(); - if !mqtt_token.is_valid() { - *mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; - }; - Ok(mqtt_token.clone()) - } - dashmap::Entry::Vacant(entry) => { - let mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; - entry.insert(mqtt_token.clone()); - Ok(mqtt_token) - } - } - } - - /// Fetches a new MQTT token from the platform. - /// - /// This method handles token validation and fetching the token - async fn fetch_new_mqtt_token( - &self, - client_id: &str, - claims: Option>, - ) -> Result { - let mut rest_token = self.rest_token.lock().await; - - if !rest_token.is_valid() { - *rest_token = RestToken::get( - &self.client, - &self.tenant_name, - &self.rest_api_key, - &self.rest_auth_url, - ) - .await? - } - - let authorization_header = format!("Bearer {}", rest_token.raw_token); - - let mqtt_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; - let payload = serde_json::to_value(&mqtt_token_request)?; - - let response = mqtt_token_request - .send( - &self.client, - &self.mqtt_auth_url, - &authorization_header, - &payload, - ) - .await?; - - MqttToken::new(response) - } -} - -/// Represent Claims information for MQTT request -/// * `action` - can be subscribe or publish -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Claims { - resource: Resource, - action: String, -} - -impl Claims { - pub fn new(resource: Resource, action: String) -> Claims { - Claims { resource, action } - } -} - -/// Enumeration representing possible actions in MQTT claims. -pub enum Actions { - Publish, - Subscribe, -} - -impl Display for Actions { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match self { - Actions::Publish => write!(f, "Publish"), - Actions::Subscribe => write!(f, "Subscribe"), - } - } -} - -/// Represents a resource in the MQTT claim. -/// -/// The resource defines what the client can access in terms of stream, prefix, topic, and type. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Resource { - stream: String, - prefix: String, - topic: String, - #[serde(rename = "type")] - type_: Option, -} - -impl Resource { - /// Creates a new `Resource` instance. Please check DSH MQTT Documentation for further explanation of the fields. - /// - /// # Arguments - /// - /// * `stream` - The data stream name. - /// * `prefix` - The prefix of the topic. - /// * `topic` - The topic name. - /// * `type_` - The optional type of the resource. - /// - /// - /// # Returns - /// - /// Returns a new `Resource` instance. - pub fn new(stream: String, prefix: String, topic: String, type_: Option) -> Resource { - Resource { - stream, - prefix, - topic, - type_, - } - } -} - -#[derive(Serialize)] -struct MqttTokenRequest { - id: String, - tenant: String, - claims: Option>, -} - -impl MqttTokenRequest { - fn new( - client_id: &str, - tenant: &str, - claims: Option>, - ) -> Result { - let mut hasher = Sha256::new(); - hasher.update(client_id); - let result = hasher.finalize(); - let id = format!("{:x}", result); - - Ok(Self { - id, - tenant: tenant.to_string(), - claims, - }) - } - - async fn send( - &self, - reqwest_client: &reqwest::Client, - mqtt_auth_url: &str, - authorization_header: &str, - payload: &serde_json::Value, - ) -> Result { - let response = reqwest_client - .post(mqtt_auth_url) - .header("Authorization", authorization_header) - .json(payload) - .send() - .await?; - - if response.status().is_success() { - Ok(response.text().await?) - } else { - Err(DshError::DshCallError { - url: mqtt_auth_url.to_string(), - status_code: response.status(), - error_body: response.text().await?, - }) - } - } -} - -/// Represents attributes associated with a mqtt token. -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "kebab-case")] -struct MqttTokenAttributes { - gen: i32, - endpoint: String, - iss: String, - claims: Option>, - exp: i32, - client_id: String, - iat: i32, - tenant_id: String, -} - -/// Represents a token used for MQTT connections. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MqttToken { - exp: i32, - raw_token: String, -} - -impl MqttToken { - /// Creates a new instance of `MqttToken` from a raw token string. - /// - /// # Arguments - /// - /// * `raw_token` - The raw token string. - /// - /// # Returns - /// - /// A Result containing the created MqttToken or an error. - pub fn new(raw_token: String) -> Result { - let header_payload = extract_header_and_payload(&raw_token)?; - - let decoded_token = decode_base64(header_payload)?; - - let token_attributes: MqttTokenAttributes = serde_json::from_slice(&decoded_token)?; - let token = MqttToken { - exp: token_attributes.exp, - raw_token, - }; - - Ok(token) - } - - /// Checks if the MQTT token is still valid. - fn is_valid(&self) -> bool { - let current_unixtime = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - .as_secs() as i32; - self.exp >= current_unixtime + 5 - } -} - -/// Represents attributes associated with a Rest token. -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "kebab-case")] -struct RestTokenAttributes { - gen: i64, - endpoint: String, - iss: String, - claims: RestClaims, - exp: i32, - tenant_id: String, -} - -#[derive(Serialize, Deserialize, Debug)] -struct RestClaims { - #[serde(rename = "datastreams/v0/mqtt/token")] - datastreams_token: DatastreamsData, -} - -#[derive(Serialize, Deserialize, Debug)] -struct DatastreamsData {} - -/// Represents a rest token with its raw value and attributes. -#[derive(Serialize, Deserialize, Debug)] -struct RestToken { - raw_token: String, - exp: i32, -} - -impl RestToken { - /// Retrieves a new REST token from the platform. - /// - /// # Arguments - /// - /// * `tenant` - The tenant name associated with the DSH platform. - /// * `api_key` - The REST API key used for authentication. - /// * `env` - The platform environment (e.g., production, staging). - /// - /// # Returns - /// - /// A Result containing the created `RestToken` or a `DshError`. - async fn get( - client: &reqwest::Client, - tenant: &str, - api_key: &str, - auth_url: &str, - ) -> Result { - let raw_token = Self::fetch_token(client, tenant, api_key, auth_url).await?; - - let header_payload = extract_header_and_payload(&raw_token)?; - - let decoded_token = decode_base64(header_payload)?; - - let token_attributes: RestTokenAttributes = serde_json::from_slice(&decoded_token)?; - let token = RestToken { - raw_token, - exp: token_attributes.exp, - }; - - Ok(token) - } - - // Checks if the REST token is still valid. - fn is_valid(&self) -> bool { - let current_unixtime = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - .as_secs() as i32; - self.exp >= current_unixtime + 5 - } - - async fn fetch_token( - client: &reqwest::Client, - tenant: &str, - api_key: &str, - auth_url: &str, - ) -> Result { - let json_body = json!({"tenant": tenant}); - - let response = client - .post(auth_url) - .header("apikey", api_key) - .json(&json_body) - .send() - .await?; - - let status = response.status(); - let body_text = response.text().await?; - match status { - reqwest::StatusCode::OK => Ok(body_text), - _ => Err(DshError::DshCallError { - url: auth_url.to_string(), - status_code: status, - error_body: body_text, - }), - } - } -} - -impl Default for RestToken { - fn default() -> Self { - Self { - raw_token: "".to_string(), - exp: 0, - } - } -} - -/// Extracts the header and payload part of a JWT token. -/// -/// # Arguments -/// -/// * `raw_token` - The raw JWT token string. -/// -/// # Returns -/// -/// A Result containing the header and payload part of the JWT token or a `DshError`. -fn extract_header_and_payload(raw_token: &str) -> Result<&str, DshError> { - let parts: Vec<&str> = raw_token.split('.').collect(); - parts - .get(1) - .copied() - .ok_or_else(|| DshError::ParseDnError("Header and payload are missing".to_string())) -} - -/// Decodes a Base64-encoded string. -/// -/// # Arguments -/// -/// * `payload` - The Base64-encoded string. -/// -/// # Returns -/// -/// A Result containing the decoded byte vector or a `DshError`. -fn decode_base64(payload: &str) -> Result, DshError> { - use base64::{alphabet, engine, read}; - use std::io::Read; - - let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::NO_PAD); - let mut decoder = read::DecoderReader::new(payload.as_bytes(), &engine); - - let mut decoded_token = Vec::new(); - decoder - .read_to_end(&mut decoded_token) - .map_err(DshError::IoError)?; - - Ok(decoded_token) -} - -#[cfg(test)] -mod tests { - use super::*; - use mockito::Matcher; - use tokio::sync::Mutex; - - fn create_valid_fetcher() -> MqttTokenFetcher { - let exp_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600; - println!("exp_time: {}", exp_time); - let rest_token: RestToken = RestToken { - exp: exp_time as i32, - raw_token: "valid.token.payload".to_string(), - }; - let mqtt_token = MqttToken { - exp: exp_time, - raw_token: "valid.token.payload".to_string(), - }; - let mqtt_token_map = DashMap::new(); - mqtt_token_map.insert("test_client".to_string(), mqtt_token.clone()); - MqttTokenFetcher { - tenant_name: "test_tenant".to_string(), - rest_api_key: "test_api_key".to_string(), - rest_token: Mutex::new(rest_token), - rest_auth_url: "test_auth_url".to_string(), - mqtt_token: mqtt_token_map, - client: reqwest::Client::new(), - mqtt_auth_url: "test_auth_url".to_string(), - } - } - - #[tokio::test] - async fn test_mqtt_token_fetcher_new() { - let tenant_name = "test_tenant".to_string(); - let rest_api_key = "test_api_key".to_string(); - let platform = Platform::NpLz; - - let fetcher = MqttTokenFetcher::new(tenant_name, rest_api_key, platform); - - assert!(fetcher.mqtt_token.is_empty()); - } - - #[tokio::test] - async fn test_mqtt_token_fetcher_new_with_client() { - let tenant_name = "test_tenant".to_string(); - let rest_api_key = "test_api_key".to_string(); - let platform = Platform::NpLz; - - let client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); - let fetcher = - MqttTokenFetcher::new_with_client(tenant_name, rest_api_key, platform, client); - - assert!(fetcher.mqtt_token.is_empty()); - } - - #[tokio::test] - async fn test_fetch_new_mqtt_token() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server.mock("POST", "/rest_auth_url") - .with_status(200) - .with_body(r#"{"raw_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"}"#) - .create_async() - .await; - let _m2 = mockito_server.mock("POST", "/mqtt_auth_url") - .with_status(200) - .with_body(r#"{"mqtt_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) - .create(); - - let client = reqwest::Client::new(); - let rest_token = RestToken { - raw_token: "initial_token".to_string(), - exp: 0, - }; - - let fetcher = MqttTokenFetcher { - client, - tenant_name: "test_tenant".to_string(), - rest_api_key: "test_api_key".to_string(), - mqtt_token: DashMap::new(), - rest_auth_url: mockito_server.url() + "/rest_auth_url", - mqtt_auth_url: mockito_server.url() + "/mqtt_auth_url", - rest_token: Mutex::new(rest_token), - }; - - let result = fetcher.fetch_new_mqtt_token("test_client_id", None).await; - println!("{:?}", result); - assert!(result.is_ok()); - let mqtt_token = result.unwrap(); - assert_eq!(mqtt_token.exp, 1); - } - - #[tokio::test] - async fn test_mqtt_token_fetcher_get_token() { - let fetcher = create_valid_fetcher(); - let token = fetcher.get_token("test_client", None).await.unwrap(); - assert_eq!(token.raw_token, "valid.token.payload"); - } - - #[test] - fn test_actions_display() { - let action = Actions::Publish; - assert_eq!(action.to_string(), "Publish"); - let action = Actions::Subscribe; - assert_eq!(action.to_string(), "Subscribe"); - } - - #[test] - fn test_token_request_new() { - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); - assert_eq!(request.id.len(), 64); - assert_eq!(request.tenant, "test_tenant"); - } - - #[tokio::test] - async fn test_send_success() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/mqtt_auth_url") - .match_header("Authorization", "Bearer test_token") - .match_body(Matcher::Json(json!({"key": "value"}))) - .with_status(200) - .with_body("success_response") - .create(); - - let client = reqwest::Client::new(); - let payload = json!({"key": "value"}); - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); - let result = request - .send( - &client, - &format!("{}/mqtt_auth_url", mockito_server.url()), - "Bearer test_token", - &payload, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "success_response"); - } - - #[tokio::test] - async fn test_send_failure() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/mqtt_auth_url") - .match_header("Authorization", "Bearer test_token") - .match_body(Matcher::Json(json!({"key": "value"}))) - .with_status(400) - .with_body("error_response") - .create(); - - let client = reqwest::Client::new(); - let payload = json!({"key": "value"}); - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); - let result = request - .send( - &client, - &format!("{}/mqtt_auth_url", mockito_server.url()), - "Bearer test_token", - &payload, - ) - .await; - - assert!(result.is_err()); - if let Err(DshError::DshCallError { - url, - status_code, - error_body, - }) = result - { - assert_eq!(url, format!("{}/mqtt_auth_url", mockito_server.url())); - assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body, "error_response"); - } else { - panic!("Expected DshCallError"); - } - } - - #[test] - fn test_claims_new() { - let resource = Resource::new( - "stream".to_string(), - "prefix".to_string(), - "topic".to_string(), - None, - ); - let action = "publish".to_string(); - - let claims = Claims::new(resource.clone(), action.clone()); - - assert_eq!(claims.resource.stream, "stream"); - assert_eq!(claims.action, "publish"); - } - - #[test] - fn test_resource_new() { - let resource = Resource::new( - "stream".to_string(), - "prefix".to_string(), - "topic".to_string(), - None, - ); - - assert_eq!(resource.stream, "stream"); - assert_eq!(resource.prefix, "prefix"); - assert_eq!(resource.topic, "topic"); - } - - #[test] - fn test_mqtt_token_is_valid() { - let raw_token = "valid.token.payload".to_string(); - let token = MqttToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600, - raw_token, - }; - - assert!(token.is_valid()); - } - #[test] - fn test_mqtt_token_is_invalid() { - let raw_token = "valid.token.payload".to_string(); - let token = MqttToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32, - raw_token, - }; - - assert!(!token.is_valid()); - } - - #[test] - fn test_rest_token_is_valid() { - let token = RestToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600, - raw_token: "valid.token.payload".to_string(), - }; - - assert!(token.is_valid()); - } - - #[test] - fn test_rest_token_is_invalid() { - let token = RestToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32, - raw_token: "valid.token.payload".to_string(), - }; - - assert!(!token.is_valid()); - } - - #[test] - fn test_rest_token_default_is_invalid() { - let token = RestToken::default(); - - assert!(!token.is_valid()); - } - - #[test] - fn test_extract_header_and_payload() { - let raw = "header.payload.signature"; - let result = extract_header_and_payload(raw).unwrap(); - assert_eq!(result, "payload"); - - let raw = "header.payload"; - let result = extract_header_and_payload(raw).unwrap(); - assert_eq!(result, "payload"); - - let raw = "header"; - let result = extract_header_and_payload(raw); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_fetch_token_success() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/auth_url") - .match_header("apikey", "test_api_key") - .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) - .with_status(200) - .with_body("test_token") - .create(); - - let client = reqwest::Client::new(); - let result = RestToken::fetch_token( - &client, - "test_tenant", - "test_api_key", - &format!("{}/auth_url", mockito_server.url()), - ) - .await; - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "test_token"); - } - - #[tokio::test] - async fn test_fetch_token_failure() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/auth_url") - .match_header("apikey", "test_api_key") - .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) - .with_status(400) - .with_body("error_response") - .create(); - - let client = reqwest::Client::new(); - let result = RestToken::fetch_token( - &client, - "test_tenant", - "test_api_key", - &format!("{}/auth_url", mockito_server.url()), - ) - .await; - - assert!(result.is_err()); - if let Err(DshError::DshCallError { - url, - status_code, - error_body, - }) = result - { - assert_eq!(url, format!("{}/auth_url", mockito_server.url())); - assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body, "error_response"); - } else { - panic!("Expected DshCallError"); - } + ) -> ProtocolTokenFetcher { + ProtocolTokenFetcher::new_with_client(tenant_name, api_key, platform, client) } } diff --git a/dsh_sdk/src/protocol_adapters/http_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/http_protocol/mod.rs new file mode 100644 index 0000000..1fc864a --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/http_protocol/mod.rs @@ -0,0 +1,3 @@ +//! DSH Http Protocol client +//! +//! To be created diff --git a/dsh_sdk/src/dsh/config.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs similarity index 76% rename from dsh_sdk/src/dsh/config.rs rename to dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs index 287b480..a258804 100644 --- a/dsh_sdk/src/dsh/config.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs @@ -1,34 +1,33 @@ -//! Additional optional configuration for kafka producer and consumer +//! Kafka configuration +//! +//! This module contains the configuration for the Kafka protocol adapter. + use crate::utils::get_env_var; use crate::*; +use std::sync::OnceLock; + +static KAFKA_CONFIG: OnceLock = OnceLock::new(); -/// Additional configuration for Consumer config +/// Kafka config /// /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information /// configuring the consmer via environment variables. #[derive(Debug, Clone)] -pub struct ConsumerConfig { +pub struct KafkaConfig { + // Consumer specific config enable_auto_commit: bool, auto_offset_reset: String, session_timeout: Option, queued_buffering_max_messages_kbytes: Option, -} - -/// Additional configuration for Producer config -/// -/// ## Environment variables -/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information -/// configuring the producer via environment variables. -#[derive(Debug, Clone, Default)] -pub struct ProducerConfig { + // Producer specific config batch_num_messages: Option, queue_buffering_max_messages: Option, queue_buffering_max_kbytes: Option, queue_buffering_max_ms: Option, } -impl ConsumerConfig { +impl KafkaConfig { pub fn new() -> Self { let enable_auto_commit = get_env_var(VAR_KAFKA_ENABLE_AUTO_COMMIT) .ok() @@ -43,40 +42,6 @@ impl ConsumerConfig { get_env_var(VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES) .ok() .and_then(|v| v.parse().ok()); - ConsumerConfig { - enable_auto_commit, - auto_offset_reset, - session_timeout, - queued_buffering_max_messages_kbytes, - } - } - pub fn enable_auto_commit(&self) -> bool { - self.enable_auto_commit - } - pub fn auto_offset_reset(&self) -> String { - self.auto_offset_reset.clone() - } - pub fn session_timeout(&self) -> Option { - self.session_timeout - } - pub fn queued_buffering_max_messages_kbytes(&self) -> Option { - self.queued_buffering_max_messages_kbytes - } -} - -impl Default for ConsumerConfig { - fn default() -> Self { - ConsumerConfig { - enable_auto_commit: false, - auto_offset_reset: "earliest".to_string(), - session_timeout: None, - queued_buffering_max_messages_kbytes: None, - } - } -} - -impl ProducerConfig { - pub fn new() -> Self { let batch_num_messages = get_env_var(VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES) .ok() .and_then(|v| v.parse().ok()); @@ -90,14 +55,33 @@ impl ProducerConfig { let queue_buffering_max_ms = get_env_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS) .ok() .and_then(|v| v.parse().ok()); - ProducerConfig { + Self { + enable_auto_commit, + auto_offset_reset, + session_timeout, + queued_buffering_max_messages_kbytes, batch_num_messages, queue_buffering_max_messages, queue_buffering_max_kbytes, queue_buffering_max_ms, } } - + // TODO: Check, does this make sense? + pub fn get() -> &'static KafkaConfig { + KAFKA_CONFIG.get_or_init(KafkaConfig::new) + } + pub fn enable_auto_commit(&self) -> bool { + self.enable_auto_commit + } + pub fn auto_offset_reset(&self) -> String { + self.auto_offset_reset.clone() + } + pub fn session_timeout(&self) -> Option { + self.session_timeout + } + pub fn queued_buffering_max_messages_kbytes(&self) -> Option { + self.queued_buffering_max_messages_kbytes + } pub fn batch_num_messages(&self) -> Option { self.batch_num_messages } @@ -112,6 +96,21 @@ impl ProducerConfig { } } +impl Default for KafkaConfig { + fn default() -> Self { + Self { + enable_auto_commit: false, + auto_offset_reset: "earliest".to_string(), + session_timeout: None, + queued_buffering_max_messages_kbytes: None, + batch_num_messages: None, + queue_buffering_max_messages: None, + queue_buffering_max_kbytes: None, + queue_buffering_max_ms: None, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -119,26 +118,36 @@ mod tests { use std::env; #[test] - fn test_consumer_config() { - let consumer_config = ConsumerConfig::new(); + #[serial(env_dependency)] + fn test_kafka_config() { + let consumer_config = KafkaConfig::new(); assert_eq!(consumer_config.enable_auto_commit(), false); assert_eq!(consumer_config.auto_offset_reset(), "earliest"); assert_eq!(consumer_config.session_timeout(), None); assert_eq!(consumer_config.queued_buffering_max_messages_kbytes(), None); + assert_eq!(consumer_config.batch_num_messages(), None); + assert_eq!(consumer_config.queue_buffering_max_messages(), None); + assert_eq!(consumer_config.queue_buffering_max_kbytes(), None); + assert_eq!(consumer_config.queue_buffering_max_ms(), None); } #[test] - fn test_consumer_config_default() { - let consumer_config = ConsumerConfig::default(); + #[serial(env_dependency)] + fn test_kafka_config_default() { + let consumer_config = KafkaConfig::default(); assert_eq!(consumer_config.enable_auto_commit(), false); assert_eq!(consumer_config.auto_offset_reset(), "earliest"); assert_eq!(consumer_config.session_timeout(), None); assert_eq!(consumer_config.queued_buffering_max_messages_kbytes(), None); + assert_eq!(consumer_config.batch_num_messages(), None); + assert_eq!(consumer_config.queue_buffering_max_messages(), None); + assert_eq!(consumer_config.queue_buffering_max_kbytes(), None); + assert_eq!(consumer_config.queue_buffering_max_ms(), None); } #[test] #[serial(env_dependency)] - fn test_consumer_config_env() { + fn test_consumer_kafka_config_env() { env::set_var(VAR_KAFKA_ENABLE_AUTO_COMMIT, "true"); env::set_var(VAR_KAFKA_AUTO_OFFSET_RESET, "latest"); env::set_var(VAR_KAFKA_CONSUMER_SESSION_TIMEOUT_MS, "1000"); @@ -146,7 +155,7 @@ mod tests { VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES, "1000", ); - let consumer_config = ConsumerConfig::new(); + let consumer_config = KafkaConfig::new(); assert_eq!(consumer_config.enable_auto_commit(), true); assert_eq!(consumer_config.auto_offset_reset(), "latest"); assert_eq!(consumer_config.session_timeout(), Some(1000)); @@ -160,32 +169,14 @@ mod tests { env::remove_var(VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES); } - #[test] - fn test_producer_config() { - let producer_config = ProducerConfig::new(); - assert_eq!(producer_config.batch_num_messages(), None); - assert_eq!(producer_config.queue_buffering_max_messages(), None); - assert_eq!(producer_config.queue_buffering_max_kbytes(), None); - assert_eq!(producer_config.queue_buffering_max_ms(), None); - } - - #[test] - fn test_producer_config_default() { - let producer_config = ProducerConfig::default(); - assert_eq!(producer_config.batch_num_messages(), None); - assert_eq!(producer_config.queue_buffering_max_messages(), None); - assert_eq!(producer_config.queue_buffering_max_kbytes(), None); - assert_eq!(producer_config.queue_buffering_max_ms(), None); - } - #[test] #[serial(env_dependency)] - fn test_producer_config_env() { + fn test_producer_kafka_config_env() { env::set_var(VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES, "1000"); env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES, "1000"); env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES, "1000"); env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS, "1000"); - let producer_config = ProducerConfig::new(); + let producer_config = KafkaConfig::new(); assert_eq!(producer_config.batch_num_messages(), Some(1000)); assert_eq!(producer_config.queue_buffering_max_messages(), Some(1000)); assert_eq!(producer_config.queue_buffering_max_kbytes(), Some(1000)); diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs new file mode 100644 index 0000000..3c4b952 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs @@ -0,0 +1,12 @@ +pub(crate) mod config; // TODO: should we make this public? What benefits would that bring? + +pub trait DshKafkaConfig { + /// Set all required configurations to consume messages from DSH. + fn dsh_consumer_config(&mut self) -> &mut Self; + /// Set all required configurations to produce messages to DSH. + fn dsh_producer_config(&mut self) -> &mut Self; + /// Set a DSH compatible group id. + /// + /// DSH Requires a group id with the prefix of the tenant name. + fn set_group_id(&mut self, group_id: &str) -> &mut Self; +} diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs new file mode 100644 index 0000000..e3dc107 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "http-protocol-adapter")] +pub mod http_protocol; +pub mod kafka_protocol; +#[cfg(feature = "mqtt-protocol-adapter")] +pub mod mqtt_protocol; +#[cfg(feature = "protocol-token-fetcher")] +pub mod token_fetcher; + +#[cfg(feature = "protocol-token-fetcher")] +#[doc(inline)] +pub use token_fetcher::ProtocolTokenFetcher; diff --git a/dsh_sdk/src/protocol_adapters/mqtt_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/mqtt_protocol/mod.rs new file mode 100644 index 0000000..21e885b --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/mqtt_protocol/mod.rs @@ -0,0 +1,3 @@ +//! DSH Mqtt Protocol client +//! +//! To be created diff --git a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs new file mode 100644 index 0000000..a029cf0 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs @@ -0,0 +1,860 @@ +//! # MQTT Token Fetcher +//! +//! `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. +use std::collections::{hash_map::Entry, HashMap}; +use std::fmt::{Display, Formatter}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sha2::{Digest, Sha256}; +use tokio::sync::RwLock; + +use crate::{error::DshError, Platform}; + +/// `ProtocolTokenFetcher` is responsible for fetching and managing tokens for the DSH Mqtt and Http protocol adapters. +/// +/// It ensures that the tokens are valid, and if not, it refreshes them automatically. The struct +/// is thread-safe and can be shared across multiple threads. + +pub struct ProtocolTokenFetcher { + tenant_name: String, + rest_api_key: String, + rest_token: RwLock, + rest_auth_url: String, + mqtt_token: RwLock>, // Mapping from Client ID to MqttToken + mqtt_auth_url: String, + client: reqwest::Client, + //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration + // port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too +} + +/// Constructs a new `ProtocolTokenFetcher`. +/// +/// # Arguments +/// +/// * `tenant_name` - The tenant name in DSH. +/// * `rest_api_key` - The REST API key used for authentication. +/// * `platform` - The DSH platform environment +/// +/// # Returns +/// +/// Returns a `Result` containing a `MqttTokenFetcher` instance or a `DshError`. +impl ProtocolTokenFetcher { + /// Constructs a new `ProtocolTokenFetcher`. + /// + /// # Arguments + /// + /// * `tenant_name` - The tenant name of DSH. + /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. + /// * `platform` - The target DSH platform environment. + /// + /// # Example + /// + /// ```no_run + /// use dsh_sdk::protocol_adapters::ProtocolTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let tenant_name = "test_tenant".to_string(); + /// let api_key = "aAbB123".to_string(); + /// let platform = Platform::NpLz; + /// + /// let fetcher = ProtocolTokenFetcher::new(tenant_name, api_key, platform); + /// let token = fetcher.get_token("test_client", None).await.unwrap(); + /// # } + /// ``` + pub fn new(tenant_name: String, api_key: String, platform: Platform) -> Self { + const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + + let reqwest_client = reqwest::Client::builder() + .timeout(DEFAULT_TIMEOUT) + .http1_only() + .build() + .expect("Failed to build reqwest client"); + Self::new_with_client(tenant_name, api_key, platform, reqwest_client) + } + + /// Constructs a new `ProtocolTokenFetcher` with a custom reqwest client. + /// On this Reqwest client, you can set custom timeouts, headers, Rustls etc. + /// + /// # Arguments + /// + /// * `tenant_name` - The tenant name of DSH. + /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. + /// * `platform` - The target DSH platform environment. + /// * `client` - User configured reqwest client to be used for fetching tokens + /// + /// # Example + /// + /// ```no_run + /// use dsh_sdk::mqtt_token_fetcher::ProtocolTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let tenant_name = "test_tenant".to_string(); + /// let api_key = "aAbB123".to_string(); + /// let platform = Platform::NpLz; + /// let client = reqwest::Client::new(); + /// let fetcher = ProtocolTokenFetcher::new_with_client(tenant_name, api_key, platform, client); + /// let token = fetcher.get_token("test_client", None).await.unwrap(); + /// # } + /// ``` + pub fn new_with_client( + tenant_name: String, + api_key: String, + platform: Platform, + client: reqwest::Client, + ) -> Self { + let rest_token = RestToken::default(); + Self { + tenant_name, + rest_api_key: api_key, + rest_token: RwLock::new(rest_token), + rest_auth_url: platform.endpoint_rest_token().to_string(), + mqtt_token: RwLock::new(HashMap::new()), + mqtt_auth_url: platform.endpoint_mqtt_token().to_string(), + client, + } + } + /// Retrieves an MQTT token for the specified client ID. + /// + /// If the token is expired or does not exist, it fetches a new token. + /// + /// # Arguments + /// + /// * `client_id` - The identifier for the MQTT client. + /// * `claims` - Optional claims for the MQTT token. + /// + /// # Returns + /// + /// Returns a `Result` containing the `MqttToken` or a `DshError`. + pub async fn get_token( + &self, + client_id: &str, + claims: Option>, + ) -> Result { + match self.mqtt_token.write().await.entry(client_id.to_string()) { + Entry::Occupied(mut entry) => { + let mqtt_token = entry.get_mut(); + if !mqtt_token.is_valid() { + *mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; + }; + Ok(mqtt_token.clone()) + } + Entry::Vacant(entry) => { + let mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; + entry.insert(mqtt_token.clone()); + Ok(mqtt_token) + } + } + } + + /// Fetches a new MQTT token from the platform. + /// + /// This method handles token validation and fetching the token + async fn fetch_new_mqtt_token( + &self, + client_id: &str, + claims: Option>, + ) -> Result { + let mut rest_token = self.rest_token.write().await; + + if !rest_token.is_valid() { + *rest_token = RestToken::get( + &self.client, + &self.tenant_name, + &self.rest_api_key, + &self.rest_auth_url, + ) + .await? + } + + let authorization_header = format!("Bearer {}", rest_token.raw_token); + + let mqtt_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; + let payload = serde_json::to_value(&mqtt_token_request)?; + + let response = mqtt_token_request + .send( + &self.client, + &self.mqtt_auth_url, + &authorization_header, + &payload, + ) + .await?; + + MqttToken::new(response) + } +} + +/// Represent Claims information for MQTT request +/// * `action` - can be subscribe or publish +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Claims { + resource: Resource, + action: String, +} + +impl Claims { + pub fn new(resource: Resource, action: String) -> Claims { + Claims { resource, action } + } +} + +/// Enumeration representing possible actions in MQTT claims. +pub enum Actions { + Publish, + Subscribe, +} + +impl Display for Actions { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Actions::Publish => write!(f, "Publish"), + Actions::Subscribe => write!(f, "Subscribe"), + } + } +} + +/// Represents a resource in the MQTT claim. +/// +/// The resource defines what the client can access in terms of stream, prefix, topic, and type. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Resource { + stream: String, + prefix: String, + topic: String, + #[serde(rename = "type")] + type_: Option, +} + +impl Resource { + /// Creates a new `Resource` instance. Please check DSH MQTT Documentation for further explanation of the fields. + /// + /// # Arguments + /// + /// * `stream` - The data stream name. + /// * `prefix` - The prefix of the topic. + /// * `topic` - The topic name. + /// * `type_` - The optional type of the resource. + /// + /// + /// # Returns + /// + /// Returns a new `Resource` instance. + pub fn new(stream: String, prefix: String, topic: String, type_: Option) -> Resource { + Resource { + stream, + prefix, + topic, + type_, + } + } +} + +#[derive(Serialize)] +struct MqttTokenRequest { + id: String, + tenant: String, + claims: Option>, +} + +impl MqttTokenRequest { + fn new( + client_id: &str, + tenant: &str, + claims: Option>, + ) -> Result { + let mut hasher = Sha256::new(); + hasher.update(client_id); + let result = hasher.finalize(); + let id = format!("{:x}", result); + + Ok(Self { + id, + tenant: tenant.to_string(), + claims, + }) + } + + async fn send( + &self, + reqwest_client: &reqwest::Client, + mqtt_auth_url: &str, + authorization_header: &str, + payload: &serde_json::Value, + ) -> Result { + let response = reqwest_client + .post(mqtt_auth_url) + .header("Authorization", authorization_header) + .json(payload) + .send() + .await?; + + if response.status().is_success() { + Ok(response.text().await?) + } else { + Err(DshError::DshCallError { + url: mqtt_auth_url.to_string(), + status_code: response.status(), + error_body: response.text().await?, + }) + } + } +} + +/// Represents attributes associated with a mqtt token. +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "kebab-case")] +struct MqttTokenAttributes { + gen: i32, + endpoint: String, + iss: String, + claims: Option>, + exp: i32, + client_id: String, + iat: i32, + tenant_id: String, +} + +/// Represents a token used for MQTT connections. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MqttToken { + exp: i32, + raw_token: String, +} + +impl MqttToken { + /// Creates a new instance of `MqttToken` from a raw token string. + /// + /// # Arguments + /// + /// * `raw_token` - The raw token string. + /// + /// # Returns + /// + /// A Result containing the created MqttToken or an error. + pub fn new(raw_token: String) -> Result { + let header_payload = extract_header_and_payload(&raw_token)?; + + let decoded_token = decode_base64(header_payload)?; + + let token_attributes: MqttTokenAttributes = serde_json::from_slice(&decoded_token)?; + let token = MqttToken { + exp: token_attributes.exp, + raw_token, + }; + + Ok(token) + } + + /// Checks if the MQTT token is still valid. + fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i32; + self.exp >= current_unixtime + 5 + } +} + +/// Represents attributes associated with a Rest token. +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "kebab-case")] +struct RestTokenAttributes { + gen: i64, + endpoint: String, + iss: String, + claims: RestClaims, + exp: i32, + tenant_id: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct RestClaims { + #[serde(rename = "datastreams/v0/mqtt/token")] + datastreams_token: DatastreamsData, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DatastreamsData {} + +/// Represents a rest token with its raw value and attributes. +#[derive(Serialize, Deserialize, Debug)] +struct RestToken { + raw_token: String, + exp: i32, +} + +impl RestToken { + /// Retrieves a new REST token from the platform. + /// + /// # Arguments + /// + /// * `tenant` - The tenant name associated with the DSH platform. + /// * `api_key` - The REST API key used for authentication. + /// * `env` - The platform environment (e.g., production, staging). + /// + /// # Returns + /// + /// A Result containing the created `RestToken` or a `DshError`. + async fn get( + client: &reqwest::Client, + tenant: &str, + api_key: &str, + auth_url: &str, + ) -> Result { + let raw_token = Self::fetch_token(client, tenant, api_key, auth_url).await?; + + let header_payload = extract_header_and_payload(&raw_token)?; + + let decoded_token = decode_base64(header_payload)?; + + let token_attributes: RestTokenAttributes = serde_json::from_slice(&decoded_token)?; + let token = RestToken { + raw_token, + exp: token_attributes.exp, + }; + + Ok(token) + } + + // Checks if the REST token is still valid. + fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i32; + self.exp >= current_unixtime + 5 + } + + async fn fetch_token( + client: &reqwest::Client, + tenant: &str, + api_key: &str, + auth_url: &str, + ) -> Result { + let json_body = json!({"tenant": tenant}); + + let response = client + .post(auth_url) + .header("apikey", api_key) + .json(&json_body) + .send() + .await?; + + let status = response.status(); + let body_text = response.text().await?; + match status { + reqwest::StatusCode::OK => Ok(body_text), + _ => Err(DshError::DshCallError { + url: auth_url.to_string(), + status_code: status, + error_body: body_text, + }), + } + } +} + +impl Default for RestToken { + fn default() -> Self { + Self { + raw_token: "".to_string(), + exp: 0, + } + } +} + +/// Extracts the header and payload part of a JWT token. +/// +/// # Arguments +/// +/// * `raw_token` - The raw JWT token string. +/// +/// # Returns +/// +/// A Result containing the header and payload part of the JWT token or a `DshError`. +fn extract_header_and_payload(raw_token: &str) -> Result<&str, DshError> { + let parts: Vec<&str> = raw_token.split('.').collect(); + parts + .get(1) + .copied() + .ok_or_else(|| DshError::ParseDnError("Header and payload are missing".to_string())) +} + +/// Decodes a Base64-encoded string. +/// +/// # Arguments +/// +/// * `payload` - The Base64-encoded string. +/// +/// # Returns +/// +/// A Result containing the decoded byte vector or a `DshError`. +fn decode_base64(payload: &str) -> Result, DshError> { + use base64::{alphabet, engine, read}; + use std::io::Read; + + let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::NO_PAD); + let mut decoder = read::DecoderReader::new(payload.as_bytes(), &engine); + + let mut decoded_token = Vec::new(); + decoder + .read_to_end(&mut decoded_token) + .map_err(DshError::IoError)?; + + Ok(decoded_token) +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Matcher; + + async fn create_valid_fetcher() -> ProtocolTokenFetcher { + let exp_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600; + println!("exp_time: {}", exp_time); + let rest_token: RestToken = RestToken { + exp: exp_time as i32, + raw_token: "valid.token.payload".to_string(), + }; + let mqtt_token = MqttToken { + exp: exp_time, + raw_token: "valid.token.payload".to_string(), + }; + let mqtt_token_map = RwLock::new(HashMap::new()); + mqtt_token_map + .write() + .await + .insert("test_client".to_string(), mqtt_token.clone()); + ProtocolTokenFetcher { + tenant_name: "test_tenant".to_string(), + rest_api_key: "test_api_key".to_string(), + rest_token: RwLock::new(rest_token), + rest_auth_url: "test_auth_url".to_string(), + mqtt_token: mqtt_token_map, + client: reqwest::Client::new(), + mqtt_auth_url: "test_auth_url".to_string(), + } + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_new() { + let tenant_name = "test_tenant".to_string(); + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let fetcher = ProtocolTokenFetcher::new(tenant_name, rest_api_key, platform); + + assert!(fetcher.mqtt_token.read().await.is_empty()); + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_new_with_client() { + let tenant_name = "test_tenant".to_string(); + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); + let fetcher = + ProtocolTokenFetcher::new_with_client(tenant_name, rest_api_key, platform, client); + + assert!(fetcher.mqtt_token.read().await.is_empty()); + } + + #[tokio::test] + async fn test_fetch_new_mqtt_token() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server.mock("POST", "/rest_auth_url") + .with_status(200) + .with_body(r#"{"raw_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"}"#) + .create_async() + .await; + let _m2 = mockito_server.mock("POST", "/mqtt_auth_url") + .with_status(200) + .with_body(r#"{"mqtt_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) + .create(); + + let client = reqwest::Client::new(); + let rest_token = RestToken { + raw_token: "initial_token".to_string(), + exp: 0, + }; + + let fetcher = ProtocolTokenFetcher { + client, + tenant_name: "test_tenant".to_string(), + rest_api_key: "test_api_key".to_string(), + mqtt_token: RwLock::new(HashMap::new()), + rest_auth_url: mockito_server.url() + "/rest_auth_url", + mqtt_auth_url: mockito_server.url() + "/mqtt_auth_url", + rest_token: RwLock::new(rest_token), + }; + + let result = fetcher.fetch_new_mqtt_token("test_client_id", None).await; + println!("{:?}", result); + assert!(result.is_ok()); + let mqtt_token = result.unwrap(); + assert_eq!(mqtt_token.exp, 1); + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_get_token() { + let fetcher = create_valid_fetcher().await; + let token = fetcher.get_token("test_client", None).await.unwrap(); + assert_eq!(token.raw_token, "valid.token.payload"); + } + + #[test] + fn test_actions_display() { + let action = Actions::Publish; + assert_eq!(action.to_string(), "Publish"); + let action = Actions::Subscribe; + assert_eq!(action.to_string(), "Subscribe"); + } + + #[test] + fn test_token_request_new() { + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + assert_eq!(request.id.len(), 64); + assert_eq!(request.tenant, "test_tenant"); + } + + #[tokio::test] + async fn test_send_success() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/mqtt_auth_url") + .match_header("Authorization", "Bearer test_token") + .match_body(Matcher::Json(json!({"key": "value"}))) + .with_status(200) + .with_body("success_response") + .create(); + + let client = reqwest::Client::new(); + let payload = json!({"key": "value"}); + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let result = request + .send( + &client, + &format!("{}/mqtt_auth_url", mockito_server.url()), + "Bearer test_token", + &payload, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "success_response"); + } + + #[tokio::test] + async fn test_send_failure() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/mqtt_auth_url") + .match_header("Authorization", "Bearer test_token") + .match_body(Matcher::Json(json!({"key": "value"}))) + .with_status(400) + .with_body("error_response") + .create(); + + let client = reqwest::Client::new(); + let payload = json!({"key": "value"}); + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let result = request + .send( + &client, + &format!("{}/mqtt_auth_url", mockito_server.url()), + "Bearer test_token", + &payload, + ) + .await; + + assert!(result.is_err()); + if let Err(DshError::DshCallError { + url, + status_code, + error_body, + }) = result + { + assert_eq!(url, format!("{}/mqtt_auth_url", mockito_server.url())); + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body, "error_response"); + } else { + panic!("Expected DshCallError"); + } + } + + #[test] + fn test_claims_new() { + let resource = Resource::new( + "stream".to_string(), + "prefix".to_string(), + "topic".to_string(), + None, + ); + let action = "publish".to_string(); + + let claims = Claims::new(resource.clone(), action.clone()); + + assert_eq!(claims.resource.stream, "stream"); + assert_eq!(claims.action, "publish"); + } + + #[test] + fn test_resource_new() { + let resource = Resource::new( + "stream".to_string(), + "prefix".to_string(), + "topic".to_string(), + None, + ); + + assert_eq!(resource.stream, "stream"); + assert_eq!(resource.prefix, "prefix"); + assert_eq!(resource.topic, "topic"); + } + + #[test] + fn test_mqtt_token_is_valid() { + let raw_token = "valid.token.payload".to_string(); + let token = MqttToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600, + raw_token, + }; + + assert!(token.is_valid()); + } + #[test] + fn test_mqtt_token_is_invalid() { + let raw_token = "valid.token.payload".to_string(); + let token = MqttToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32, + raw_token, + }; + + assert!(!token.is_valid()); + } + + #[test] + fn test_rest_token_is_valid() { + let token = RestToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600, + raw_token: "valid.token.payload".to_string(), + }; + + assert!(token.is_valid()); + } + + #[test] + fn test_rest_token_is_invalid() { + let token = RestToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32, + raw_token: "valid.token.payload".to_string(), + }; + + assert!(!token.is_valid()); + } + + #[test] + fn test_rest_token_default_is_invalid() { + let token = RestToken::default(); + + assert!(!token.is_valid()); + } + + #[test] + fn test_extract_header_and_payload() { + let raw = "header.payload.signature"; + let result = extract_header_and_payload(raw).unwrap(); + assert_eq!(result, "payload"); + + let raw = "header.payload"; + let result = extract_header_and_payload(raw).unwrap(); + assert_eq!(result, "payload"); + + let raw = "header"; + let result = extract_header_and_payload(raw); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_fetch_token_success() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/auth_url") + .match_header("apikey", "test_api_key") + .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) + .with_status(200) + .with_body("test_token") + .create(); + + let client = reqwest::Client::new(); + let result = RestToken::fetch_token( + &client, + "test_tenant", + "test_api_key", + &format!("{}/auth_url", mockito_server.url()), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test_token"); + } + + #[tokio::test] + async fn test_fetch_token_failure() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/auth_url") + .match_header("apikey", "test_api_key") + .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) + .with_status(400) + .with_body("error_response") + .create(); + + let client = reqwest::Client::new(); + let result = RestToken::fetch_token( + &client, + "test_tenant", + "test_api_key", + &format!("{}/auth_url", mockito_server.url()), + ) + .await; + + assert!(result.is_err()); + if let Err(DshError::DshCallError { + url, + status_code, + error_body, + }) = result + { + assert_eq!(url, format!("{}/auth_url", mockito_server.url())); + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body, "error_response"); + } else { + panic!("Expected DshCallError"); + } + } +} diff --git a/dsh_sdk/src/rest_api_token_fetcher.rs b/dsh_sdk/src/rest_api_token_fetcher.rs index 172fcee..8b06635 100644 --- a/dsh_sdk/src/rest_api_token_fetcher.rs +++ b/dsh_sdk/src/rest_api_token_fetcher.rs @@ -1,147 +1,17 @@ -//! Module for fetching and storing access tokens for the DSH Rest API client -//! -//! This module is meant to be used together with the [dsh_rest_api_client]. -//! -//! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. -//! -//! ## Example -//! Recommended usage is to use the [RestTokenFetcherBuilder] to create a new instance of the token fetcher. -//! However, you can also create a new instance of the token fetcher directly. -//! ```no_run -//! use dsh_sdk::{RestTokenFetcherBuilder, Platform}; -//! use dsh_rest_api_client::Client; -//! -//! const CLIENT_SECRET: &str = ""; -//! const TENANT: &str = "tenant-name"; -//! -//! #[tokio::main] -//! async fn main() { -//! let platform = Platform::NpLz; -//! let client = Client::new(platform.endpoint_rest_api()); -//! -//! let tf = RestTokenFetcherBuilder::new(platform) -//! .tenant_name(TENANT.to_string()) -//! .client_secret(CLIENT_SECRET.to_string()) -//! .build() -//! .unwrap(); -//! -//! let response = client -//! .topic_get_by_tenant_topic(TENANT, &tf.get_token().await.unwrap()) -//! .await; -//! println!("Available topics: {:#?}", response); -//! } -//! ``` - -use std::fmt::Debug; -use std::ops::Add; -use std::sync::Mutex; -use std::time::{Duration, Instant}; - -use log::debug; -use serde::Deserialize; - -use crate::error::DshRestTokenError; -use crate::utils::Platform; - -/// Access token of the authentication serveice of DSH. -/// -/// This is the response whem requesting for a new access token. -/// -/// ## Recommended usage -/// Use the [RestTokenFetcher::get_token] to get the bearer token, the `TokenFetcher` will automatically fetch a new token if the current token is not valid. -#[derive(Debug, Clone, Deserialize)] -pub struct AccessToken { - access_token: String, - expires_in: u64, - refresh_expires_in: u32, - token_type: String, - #[serde(rename(deserialize = "not-before-policy"))] - not_before_policy: u32, - scope: String, -} - -impl AccessToken { - /// Get the formatted token - pub fn formatted_token(&self) -> String { - format!("{} {}", self.token_type, self.access_token) - } - - /// Get the access token - pub fn access_token(&self) -> &str { - &self.access_token - } - - /// Get the expires in of the access token - pub fn expires_in(&self) -> u64 { - self.expires_in - } - - /// Get the refresh expires in of the access token - pub fn refresh_expires_in(&self) -> u32 { - self.refresh_expires_in - } - - /// Get the token type of the access token - pub fn token_type(&self) -> &str { - &self.token_type - } - - /// Get the not before policy of the access token - pub fn not_before_policy(&self) -> u32 { - self.not_before_policy - } - - /// Get the scope of the access token - pub fn scope(&self) -> &str { - &self.scope - } -} - -impl Default for AccessToken { - fn default() -> Self { - Self { - access_token: "".to_string(), - expires_in: 0, - refresh_expires_in: 0, - token_type: "".to_string(), - not_before_policy: 0, - scope: "".to_string(), - } - } -} - /// Fetch and store access tokens to be used in the DSH Rest API client /// /// This struct will fetch and store access tokens to be used in the DSH Rest API client. /// It will automatically fetch a new token if the current token is not valid. -pub struct RestTokenFetcher { - access_token: Mutex, - fetched_at: Mutex, - client_id: String, - client_secret: String, - client: reqwest::Client, - auth_url: String, -} +pub struct RestTokenFetcher; impl RestTokenFetcher { /// Create a new instance of the token fetcher - /// - /// ## Example - /// ```no_run - /// use dsh_sdk::{RestTokenFetcher, Platform}; - /// use dsh_rest_api_client::Client; - /// - /// #[tokio::main] - /// async fn main() { - /// let platform = Platform::NpLz; - /// let client_id = platform.rest_client_id("my-tenant"); - /// let client_secret = "my-secret".to_string(); - /// let token_fetcher = RestTokenFetcher::new(client_id, client_secret, platform.endpoint_rest_access_token().to_string()); - /// let token = token_fetcher.get_token().await.unwrap(); - /// } - /// ``` - pub fn new(client_id: String, client_secret: String, auth_url: String) -> Self { - Self::new_with_client( + pub fn new( + client_id: String, + client_secret: String, + auth_url: String, + ) -> crate::ManagementApiTokenFetcher { + crate::ManagementApiTokenFetcher::new_with_client( client_id, client_secret, auth_url, @@ -150,448 +20,30 @@ impl RestTokenFetcher { } /// Create a new instance of the token fetcher with custom reqwest client - /// - /// ## Example - /// ```no_run - /// use dsh_sdk::{RestTokenFetcher, Platform}; - /// use dsh_rest_api_client::Client; - /// - /// #[tokio::main] - /// async fn main() { - /// let platform = Platform::NpLz; - /// let client_id = platform.rest_client_id("my-tenant"); - /// let client_secret = "my-secret".to_string(); - /// let client = reqwest::Client::new(); - /// let token_fetcher = RestTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_rest_access_token().to_string(), client); - /// let token = token_fetcher.get_token().await.unwrap(); - /// } - /// ``` pub fn new_with_client( client_id: String, client_secret: String, auth_url: String, client: reqwest::Client, - ) -> Self { - Self { - access_token: Mutex::new(AccessToken::default()), - fetched_at: Mutex::new(Instant::now()), + ) -> crate::ManagementApiTokenFetcher { + crate::ManagementApiTokenFetcher::new_with_client( client_id, client_secret, - client, auth_url, - } - } - - /// Get token from the token fetcher - /// - /// If the cached token is not valid, it will fetch a new token from the server. - /// It will return the token as a string, formatted as "{token_type} {token}" - /// If the request fails for a new token, it will return a [DshRestTokenError::FailureTokenFetch] error. - /// This will contain the underlying reqwest error. - pub async fn get_token(&self) -> Result { - match self.is_valid() { - true => Ok(self.access_token.lock().unwrap().formatted_token()), - false => { - debug!("Token is expired, fetching new token"); - let access_token = self.fetch_access_token_from_server().await?; - let mut token = self.access_token.lock().unwrap(); - let mut fetched_at = self.fetched_at.lock().unwrap(); - *token = access_token; - *fetched_at = Instant::now(); - Ok(token.formatted_token()) - } - } - } - - /// Check if the current access token is still valid - /// - /// If the token has expired, it will return false. - pub fn is_valid(&self) -> bool { - let access_token = self.access_token.lock().unwrap_or_else(|mut e| { - **e.get_mut() = AccessToken::default(); - self.access_token.clear_poison(); - e.into_inner() - }); - let fetched_at = self.fetched_at.lock().unwrap_or_else(|e| { - self.fetched_at.clear_poison(); - e.into_inner() - }); - // Check if expires in has elapsed (+ safety margin of 5 seconds) - fetched_at.elapsed().add(Duration::from_secs(5)) - < Duration::from_secs(access_token.expires_in) - } - - /// Fetch a new access token from the server - /// - /// This will fetch a new access token from the server and return it. - /// If the request fails, it will return a [DshRestTokenError::FailureTokenFetch] error. - /// If the status code is not successful, it will return a [DshRestTokenError::StatusCode] error. - /// If the request is successful, it will return the [AccessToken]. - pub async fn fetch_access_token_from_server(&self) -> Result { - let response = self - .client - .post(&self.auth_url) - .form(&[ - ("client_id", self.client_id.as_ref()), - ("client_secret", self.client_secret.as_ref()), - ("grant_type", "client_credentials"), - ]) - .send() - .await - .map_err(DshRestTokenError::FailureTokenFetch)?; - if !response.status().is_success() { - Err(DshRestTokenError::StatusCode { - status_code: response.status(), - error_body: response, - }) - } else { - response - .json::() - .await - .map_err(DshRestTokenError::FailureTokenFetch) - } - } -} - -impl Debug for RestTokenFetcher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RestTokenFetcher") - .field("access_token", &self.access_token) - .field("fetched_at", &self.fetched_at) - .field("client_id", &self.client_id) - .field("client_secret", &"xxxxxx") - .field("auth_url", &self.auth_url) - .finish() + client, + ) } } /// Builder for the token fetcher -pub struct RestTokenFetcherBuilder { - client: Option, - client_id: Option, - client_secret: Option, - platform: Platform, - tenant_name: Option, -} +pub struct RestTokenFetcherBuilder; impl RestTokenFetcherBuilder { /// Get a new instance of the ClientBuilder /// /// # Arguments /// * `platform` - The target platform to use for the token fetcher - pub fn new(platform: Platform) -> Self { - Self { - client: None, - client_id: None, - client_secret: None, - platform, - tenant_name: None, - } - } - - /// Set the client_id for the client - /// - /// Alternatively, set `tenant_name` to generate the client_id. - /// `Client_id` does have precedence over `tenant_name`. - pub fn client_id(mut self, client_id: String) -> Self { - self.client_id = Some(client_id); - self - } - - /// Set the client_secret for the client - pub fn client_secret(mut self, client_secret: String) -> Self { - self.client_secret = Some(client_secret); - self - } - - /// Set the tenant_name for the client, this will generate the client_id - /// - /// Alternatively, set `client_id` directly. - /// `Tenant_name` does have precedence over `client_id`. - pub fn tenant_name(mut self, tenant_name: String) -> Self { - self.tenant_name = Some(tenant_name); - self - } - - /// Provide a custom configured Reqwest client for the token - /// - /// This is optional, if not provided, a default client will be used. - pub fn client(mut self, client: reqwest::Client) -> Self { - self.client = Some(client); - self - } - - /// Build the client and token fetcher - /// - /// This will build the client and token fetcher based on the given parameters. - /// It will return a tuple with the client and token fetcher. - /// - /// ## Example - /// ``` - /// # use dsh_sdk::{RestTokenFetcherBuilder, Platform}; - /// let platform = Platform::NpLz; - /// let client_id = "robot:dev-lz-dsh:my-tenant".to_string(); - /// let client_secret = "secret".to_string(); - /// let tf = RestTokenFetcherBuilder::new(platform) - /// .client_id(client_id) - /// .client_secret(client_secret) - /// .build() - /// .unwrap(); - /// ``` - pub fn build(self) -> Result { - let client_secret = self - .client_secret - .ok_or(DshRestTokenError::UnknownClientSecret)?; - let client_id = self - .client_id - .or_else(|| { - self.tenant_name - .as_ref() - .map(|tenant_name| self.platform.rest_client_id(tenant_name)) - }) - .ok_or(DshRestTokenError::UnknownClientId)?; - let client = self.client.unwrap_or_default(); - let token_fetcher = RestTokenFetcher::new_with_client( - client_id, - client_secret, - self.platform.endpoint_rest_access_token().to_string(), - client, - ); - Ok(token_fetcher) - } -} - -impl Debug for RestTokenFetcherBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let client_secret = self - .client_secret - .as_ref() - .map(|_| "Some(\"client_secret\")"); - f.debug_struct("RestTokenFetcherBuilder") - .field("client_id", &self.client_id) - .field("client_secret", &client_secret) - .field("platform", &self.platform) - .field("tenant_name", &self.tenant_name) - .finish() - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn create_mock_tf() -> RestTokenFetcher { - RestTokenFetcher { - access_token: Mutex::new(AccessToken::default()), - fetched_at: Mutex::new(Instant::now()), - client_id: "client_id".to_string(), - client_secret: "client_secret".to_string(), - client: reqwest::Client::new(), - auth_url: "http://localhost".to_string(), - } - } - - #[test] - fn test_access_token() { - let token_str = r#"{ - "access_token": "secret_access_token", - "expires_in": 600, - "refresh_expires_in": 0, - "token_type": "Bearer", - "not-before-policy": 0, - "scope": "email" - }"#; - let token: AccessToken = serde_json::from_str(token_str).unwrap(); - assert_eq!(token.access_token(), "secret_access_token"); - assert_eq!(token.expires_in(), 600); - assert_eq!(token.refresh_expires_in(), 0); - assert_eq!(token.token_type(), "Bearer"); - assert_eq!(token.not_before_policy(), 0); - assert_eq!(token.scope(), "email"); - assert_eq!(token.formatted_token(), "Bearer secret_access_token"); - } - - #[test] - fn test_access_token_default() { - let token = AccessToken::default(); - assert_eq!(token.access_token(), ""); - assert_eq!(token.expires_in(), 0); - assert_eq!(token.refresh_expires_in(), 0); - assert_eq!(token.token_type(), ""); - assert_eq!(token.not_before_policy(), 0); - assert_eq!(token.scope(), ""); - assert_eq!(token.formatted_token(), " "); - } - - #[test] - fn test_rest_token_fetcher_is_valid_default_token() { - // Test is_valid when validating default token (should expire in 0 seconds) - let tf = create_mock_tf(); - assert!(!tf.is_valid()); - } - - #[test] - fn test_rest_token_fetcher_is_valid_valid_token() { - let tf = create_mock_tf(); - tf.access_token.lock().unwrap().expires_in = 600; - assert!(tf.is_valid()); - } - - #[test] - fn test_rest_token_fetcher_is_valid_expired_token() { - // Test is_valid when validating an expired token - let tf = create_mock_tf(); - tf.access_token.lock().unwrap().expires_in = 600; - *tf.fetched_at.lock().unwrap() = Instant::now() - Duration::from_secs(600); - assert!(!tf.is_valid()); - } - - #[test] - fn test_rest_token_fetcher_is_valid_poisoned_token() { - // Test is_valid when token is poisoned - let tf = create_mock_tf(); - tf.access_token.lock().unwrap().expires_in = 600; - let tf_arc = std::sync::Arc::new(tf); - let tf_clone = tf_arc.clone(); - assert!(tf_arc.is_valid(), "Token should be valid"); - let h = std::thread::spawn(move || { - let _unused = tf_clone.access_token.lock().unwrap(); - panic!("Poison token") - }); - let _ = h.join(); - assert!(!tf_arc.is_valid(), "Token should be invalid"); - } - - #[tokio::test] - async fn test_fetch_access_token_from_server() { - let mut auth_server = mockito::Server::new_async().await; - auth_server - .mock("POST", "/") - .with_status(200) - .with_body( - r#"{ - "access_token": "secret_access_token", - "expires_in": 600, - "refresh_expires_in": 0, - "token_type": "Bearer", - "not-before-policy": 0, - "scope": "email" - }"#, - ) - .create(); - let mut tf = create_mock_tf(); - tf.auth_url = auth_server.url(); - let token = tf.fetch_access_token_from_server().await.unwrap(); - assert_eq!(token.access_token(), "secret_access_token"); - assert_eq!(token.expires_in(), 600); - assert_eq!(token.refresh_expires_in(), 0); - assert_eq!(token.token_type(), "Bearer"); - assert_eq!(token.not_before_policy(), 0); - assert_eq!(token.scope(), "email"); - } - - #[tokio::test] - async fn test_fetch_access_token_from_server_error() { - let mut auth_server = mockito::Server::new_async().await; - auth_server - .mock("POST", "/") - .with_status(400) - .with_body("Bad request") - .create(); - let mut tf = create_mock_tf(); - tf.auth_url = auth_server.url(); - let err = tf.fetch_access_token_from_server().await.unwrap_err(); - match err { - DshRestTokenError::StatusCode { - status_code, - error_body, - } => { - assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body.text().await.unwrap(), "Bad request"); - } - _ => panic!("Unexpected error: {:?}", err), - } - } - - #[test] - fn test_token_fetcher_builder_client_id() { - let platform = Platform::NpLz; - let client_id = "robot:dev-lz-dsh:my-tenant"; - let client_secret = "secret"; - let tf = RestTokenFetcherBuilder::new(platform) - .client_id(client_id.to_string()) - .client_secret(client_secret.to_string()) - .build() - .unwrap(); - assert_eq!(tf.client_id, client_id); - assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); - } - - #[test] - fn test_token_fetcher_builder_tenant_name() { - let platform = Platform::NpLz; - let tenant_name = "my-tenant"; - let client_secret = "secret"; - let tf = RestTokenFetcherBuilder::new(platform) - .tenant_name(tenant_name.to_string()) - .client_secret(client_secret.to_string()) - .build() - .unwrap(); - assert_eq!( - tf.client_id, - format!("robot:{}:{}", Platform::NpLz.realm(), tenant_name) - ); - assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); - } - - #[test] - fn test_token_fetcher_builder_custom_client() { - let platform = Platform::NpLz; - let client_id = "robot:dev-lz-dsh:my-tenant"; - let client_secret = "secret"; - let custom_client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); - let tf = RestTokenFetcherBuilder::new(platform) - .client_id(client_id.to_string()) - .client_secret(client_secret.to_string()) - .client(custom_client.clone()) - .build() - .unwrap(); - assert_eq!(tf.client_id, client_id); - assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); - } - - #[test] - fn test_token_fetcher_builder_client_id_precedence() { - let platform = Platform::NpLz; - let tenant = "my-tenant"; - let client_id_override = "override"; - let client_secret = "secret"; - let tf = RestTokenFetcherBuilder::new(platform) - .tenant_name(tenant.to_string()) - .client_id(client_id_override.to_string()) - .client_secret(client_secret.to_string()) - .build() - .unwrap(); - assert_eq!(tf.client_id, client_id_override); - assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); - } - - #[test] - fn test_token_fetcher_builder_build_error() { - let err = RestTokenFetcherBuilder::new(Platform::NpLz) - .client_secret("client_secret".to_string()) - .build() - .unwrap_err(); - assert!(matches!(err, DshRestTokenError::UnknownClientId)); - - let err = RestTokenFetcherBuilder::new(Platform::NpLz) - .tenant_name("tenant_name".to_string()) - .build() - .unwrap_err(); - assert!(matches!(err, DshRestTokenError::UnknownClientSecret)); + pub fn new(platform: crate::Platform) -> crate::ManagementApiTokenFetcherBuilder { + crate::ManagementApiTokenFetcherBuilder::new(platform) } } diff --git a/dsh_sdk/src/utils/dlq.rs b/dsh_sdk/src/utils/dlq.rs new file mode 100644 index 0000000..84d605f --- /dev/null +++ b/dsh_sdk/src/utils/dlq.rs @@ -0,0 +1,562 @@ +//! # Dead Letter Queue +//! This optional module contains an implementation of pushing unprocessable/invalid messages towards a Dead Letter Queue (DLQ). +//! +//! add feature `dlq` to your Cargo.toml to enable this module +//! +//! ### NOTE: +//! This module is meant for pushing messages towards a dead/retry topic only, it does and WILL not handle any logic for retrying messages. +//! Reason is, it can differ per use case what strategy is needed to retry messages and handle the dead letters. +//! +//! It is up to the user to implement the strategy and logic for retrying messages. +//! +//! ### How it works +//! The DLQ struct can +//! +//! ## How to use +//! 1. Implement the `ErrorToDlq` trait on top your (custom) error type. +//! 2. Initialize the `Dlq` struct in your service in main. +//! 3. Get the dlq channel sender from the `Dlq` struct and use this channel to communicate with the `Dlq` struct from other threads. +//! 4. Run the `Dlq` struct in a separate tokio thread. This will run the producer that will produce towards the dead/retry topics. +//! +//! The topics are set via environment variables DLQ_DEAD_TOPIC and DLQ_RETRY_TOPIC. +//! +//! ### Example: +//! See the examples folder on github for a working example. + +use std::collections::HashMap; +use std::env; +use std::str::from_utf8; + +use log::{debug, error, info, warn}; + +use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; +use rdkafka::producer::{FutureProducer, FutureRecord}; + +use tokio::sync::mpsc; + +use crate::graceful_shutdown::Shutdown; +use crate::Properties; + +/// Trait to convert an error to a dlq message +/// This trait is implemented for all errors that can and should be converted to a dlq message +/// +/// Example: +///``` +/// use dsh_sdk::dlq; +/// use std::backtrace::Backtrace; +/// use thiserror::Error; +/// +/// #[derive(Error, Debug)] +/// enum ConsumerError { +/// #[error("Deserialization error: {0}")] +/// DeserializeError(String), +/// } +/// +/// impl dlq::ErrorToDlq for ConsumerError { +/// fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { +/// dlq::SendToDlq::new(kafka_message, self.retryable(), self.to_string(), None) +/// } +/// fn retryable(&self) -> dlq::Retryable { +/// match self { +/// ConsumerError::DeserializeError(e) => dlq::Retryable::NonRetryable, +/// } +/// } +/// } +/// ``` +pub trait ErrorToDlq { + /// Convert error message to a dlq message + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq; + /// Match error if the orignal message is able to be retried or not + fn retryable(&self) -> Retryable; +} + +/// Struct with required details to send a channel message to the dlq +/// Error needs to be send as string, as it is not possible to send a struct that implements Error trait +pub struct SendToDlq { + kafka_message: OwnedMessage, + retryable: Retryable, + error: String, + stack_trace: Option, +} + +impl SendToDlq { + /// Create new SendToDlq message + pub fn new( + kafka_message: OwnedMessage, + retryable: Retryable, + error: String, + stack_trace: Option, + ) -> Self { + Self { + kafka_message, + retryable, + error, + stack_trace, + } + } + /// Send message to dlq channel + pub async fn send(self, dlq_tx: &mut mpsc::Sender) { + match dlq_tx.send(self).await { + Ok(_) => debug!("Message sent to DLQ channel"), + Err(e) => error!("Error sending message to DLQ: {}", e), + } + } + + fn get_original_msg(&self) -> OwnedMessage { + self.kafka_message.clone() + } +} + +/// Helper enum to decide to which topic the message should be sent to. +#[derive(Debug, Clone, Copy)] +pub enum Retryable { + Retryable, + NonRetryable, + Other, +} + +impl std::fmt::Display for Retryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Retryable::Retryable => write!(f, "Retryable"), + Retryable::NonRetryable => write!(f, "NonRetryable"), + Retryable::Other => write!(f, "Other"), + } + } +} + +/// Struct with implementation to send messages to the dlq +pub struct Dlq { + dlq_producer: FutureProducer, + dlq_rx: mpsc::Receiver, + dlq_tx: mpsc::Sender, + dlq_dead_topic: String, + dlq_retry_topic: String, + shutdown: Shutdown, +} + +impl Dlq { + /// Create new Dlq struct + pub fn new( + dsh_prop: &Properties, + shutdown: Shutdown, + ) -> Result> { + use crate::datastream::ReadWriteAccess; + let (dlq_tx, dlq_rx) = mpsc::channel(200); + let dlq_producer = Self::build_producer(dsh_prop)?; + let dlq_dead_topic = env::var("DLQ_DEAD_TOPIC")?; + let dlq_retry_topic = env::var("DLQ_RETRY_TOPIC")?; + dsh_prop.datastream().verify_list_of_topics( + &vec![&dlq_dead_topic, &dlq_retry_topic], + ReadWriteAccess::Write, + )?; + Ok(Self { + dlq_producer, + dlq_rx, + dlq_tx, + dlq_dead_topic, + dlq_retry_topic, + shutdown, + }) + } + + /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics + /// This function will run until the shutdown channel is closed + pub async fn run(&mut self) { + info!("DLQ started"); + loop { + tokio::select! { + _ = self.shutdown.recv() => { + warn!("DLQ shutdown"); + return; + }, + Some(mut dlq_message) = self.dlq_rx.recv() => { + match self.send(&mut dlq_message).await { + Ok(_) => {}, + Err(e) => error!("Error sending message to DLQ: {}", e), + }; + } + } + } + } + + /// Get the dlq channel sender. To be used in your service to send messages to the dlq in case of errors. + /// + /// This channel can be used to send messages to the dlq from different threads. + pub fn dlq_records_tx(&self) -> mpsc::Sender { + self.dlq_tx.clone() + } + + /// Create and send message towards the dlq + async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), rdkafka::error::KafkaError> { + let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); + let headers = orignal_kafka_msg + .generate_dlq_headers(dlq_message) + .to_owned_headers(); + let topic = self.dlq_topic(dlq_message.retryable); + let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); + let payload = orignal_kafka_msg.payload().unwrap_or_default(); + debug!("Sending message to DLQ topic: {}", topic); + let record = FutureRecord::to(topic) + .payload(payload) + .key(key) + .headers(headers); + let s = self.dlq_producer.send(record, None).await; + match s { + Ok((p, o)) => warn!( + "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", + from_utf8(key), + topic, + p, + o + ), + Err((e, _)) => return Err(e), + }; + Ok(()) + } + + fn dlq_topic(&self, retryable: Retryable) -> &str { + match retryable { + Retryable::Retryable => &self.dlq_retry_topic, + Retryable::NonRetryable => &self.dlq_dead_topic, + Retryable::Other => &self.dlq_dead_topic, + } + } + + fn build_producer(dsh_prop: &Properties) -> Result { + dsh_prop.producer_rdkafka_config().create() + } +} + +trait DlqHeaders { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>>; +} + +impl DlqHeaders for OwnedMessage { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>> { + let mut hashmap_headers: HashMap<&str, Option>> = HashMap::new(); + // Get original headers and add to hashmap + if let Some(headers) = self.headers() { + for header in headers.iter() { + hashmap_headers.insert(header.key, header.value.map(|v| v.to_vec())); + } + } + + // Add dlq headers if not exist (we don't want to overwrite original dlq headers if message already failed earlier) + let partition = self.partition().to_string().as_bytes().to_vec(); + let offset = self.offset().to_string().as_bytes().to_vec(); + let timestamp = self + .timestamp() + .to_millis() + .unwrap_or(-1) + .to_string() + .as_bytes() + .to_vec(); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_partition_origin") + .or_insert_with(move || Some(partition)); + hashmap_headers + .entry("dlq_partition_offset_origin") + .or_insert_with(move || Some(offset)); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_timestamp_origin") + .or_insert_with(move || Some(timestamp)); + // Overwrite if exist + hashmap_headers.insert( + "dlq_retryable", + Some(dlq_message.retryable.to_string().as_bytes().to_vec()), + ); + hashmap_headers.insert( + "dlq_error", + Some(dlq_message.error.to_string().as_bytes().to_vec()), + ); + if let Some(stack_trace) = &dlq_message.stack_trace { + hashmap_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + // update dlq_retries with +1 if exists, else add dlq_retries wiith 1 + let retries = hashmap_headers + .get("dlq_retries") + .map(|v| { + let mut retries = [0; 4]; + retries.copy_from_slice(v.as_ref().unwrap()); + i32::from_be_bytes(retries) + }) + .unwrap_or(0); + hashmap_headers.insert("dlq_retries", Some((retries + 1).to_be_bytes().to_vec())); + + hashmap_headers + } +} + +trait HashMapToKafkaHeaders { + fn to_owned_headers(&self) -> OwnedHeaders; +} + +impl HashMapToKafkaHeaders for HashMap<&str, Option>> { + fn to_owned_headers(&self) -> OwnedHeaders { + // Convert to OwnedHeaders + let mut owned_headers = OwnedHeaders::new_with_capacity(self.len()); + for header in self { + let value = header.1.as_ref().map(|value| value.as_slice()); + owned_headers = owned_headers.insert(Header { + key: header.0, + value, + }); + } + owned_headers + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rdkafka::config::ClientConfig; + use rdkafka::mocking::MockCluster; + + #[derive(Debug)] + enum MockError { + MockErrorRetryable(String), + MockErrorDead(String), + } + impl MockError { + fn to_string(&self) -> String { + match self { + MockError::MockErrorRetryable(e) => e.to_string(), + MockError::MockErrorDead(e) => e.to_string(), + } + } + } + + impl std::fmt::Display for MockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MockError::MockErrorRetryable(e) => write!(f, "{}", e), + MockError::MockErrorDead(e) => write!(f, "{}", e), + } + } + } + + impl ErrorToDlq for MockError { + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq { + let backtrace = "some_backtrace"; + SendToDlq::new( + kafka_message, + self.retryable(), + self.to_string(), + Some(backtrace.to_string()), + ) + } + + fn retryable(&self) -> Retryable { + match self { + MockError::MockErrorRetryable(_) => Retryable::Retryable, + MockError::MockErrorDead(_) => Retryable::NonRetryable, + } + } + } + + #[test] + fn test_dlq_get_original_msg() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let dlq_message = + MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); + let result = dlq_message.get_original_msg(); + assert_eq!( + result.payload(), + dlq_message.kafka_message.payload(), + "payoad does not match" + ); + assert_eq!( + result.key(), + dlq_message.kafka_message.key(), + "key does not match" + ); + assert_eq!( + result.topic(), + dlq_message.kafka_message.topic(), + "topic does not match" + ); + assert_eq!( + result.partition(), + dlq_message.kafka_message.partition(), + "partition does not match" + ); + assert_eq!( + result.offset(), + dlq_message.kafka_message.offset(), + "offset does not match" + ); + assert_eq!( + result.timestamp(), + dlq_message.kafka_message.timestamp(), + "timestamp does not match" + ); + } + + #[test] + fn test_dlq_hashmap_to_owned_headers() { + let mut hashmap: HashMap<&str, Option>> = HashMap::new(); + hashmap.insert("some_key", Some(b"key_value".to_vec())); + hashmap.insert("some_other_key", None); + let result: Vec<(&str, Option<&[u8]>)> = + vec![("some_key", Some(b"key_value")), ("some_other_key", None)]; + + let owned_headers = hashmap.to_owned_headers(); + for header in owned_headers.iter() { + assert!(result.contains(&(header.key, header.value))); + } + } + + #[test] + fn test_dlq_topic() { + let mock_cluster = MockCluster::new(1).unwrap(); + let mut producer = ClientConfig::new(); + producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); + let producer = producer.create().unwrap(); + let dlq = Dlq { + dlq_producer: producer, + dlq_rx: mpsc::channel(200).1, + dlq_tx: mpsc::channel(200).0, + dlq_dead_topic: "dead_topic".to_string(), + dlq_retry_topic: "retry_topic".to_string(), + shutdown: Shutdown::new(), + }; + let error = MockError::MockErrorRetryable("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "retry_topic"); + let error = MockError::MockErrorDead("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "dead_topic"); + } + + #[test] + fn test_dlq_generate_dlq_headers() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let error = Box::new(MockError::MockErrorRetryable("some_error".to_string())); + + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + + let mut dlq_message = error.to_dlq(owned_message.clone()); + + let mut expected_headers: HashMap<&str, Option>> = HashMap::new(); + expected_headers.insert("some_key", Some(b"some_value".to_vec())); + expected_headers.insert("dlq_topic_origin", Some(topic.as_bytes().to_vec())); + expected_headers.insert( + "dlq_partition_origin", + Some(partition.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_partition_offset_origin", + Some(offset.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_timestamp_origin", + Some(timestamp.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_retryable", + Some(Retryable::Retryable.to_string().as_bytes().to_vec()), + ); + expected_headers.insert("dlq_retries", Some(1_i32.to_be_bytes().to_vec())); + expected_headers.insert("dlq_error", Some(error.to_string().as_bytes().to_vec())); + if let Some(stack_trace) = &dlq_message.stack_trace { + expected_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + + let result = owned_message.generate_dlq_headers(&mut dlq_message); + for header in result.iter() { + assert_eq!( + header.1, + expected_headers.get(header.0).unwrap_or(&None), + "Header {} does not match", + header.0 + ); + } + + // Test if dlq headers are correctly overwritten when to be retried message was already retried before + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "dlq_error", + value: Some( + "to_be_overwritten_error_as_this_was_the_original_error_from_1st_retry".as_bytes(), + ), + }); + original_headers = original_headers.insert(Header { + key: "dlq_topic_origin", + value: Some(topic.as_bytes()), + }); + original_headers = original_headers.insert(Header { + key: "dlq_retries", + value: Some(&1_i32.to_be_bytes().to_vec()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + "retry_topic".to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let result = owned_message.generate_dlq_headers(&mut dlq_message); + assert_eq!( + result.get("dlq_error").unwrap(), + &Some(error.to_string().as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_topic_origin").unwrap(), + &Some(topic.as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_retries").unwrap(), + &Some(2_i32.to_be_bytes().to_vec()) + ); + } +} diff --git a/dsh_sdk/src/utils/graceful_shutdown.rs b/dsh_sdk/src/utils/graceful_shutdown.rs new file mode 100644 index 0000000..0354bf1 --- /dev/null +++ b/dsh_sdk/src/utils/graceful_shutdown.rs @@ -0,0 +1,203 @@ +//! Graceful shutdown +//! +//! This module provides a shutdown handle for graceful shutdown of (tokio tasks within) your service. +//! It listens for SIGTERM requests and sends out shutdown requests to all shutdown handles. +//! +//! It creates a clonable object which can be used to send shutdown request to all tasks. +//! Based on this request you are able to handle your shutdown procedure. +//! +//! This appproach is based on Tokio's graceful shutdown example: +//! +//! +//! # Example: +//! +//! ```no_run +//! use dsh_sdk::graceful_shutdown::Shutdown; +//! +//! // your process task +//! async fn process_task(shutdown: Shutdown) { +//! loop { +//! tokio::select! { +//! _ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => { +//! // Do something here, e.g. consume messages from Kafka +//! println!("Still processing the task") +//! }, +//! _ = shutdown.recv() => { +//! // shutdown request received, include your shutdown procedure here e.g. close db connection +//! println!("Gracefully exiting process_task"); +//! break; +//! }, +//! } +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! // Create shutdown handle +//! let shutdown = dsh_sdk::graceful_shutdown::Shutdown::new(); +//! // Create your process task with a cloned shutdown handle +//! let process_task = process_task(shutdown.clone()); +//! // Spawn your process task in a tokio runtime +//! let process_task_handle = tokio::spawn(async move { +//! process_task.await; +//! }); +//! +//! // Listen for shutdown request or if process task stopped +//! // If your process stops, start shutdown procedure to stop other tasks (if any) +//! tokio::select! { +//! _ = shutdown.signal_listener() => println!("Exit signal received!"), +//! _ = process_task_handle => {println!("process_task stopped"); shutdown.start()}, +//! } +//! // Wait till shutdown procedures is finished +//! let _ = shutdown.complete().await; +//! println!("Exiting main...") +//! } +//! ``` + +use log::{info, warn}; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +/// Shutdown handle to interact on SIGTERM of DSH for a graceful shutdown. +/// +/// Use original to wait for shutdown complete. +/// Use clone to send shutdown request to all shutdown handles. +/// +/// see [dsh_sdk::graceful_shutdown](index.html) for full implementation example. +#[derive(Debug)] +pub struct Shutdown { + cancel_token: CancellationToken, + shutdown_complete_tx: mpsc::Sender<()>, + shutdown_complete_rx: Option>, +} + +impl Shutdown { + /// Create new shutdown handle. + /// Returns shutdown handle and shutdown complete receiver. + /// Shutdown complete receiver is used to wait for all tasks to finish. + /// + /// NOTE: Make sure to clone shutdown handles to use it in other components/tasks. + /// Use orignal in main and receive shutdown complete. + pub fn new() -> Self { + let cancel_token = CancellationToken::new(); + let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); + Self { + cancel_token, + shutdown_complete_tx, + shutdown_complete_rx: Some(shutdown_complete_rx), + } + } + + /// Send out internal shutdown request to all Shutdown handles, so they can start their shutdown procedure. + pub fn start(&self) { + self.cancel_token.cancel(); + } + + /// Listen to internal shutdown request. + /// Based on this you can start shutdown procedure in your component/task. + pub async fn recv(&self) { + self.cancel_token.cancelled().await; + } + + /// Listen for external shutdown request coming from DSH (SIGTERM) or CTRL-C/SIGINT and start shutdown procedure. + /// + /// Compatible with Unix (SIGINT and SIGTERM) and Windows (SIGINT). + pub async fn signal_listener(&self) { + let ctrl_c_signal = tokio::signal::ctrl_c(); + #[cfg(unix)] + let mut sigterm_signal = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap(); + #[cfg(unix)] + tokio::select! { + _ = ctrl_c_signal => {}, + _ = sigterm_signal.recv() => {} + } + #[cfg(windows)] + let _ = ctrl_c_signal.await; + + warn!("Shutdown signal received!"); + self.start(); + } + + /// This function can only be called by the original shutdown handle. + /// + /// Check if all tasks are finished and shutdown complete. + /// This function should be awaited after all tasks are spawned. + pub async fn complete(self) { + // drop original shutdown_complete_tx, else it would await forever + drop(self.shutdown_complete_tx); + self.shutdown_complete_rx.unwrap().recv().await; + info!("Shutdown complete!") + } +} + +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} + +impl std::clone::Clone for Shutdown { + /// Clone shutdown handle. + /// + /// Use this handle in your components/tasks. + fn clone(&self) -> Self { + Self { + cancel_token: self.cancel_token.clone(), + shutdown_complete_tx: self.shutdown_complete_tx.clone(), + shutdown_complete_rx: None, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use super::*; + use tokio::time::Duration; + + #[tokio::test] + async fn test_shutdown_recv() { + let shutdown = Shutdown::new(); + let shutdown_clone = shutdown.clone(); + // receive shutdown task + let task = tokio::spawn(async move { + shutdown_clone.recv().await; + 1 + }); + // start shutdown task after 200 ms + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + shutdown.start(); + }); + // if shutdown is not received within 5 seconds, fail test + let check_value = tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(5)) => panic!("Shutdown not received within 5 seconds"), + v = task => v.unwrap(), + }; + assert_eq!(check_value, 1); + } + + #[tokio::test] + async fn test_shutdown_wait_for_complete() { + let shutdown = Shutdown::new(); + let shutdown_clone = shutdown.clone(); + let check_value: Arc> = Arc::new(Mutex::new(false)); + let check_value_clone = Arc::clone(&check_value); + // receive shutdown task + tokio::spawn(async move { + shutdown_clone.recv().await; + tokio::time::sleep(Duration::from_millis(200)).await; + let mut check: std::sync::MutexGuard<'_, bool> = check_value_clone.lock().unwrap(); + *check = true; + }); + shutdown.start(); + shutdown.complete().await; + let check = check_value.lock().unwrap(); + assert_eq!( + *check, true, + "shutdown did not succesfully wait for complete" + ); + } +} diff --git a/dsh_sdk/src/utils/metrics.rs b/dsh_sdk/src/utils/metrics.rs new file mode 100644 index 0000000..5c05587 --- /dev/null +++ b/dsh_sdk/src/utils/metrics.rs @@ -0,0 +1,334 @@ +//! This module wraps the prometheus metrics library and provides a http server to expose the metrics. +//! +//! It is technically a re-exports the prometheus metrics library with some additional functions. +//! +//! # Create custom metrics +//! +//! To define custom metrics, the prometheus macros can be used. They are re-exported in this module. +//! +//! As they are a pub static reference, you can use them anywhere in your code. +//! +//! See [prometheus](https://docs.rs/prometheus/0.13.3/prometheus/index.html#macros) for more information. +//! +//! ### Example +//! ``` +//! use dsh_sdk::metrics::*; +//! +//! lazy_static! { +//! pub static ref HIGH_FIVE_COUNTER: IntCounter = +//! register_int_counter!("highfives", "Number of high fives recieved").unwrap(); +//! } +//! +//! HIGH_FIVE_COUNTER.inc(); +//! ``` +//! +//! # Expose metrics to DSH / HTTP Server +//! +//! This module provides a http server to expose the metrics to DSH. A port number needs to be defined. +//! +//! ### Example: +//! ``` +//! use dsh_sdk::metrics::start_http_server; +//!#[tokio::main] +//!async fn main() { +//! start_http_server(9090); +//!} +//! ``` +//! After starting the http server, the metrics can be found at http://localhost:8080/metrics. +//! To expose the metrics to DSH, the port number needs to be defined in the DSH service configuration. +//! +//! ```json +//! "metrics": { +//! "port": 9090, +//! "path": "/metrics" +//! }, +//! ``` +//! +//! And in your dockerfile expose the port: +//! ```dockerfile +//! EXPOSE 9090 +//! ``` + +use std::net::SocketAddr; + +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{header, Method, Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +pub use lazy_static::lazy_static; +use log::{error, warn}; +pub use prometheus::*; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +use crate::error::DshError; + +type DshResult = std::result::Result; +type BoxBody = http_body_util::combinators::BoxBody; + +static NOTFOUND: &[u8] = b"404: Not Found"; + +/// Start a http server to expose prometheus metrics. +/// +/// The exposed endpoint is /metrics and port number needs to be defined. The server will run on a separate thread +/// and this function will return a JoinHandle of the thread. It is optional to handle the thread status. If left unhandled, +/// the server will run until the main thread is stopped. +/// +/// # Note! +/// Don't forget to expose the port in your dockerfile and add the port number to the DSH service configuration. +///```Dockerfile +/// EXPOSE 9090 +/// ``` +/// +/// # Example +/// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. +/// ```rust +/// use dsh_sdk::metrics::start_http_server; +/// +/// #[tokio::main] +/// async fn main() { +/// start_http_server(9090); +/// } +/// ``` +/// +/// # Optional: Check http server thread status +/// Await the JoinHandle in a a tokio select besides your application logic to check if the server is still running. +/// ```rust +/// use dsh_sdk::metrics::start_http_server; +/// # use tokio::time::sleep; +/// # use std::time::Duration; +/// +/// #[tokio::main] +/// async fn main() { +/// let server = start_http_server(9090); +/// tokio::select! { +/// // Replace sleep with your application logic +/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, +/// // Check if the server is still running +/// tokio_result = server => { +/// match tokio_result { +/// Ok(server_result) => if let Err(e) = server_result { +/// eprintln!("Metrics server operation failed: {}", e); +/// }, +/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), +/// } +/// } +/// } +/// } +/// ``` +pub fn start_http_server(port: u16) -> JoinHandle> { + tokio::spawn(async move { + let result = run_server(port).await; + warn!("HTTP server stopped: {:?}", result); + result + }) +} + +/// Encode metrics to a string (UTF8) +pub fn metrics_to_string() -> DshResult { + let encoder = prometheus::TextEncoder::new(); + + let mut buffer = Vec::new(); + encoder.encode(&prometheus::gather(), &mut buffer)?; + Ok(String::from_utf8(buffer)?) +} + +async fn run_server(port: u16) -> DshResult<()> { + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + let listener = TcpListener::bind(addr).await?; + + loop { + let (stream, _) = listener.accept().await?; + tokio::spawn(handle_connection(stream)); + } +} + +async fn handle_connection(stream: tokio::net::TcpStream) { + let io = TokioIo::new(stream); + let service = service_fn(routes); + + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + error!("Failed to serve connection: {:?}", err); + } +} + +async fn routes(req: Request) -> DshResult> { + match (req.method(), req.uri().path()) { + (&Method::GET, "/metrics") => get_metrics(), + (_, _) => not_found(), + } +} + +fn get_metrics() -> DshResult> { + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, prometheus::TEXT_FORMAT) + .body(full(metrics_to_string().unwrap_or_default()))?) +} + +fn not_found() -> DshResult> { + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(full(NOTFOUND))?) +} + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +#[cfg(test)] +mod tests { + use super::*; + use http_body_util::Empty; + use hyper::body::Body; + use hyper::client::conn; + use hyper::client::conn::http1::{Connection, SendRequest}; + use hyper::http::HeaderValue; + use hyper::Uri; + use serial_test::serial; + use tokio::net::TcpStream; + + const PORT: u16 = 9090; + + lazy_static! { + pub static ref HIGH_FIVE_COUNTER: IntCounter = + register_int_counter!("highfives", "Number of high fives recieved").unwrap(); + } + + async fn create_client( + url: &Uri, + ) -> ( + SendRequest>, + Connection, Empty>, + ) { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(PORT); + let addr = format!("{}:{}", host, port); + + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + + conn::http1::handshake(io).await.unwrap() + } + + fn to_get_req(url: &Uri) -> Request> { + Request::builder() + .uri(url) + .method(Method::GET) + .header(header::HOST, url.authority().unwrap().clone().as_str()) + .body(Empty::::new()) + .unwrap() + } + + #[tokio::test] + async fn test_http_metric_response() { + // Increment the counter + HIGH_FIVE_COUNTER.inc(); + + // Call the function + let res = get_metrics(); + + // Check if the function returns a result + assert!(res.is_ok()); + + // Check if the result is not an empty string + let response = res.unwrap(); + let status_code = response.status(); + + assert_eq!(status_code, StatusCode::OK); + assert!(response.body().size_hint().exact().unwrap() > 0); + assert_eq!( + response.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static(prometheus::TEXT_FORMAT) + ); + } + + #[tokio::test] + #[serial(port_usage)] + async fn test_start_http_server() { + // Start HTTP server + let server = start_http_server(PORT); + + // increment the counter + HIGH_FIVE_COUNTER.inc(); + + // Give the server a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let url: Uri = format!("http://localhost:{PORT}/metrics").parse().unwrap(); + let (mut request_sender, connection) = create_client(&url).await; + tokio::task::spawn(async move { + if let Err(err) = connection.await { + error!("Connection failed: {:?}", err); + } + }); + + // Send a request to the server + let request = to_get_req(&url); + let response = request_sender.send_request(request).await.unwrap(); + + // Check if the server returns a 200 status + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static(prometheus::TEXT_FORMAT) + ); + + // Check if the response body is not empty + let buf = response.collect().await.unwrap().to_bytes(); + let res = String::from_utf8(buf.to_vec()).unwrap(); + + println!("{}", res); + assert!(!res.is_empty()); + + // Terminate the server + server.abort(); + } + + #[tokio::test] + #[serial(port_usage)] + async fn test_unknown_path() { + // Start HTTP server + let server = start_http_server(PORT); + + // Give the server a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let url: Uri = format!("http://localhost:{PORT}").parse().unwrap(); + let (mut request_sender, connection) = create_client(&url).await; + tokio::task::spawn(async move { + if let Err(err) = connection.await { + error!("Connection failed: {:?}", err); + } + }); + + // Send a request to the server + let request = to_get_req(&url); + + let response = request_sender.send_request(request).await.unwrap(); + + // Check if the server returns a 404 status + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // Check if the response body is not empty + let buf = response.collect().await.unwrap().to_bytes(); + let res = String::from_utf8(buf.to_vec()).unwrap(); + + assert_eq!(res, String::from_utf8_lossy(NOTFOUND)); + + // Terminate the server + server.abort(); + } + + #[test] + fn test_metrics_to_string() { + HIGH_FIVE_COUNTER.inc(); + let res = metrics_to_string().unwrap(); + assert!(res.contains("highfives")); + } +} diff --git a/dsh_sdk/src/utils.rs b/dsh_sdk/src/utils/mod.rs similarity index 82% rename from dsh_sdk/src/utils.rs rename to dsh_sdk/src/utils/mod.rs index f736a24..6209fb8 100644 --- a/dsh_sdk/src/utils.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -1,9 +1,19 @@ -//! Utility functions for the SDK +//! Utilities for DSH +//! +//! This module contains helpful functions and utilities for interacting with DSH. +use std::env; + +use log::{debug, info, warn}; use super::{VAR_APP_ID, VAR_DSH_TENANT_NAME}; use crate::error::DshError; -use log::{debug, info, warn}; -use std::env; + +#[cfg(feature = "dlq")] +pub mod dlq; +#[cfg(feature = "graceful-shutdown")] +pub mod graceful_shutdown; +#[cfg(feature = "metrics")] +pub mod metrics; /// Available DSH platforms plus it's related metadata /// @@ -131,8 +141,21 @@ impl Platform { /// Get the configured topics from the environment variable TOPICS /// Topics can be delimited by a comma +/// +/// ## Example +/// ``` +/// # use dsh_sdk::utils::get_configured_topics; +/// std::env::set_var("TOPICS", "topic1, topic2, topic3"); +/// +/// let topics = get_configured_topics().unwrap(); +/// +/// assert_eq!(topics[0], "topic1"); +/// assert_eq!(topics[1], "topic2"); +/// assert_eq!(topics[2], "topic3"); +/// # std::env::remove_var("TOPICS"); +/// ``` pub fn get_configured_topics() -> Result, DshError> { - let kafka_topic_string = env::var("TOPICS")?; + let kafka_topic_string = get_env_var("TOPICS")?; Ok(kafka_topic_string .split(',') .map(str::trim) @@ -142,9 +165,30 @@ pub fn get_configured_topics() -> Result, DshError> { /// Get the tenant name from the environment variables /// -/// Derive the tenant name from the MARATHON_APP_ID or DSH_TENANT_NAME environment variables. +/// Derive the tenant name from the `MARATHON_APP_ID` or `DSH_TENANT_NAME` environment variables. /// Returns `NoTenantName` error if neither of the environment variables are set. -pub(crate) fn tenant_name() -> Result { +/// +/// ## Example +/// ``` +/// # use dsh_sdk::utils::tenant_name; +/// # use dsh_sdk::error::DshError; +/// std::env::set_var("MARATHON_APP_ID", "/dsh-tenant-name/app-name"); // Injected by DSH by default +/// +/// let tenant = tenant_name().unwrap(); +/// assert_eq!(&tenant, "dsh-tenant-name"); +/// # std::env::remove_var("MARATHON_APP_ID"); +/// +/// std::env::set_var("DSH_TENANT_NAME", "your-tenant-name"); // Set by user, useful when running outside of DSH together with Kafka Proxy or VPN +/// let tenant = tenant_name().unwrap(); +/// assert_eq!(&tenant, "your-tenant-name"); +/// # std::env::remove_var("DSH_TENANT_NAME"); +/// +/// // If neither of the environment variables are set, it will return an error +/// let result = tenant_name(); +/// assert!(matches!(result, Err(DshError::NoTenantName))); +/// ``` + +pub fn tenant_name() -> Result { if let Ok(app_id) = get_env_var(VAR_APP_ID) { let tenant_name = app_id.split('/').nth(1); match tenant_name { @@ -160,6 +204,7 @@ pub(crate) fn tenant_name() -> Result { } else if let Ok(tenant_name) = get_env_var(VAR_DSH_TENANT_NAME) { Ok(tenant_name) } else { + log::error!("{} and {} are not set, this may cause unexpected behaviour when connecting to DSH Kafka cluster!. Please set one of these environment variables.", VAR_DSH_TENANT_NAME, VAR_APP_ID); Err(DshError::NoTenantName) } } @@ -174,7 +219,7 @@ pub(crate) fn get_env_var(var_name: &str) -> Result { Ok(value) => Ok(value), Err(e) => { info!("{} is not set", var_name); - Err(e.into()) + Err(DshError::EnvVarError(var_name.to_string(), e)) } } } diff --git a/example_dsh_service/src/main.rs b/example_dsh_service/src/main.rs index 474b7da..71a6704 100644 --- a/example_dsh_service/src/main.rs +++ b/example_dsh_service/src/main.rs @@ -64,7 +64,7 @@ async fn main() -> Result<(), Box> { // Validate your configured topic if it has read access (optional) dsh_properties .datastream() - .verify_list_of_topics(&topics, dsh_sdk::dsh::datastream::ReadWriteAccess::Read)?; + .verify_list_of_topics(&topics, dsh_sdk::dsh_old::datastream::ReadWriteAccess::Read)?; // Initialize the shutdown handler (This will handle SIGTERM and SIGINT signals, and you can act on them) let shutdown = Shutdown::new(); From ee2bfb2b09b730d1c20f243e02eb1ee32cc78485 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Tue, 24 Dec 2024 15:38:22 +0100 Subject: [PATCH 02/23] Feature/schema store api (#96) * Add schema store api --- .github/workflows/branch.yaml | 2 +- .gitignore | 2 + dsh_sdk/CHANGELOG.md | 4 + dsh_sdk/Cargo.toml | 23 +- dsh_sdk/src/certificates/mod.rs | 41 +- dsh_sdk/src/datastream.rs | 3 + dsh_sdk/src/dsh.rs | 43 +- dsh_sdk/src/dsh_old/properties.rs | 6 +- dsh_sdk/src/error.rs | 2 +- dsh_sdk/src/lib.rs | 5 +- dsh_sdk/src/schema_store/api.rs | 431 ++++++++++++++++++ dsh_sdk/src/schema_store/client.rs | 371 +++++++++++++++ dsh_sdk/src/schema_store/error.rs | 28 ++ dsh_sdk/src/schema_store/mod.rs | 70 +++ dsh_sdk/src/schema_store/request.rs | 133 ++++++ .../src/schema_store/types/compatibility.rs | 69 +++ dsh_sdk/src/schema_store/types/mod.rs | 12 + dsh_sdk/src/schema_store/types/schema/id.rs | 25 + dsh_sdk/src/schema_store/types/schema/mod.rs | 10 + .../schema_store/types/schema/raw_schema.rs | 325 +++++++++++++ .../src/schema_store/types/schema/schema.rs | 44 ++ dsh_sdk/src/schema_store/types/schema/type.rs | 41 ++ .../schema_store/types/subject_strategy.rs | 319 +++++++++++++ dsh_sdk/src/schema_store/types/subjects.rs | 151 ++++++ dsh_sdk/src/utils/http_client.rs | 110 +++++ dsh_sdk/src/utils/mod.rs | 2 + example_dsh_service/Cargo.toml | 2 +- 27 files changed, 2224 insertions(+), 50 deletions(-) create mode 100644 dsh_sdk/src/schema_store/api.rs create mode 100644 dsh_sdk/src/schema_store/client.rs create mode 100644 dsh_sdk/src/schema_store/error.rs create mode 100644 dsh_sdk/src/schema_store/mod.rs create mode 100644 dsh_sdk/src/schema_store/request.rs create mode 100644 dsh_sdk/src/schema_store/types/compatibility.rs create mode 100644 dsh_sdk/src/schema_store/types/mod.rs create mode 100644 dsh_sdk/src/schema_store/types/schema/id.rs create mode 100644 dsh_sdk/src/schema_store/types/schema/mod.rs create mode 100644 dsh_sdk/src/schema_store/types/schema/raw_schema.rs create mode 100644 dsh_sdk/src/schema_store/types/schema/schema.rs create mode 100644 dsh_sdk/src/schema_store/types/schema/type.rs create mode 100644 dsh_sdk/src/schema_store/types/subject_strategy.rs create mode 100644 dsh_sdk/src/schema_store/types/subjects.rs create mode 100644 dsh_sdk/src/utils/http_client.rs diff --git a/.github/workflows/branch.yaml b/.github/workflows/branch.yaml index 9ca519e..dd0c600 100644 --- a/.github/workflows/branch.yaml +++ b/.github/workflows/branch.yaml @@ -57,5 +57,5 @@ jobs: run: cargo install cargo-hack --locked if: matrix.version == 'stable' - name: cargo check all features - run: cargo hack check --feature-powerset --no-dev-deps + run: cargo hack check --feature-powerset --depth 3 --no-dev-deps if: matrix.version == 'stable' diff --git a/.gitignore b/.gitignore index 22633c6..d581112 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ test_files/ *.iws *.iml *.ipr + +.env \ No newline at end of file diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index c0eac99..4548363 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add support reading private key in DER format when reading from PKI_CONFIG_DIR ### Changed +- **Breaking change:** `dsh_sdk::Dsh::reqwest_client_config` now returns `reqwest::ClientConfig` instead of `Result` +- **Breaking change:** `dsh_sdk::Dsh::reqwest_blocking_client_config` now returns `reqwest::ClientConfig` instead of `Result` + +### Moved - Moved `dsh_sdk::dsh::properties` to `dsh_sdk::propeties` - Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` - **NOTE** Cargo.toml feature flag falls now under `management_api` (`rest-token-fetcher` will be removed in v0.6.0) diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 0dfa5b3..7dd35be 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -9,43 +9,51 @@ license.workspace = true name = "dsh_sdk" readme = 'README.md' repository.workspace = true -version = "0.4.10" +version = "0.5.0" [package.metadata.docs.rs] all-features = true [dependencies] +apache-avro = {version = "0.17", optional = true } base64 = {version = "0.22", optional = true } bytes = { version = "1.6", optional = true } http-body-util = { version = "0.1", optional = true } -hyper = { version = "1.3", features = ["server", "http1"], optional = true } -hyper-util = { version = "0.1", features = ["tokio"], optional = true } +hyper = { version = "1.5", features = ["client", "http1"], optional = true } +hyper-util = { version = "0.1", features = ["tokio","client-legacy"], optional = true } +hyper-rustls = { version = "0.27",features = ["ring","http1", "native-tokio", "logging"], default-features = false, optional = true } +http = { version = "1.2", optional = true } +rustls = { version = "0.23", features = ["ring", "tls12", "logging"], default-features = false, optional = true } +rustls-pemfile = { version = "2.2", optional = true } lazy_static = { version = "1.5", optional = true } log = "0.4" pem = {version = "3", optional = true } prometheus = { version = "0.13", features = ["process"], optional = true } +protofish = { version = "0.5.2", optional = true } rcgen = { version = "0.13", optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "blocking"], optional = true } rdkafka = { version = "0.36", features = ["cmake-build"], optional = true } serde = { version = "1.0", features = ["derive"] } -serde_json = { version = "1.0", optional = true } +serde_json = { version = "1.0", features = ["preserve_order"], optional = true } sha2 = { version = "0.10", optional = true} thiserror = "2.0" tokio = { version = "^1.35", features = ["signal", "sync", "time", "macros"], optional = true } tokio-util = { version = "0.7", default-features = false, optional = true } [features] -default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl"] +default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl", "schema-store"] bootstrap = ["certificate", "serde_json", "tokio/rt-multi-thread"] certificate = ["rcgen", "reqwest", "pem"] -metrics = ["prometheus", "hyper", "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] +schema-store = ["bootstrap", "reqwest", "serde_json", "apache-avro", "protofish"] +metrics = ["prometheus", "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] dlq = ["tokio", "bootstrap", "rdkafka-ssl", "graceful-shutdown"] graceful-shutdown = ["tokio", "tokio-util"] management-api = ["reqwest"] protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] # http-protocol-adapter = ["protocol-token-fetcher"] # mqtt-protocol-adapter = ["protocol-token-fetcher"] +# hyper-client = ["hyper", "hyper-util", "hyper-rustls", "rustls", "http", "rustls-pemfile"] # TODO: Remove the following features at v0.6.0 rdkafka-ssl-vendored = ["rdkafka", "rdkafka/ssl-vendored", "rdkafka/cmake-build"] @@ -60,4 +68,5 @@ tokio = { version = "^1.35", features = ["full"] } hyper = { version = "1.3", features = ["full"]} serial_test = "3.1.0" dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.2.0"} -dsh_sdk = { features = ["dlq"], path = "." } \ No newline at end of file +dsh_sdk = { features = ["dlq"], path = "." } +env_logger = "0.11" \ No newline at end of file diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs index c917399..4caea9e 100644 --- a/dsh_sdk/src/certificates/mod.rs +++ b/dsh_sdk/src/certificates/mod.rs @@ -134,30 +134,38 @@ impl Cert { /// Build an async reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. - pub fn reqwest_client_config(&self) -> Result { + #[deprecated( + since = "0.5.0", + note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" + )] + pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), &self.private_key_pem(), self.dsh_ca_certificate_pem(), - )?; - Ok(reqwest::Client::builder() + ); + reqwest::Client::builder() .add_root_certificate(reqwest_cert) .identity(pem_identity) - .use_rustls_tls()) + .use_rustls_tls() } /// Build a reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. - pub fn reqwest_blocking_client_config(&self) -> Result { + #[deprecated( + since = "0.5.0", + note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" + )] + pub fn reqwest_blocking_client_config(&self) -> ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), &self.private_key_pem(), self.dsh_ca_certificate_pem(), - )?; - Ok(Client::builder() + ); + Client::builder() .add_root_certificate(reqwest_cert) .identity(pem_identity) - .use_rustls_tls()) + .use_rustls_tls() } /// Get the root certificate as PEM string. Equivalent to ca.crt. @@ -231,15 +239,21 @@ impl Cert { reqwest::Identity::from_pem(&ident) } + /// Panics when the certificate or key is not valid. + /// However, these are already validated during the creation of the `Cert` struct and converted if nedded. fn prepare_reqwest_client( kafka_certificate: &str, private_key: &str, ca_certificate: &str, - ) -> Result<(reqwest::Identity, reqwest::tls::Certificate), DshError> { + ) -> (reqwest::Identity, reqwest::tls::Certificate) { let pem_identity = - Cert::create_identity(kafka_certificate.as_bytes(), private_key.as_bytes())?; - let reqwest_cert = reqwest::tls::Certificate::from_pem(ca_certificate.as_bytes())?; - Ok((pem_identity, reqwest_cert)) + Cert::create_identity(kafka_certificate.as_bytes(), private_key.as_bytes()).expect( + "Error creating identity. The kafka certificate or key is not valid. Please check the certificate and key.", + ); + let reqwest_cert = reqwest::tls::Certificate::from_pem(ca_certificate.as_bytes()).expect( + "Error parsing CA certificate as PEM to be used in Reqwest. The certificate is not valid. Please check the certificate.", + ); + (pem_identity, reqwest_cert) } } @@ -358,21 +372,18 @@ mod tests { &cert.private_key_pem(), cert.dsh_ca_certificate_pem(), ); - assert!(result.is_ok()); } #[test] fn test_reqwest_client_config() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); let client = cert.reqwest_client_config(); - assert!(client.is_ok()); } #[test] fn test_reqwest_blocking_client_config() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); let client = cert.reqwest_blocking_client_config(); - assert!(client.is_ok()); } #[test] diff --git a/dsh_sdk/src/datastream.rs b/dsh_sdk/src/datastream.rs index 3687268..99c020c 100644 --- a/dsh_sdk/src/datastream.rs +++ b/dsh_sdk/src/datastream.rs @@ -242,6 +242,9 @@ impl Datastream { if let Ok(brokers) = utils::get_env_var(VAR_KAFKA_BOOTSTRAP_SERVERS) { datastream.brokers = brokers.split(',').map(|s| s.to_string()).collect(); } + if let Ok(schema_store) = utils::get_env_var(VAR_SCHEMA_REGISTRY_HOST) { + datastream.schema_store = schema_store; + } Ok(datastream) } } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index dc3f3e9..a948b85 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -151,8 +151,8 @@ impl Dsh { }; let fetched_datastreams = certificates.as_ref().and_then(|cert| { cert.reqwest_blocking_client_config() + .build() .ok() - .and_then(|cb| cb.build().ok()) .and_then(|client| { Datastream::fetch_blocking(&client, &config_host, &tenant_name, &task_id).ok() }) @@ -169,8 +169,6 @@ impl Dsh { /// Get reqwest async client config to connect to DSH Schema Registry. /// If certificates are present, it will use SSL to connect to Schema Registry. /// - /// Use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) to connect to Schema Registry. - /// /// # Example /// ``` /// # use dsh_sdk::Dsh; @@ -178,16 +176,20 @@ impl Dsh { /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let dsh_properties = Dsh::get(); - /// let client = dsh_properties.reqwest_client_config()?.build()?; + /// let client = dsh_properties.reqwest_client_config().build()?; /// # Ok(()) /// # } /// ``` - pub fn reqwest_client_config(&self) -> Result { - let mut client_builder = reqwest::Client::builder(); + #[deprecated( + since = "0.5.0", + note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" + )] + pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_client_config()?; + certificates.reqwest_client_config() + } else { + reqwest::Client::builder() } - Ok(client_builder) } /// Get reqwest blocking client config to connect to DSH Schema Registry. @@ -202,18 +204,20 @@ impl Dsh { /// # use dsh_sdk::error::DshError; /// # fn main() -> Result<(), DshError> { /// let dsh_properties = Dsh::get(); - /// let client = dsh_properties.reqwest_blocking_client_config()?.build()?; + /// let client = dsh_properties.reqwest_blocking_client_config().build()?; /// # Ok(()) /// # } - pub fn reqwest_blocking_client_config( - &self, - ) -> Result { - let mut client_builder: reqwest::blocking::ClientBuilder = - reqwest::blocking::Client::builder(); + /// ``` + #[deprecated( + since = "0.5.0", + note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" + )] + pub fn reqwest_blocking_client_config(&self) -> reqwest::blocking::ClientBuilder { if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_blocking_client_config()?; + certificates.reqwest_blocking_client_config() + } else { + reqwest::blocking::Client::builder() } - Ok(client_builder) } /// Get the certificates and private key. Returns an error when running on local machine. @@ -227,6 +231,7 @@ impl Dsh { /// let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); /// # Ok(()) /// # } + /// ``` pub fn certificates(&self) -> Result<&Cert, DshError> { if let Some(cert) = &self.certificates { Ok(cert) @@ -271,7 +276,6 @@ impl Dsh { let client = ASYNC_CLIENT.get_or_init(|| { self.reqwest_client_config() - .expect("Failed loading certificates into reqwest client config") .build() .expect("Could not build reqwest client for fetching datastream") }); @@ -292,7 +296,6 @@ impl Dsh { let client = BLOCKING_CLIENT.get_or_init(|| { self.reqwest_blocking_client_config() - .expect("Failed loading certificates into reqwest client config") .build() .expect("Could not build reqwest client for fetching datastream") }); @@ -511,8 +514,8 @@ mod tests { #[serial(env_dependency)] fn test_reqwest_client_config() { let properties = Dsh::default(); - let config = properties.reqwest_client_config(); - assert!(config.is_ok()); + let _ = properties.reqwest_client_config(); + assert!(true) } #[test] diff --git a/dsh_sdk/src/dsh_old/properties.rs b/dsh_sdk/src/dsh_old/properties.rs index 32e347d..3e8a971 100644 --- a/dsh_sdk/src/dsh_old/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -149,8 +149,8 @@ impl Properties { }; let fetched_datastreams = certificates.as_ref().and_then(|cert| { cert.reqwest_blocking_client_config() + .build() .ok() - .and_then(|cb| cb.build().ok()) .and_then(|client| { datastream::Datastream::fetch_blocking( &client, @@ -189,7 +189,7 @@ impl Properties { pub fn reqwest_client_config(&self) -> Result { let mut client_builder = reqwest::Client::builder(); if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_client_config()?; + client_builder = certificates.reqwest_client_config(); } Ok(client_builder) } @@ -215,7 +215,7 @@ impl Properties { let mut client_builder: reqwest::blocking::ClientBuilder = reqwest::blocking::Client::builder(); if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_blocking_client_config()?; + client_builder = certificates.reqwest_blocking_client_config(); } Ok(client_builder) } diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index 1996b1d..ee8d5e4 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -21,7 +21,7 @@ pub enum DshError { #[cfg(feature = "bootstrap")] #[error("Certificates are not set")] NoCertificates, - #[cfg(any(feature = "bootstrap", feature = "pki-config-dir"))] + #[cfg(feature = "bootstrap")] #[error("Invalid PEM certificate: {0}")] PemError(#[from] pem::PemError), #[cfg(any(feature = "certificate", feature = "protocol-token-fetcher"))] diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 47e58b3..b2fd253 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -72,8 +72,6 @@ //! The DLQ is implemented by running the `Dlq` struct to push messages towards the DLQ topics. //! The `ErrorToDlq` trait can be implemented on your defined errors, to be able to send messages towards the DLQ Struct. -#![allow(deprecated)] - // to be kept in v0.6.0 #[cfg(feature = "certificate")] pub mod certificates; @@ -87,6 +85,9 @@ pub mod management_api; pub mod protocol_adapters; pub mod utils; +#[cfg(feature = "schema-store")] +pub mod schema_store; + #[cfg(feature = "bootstrap")] #[doc(inline)] pub use dsh::Dsh; diff --git a/dsh_sdk/src/schema_store/api.rs b/dsh_sdk/src/schema_store/api.rs new file mode 100644 index 0000000..47f98a2 --- /dev/null +++ b/dsh_sdk/src/schema_store/api.rs @@ -0,0 +1,431 @@ +use super::types::*; + +use super::request::Request; +use super::Result; + +use super::SchemaStoreClient; + +/// Low level SchemaStoreApi trait +/// +/// This trait follows the definition of the Schema Registry API as stated in the OpenAPI specification. +/// It is recomonded to only use the high level [SchemaStoreClient]. +pub trait SchemaStoreApi { + /// Get glabal compatibility level + /// + /// {base_url}/config/{subject} + async fn get_config_subject(&self, subject: String) -> Result; + /// Set compatibility on subject level. With 1 schema stored in the subject, you can change it to any compatibility level. Else, you can only change into a less restrictive level. Must be one of BACKWARD, BACKWARD_TRANSITIVE, FORWARD, FORWARD_TRANSITIVE, FULL, FULL_TRANSITIVE, NONE + /// + /// {base_url}/config/{subject} + async fn put_config_subject(&self, subject: String, body: Compatibility) -> Result; + /// Get a list of registered subjects + /// + /// {base_url}/subjects + async fn get_subjects(&self) -> Result>; + + /// Check if a schema has already been registered under the specified subject. + /// If so, this returns the schema string along with its globally unique identifier, + /// its version under this subject and the subject name: \"{ \\\"schema\\\": {\\\"name\\\": \\\"username\\\", \\\"type\\\": \\\"string\\\"} }\" + /// + /// {base_url}/subjects/{subject} + async fn post_subjects_subject( + &self, + subject: String, + body: RawSchemaWithType, + ) -> Result; + + /// Get a list of versions registered under the specified subject. + /// + /// {base_url}/subjects/{subject} + async fn get_subjects_subject_versions(&self, subject: String) -> Result>; + + /// Get a specific version of the schema registered under this subject. + /// + /// subjects/{subject}/versions/{id} + async fn get_subjects_subject_versions_id( + &self, + subject: String, + id: String, + ) -> Result; + + /// Register a new schema under the specified subject. + /// + /// If successfully registered, this returns the unique identifier of this schema in the registry. + /// The returned identifier should be used to retrieve this schema from the schemas resource and is different from the schema’s version which is associated with the subject. + /// If the same schema is registered under a different subject, the same identifier will be returned. + /// However, the version of the schema may be different under different subjects. + /// A schema should be compatible with the previously registered schema or schemas (if there are any) as per the configured compatibility level. + /// The configured compatibility level can be obtained by issuing a GET http:get:: /config/(string: subject). + /// If that returns null, then GET http:get:: /config. + /// + /// {base_url}/subjects/{subject}/versions + async fn post_subjects_subject_versions( + &self, + subject: String, + body: RawSchemaWithType, + ) -> Result; + + /// Test input schema against a particular version of a subject’s schema for compatibility. + /// Note that the compatibility level applied for the check is the configured compatibility level for the subject (GET /config/(string: subject)). + /// If this subject’s compatibility level was never changed, then the global compatibility level applies (GET /config). + /// + /// {base_url}/compatibility/subjects/{subject}/versions/{id} + async fn post_compatibility_subjects_subject_versions_id( + &self, + subject: String, + id: String, + body: RawSchemaWithType, + ) -> Result; + + /// "Get the schema for the specified version of this subject. The unescaped schema only is returned. + /// + /// {base_url}/subjects/{subject}/versions/{id}/schema + async fn get_subjects_subject_versions_id_schema( + &self, + subject: String, + version_id: String, + ) -> Result; + + /// Get the schema for the specified version of schema. + /// + /// {base_url}/schemas/ids/{id} + async fn get_schemas_ids_id(&self, id: i32) -> Result; + + /// Get the related subjects vesrion for the specified schema. + /// + /// {base_url}/schemas/ids/{id}/versions + async fn get_schemas_ids_id_versions(&self, id: i32) -> Result>; +} + +impl SchemaStoreApi for SchemaStoreClient +where + C: Request, +{ + async fn get_config_subject(&self, subject: String) -> Result { + let url = format!("{}/config/{}", self.base_url, subject); + Ok(self.client.get_request(url).await?) + } + + async fn put_config_subject(&self, subject: String, body: Compatibility) -> Result { + let url = format!("{}/config/{}", self.base_url, subject); + Ok(self.client.put_request(url, body).await?) + } + + async fn get_subjects(&self) -> Result> { + let url = format!("{}/subjects", self.base_url); + Ok(self.client.get_request(url).await?) + } + + async fn post_subjects_subject( + &self, + subject: String, + body: RawSchemaWithType, + ) -> Result { + let url = format!("{}/subjects/{}", self.base_url, subject); + Ok(self.client.post_request(url, body).await?) + } + + async fn get_subjects_subject_versions(&self, subject: String) -> Result> { + let url = format!("{}/subjects/{}/versions", self.base_url, subject); + Ok(self.client.get_request(url).await?) + } + + async fn get_subjects_subject_versions_id( + &self, + subject: String, + version_id: String, + ) -> Result { + let url = format!( + "{}/subjects/{}/versions/{}", + self.base_url, subject, version_id + ); + Ok(self.client.get_request(url).await?) + } + + async fn post_subjects_subject_versions( + &self, + subject: String, + body: RawSchemaWithType, + ) -> Result { + let url = format!("{}/subjects/{}/versions", self.base_url, subject); + Ok(self.client.post_request(url, body).await?) + } + + async fn post_compatibility_subjects_subject_versions_id( + &self, + subject: String, + version_id: String, + body: RawSchemaWithType, + ) -> Result { + let url = format!( + "{}/compatibility/subjects/{}/versions/{}", + self.base_url, subject, version_id + ); + Ok(self.client.post_request(url, body).await?) + } + + async fn get_subjects_subject_versions_id_schema( + &self, + subject: String, + version_id: String, + ) -> Result { + let url = format!( + "{}/subjects/{}/versions/{}/schema", + self.base_url, subject, version_id + ); + Ok(self.client.get_request_plain(url).await?) + } + + async fn get_schemas_ids_id(&self, id: i32) -> Result { + let url = format!("{}/schemas/ids/{}", self.base_url, id); + Ok(self.client.get_request(url).await?) + } + + async fn get_schemas_ids_id_versions(&self, id: i32) -> Result> { + let url = format!("{}/schemas/ids/{}/versions", self.base_url, id); + Ok(self.client.get_request(url).await?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_config_subject() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/config/test-value") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"compatibilityLevel":"FULL"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .get_config_subject("test-value".to_string()) + .await + .unwrap(); + assert_eq!(result.compatibility_level, Compatibility::FULL); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_put_config_subject() { + let mut ss = mockito::Server::new_async().await; + ss.mock("PUT", "/config/test-value") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"compatibility":"FULL"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .put_config_subject("test-value".to_string(), Compatibility::FULL) + .await + .unwrap(); + assert_eq!(result.compatibility, Compatibility::FULL); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_subjects() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/subjects") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"["test-value", "topic-key"]"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client.get_subjects().await.unwrap(); + assert_eq!(result, vec!["test-value", "topic-key"]); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_post_subjects_subject() { + let mut ss = mockito::Server::new_async().await; + let avro_schema = RawSchemaWithType { + schema_type: SchemaType::AVRO, + schema: r#"{"type":"string"}"#.to_string(), + }; + ss.mock( "POST", "/subjects/test-value") + .match_body(mockito::Matcher::Json(serde_json::to_value(&avro_schema).unwrap())) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"subject":"test-value", "version":1, "id":1, "schema":"{\"type\":\"string\"}"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .post_subjects_subject("test-value".to_string(), avro_schema) + .await + .unwrap(); + assert_eq!(result.subject, "test-value"); + assert_eq!(result.version, 1); + assert_eq!(result.id, 1); + assert_eq!(result.schema, r#"{"type":"string"}"#); + assert_eq!(result.schema_type, SchemaType::AVRO); + + let proto_schema = RawSchemaWithType { + schema_type: SchemaType::PROTOBUF, + schema: r#"syntax = "proto3";package com.kpn.protobuf;message SimpleMessage {string content = 1;string date_time = 2;}"#.to_string(), + }; + ss.mock( "POST", "/subjects/protobuf-value") + .match_body(mockito::Matcher::Json(serde_json::to_value(&proto_schema).unwrap())) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"subject":"protobuf-value", "version":1, "id":1, "schemaType":"PROTOBUF", "schema":"syntax = \"proto3\";package com.kpn.protobuf;message SimpleMessage {string content = 1;string date_time = 2;}"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .post_subjects_subject("protobuf-value".to_string(), proto_schema) + .await + .unwrap(); + assert_eq!(result.subject, "protobuf-value"); + assert_eq!(result.version, 1); + assert_eq!(result.id, 1); + assert_eq!( + result.schema, + r#"syntax = "proto3";package com.kpn.protobuf;message SimpleMessage {string content = 1;string date_time = 2;}"# + ); + assert_eq!(result.schema_type, SchemaType::PROTOBUF); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_subjects_subject_versions() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/subjects/test-value/versions") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"[1, 2, 3]"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .get_subjects_subject_versions("test-value".to_string()) + .await + .unwrap(); + assert_eq!(result, vec![1, 2, 3]); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_subjects_subject_versions_id() { + let mut ss = mockito::Server::new_async().await; + ss.mock( "GET", "/subjects/test-value/versions/1") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"subject":"test-value", "version":1, "id":1, "schema":"{\"type\":\"string\"}"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .get_subjects_subject_versions_id("test-value".to_string(), "1".to_string()) + .await + .unwrap(); + assert_eq!(result.subject, "test-value"); + assert_eq!(result.version, 1); + assert_eq!(result.id, 1); + assert_eq!(result.schema, r#"{"type":"string"}"#); + assert_eq!(result.schema_type, SchemaType::AVRO); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_post_subjects_subject_versions() { + let mut ss = mockito::Server::new_async().await; + let avro_schema = RawSchemaWithType { + schema_type: SchemaType::AVRO, + schema: r#"{"type":"string"}"#.to_string(), + }; + ss.mock("POST", "/subjects/test-value/versions") + .match_body(mockito::Matcher::Json( + serde_json::to_value(&avro_schema).unwrap(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"id":1}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .post_subjects_subject_versions("test-value".to_string(), avro_schema) + .await + .unwrap(); + assert_eq!(result.id, 1); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_post_compatibility_subjects_subject_versions_id() { + let mut ss = mockito::Server::new_async().await; + let avro_schema = RawSchemaWithType { + schema_type: SchemaType::AVRO, + schema: r#"{"type":"string"}"#.to_string(), + }; + ss.mock("POST", "/compatibility/subjects/test-value/versions/1") + .match_body(mockito::Matcher::Json( + serde_json::to_value(&avro_schema).unwrap(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"is_compatible":true}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .post_compatibility_subjects_subject_versions_id( + "test-value".to_string(), + "1".to_string(), + avro_schema, + ) + .await + .unwrap(); + assert_eq!(result.is_compatible(), true); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_subjects_subject_versions_id_schema() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/subjects/test-value/versions/1/schema") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"type":"string"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client + .get_subjects_subject_versions_id_schema("test-value".to_string(), "1".to_string()) + .await + .unwrap(); + assert_eq!(result, r#"{"type":"string"}"#); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_schemas_ids_id() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/schemas/ids/1") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"schema":"{\"type\":\"string\"}"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client.get_schemas_ids_id(1).await.unwrap(); + assert_eq!(result.schema, r#"{"type":"string"}"#); + assert_eq!(result.schema_type, SchemaType::AVRO); + + ss.mock("GET", "/schemas/ids/2") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"schema":"syntax = \"proto3\";package com.kpn.protobuf;message SimpleMessage {string content = 1;string date_time = 2;}", "schemaType": "PROTOBUF"}"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client.get_schemas_ids_id(2).await.unwrap(); + assert_eq!( + result.schema, + r#"syntax = "proto3";package com.kpn.protobuf;message SimpleMessage {string content = 1;string date_time = 2;}"# + ); + assert_eq!(result.schema_type, SchemaType::PROTOBUF); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_get_schemas_ids_id_versions() { + let mut ss = mockito::Server::new_async().await; + ss.mock("GET", "/schemas/ids/1/versions") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"[{"subject":"test-value", "version":1, "id":1, "schema":"{\"type\":\"string\"}"}]"#) + .create(); + let client = SchemaStoreClient::new_with_base_url(&ss.url()); + let result = client.get_schemas_ids_id_versions(1).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].subject, "test-value"); + assert_eq!(result[0].version, 1); + } +} diff --git a/dsh_sdk/src/schema_store/client.rs b/dsh_sdk/src/schema_store/client.rs new file mode 100644 index 0000000..91fc395 --- /dev/null +++ b/dsh_sdk/src/schema_store/client.rs @@ -0,0 +1,371 @@ +use super::api::SchemaStoreApi; +use super::request::Request; +use super::types::*; +use super::{Result, SchemaStoreError}; +use crate::Dsh; + +/// High level Schema Store Client +/// +/// Client to interact with the Schema Store API. +pub struct SchemaStoreClient { + pub(crate) base_url: String, + pub(crate) client: C, +} + +impl SchemaStoreClient { + pub fn new() -> Self { + Self::new_with_base_url(Dsh::get().schema_registry_host()) + } + + /// Create SchemaStoreClient with a custom base URL + pub fn new_with_base_url(base_url: &str) -> Self { + Self { + base_url: base_url.trim_end_matches('/').to_string(), + client: Request::new_client(), + } + } +} + +impl SchemaStoreClient +where + C: Request, +{ + /// Get the compatibility level for a subject + /// + /// ## Returns + /// Returns a Result of the compatibility level of given subject + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// println!("Config: {:?}", client.subject_compatibility("scratch.example-topic.tenant-value").await); + /// # } + /// + pub async fn subject_compatibility(&self, subject: Sn) -> Result + where + Sn: Into, + { + Ok(self.get_config_subject(subject.into().name()).await?.into()) + } + + /// Set the compatibility level for a subject + /// + /// Set compatibility on subject level. With 1 schema stored in the subject, you can change it to any compatibility level. + /// Else, you can only change into a less restrictive level. + /// + /// ## Returns + /// Returns a Result of the new compatibility level + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::Compatibility; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// client.subject_compatibility_update("scratch.example-topic.tenant-value", Compatibility::FULL).await.unwrap(); + /// # } + /// ``` + /// + /// TODO: untested as this API method does not seem to work at all on DSH + pub async fn subject_compatibility_update( + &self, + subject: Sn, + compatibility: Compatibility, + ) -> Result + where + Sn: Into, + { + Ok(self + .put_config_subject(subject.into().name(), compatibility) + .await? + .into()) + } + + /// Get a list of all registered subjects + /// + /// ## Returns + /// Returns a Result of of all registered subjects from the schema registry + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// println!("Subjects: {:?}", client.subjects().await); + /// # } + /// ``` + pub async fn subjects(&self) -> Result> { + self.get_subjects().await + } + + /// Get a list of all versions of a subject + /// + /// ## Returns + /// Returns a Result of all version ID's of a subject from the schema registry + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// println!("Available versions: {:?}", client.subject_versions("scratch.example-topic.tenant-value").await); + /// # } + /// ``` + pub async fn subject_versions(&self, subject: Sn) -> Result> + where + Sn: Into, + { + self.get_subjects_subject_versions(subject.into().name()) + .await + } + + /// Get subject for specific version + /// + /// ## Returns + /// Returns a Result of the schema for the given subject and version + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::{SubjectName, SubjectVersion}; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// + /// // Get the latest version of the schema + /// let subject = client.subject(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}, SubjectVersion::Latest).await.unwrap(); + /// let raw_schema = subject.schema; + /// + /// // Get a specific version of the schema + /// let subject = client.subject("scratch.example-topic.tenant-value", SubjectVersion::Version(1)).await.unwrap(); + /// let raw_schema = subject.schema; + /// # } + /// ``` + pub async fn subject(&self, subject: Sn, version: V) -> Result + where + //Sn: TryInto, + Sn: Into, + V: Into, + { + let subject = subject.into().name(); + let version = version.into(); + self.get_subjects_subject_versions_id(subject, version.to_string()) + .await + } + + /// Get the raw schema string for the specified version of subject. + /// + /// ## Arguments + /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// + /// ## Returns + /// Returns a Result of the raw schema string for the given subject and version + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// let raw_schema = client.subject_raw_schema("scratch.example-topic.tenant-value", 1).await.unwrap(); + /// # } + /// ``` + pub async fn subject_raw_schema(&self, subject: Sn, version: V) -> Result + where + Sn: Into, + V: Into, + { + self.get_subjects_subject_versions_id_schema( + subject.into().name(), + version.into().to_string(), + ) + .await + } + + /// Get all schemas for a subject + /// + /// ## Arguments + /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// + /// ## Returns + /// Returns a Result of all schemas for the given subject + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::SubjectName; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// let subjects = client.subject_all_schemas(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}).await.unwrap(); + /// # } + pub async fn subject_all_schemas(&self, subject: Sn) -> Result> + where + Sn: Into + Clone, + { + let versions = self.subject_versions(subject.clone()).await?; + let mut subjects = Vec::new(); + for version in versions { + let subject = self.subject(subject.clone(), version).await?; + subjects.push(subject); + } + Ok(subjects) + } + + /// Post a new schema for a (new) subject + /// + /// ## Errors + /// - If the given schema cannot be converted into a String with given schema type + /// - The API call will retun a error when + /// - subject already has a schema and it's compatibility does not allow it + /// - subject already has a schema with a different schema type + /// - schema is invalid + /// + /// ## Arguments + /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// + /// ## Returns + /// Returns a Result of the new schema ID. + /// If schema already exists, it will return with the existing schema ID. + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SchemaType}; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// + /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) + /// let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; + /// let schema_version = client.subject_add_schema("scratch.example-topic.tenant-value", (raw_schema, SchemaType::AVRO)).await.unwrap(); + /// + /// // Or if you have a schema object + /// let avro_schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); // or ProtoBuf or JSON schema + /// let schema_version = client.subject_add_schema("scratch.example-topic.tenant-value", avro_schema).await.unwrap(); + /// # } + /// ``` + pub async fn subject_add_schema(&self, subject: Sn, schema: Sc) -> Result + where + Sn: Into, + Sc: TryInto, + { + let schema = schema.try_into()?; + Ok(self + .post_subjects_subject_versions(subject.into().name(), schema) + .await? + .id()) + } + + /// Check if schema already been registred for a subject + /// + /// If it returns 404, it means the schema is not yet registered (even when it states "unable to process") + /// + /// ## Errors + /// - If the given schema cannot be converted into a String with given schema type + /// - The API call will retun a error when + /// - provided schema is different + /// - schema is invalid + /// + /// ## Arguments + /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// + /// ## Returns + /// If schema exists, it will return with the existing version and schema ID. + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::{SubjectName, SchemaType}; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let client = SchemaStoreClient::new(); + /// + /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) + /// let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; + /// let subject = client.subject_schema_exist("scratch.example-topic.tenant-value", (raw_schema, SchemaType::AVRO)).await.unwrap(); + /// # } + /// ``` + pub async fn subject_schema_exist(&self, subject: Sn, schema: Sc) -> Result + where + Sn: Into, + Sc: TryInto, + { + let schema = schema.try_into()?; + self.post_subjects_subject(subject.into().name(), schema) + .await + } + + /// Check if schema is compatible with a specific version of a subject based on the compatibility level + /// + /// Note that the compatibility level applied for the check is the configured compatibility level for the subject. + /// If this subject’s compatibility level was never changed, then the global compatibility level applies. + /// + /// ## Arguments + /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `version`: Anything that can be converted into a [SubjectVersion] + /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// + /// ## Returns + /// Returns a Result of a boolean if the schema is compatible with the given version of the subject + pub async fn subject_new_schema_compatibility( + &self, + subject: Sn, + version: Sv, + schema: Sc, + ) -> Result + where + Sn: Into, + Sv: Into, + Sc: TryInto, + { + let schema = schema.try_into()?; + Ok(self + .post_compatibility_subjects_subject_versions_id( + subject.into().name(), + version.into().to_string(), + schema, + ) + .await? + .is_compatible()) + } + + /// Get the schema based in schema ID. + /// + /// ## Arguments + /// - `id`: The schema ID (Into<[i32]>) + pub async fn schema(&self, id: Si) -> Result + where + Si: Into, + { + self.get_schemas_ids_id(id.into()).await + } + + /// Get all subjects that are using the given schema + /// + /// ## Arguments + /// - `id`: The schema ID (Into<[i32]>) + pub async fn schema_subjects(&self, id: Si) -> Result> + where + Si: Into, + { + self.get_schemas_ids_id_versions(id.into()).await + } +} diff --git a/dsh_sdk/src/schema_store/error.rs b/dsh_sdk/src/schema_store/error.rs new file mode 100644 index 0000000..231f72c --- /dev/null +++ b/dsh_sdk/src/schema_store/error.rs @@ -0,0 +1,28 @@ +use thiserror::Error; + +/// Error type for the SchemaStore +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum SchemaStoreError { + #[error("Reqwest error: {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("SerdeJson error: {0}")] + SerdeJson(#[from] serde_json::Error), + #[error("Could not parse raw schema to a valid schema {:?}", .0)] + FailedToParseSchema(Option), + #[error("Invalid status code: {status_code} for {url} ({error})")] + InvalidStatusCode { + status_code: u16, + url: String, + error: String, + }, + #[error("Empty payload")] + EmptyPayload, + #[error("Failed to decode payload: {0}")] + FailedToDecode(String), + #[error("Failed to parse value onto struct")] + FailedParseToStruct, + + #[error("Protobuf to struct not (yet) implemented")] + NotImplementedProtobufDeserialize, +} diff --git a/dsh_sdk/src/schema_store/mod.rs b/dsh_sdk/src/schema_store/mod.rs new file mode 100644 index 0000000..fbd4659 --- /dev/null +++ b/dsh_sdk/src/schema_store/mod.rs @@ -0,0 +1,70 @@ +//! Schema Store client +//! +//! This module contains the SchemaStoreClient struct which is the main entry point for interacting with the DSH Schema Registry API. +//! +//! It automatically connects to the Schema Registry API with proper certificates and uses the base URL provided by the datastreams.josn. +//! +//! When connecting via Proxy or to a local Schema Registry, you can provide the base URL yourself via the [SchemaStoreClient::new_with_base_url] function or by setting `SCHEMA_REGISTRY_HOST` variable. +//! +//! ## Example +//! ```no_run +//! use dsh_sdk::schema_store::SchemaStoreClient; +//! use dsh_sdk::schema_store::types::*; +//! +//! # #[tokio::main] +//! # async fn main() { +//! let client = SchemaStoreClient::new(); +//! +//! // List all subjects +//! let subjects = client.subjects().await.unwrap(); +//! +//! // Get the latest version of a subjects value schema +//! let subject = client.subject(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}, SubjectVersion::Latest).await.unwrap(); +//! let raw_schema = subject.schema; +//! # } +//! ``` +//! +//! ## Input arguments +//! Note that for all input types [TryInto] or [Into] is implemented. This means you can use the following types as input: +//! ``` +//! use dsh_sdk::schema_store::types::*; +//! +//! // From original type +//! let from_struct = SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}; +//! +//! // From string +//! let from_str: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); // Note that `-value`` is added, else it will return error as it is not a valid SubjectName +//! assert_eq!(from_str, from_struct); +//! +//! // From tuple +//! let from_tuple: SubjectName = ("scratch.example-topic.tenant", false).into(); +//! assert_eq!(from_tuple, from_struct); +//! ``` +//! +//! This means you can easily provide the input arguments from other types without converting it yourself. +//! For example: +//! ```no_run +//! use dsh_sdk::schema_store::SchemaStoreClient; +//! use dsh_sdk::schema_store::types::*; +//! +//! # #[tokio::main] +//! # async fn main() { +//! let client = SchemaStoreClient::new(); +//! +//! let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; +//! client.subject_add_schema("scratch.example-topic.tenant-value", raw_schema).await.unwrap(); // Returns error if schema is not valid +//! # } +//! ``` + +mod api; +mod client; +mod error; +mod request; +pub mod types; + +#[doc(inline)] +pub use client::SchemaStoreClient; +#[doc(inline)] +pub use error::SchemaStoreError; + +type Result = std::result::Result; diff --git a/dsh_sdk/src/schema_store/request.rs b/dsh_sdk/src/schema_store/request.rs new file mode 100644 index 0000000..29fa727 --- /dev/null +++ b/dsh_sdk/src/schema_store/request.rs @@ -0,0 +1,133 @@ +use log::trace; + +use crate::Dsh; + +use super::{Result, SchemaStoreError}; + +const DEFAULT_CONTENT_TYPE: &str = "application/vnd.schemaregistry.v1+json"; + +pub trait Request { + fn new_client() -> Self; + fn get_request(&self, url: String) -> impl std::future::Future> + Send + where + R: serde::de::DeserializeOwned; + fn get_request_plain( + &self, + url: String, + ) -> impl std::future::Future> + Send; + fn post_request( + &self, + url: String, + body: B, + ) -> impl std::future::Future> + Send + where + R: serde::de::DeserializeOwned, + B: serde::Serialize + Send; + fn put_request( + &self, + url: String, + body: B, + ) -> impl std::future::Future> + Send + where + R: serde::de::DeserializeOwned, + B: serde::Serialize + Send; +} + +impl Request for reqwest::Client { + fn new_client() -> Self { + // TODO: replace with hyper client + Dsh::get() + .reqwest_client_config() + .build() + .expect("Failed to build reqwest client") + } + async fn get_request(&self, url: String) -> Result + where + R: serde::de::DeserializeOwned, + { + trace!("GET {}", url); + let request = self + .get(&url) + .header("Content-Type", DEFAULT_CONTENT_TYPE) + .header("Accept", DEFAULT_CONTENT_TYPE); + let response = request.send().await?; + trace!("Response: {:?}", response); + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(SchemaStoreError::InvalidStatusCode { + status_code: response.status().as_u16(), + url: url.to_string(), + error: response.text().await.unwrap_or_default(), + }) + } + } + + async fn get_request_plain(&self, url: String) -> Result { + trace!("GET {}", url); + let request = self.get(&url); + let response = request.send().await?; + trace!("Response: {:?}", response); + if response.status().is_success() { + Ok(response.text().await?) + } else { + Err(SchemaStoreError::InvalidStatusCode { + status_code: response.status().as_u16(), + url: url.to_string(), + error: response.text().await.unwrap_or_default(), + }) + } + } + + /// Helper function to send a POST request and return the response with the expected type (serde with as JSON) + async fn post_request(&self, url: String, body: B) -> Result + where + R: serde::de::DeserializeOwned, + B: serde::Serialize + Send, + { + trace!("POST {}", url); + let json_body = serde_json::to_vec(&body)?; + let request = self + .post(&url) + .body(json_body) + .header("Content-Type", DEFAULT_CONTENT_TYPE) + .header("Accept", DEFAULT_CONTENT_TYPE); + let response = request.send().await?; + trace!("Response {:?}", response); + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(SchemaStoreError::InvalidStatusCode { + status_code: response.status().as_u16(), + url: url.to_string(), + error: response.text().await.unwrap_or_default(), + }) + } + } + + /// Helper function to send a PUT request and return the response with the expected type (serde with as JSON) + async fn put_request(&self, url: String, body: B) -> Result + where + R: serde::de::DeserializeOwned, + B: serde::Serialize + Send, + { + trace!("PUT {}", url); + let json_body = serde_json::to_vec(&body)?; + let request = self + .put(&url) + .body(json_body) + .header("Content-Type", DEFAULT_CONTENT_TYPE) + .header("Accept", DEFAULT_CONTENT_TYPE); + let response = request.send().await?; + trace!("Response {:?}", response); + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(SchemaStoreError::InvalidStatusCode { + status_code: response.status().as_u16(), + url: url.to_string(), + error: response.text().await.unwrap_or_default(), + }) + } + } +} diff --git a/dsh_sdk/src/schema_store/types/compatibility.rs b/dsh_sdk/src/schema_store/types/compatibility.rs new file mode 100644 index 0000000..b3c7a73 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/compatibility.rs @@ -0,0 +1,69 @@ +use serde::{Deserialize, Serialize}; + +/// Schema compatibility level +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +pub enum Compatibility { + BACKWARD, + BACKWARD_TRANSITIVE, + FORWARD, + FORWARD_TRANSITIVE, + FULL, + FULL_TRANSITIVE, + NONE, +} + +/// Schema config containing compatibility level +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ConfigGet { + pub compatibility_level: Compatibility, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Schema config containing compatibility level +/// +/// For some reason the body is different compared from the get response +pub(crate) struct ConfigPut { + pub compatibility: Compatibility, +} + +/// Response from compatibility check +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct CompatibilityCheck { + pub is_compatible: bool, +} + +impl CompatibilityCheck { + pub fn is_compatible(&self) -> bool { + self.is_compatible + } +} + +impl From for Compatibility { + fn from(value: ConfigGet) -> Self { + value.compatibility_level + } +} + +impl From for Compatibility { + fn from(value: ConfigPut) -> Self { + value.compatibility + } +} + +impl std::fmt::Display for Compatibility { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BACKWARD => write!(f, "BACKWARD"), + Self::BACKWARD_TRANSITIVE => write!(f, "BACKWARD_TRANSITIVE"), + Self::FORWARD => write!(f, "FORWARD"), + Self::FORWARD_TRANSITIVE => write!(f, "FORWARD_TRANSITIVE"), + Self::FULL => write!(f, "FULL"), + Self::FULL_TRANSITIVE => write!(f, "FULL_TRANSITIVE"), + Self::NONE => write!(f, "NONE"), + } + } +} diff --git a/dsh_sdk/src/schema_store/types/mod.rs b/dsh_sdk/src/schema_store/types/mod.rs new file mode 100644 index 0000000..b3c2de2 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/mod.rs @@ -0,0 +1,12 @@ +//! Schema store types +//! +//! This module contains all types used to interact with the schema store. +mod compatibility; +mod schema; +mod subject_strategy; +mod subjects; + +pub use compatibility::*; +pub use schema::*; +pub use subject_strategy::*; +pub use subjects::*; diff --git a/dsh_sdk/src/schema_store/types/schema/id.rs b/dsh_sdk/src/schema_store/types/schema/id.rs new file mode 100644 index 0000000..8f6f142 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/schema/id.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; + +/// Schema id +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SchemaId { + pub id: i32, +} + +impl SchemaId { + pub fn id(&self) -> i32 { + self.id + } +} +impl From for SchemaId { + fn from(value: i32) -> Self { + Self { id: value } + } +} + +impl From for i32 { + fn from(value: SchemaId) -> Self { + value.id + } +} diff --git a/dsh_sdk/src/schema_store/types/schema/mod.rs b/dsh_sdk/src/schema_store/types/schema/mod.rs new file mode 100644 index 0000000..aca351c --- /dev/null +++ b/dsh_sdk/src/schema_store/types/schema/mod.rs @@ -0,0 +1,10 @@ +mod id; +mod raw_schema; +// mod schema; // Tp do, add schema and decoders +mod r#type; + +pub use id::*; +pub use r#type::*; +pub use raw_schema::*; + +use super::Subject; diff --git a/dsh_sdk/src/schema_store/types/schema/raw_schema.rs b/dsh_sdk/src/schema_store/types/schema/raw_schema.rs new file mode 100644 index 0000000..1222ef4 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/schema/raw_schema.rs @@ -0,0 +1,325 @@ +use super::SchemaType; +use crate::schema_store::SchemaStoreError; +use serde::{Deserialize, Serialize}; + +/// Structure to post (new) schema to a (new) subject or verify if a schema already exists for a subject +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct RawSchemaWithType { + #[serde(default)] + pub schema_type: SchemaType, + pub schema: String, +} + +impl RawSchemaWithType { + /// Create [RawSchemaWithType] from a schema string. + /// + /// It returns error if it cannot parse it into to a Avro, JSON or Protobuf schema. + pub fn parse(schema: S) -> Result + where + S: AsRef, + { + schema.as_ref().try_into() + } + + /// Raw schema string + pub fn schema(&self) -> &str { + &self.schema + } + + /// Schema type (AVRO, JSON, PROTOBUF) + pub fn schema_type(&self) -> SchemaType { + self.schema_type + } +} + +impl From for RawSchemaWithType { + fn from(value: super::Subject) -> Self { + Self { + schema_type: value.schema_type, + schema: value.schema, + } + } +} + +impl TryFrom for RawSchemaWithType { + type Error = SchemaStoreError; + + fn try_from(value: String) -> Result { + value.as_str().try_into() + } +} + +impl TryFrom<&str> for RawSchemaWithType { + type Error = SchemaStoreError; + + fn try_from(value: &str) -> Result { + if let Ok(avro) = apache_avro::Schema::parse_str(value) { + log::debug!("avro: {}", avro.canonical_form()); + Ok(Self { + schema_type: SchemaType::AVRO, + schema: avro.canonical_form(), + }) + } else if let Ok(json) = serde_json::from_str::(value) { + log::debug!("json: {:?}", json); + Ok(Self { + schema_type: SchemaType::JSON, + schema: json.to_string(), + }) + } else if let Ok(_) = protofish::context::Context::parse(&[value]) { + // TODO: Add parser for protobuf + log::debug!("protobuf: {:?}", value); + Ok(Self { + schema_type: SchemaType::PROTOBUF, + schema: value.to_string(), + }) + } else { + Err(SchemaStoreError::FailedToParseSchema(None)) + } + } +} + +impl TryFrom for RawSchemaWithType { + type Error = SchemaStoreError; + + fn try_from(value: apache_avro::Schema) -> Result { + Ok(Self { + schema_type: SchemaType::AVRO, + schema: value.canonical_form(), + }) + } +} + +impl TryFrom for RawSchemaWithType { + type Error = SchemaStoreError; + + fn try_from(value: serde_json::Value) -> Result { + Ok(if let Ok(avro) = apache_avro::Schema::parse(&value) { + Self { + schema_type: SchemaType::AVRO, + schema: avro.canonical_form(), + } + } else { + Self { + schema_type: SchemaType::JSON, + schema: value.to_string(), + } + }) + } +} + +impl TryFrom<(S, SchemaType)> for RawSchemaWithType +where + S: AsRef, +{ + type Error = SchemaStoreError; + + fn try_from(value: (S, SchemaType)) -> Result { + let schema_type = value.1; + let raw_schema = value.0.as_ref(); + let raw_schema = match schema_type { + SchemaType::JSON => { + let _ = serde_json::from_str::(raw_schema) + .map_err(|_| SchemaStoreError::FailedToParseSchema(Some(schema_type)))?; + raw_schema + } + SchemaType::AVRO => { + let _ = apache_avro::Schema::parse_str(raw_schema) + .map_err(|_| SchemaStoreError::FailedToParseSchema(Some(schema_type)))?; + raw_schema + } + SchemaType::PROTOBUF => { + let _ = protofish::context::Context::parse(&[raw_schema]) + .map_err(|_| SchemaStoreError::FailedToParseSchema(Some(schema_type)))?; + raw_schema + } + }; + Ok(Self { + schema_type, + schema: raw_schema.to_string(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use apache_avro::Schema as AvroSchema; + use serde_json::Value as JsonValue; + + #[test] + fn test_parse_avro() { + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::parse(raw_schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_parse_json() { + let raw_schema = r#"{"fields":[{"name":"name","type":"string"}],"name":"User"}"#; + let schema = RawSchemaWithType::parse(raw_schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::JSON); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_parse_protobuf() { + let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; + let schema = RawSchemaWithType::parse(raw_schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_parse_invalid() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}"#; + let schema = RawSchemaWithType::parse(raw_schema); + assert!(schema.is_err()); + + let raw_schema = r#"not a schema"#; + let schema = RawSchemaWithType::parse(raw_schema); + assert!(schema.is_err()); + } + + #[test] + fn test_try_from_avro() { + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); + let schema = RawSchemaWithType::try_from(schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_json() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; + let schema = serde_json::from_str::(raw_schema).unwrap(); + let schema = RawSchemaWithType::try_from(schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::JSON); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_protobuf() { + let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; + let schema = RawSchemaWithType::try_from(raw_schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_tuple_json() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::JSON)).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::JSON); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_tuple_avro() { + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::AVRO)).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_tuple_protobuf() { + let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::PROTOBUF)).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_tuple_invalid() { + let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::JSON)); + assert!(schema.is_err()); + + let raw_schema = r#"name":"User","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::AVRO)); + assert!(schema.is_err()); + + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::PROTOBUF)); + assert!(schema.is_err()); + + let raw_schema = r#"not a schema"#; + let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::PROTOBUF)); + assert!(schema.is_err()); + } + + #[test] + fn test_try_from_string() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::JSON); + assert_eq!(schema.schema(), raw_schema); + + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + + let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); + assert_eq!(schema.schema(), raw_schema); + + let raw_schema = r#"not a schema"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()); + assert!(schema.is_err()); + } + + #[test] + fn test_try_from_string_invalid() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()); + assert!(schema.is_err()); + + let raw_schema = r#"not a schema"#; + let schema = RawSchemaWithType::try_from(raw_schema.to_string()); + assert!(schema.is_err()); + } + + #[test] + fn test_from_subject() { + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let subject = crate::schema_store::types::Subject { + version: 1, + id: 1, + subject: "test".to_string(), + schema_type: SchemaType::AVRO, + schema: raw_schema.to_string(), + }; + let schema = RawSchemaWithType::from(subject); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_json_value() { + let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; + let schema = serde_json::from_str::(raw_schema).unwrap(); + let schema = RawSchemaWithType::try_from(schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::JSON); + assert_eq!(schema.schema(), raw_schema); + } + + #[test] + fn test_try_from_json_value_avro() { + let raw_schema = + r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; + let schema = AvroSchema::parse_str(raw_schema).unwrap(); + let schema = RawSchemaWithType::try_from(schema).unwrap(); + assert_eq!(schema.schema_type(), SchemaType::AVRO); + assert_eq!(schema.schema(), raw_schema); + } +} diff --git a/dsh_sdk/src/schema_store/types/schema/schema.rs b/dsh_sdk/src/schema_store/types/schema/schema.rs new file mode 100644 index 0000000..efa5a44 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/schema/schema.rs @@ -0,0 +1,44 @@ +use std::io::Cursor; + +use apache_avro::Schema as AvroSchema; +use protofish::context::Context as ProtoSchema; +use serde_json::Value as JsonValue; + +use crate::schema_store::error::SchemaStoreError; + +/// Schema object +/// +/// Common object to apply schema operations on +enum SchemaObject { + Avro(AvroSchema), + Json(JsonValue), + Proto(ProtoSchema), +} + +impl SchemaObject { + /// Deserialize bytes into struct + /// + /// Bytes should only contain the encoded data and not the magic bytes + pub fn decode(&self, bytes: &[u8]) -> Result + where T: serde::de::DeserializeOwned { + match self { + Self::Avro(schema) => { + apache_avro::from_value(&to_avro_value(schema, bytes)?).map_err(|_| SchemaStoreError::FailedParseToStruct) + } + Self::Json(schema) => serde_json::from_slice(bytes).map_err(SchemaStoreError::from), + Self::Proto(_) => Err(SchemaStoreError::NotImplementedProtobufDeserialize), + } + } +} + + + + +fn to_avro_value(schema: &AvroSchema, bytes: &[u8]) -> Result { + let mut buf = Cursor::new(bytes); + let value= apache_avro::from_avro_datum(schema, & mut buf, None).map_err(|e| { + log::warn!("Failed to decode value for Avro schema {}: {}", schema.name().map(|n|n.name.clone()).unwrap_or_default() , e); + SchemaStoreError::FailedToDecode(e.to_string()) + })?; + Ok(value) +} \ No newline at end of file diff --git a/dsh_sdk/src/schema_store/types/schema/type.rs b/dsh_sdk/src/schema_store/types/schema/type.rs new file mode 100644 index 0000000..caab083 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/schema/type.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; + +/// Schema type +/// +/// Available schema types +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +pub enum SchemaType { + JSON, + PROTOBUF, + AVRO, +} + +impl Default for SchemaType { + fn default() -> Self { + Self::AVRO + } +} + +impl From for SchemaType +where + S: AsRef, +{ + fn from(value: S) -> Self { + match value.as_ref() { + "JSON" | "json" => Self::JSON, + "PROTOBUF" | "protobuf" => Self::PROTOBUF, + "AVRO" | "avro" => Self::AVRO, + _ => Self::AVRO, + } + } +} + +impl std::fmt::Display for SchemaType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::JSON => write!(f, "JSON"), + Self::PROTOBUF => write!(f, "PROTOBUF"), + Self::AVRO => write!(f, "AVRO"), + } + } +} diff --git a/dsh_sdk/src/schema_store/types/subject_strategy.rs b/dsh_sdk/src/schema_store/types/subject_strategy.rs new file mode 100644 index 0000000..db756c8 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/subject_strategy.rs @@ -0,0 +1,319 @@ +use std::hash::{Hash, Hasher}; + +/// Subject name strategy +/// +/// Defines the strategy to use for the subject name +/// +/// ## Variants +/// Currently only the `TopicNameStrategy` is supported +/// +/// - `TopicNameStrategy`: Use the topic name as the subject name and suffix of '-key' or '-value' for the key and value schemas +/// +/// Example: +///``` +/// # use dsh_sdk::schema_store::types::SubjectName; +/// SubjectName::TopicNameStrategy{topic: "scratch.example.tenant".to_string(), key: false}; // "scratch.example.tenant-value" +/// ``` +#[derive(Debug, Clone, Eq)] +pub enum SubjectName { + /// Use the topic name as the subject name and suffix of '-key' or '-value' for the key and value schemas + /// + /// Example: + ///``` + /// # use dsh_sdk::schema_store::types::SubjectName; + /// SubjectName::TopicNameStrategy{topic: "scratch.example.tenant".to_string(), key: false}; // "scratch.example.tenant-value" + /// ``` + TopicNameStrategy { topic: String, key: bool }, +} + +impl SubjectName { + pub fn new(topic: S, key: bool) -> Self + where + S: AsRef, + { + Self::TopicNameStrategy { + topic: topic.as_ref().to_string(), + key, + } + } + pub fn name(&self) -> String { + match self { + Self::TopicNameStrategy { topic, key } => { + if *key { + format!("{}-key", topic) + } else { + format!("{}-value", topic) + } + } + } + } + + pub fn topic(&self) -> &str { + match self { + Self::TopicNameStrategy { topic, .. } => topic, + } + } + + pub fn key(&self) -> bool { + match self { + Self::TopicNameStrategy { key, .. } => *key, + } + } +} + +impl From<&str> for SubjectName { + fn from(value: &str) -> Self { + let (topic, key) = if value.ends_with("-key") { + (value.trim_end_matches("-key"), true) + } else if value.ends_with("-value") { + (value.trim_end_matches("-value"), false) + } else { + (value, false) + }; + Self::TopicNameStrategy { + topic: topic.to_string(), + key, + } + } +} + +impl From for SubjectName { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<(&str, bool)> for SubjectName { + fn from(value: (&str, bool)) -> Self { + Self::TopicNameStrategy { + topic: value.0.to_string(), + key: value.1, + } + } +} + +impl From<(String, bool)> for SubjectName { + fn from(value: (String, bool)) -> Self { + { + Self::TopicNameStrategy { + topic: value.0, + key: value.1, + } + } + } +} + +impl std::fmt::Display for SubjectName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::TopicNameStrategy { topic, key } => { + write!(f, "{}-{}", topic, if *key { "key" } else { "value" }) + } + } + } +} + +impl PartialEq for SubjectName { + fn eq(&self, other: &SubjectName) -> bool { + self.to_string() == other.to_string() // TODO: not the fastest way to compare, but it works for now + } +} + +impl Hash for SubjectName { + fn hash(&self, state: &mut H) { + self.to_string().hash(state); + } +} + +#[cfg(test)] +mod tests { + use openssl::hash; + + use super::*; + use std::hash::DefaultHasher; + + #[test] + fn test_subject_name_funcitons() { + let subject = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + assert_eq!(subject.topic(), "scratch.example.tenant"); + assert_eq!(subject.key(), false); + assert_eq!(subject.name(), "scratch.example.tenant-value"); + + let subject = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + assert_eq!(subject.topic(), "scratch.example.tenant"); + assert_eq!(subject.key(), true); + assert_eq!(subject.name(), "scratch.example.tenant-key"); + } + + #[test] + fn test_subject_name_new() { + let subject = SubjectName::new("scratch.example.tenant", false); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false + } + ); + } + + #[test] + fn test_subject_name_from_string() { + let subject: SubjectName = "scratch.example.tenant-value".into(); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false + } + ); + + let subject: SubjectName = "scratch.example.tenant-key".into(); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true + } + ); + } + + #[test] + fn test_subject_name_from_tuple() { + let subject: SubjectName = ("scratch.example.tenant".to_string(), false).into(); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false + } + ); + + let subject: SubjectName = ("scratch.example.tenant".to_string(), true).into(); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true + } + ); + } + + #[test] + fn test_subject_name_from_string_ref() { + let string = "scratch.example.tenant-value".to_string(); + let subject: SubjectName = string.into(); + assert_eq!( + subject, + SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false + } + ); + } + + #[test] + fn test_subject_name_display() { + let subject = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + assert_eq!(subject.to_string(), "scratch.example.tenant-value"); + + let subject = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + assert_eq!(subject.to_string(), "scratch.example.tenant-key"); + } + + #[test] + fn test_subject_name_eq() { + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + assert_eq!(subject1, subject2); + + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + assert_eq!(subject1, subject2); + + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + assert_ne!(subject1, subject2); + } + + #[test] + fn test_subject_name_hash() { + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + let mut hasher = DefaultHasher::new(); + subject1.hash(&mut hasher); + let hash1 = hasher.finish(); + let mut hasher = DefaultHasher::new(); + subject2.hash(&mut hasher); + let hash2 = hasher.finish(); + assert_eq!(hash1, hash2); + + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + let mut hasher = DefaultHasher::new(); + subject1.hash(&mut hasher); + let hash1 = hasher.finish(); + let mut hasher = DefaultHasher::new(); + subject2.hash(&mut hasher); + let hash2 = hasher.finish(); + assert_eq!(hash1, hash2); + + let subject1 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: false, + }; + let subject2 = SubjectName::TopicNameStrategy { + topic: "scratch.example.tenant".to_string(), + key: true, + }; + let mut hasher = DefaultHasher::new(); + subject1.hash(&mut hasher); + let hash1 = hasher.finish(); + let mut hasher = DefaultHasher::new(); + subject2.hash(&mut hasher); + let hash2 = hasher.finish(); + assert_ne!(hash1, hash2); + } +} diff --git a/dsh_sdk/src/schema_store/types/subjects.rs b/dsh_sdk/src/schema_store/types/subjects.rs new file mode 100644 index 0000000..48af082 --- /dev/null +++ b/dsh_sdk/src/schema_store/types/subjects.rs @@ -0,0 +1,151 @@ +use serde::{Deserialize, Serialize}; + +use super::SchemaType; + +/// Subject version +/// +/// Select a specific `version` of the subject or the `latest` version +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum SubjectVersion { + Latest, + Version(i32), +} + +/// Subject +/// +/// All related info related subject +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Subject { + pub subject: String, + pub id: i32, + pub version: i32, + #[serde(default)] + pub schema_type: SchemaType, + pub schema: String, +} + +/// Subject version +/// +/// Subjects related version +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubjectVersionInfo { + pub subject: String, + pub version: i32, +} + +impl Default for SubjectVersion { + fn default() -> Self { + Self::Latest + } +} + +impl From for SubjectVersion { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for SubjectVersion { + fn from(value: &str) -> Self { + match value { + "latest" => Self::Latest, + version => match version.parse::() { + Ok(version) => Self::Version(version), + Err(_) => Self::Latest, + }, + } + } +} + +impl From for SubjectVersion { + fn from(value: i32) -> Self { + Self::Version(value) + } +} + +impl std::fmt::Display for SubjectVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Latest => write!(f, "latest"), + Self::Version(version) => write!(f, "{}", version), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_subject_version_from_string() { + let latest: SubjectVersion = "latest".into(); + assert_eq!(latest, SubjectVersion::Latest); + + let version: SubjectVersion = "1".into(); + assert_eq!(version, SubjectVersion::Version(1)); + + let version: SubjectVersion = "2".into(); + assert_eq!(version, SubjectVersion::Version(2)); + } + + #[test] + fn test_subject_version_from_i32() { + let version: SubjectVersion = 1.into(); + assert_eq!(version, SubjectVersion::Version(1)); + + let version: SubjectVersion = 2.into(); + assert_eq!(version, SubjectVersion::Version(2)); + } + + #[test] + fn test_subject_version_display() { + let latest = SubjectVersion::Latest; + assert_eq!(latest.to_string(), "latest"); + + let version = SubjectVersion::Version(1); + assert_eq!(version.to_string(), "1"); + } + + #[test] + fn test_subject_version_default() { + let default = SubjectVersion::default(); + assert_eq!(default, SubjectVersion::Latest); + } + + #[test] + fn test_subject_version_from_string_default() { + let default: SubjectVersion = "invalid".into(); + assert_eq!(default, SubjectVersion::Latest); + } + + #[test] + fn test_subject_serde() { + let subject_proto = + r#"{"subject":"test","id":1,"version":1,"schemaType":"PROTOBUF","schema":"schema"}"#; + let subject: Subject = serde_json::from_str(subject_proto).unwrap(); + assert_eq!(subject.subject, "test"); + assert_eq!(subject.id, 1); + assert_eq!(subject.version, 1); + assert_eq!(subject.schema_type, SchemaType::PROTOBUF); + assert_eq!(subject.schema, "schema"); + + let subject_json = + r#"{"subject":"test","id":1,"version":1,"schemaType":"JSON","schema":"schema"}"#; + let subject: Subject = serde_json::from_str(subject_json).unwrap(); + assert_eq!(subject.subject, "test"); + assert_eq!(subject.id, 1); + assert_eq!(subject.version, 1); + assert_eq!(subject.schema_type, SchemaType::JSON); + assert_eq!(subject.schema, "schema"); + + let subject_avro = r#"{"subject":"test","id":1,"version":1,"schema":"schema"}"#; + let subject: Subject = serde_json::from_str(subject_avro).unwrap(); + assert_eq!(subject.subject, "test"); + assert_eq!(subject.id, 1); + assert_eq!(subject.version, 1); + assert_eq!(subject.schema_type, SchemaType::AVRO); + assert_eq!(subject.schema, "schema"); + } +} diff --git a/dsh_sdk/src/utils/http_client.rs b/dsh_sdk/src/utils/http_client.rs new file mode 100644 index 0000000..5e8ab33 --- /dev/null +++ b/dsh_sdk/src/utils/http_client.rs @@ -0,0 +1,110 @@ +//! Lightweight HTTP client. +//! +//! This module provides a simple HTTP client to send requests to the server and other internal endpoints. +//! +//! # NOTE: +//! This client is not meant to be exposed as a public API, it is only meant to be used internally. + +use hyper::Request; +use hyper_util::client::legacy::{connect::HttpConnector, Client}; + +use http::Uri; +use http_body_util::{BodyExt, Empty}; +use hyper::body::Bytes; +use hyper_rustls::ConfigBuilderExt; +use hyper_util::rt::TokioExecutor; +use rustls::RootCertStore; +use std::str::FromStr; + +async fn http_client() -> Result<(), Box> { + let url = "http://httpbin.org/ip"; + // HTTPS requires picking a TLS implementation, so give a better + // warning if the user tries to request an 'https' URL. + let url = url.parse::()?; + if url.scheme_str() != Some("http") { + eprintln!("This example only works with 'http' URLs."); + return Ok(()); + } + + let client = Client::builder(TokioExecutor::new()).build(HttpConnector::new()); + + let req = Request::builder() + .uri(url) + .body(Empty::::new())?; + + let resp = client.request(req).await?; + + eprintln!("{:?} {:?}", resp.version(), resp.status()); + eprintln!("{:#?}", resp.headers()); + + Ok(()) +} + +async fn https_client() -> io::Result<()> { + // Set a process wide default crypto provider. + let _ = rustls::crypto::ring::default_provider().install_default(); + + let url = "http://httpbin.org/ip"; + + // Second parameter is custom Root-CA store (optional, defaults to native cert store). + let mut ca = match env::args().nth(2) { + Some(ref path) => { + let f = fs::File::open(path) + .map_err(|e| error(format!("failed to open {}: {}", path, e)))?; + let rd = io::BufReader::new(f); + Some(rd) + } + None => None, + }; + + // Prepare the TLS client config + let tls = match ca { + Some(ref mut rd) => { + // Read trust roots + let certs = rustls_pemfile::certs(rd).collect::, _>>()?; + let mut roots = RootCertStore::empty(); + roots.add_parsable_certificates(certs); + // TLS client config using the custom CA store for lookups + rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth() + } + // Default TLS client config with native roots + None => rustls::ClientConfig::builder() + .with_native_roots()? + .with_no_client_auth(), + }; + // Prepare the HTTPS connector + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(tls) + .https_or_http() + .enable_http1() + .build(); + + // Build the hyper client from the HTTPS connector. + let client: Client<_, Empty> = Client::builder(TokioExecutor::new()).build(https); + + // Prepare a chain of futures which sends a GET request, inspects + // the returned headers, collects the whole body and prints it to + // stdout. + let fut = async move { + let res = client + .get(url) + .await + .map_err(|e| error(format!("Could not get: {:?}", e)))?; + println!("Status:\n{}", res.status()); + println!("Headers:\n{:#?}", res.headers()); + + let body = res + .into_body() + .collect() + .await + .map_err(|e| error(format!("Could not get body: {:?}", e)))? + .to_bytes(); + println!("Body:\n{}", String::from_utf8_lossy(&body)); + + Ok(()) + }; + + fut.await +} diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index 6209fb8..365c4c5 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -12,6 +12,8 @@ use crate::error::DshError; pub mod dlq; #[cfg(feature = "graceful-shutdown")] pub mod graceful_shutdown; +#[cfg(feature = "hyper-client")] +pub(crate) mod http_client; #[cfg(feature = "metrics")] pub mod metrics; diff --git a/example_dsh_service/Cargo.toml b/example_dsh_service/Cargo.toml index fb68079..09c1af7 100644 --- a/example_dsh_service/Cargo.toml +++ b/example_dsh_service/Cargo.toml @@ -5,7 +5,7 @@ description = "An example of DSH service using the dsh-sdk crate" edition = "2021" [dependencies] -dsh_sdk = { path = "../dsh_sdk", version = "0.4", features = ["rdkafka-ssl-vendored"] } +dsh_sdk = { path = "../dsh_sdk", version = "0.5", features = ["rdkafka-ssl-vendored"] } log = "0.4" env_logger = "0.11" tokio = { version = "^1.35", features = ["full"] } \ No newline at end of file From 8cc424f4edeff89134049c8faaa75c0f0b34ce1e Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:46:25 +0100 Subject: [PATCH 03/23] Dsh kafka config (#97) only on PR * Add DshKafkaConfig trait and update DLQ * Update readme and minor bugs --- .github/workflows/branch.yaml | 2 +- dsh_sdk/CHANGELOG.md | 3 + dsh_sdk/Cargo.toml | 25 +-- dsh_sdk/README.md | 43 +++-- dsh_sdk/examples/dlq_implementation.rs | 99 +++++----- dsh_sdk/examples/produce_consume.rs | 30 ++- dsh_sdk/src/dsh.rs | 160 ++++++---------- dsh_sdk/src/dsh_old/mod.rs | 3 +- dsh_sdk/src/dsh_old/properties.rs | 141 +++----------- dsh_sdk/src/error.rs | 2 +- dsh_sdk/src/lib.rs | 34 ++-- dsh_sdk/src/management_api/token_fetcher.rs | 6 +- .../kafka_protocol/config.rs | 175 +++++++++++++++--- .../protocol_adapters/kafka_protocol/mod.rs | 38 +++- .../kafka_protocol/rdkafka.rs | 88 +++++++++ dsh_sdk/src/protocol_adapters/mod.rs | 1 + .../schema_store/types/subject_strategy.rs | 2 - dsh_sdk/src/utils/dlq.rs | 157 +++++++++------- dsh_sdk/src/utils/mod.rs | 2 +- example_dsh_service/Cargo.toml | 3 +- example_dsh_service/src/custom_metrics.rs | 2 +- example_dsh_service/src/main.rs | 19 +- 22 files changed, 607 insertions(+), 428 deletions(-) create mode 100644 dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs diff --git a/.github/workflows/branch.yaml b/.github/workflows/branch.yaml index dd0c600..9ca519e 100644 --- a/.github/workflows/branch.yaml +++ b/.github/workflows/branch.yaml @@ -57,5 +57,5 @@ jobs: run: cargo install cargo-hack --locked if: matrix.version == 'stable' - name: cargo check all features - run: cargo hack check --feature-powerset --depth 3 --no-dev-deps + run: cargo hack check --feature-powerset --no-dev-deps if: matrix.version == 'stable' diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 4548363..8b4e336 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **Breaking change:** `dsh_sdk::Dsh::reqwest_client_config` now returns `reqwest::ClientConfig` instead of `Result` - **Breaking change:** `dsh_sdk::Dsh::reqwest_blocking_client_config` now returns `reqwest::ClientConfig` instead of `Result` +- **Breaking change:** `dsh_sdk::utils::Dlq` does not require `Dsh`/`Properties` as argument anymore ### Moved - Moved `dsh_sdk::dsh::properties` to `dsh_sdk::propeties` @@ -30,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` ### Removed +- Removed `dsh_sdk::rdkafka` public re-export, import `rdkafka` directly + - **NOTE** Feature-flag `rdkafka-ssl` and `rdkafka-ssl-vendored` are removed! - Removed `Default` trait for `Dsh` (original `Properties`) struct as this should be public ### Fixed diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 7dd35be..4c85dca 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true name = "dsh_sdk" readme = 'README.md' repository.workspace = true -version = "0.5.0" +version = "0.5.0-rc.1" [package.metadata.docs.rs] all-features = true @@ -32,7 +32,7 @@ prometheus = { version = "0.13", features = ["process"], optional = true } protofish = { version = "0.5.2", optional = true } rcgen = { version = "0.13", optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "blocking"], optional = true } -rdkafka = { version = "0.36", features = ["cmake-build"], optional = true } +rdkafka = { version = ">=0.36", default-features = false, optional = true } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"], optional = true } sha2 = { version = "0.10", optional = true} @@ -41,32 +41,33 @@ tokio = { version = "^1.35", features = ["signal", "sync", "time", "macros"], op tokio-util = { version = "0.7", default-features = false, optional = true } [features] -default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl", "schema-store"] +# default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl", "schema-store"] +default = ["bootstrap", "kafka"] bootstrap = ["certificate", "serde_json", "tokio/rt-multi-thread"] +kafka = ["bootstrap"] +rdkafka-config = ["rdkafka", "kafka"] # Impl of config trait only + certificate = ["rcgen", "reqwest", "pem"] schema-store = ["bootstrap", "reqwest", "serde_json", "apache-avro", "protofish"] metrics = ["prometheus", "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] -dlq = ["tokio", "bootstrap", "rdkafka-ssl", "graceful-shutdown"] +dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/ssl-vendored", "rdkafka/libz", "rdkafka/tokio", "graceful-shutdown"] graceful-shutdown = ["tokio", "tokio-util"] -management-api = ["reqwest"] +management-api-token-fetcher = ["reqwest"] protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] + # http-protocol-adapter = ["protocol-token-fetcher"] # mqtt-protocol-adapter = ["protocol-token-fetcher"] # hyper-client = ["hyper", "hyper-util", "hyper-rustls", "rustls", "http", "rustls-pemfile"] -# TODO: Remove the following features at v0.6.0 -rdkafka-ssl-vendored = ["rdkafka", "rdkafka/ssl-vendored", "rdkafka/cmake-build"] -rdkafka-ssl = ["rdkafka", "rdkafka/ssl", "rdkafka/cmake-build"] -#mqtt-token-fetcher = ["protocol-token-fetcher"] -#rest-token-fetcher = ["management-api"] [dev-dependencies] mockito = "1.1.1" openssl = "0.10" tokio = { version = "^1.35", features = ["full"] } -hyper = { version = "1.3", features = ["full"]} +hyper = { version = "1.3", features = ["full"] } serial_test = "3.1.0" dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.2.0"} dsh_sdk = { features = ["dlq"], path = "." } -env_logger = "0.11" \ No newline at end of file +env_logger = "0.11" +rdkafka = { version = ">=0.36", features = ["cmake-build", "ssl-vendored"], default-features = true } \ No newline at end of file diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 4d7639a..7f8edac 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -5,6 +5,14 @@ [![dependency status](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs/status.svg)](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +# NOTE +As this is a release candidate it may contain bugs and/or incomplete features and incorrect documentation and future updates may contain breaking changes. + +Please report any issues you encounter. + +## Migration guide 0.4.X -> 0.5.X +See [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for more information on how to migrate from 0.4.X to 0.5.X. + ## Description This library can be used to interact with the DSH Platform. It is intended to be used as a base for services that will be used to interact with DSH. Features include: - Connect to DSH @@ -27,27 +35,28 @@ To use this SDK with the default features in your project, add the following to ```toml [dependencies] -dsh_sdk = "0.4" +dsh_sdk = "0.5" ``` However, if you would like to use only specific features, you can specify them in your Cargo.toml file. For example, if you would like to use only the bootstrap feature, add the following to your Cargo.toml file: ```toml [dependencies] -dsh_sdk = { version = "0.4", default-features = false, features = ["bootstrap"] } +dsh_sdk = { version = "0.5", default-features = false, features = ["rdkafka"] } +rdkafka = { version = "0.37", features = ["cmake-buld", "ssl-vendored"] } ``` See [feature flags](#feature-flags) for more information on the available features. To use this SDK in your project ```rust -use dsh_sdk::Properties; -use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; +use dsh_sdk::DshKafkaConfig; +use rdkafka::consumer::{Consumer, StreamConsumer}; +use rdkafka::ClientConfig; fn main() -> Result<(), Box>{ - let dsh_properties = Properties::get(); // get a rdkafka consumer config for example - let consumer: StreamConsumer = dsh_properties.consumer_rdkafka_config().create()?; + let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; } ``` @@ -56,18 +65,22 @@ The SDK is compatible with running in a container on a DSH tenant, on DSH System See [CONNECT_PROXY_VPN_LOCAL](CONNECT_PROXY_VPN_LOCAL.md) for more info. ## Feature flags -The following features are available in this library and can be enabled/disabled in your Cargo.toml file.: +See the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for more information on the changes in feature flags since the v0.5.X update. + +The following features are available in this library and can be enabled/disabled in your Cargo.toml file: | **feature** | **default** | **Description** | |---|---|---| -| `bootstrap` | ✓ | Generate signed certificate and fetch datastreams info
Also makes certificates available, to be used as lowlevel API | -| `rest-token-fetcher` | ✗ | Fetch tokens to use DSH Rest API | -| `mqtt-token-fetcher` | ✗ | Fetch tokens to use DSH MQTT | -| `metrics` | ✓ | Enable (custom) metrics for your service | -| `graceful-shutdown` | ✓ | Create a signal handler for implementing a graceful shutdown | -| `dlq` | ✗ | Dead Letter Queue implementation (experimental) | -| `rdkafka-ssl` | ✓ | Dynamically link to librdkafka to a locally installed OpenSSL | -| `rdkafka-ssl-vendored` | ✗ | Build OpenSSL during compile and statically link librdkafka
(No initial install required in environment, slower compile time) | +| `bootstrap` | ✓ | Generate signed certificate and fetch datastreams info | +| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | +| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | +| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | +| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | +| `metrics` | ✗ | Enable prometheus metrics including http server | +| `graceful-shutdown` | ✗ | Tokio based gracefull shutdown handler | +| `dlq` | ✗ | Dead Letter Queue implementation | +| `rest-token-fetcher` | ✗ | Replaced by `management-api-token-fetcher` | +| `mqtt-token-fetcher` | ✗ | Replaced by `protocol-token-fetcher` | See api documentation for more information on how to use these features including. diff --git a/dsh_sdk/examples/dlq_implementation.rs b/dsh_sdk/examples/dlq_implementation.rs index 69b157e..c3ffaa0 100644 --- a/dsh_sdk/examples/dlq_implementation.rs +++ b/dsh_sdk/examples/dlq_implementation.rs @@ -1,14 +1,10 @@ -// make sure to use the dlq feature in your Cargo.toml -// dsh_sdk = { version = "0.4", features = ["dlq"] } -// -// To run this example, run the following command: -// cargo run --features dlq --example dlq_implementation +use dsh_sdk::utils::dlq::{self, DlqChannel, ErrorToDlq}; +use dsh_sdk::utils::graceful_shutdown::Shutdown; +use dsh_sdk::DshKafkaConfig; -use dsh_sdk::dlq::{self, ErrorToDlq}; -use dsh_sdk::graceful_shutdown::Shutdown; -use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; -use dsh_sdk::rdkafka::Message; -use dsh_sdk::Properties; +use rdkafka::consumer::{Consumer, StreamConsumer}; +use rdkafka::message::{BorrowedMessage, Message, OwnedMessage}; +use rdkafka::ClientConfig; use std::backtrace::Backtrace; use thiserror::Error; use tokio::sync::mpsc; @@ -20,9 +16,9 @@ enum ConsumerError { DeserializeError(#[from] std::string::FromUtf8Error), } -// implement the ErrorToDlq trait for your custom error type (or exusting error types) -impl dlq::ErrorToDlq for ConsumerError { - fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { +// implement the `ErrorToDlq` trait for your custom error type (or existing error types) +impl ErrorToDlq for ConsumerError { + fn to_dlq(&self, kafka_message: OwnedMessage) -> dlq::SendToDlq { let backtrace = Backtrace::force_capture(); // this is optional as it is heavy on performance dlq::SendToDlq::new( kafka_message, @@ -31,7 +27,7 @@ impl dlq::ErrorToDlq for ConsumerError { Some(backtrace.to_string()), ) } - // Definition if error is retryable or not + // Define if error is retryable or not fn retryable(&self) -> dlq::Retryable { match self { ConsumerError::DeserializeError(_) => dlq::Retryable::NonRetryable, @@ -40,62 +36,81 @@ impl dlq::ErrorToDlq for ConsumerError { } // simple deserialization function, that returns a Result of string or defined ConsumerError -fn deserialize(msg: &dsh_sdk::rdkafka::message::OwnedMessage) -> Result { +fn deserialize(msg: &BorrowedMessage) -> Result { match msg.payload() { Some(payload) => Ok(String::from_utf8(payload.to_vec())?), None => Ok("".to_string()), } } -// simple consumer function -async fn consume(consumer: StreamConsumer, dlq_tx: &mut mpsc::Sender) { +// simple consumer function with shutdown function +async fn consume( + consumer: StreamConsumer, + topic: &str, + mut dlq_channel: DlqChannel, + shutdown: Shutdown, +) { consumer - .subscribe(&["sub_to_your_topic"]) + .subscribe(&[topic]) .expect("Can't subscribe to topic"); - while let Ok(msg) = consumer.recv().await { - let owned_msg = msg.detach(); - match deserialize(&owned_msg) { - // send message to dlq if error occurs - Err(e) => e.to_dlq(owned_msg).send(dlq_tx).await, - // process message, in this case print payload - Ok(payload) => { - println!("Payload: {}", payload) + + loop { + tokio::select! { + msg = consumer.recv() => match msg { + Ok(msg) => { + match deserialize(&msg) { + // send message to dlq if error occurs + Err(e) => e.to_dlq(msg.detach()).send(&mut dlq_channel).await, + // process message, in this case print payload + Ok(payload) => { + println!("Payload: {}", payload) + } + } + } + Err(e) => { + eprintln!("Error while receiving message: {}", e); } + }, + _ = shutdown.signal_listener() => { + println!("Shutting down consumer"); + break; + } } } } #[tokio::main] async fn main() -> Result<(), Box> { - // set the dlq topics + // set the dlq topics (required) std::env::set_var("DLQ_DEAD_TOPIC", "scratch.dlq.local-tenant"); std::env::set_var("DLQ_RETRY_TOPIC", "scratch.dlq.local-tenant"); - let dsh = Properties::get(); + + // Topic to subscribe to (change to your topic) + let topic = "your_topic_name"; + let shutdown = Shutdown::new(); - let consumer: StreamConsumer = dsh.consumer_rdkafka_config().create()?; + let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; - let mut dlq = dlq::Dlq::new(dsh, shutdown.clone())?; - // get the dlq channel sender to send messages to the dlq - // for example in your consumer - let mut dlq_tx = dlq.dlq_records_tx(); + // Start the `dlq` service, returns a sender to send messages to the dlq + let dlq_channel = dlq::Dlq::start(shutdown.clone())?; + + // run the `consumer` in a separate tokio task + let shutdown_clone = shutdown.clone(); let consumer_handle = tokio::spawn(async move { - consume(consumer, &mut dlq_tx).await; - }); - // run the dlq in a separate tokio task - let dlq_handle = tokio::spawn(async move { - dlq.run().await; + consume(consumer, topic, dlq_channel, shutdown_clone).await; }); + + // wait for `consumer` or `dlq` to finish or for shutdown signal tokio::select! { _ = consumer_handle => { println!("Consumer finished"); } - _ = dlq_handle => { - println!("DLQ finished"); - } _ = shutdown.signal_listener() => { println!("Shutting down"); } } - shutdown.complete().await; // wait for graceful shutdown to complete + + // wait for graceful shutdown to complete + shutdown.complete().await; Ok(()) } diff --git a/dsh_sdk/examples/produce_consume.rs b/dsh_sdk/examples/produce_consume.rs index bbe3c12..ef2788f 100644 --- a/dsh_sdk/examples/produce_consume.rs +++ b/dsh_sdk/examples/produce_consume.rs @@ -1,12 +1,13 @@ -use dsh_sdk::rdkafka::consumer::CommitMode; -use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; -use dsh_sdk::rdkafka::producer::{FutureProducer, FutureRecord}; -use dsh_sdk::rdkafka::Message; -use dsh_sdk::Properties; +use dsh_sdk::DshKafkaConfig; +use rdkafka::consumer::CommitMode; +use rdkafka::consumer::{Consumer, StreamConsumer}; +use rdkafka::producer::{FutureProducer, FutureRecord}; +use rdkafka::ClientConfig; +use rdkafka::Message; const TOTAL_MESSAGES: usize = 10; -async fn produce(producer: &mut FutureProducer, topic: &str) { +async fn produce(producer: FutureProducer, topic: &str) { for key in 0..TOTAL_MESSAGES { let payload = format!("hello world {}", key); let msg = producer @@ -24,7 +25,7 @@ async fn produce(producer: &mut FutureProducer, topic: &str) { } } -async fn consume(consumer: &mut StreamConsumer, topic: &str) { +async fn consume(consumer: StreamConsumer, topic: &str) { consumer.subscribe(&[topic]).unwrap(); let mut i = 0; while i < TOTAL_MESSAGES { @@ -39,21 +40,18 @@ async fn consume(consumer: &mut StreamConsumer, topic: &str) { #[tokio::main] async fn main() -> Result<(), Box> { - // Create a new DSH Properties instance (requires local_datastreams.json in root of project, as it runs in local mode) - let dsh_properties = Properties::get(); - // Define your topic let topic = "test"; - // Create a new producer based on the properties default config - let mut producer: FutureProducer = dsh_properties.producer_rdkafka_config().create()?; + // Create a new producer from the RDkafka Client Config together with dsh_prodcer_config form DshKafkaConfig trait + let producer: FutureProducer = ClientConfig::new().dsh_producer_config().create()?; // Produce messages towards topic - produce(&mut producer, topic).await; + produce(producer, topic).await; - // Create a new consumer based on the properties default config - let mut consumer: StreamConsumer = dsh_properties.consumer_rdkafka_config().create()?; + // Create a new consumer from the RDkafka Client Config together with dsh_consumer_config form DshKafkaConfig trait + let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; - consume(&mut consumer, topic).await; + consume(consumer, topic).await; Ok(()) } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index a948b85..8e7e244 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -27,16 +27,18 @@ //! ``` use log::{error, warn}; use std::env; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use crate::certificates::Cert; use crate::datastream::Datastream; use crate::error::DshError; -use crate::protocol_adapters::kafka_protocol::config; use crate::utils; use crate::*; -// TODP: Remove at v0.6.0 +#[cfg(feature = "kafka")] +use crate::protocol_adapters::kafka_protocol::config::KafkaConfig; + +// TODO: Remove at v0.6.0 pub use crate::dsh_old::*; /// DSH properties struct. Create new to initialize all related components to connect to the DSH kafka clusters @@ -47,30 +49,16 @@ pub use crate::dsh_old::*; /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for /// more information configuring the consmer or producer via environment variables. -/// -/// # Example -/// ``` -/// use dsh_sdk::Dsh; -/// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; -/// -/// #[tokio::main] -/// async fn main() -> Result<(), Box> { -/// let dsh_properties = Dsh::get(); -/// -/// let consumer_config = dsh_properties.consumer_rdkafka_config(); -/// let consumer: StreamConsumer = consumer_config.create()?; -/// -/// Ok(()) -/// } -/// ``` #[derive(Debug, Clone)] pub struct Dsh { config_host: String, task_id: String, tenant_name: String, - datastream: Datastream, + datastream: Arc, certificates: Option, + #[cfg(feature = "kafka")] + kafka_config: KafkaConfig, } impl Dsh { @@ -82,12 +70,15 @@ impl Dsh { datastream: Datastream, certificates: Option, ) -> Self { + let datastream = Arc::new(datastream); Self { config_host, task_id, tenant_name, - datastream, + datastream: datastream.clone(), certificates, + #[cfg(feature = "kafka")] + kafka_config: KafkaConfig::new(Some(datastream)), } } /// Get the DSH Dsh on a lazy way. If not already initialized, it will initialize the properties @@ -107,12 +98,11 @@ impl Dsh { /// # Example /// ``` /// use dsh_sdk::Dsh; - /// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { - /// let dsh_properties = Dsh::get(); - /// let consumer: StreamConsumer = dsh_properties.consumer_rdkafka_config().create()?; + /// let dsh = Dsh::get(); + /// let datastreams = dsh.datastream(); /// # Ok(()) /// # } /// ``` @@ -259,7 +249,7 @@ impl Dsh { /// /// This datastream is fetched at initialization of the properties, and can not be updated during runtime. pub fn datastream(&self) -> &Datastream { - &self.datastream + self.datastream.as_ref() } /// High level method to fetch the kafka properties provided by DSH (datastreams.json) @@ -307,6 +297,11 @@ impl Dsh { self.datastream().schema_store() } + #[cfg(feature = "kafka")] + #[deprecated( + since = "0.5.0", + note = "Moved to `Dsh::kafka_config().kafka_brokers()` and is part of the `kafka` feature" + )] /// Get the Kafka brokers. /// /// ## Environment variables @@ -320,6 +315,11 @@ impl Dsh { self.datastream().get_brokers_string() } + #[cfg(feature = "kafka")] + #[deprecated( + since = "0.5.0", + note = "Moved to `Dsh::kafka_config().group_id()` and is part of the `kafka` feature" + )] /// Get the kafka_group_id based. /// /// ## Environment variables @@ -351,6 +351,11 @@ impl Dsh { } } + #[cfg(feature = "kafka")] + #[deprecated( + since = "0.5.0", + note = "Moved to `Dsh::kafka_config().enable_auto_commit()` and is part of the `kafka` feature" + )] /// Get the confifured kafka auto commit setinngs. /// /// ## Environment variables @@ -362,9 +367,14 @@ impl Dsh { /// - Required: `false` /// - Options: `true`, `false` pub fn kafka_auto_commit(&self) -> bool { - config::KafkaConfig::get().enable_auto_commit() + self.kafka_config.enable_auto_commit() } + #[cfg(feature = "kafka")] + #[deprecated( + since = "0.5.0", + note = "Moved to `Dsh::kafka_config().auto_offset_reset()` and is part of the `kafka` feature" + )] /// Get the kafka auto offset reset settings. /// /// ## Environment variables @@ -376,86 +386,36 @@ impl Dsh { /// - Required: `false` /// - Options: smallest, earliest, beginning, largest, latest, end pub fn kafka_auto_offset_reset(&self) -> String { - config::KafkaConfig::get().auto_offset_reset() + self.kafka_config.auto_offset_reset().to_string() } - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + + #[cfg(feature = "kafka")] + /// Get the kafka config from initiated Dsh struct. + pub fn kafka_config(&self) -> &KafkaConfig { + &self.kafka_config + } + + #[deprecated( + since = "0.5.0", + note = "Use `Dsh::DshKafkaConfig` trait instead, see https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)" + )] + #[cfg(feature = "rdkafka-config")] pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let consumer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("group.id", self.kafka_group_id()) - .set("client.id", self.client_id()) - .set("enable.auto.commit", self.kafka_auto_commit().to_string()) - .set("auto.offset.reset", self.kafka_auto_offset_reset()); - if let Some(session_timeout) = consumer_config.session_timeout() { - config.set("session.timeout.ms", session_timeout.to_string()); - } - if let Some(queued_buffering_max_messages_kbytes) = - consumer_config.queued_buffering_max_messages_kbytes() - { - config.set( - "queued.max.messages.kbytes", - queued_buffering_max_messages_kbytes.to_string(), - ); - } - log::debug!("Consumer config: {:#?}", config); - // Set SSL if certificates are present - if let Ok(certificates) = &self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } + config.dsh_consumer_config(); config } - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + #[deprecated( + since = "0.5.0", + note = "Use `Dsh::DshKafkaConfig` trait instead, see https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)" + )] + #[cfg(feature = "rdkafka-config")] pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let producer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); + use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("client.id", self.client_id()); - if let Some(batch_num_messages) = producer_config.batch_num_messages() { - config.set("batch.num.messages", batch_num_messages.to_string()); - } - if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { - config.set( - "queue.buffering.max.messages", - queue_buffering_max_messages.to_string(), - ); - } - if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { - config.set( - "queue.buffering.max.kbytes", - queue_buffering_max_kbytes.to_string(), - ); - } - if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { - config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); - } - log::debug!("Producer config: {:#?}", config); - - // Set SSL if certificates are present - if let Ok(certificates) = self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } + config.dsh_producer_config(); config } } @@ -469,13 +429,15 @@ mod tests { impl Default for Dsh { fn default() -> Self { - let datastream = Datastream::load_local_datastreams().unwrap_or_default(); + let datastream = Arc::new(Datastream::load_local_datastreams().unwrap_or_default()); Self { task_id: "local_task_id".to_string(), tenant_name: "local_tenant".to_string(), config_host: "http://localhost/".to_string(), datastream, certificates: None, + #[cfg(feature = "kafka")] + kafka_config: KafkaConfig::default(), } } } diff --git a/dsh_sdk/src/dsh_old/mod.rs b/dsh_sdk/src/dsh_old/mod.rs index c0024aa..4f6e16f 100644 --- a/dsh_sdk/src/dsh_old/mod.rs +++ b/dsh_sdk/src/dsh_old/mod.rs @@ -10,7 +10,7 @@ //! # Example //! ``` //! use dsh_sdk::Properties; -//! use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; +//! use rdkafka::consumer::{Consumer, StreamConsumer}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -21,7 +21,6 @@ //! # Ok(()) //! # } //! ``` - #[deprecated( since = "0.5.0", note = "`dsh_sdk::dsh::certificates` is moved to `dsh_sdk::certificates`" diff --git a/dsh_sdk/src/dsh_old/properties.rs b/dsh_sdk/src/dsh_old/properties.rs index 3e8a971..9c4994d 100644 --- a/dsh_sdk/src/dsh_old/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -10,29 +10,14 @@ //! ## Environment variables //! See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for //! more information configuring the consmer or producer via environment variables. -//! -//! # Example -//! ``` -//! use dsh_sdk::Properties; -//! use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! let dsh_properties = Properties::get(); -//! let consumer_config = dsh_properties.consumer_rdkafka_config(); -//! let consumer: StreamConsumer = consumer_config.create()?; -//! -//! # Ok(()) -//! # } -//! ``` use log::{error, warn}; use std::env; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use crate::certificates::Cert; use crate::datastream; use crate::error::DshError; -use crate::protocol_adapters::kafka_protocol::config; + use crate::utils; use crate::*; @@ -44,22 +29,6 @@ use crate::*; /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for /// more information configuring the consmer or producer via environment variables. -/// -/// # Example -/// ``` -/// use dsh_sdk::Properties; -/// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; -/// -/// #[tokio::main] -/// async fn main() -> Result<(), Box> { -/// let dsh_properties = Properties::get(); -/// -/// let consumer_config = dsh_properties.consumer_rdkafka_config(); -/// let consumer: StreamConsumer = consumer_config.create()?; -/// -/// Ok(()) -/// } -/// ``` #[deprecated(since = "0.5.0", note = "`Properties` is renamed to `dsh_sdk::Dsh`")] #[derive(Debug, Clone)] @@ -67,7 +36,7 @@ pub struct Properties { config_host: String, task_id: String, tenant_name: String, - datastream: datastream::Datastream, + datastream: Arc, certificates: Option, } @@ -80,6 +49,7 @@ impl Properties { datastream: datastream::Datastream, certificates: Option, ) -> Self { + let datastream = Arc::new(datastream); Self { config_host, task_id, @@ -105,7 +75,7 @@ impl Properties { /// # Example /// ``` /// use dsh_sdk::Properties; - /// use dsh_sdk::rdkafka::consumer::{Consumer, StreamConsumer}; + /// use rdkafka::consumer::{Consumer, StreamConsumer}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -358,6 +328,7 @@ impl Properties { } } + #[cfg(feature = "kafka")] /// Get the confifured kafka auto commit setinngs. /// /// ## Environment variables @@ -369,9 +340,13 @@ impl Properties { /// - Required: `false` /// - Options: `true`, `false` pub fn kafka_auto_commit(&self) -> bool { - config::KafkaConfig::get().enable_auto_commit() + crate::protocol_adapters::kafka_protocol::config::KafkaConfig::new(Some( + self.datastream.clone(), + )) + .enable_auto_commit() } + #[cfg(feature = "kafka")] /// Get the kafka auto offset reset settings. /// /// ## Environment variables @@ -383,7 +358,11 @@ impl Properties { /// - Required: `false` /// - Options: smallest, earliest, beginning, largest, latest, end pub fn kafka_auto_offset_reset(&self) -> String { - config::KafkaConfig::get().auto_offset_reset() + crate::protocol_adapters::kafka_protocol::config::KafkaConfig::new(Some( + self.datastream.clone(), + )) + .auto_offset_reset() + .to_string() } /// Get default RDKafka Consumer config to connect to Kafka on DSH. @@ -400,8 +379,7 @@ impl Properties { /// # Example /// ``` /// use dsh_sdk::Properties; - /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; - /// use dsh_sdk::rdkafka::consumer::stream_consumer::StreamConsumer; + /// use rdkafka::consumer::stream_consumer::StreamConsumer; /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { @@ -433,42 +411,9 @@ impl Properties { /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information /// configuring the consmer via environment variables. - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + #[cfg(feature = "rdkafka-config")] pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let consumer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); - let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("group.id", self.kafka_group_id()) - .set("client.id", self.client_id()) - .set("enable.auto.commit", self.kafka_auto_commit().to_string()) - .set("auto.offset.reset", self.kafka_auto_offset_reset()); - if let Some(session_timeout) = consumer_config.session_timeout() { - config.set("session.timeout.ms", session_timeout.to_string()); - } - if let Some(queued_buffering_max_messages_kbytes) = - consumer_config.queued_buffering_max_messages_kbytes() - { - config.set( - "queued.max.messages.kbytes", - queued_buffering_max_messages_kbytes.to_string(), - ); - } - log::debug!("Consumer config: {:#?}", config); - // Set SSL if certificates are present - if let Ok(certificates) = &self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } - config + crate::Dsh::get().consumer_rdkafka_config() } /// Get default RDKafka Producer config to connect to Kafka on DSH. @@ -479,8 +424,7 @@ impl Properties { /// /// # Example /// ``` - /// use dsh_sdk::rdkafka::config::RDKafkaLogLevel; - /// use dsh_sdk::rdkafka::producer::FutureProducer; + /// use rdkafka::producer::FutureProducer; /// use dsh_sdk::Properties; /// /// #[tokio::main] @@ -509,53 +453,16 @@ impl Properties { /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information /// configuring the producer via environment variables. - #[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] + #[cfg(feature = "rdkafka-config")] pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - let producer_config = crate::protocol_adapters::kafka_protocol::config::KafkaConfig::get(); - let mut config = rdkafka::config::ClientConfig::new(); - config - .set("bootstrap.servers", self.kafka_brokers()) - .set("client.id", self.client_id()); - if let Some(batch_num_messages) = producer_config.batch_num_messages() { - config.set("batch.num.messages", batch_num_messages.to_string()); - } - if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { - config.set( - "queue.buffering.max.messages", - queue_buffering_max_messages.to_string(), - ); - } - if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { - config.set( - "queue.buffering.max.kbytes", - queue_buffering_max_kbytes.to_string(), - ); - } - if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { - config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); - } - log::debug!("Producer config: {:#?}", config); - - // Set SSL if certificates are present - if let Ok(certificates) = self.certificates() { - config - .set("security.protocol", "ssl") - .set("ssl.key.pem", certificates.private_key_pem()) - .set( - "ssl.certificate.pem", - certificates.dsh_kafka_certificate_pem(), - ) - .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); - } else { - config.set("security.protocol", "plaintext"); - } - config + crate::Dsh::get().producer_rdkafka_config() } } impl Default for Properties { fn default() -> Self { - let datastream = datastream::Datastream::load_local_datastreams().unwrap_or_default(); + let datastream = + Arc::new(datastream::Datastream::load_local_datastreams().unwrap_or_default()); Self { task_id: "local_task_id".to_string(), tenant_name: "local_tenant".to_string(), diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index ee8d5e4..af93b43 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -55,7 +55,7 @@ pub enum DshError { HyperError(#[from] hyper::http::Error), } -#[cfg(feature = "management-api")] +#[cfg(feature = "management-api-token-fetcher")] #[derive(Error, Debug)] #[non_exhaustive] pub enum DshRestTokenError { diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index b2fd253..79e3922 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -12,13 +12,13 @@ //! //! ### Example: //! ``` -//! use dsh_sdk::Properties; -//! use dsh_sdk::rdkafka::consumer::stream_consumer::StreamConsumer; +//! use dsh_sdk::DshKafkaConfig; +//! use rdkafka::ClientConfig; +//! use rdkafka::consumer::stream_consumer::StreamConsumer; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box>{ -//! let dsh_properties = Properties::get(); -//! let consumer: StreamConsumer = dsh_properties.consumer_rdkafka_config().create()?; +//! let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; //! # Ok(()) //! # } //! ``` @@ -28,14 +28,14 @@ //! //! ### Example: //! ```no_run -//! # use dsh_sdk::Properties; -//! # use dsh_sdk::rdkafka::consumer::stream_consumer::StreamConsumer; +//! use dsh_sdk::Dsh; + //! # fn main() -> Result<(), Box>{ -//! # let dsh_properties = Properties::get(); +//! let dsh = Dsh::get(); //! // check for write access to topic -//! let write_access = dsh_properties.datastream().get_stream("scratch.local.local-tenant").expect("Topic not found").write_access(); +//! let write_access = dsh.datastream().get_stream("scratch.local.local-tenant").expect("Topic not found").write_access(); //! // get the certificates, for example DSH_KAFKA_CERTIFICATE -//! let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); +//! let dsh_kafka_certificate = dsh.certificates()?.dsh_kafka_certificate_pem(); //! # Ok(()) //! # } //! ``` @@ -72,6 +72,8 @@ //! The DLQ is implemented by running the `Dlq` struct to push messages towards the DLQ topics. //! The `ErrorToDlq` trait can be implemented on your defined errors, to be able to send messages towards the DLQ Struct. +#![allow(deprecated)] + // to be kept in v0.6.0 #[cfg(feature = "certificate")] pub mod certificates; @@ -80,7 +82,7 @@ pub mod datastream; #[cfg(feature = "bootstrap")] pub mod dsh; pub mod error; -#[cfg(feature = "management-api")] +#[cfg(feature = "management-api-token-fetcher")] pub mod management_api; pub mod protocol_adapters; pub mod utils; @@ -92,7 +94,11 @@ pub mod schema_store; #[doc(inline)] pub use dsh::Dsh; -#[cfg(feature = "management-api")] +#[cfg(feature = "kafka")] +#[doc(inline)] +pub use protocol_adapters::kafka_protocol::DshKafkaConfig; + +#[cfg(feature = "management-api-token-fetcher")] pub use management_api::token_fetcher::{ ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder, }; @@ -130,8 +136,6 @@ pub mod graceful_shutdown; )] pub mod metrics; -#[cfg(any(feature = "rdkafka-ssl", feature = "rdkafka-ssl-vendored"))] -pub use rdkafka; #[cfg(feature = "protocol-token-fetcher")] #[deprecated( since = "0.5.0", @@ -141,13 +145,13 @@ pub mod mqtt_token_fetcher; #[cfg(feature = "bootstrap")] pub use dsh_old::Properties; -#[cfg(feature = "management-api")] +#[cfg(feature = "management-api-token-fetcher")] #[deprecated( since = "0.5.0", note = "`RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to `ManagementApiTokenFetcher` and `ManagementApiTokenFetcherBuilder`" )] mod rest_api_token_fetcher; -#[cfg(feature = "management-api")] +#[cfg(feature = "management-api-token-fetcher")] pub use rest_api_token_fetcher::{RestTokenFetcher, RestTokenFetcherBuilder}; // Environment variables diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index eda43be..eb43d9a 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -5,10 +5,10 @@ //! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. //! //! ## Example -//! Recommended usage is to use the [RestTokenFetcherBuilder] to create a new instance of the token fetcher. +//! Recommended usage is to use the [ManagementApiTokenFetcherBuilder] to create a new instance of the token fetcher. //! However, you can also create a new instance of the token fetcher directly. //! ```no_run -//! use dsh_sdk::{RestTokenFetcherBuilder, Platform}; +//! use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; //! use dsh_rest_api_client::Client; //! //! const CLIENT_SECRET: &str = ""; @@ -19,7 +19,7 @@ //! let platform = Platform::NpLz; //! let client = Client::new(platform.endpoint_rest_api()); //! -//! let tf = RestTokenFetcherBuilder::new(platform) +//! let tf = ManagementApiTokenFetcherBuilder::new(platform) //! .tenant_name(TENANT.to_string()) //! .client_secret(CLIENT_SECRET.to_string()) //! .build() diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs index a258804..57dee3b 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/config.rs @@ -1,12 +1,11 @@ //! Kafka configuration //! //! This module contains the configuration for the Kafka protocol adapter. +use std::sync::Arc; +use crate::datastream::Datastream; use crate::utils::get_env_var; use crate::*; -use std::sync::OnceLock; - -static KAFKA_CONFIG: OnceLock = OnceLock::new(); /// Kafka config /// @@ -15,6 +14,8 @@ static KAFKA_CONFIG: OnceLock = OnceLock::new(); /// configuring the consmer via environment variables. #[derive(Debug, Clone)] pub struct KafkaConfig { + // Datastreams + datastream: Arc, // Consumer specific config enable_auto_commit: bool, auto_offset_reset: String, @@ -28,7 +29,9 @@ pub struct KafkaConfig { } impl KafkaConfig { - pub fn new() -> Self { + pub fn new(datastream: Option>) -> Self { + let datastream = datastream + .unwrap_or_else(|| Arc::new(Datastream::load_local_datastreams().unwrap_or_default())); let enable_auto_commit = get_env_var(VAR_KAFKA_ENABLE_AUTO_COMMIT) .ok() .and_then(|v| v.parse().ok()) @@ -56,6 +59,7 @@ impl KafkaConfig { .ok() .and_then(|v| v.parse().ok()); Self { + datastream, enable_auto_commit, auto_offset_reset, session_timeout, @@ -66,31 +70,167 @@ impl KafkaConfig { queue_buffering_max_ms, } } - // TODO: Check, does this make sense? - pub fn get() -> &'static KafkaConfig { - KAFKA_CONFIG.get_or_init(KafkaConfig::new) + + /// Get the kafka properties provided by DSH (datastreams.json) + /// + /// This datastream is fetched at initialization of the properties, and can not be updated during runtime. + pub fn datastream(&self) -> &Datastream { + self.datastream.as_ref() } + + /// Get the Kafka brokers. + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_BOOTSTRAP_SERVERS` + /// - Usage: Overwrite hostnames of brokers + /// - Default: Brokers based on datastreams + /// - Required: `false` + pub fn kafka_brokers(&self) -> String { + self.datastream().get_brokers_string() + } + + /// Get the kafka_group_id + /// + /// ## Environment variables + /// You can set the following environment variables to overwrite the default value. + /// + /// ### `KAFKA_CONSUMER_GROUP_TYPE` + /// - Usage: Picks group_id based on type from datastreams + /// - Default: Shared + /// - Options: private, shared + /// - Required: `false` + /// + /// ### `KAFKA_GROUP_ID` + /// - Usage: Custom group id + /// - Default: NA + /// - Required: `false` + /// - Remark: Overrules `KAFKA_CONSUMER_GROUP_TYPE`. Mandatory to start with tenant name. (will prefix tenant name automatically if not set) + pub fn group_id(&self) -> String { + let tenant_name = Dsh::get().tenant_name(); + if let Ok(group_id) = get_env_var(VAR_KAFKA_GROUP_ID) { + if !group_id.starts_with(tenant_name) { + format!("{}_{}", tenant_name, group_id) + } else { + group_id + } + } else { + self.datastream() + .get_group_id(crate::datastream::GroupType::from_env()) + .unwrap_or(&format!("{}_CONSUMER", tenant_name)) + .to_string() + } + } + + /// Get the confifured kafka auto commit setinngs. + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_ENABLE_AUTO_COMMIT` + /// - Usage: Enable/Disable auto commit + /// - Default: `false` + /// - Required: `false` + /// - Options: `true`, `false` pub fn enable_auto_commit(&self) -> bool { self.enable_auto_commit } - pub fn auto_offset_reset(&self) -> String { - self.auto_offset_reset.clone() + + /// Get the kafka auto offset reset settings. + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_AUTO_OFFSET_RESET` + /// - Usage: Set the offset reset settings to start consuming from set option. + /// - Default: earliest + /// - Required: `false` + /// - Options: smallest, earliest, beginning, largest, latest, end + pub fn auto_offset_reset(&self) -> &str { + &self.auto_offset_reset } + + /// Session timeout in milliseconds for consuming messages + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_CONSUMER_SESSION_TIMEOUT_MS` + /// - Usage: Set the session timeout in milliseconds + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn session_timeout(&self) -> Option { self.session_timeout } + + /// Queued buffering max messages kbytes while consiuming + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES` + /// - Usage: Set the queued buffering max messages kbytes + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn queued_buffering_max_messages_kbytes(&self) -> Option { self.queued_buffering_max_messages_kbytes } + + /// Batch number of messages to be produced + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_PRODUCER_BATCH_NUM_MESSAGES` + /// - Usage: Set the batch number of messages to be produced + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn batch_num_messages(&self) -> Option { self.batch_num_messages } + + /// Maximum number of messages allowed on the producer queue + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES` + /// - Usage: Set the maximum number of messages allowed on the producer queue + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn queue_buffering_max_messages(&self) -> Option { self.queue_buffering_max_messages } + + /// Maximum total message size in KBYTES sum allowed on the producer queue + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES` + /// - Usage: Set the maximum total message size in KBYTES sum allowed on the producer queue + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn queue_buffering_max_kbytes(&self) -> Option { self.queue_buffering_max_kbytes } + + /// Delay in milliseconds to wait for messages in the producer queue to accumulate before sending in batch + /// + /// ## Environment variable + /// You can set the following environment variable to overwrite the default value. + /// + /// ### `KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS` + /// - Usage: Set the delay in milliseconds to wait for messages in the producer queue to accumulate before sending in batch + /// - Default: LibRdKafka default + /// - Required: `false` + /// - Options: Any integer pub fn queue_buffering_max_ms(&self) -> Option { self.queue_buffering_max_ms } @@ -98,16 +238,7 @@ impl KafkaConfig { impl Default for KafkaConfig { fn default() -> Self { - Self { - enable_auto_commit: false, - auto_offset_reset: "earliest".to_string(), - session_timeout: None, - queued_buffering_max_messages_kbytes: None, - batch_num_messages: None, - queue_buffering_max_messages: None, - queue_buffering_max_kbytes: None, - queue_buffering_max_ms: None, - } + Self::new(None) } } @@ -120,7 +251,7 @@ mod tests { #[test] #[serial(env_dependency)] fn test_kafka_config() { - let consumer_config = KafkaConfig::new(); + let consumer_config = KafkaConfig::new(None); assert_eq!(consumer_config.enable_auto_commit(), false); assert_eq!(consumer_config.auto_offset_reset(), "earliest"); assert_eq!(consumer_config.session_timeout(), None); @@ -155,7 +286,7 @@ mod tests { VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES, "1000", ); - let consumer_config = KafkaConfig::new(); + let consumer_config = KafkaConfig::default(); assert_eq!(consumer_config.enable_auto_commit(), true); assert_eq!(consumer_config.auto_offset_reset(), "latest"); assert_eq!(consumer_config.session_timeout(), Some(1000)); @@ -176,7 +307,7 @@ mod tests { env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES, "1000"); env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES, "1000"); env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS, "1000"); - let producer_config = KafkaConfig::new(); + let producer_config = KafkaConfig::default(); assert_eq!(producer_config.batch_num_messages(), Some(1000)); assert_eq!(producer_config.queue_buffering_max_messages(), Some(1000)); assert_eq!(producer_config.queue_buffering_max_kbytes(), Some(1000)); diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs index 3c4b952..20f4b4c 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs @@ -1,12 +1,42 @@ -pub(crate) mod config; // TODO: should we make this public? What benefits would that bring? +pub mod config; // TODO: should we make this public? What benefits would that bring? + +#[cfg(feature = "rdkafka")] +mod rdkafka; pub trait DshKafkaConfig { - /// Set all required configurations to consume messages from DSH. + /// Set all required configurations to consume messages from DSH Kafka Cluster. + /// + /// | **config** | **Default value** | **Remark** | + /// |---------------------------|----------------------------------|------------------------------------------------------------------------| + /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | + /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| + /// | `client.id` | Task_id of service | | + /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | + /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | + /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | + /// | `ssl.key.pem` | private key | Generated when sdk is initiated | + /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | + /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | fn dsh_consumer_config(&mut self) -> &mut Self; - /// Set all required configurations to produce messages to DSH. + /// Set all required configurations to produce messages to DSH Kafka Cluster. + /// + /// ## Configurations + /// | **config** | **Default value** | **Remark** | + /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| + /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | + /// | client.id | task_id of service | Based on task_id of running service | + /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | + /// | ssl.key.pem | private key | Generated when bootstrap is initiated | + /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | + /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | fn dsh_producer_config(&mut self) -> &mut Self; /// Set a DSH compatible group id. /// /// DSH Requires a group id with the prefix of the tenant name. - fn set_group_id(&mut self, group_id: &str) -> &mut Self; + fn set_dsh_group_id(&mut self, group_id: &str) -> &mut Self; + /// Set the required DSH Certificates. + /// + /// This function will set the required SSL configurations if the certificates are present. + /// Else it will return plaintext. (for connection to a local kafka cluster) + fn set_dsh_certificates(&mut self) -> &mut Self; } diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs new file mode 100644 index 0000000..9457dc9 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs @@ -0,0 +1,88 @@ +use rdkafka::ClientConfig; + +use super::DshKafkaConfig; +use crate::Dsh; + +impl DshKafkaConfig for ClientConfig { + fn dsh_consumer_config(&mut self) -> &mut Self { + let dsh = Dsh::get(); + let client_id = dsh.client_id(); + let config = dsh.kafka_config(); + + self.set("bootstrap.servers", config.kafka_brokers()) + .set("group.id", config.group_id()) + .set("client.id", client_id) + .set( + "enable.auto.commit", + config.enable_auto_commit().to_string(), + ) + .set("auto.offset.reset", config.auto_offset_reset()); + if let Some(session_timeout) = config.session_timeout() { + self.set("session.timeout.ms", session_timeout.to_string()); + } + if let Some(queued_buffering_max_messages_kbytes) = + config.queued_buffering_max_messages_kbytes() + { + self.set( + "queued.max.messages.kbytes", + queued_buffering_max_messages_kbytes.to_string(), + ); + } + log::debug!("Consumer config: {:#?}", self); + self.set_dsh_certificates(); + self + } + + fn dsh_producer_config(&mut self) -> &mut Self { + let dsh = Dsh::get(); + let client_id = dsh.client_id(); + let config = dsh.kafka_config(); + self.set("bootstrap.servers", config.kafka_brokers()) + .set("client.id", client_id); + if let Some(batch_num_messages) = config.batch_num_messages() { + self.set("batch.num.messages", batch_num_messages.to_string()); + } + if let Some(queue_buffering_max_messages) = config.queue_buffering_max_messages() { + self.set( + "queue.buffering.max.messages", + queue_buffering_max_messages.to_string(), + ); + } + if let Some(queue_buffering_max_kbytes) = config.queue_buffering_max_kbytes() { + self.set( + "queue.buffering.max.kbytes", + queue_buffering_max_kbytes.to_string(), + ); + } + if let Some(queue_buffering_max_ms) = config.queue_buffering_max_ms() { + self.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); + } + log::debug!("Producer config: {:#?}", self); + self.set_dsh_certificates(); + self + } + + fn set_dsh_group_id(&mut self, group_id: &str) -> &mut Self { + let tenant = Dsh::get().tenant_name(); + if group_id.starts_with(tenant) { + self.set("group.id", group_id) + } else { + self.set("group.id", &format!("{}_{}", tenant, group_id)) + } + } + + fn set_dsh_certificates(&mut self) -> &mut Self { + let dsh = Dsh::get(); + if let Ok(certificates) = dsh.certificates() { + self.set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()) + } else { + self.set("security.protocol", "plaintext") + } + } +} diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs index e3dc107..64a79f7 100644 --- a/dsh_sdk/src/protocol_adapters/mod.rs +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "http-protocol-adapter")] pub mod http_protocol; +#[cfg(feature = "kafka")] pub mod kafka_protocol; #[cfg(feature = "mqtt-protocol-adapter")] pub mod mqtt_protocol; diff --git a/dsh_sdk/src/schema_store/types/subject_strategy.rs b/dsh_sdk/src/schema_store/types/subject_strategy.rs index db756c8..310a0c9 100644 --- a/dsh_sdk/src/schema_store/types/subject_strategy.rs +++ b/dsh_sdk/src/schema_store/types/subject_strategy.rs @@ -127,8 +127,6 @@ impl Hash for SubjectName { #[cfg(test)] mod tests { - use openssl::hash; - use super::*; use std::hash::DefaultHasher; diff --git a/dsh_sdk/src/utils/dlq.rs b/dsh_sdk/src/utils/dlq.rs index 84d605f..004ceaa 100644 --- a/dsh_sdk/src/utils/dlq.rs +++ b/dsh_sdk/src/utils/dlq.rs @@ -1,7 +1,8 @@ //! # Dead Letter Queue //! This optional module contains an implementation of pushing unprocessable/invalid messages towards a Dead Letter Queue (DLQ). +//! It is implemeted with [rdkafka] and [tokio]. //! -//! add feature `dlq` to your Cargo.toml to enable this module +//! add feature `dlq` to your Cargo.toml to enable this module. //! //! ### NOTE: //! This module is meant for pushing messages towards a dead/retry topic only, it does and WILL not handle any logic for retrying messages. @@ -13,29 +14,32 @@ //! The DLQ struct can //! //! ## How to use -//! 1. Implement the `ErrorToDlq` trait on top your (custom) error type. -//! 2. Initialize the `Dlq` struct in your service in main. -//! 3. Get the dlq channel sender from the `Dlq` struct and use this channel to communicate with the `Dlq` struct from other threads. -//! 4. Run the `Dlq` struct in a separate tokio thread. This will run the producer that will produce towards the dead/retry topics. +//! 1. Implement the [ErrorToDlq] trait on top your (custom) error type. +//! 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) +//! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method. //! -//! The topics are set via environment variables DLQ_DEAD_TOPIC and DLQ_RETRY_TOPIC. +//! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. //! //! ### Example: -//! See the examples folder on github for a working example. +//! https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs use std::collections::HashMap; -use std::env; use std::str::from_utf8; use log::{debug, error, info, warn}; - +use rdkafka::client::DefaultClientContext; +use rdkafka::error::KafkaError; use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; use rdkafka::producer::{FutureProducer, FutureRecord}; - +use rdkafka::ClientConfig; use tokio::sync::mpsc; -use crate::graceful_shutdown::Shutdown; -use crate::Properties; +use crate::utils::get_env_var; +use crate::utils::graceful_shutdown::Shutdown; +use crate::DshKafkaConfig; + +/// Channel to send messages to the dead letter queue +pub type DlqChannel = mpsc::Sender; /// Trait to convert an error to a dlq message /// This trait is implemented for all errors that can and should be converted to a dlq message @@ -95,7 +99,7 @@ impl SendToDlq { } } /// Send message to dlq channel - pub async fn send(self, dlq_tx: &mut mpsc::Sender) { + pub async fn send(self, dlq_tx: &mut DlqChannel) { match dlq_tx.send(self).await { Ok(_) => debug!("Message sent to DLQ channel"), Err(e) => error!("Error sending message to DLQ: {}", e), @@ -125,70 +129,104 @@ impl std::fmt::Display for Retryable { } } -/// Struct with implementation to send messages to the dlq +/// The dead letter queue +/// +/// ## How to use +/// 1. Implement the [ErrorToDlq] trait on top your (custom) error type. +/// 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) +/// 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method. +/// +/// # Example +/// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) pub struct Dlq { dlq_producer: FutureProducer, dlq_rx: mpsc::Receiver, - dlq_tx: mpsc::Sender, dlq_dead_topic: String, dlq_retry_topic: String, - shutdown: Shutdown, + _shutdown: Shutdown, // hold the shutdown alive until exit } impl Dlq { - /// Create new Dlq struct - pub fn new( - dsh_prop: &Properties, - shutdown: Shutdown, - ) -> Result> { - use crate::datastream::ReadWriteAccess; + /// Start the dlq on a tokio task + /// + /// The DLQ will run until the return `Sender` is dropped. + /// + /// # Arguments + /// * `shutdown` - The shutdown is required to keep the DLQ alive until the DLQ Sender is dropped + /// + /// # Returns + /// * The [DlqChannel] to send messages to the DLQ + /// + /// # Note + /// **NEVER** borrow the [DlqChannel] to your consumer, always use an owned [DlqChannel]. + /// This is required to stop the gracefull shutdown the DLQ as it depends on the [DlqChannel] to be dropped. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// use dsh_sdk::utils::dlq::{Dlq, DlqChannel, SendToDlq}; + /// + /// async fn consume(dlq_channel: DlqChannel) { + /// // Your consumer logic together with error handling + /// loop { + /// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + /// } + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let shutdown = Shutdown::new(); + /// let dlq_channel = Dlq::start(shutdown.clone()).unwrap(); + /// + /// tokio::select! { + /// _ = async move { + /// // Your consumer logic together with the owned dlq_channel + /// dlq_channel + /// } => {} + /// _ = shutdown.signal_listener() => { + /// println!("Shutting down consumer"); + /// } + /// } + /// // wait for graceful shutdown to complete + /// // NOTE that the `dlq_channel` will go out of scope when shutdown is called and the DLQ will stop + /// shutdown.complete().await; + /// } + /// ``` + pub fn start(shutdown: Shutdown) -> Result> { let (dlq_tx, dlq_rx) = mpsc::channel(200); - let dlq_producer = Self::build_producer(dsh_prop)?; - let dlq_dead_topic = env::var("DLQ_DEAD_TOPIC")?; - let dlq_retry_topic = env::var("DLQ_RETRY_TOPIC")?; - dsh_prop.datastream().verify_list_of_topics( - &vec![&dlq_dead_topic, &dlq_retry_topic], - ReadWriteAccess::Write, - )?; - Ok(Self { + let dlq_producer: FutureProducer = + ClientConfig::new().dsh_producer_config().create()?; + let dlq_dead_topic = get_env_var("DLQ_DEAD_TOPIC")?; + let dlq_retry_topic = get_env_var("DLQ_RETRY_TOPIC")?; + let dlq = Self { dlq_producer, dlq_rx, - dlq_tx, dlq_dead_topic, dlq_retry_topic, - shutdown, - }) + _shutdown: shutdown, + }; + tokio::spawn(dlq.run()); + Ok(dlq_tx) } /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics /// This function will run until the shutdown channel is closed - pub async fn run(&mut self) { + async fn run(mut self) { info!("DLQ started"); loop { - tokio::select! { - _ = self.shutdown.recv() => { - warn!("DLQ shutdown"); - return; - }, - Some(mut dlq_message) = self.dlq_rx.recv() => { - match self.send(&mut dlq_message).await { - Ok(_) => {}, - Err(e) => error!("Error sending message to DLQ: {}", e), - }; - } + if let Some(mut dlq_message) = self.dlq_rx.recv().await { + match self.send(&mut dlq_message).await { + Ok(_) => {} + Err(e) => error!("Error sending message to DLQ: {}", e), + }; + } else { + warn!("DLQ stopped as there is no active DLQ Channel"); + break; } } } - - /// Get the dlq channel sender. To be used in your service to send messages to the dlq in case of errors. - /// - /// This channel can be used to send messages to the dlq from different threads. - pub fn dlq_records_tx(&self) -> mpsc::Sender { - self.dlq_tx.clone() - } - /// Create and send message towards the dlq - async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), rdkafka::error::KafkaError> { + async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), KafkaError> { let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); let headers = orignal_kafka_msg .generate_dlq_headers(dlq_message) @@ -201,8 +239,8 @@ impl Dlq { .payload(payload) .key(key) .headers(headers); - let s = self.dlq_producer.send(record, None).await; - match s { + let send = self.dlq_producer.send(record, None).await; + match send { Ok((p, o)) => warn!( "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", from_utf8(key), @@ -222,10 +260,6 @@ impl Dlq { Retryable::Other => &self.dlq_dead_topic, } } - - fn build_producer(dsh_prop: &Properties) -> Result { - dsh_prop.producer_rdkafka_config().create() - } } trait DlqHeaders { @@ -445,10 +479,9 @@ mod tests { let dlq = Dlq { dlq_producer: producer, dlq_rx: mpsc::channel(200).1, - dlq_tx: mpsc::channel(200).0, dlq_dead_topic: "dead_topic".to_string(), dlq_retry_topic: "retry_topic".to_string(), - shutdown: Shutdown::new(), + _shutdown: Shutdown::new(), }; let error = MockError::MockErrorRetryable("some_error".to_string()); let topic = dlq.dlq_topic(error.retryable()); diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index 365c4c5..684c887 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -12,7 +12,7 @@ use crate::error::DshError; pub mod dlq; #[cfg(feature = "graceful-shutdown")] pub mod graceful_shutdown; -#[cfg(feature = "hyper-client")] +#[cfg(feature = "hyper-client")] // TODO: to be implemented pub(crate) mod http_client; #[cfg(feature = "metrics")] pub mod metrics; diff --git a/example_dsh_service/Cargo.toml b/example_dsh_service/Cargo.toml index 09c1af7..6df8ac9 100644 --- a/example_dsh_service/Cargo.toml +++ b/example_dsh_service/Cargo.toml @@ -5,7 +5,8 @@ description = "An example of DSH service using the dsh-sdk crate" edition = "2021" [dependencies] -dsh_sdk = { path = "../dsh_sdk", version = "0.5", features = ["rdkafka-ssl-vendored"] } +dsh_sdk = { path = "../dsh_sdk", version = "0.5.0-rc.1", features = ["rdkafka-config", "metrics", "graceful-shutdown"] } +rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } log = "0.4" env_logger = "0.11" tokio = { version = "^1.35", features = ["full"] } \ No newline at end of file diff --git a/example_dsh_service/src/custom_metrics.rs b/example_dsh_service/src/custom_metrics.rs index d42924e..a3e93d3 100644 --- a/example_dsh_service/src/custom_metrics.rs +++ b/example_dsh_service/src/custom_metrics.rs @@ -1,4 +1,4 @@ -use dsh_sdk::metrics::*; +use dsh_sdk::utils::metrics::*; lazy_static! { pub static ref CONSUMED_MESSAGES: IntCounter = diff --git a/example_dsh_service/src/main.rs b/example_dsh_service/src/main.rs index 71a6704..e5e7e8d 100644 --- a/example_dsh_service/src/main.rs +++ b/example_dsh_service/src/main.rs @@ -1,7 +1,7 @@ -use dsh_sdk::graceful_shutdown::Shutdown; -use dsh_sdk::rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; -use dsh_sdk::rdkafka::message::{BorrowedMessage, Message}; -use dsh_sdk::Properties; +use dsh_sdk::utils::graceful_shutdown::Shutdown; +use dsh_sdk::Dsh; +use rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; +use rdkafka::message::{BorrowedMessage, Message}; use log::{error, info}; @@ -52,20 +52,15 @@ async fn main() -> Result<(), Box> { .init(); // Start http server for exposing prometheus metrics, note that in Dockerfile we expose port 8080 as well - dsh_sdk::metrics::start_http_server(8080); + dsh_sdk::utils::metrics::start_http_server(8080); // Create a new properties instance (connects to the DSH server and fetches the datastream) - let dsh_properties = Properties::get(); + let dsh_properties = Dsh::get(); // Get the configured topics from env variable TOPICS (comma separated) let topics_string = std::env::var("TOPICS").expect("TOPICS env variable not set"); let topics = topics_string.split(',').collect::>(); - // Validate your configured topic if it has read access (optional) - dsh_properties - .datastream() - .verify_list_of_topics(&topics, dsh_sdk::dsh_old::datastream::ReadWriteAccess::Read)?; - // Initialize the shutdown handler (This will handle SIGTERM and SIGINT signals, and you can act on them) let shutdown = Shutdown::new(); @@ -73,7 +68,7 @@ async fn main() -> Result<(), Box> { let mut consumer_client_config = dsh_properties.consumer_rdkafka_config(); // Override some default values (optional) - consumer_client_config.set("auto.offset.reset", "latest"); + consumer_client_config.set("auto.offset.reset", "earliest"); // Create a new consumer instance let consumer: StreamConsumer = consumer_client_config.create()?; From 17c552b938a675cac68a133ad42d42d7777d8589 Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Fri, 3 Jan 2025 15:12:04 +0100 Subject: [PATCH 04/23] Fix missing rdkafka-config default features --- dsh_sdk/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 4c85dca..abe7858 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -42,7 +42,7 @@ tokio-util = { version = "0.7", default-features = false, optional = true } [features] # default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl", "schema-store"] -default = ["bootstrap", "kafka"] +default = ["bootstrap", "kafka", "rdkafka-config"] bootstrap = ["certificate", "serde_json", "tokio/rt-multi-thread"] kafka = ["bootstrap"] From bfcd92288da05387aeb482d1a77e6a791d95459c Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:37:02 +0100 Subject: [PATCH 05/23] Feature/update examples (#103) * small fix in readme * update existing examples * Update readme/changelog/security * update examples * align methods with RDkafka * cargo fmt * Remove certificate flag * improve schema store api (WIP) * improve schema store --- SECURITY.md | 4 +- dsh_sdk/CHANGELOG.md | 4 +- dsh_sdk/Cargo.toml | 14 +- dsh_sdk/README.md | 77 +++---- dsh_sdk/examples/custom_metrics.rs | 2 +- dsh_sdk/examples/dlq_implementation.rs | 24 ++- dsh_sdk/examples/expose_metrics.rs | 2 +- dsh_sdk/examples/graceful_shutdown.rs | 16 +- .../{produce_consume.rs => kafka_example.rs} | 10 +- dsh_sdk/examples/kafka_proxy.rs | 43 ++++ ...her.rs => management_api_token_fetcher.rs} | 4 +- ...n_fetcher.rs => protocol_token_fetcher.rs} | 5 +- ...protocol_token_fetcher_specific_claims.rs} | 15 +- dsh_sdk/examples/schema_store_api.rs | 50 +++++ dsh_sdk/src/certificates/mod.rs | 22 -- dsh_sdk/src/dsh.rs | 4 +- dsh_sdk/src/dsh_old/certificates.rs | 24 --- dsh_sdk/src/error.rs | 19 +- dsh_sdk/src/lib.rs | 4 +- dsh_sdk/src/management_api/error.rs | 17 ++ dsh_sdk/src/management_api/mod.rs | 1 + dsh_sdk/src/management_api/token_fetcher.rs | 36 ++-- .../protocol_adapters/kafka_protocol/mod.rs | 4 +- .../kafka_protocol/rdkafka.rs | 5 +- .../protocol_adapters/token_fetcher/mod.rs | 19 +- dsh_sdk/src/schema_store/client.rs | 202 ++++++++++-------- dsh_sdk/src/schema_store/error.rs | 11 +- dsh_sdk/src/schema_store/mod.rs | 11 +- .../schema_store/types/subject_strategy.rs | 42 ++-- dsh_sdk/src/utils/dlq.rs | 2 +- example_dsh_service/src/main.rs | 23 +- 31 files changed, 399 insertions(+), 317 deletions(-) rename dsh_sdk/examples/{produce_consume.rs => kafka_example.rs} (82%) create mode 100644 dsh_sdk/examples/kafka_proxy.rs rename dsh_sdk/examples/{rest_api_token_fetcher.rs => management_api_token_fetcher.rs} (89%) rename dsh_sdk/examples/{mqtt_token_fetcher.rs => protocol_token_fetcher.rs} (69%) rename dsh_sdk/examples/{mqtt_token_fetcher_specific_claims.rs => protocol_token_fetcher_specific_claims.rs} (61%) create mode 100644 dsh_sdk/examples/schema_store_api.rs create mode 100644 dsh_sdk/src/management_api/error.rs diff --git a/SECURITY.md b/SECURITY.md index 78d67d9..ae6cc72 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -15,7 +15,8 @@ The following versions of this project are currently being supported with securi | Version | Supported | | ------- | ------------------ | -| 0.4.x | :white_check_mark: | +| 0.5.x | :white_check_mark: | +| 0.4.x | :white_check_mark: (till 28/02/2025) | | 0.3.x | :x: | | 0.2.x | :x: | | 0.1.x | :x: | @@ -25,7 +26,6 @@ The following versions of this project are currently being supported with securi [![dependency status](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs/status.svg)](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs). - ## Reporting a Vulnerability If you have found a vulnerability or bug, you can report it to unibox@kpn.com. diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 8b4e336..cb2271d 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.5.0] - unreleased ### Added +- DSH Kafka Config trait to configure kafka client with RDKafka implementation +- DSH Schema store API Client - New public functions `dsh_sdk::certificates::Cert` - Bootstrap to DSH - Read certificates from PKI_CONFIG_DIR @@ -25,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module - Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` where it is renamed to `ProtocolTokenFetcher` - - **NOTE** Cargo.toml feature flag falls now under `mqtt-protocol` (`mqtt_token_fetcher` will be removed in v0.6.0) + - **NOTE** Cargo.toml feature flag falls now under `protocol-token-fetcher` (`mqtt-token-fetcher` will be removed in v0.6.0) - Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` - Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` - Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index abe7858..f9a79dd 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -32,7 +32,7 @@ prometheus = { version = "0.13", features = ["process"], optional = true } protofish = { version = "0.5.2", optional = true } rcgen = { version = "0.13", optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "blocking"], optional = true } -rdkafka = { version = ">=0.36", default-features = false, optional = true } +rdkafka = { version = "0.37", default-features = false, optional = true } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"], optional = true } sha2 = { version = "0.10", optional = true} @@ -44,17 +44,15 @@ tokio-util = { version = "0.7", default-features = false, optional = true } # default = ["bootstrap", "graceful-shutdown", "metrics", "rdkafka-ssl", "schema-store"] default = ["bootstrap", "kafka", "rdkafka-config"] -bootstrap = ["certificate", "serde_json", "tokio/rt-multi-thread"] +bootstrap = ["rcgen", "reqwest", "pem", "serde_json", "tokio/rt-multi-thread"] kafka = ["bootstrap"] rdkafka-config = ["rdkafka", "kafka"] # Impl of config trait only - -certificate = ["rcgen", "reqwest", "pem"] schema-store = ["bootstrap", "reqwest", "serde_json", "apache-avro", "protofish"] -metrics = ["prometheus", "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] -dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/ssl-vendored", "rdkafka/libz", "rdkafka/tokio", "graceful-shutdown"] graceful-shutdown = ["tokio", "tokio-util"] management-api-token-fetcher = ["reqwest"] protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] +metrics = ["prometheus", "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] +dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/ssl-vendored", "rdkafka/libz", "rdkafka/tokio", "graceful-shutdown"] # http-protocol-adapter = ["protocol-token-fetcher"] # mqtt-protocol-adapter = ["protocol-token-fetcher"] @@ -67,7 +65,7 @@ openssl = "0.10" tokio = { version = "^1.35", features = ["full"] } hyper = { version = "1.3", features = ["full"] } serial_test = "3.1.0" -dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.2.0"} +dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.2.0" } dsh_sdk = { features = ["dlq"], path = "." } env_logger = "0.11" -rdkafka = { version = ">=0.36", features = ["cmake-build", "ssl-vendored"], default-features = true } \ No newline at end of file +rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"], default-features = true } \ No newline at end of file diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 7f8edac..d5ca76f 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -15,20 +15,16 @@ See [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migrat ## Description This library can be used to interact with the DSH Platform. It is intended to be used as a base for services that will be used to interact with DSH. Features include: -- Connect to DSH -- Fetch Kafka Properties and certificates -- Rest API Token Fetcher (to be used with [dsh_rest_api_client](https://crates.io/crates/dsh_rest_api_client)) -- MQTT Token Fetcher -- Common functions - - Preconfigured RDKafka client config - - Preconfigured Reqwest client config (for schema store) -- Graceful shutdown -- Prometheus Metrics (web server and re-export of metrics crate) -- Dead Letter Queue (experimental) - -### Note -Rdkafka and thereby this library is dependent on CMAKE. Make sure it is installed in your environment and/or Dockerfile where you are compiling. -See [dockerfile](../example_dsh_service/Dockerfile) for an example. +- Connect to DSH Kafka (DSH, Kafka Proxy, VPN, System Space, Local) + - Bootstrap (fetch datastreams info and generate signed certificate) + - PKI Config Directory (for Kafka Proxy/VPN) +- Kafka config for DSH (incl. RDKafka) +- Management API Token Fetcher (to be used with [dsh_rest_api_client](https://crates.io/crates/dsh_rest_api_client)) +- Protocol Token Fetcher (MQTT and HTTP) +- Common utilities + - Prometheus Metrics (web server and re-export of metrics crate) + - Graceful shutdown + - Dead Letter Queue ## Usage To use this SDK with the default features in your project, add the following to your Cargo.toml file: @@ -36,16 +32,8 @@ To use this SDK with the default features in your project, add the following to ```toml [dependencies] dsh_sdk = "0.5" +rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } ``` - -However, if you would like to use only specific features, you can specify them in your Cargo.toml file. For example, if you would like to use only the bootstrap feature, add the following to your Cargo.toml file: - -```toml -[dependencies] -dsh_sdk = { version = "0.5", default-features = false, features = ["rdkafka"] } -rdkafka = { version = "0.37", features = ["cmake-buld", "ssl-vendored"] } -``` - See [feature flags](#feature-flags) for more information on the available features. To use this SDK in your project @@ -56,7 +44,7 @@ use rdkafka::ClientConfig; fn main() -> Result<(), Box>{ // get a rdkafka consumer config for example - let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; + let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; } ``` @@ -69,30 +57,33 @@ See the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Mi The following features are available in this library and can be enabled/disabled in your Cargo.toml file: -| **feature** | **default** | **Description** | -|---|---|---| -| `bootstrap` | ✓ | Generate signed certificate and fetch datastreams info | -| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | -| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | -| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | -| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | -| `metrics` | ✗ | Enable prometheus metrics including http server | -| `graceful-shutdown` | ✗ | Tokio based gracefull shutdown handler | -| `dlq` | ✗ | Dead Letter Queue implementation | -| `rest-token-fetcher` | ✗ | Replaced by `management-api-token-fetcher` | -| `mqtt-token-fetcher` | ✗ | Replaced by `protocol-token-fetcher` | - -See api documentation for more information on how to use these features including. +| **feature** | **default** | **Description** | **Example** | +| --- |--- | --- | --- | +| `bootstrap` | ✓ | Certificate signing process and fetch datastreams info | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | +| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | +| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [expose metrics](./examples/expose_metrics.rs) / [custom metrics](./examples/custom_metrics.rs) | +| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | +| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementatione example](./examples/dlq_implementation.rs) | -## Environment variables -The default RDKafka config can be overwritten by setting environment variables. See [ENV_VARIABLES.md](ENV_VARIABLES.md) for more information. +See the [api documentation](https://docs.rs/dsh_sdk/latest/dsh_sdk/) for more information on how to use these features. + +If you would like to use specific features, you can specify them in your Cargo.toml file. This can save compile time and dependencies. +For example, if you only want to use the Management API token fetcher feature, add the following to your Cargo.toml file: +```toml +[dependencies] +dsh_sdk = { version = "0.5", default-features = false, features = ["management-api-token-fetcher"] } +``` -## Api doc -See the [api documentation](https://docs.rs/dsh_sdk/latest/dsh_sdk/) for more information on how to use this library. +## Environment variables +The SDK checks environment variables to change configuration for See [ENV_VARIABLES.md](ENV_VARIABLES.md) which . ## Examples -See folder [dsh_sdk/examples](/examples/) for simple examples on how to use the SDK. +See folder [dsh_sdk/examples](./examples/) for simple examples on how to use the SDK. ### Full service example See folder [example_dsh_service](../example_dsh_service/) for a full service, including how to build the Rust project and post it to Harbor. See [readme](../example_dsh_service/README.md) for more information. diff --git a/dsh_sdk/examples/custom_metrics.rs b/dsh_sdk/examples/custom_metrics.rs index d66cd25..5874ade 100644 --- a/dsh_sdk/examples/custom_metrics.rs +++ b/dsh_sdk/examples/custom_metrics.rs @@ -1,4 +1,4 @@ -use dsh_sdk::metrics::*; +use dsh_sdk::utils::metrics::*; lazy_static! { pub static ref HIGH_FIVE_COUNTER: IntCounter = diff --git a/dsh_sdk/examples/dlq_implementation.rs b/dsh_sdk/examples/dlq_implementation.rs index c3ffaa0..18919ea 100644 --- a/dsh_sdk/examples/dlq_implementation.rs +++ b/dsh_sdk/examples/dlq_implementation.rs @@ -7,7 +7,11 @@ use rdkafka::message::{BorrowedMessage, Message, OwnedMessage}; use rdkafka::ClientConfig; use std::backtrace::Backtrace; use thiserror::Error; -use tokio::sync::mpsc; + +// Required environment variables for DLQ +const DLQ_DEAD_TOPIC: &str = "scratch.dlq.local-tenant"; // Topic to send non-retryable messages to +const DLQ_RETRY_TOPIC: &str = "scratch.dlq.local-tenant"; // Topic to send retryable messages to (can be the same as DLQ_DEAD_TOPIC) +const TOPIC: &str = "scratch.topic-name.local-tenant"; // topic to consume from // Define your custom error type #[derive(Error, Debug)] @@ -81,26 +85,24 @@ async fn consume( #[tokio::main] async fn main() -> Result<(), Box> { - // set the dlq topics (required) - std::env::set_var("DLQ_DEAD_TOPIC", "scratch.dlq.local-tenant"); - std::env::set_var("DLQ_RETRY_TOPIC", "scratch.dlq.local-tenant"); - - // Topic to subscribe to (change to your topic) - let topic = "your_topic_name"; + // Set the dlq topics (required) + // Normally injected via DSH Config + std::env::set_var("DLQ_DEAD_TOPIC", DLQ_DEAD_TOPIC); + std::env::set_var("DLQ_RETRY_TOPIC", DLQ_RETRY_TOPIC); let shutdown = Shutdown::new(); - let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; + let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; - // Start the `dlq` service, returns a sender to send messages to the dlq + // Start the `Dlq` service, returns a sender to send messages to the dlq let dlq_channel = dlq::Dlq::start(shutdown.clone())?; // run the `consumer` in a separate tokio task let shutdown_clone = shutdown.clone(); let consumer_handle = tokio::spawn(async move { - consume(consumer, topic, dlq_channel, shutdown_clone).await; + consume(consumer, TOPIC, dlq_channel, shutdown_clone).await; }); - // wait for `consumer` or `dlq` to finish or for shutdown signal + // wait for `consumer` to shutdown for shutdown signal tokio::select! { _ = consumer_handle => { println!("Consumer finished"); diff --git a/dsh_sdk/examples/expose_metrics.rs b/dsh_sdk/examples/expose_metrics.rs index c55343f..59906dc 100644 --- a/dsh_sdk/examples/expose_metrics.rs +++ b/dsh_sdk/examples/expose_metrics.rs @@ -1,4 +1,4 @@ -use dsh_sdk::metrics::*; +use dsh_sdk::utils::metrics::*; lazy_static! { pub static ref HIGH_FIVE_COUNTER: IntCounter = diff --git a/dsh_sdk/examples/graceful_shutdown.rs b/dsh_sdk/examples/graceful_shutdown.rs index f2114d9..fb85d3d 100644 --- a/dsh_sdk/examples/graceful_shutdown.rs +++ b/dsh_sdk/examples/graceful_shutdown.rs @@ -1,4 +1,11 @@ -use dsh_sdk::graceful_shutdown::Shutdown; +//! Example on how to implement a graceful shutdown in a tokio application. +//! +//! Run the example with: +//! ```bash +//! cargo run --example graceful_shutdown +//! ``` + +use dsh_sdk::utils::graceful_shutdown::Shutdown; // your process task async fn process_task(shutdown: Shutdown) { @@ -6,7 +13,7 @@ async fn process_task(shutdown: Shutdown) { tokio::select! { _ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => { // Do something here, e.g. consume messages from Kafka - println!("Still processing the task, press Ctrl+C to exit") + println!("Still processing the task, press Ctrl+C or send SIGTERM to exit") }, _ = shutdown.recv() => { // shutdown request received, include your shutdown procedure here e.g. close db connection @@ -20,14 +27,13 @@ async fn process_task(shutdown: Shutdown) { #[tokio::main] async fn main() { // Create shutdown handle - let shutdown = dsh_sdk::graceful_shutdown::Shutdown::new(); + let shutdown = Shutdown::new(); // Create your process task with a cloned shutdown handle let process_task = process_task(shutdown.clone()); // Spawn your process task in a tokio runtime let process_task_handle = tokio::spawn(async move { process_task.await; }); - // Listen for shutdown request or if process task stopped // If your process stops, start shutdown procedure to stop other tasks (if any) tokio::select! { @@ -35,6 +41,6 @@ async fn main() { _ = process_task_handle => {println!("process_task stopped"); shutdown.start()}, } // Wait till shutdown procedures is finished - let _ = shutdown.complete().await; + shutdown.complete().await; println!("Exiting main...") } diff --git a/dsh_sdk/examples/produce_consume.rs b/dsh_sdk/examples/kafka_example.rs similarity index 82% rename from dsh_sdk/examples/produce_consume.rs rename to dsh_sdk/examples/kafka_example.rs index ef2788f..0324f78 100644 --- a/dsh_sdk/examples/produce_consume.rs +++ b/dsh_sdk/examples/kafka_example.rs @@ -1,3 +1,9 @@ +//! Simple producer and consumer example that sends and receives messages from a Kafka topic +//! +//! Run the example with: +//! ```bash +//! cargo run --example kafka_example features=rdkafka-config +//! ``` use dsh_sdk::DshKafkaConfig; use rdkafka::consumer::CommitMode; use rdkafka::consumer::{Consumer, StreamConsumer}; @@ -44,13 +50,13 @@ async fn main() -> Result<(), Box> { let topic = "test"; // Create a new producer from the RDkafka Client Config together with dsh_prodcer_config form DshKafkaConfig trait - let producer: FutureProducer = ClientConfig::new().dsh_producer_config().create()?; + let producer: FutureProducer = ClientConfig::new().set_dsh_producer_config().create()?; // Produce messages towards topic produce(producer, topic).await; // Create a new consumer from the RDkafka Client Config together with dsh_consumer_config form DshKafkaConfig trait - let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; + let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; consume(consumer, topic).await; Ok(()) diff --git a/dsh_sdk/examples/kafka_proxy.rs b/dsh_sdk/examples/kafka_proxy.rs new file mode 100644 index 0000000..25ddd9c --- /dev/null +++ b/dsh_sdk/examples/kafka_proxy.rs @@ -0,0 +1,43 @@ +use std::env; + +use dsh_sdk::DshKafkaConfig; +use rdkafka::consumer::{Consumer, StreamConsumer}; +use rdkafka::ClientConfig; +use rdkafka::Message; + +// Enter your details here +const KAFKA_BOOTSTRAP_SERVERS: &str = "kafkaproxy urls"; // example "broker-0.kafka.tenant.kpn-dsh.com:9091,broker-1.kafka.tenant.kpn-dsh.com:9091,broker-2.kafka.tenant.kpn-dsh.com:9091" +const PKI_CONFIG_DIR: &str = "path/to/pki/config/dir"; // example /Documents/pki_config_dir/tenant +const DSH_TENANT_NAME: &str = "tenant"; // enter your tenant name (required for creating group id) +const TOPIC: &str = "scratch.topic-name.tenant"; // enter your topic name + +/// Simple consumer that consumes messages from a Kafka topic +async fn consume(consumer: StreamConsumer) { + consumer.subscribe(&[TOPIC]).unwrap(); + loop { + let msg = consumer.recv().await.unwrap(); + let payload = String::from_utf8_lossy(msg.payload().unwrap()); + println!( + "Received message: key: {:?}, payload: {}", + msg.key(), + payload + ); + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Set the environment variables (normally you would set them outside of the code) + env::set_var("KAFKA_BOOTSTRAP_SERVERS", KAFKA_BOOTSTRAP_SERVERS); + env::set_var("PKI_CONFIG_DIR", PKI_CONFIG_DIR); + env::set_var("DSH_TENANT_NAME", DSH_TENANT_NAME); + + // Create a new consumer from the RDkafka Client Config together with dsh_consumer_config form DshKafkaConfig trait + // The config will take over the info from the environment variables and load certificates from the PKI_CONFIG_DIR + // This makes it easy to switch from Kafka Proxy to Normal usage without changing the code + let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; + + // start consuming messages from the topic + consume(consumer).await; + Ok(()) +} diff --git a/dsh_sdk/examples/rest_api_token_fetcher.rs b/dsh_sdk/examples/management_api_token_fetcher.rs similarity index 89% rename from dsh_sdk/examples/rest_api_token_fetcher.rs rename to dsh_sdk/examples/management_api_token_fetcher.rs index dd1df21..899819d 100644 --- a/dsh_sdk/examples/rest_api_token_fetcher.rs +++ b/dsh_sdk/examples/management_api_token_fetcher.rs @@ -6,7 +6,7 @@ //! CLIENT_SECRET=your_client_secret TENANT=your_tenant cargo run --features rest-token-fetcher --example rest_api_token_fetcher //! ``` use dsh_rest_api_client::Client; -use dsh_sdk::{Platform, RestTokenFetcherBuilder}; +use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; use std::env; #[tokio::main] @@ -16,7 +16,7 @@ async fn main() { env::var("CLIENT_SECRET").expect("CLIENT_SECRET must be set as environment variable"); let tenant = env::var("TENANT").expect("TENANT must be set as environment variable"); let client = Client::new(platform.endpoint_rest_api()); - let tf = RestTokenFetcherBuilder::new(platform) + let tf = ManagementApiTokenFetcherBuilder::new(platform) .tenant_name(tenant.clone()) .client_secret(client_secret) .build() diff --git a/dsh_sdk/examples/mqtt_token_fetcher.rs b/dsh_sdk/examples/protocol_token_fetcher.rs similarity index 69% rename from dsh_sdk/examples/mqtt_token_fetcher.rs rename to dsh_sdk/examples/protocol_token_fetcher.rs index 9b48c95..73e09eb 100644 --- a/dsh_sdk/examples/mqtt_token_fetcher.rs +++ b/dsh_sdk/examples/protocol_token_fetcher.rs @@ -1,12 +1,13 @@ use std::env; -use dsh_sdk::mqtt_token_fetcher::{MqttToken, MqttTokenFetcher}; +use dsh_sdk::protocol_adapters::token_fetcher::*; #[tokio::main] async fn main() { let tenant_name = env::var("TENANT").unwrap().to_string(); let api_key = env::var("API_KEY").unwrap().to_string(); - let mqtt_token_fetcher = MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); + let mqtt_token_fetcher = + ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); let token: MqttToken = mqtt_token_fetcher .get_token("Client-id", None) //Claims = None fetches all possible claims .await diff --git a/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs b/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs similarity index 61% rename from dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs rename to dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs index 3439128..4fb6ae2 100644 --- a/dsh_sdk/examples/mqtt_token_fetcher_specific_claims.rs +++ b/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs @@ -1,25 +1,28 @@ use std::env; -use dsh_sdk::mqtt_token_fetcher::{Actions, Claims, MqttToken, MqttTokenFetcher, Resource}; +use dsh_sdk::protocol_adapters::token_fetcher::*; #[tokio::main] async fn main() { + // Get the config and secret from the environment let tenant_name = env::var("TENANT").unwrap().to_string(); let api_key = env::var("API_KEY").unwrap().to_string(); let stream = env::var("STREAM").unwrap().to_string(); + let topic = "#".to_string(); // check MQTT documentation for better understanding of wildcards let resource = Resource::new(stream, "/tt".to_string(), topic, Some("topic".to_string())); - let claims_sub = Claims::new(resource.clone(), Actions::Subscribe.to_string()); + let claims_sub = Claims::new(resource.clone(), Actions::Subscribe); - let claims_pub = Claims::new(resource, Actions::Publish.to_string()); + let claims_pub = Claims::new(resource, Actions::Publish); - let claims_vector = vec![claims_sub, claims_pub]; + let claims = vec![claims_sub, claims_pub]; - let mqtt_token_fetcher = MqttTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); + let mqtt_token_fetcher = + ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); let token: MqttToken = mqtt_token_fetcher - .get_token("Client-id", Some(claims_vector)) + .get_token("Client-id", Some(claims)) .await .unwrap(); println!("MQTT Token: {:?}", token); diff --git a/dsh_sdk/examples/schema_store_api.rs b/dsh_sdk/examples/schema_store_api.rs new file mode 100644 index 0000000..445b15b --- /dev/null +++ b/dsh_sdk/examples/schema_store_api.rs @@ -0,0 +1,50 @@ +use dsh_sdk::schema_store::types::*; +use dsh_sdk::schema_store::SchemaStoreClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create a new SchemaStoreClient, connects to the Schema Registry based on datastreams.json + // However, you can overwrite it by setting the environment variable SCHEMA_REGISTRY_HOST or SchemaStoreClient::new_with_base_url("http://localhost:8081") + let client = SchemaStoreClient::new(); + + // Register a new schema (and subject if not exists) + let schema = r#" + { + "type": "record", + "name": "Test", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"} + ] + } + "#; + let subject_name = + SubjectName::new_topic_name_strategy("scratch.topic-name.tenant-name", false); // "scratch.topic-name.tenant-name-value" + let schema: RawSchemaWithType = schema.try_into()?; + let schema_id = client.subject_add_schema(&subject_name, schema).await?; + println!("Registered schema with id: {}\n", schema_id); + + // Get schema by id + let raw_schema = client.schema(schema_id).await?; + println!("Schema by id {}: {:#?}\n", schema_id, raw_schema); + + // List all subjects + let schemas = client.subjects().await?; + println!("List all registred subjects: {:#?}\n", schemas); + + let subject_name: SubjectName = "scratch.topic-name.tenant-name-value".try_into()?; + // List all schemas for a subject + let schemas_for_subject = client.subject_all_schemas(&subject_name).await?; + println!("List all schemas for subject: {:#?}\n", schemas_for_subject); + + // Get the latest schema for a subject + let latest_schema = client + .subject_raw_schema( + &"scratch.topic-name.tenant-name-value".try_into()?, + SubjectVersion::Latest, + ) + .await?; + println!("Latest schema for subject: {:#?}\n", latest_schema); + + Ok(()) +} diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs index 4caea9e..3976cc7 100644 --- a/dsh_sdk/src/certificates/mod.rs +++ b/dsh_sdk/src/certificates/mod.rs @@ -364,28 +364,6 @@ mod tests { assert!(identity.is_ok()); } - #[test] - fn test_prepare_reqwest_client() { - let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); - let result = Cert::prepare_reqwest_client( - cert.dsh_kafka_certificate_pem(), - &cert.private_key_pem(), - cert.dsh_ca_certificate_pem(), - ); - } - - #[test] - fn test_reqwest_client_config() { - let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); - let client = cert.reqwest_client_config(); - } - - #[test] - fn test_reqwest_blocking_client_config() { - let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); - let client = cert.reqwest_blocking_client_config(); - } - #[test] fn test_ensure_https_prefix() { let host = "http://example.com".to_string(); diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index 8e7e244..182d24d 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -403,7 +403,7 @@ impl Dsh { pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); - config.dsh_consumer_config(); + config.set_dsh_consumer_config(); config } @@ -415,7 +415,7 @@ impl Dsh { pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); - config.dsh_producer_config(); + config.set_dsh_producer_config(); config } } diff --git a/dsh_sdk/src/dsh_old/certificates.rs b/dsh_sdk/src/dsh_old/certificates.rs index 04daaf8..f11e6f3 100644 --- a/dsh_sdk/src/dsh_old/certificates.rs +++ b/dsh_sdk/src/dsh_old/certificates.rs @@ -151,15 +151,6 @@ impl Cert { } } -/// Helper function to ensure that the host starts with `https://` (or `http://`) -fn ensure_https_prefix(host: String) -> String { - if host.starts_with("https://") || host.starts_with("http://") { - host - } else { - format!("https://{}", host) - } -} - #[cfg(test)] mod tests { use super::*; @@ -286,19 +277,4 @@ mod tests { let client = cert.reqwest_blocking_client_config(); assert!(client.is_ok()); } - - #[test] - fn test_ensure_https_prefix() { - let host = "http://example.com".to_string(); - let result = ensure_https_prefix(host); - assert_eq!(result, "http://example.com"); - - let host = "https://example.com".to_string(); - let result = ensure_https_prefix(host); - assert_eq!(result, "https://example.com"); - - let host = "example.com".to_string(); - let result = ensure_https_prefix(host); - assert_eq!(result, "https://example.com"); - } } diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index af93b43..8946112 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -24,7 +24,7 @@ pub enum DshError { #[cfg(feature = "bootstrap")] #[error("Invalid PEM certificate: {0}")] PemError(#[from] pem::PemError), - #[cfg(any(feature = "certificate", feature = "protocol-token-fetcher"))] + #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] #[error("Reqwest: {0}")] ReqwestError(#[from] reqwest::Error), #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] @@ -54,20 +54,3 @@ pub enum DshError { #[error("Hyper error: {0}")] HyperError(#[from] hyper::http::Error), } - -#[cfg(feature = "management-api-token-fetcher")] -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum DshRestTokenError { - #[error("Client ID is unknown")] - UnknownClientId, - #[error("Client secret not set")] - UnknownClientSecret, - #[error("Unexpected failure while fetching token from server: {0}")] - FailureTokenFetch(reqwest::Error), - #[error("Unexpected status code: {status_code}, error body: {error_body:#?}")] - StatusCode { - status_code: reqwest::StatusCode, - error_body: reqwest::Response, - }, -} diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 79e3922..2f5cbeb 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -18,7 +18,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box>{ -//! let consumer: StreamConsumer = ClientConfig::new().dsh_consumer_config().create()?; +//! let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; //! # Ok(()) //! # } //! ``` @@ -75,7 +75,7 @@ #![allow(deprecated)] // to be kept in v0.6.0 -#[cfg(feature = "certificate")] +#[cfg(feature = "bootstrap")] pub mod certificates; #[cfg(feature = "bootstrap")] pub mod datastream; diff --git a/dsh_sdk/src/management_api/error.rs b/dsh_sdk/src/management_api/error.rs new file mode 100644 index 0000000..62e90bd --- /dev/null +++ b/dsh_sdk/src/management_api/error.rs @@ -0,0 +1,17 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum ManagementTokenError { + #[error("Client ID is unknown")] + UnknownClientId, + #[error("Client secret not set")] + UnknownClientSecret, + #[error("Unexpected failure while fetching token from server: {0}")] + FailureTokenFetch(reqwest::Error), + #[error("Unexpected status code: {status_code}, error body: {error_body}")] + StatusCode { + status_code: reqwest::StatusCode, + error_body: String, + }, +} diff --git a/dsh_sdk/src/management_api/mod.rs b/dsh_sdk/src/management_api/mod.rs index a3e273a..8f11a52 100644 --- a/dsh_sdk/src/management_api/mod.rs +++ b/dsh_sdk/src/management_api/mod.rs @@ -1 +1,2 @@ +pub mod error; pub mod token_fetcher; diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index eb43d9a..44c680d 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -40,7 +40,7 @@ use std::time::{Duration, Instant}; use log::debug; use serde::Deserialize; -use crate::error::DshRestTokenError; +use super::error::ManagementTokenError; use crate::utils::Platform; /// Access token of the authentication serveice of DSH. @@ -186,9 +186,9 @@ impl ManagementApiTokenFetcher { /// /// If the cached token is not valid, it will fetch a new token from the server. /// It will return the token as a string, formatted as "{token_type} {token}" - /// If the request fails for a new token, it will return a [DshRestTokenError::FailureTokenFetch] error. + /// If the request fails for a new token, it will return a [ManagementTokenError::FailureTokenFetch] error. /// This will contain the underlying reqwest error. - pub async fn get_token(&self) -> Result { + pub async fn get_token(&self) -> Result { match self.is_valid() { true => Ok(self.access_token.lock().unwrap().formatted_token()), false => { @@ -224,10 +224,12 @@ impl ManagementApiTokenFetcher { /// Fetch a new access token from the server /// /// This will fetch a new access token from the server and return it. - /// If the request fails, it will return a [DshRestTokenError::FailureTokenFetch] error. - /// If the status code is not successful, it will return a [DshRestTokenError::StatusCode] error. + /// If the request fails, it will return a [ManagementTokenError::FailureTokenFetch] error. + /// If the status code is not successful, it will return a [ManagementTokenError::StatusCode] error. /// If the request is successful, it will return the [AccessToken]. - pub async fn fetch_access_token_from_server(&self) -> Result { + pub async fn fetch_access_token_from_server( + &self, + ) -> Result { let response = self .client .post(&self.auth_url) @@ -238,17 +240,17 @@ impl ManagementApiTokenFetcher { ]) .send() .await - .map_err(DshRestTokenError::FailureTokenFetch)?; + .map_err(ManagementTokenError::FailureTokenFetch)?; if !response.status().is_success() { - Err(DshRestTokenError::StatusCode { + Err(ManagementTokenError::StatusCode { status_code: response.status(), - error_body: response, + error_body: response.text().await.unwrap_or_default(), }) } else { response .json::() .await - .map_err(DshRestTokenError::FailureTokenFetch) + .map_err(ManagementTokenError::FailureTokenFetch) } } } @@ -338,10 +340,10 @@ impl ManagementApiTokenFetcherBuilder { /// .build() /// .unwrap(); /// ``` - pub fn build(self) -> Result { + pub fn build(self) -> Result { let client_secret = self .client_secret - .ok_or(DshRestTokenError::UnknownClientSecret)?; + .ok_or(ManagementTokenError::UnknownClientSecret)?; let client_id = self .client_id .or_else(|| { @@ -349,7 +351,7 @@ impl ManagementApiTokenFetcherBuilder { .as_ref() .map(|tenant_name| self.platform.rest_client_id(tenant_name)) }) - .ok_or(DshRestTokenError::UnknownClientId)?; + .ok_or(ManagementTokenError::UnknownClientId)?; let client = self.client.unwrap_or_default(); let token_fetcher = ManagementApiTokenFetcher::new_with_client( client_id, @@ -502,12 +504,12 @@ mod test { tf.auth_url = auth_server.url(); let err = tf.fetch_access_token_from_server().await.unwrap_err(); match err { - DshRestTokenError::StatusCode { + ManagementTokenError::StatusCode { status_code, error_body, } => { assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body.text().await.unwrap(), "Bad request"); + assert_eq!(error_body, "Bad request"); } _ => panic!("Unexpected error: {:?}", err), } @@ -586,12 +588,12 @@ mod test { .client_secret("client_secret".to_string()) .build() .unwrap_err(); - assert!(matches!(err, DshRestTokenError::UnknownClientId)); + assert!(matches!(err, ManagementTokenError::UnknownClientId)); let err = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) .tenant_name("tenant_name".to_string()) .build() .unwrap_err(); - assert!(matches!(err, DshRestTokenError::UnknownClientSecret)); + assert!(matches!(err, ManagementTokenError::UnknownClientSecret)); } } diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs index 20f4b4c..9ea71d8 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs @@ -17,7 +17,7 @@ pub trait DshKafkaConfig { /// | `ssl.key.pem` | private key | Generated when sdk is initiated | /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | - fn dsh_consumer_config(&mut self) -> &mut Self; + fn set_dsh_consumer_config(&mut self) -> &mut Self; /// Set all required configurations to produce messages to DSH Kafka Cluster. /// /// ## Configurations @@ -29,7 +29,7 @@ pub trait DshKafkaConfig { /// | ssl.key.pem | private key | Generated when bootstrap is initiated | /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | - fn dsh_producer_config(&mut self) -> &mut Self; + fn set_dsh_producer_config(&mut self) -> &mut Self; /// Set a DSH compatible group id. /// /// DSH Requires a group id with the prefix of the tenant name. diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs index 9457dc9..74c06e1 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs @@ -1,10 +1,11 @@ +#[cfg(feature = "rdkafka-config")] use rdkafka::ClientConfig; use super::DshKafkaConfig; use crate::Dsh; impl DshKafkaConfig for ClientConfig { - fn dsh_consumer_config(&mut self) -> &mut Self { + fn set_dsh_consumer_config(&mut self) -> &mut Self { let dsh = Dsh::get(); let client_id = dsh.client_id(); let config = dsh.kafka_config(); @@ -33,7 +34,7 @@ impl DshKafkaConfig for ClientConfig { self } - fn dsh_producer_config(&mut self) -> &mut Self { + fn set_dsh_producer_config(&mut self) -> &mut Self { let dsh = Dsh::get(); let client_id = dsh.client_id(); let config = dsh.kafka_config(); diff --git a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs index a029cf0..245a61d 100644 --- a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs +++ b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs @@ -199,8 +199,11 @@ pub struct Claims { } impl Claims { - pub fn new(resource: Resource, action: String) -> Claims { - Claims { resource, action } + pub fn new(resource: Resource, action: Actions) -> Claims { + Claims { + resource, + action: action.to_string(), + } } } @@ -213,8 +216,8 @@ pub enum Actions { impl Display for Actions { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { - Actions::Publish => write!(f, "Publish"), - Actions::Subscribe => write!(f, "Subscribe"), + Actions::Publish => write!(f, "publish"), + Actions::Subscribe => write!(f, "subscribe"), } } } @@ -615,9 +618,9 @@ mod tests { #[test] fn test_actions_display() { let action = Actions::Publish; - assert_eq!(action.to_string(), "Publish"); + assert_eq!(action.to_string(), "publish"); let action = Actions::Subscribe; - assert_eq!(action.to_string(), "Subscribe"); + assert_eq!(action.to_string(), "subscribe"); } #[test] @@ -700,9 +703,9 @@ mod tests { "topic".to_string(), None, ); - let action = "publish".to_string(); + let action = Actions::Publish; - let claims = Claims::new(resource.clone(), action.clone()); + let claims = Claims::new(resource.clone(), action); assert_eq!(claims.resource.stream, "stream"); assert_eq!(claims.action, "publish"); diff --git a/dsh_sdk/src/schema_store/client.rs b/dsh_sdk/src/schema_store/client.rs index 91fc395..622f57c 100644 --- a/dsh_sdk/src/schema_store/client.rs +++ b/dsh_sdk/src/schema_store/client.rs @@ -1,7 +1,7 @@ use super::api::SchemaStoreApi; use super::request::Request; use super::types::*; -use super::{Result, SchemaStoreError}; +use super::Result; use crate::Dsh; /// High level Schema Store Client @@ -35,21 +35,24 @@ where /// ## Returns /// Returns a Result of the compatibility level of given subject /// + /// ## Arguments + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// /// ## Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::SubjectName; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); - /// println!("Config: {:?}", client.subject_compatibility("scratch.example-topic.tenant-value").await); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// println!("Config: {:?}", client.subject_compatibility(&subject_name).await); + /// # Ok(()) /// # } /// - pub async fn subject_compatibility(&self, subject: Sn) -> Result - where - Sn: Into, - { - Ok(self.get_config_subject(subject.into().name()).await?.into()) + pub async fn subject_compatibility(&self, subject: &SubjectName) -> Result { + Ok(self.get_config_subject(subject.name()).await?.into()) } /// Set the compatibility level for a subject @@ -57,32 +60,32 @@ where /// Set compatibility on subject level. With 1 schema stored in the subject, you can change it to any compatibility level. /// Else, you can only change into a less restrictive level. /// + /// ## Arguments + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// /// ## Returns /// Returns a Result of the new compatibility level /// /// ## Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::Compatibility; + /// use dsh_sdk::schema_store::types::{Compatibility, SubjectName}; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); - /// client.subject_compatibility_update("scratch.example-topic.tenant-value", Compatibility::FULL).await.unwrap(); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// client.subject_compatibility_update(&subject_name, Compatibility::FULL).await?; + /// # Ok(()) /// # } /// ``` - /// - /// TODO: untested as this API method does not seem to work at all on DSH - pub async fn subject_compatibility_update( + pub async fn subject_compatibility_update( &self, - subject: Sn, + subject: &SubjectName, compatibility: Compatibility, - ) -> Result - where - Sn: Into, - { + ) -> Result { Ok(self - .put_config_subject(subject.into().name(), compatibility) + .put_config_subject(subject.name(), compatibility) .await? .into()) } @@ -114,19 +117,18 @@ where /// ## Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::SubjectName; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); - /// println!("Available versions: {:?}", client.subject_versions("scratch.example-topic.tenant-value").await); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// println!("Available versions: {:?}", client.subject_versions(&subject_name).await); + /// # Ok(()) /// # } /// ``` - pub async fn subject_versions(&self, subject: Sn) -> Result> - where - Sn: Into, - { - self.get_subjects_subject_versions(subject.into().name()) - .await + pub async fn subject_versions(&self, subject: &SubjectName) -> Result> { + self.get_subjects_subject_versions(subject.name()).await } /// Get subject for specific version @@ -140,25 +142,25 @@ where /// use dsh_sdk::schema_store::types::{SubjectName, SubjectVersion}; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; /// /// // Get the latest version of the schema - /// let subject = client.subject(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}, SubjectVersion::Latest).await.unwrap(); + /// let subject = client.subject(&subject_name, SubjectVersion::Latest).await?; /// let raw_schema = subject.schema; /// /// // Get a specific version of the schema - /// let subject = client.subject("scratch.example-topic.tenant-value", SubjectVersion::Version(1)).await.unwrap(); + /// let subject = client.subject(&subject_name, SubjectVersion::Version(1)).await?; /// let raw_schema = subject.schema; + /// # Ok(()) /// # } /// ``` - pub async fn subject(&self, subject: Sn, version: V) -> Result + pub async fn subject(&self, subject: &SubjectName, version: V) -> Result where - //Sn: TryInto, - Sn: Into, V: Into, { - let subject = subject.into().name(); + let subject = subject.name(); let version = version.into(); self.get_subjects_subject_versions_id(subject, version.to_string()) .await @@ -167,8 +169,8 @@ where /// Get the raw schema string for the specified version of subject. /// /// ## Arguments - /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) - /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) /// /// ## Returns /// Returns a Result of the raw schema string for the given subject and version @@ -176,29 +178,28 @@ where /// ## Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::SubjectName; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); - /// let raw_schema = client.subject_raw_schema("scratch.example-topic.tenant-value", 1).await.unwrap(); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// let raw_schema = client.subject_raw_schema(&subject_name, 1).await.unwrap(); + /// # Ok(()) /// # } /// ``` - pub async fn subject_raw_schema(&self, subject: Sn, version: V) -> Result + pub async fn subject_raw_schema(&self, subject: &SubjectName, version: V) -> Result where - Sn: Into, V: Into, { - self.get_subjects_subject_versions_id_schema( - subject.into().name(), - version.into().to_string(), - ) - .await + self.get_subjects_subject_versions_id_schema(subject.name(), version.into().to_string()) + .await } /// Get all schemas for a subject /// /// ## Arguments - /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) /// /// ## Returns /// Returns a Result of all schemas for the given subject @@ -209,23 +210,38 @@ where /// use dsh_sdk::schema_store::types::SubjectName; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); - /// let subjects = client.subject_all_schemas(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}).await.unwrap(); + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// let subjects = client.subject_all_schemas(&subject_name).await?; + /// # Ok(()) /// # } - pub async fn subject_all_schemas(&self, subject: Sn) -> Result> - where - Sn: Into + Clone, - { - let versions = self.subject_versions(subject.clone()).await?; + pub async fn subject_all_schemas(&self, subject: &SubjectName) -> Result> { + let versions = self.subject_versions(&subject).await?; let mut subjects = Vec::new(); for version in versions { - let subject = self.subject(subject.clone(), version).await?; + let subject = self.subject(&subject, version).await?; subjects.push(subject); } Ok(subjects) } + /// Get all schemas for a topic + /// + /// ## Arguments + /// - `topic`: &str/String of the topic name + /// + /// ## Returns + /// + // pub async fn topic_all_schemas(&self, topic: S) -> Result<(Vec,Vec)> + // where + // S: AsRef, + // { + // let key_schemas = self.subject_all_schemas((topic.as_ref(), true)).await?; + // let value_schemas = self.subject_all_schemas((topic.as_ref(), false)).await?; + // Ok(subjects) + // } + /// Post a new schema for a (new) subject /// /// ## Errors @@ -236,8 +252,8 @@ where /// - schema is invalid /// /// ## Arguments - /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) - /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) /// /// ## Returns /// Returns a Result of the new schema ID. @@ -245,29 +261,34 @@ where /// /// ## Example /// ```no_run - /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SchemaType}; + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// use dsh_sdk::schema_store::types::{RawSchemaWithType, SubjectName}; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); /// + /// // Get subjectname (note it ends on "-value") + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) /// let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; - /// let schema_version = client.subject_add_schema("scratch.example-topic.tenant-value", (raw_schema, SchemaType::AVRO)).await.unwrap(); + /// let schema_with_type:RawSchemaWithType = raw_schema.try_into()?; + /// let schema_version = client.subject_add_schema(&subject_name, schema_with_type).await?; /// /// // Or if you have a schema object - /// let avro_schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); // or ProtoBuf or JSON schema - /// let schema_version = client.subject_add_schema("scratch.example-topic.tenant-value", avro_schema).await.unwrap(); + /// let avro_schema:RawSchemaWithType = apache_avro::Schema::parse_str(raw_schema)?.try_into()?; // or ProtoBuf or JSON schema + /// let schema_version = client.subject_add_schema(&subject_name, avro_schema).await?; + /// # Ok(()) /// # } /// ``` - pub async fn subject_add_schema(&self, subject: Sn, schema: Sc) -> Result - where - Sn: Into, - Sc: TryInto, - { - let schema = schema.try_into()?; + pub async fn subject_add_schema( + &self, + subject: &SubjectName, + schema: RawSchemaWithType, + ) -> Result { Ok(self - .post_subjects_subject_versions(subject.into().name(), schema) + .post_subjects_subject_versions(subject.name(), schema) .await? .id()) } @@ -283,8 +304,8 @@ where /// - schema is invalid /// /// ## Arguments - /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) - /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) /// /// ## Returns /// If schema exists, it will return with the existing version and schema ID. @@ -292,25 +313,25 @@ where /// ## Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::{SubjectName, SchemaType}; + /// use dsh_sdk::schema_store::types::{SubjectName, SchemaType, RawSchemaWithType}; /// /// # #[tokio::main] - /// # async fn main() { + /// # async fn main() -> Result<(), Box> { /// let client = SchemaStoreClient::new(); /// /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) - /// let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; - /// let subject = client.subject_schema_exist("scratch.example-topic.tenant-value", (raw_schema, SchemaType::AVRO)).await.unwrap(); + /// let raw_schema: RawSchemaWithType = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#.try_into()?; + /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// let subject = client.subject_schema_exist(&subject_name, raw_schema).await?; + /// # Ok(()) /// # } /// ``` - pub async fn subject_schema_exist(&self, subject: Sn, schema: Sc) -> Result - where - Sn: Into, - Sc: TryInto, - { - let schema = schema.try_into()?; - self.post_subjects_subject(subject.into().name(), schema) - .await + pub async fn subject_schema_exist( + &self, + subject: &SubjectName, + schema: RawSchemaWithType, + ) -> Result { + self.post_subjects_subject(subject.name(), schema).await } /// Check if schema is compatible with a specific version of a subject based on the compatibility level @@ -319,27 +340,24 @@ where /// If this subject’s compatibility level was never changed, then the global compatibility level applies. /// /// ## Arguments - /// - `subject`: Anything that can be converted into a [SubjectName] (Returns error if invalid SubjectStrategy) + /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) /// - `version`: Anything that can be converted into a [SubjectVersion] - /// - `schema`: Anything that can be converted into a [RawSchemaWithType] + /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) /// /// ## Returns /// Returns a Result of a boolean if the schema is compatible with the given version of the subject - pub async fn subject_new_schema_compatibility( + pub async fn subject_new_schema_compatibility( &self, - subject: Sn, + subject: &SubjectName, version: Sv, - schema: Sc, + schema: RawSchemaWithType, ) -> Result where - Sn: Into, Sv: Into, - Sc: TryInto, { - let schema = schema.try_into()?; Ok(self .post_compatibility_subjects_subject_versions_id( - subject.into().name(), + subject.name(), version.into().to_string(), schema, ) diff --git a/dsh_sdk/src/schema_store/error.rs b/dsh_sdk/src/schema_store/error.rs index 231f72c..497813b 100644 --- a/dsh_sdk/src/schema_store/error.rs +++ b/dsh_sdk/src/schema_store/error.rs @@ -16,13 +16,6 @@ pub enum SchemaStoreError { url: String, error: String, }, - #[error("Empty payload")] - EmptyPayload, - #[error("Failed to decode payload: {0}")] - FailedToDecode(String), - #[error("Failed to parse value onto struct")] - FailedParseToStruct, - - #[error("Protobuf to struct not (yet) implemented")] - NotImplementedProtobufDeserialize, + #[error("Invalid subject name: {0}")] + InvalidSubjectName(String), } diff --git a/dsh_sdk/src/schema_store/mod.rs b/dsh_sdk/src/schema_store/mod.rs index fbd4659..77201a6 100644 --- a/dsh_sdk/src/schema_store/mod.rs +++ b/dsh_sdk/src/schema_store/mod.rs @@ -19,7 +19,8 @@ //! let subjects = client.subjects().await.unwrap(); //! //! // Get the latest version of a subjects value schema -//! let subject = client.subject(SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}, SubjectVersion::Latest).await.unwrap(); +//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); +//! let subject = client.subject(&subject_name, SubjectVersion::Latest).await.unwrap(); //! let raw_schema = subject.schema; //! # } //! ``` @@ -41,7 +42,7 @@ //! assert_eq!(from_tuple, from_struct); //! ``` //! -//! This means you can easily provide the input arguments from other types without converting it yourself. +//! This means you can easily convert into [types::SubjectName] and [types::RawSchemaWithType]. //! For example: //! ```no_run //! use dsh_sdk::schema_store::SchemaStoreClient; @@ -51,11 +52,11 @@ //! # async fn main() { //! let client = SchemaStoreClient::new(); //! -//! let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; -//! client.subject_add_schema("scratch.example-topic.tenant-value", raw_schema).await.unwrap(); // Returns error if schema is not valid +//! let raw_schema: RawSchemaWithType = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#.try_into().unwrap(); +//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); +//! client.subject_add_schema(&subject_name, raw_schema).await.unwrap(); // Returns error if schema is not valid //! # } //! ``` - mod api; mod client; mod error; diff --git a/dsh_sdk/src/schema_store/types/subject_strategy.rs b/dsh_sdk/src/schema_store/types/subject_strategy.rs index 310a0c9..88d82f4 100644 --- a/dsh_sdk/src/schema_store/types/subject_strategy.rs +++ b/dsh_sdk/src/schema_store/types/subject_strategy.rs @@ -1,5 +1,7 @@ use std::hash::{Hash, Hasher}; +use crate::schema_store::SchemaStoreError; + /// Subject name strategy /// /// Defines the strategy to use for the subject name @@ -27,7 +29,7 @@ pub enum SubjectName { } impl SubjectName { - pub fn new(topic: S, key: bool) -> Self + pub fn new_topic_name_strategy(topic: S, key: bool) -> Self where S: AsRef, { @@ -61,25 +63,29 @@ impl SubjectName { } } -impl From<&str> for SubjectName { - fn from(value: &str) -> Self { - let (topic, key) = if value.ends_with("-key") { - (value.trim_end_matches("-key"), true) +impl TryFrom<&str> for SubjectName { + type Error = SchemaStoreError; + fn try_from(value: &str) -> Result { + if value.ends_with("-key") { + Ok(Self::TopicNameStrategy { + topic: value.trim_end_matches("-key").to_string(), + key: true, + }) } else if value.ends_with("-value") { - (value.trim_end_matches("-value"), false) + Ok(Self::TopicNameStrategy { + topic: value.trim_end_matches("-value").to_string(), + key: false, + }) } else { - (value, false) - }; - Self::TopicNameStrategy { - topic: topic.to_string(), - key, + Err(SchemaStoreError::InvalidSubjectName(value.to_string())) } } } -impl From for SubjectName { - fn from(value: String) -> Self { - value.as_str().into() +impl TryFrom for SubjectName { + type Error = SchemaStoreError; + fn try_from(value: String) -> Result { + value.as_str().try_into() } } @@ -151,7 +157,7 @@ mod tests { #[test] fn test_subject_name_new() { - let subject = SubjectName::new("scratch.example.tenant", false); + let subject = SubjectName::new_topic_name_strategy("scratch.example.tenant", false); assert_eq!( subject, SubjectName::TopicNameStrategy { @@ -163,7 +169,7 @@ mod tests { #[test] fn test_subject_name_from_string() { - let subject: SubjectName = "scratch.example.tenant-value".into(); + let subject: SubjectName = "scratch.example.tenant-value".try_into().unwrap(); assert_eq!( subject, SubjectName::TopicNameStrategy { @@ -172,7 +178,7 @@ mod tests { } ); - let subject: SubjectName = "scratch.example.tenant-key".into(); + let subject: SubjectName = "scratch.example.tenant-key".try_into().unwrap(); assert_eq!( subject, SubjectName::TopicNameStrategy { @@ -206,7 +212,7 @@ mod tests { #[test] fn test_subject_name_from_string_ref() { let string = "scratch.example.tenant-value".to_string(); - let subject: SubjectName = string.into(); + let subject: SubjectName = string.try_into().unwrap(); assert_eq!( subject, SubjectName::TopicNameStrategy { diff --git a/dsh_sdk/src/utils/dlq.rs b/dsh_sdk/src/utils/dlq.rs index 004ceaa..bb73430 100644 --- a/dsh_sdk/src/utils/dlq.rs +++ b/dsh_sdk/src/utils/dlq.rs @@ -195,7 +195,7 @@ impl Dlq { pub fn start(shutdown: Shutdown) -> Result> { let (dlq_tx, dlq_rx) = mpsc::channel(200); let dlq_producer: FutureProducer = - ClientConfig::new().dsh_producer_config().create()?; + ClientConfig::new().set_dsh_producer_config().create()?; let dlq_dead_topic = get_env_var("DLQ_DEAD_TOPIC")?; let dlq_retry_topic = get_env_var("DLQ_RETRY_TOPIC")?; let dlq = Self { diff --git a/example_dsh_service/src/main.rs b/example_dsh_service/src/main.rs index e5e7e8d..c89ae38 100644 --- a/example_dsh_service/src/main.rs +++ b/example_dsh_service/src/main.rs @@ -1,18 +1,21 @@ use dsh_sdk::utils::graceful_shutdown::Shutdown; -use dsh_sdk::Dsh; +use dsh_sdk::DshKafkaConfig; + use rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; use rdkafka::message::{BorrowedMessage, Message}; +use rdkafka::ClientConfig; use log::{error, info}; mod custom_metrics; +/// Deserialize and print the message fn deserialize_and_print(msg: &BorrowedMessage) { let payload = String::from_utf8_lossy(msg.payload().unwrap_or(b"")); let key = String::from_utf8_lossy(msg.key().unwrap_or(b"")); - info!( - "Received message from topic {} partition {} offset {} with key {:?} and payload {}", + println!( + "Received message from topic: {}, partition: {}, offset: {}, key: {}, and payload:\n{}", msg.topic(), msg.partition(), msg.offset(), @@ -21,6 +24,7 @@ fn deserialize_and_print(msg: &BorrowedMessage) { ); } +/// Simple consumer that consumes messages from Kafka and prints them async fn consume(consumer: StreamConsumer, shutdown: Shutdown) { loop { tokio::select! { @@ -54,9 +58,6 @@ async fn main() -> Result<(), Box> { // Start http server for exposing prometheus metrics, note that in Dockerfile we expose port 8080 as well dsh_sdk::utils::metrics::start_http_server(8080); - // Create a new properties instance (connects to the DSH server and fetches the datastream) - let dsh_properties = Dsh::get(); - // Get the configured topics from env variable TOPICS (comma separated) let topics_string = std::env::var("TOPICS").expect("TOPICS env variable not set"); let topics = topics_string.split(',').collect::>(); @@ -64,11 +65,11 @@ async fn main() -> Result<(), Box> { // Initialize the shutdown handler (This will handle SIGTERM and SIGINT signals, and you can act on them) let shutdown = Shutdown::new(); - // Get the consumer config from the Properties instance - let mut consumer_client_config = dsh_properties.consumer_rdkafka_config(); + // Create RDKafka Client config + let mut consumer_client_config = ClientConfig::new(); - // Override some default values (optional) - consumer_client_config.set("auto.offset.reset", "earliest"); + // Load the Kafka configuration from the SDK (this method comes from the `DshKafkaConfig` trait) + consumer_client_config.set_dsh_consumer_config(); // Create a new consumer instance let consumer: StreamConsumer = consumer_client_config.create()?; @@ -93,7 +94,7 @@ async fn main() -> Result<(), Box> { } } - // Wait till the shutdown is complete + // Wait till the graceful shutdown is finished shutdown.complete().await; Ok(()) } From 2fd83010f0605457a9256ab39c1843a6c7651f8e Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Wed, 8 Jan 2025 15:31:47 +0100 Subject: [PATCH 06/23] restore original Properties to have proper deprecation warnings --- dsh_sdk/src/dsh_old/bootstrap.rs | 320 ++++++++++++++++++++++++ dsh_sdk/src/dsh_old/certificates.rs | 76 +++++- dsh_sdk/src/dsh_old/config.rs | 198 +++++++++++++++ dsh_sdk/src/dsh_old/datastream.rs | 187 ++++++++++++-- dsh_sdk/src/dsh_old/error.rs | 38 +++ dsh_sdk/src/dsh_old/mod.rs | 31 +-- dsh_sdk/src/dsh_old/pki_config_dir.rs | 242 ++++++++++++++++++ dsh_sdk/src/dsh_old/properties.rs | 337 +++++++++++++++++--------- dsh_sdk/src/lib.rs | 4 +- 9 files changed, 1253 insertions(+), 180 deletions(-) create mode 100644 dsh_sdk/src/dsh_old/bootstrap.rs create mode 100644 dsh_sdk/src/dsh_old/config.rs create mode 100644 dsh_sdk/src/dsh_old/error.rs create mode 100644 dsh_sdk/src/dsh_old/pki_config_dir.rs diff --git a/dsh_sdk/src/dsh_old/bootstrap.rs b/dsh_sdk/src/dsh_old/bootstrap.rs new file mode 100644 index 0000000..096d5e4 --- /dev/null +++ b/dsh_sdk/src/dsh_old/bootstrap.rs @@ -0,0 +1,320 @@ +//! Module for bootstrapping the DSH client. +//! +//! This module contains the logic to connect to DSH and retrieve the certificates and datastreams.json +//! to create the properties struct. It follows the certificate signing request pattern as normally +//! used in the get_signed_certificates_json.sh script. +//! +//! ## Note +//! This module is not intended to be used directly, but through the `Properties` struct. It will +//! always be used when getting a `Properties` struct via dsh::Properties::get(). +use log::{debug, info}; +use reqwest::blocking::Client; + +use super::error::DshError; + +use super::certificates::Cert; +use crate::utils; +use crate::{VAR_DSH_CA_CERTIFICATE, VAR_DSH_SECRET_TOKEN, VAR_DSH_SECRET_TOKEN_PATH}; + +/// Connect to DSH and retrieve the certificates and datastreams.json to create the properties struct +pub(crate) fn bootstrap( + config_host: &str, + tenant_name: &str, + task_id: &str, +) -> Result { + let dsh_config = DshConfig::new(config_host, tenant_name, task_id)?; + let client = reqwest_ca_client(dsh_config.dsh_ca_certificate.as_bytes())?; + let dn = DshBootstapCall::Dn(&dsh_config).retryable_call(&client)?; + let dn = Dn::parse_string(&dn)?; + let certificates = Cert::get_signed_client_cert(dn, &dsh_config, &client)?; + info!("Successfully connected to DSH"); + Ok(certificates) +} + +/// Build a request client with the DSH CA certificate. +fn reqwest_ca_client(dsh_ca_certificate: &[u8]) -> Result { + let reqwest_cert = reqwest::tls::Certificate::from_pem(dsh_ca_certificate)?; + let client = Client::builder() + .add_root_certificate(reqwest_cert) + .build()?; + Ok(client) +} + +/// Helper struct to store the config needed for bootstrapping to DSH +#[derive(Debug)] +pub(crate) struct DshConfig<'a> { + config_host: &'a str, + tenant_name: &'a str, + task_id: &'a str, + dsh_secret_token: String, + dsh_ca_certificate: String, +} +impl<'a> DshConfig<'a> { + fn new(config_host: &'a str, tenant_name: &'a str, task_id: &'a str) -> Result { + let dsh_secret_token = match utils::get_env_var(VAR_DSH_SECRET_TOKEN) { + Ok(token) => token, + Err(_) => { + // if DSH_SECRET_TOKEN is not set, try to read it from a file (for system space applications) + debug!("trying to read DSH_SECRET_TOKEN from file"); + let secret_token_path = utils::get_env_var(VAR_DSH_SECRET_TOKEN_PATH) + .map_err(|_| DshError::EnvVarError(VAR_DSH_SECRET_TOKEN_PATH))?; + let path = std::path::PathBuf::from(secret_token_path); + std::fs::read_to_string(path)? + } + }; + let dsh_ca_certificate = utils::get_env_var(VAR_DSH_CA_CERTIFICATE) + .map_err(|_| DshError::EnvVarError(VAR_DSH_CA_CERTIFICATE))?; + Ok(DshConfig { + config_host, + task_id, + tenant_name, + dsh_secret_token, + dsh_ca_certificate, + }) + } + + pub(crate) fn dsh_ca_certificate(&self) -> &str { + &self.dsh_ca_certificate + } +} + +pub(crate) enum DshBootstapCall<'a> { + /// Call to retreive distinguished name. + Dn(&'a DshConfig<'a>), + /// Call to post the certificate signing request. + CertificateSignRequest { + config: &'a DshConfig<'a>, + csr: &'a str, + }, +} + +impl DshBootstapCall<'_> { + fn url_for_call(&self) -> String { + match self { + DshBootstapCall::Dn(config) => { + format!( + "{}/dn/{}/{}", + config.config_host, config.tenant_name, config.task_id + ) + } + DshBootstapCall::CertificateSignRequest { config, .. } => { + format!( + "{}/sign/{}/{}", + config.config_host, config.tenant_name, config.task_id + ) + } + } + } + + fn request_builder(&self, client: &Client) -> reqwest::blocking::RequestBuilder { + let url = self.url_for_call(); + match self { + DshBootstapCall::Dn(..) => client.get(url), + DshBootstapCall::CertificateSignRequest { config, csr, .. } => client + .post(url) + .header("X-Kafka-Config-Token", &config.dsh_secret_token) + .body(csr.to_string()), + } + } + + fn perform_call(&self, client: &Client) -> Result { + let response = self.request_builder(client).send()?; + if !response.status().is_success() { + return Err(DshError::DshCallError { + url: self.url_for_call(), + status_code: response.status(), + error_body: response.text().unwrap_or_default(), + }); + } + Ok(response.text()?) + } + + pub(crate) fn retryable_call(&self, client: &Client) -> Result { + let mut retries = 0; + loop { + match self.perform_call(client) { + Ok(response) => return Ok(response), + Err(err) => { + if retries >= 30 { + return Err(err); + } + retries += 1; + // sleep exponentially + let sleep: u64 = std::cmp::min(2u64.pow(retries), 60); + log::warn!( + "Retrying call to DSH in {sleep} seconds due to error: {}", + crate::error::report(&err) + ); + std::thread::sleep(std::time::Duration::from_secs(sleep)); + } + } + } + } +} + +/// Struct to parse DN string into separate fields. +/// Needed for Picky solution. +#[derive(Debug)] +pub(crate) struct Dn { + cn: String, + ou: String, + o: String, +} + +impl Dn { + /// Parse the DN string into Dn struct. + pub(crate) fn parse_string(dn_string: &str) -> Result { + let mut cn = None; + let mut ou = None; + let mut o = None; + + for segment in dn_string.split(',') { + let parts: Vec<&str> = segment.split('=').collect(); + if parts.len() == 2 { + match parts[0] { + "CN" => cn = Some(parts[1].to_string()), + "OU" => ou = Some(parts[1].to_string()), + "O" => o = Some(parts[1].to_string()), + _ => (), + } + } + } + + Ok(Dn { + cn: cn.ok_or(DshError::ParseDnError( + "CN is missing in DN string".to_string(), + ))?, + ou: ou.ok_or(DshError::ParseDnError( + "OU is missing in DN string".to_string(), + ))?, + o: o.ok_or(DshError::ParseDnError( + "O is missing in DN string".to_string(), + ))?, + }) + } + pub(crate) fn cn(&self) -> &str { + &self.cn + } + + pub(crate) fn ou(&self) -> &str { + &self.ou + } + + pub(crate) fn o(&self) -> &str { + &self.o + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::env; + use std::str::from_utf8; + + #[test] + fn test_dsh_call_request_builder() { + let dsh_config = DshConfig { + config_host: "https://test_host", + tenant_name: "test_tenant_name", + task_id: "test_task_id", + dsh_secret_token: "test_token".to_string(), + dsh_ca_certificate: "test_ca_certificate".to_string(), + }; + let builder: reqwest::blocking::RequestBuilder = + DshBootstapCall::Dn(&dsh_config).request_builder(&Client::new()); + let request = builder.build().unwrap(); + assert_eq!(request.method().as_str(), "GET"); + let csr = "-----BEGIN test_type-----\n-----END test_type-----"; + let builder: reqwest::blocking::RequestBuilder = DshBootstapCall::CertificateSignRequest { + config: &dsh_config, + csr, + } + .request_builder(&Client::new()); + let request = builder.build().unwrap(); + assert_eq!(request.method().as_str(), "POST"); + assert_eq!( + request + .headers() + .get("X-Kafka-Config-Token") + .unwrap() + .to_str() + .unwrap(), + "test_token" + ); + let body = from_utf8(request.body().unwrap().as_bytes().unwrap()).unwrap(); + assert_eq!(body, csr); + } + + #[test] + fn test_dsh_call_perform() { + // Create a mock for the expected HTTP request + let mut dsh = mockito::Server::new(); + let dn = "CN=test_cn,OU=test_ou,O=test_o"; + dsh.mock("GET", "/dn/tenant/test_task_id") + .with_status(200) + .with_header("content-type", "text/plain") + .with_body(dn) + .create(); + // simple reqwest client + let client = Client::new(); + // create a DshConfig struct + let dsh_config = DshConfig { + config_host: &dsh.url(), + tenant_name: "tenant", + task_id: "test_task_id", + dsh_secret_token: "test_token".to_string(), + dsh_ca_certificate: "test_ca_certificate".to_string(), + }; + // call the function + let response = DshBootstapCall::Dn(&dsh_config) + .perform_call(&client) + .unwrap(); + assert_eq!(response, dn); + } + + #[test] + fn test_dsh_parse_dn() { + let dn_string = "CN=test_cn,OU=test_ou,O=test_o"; + let dn = Dn::parse_string(dn_string).unwrap(); + assert_eq!(dn.cn, "test_cn"); + assert_eq!(dn.ou, "test_ou"); + assert_eq!(dn.o, "test_o"); + } + + #[test] + #[serial(env_dependency)] + fn test_dsh_config_new() { + // normal situation where DSH variables are set + env::set_var(VAR_DSH_SECRET_TOKEN, "test_token"); + env::set_var(VAR_DSH_CA_CERTIFICATE, "test_ca_certificate"); + let config_host = "https://test_host"; + let tenant_name = "test_tenant"; + let task_id = "test_task_id"; + let dsh_config = DshConfig::new(config_host, tenant_name, task_id).unwrap(); + assert_eq!(dsh_config.config_host, "https://test_host"); + assert_eq!(dsh_config.task_id, "test_task_id"); + assert_eq!(dsh_config.tenant_name, "test_tenant"); + assert_eq!(dsh_config.dsh_secret_token, "test_token"); + assert_eq!(dsh_config.dsh_ca_certificate, "test_ca_certificate"); + // DSH_SECRET_TOKEN is not set, but DSH_SECRET_TOKEN_PATH is set + env::remove_var(VAR_DSH_SECRET_TOKEN); + let test_token_dir = "test_files"; + std::fs::create_dir_all(test_token_dir).unwrap(); + let test_token_dir = format!("{}/test_token", test_token_dir); + let _ = std::fs::remove_file(&test_token_dir); + env::set_var(VAR_DSH_SECRET_TOKEN_PATH, &test_token_dir); + let result = DshConfig::new(config_host, tenant_name, task_id); + assert!(result.is_err()); + std::fs::write(test_token_dir.as_str(), "test_token_from_file").unwrap(); + let dsh_config = DshConfig::new(config_host, tenant_name, task_id).unwrap(); + assert_eq!(dsh_config.dsh_secret_token, "test_token_from_file"); + // fail if DSH_CA_CERTIFICATE is not set + env::remove_var(VAR_DSH_CA_CERTIFICATE); + let result = DshConfig::new(config_host, tenant_name, task_id); + assert!(result.is_err()); + env::remove_var(VAR_DSH_SECRET_TOKEN); + env::remove_var(VAR_DSH_CA_CERTIFICATE); + env::remove_var(VAR_DSH_SECRET_TOKEN_PATH); + } +} diff --git a/dsh_sdk/src/dsh_old/certificates.rs b/dsh_sdk/src/dsh_old/certificates.rs index 96d8562..800ab20 100644 --- a/dsh_sdk/src/dsh_old/certificates.rs +++ b/dsh_sdk/src/dsh_old/certificates.rs @@ -11,28 +11,32 @@ //! To create the ca.crt, client.pem, and client.key files in a desired directory, use the //! `to_files` method. //! ```no_run -//! use dsh_sdk::certificates::Cert; +//! use dsh_sdk::Properties; //! use std::path::PathBuf; //! //! # fn main() -> Result<(), Box> { -//! let certificates = Cert::from_env()?; -//! let directory = PathBuf::from("path/to/dir"); -//! certificates.to_files(&directory)?; +//! let dsh_properties = Properties::get(); +//! let directory = PathBuf::from("dir"); +//! dsh_properties.certificates()?.to_files(&directory)?; //! # Ok(()) //! # } //! ``` //! //! ## Reqwest Client //! With this request client we can retrieve datastreams.json and connect to Schema Registry. -use std::path::PathBuf; use std::sync::Arc; use log::info; -use rcgen::KeyPair; use reqwest::blocking::{Client, ClientBuilder}; use reqwest::Identity; +use std::path::PathBuf; + +use super::bootstrap::{Dn, DshBootstapCall, DshConfig}; -use crate::error::DshError; +use super::error::DshError; + +use pem; +use rcgen::{CertificateParams, CertificateSigningRequest, DnType, KeyPair}; /// Hold all relevant certificates and keys to connect to DSH Kafka Cluster and Schema Store. #[derive(Debug, Clone)] @@ -76,7 +80,7 @@ impl Cert { key_pair, )) } - + /// Build an async reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. pub fn reqwest_client_config(&self) -> Result { @@ -160,6 +164,19 @@ impl Cert { Ok(()) } + /// Generate the certificate signing request. + fn generate_csr(key_pair: &KeyPair, dn: Dn) -> Result { + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, dn.cn()); + params + .distinguished_name + .push(DnType::OrganizationalUnitName, dn.ou()); + params + .distinguished_name + .push(DnType::OrganizationName, dn.o()); + Ok(params.serialize_request(key_pair)?) + } + fn create_file>(path: PathBuf, contents: C) -> Result<(), DshError> { std::fs::write(&path, contents)?; info!("File created ({})", path.display()); @@ -192,6 +209,7 @@ mod tests { use std::sync::OnceLock; use openssl::pkey::PKey; + use openssl::x509::X509Req; static TEST_CERTIFICATES: OnceLock = OnceLock::new(); @@ -199,11 +217,12 @@ mod tests { let subject_alt_names = vec!["hello.world.example".to_string(), "localhost".to_string()]; let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names).unwrap(); - Cert { - dsh_ca_certificate_pem: cert.pem(), - dsh_client_certificate_pem: cert.pem(), - key_pair: Arc::new(key_pair), - } + Cert::new(cert.pem(), cert.pem(), key_pair) + //Cert { + // dsh_ca_certificate_pem: CA_CERT.to_string(), + // dsh_client_certificate_pem: KAFKA_CERT.to_string(), + // key_pair: Arc::new(KeyPair::generate().unwrap()), + //} } #[test] @@ -277,6 +296,37 @@ mod tests { assert!(std::path::Path::new(&format!("{}/client.key", dir)).exists()); } + #[test] + fn test_dsh_certificate_sign_request() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); + let csr = Cert::generate_csr(&cert.key_pair, dn).unwrap(); + let req = csr.pem().unwrap(); + assert!(req.starts_with("-----BEGIN CERTIFICATE REQUEST-----")); + assert!(req.trim().ends_with("-----END CERTIFICATE REQUEST-----")); + } + + #[test] + fn test_verify_csr() { + let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); + let dn = Dn::parse_string("CN=Test CN,OU=Test OU,O=Test Org").unwrap(); + let csr = Cert::generate_csr(&cert.key_pair, dn).unwrap(); + let csr_pem = csr.pem().unwrap(); + let key = cert.private_key_pkcs8(); + let pkey = PKey::private_key_from_der(&key).unwrap(); + + let req = X509Req::from_pem(csr_pem.as_bytes()).unwrap(); + req.verify(&pkey).unwrap(); + let subject = req + .subject_name() + .entries() + .into_iter() + .map(|e| e.data().as_utf8().unwrap().to_string()) + .collect::>() + .join(","); + assert_eq!(subject, "Test CN,Test OU,Test Org"); + } + #[test] fn test_create_identity() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); diff --git a/dsh_sdk/src/dsh_old/config.rs b/dsh_sdk/src/dsh_old/config.rs new file mode 100644 index 0000000..287b480 --- /dev/null +++ b/dsh_sdk/src/dsh_old/config.rs @@ -0,0 +1,198 @@ +//! Additional optional configuration for kafka producer and consumer +use crate::utils::get_env_var; +use crate::*; + +/// Additional configuration for Consumer config +/// +/// ## Environment variables +/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information +/// configuring the consmer via environment variables. +#[derive(Debug, Clone)] +pub struct ConsumerConfig { + enable_auto_commit: bool, + auto_offset_reset: String, + session_timeout: Option, + queued_buffering_max_messages_kbytes: Option, +} + +/// Additional configuration for Producer config +/// +/// ## Environment variables +/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information +/// configuring the producer via environment variables. +#[derive(Debug, Clone, Default)] +pub struct ProducerConfig { + batch_num_messages: Option, + queue_buffering_max_messages: Option, + queue_buffering_max_kbytes: Option, + queue_buffering_max_ms: Option, +} + +impl ConsumerConfig { + pub fn new() -> Self { + let enable_auto_commit = get_env_var(VAR_KAFKA_ENABLE_AUTO_COMMIT) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(false); + let auto_offset_reset = + get_env_var(VAR_KAFKA_AUTO_OFFSET_RESET).unwrap_or("earliest".to_string()); + let session_timeout = get_env_var(VAR_KAFKA_CONSUMER_SESSION_TIMEOUT_MS) + .ok() + .and_then(|v| v.parse().ok()); + let queued_buffering_max_messages_kbytes = + get_env_var(VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES) + .ok() + .and_then(|v| v.parse().ok()); + ConsumerConfig { + enable_auto_commit, + auto_offset_reset, + session_timeout, + queued_buffering_max_messages_kbytes, + } + } + pub fn enable_auto_commit(&self) -> bool { + self.enable_auto_commit + } + pub fn auto_offset_reset(&self) -> String { + self.auto_offset_reset.clone() + } + pub fn session_timeout(&self) -> Option { + self.session_timeout + } + pub fn queued_buffering_max_messages_kbytes(&self) -> Option { + self.queued_buffering_max_messages_kbytes + } +} + +impl Default for ConsumerConfig { + fn default() -> Self { + ConsumerConfig { + enable_auto_commit: false, + auto_offset_reset: "earliest".to_string(), + session_timeout: None, + queued_buffering_max_messages_kbytes: None, + } + } +} + +impl ProducerConfig { + pub fn new() -> Self { + let batch_num_messages = get_env_var(VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES) + .ok() + .and_then(|v| v.parse().ok()); + let queue_buffering_max_messages = + get_env_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES) + .ok() + .and_then(|v| v.parse().ok()); + let queue_buffering_max_kbytes = get_env_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES) + .ok() + .and_then(|v| v.parse().ok()); + let queue_buffering_max_ms = get_env_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS) + .ok() + .and_then(|v| v.parse().ok()); + ProducerConfig { + batch_num_messages, + queue_buffering_max_messages, + queue_buffering_max_kbytes, + queue_buffering_max_ms, + } + } + + pub fn batch_num_messages(&self) -> Option { + self.batch_num_messages + } + pub fn queue_buffering_max_messages(&self) -> Option { + self.queue_buffering_max_messages + } + pub fn queue_buffering_max_kbytes(&self) -> Option { + self.queue_buffering_max_kbytes + } + pub fn queue_buffering_max_ms(&self) -> Option { + self.queue_buffering_max_ms + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::env; + + #[test] + fn test_consumer_config() { + let consumer_config = ConsumerConfig::new(); + assert_eq!(consumer_config.enable_auto_commit(), false); + assert_eq!(consumer_config.auto_offset_reset(), "earliest"); + assert_eq!(consumer_config.session_timeout(), None); + assert_eq!(consumer_config.queued_buffering_max_messages_kbytes(), None); + } + + #[test] + fn test_consumer_config_default() { + let consumer_config = ConsumerConfig::default(); + assert_eq!(consumer_config.enable_auto_commit(), false); + assert_eq!(consumer_config.auto_offset_reset(), "earliest"); + assert_eq!(consumer_config.session_timeout(), None); + assert_eq!(consumer_config.queued_buffering_max_messages_kbytes(), None); + } + + #[test] + #[serial(env_dependency)] + fn test_consumer_config_env() { + env::set_var(VAR_KAFKA_ENABLE_AUTO_COMMIT, "true"); + env::set_var(VAR_KAFKA_AUTO_OFFSET_RESET, "latest"); + env::set_var(VAR_KAFKA_CONSUMER_SESSION_TIMEOUT_MS, "1000"); + env::set_var( + VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES, + "1000", + ); + let consumer_config = ConsumerConfig::new(); + assert_eq!(consumer_config.enable_auto_commit(), true); + assert_eq!(consumer_config.auto_offset_reset(), "latest"); + assert_eq!(consumer_config.session_timeout(), Some(1000)); + assert_eq!( + consumer_config.queued_buffering_max_messages_kbytes(), + Some(1000) + ); + env::remove_var(VAR_KAFKA_ENABLE_AUTO_COMMIT); + env::remove_var(VAR_KAFKA_AUTO_OFFSET_RESET); + env::remove_var(VAR_KAFKA_CONSUMER_SESSION_TIMEOUT_MS); + env::remove_var(VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES); + } + + #[test] + fn test_producer_config() { + let producer_config = ProducerConfig::new(); + assert_eq!(producer_config.batch_num_messages(), None); + assert_eq!(producer_config.queue_buffering_max_messages(), None); + assert_eq!(producer_config.queue_buffering_max_kbytes(), None); + assert_eq!(producer_config.queue_buffering_max_ms(), None); + } + + #[test] + fn test_producer_config_default() { + let producer_config = ProducerConfig::default(); + assert_eq!(producer_config.batch_num_messages(), None); + assert_eq!(producer_config.queue_buffering_max_messages(), None); + assert_eq!(producer_config.queue_buffering_max_kbytes(), None); + assert_eq!(producer_config.queue_buffering_max_ms(), None); + } + + #[test] + #[serial(env_dependency)] + fn test_producer_config_env() { + env::set_var(VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES, "1000"); + env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES, "1000"); + env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES, "1000"); + env::set_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS, "1000"); + let producer_config = ProducerConfig::new(); + assert_eq!(producer_config.batch_num_messages(), Some(1000)); + assert_eq!(producer_config.queue_buffering_max_messages(), Some(1000)); + assert_eq!(producer_config.queue_buffering_max_kbytes(), Some(1000)); + assert_eq!(producer_config.queue_buffering_max_ms(), Some(1000)); + env::remove_var(VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES); + env::remove_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES); + env::remove_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES); + env::remove_var(VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS); + } +} diff --git a/dsh_sdk/src/dsh_old/datastream.rs b/dsh_sdk/src/dsh_old/datastream.rs index 4d86a67..e8e8fb6 100644 --- a/dsh_sdk/src/dsh_old/datastream.rs +++ b/dsh_sdk/src/dsh_old/datastream.rs @@ -16,12 +16,20 @@ //! let schema_store = datastream.schema_store(); //! ``` use std::collections::HashMap; +use std::env; +use std::fs::File; +use std::io::Read; -use log::info; +use log::{debug, error, info, warn}; use serde::{Deserialize, Serialize}; -use crate::error::DshError; -use crate::{utils, VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_SCHEMA_REGISTRY_HOST}; +use super::error::DshError; +use crate::{ + utils, VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_KAFKA_CONSUMER_GROUP_TYPE, VAR_LOCAL_DATASTREAMS_JSON, + VAR_SCHEMA_REGISTRY_HOST, +}; + +const FILE_NAME: &str = "local_datastreams.json"; /// This struct is equivalent to the datastreams.json /// @@ -91,20 +99,20 @@ impl Datastream { pub fn verify_list_of_topics( &self, topics: &Vec, - access: crate::datastream::ReadWriteAccess, + access: ReadWriteAccess, ) -> Result<(), DshError> { let read_topics = self .streams() .values() .map(|datastream| match access { - crate::datastream::ReadWriteAccess::Read => datastream + ReadWriteAccess::Read => datastream .read .split('.') .take(2) .collect::>() .join(".") .replace('\\', ""), - crate::datastream::ReadWriteAccess::Write => datastream + ReadWriteAccess::Write => datastream .write .split('.') .take(2) @@ -143,7 +151,7 @@ impl Datastream { /// /// # Example /// ```no_run - /// # use dsh_sdk::datastream::Datastream; + /// # use dsh_sdk::dsh_old::datastream::Datastream; /// # let datastream = Datastream::default(); /// let path = std::path::PathBuf::from("/path/to/directory"); /// datastream.to_file(&path).unwrap(); @@ -202,6 +210,40 @@ impl Datastream { pub(crate) fn datastreams_endpoint(host: &str, tenant: &str, task_id: &str) -> String { format!("{}/kafka/config/{}/{}", host, tenant, task_id) } + + /// If local_datastreams.json is found, it will load the datastreams from this file. + /// If it does not parse or the file is not found based on on Environment Variable, it will panic. + /// If the Environment Variable is not set, it will look in the current directory. If it is not found, + /// it will return a Error on the Result. Based on this it will use default Datastreams. + pub(crate) fn load_local_datastreams() -> Result { + let path_buf = if let Ok(path) = utils::get_env_var(VAR_LOCAL_DATASTREAMS_JSON) { + let path = std::path::PathBuf::from(path); + if !path.exists() { + panic!("{} not found", path.display()); + } else { + path + } + } else { + std::env::current_dir().unwrap().join(FILE_NAME) + }; + debug!("Reading local datastreams from {}", path_buf.display()); + let mut file = File::open(&path_buf).map_err(|e| { + debug!( + "Failed opening local_datastreams.json ({}): {}", + path_buf.display(), + e + ); + DshError::IoError(e) + })?; + let mut contents = String::new(); + file.read_to_string(&mut contents).unwrap(); + let mut datastream: Datastream = serde_json::from_str(&contents) + .unwrap_or_else(|e| panic!("Failed to parse {}, {:?}", path_buf.display(), e)); + if let Ok(brokers) = utils::get_env_var(VAR_KAFKA_BOOTSTRAP_SERVERS) { + datastream.brokers = brokers.split(',').map(|s| s.to_string()).collect(); + } + Ok(datastream) + } } impl Default for Datastream { @@ -313,7 +355,7 @@ impl Stream { } else { Err(DshError::TopicPermissionsError( self.name.clone(), - crate::datastream::ReadWriteAccess::Read, + ReadWriteAccess::Read, )) } } @@ -328,22 +370,58 @@ impl Stream { } else { Err(DshError::TopicPermissionsError( self.name.clone(), - crate::datastream::ReadWriteAccess::Write, + ReadWriteAccess::Write, )) } } } -pub use crate::datastream::GroupType; -pub use crate::datastream::ReadWriteAccess; +/// Enum to indicate if we want to check the read or write topics +#[derive(Debug, Clone, PartialEq)] +pub enum ReadWriteAccess { + Read, + Write, +} + +#[derive(Debug, PartialEq)] +pub enum GroupType { + Private(usize), + Shared(usize), +} + +impl GroupType { + /// Get the group type from the environment variable KAFKA_CONSUMER_GROUP_TYPE + /// If KAFKA_CONSUMER_GROUP_TYPE is not (properly) set, it defaults to shared + pub fn from_env() -> Self { + let group_type = env::var(VAR_KAFKA_CONSUMER_GROUP_TYPE); + match group_type { + Ok(s) if s.to_lowercase() == *"private" => GroupType::Private(0), + Ok(s) if s.to_lowercase() == *"shared" => GroupType::Shared(0), + Ok(_) => { + error!("KAFKA_CONSUMER_GROUP_TYPE is not set with \"shared\" or \"private\", defaulting to shared group type."); + GroupType::Shared(0) + } + Err(_) => { + warn!("KAFKA_CONSUMER_GROUP_TYPE is not set, defaulting to shared group type."); + GroupType::Shared(0) + } + } + } +} + +impl std::fmt::Display for GroupType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GroupType::Private(i) => write!(f, "private; index: {}", i), + GroupType::Shared(i) => write!(f, "shared; index: {}", i), + } + } +} #[cfg(test)] mod tests { use super::*; use serial_test::serial; - use std::io::Read; - - use crate::VAR_KAFKA_CONSUMER_GROUP_TYPE; // Define a reusable Properties instance fn datastream() -> Datastream { @@ -485,13 +563,13 @@ mod tests { #[serial(env_dependency)] fn test_datastream_get_group_type_from_env() { // Set the KAFKA_CONSUMER_GROUP_TYPE environment variable to "private" - std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "private"); + env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "private"); assert_eq!(GroupType::from_env(), GroupType::Private(0),); - std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "shared"); + env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "shared"); assert_eq!(GroupType::from_env(), GroupType::Shared(0),); - std::env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "invalid-type"); + env::set_var(VAR_KAFKA_CONSUMER_GROUP_TYPE, "invalid-type"); assert_eq!(GroupType::from_env(), GroupType::Shared(0),); - std::env::remove_var(VAR_KAFKA_CONSUMER_GROUP_TYPE); + env::remove_var(VAR_KAFKA_CONSUMER_GROUP_TYPE); assert_eq!(GroupType::from_env(), GroupType::Shared(0),); } @@ -600,6 +678,79 @@ mod tests { assert!(result.is_ok()) } + #[test] + #[serial(env_dependency)] + fn test_load_local_valid_datastreams() { + // load from root directory + let datastream = Datastream::load_local_datastreams().is_ok(); + assert!(datastream); + // load from custom directory + let current_dir = env::current_dir().unwrap(); + let file_location = format!( + "{}/test_resources/valid_datastreams.json", + current_dir.display() + ); + println!("file_location: {}", file_location); + env::set_var(VAR_LOCAL_DATASTREAMS_JSON, file_location); + let datastream = Datastream::load_local_datastreams().is_ok(); + assert!(datastream); + env::remove_var(VAR_LOCAL_DATASTREAMS_JSON); + } + + #[test] + #[serial(env_dependency)] + fn test_load_local_nonexisting_datastreams() { + let current_dir = env::current_dir().unwrap(); + let file_location = format!( + "{}/test_resoources/nonexisting_datastreams.json", + current_dir.display() + ); + env::set_var(VAR_LOCAL_DATASTREAMS_JSON, file_location); + // let it panic in a thread + let join_handle = std::thread::spawn(move || { + let _ = Datastream::load_local_datastreams(); + }); + let result = join_handle.join(); + assert!(result.is_err()); + env::remove_var(VAR_LOCAL_DATASTREAMS_JSON); + } + + #[test] + #[serial(env_dependency)] + fn test_load_local_invalid_datastreams() { + let current_dir = env::current_dir().unwrap(); + let file_location = format!( + "{}/test_resources/invalid_datastreams.json", + current_dir.display() + ); + env::set_var(VAR_LOCAL_DATASTREAMS_JSON, file_location); + // let it panic in a thread + let join_handle = std::thread::spawn(move || { + let _ = Datastream::load_local_datastreams(); + }); + let result = join_handle.join(); + assert!(result.is_err()); + env::remove_var(VAR_LOCAL_DATASTREAMS_JSON); + } + + #[test] + #[serial(env_dependency)] + fn test_load_local_invalid_json() { + let current_dir = env::current_dir().unwrap(); + let file_location = format!( + "{}/test_resources/invalid_datastreams_missing_field.json", + current_dir.display() + ); + env::set_var(VAR_LOCAL_DATASTREAMS_JSON, file_location); + // let it panic in a thread + let join_handle = std::thread::spawn(move || { + let _ = Datastream::load_local_datastreams(); + }); + let result = join_handle.join(); + assert!(result.is_err()); + env::remove_var(VAR_LOCAL_DATASTREAMS_JSON); + } + #[test] fn test_datastream_endpoint() { let host = "http://localhost:8080"; diff --git a/dsh_sdk/src/dsh_old/error.rs b/dsh_sdk/src/dsh_old/error.rs new file mode 100644 index 0000000..7a4310d --- /dev/null +++ b/dsh_sdk/src/dsh_old/error.rs @@ -0,0 +1,38 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum DshError { + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + #[error("Env var error: {0}")] + EnvVarError(&'static str), + #[error("Convert bytes to utf8 error: {0}")] + Utf8(#[from] std::string::FromUtf8Error), + #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] + DshCallError { + url: String, + status_code: reqwest::StatusCode, + error_body: String, + }, + #[error("Certificates are not set")] + NoCertificates, + #[error("Invalid PEM certificate: {0}")] + PemError(#[from] pem::PemError), + #[error("Reqwest: {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("Serde_json error: {0}")] + JsonError(#[from] serde_json::Error), + #[error("Rcgen error: {0}")] + PrivateKeyError(#[from] rcgen::Error), + #[error("Error parsing: {0}")] + ParseDnError(String), + #[error("Error getting group id, index out of bounds for {0}")] + IndexGroupIdError(super::datastream::GroupType), + #[error("No tenant name found")] + NoTenantName, + #[error("Error getting topic name {0}, Topic not found in datastreams.")] + NotFoundTopicError(String), + #[error("Error in topic permissions: {0} does not have {1:?} permissions.")] + TopicPermissionsError(String, super::datastream::ReadWriteAccess), +} diff --git a/dsh_sdk/src/dsh_old/mod.rs b/dsh_sdk/src/dsh_old/mod.rs index 4f6e16f..5de9d3e 100644 --- a/dsh_sdk/src/dsh_old/mod.rs +++ b/dsh_sdk/src/dsh_old/mod.rs @@ -21,35 +21,14 @@ //! # Ok(()) //! # } //! ``` -#[deprecated( - since = "0.5.0", - note = "`dsh_sdk::dsh::certificates` is moved to `dsh_sdk::certificates`" -)] +mod bootstrap; pub mod certificates; -#[deprecated( - since = "0.5.0", - note = "`dsh_sdk::dsh::datastream` is moved to `dsh_sdk::datastream`" -)] +mod config; pub mod datastream; -#[deprecated( - since = "0.5.0", - note = "`dsh_sdk::dsh::properties` is moved to `dsh_sdk::dsh`" -)] +mod error; +mod pki_config_dir; pub mod properties; // Re-export the properties struct to avoid braking changes - -#[deprecated( - since = "0.5.0", - note = "get_configured_topics is moved to `dsh_sdk::utils::get_configured_topics`" -)] -pub fn get_configured_topics() -> Result, crate::error::DshError> { - let kafka_topic_string = crate::utils::get_env_var("TOPICS")?; - Ok(kafka_topic_string - .split(',') - .map(str::trim) - .map(String::from) - .collect()) -} - +pub use super::utils::get_configured_topics; pub use properties::Properties; diff --git a/dsh_sdk/src/dsh_old/pki_config_dir.rs b/dsh_sdk/src/dsh_old/pki_config_dir.rs new file mode 100644 index 0000000..ac3a44a --- /dev/null +++ b/dsh_sdk/src/dsh_old/pki_config_dir.rs @@ -0,0 +1,242 @@ +//! Module for loading PKI Files from the $PKU_CONFIG_DIR directory. +//! +//! If for some reason the legacy bootstrap script is executed, these PKI files should be used, +//! instead of initiating a new PKI request. +//! +//! This also makes it possible to use the DSH SDK with Kafka Proxy +//! or VPN outside of the DSH environment. +use super::certificates::Cert; +use super::error::DshError; +use crate::{utils, VAR_PKI_CONFIG_DIR}; + +use log::{debug, info, warn}; +use pem::{self, Pem}; +use rcgen::KeyPair; +use std::path::{Path, PathBuf}; + +pub(crate) fn get_pki_cert() -> Result { + let config_dir = PathBuf::from( + utils::get_env_var(VAR_PKI_CONFIG_DIR) + .map_err(|_| DshError::EnvVarError(VAR_PKI_CONFIG_DIR))?, + ); + let ca_cert_paths = get_file_path_bufs("ca", PkiFileType::Cert, &config_dir)?; + let dsh_ca_certificate_pem = get_certificate(ca_cert_paths)?; + let client_cert_paths = get_file_path_bufs("client", PkiFileType::Cert, &config_dir)?; + let dsh_client_certificate_pem = get_certificate(client_cert_paths)?; + let client_key_paths = get_file_path_bufs("client", PkiFileType::Key, &config_dir)?; + let key_pair = get_key_pair(client_key_paths)?; + info!("Certificates loaded from PKI config directory"); + Ok(Cert::new( + pem::encode_many(&dsh_ca_certificate_pem), + pem::encode_many(&dsh_client_certificate_pem), + key_pair, + )) +} + +/// Get certificate from the PKI config directory +/// +/// Looks for all files containing client*.pem and client.crt in the PKI config directory. +fn get_certificate(mut cert_paths: Vec) -> Result, DshError> { + 'file: while let Some(file) = cert_paths.pop() { + info!("{} - Reading certificate file", file.display()); + if let Ok(ca_cert) = std::fs::read(&file) { + let pem_result = pem::parse_many(&ca_cert); + match pem_result { + Ok(pem) => { + debug!( + "{} - Certificate parsed as PEM ({} certificate in file)", + file.display(), + pem.len() + ); + for p in &pem { + if !p.tag().eq_ignore_ascii_case("CERTIFICATE") { + warn!("{} - Certificate tag is not 'CERTIFICATE'", file.display()); + continue 'file; + } + } + return Ok(pem); + } + Err(e) => warn!("{} - Error parsing certificate: {:?}", file.display(), e), + } + } + } + info!("No (valid) certificates found in the PKI config directory"); + Err(DshError::NoCertificates) +} + +/// Get certificate from the PKI config directory +/// +/// Looks for all files containing client*.pem and client.crt in the PKI config directory. +fn get_key_pair(mut key_paths: Vec) -> Result { + while let Some(file) = key_paths.pop() { + info!("{} - Reading key file", file.display()); + if let Ok(bytes) = std::fs::read(&file) { + if let Ok(string) = std::string::String::from_utf8(bytes) { + debug!("{} - Key parsed as string", file.display()); + match rcgen::KeyPair::from_pem(&string) { + Ok(key_pair) => { + debug!("{} - Key parsed as KeyPair from string", file.display()); + return Ok(key_pair); + } + Err(e) => warn!("{} - Error parsing key: {:?}", file.display(), e), + } + } + } + } + info!("No (valid) key found in the PKI config directory"); + Err(DshError::NoCertificates) +} + +/// Get the path to the PKI config direc +fn get_file_path_bufs

( + prefix: &str, + contains: PkiFileType, + config_dir: P, +) -> Result, DshError> +where + P: AsRef, +{ + let file_paths: Vec = config_dir + .as_ref() + .read_dir()? + .filter_map(|entry| { + entry.ok().and_then(|e| { + let filename = e.file_name().to_string_lossy().into_owned(); + if filename.contains(prefix) + && match contains { + PkiFileType::Cert => { + filename.ends_with(".crt") || filename.ends_with(".pem") + } + PkiFileType::Key => filename.contains(".key"), //.key.pem is allowed + } + { + Some(e.path()) + } else { + None + } + }) + }) + .collect(); + + if file_paths.len() > 1 { + warn!("Found multiple files: {:?}", file_paths); + } + + Ok(file_paths) +} + +/// Helper Enum for the type of PKI file +enum PkiFileType { + Cert, + Key, +} + +#[cfg(test)] +mod tests { + use super::*; + use openssl; + use openssl::pkey::PKey; + use serial_test::serial; + + const PKI_CONFIG_DIR: &str = "test_files/pki_config_dir"; + const PKI_KEY_FILE_NAME: &str = "client.key"; + const PKI_CERT_FILE_NAME: &str = "client.pem"; + const PKI_CA_FILE_NAME: &str = "ca.crt"; + + fn create_test_pki_config_dir() { + let path = PathBuf::from(PKI_CONFIG_DIR); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); + let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); + if path_key.exists() && path_cert.exists() && path_ca.exists() { + return; + } + let _ = std::fs::create_dir(path); + let priv_key = openssl::rsa::Rsa::generate(2048).unwrap(); + let pkey = PKey::from_rsa(priv_key).unwrap(); + let key = pkey.private_key_to_pem_pkcs8().unwrap(); + let mut x509_name = openssl::x509::X509NameBuilder::new().unwrap(); + x509_name.append_entry_by_text("CN", "test_ca").unwrap(); + let x509_name = x509_name.build(); + let mut x509 = openssl::x509::X509::builder().unwrap(); + x509.set_version(2).unwrap(); + x509.set_subject_name(&x509_name).unwrap(); + x509.set_not_before(&openssl::asn1::Asn1Time::days_from_now(0).unwrap()) + .unwrap(); + x509.set_not_after(&openssl::asn1::Asn1Time::days_from_now(365).unwrap()) + .unwrap(); + x509.set_pubkey(&pkey).unwrap(); + x509.sign(&pkey, openssl::hash::MessageDigest::sha256()) + .unwrap(); + let x509 = x509.build(); + let ca_cert = x509.to_pem().unwrap(); + let cert = x509.to_pem().unwrap(); + std::fs::write(path_key, key).unwrap(); + std::fs::write(path_ca, ca_cert).unwrap(); + std::fs::write(path_cert, cert).unwrap(); + } + + #[test] + #[serial(pki)] + fn test_get_certificate() { + create_test_pki_config_dir(); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); + let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); + let path_ne = PathBuf::from(PKI_CONFIG_DIR).join("not_existing.crt"); + let result = get_certificate(vec![path_cert.clone()]).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].tag(), "CERTIFICATE"); + let result = get_certificate(vec![path_ca.clone()]).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].tag(), "CERTIFICATE"); + let result = get_certificate(vec![ + path_key.clone(), + path_ne.clone(), + path_cert.clone(), + path_ca.clone(), + ]) + .unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].tag(), "CERTIFICATE"); + let result = get_certificate(vec![path_key]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_certificate(vec![]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_certificate(vec![path_ne]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + } + + #[test] + #[serial(pki)] + fn test_get_key_pair() { + create_test_pki_config_dir(); + let path_key = PathBuf::from(PKI_CONFIG_DIR).join(PKI_KEY_FILE_NAME); + let path_cert = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CERT_FILE_NAME); + let path_ca = PathBuf::from(PKI_CONFIG_DIR).join(PKI_CA_FILE_NAME); + let path_ne = PathBuf::from(PKI_CONFIG_DIR).join("not_existing.key"); + let result = get_key_pair(vec![path_key.clone()]); + assert!(result.is_ok()); + let result = get_key_pair(vec![path_ne.clone(), path_key.clone()]); + assert!(result.is_ok()); + let result = + get_key_pair(vec![path_ne.clone(), path_cert.clone(), path_ca.clone()]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_key_pair(vec![]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + let result = get_key_pair(vec![path_ne]).unwrap_err(); + assert!(matches!(result, DshError::NoCertificates)); + } + + #[test] + #[serial(pki, env_dependency)] + fn test_get_pki_cert() { + create_test_pki_config_dir(); + let result = get_pki_cert().unwrap_err(); + assert!(matches!(result, DshError::EnvVarError(_))); + std::env::set_var(VAR_PKI_CONFIG_DIR, PKI_CONFIG_DIR); + let result = get_pki_cert(); + assert!(result.is_ok()); + std::env::remove_var(VAR_PKI_CONFIG_DIR); + } +} diff --git a/dsh_sdk/src/dsh_old/properties.rs b/dsh_sdk/src/dsh_old/properties.rs index 56cd5de..f971495 100644 --- a/dsh_sdk/src/dsh_old/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -10,16 +10,33 @@ //! ## Environment variables //! See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for //! more information configuring the consmer or producer via environment variables. -use log::{error, warn}; +//! +//! # Example +//! ``` +//! use dsh_sdk::Properties; +//! use rdkafka::consumer::{Consumer, StreamConsumer}; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let dsh_properties = Properties::get(); +//! let consumer_config = dsh_properties.consumer_rdkafka_config(); +//! let consumer: StreamConsumer = consumer_config.create()?; +//! +//! # Ok(()) +//! # } +//! ``` +use log::{debug, error, warn}; use std::env; -use std::sync::{Arc, OnceLock}; - -use crate::certificates::Cert; -use crate::datastream; -use crate::error::DshError; +use std::sync::OnceLock; +use super::bootstrap::bootstrap; +use super::error::DshError; +use super::{certificates, config, datastream, pki_config_dir}; use crate::utils; use crate::*; +static PROPERTIES: OnceLock = OnceLock::new(); +static CONSUMER_CONFIG: OnceLock = OnceLock::new(); +static PRODUCER_CONFIG: OnceLock = OnceLock::new(); /// DSH properties struct. Create new to initialize all related components to connect to the DSH kafka clusters /// - Contains info from datastreams.json @@ -29,14 +46,29 @@ use crate::*; /// ## Environment variables /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for /// more information configuring the consmer or producer via environment variables. +/// +/// # Example +/// ``` +/// use dsh_sdk::Properties; +/// use rdkafka::consumer::{Consumer, StreamConsumer}; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let dsh_properties = Properties::get(); +/// +/// let consumer_config = dsh_properties.consumer_rdkafka_config(); +/// let consumer: StreamConsumer = consumer_config.create()?; +/// +/// Ok(()) +/// } +/// ``` -#[deprecated(since = "0.5.0", note = "`Properties` is renamed to `dsh_sdk::Dsh`")] #[derive(Debug, Clone)] pub struct Properties { config_host: String, task_id: String, tenant_name: String, - datastream: Arc, + datastream: datastream::Datastream, certificates: Option, } @@ -49,7 +81,6 @@ impl Properties { datastream: datastream::Datastream, certificates: Option, ) -> Self { - let datastream = Arc::new(datastream); Self { config_host, task_id, @@ -85,7 +116,6 @@ impl Properties { /// # } /// ``` pub fn get() -> &'static Self { - static PROPERTIES: OnceLock = OnceLock::new(); PROPERTIES.get_or_init(|| tokio::task::block_in_place(Self::init)) } @@ -116,8 +146,8 @@ impl Properties { let config_host = config_host.unwrap_or(DEFAULT_CONFIG_HOST.to_string()); // Default is for running on local machine with VPN let fetched_datastreams = certificates.as_ref().and_then(|cert| { cert.reqwest_blocking_client_config() - .build() .ok() + .and_then(|cb| cb.build().ok()) .and_then(|client| { datastream::Datastream::fetch_blocking( &client, @@ -137,6 +167,172 @@ impl Properties { Self::new(config_host, task_id, tenant_name, datastream, certificates) } + /// Get default RDKafka Consumer config to connect to Kafka on DSH. + /// + /// Note: This config is set to auto commit to false. You need to manually commit offsets. + /// You can overwrite this config by setting the enable.auto.commit and enable.auto.offset.store property to `true`. + /// + /// # Group ID + /// There are 2 types of group id's in DSH: private and shared. Private will have a unique group id per running instance. + /// Shared will have the same group id for all running instances. With this you can horizontally scale your service. + /// The group type can be manipulated by environment variable KAFKA_CONSUMER_GROUP_TYPE. + /// If not set, it will default to shared. + /// + /// # Example + /// ``` + /// use dsh_sdk::Properties; + /// use rdkafka::config::RDKafkaLogLevel; + /// use rdkafka::consumer::stream_consumer::StreamConsumer; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let dsh_properties = Properties::get(); + /// let mut consumer_config = dsh_properties.consumer_rdkafka_config(); + /// let consumer: StreamConsumer = consumer_config.create()?; + /// Ok(()) + /// } + /// ``` + /// + /// # Default configs + /// See full list of configs properties in case you want to add/overwrite the config: + /// + /// + /// Some configurations are overwitable by environment variables. + /// + /// | **config** | **Default value** | **Remark** | + /// |---------------------------|----------------------------------|------------------------------------------------------------------------| + /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | + /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| + /// | `client.id` | Task_id of service | | + /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | + /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | + /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | + /// | `ssl.key.pem` | private key | Generated when bootstrap is initiated | + /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | + /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | + /// + /// ## Environment variables + /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information + /// configuring the consmer via environment variables. + #[cfg(feature = "rdkafka-config")] + pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("group.id", self.kafka_group_id()) + .set("client.id", self.client_id()) + .set("enable.auto.commit", self.kafka_auto_commit().to_string()) + .set("auto.offset.reset", self.kafka_auto_offset_reset()); + if let Some(session_timeout) = consumer_config.session_timeout() { + config.set("session.timeout.ms", session_timeout.to_string()); + } + if let Some(queued_buffering_max_messages_kbytes) = + consumer_config.queued_buffering_max_messages_kbytes() + { + config.set( + "queued.max.messages.kbytes", + queued_buffering_max_messages_kbytes.to_string(), + ); + } + debug!("Consumer config: {:#?}", config); + // Set SSL if certificates are present + if let Ok(certificates) = &self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config + } + + /// Get default RDKafka Producer config to connect to Kafka on DSH. + /// If certificates are present, it will use SSL to connect to Kafka. + /// If not, it will use plaintext so it can connect to local as well. + /// + /// Note: The default config is set to auto commit to false. You need to manually commit offsets. + /// + /// # Example + /// ``` + /// use rdkafka::config::RDKafkaLogLevel; + /// use rdkafka::producer::FutureProducer; + /// use dsh_sdk::Properties; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box>{ + /// let dsh_properties = Properties::get(); + /// let mut producer_config = dsh_properties.producer_rdkafka_config(); + /// let producer: FutureProducer = producer_config.create().expect("Producer creation failed"); + /// Ok(()) + /// } + /// ``` + /// + /// # Default configs + /// See full list of configs properties in case you want to manually add/overwrite the config: + /// + /// + /// | **config** | **Default value** | **Remark** | + /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| + /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | + /// | client.id | task_id of service | Based on task_id of running service | + /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | + /// | ssl.key.pem | private key | Generated when bootstrap is initiated | + /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | + /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | + /// | log_level | Info | Log level of rdkafka | + /// + /// ## Environment variables + /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information + /// configuring the producer via environment variables. + #[cfg(feature = "rdkafka-config")] + pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { + let producer_config = PRODUCER_CONFIG.get_or_init(config::ProducerConfig::new); + let mut config = rdkafka::config::ClientConfig::new(); + config + .set("bootstrap.servers", self.kafka_brokers()) + .set("client.id", self.client_id()); + if let Some(batch_num_messages) = producer_config.batch_num_messages() { + config.set("batch.num.messages", batch_num_messages.to_string()); + } + if let Some(queue_buffering_max_messages) = producer_config.queue_buffering_max_messages() { + config.set( + "queue.buffering.max.messages", + queue_buffering_max_messages.to_string(), + ); + } + if let Some(queue_buffering_max_kbytes) = producer_config.queue_buffering_max_kbytes() { + config.set( + "queue.buffering.max.kbytes", + queue_buffering_max_kbytes.to_string(), + ); + } + if let Some(queue_buffering_max_ms) = producer_config.queue_buffering_max_ms() { + config.set("queue.buffering.max.ms", queue_buffering_max_ms.to_string()); + } + debug!("Producer config: {:#?}", config); + + // Set SSL if certificates are present + if let Ok(certificates) = self.certificates() { + config + .set("security.protocol", "ssl") + .set("ssl.key.pem", certificates.private_key_pem()) + .set( + "ssl.certificate.pem", + certificates.dsh_kafka_certificate_pem(), + ) + .set("ssl.ca.pem", certificates.dsh_ca_certificate_pem()); + } else { + config.set("security.protocol", "plaintext"); + } + config + } + /// Get reqwest async client config to connect to DSH Schema Registry. /// If certificates are present, it will use SSL to connect to Schema Registry. /// @@ -156,7 +352,7 @@ impl Properties { pub fn reqwest_client_config(&self) -> Result { let mut client_builder = reqwest::Client::builder(); if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_client_config(); + client_builder = certificates.reqwest_client_config()?; } Ok(client_builder) } @@ -171,7 +367,7 @@ impl Properties { /// # use dsh_sdk::Properties; /// # use reqwest::blocking::Client; /// # use dsh_sdk::error::DshError; - /// # fn main() -> Result<(), DshError> { + /// # fn main() -> Result<(), Box> { /// let dsh_properties = Properties::get(); /// let client = dsh_properties.reqwest_blocking_client_config()?.build()?; /// # Ok(()) @@ -182,7 +378,7 @@ impl Properties { let mut client_builder: reqwest::blocking::ClientBuilder = reqwest::blocking::Client::builder(); if let Ok(certificates) = &self.certificates() { - client_builder = certificates.reqwest_blocking_client_config(); + client_builder = certificates.reqwest_blocking_client_config()?; } Ok(client_builder) } @@ -193,7 +389,7 @@ impl Properties { /// ```no_run /// # use dsh_sdk::Properties; /// # use dsh_sdk::error::DshError; - /// # fn main() -> Result<(), DshError> { + /// # fn main() -> Result<(), Box>{ /// let dsh_properties = Properties::get(); /// let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); /// # Ok(()) @@ -325,7 +521,6 @@ impl Properties { } } - #[cfg(feature = "kafka")] /// Get the confifured kafka auto commit setinngs. /// /// ## Environment variables @@ -337,13 +532,10 @@ impl Properties { /// - Required: `false` /// - Options: `true`, `false` pub fn kafka_auto_commit(&self) -> bool { - crate::protocol_adapters::kafka_protocol::config::KafkaConfig::new(Some( - self.datastream.clone(), - )) - .enable_auto_commit() + let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); + consumer_config.enable_auto_commit() } - #[cfg(feature = "kafka")] /// Get the kafka auto offset reset settings. /// /// ## Environment variables @@ -355,111 +547,14 @@ impl Properties { /// - Required: `false` /// - Options: smallest, earliest, beginning, largest, latest, end pub fn kafka_auto_offset_reset(&self) -> String { - crate::protocol_adapters::kafka_protocol::config::KafkaConfig::new(Some( - self.datastream.clone(), - )) - .auto_offset_reset() - .to_string() - } - - /// Get default RDKafka Consumer config to connect to Kafka on DSH. - /// - /// Note: This config is set to auto commit to false. You need to manually commit offsets. - /// You can overwrite this config by setting the enable.auto.commit and enable.auto.offset.store property to `true`. - /// - /// # Group ID - /// There are 2 types of group id's in DSH: private and shared. Private will have a unique group id per running instance. - /// Shared will have the same group id for all running instances. With this you can horizontally scale your service. - /// The group type can be manipulated by environment variable KAFKA_CONSUMER_GROUP_TYPE. - /// If not set, it will default to shared. - /// - /// # Example - /// ``` - /// use dsh_sdk::Properties; - /// use rdkafka::consumer::stream_consumer::StreamConsumer; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box> { - /// let dsh_properties = Properties::get(); - /// let mut consumer_config = dsh_properties.consumer_rdkafka_config(); - /// let consumer: StreamConsumer = consumer_config.create()?; - /// Ok(()) - /// } - /// ``` - /// - /// # Default configs - /// See full list of configs properties in case you want to add/overwrite the config: - /// - /// - /// Some configurations are overwitable by environment variables. - /// - /// | **config** | **Default value** | **Remark** | - /// |---------------------------|----------------------------------|------------------------------------------------------------------------| - /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | - /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| - /// | `client.id` | Task_id of service | | - /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | - /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | - /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | - /// | `ssl.key.pem` | private key | Generated when bootstrap is initiated | - /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | - /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | - /// - /// ## Environment variables - /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information - /// configuring the consmer via environment variables. - #[cfg(feature = "rdkafka-config")] - pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - crate::Dsh::get().consumer_rdkafka_config() - } - - /// Get default RDKafka Producer config to connect to Kafka on DSH. - /// If certificates are present, it will use SSL to connect to Kafka. - /// If not, it will use plaintext so it can connect to local as well. - /// - /// Note: The default config is set to auto commit to false. You need to manually commit offsets. - /// - /// # Example - /// ``` - /// use rdkafka::producer::FutureProducer; - /// use dsh_sdk::Properties; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box>{ - /// let dsh_properties = Properties::get(); - /// let mut producer_config = dsh_properties.producer_rdkafka_config(); - /// let producer: FutureProducer = producer_config.create().expect("Producer creation failed"); - /// Ok(()) - /// } - /// ``` - /// - /// # Default configs - /// See full list of configs properties in case you want to manually add/overwrite the config: - /// - /// - /// | **config** | **Default value** | **Remark** | - /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| - /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | - /// | client.id | task_id of service | Based on task_id of running service | - /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | - /// | ssl.key.pem | private key | Generated when bootstrap is initiated | - /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | - /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | - /// | log_level | Info | Log level of rdkafka | - /// - /// ## Environment variables - /// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for more information - /// configuring the producer via environment variables. - #[cfg(feature = "rdkafka-config")] - pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { - crate::Dsh::get().producer_rdkafka_config() + let consumer_config = CONSUMER_CONFIG.get_or_init(config::ConsumerConfig::new); + consumer_config.auto_offset_reset() } } impl Default for Properties { fn default() -> Self { - let datastream = - Arc::new(datastream::Datastream::load_local_datastreams().unwrap_or_default()); + let datastream = datastream::Datastream::load_local_datastreams().unwrap_or_default(); Self { task_id: "local_task_id".to_string(), tenant_name: "local_tenant".to_string(), diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 2f5cbeb..a305e2a 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -114,8 +114,8 @@ pub mod dlq; #[cfg(feature = "bootstrap")] #[deprecated( since = "0.5.0", - note = "The `dsh` as module is phased out. Use - `dsh_sdk::Dsh` for all info about your running container; + note = "The `Properties` struct phased out. Use + `dsh_sdk::Dsh` for an all-in-one struct, similar to the original `Properties`; `dsh_sdk::certificates` for all certificate related info; `dsh_sdk::datastream` for all datastream related info; " From 5344c6f4669f9a0a4059a8c9d44f07449cb90da9 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:39:13 +0100 Subject: [PATCH 07/23] restore deprectated code to original to have proper deprecated warnings (#104) --- dsh_sdk/src/dsh_old/properties.rs | 2 + dsh_sdk/src/mqtt_token_fetcher.rs | 860 +++++++++++++++++- .../protocol_adapters/token_fetcher/mod.rs | 101 +- dsh_sdk/src/rest_api_token_fetcher.rs | 576 +++++++++++- dsh_sdk/src/utils/mod.rs | 20 +- 5 files changed, 1486 insertions(+), 73 deletions(-) diff --git a/dsh_sdk/src/dsh_old/properties.rs b/dsh_sdk/src/dsh_old/properties.rs index f971495..e0f539b 100644 --- a/dsh_sdk/src/dsh_old/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -737,12 +737,14 @@ mod tests { } #[test] + #[serial(env_dependency)] fn test_kafka_auto_commit() { let properties = Properties::default(); assert!(!properties.kafka_auto_commit()); } #[test] + #[serial(env_dependency)] fn test_kafka_auto_offset_reset() { let properties = Properties::default(); assert_eq!(properties.kafka_auto_offset_reset(), "earliest"); diff --git a/dsh_sdk/src/mqtt_token_fetcher.rs b/dsh_sdk/src/mqtt_token_fetcher.rs index 0fcdf55..19be160 100644 --- a/dsh_sdk/src/mqtt_token_fetcher.rs +++ b/dsh_sdk/src/mqtt_token_fetcher.rs @@ -1,24 +1,862 @@ -pub use crate::protocol_adapters::token_fetcher::*; +//! # MQTT Token Fetcher +//! +//! `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. +use std::fmt::{Display, Formatter}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::Mutex; -use crate::Platform; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; -#[deprecated( - since = "0.5.0", - note = "`MqttTokenFetcher` is renamed to `ProtocolTokenFetcher`" -)] -pub struct MqttTokenFetcher; +use crate::{error::DshError, Platform}; +/// `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. +/// +/// It ensures that the tokens are valid, and if not, it refreshes them automatically. The struct +/// is thread-safe and can be shared across multiple threads. + +pub struct MqttTokenFetcher { + tenant_name: String, + rest_api_key: String, + rest_token: Mutex, + rest_auth_url: String, + mqtt_token: Arc>>, // Mapping from Client ID to MqttToken + mqtt_auth_url: String, + client: reqwest::Client, + //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration + // port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too +} + +/// Constructs a new `MqttTokenFetcher`. +/// +/// # Arguments +/// +/// * `tenant_name` - The tenant name in DSH. +/// * `rest_api_key` - The REST API key used for authentication. +/// * `platform` - The DSH platform environment +/// +/// # Returns +/// +/// Returns a `Result` containing a `MqttTokenFetcher` instance or a `DshError`. impl MqttTokenFetcher { - pub fn new(tenant_name: String, api_key: String, platform: Platform) -> ProtocolTokenFetcher { - ProtocolTokenFetcher::new(tenant_name, api_key, platform) + /// Constructs a new `MqttTokenFetcher`. + /// + /// # Arguments + /// + /// * `tenant_name` - The tenant name of DSH. + /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. + /// * `platform` - The target DSH platform environment. + /// + /// # Example + /// + /// ```no_run + /// use dsh_sdk::mqtt_token_fetcher::MqttTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let tenant_name = "test_tenant".to_string(); + /// let api_key = "aAbB123".to_string(); + /// let platform = Platform::NpLz; + /// + /// let fetcher = MqttTokenFetcher::new(tenant_name, api_key, platform); + /// let token = fetcher.get_token("test_client", None).await.unwrap(); + /// # } + /// ``` + pub fn new(tenant_name: String, api_key: String, platform: Platform) -> MqttTokenFetcher { + const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + + let reqwest_client = reqwest::Client::builder() + .timeout(DEFAULT_TIMEOUT) + .http1_only() + .build() + .expect("Failed to build reqwest client"); + Self::new_with_client(tenant_name, api_key, platform, reqwest_client) } + /// Constructs a new `MqttTokenFetcher` with a custom reqwest client. + /// On this Reqwest client, you can set custom timeouts, headers, Rustls etc. + /// + /// # Arguments + /// + /// * `tenant_name` - The tenant name of DSH. + /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. + /// * `platform` - The target DSH platform environment. + /// * `client` - User configured reqwest client to be used for fetching tokens + /// + /// # Example + /// + /// ```no_run + /// use dsh_sdk::mqtt_token_fetcher::MqttTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let tenant_name = "test_tenant".to_string(); + /// let api_key = "aAbB123".to_string(); + /// let platform = Platform::NpLz; + /// let client = reqwest::Client::new(); + /// let fetcher = MqttTokenFetcher::new_with_client(tenant_name, api_key, platform, client); + /// let token = fetcher.get_token("test_client", None).await.unwrap(); + /// # } + /// ``` pub fn new_with_client( tenant_name: String, api_key: String, platform: Platform, client: reqwest::Client, - ) -> ProtocolTokenFetcher { - ProtocolTokenFetcher::new_with_client(tenant_name, api_key, platform, client) + ) -> MqttTokenFetcher { + let rest_token = RestToken::default(); + Self { + tenant_name, + rest_api_key: api_key, + rest_token: Mutex::new(rest_token), + rest_auth_url: platform.endpoint_rest_token().to_string(), + mqtt_token: Arc::new(Mutex::new(HashMap::new())), + mqtt_auth_url: platform.endpoint_mqtt_token().to_string(), + client, + } + } + /// Retrieves an MQTT token for the specified client ID. + /// + /// If the token is expired or does not exist, it fetches a new token. + /// + /// # Arguments + /// + /// * `client_id` - The identifier for the MQTT client. + /// * `claims` - Optional claims for the MQTT token. + /// + /// # Returns + /// + /// Returns a `Result` containing the `MqttToken` or a `DshError`. + pub async fn get_token( + &self, + client_id: &str, + claims: Option>, + ) -> Result { + match self.mqtt_token.lock().await.entry(client_id.to_string()) { + std::collections::hash_map::Entry::Occupied(mut entry) => { + let mqtt_token = entry.get_mut(); + if !mqtt_token.is_valid() { + *mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; + }; + Ok(mqtt_token.clone()) + } + std::collections::hash_map::Entry::Vacant(entry) => { + let mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; + entry.insert(mqtt_token.clone()); + Ok(mqtt_token) + } + } + } + + /// Fetches a new MQTT token from the platform. + /// + /// This method handles token validation and fetching the token + async fn fetch_new_mqtt_token( + &self, + client_id: &str, + claims: Option>, + ) -> Result { + let mut rest_token = self.rest_token.lock().await; + + if !rest_token.is_valid() { + *rest_token = RestToken::get( + &self.client, + &self.tenant_name, + &self.rest_api_key, + &self.rest_auth_url, + ) + .await? + } + + let authorization_header = format!("Bearer {}", rest_token.raw_token); + + let mqtt_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; + let payload = serde_json::to_value(&mqtt_token_request)?; + + let response = mqtt_token_request + .send( + &self.client, + &self.mqtt_auth_url, + &authorization_header, + &payload, + ) + .await?; + + MqttToken::new(response) + } +} + +/// Represent Claims information for MQTT request +/// * `action` - can be subscribe or publish +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Claims { + resource: Resource, + action: String, +} + +impl Claims { + pub fn new(resource: Resource, action: String) -> Claims { + Claims { resource, action } + } +} + +/// Enumeration representing possible actions in MQTT claims. +pub enum Actions { + Publish, + Subscribe, +} + +impl Display for Actions { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Actions::Publish => write!(f, "Publish"), + Actions::Subscribe => write!(f, "Subscribe"), + } + } +} + +/// Represents a resource in the MQTT claim. +/// +/// The resource defines what the client can access in terms of stream, prefix, topic, and type. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Resource { + stream: String, + prefix: String, + topic: String, + #[serde(rename = "type")] + type_: Option, +} + +impl Resource { + /// Creates a new `Resource` instance. Please check DSH MQTT Documentation for further explanation of the fields. + /// + /// # Arguments + /// + /// * `stream` - The data stream name. + /// * `prefix` - The prefix of the topic. + /// * `topic` - The topic name. + /// * `type_` - The optional type of the resource. + /// + /// + /// # Returns + /// + /// Returns a new `Resource` instance. + pub fn new(stream: String, prefix: String, topic: String, type_: Option) -> Resource { + Resource { + stream, + prefix, + topic, + type_, + } + } +} + +#[derive(Serialize)] +struct MqttTokenRequest { + id: String, + tenant: String, + claims: Option>, +} + +impl MqttTokenRequest { + fn new( + client_id: &str, + tenant: &str, + claims: Option>, + ) -> Result { + let mut hasher = Sha256::new(); + hasher.update(client_id); + let result = hasher.finalize(); + let id = format!("{:x}", result); + + Ok(Self { + id, + tenant: tenant.to_string(), + claims, + }) + } + + async fn send( + &self, + reqwest_client: &reqwest::Client, + mqtt_auth_url: &str, + authorization_header: &str, + payload: &serde_json::Value, + ) -> Result { + let response = reqwest_client + .post(mqtt_auth_url) + .header("Authorization", authorization_header) + .json(payload) + .send() + .await?; + + if response.status().is_success() { + Ok(response.text().await?) + } else { + Err(DshError::DshCallError { + url: mqtt_auth_url.to_string(), + status_code: response.status(), + error_body: response.text().await?, + }) + } + } +} + +/// Represents attributes associated with a mqtt token. +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "kebab-case")] +struct MqttTokenAttributes { + gen: i32, + endpoint: String, + iss: String, + claims: Option>, + exp: i32, + client_id: String, + iat: i32, + tenant_id: String, +} + +/// Represents a token used for MQTT connections. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MqttToken { + exp: i32, + raw_token: String, +} + +impl MqttToken { + /// Creates a new instance of `MqttToken` from a raw token string. + /// + /// # Arguments + /// + /// * `raw_token` - The raw token string. + /// + /// # Returns + /// + /// A Result containing the created MqttToken or an error. + pub fn new(raw_token: String) -> Result { + let header_payload = extract_header_and_payload(&raw_token)?; + + let decoded_token = decode_base64(header_payload)?; + + let token_attributes: MqttTokenAttributes = serde_json::from_slice(&decoded_token)?; + let token = MqttToken { + exp: token_attributes.exp, + raw_token, + }; + + Ok(token) + } + + /// Checks if the MQTT token is still valid. + fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i32; + self.exp >= current_unixtime + 5 + } +} + +/// Represents attributes associated with a Rest token. +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "kebab-case")] +struct RestTokenAttributes { + gen: i64, + endpoint: String, + iss: String, + claims: RestClaims, + exp: i32, + tenant_id: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct RestClaims { + #[serde(rename = "datastreams/v0/mqtt/token")] + datastreams_token: DatastreamsData, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DatastreamsData {} + +/// Represents a rest token with its raw value and attributes. +#[derive(Serialize, Deserialize, Debug)] +struct RestToken { + raw_token: String, + exp: i32, +} + +impl RestToken { + /// Retrieves a new REST token from the platform. + /// + /// # Arguments + /// + /// * `tenant` - The tenant name associated with the DSH platform. + /// * `api_key` - The REST API key used for authentication. + /// * `env` - The platform environment (e.g., production, staging). + /// + /// # Returns + /// + /// A Result containing the created `RestToken` or a `DshError`. + async fn get( + client: &reqwest::Client, + tenant: &str, + api_key: &str, + auth_url: &str, + ) -> Result { + let raw_token = Self::fetch_token(client, tenant, api_key, auth_url).await?; + + let header_payload = extract_header_and_payload(&raw_token)?; + + let decoded_token = decode_base64(header_payload)?; + + let token_attributes: RestTokenAttributes = serde_json::from_slice(&decoded_token)?; + let token = RestToken { + raw_token, + exp: token_attributes.exp, + }; + + Ok(token) + } + + // Checks if the REST token is still valid. + fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i32; + self.exp >= current_unixtime + 5 + } + + async fn fetch_token( + client: &reqwest::Client, + tenant: &str, + api_key: &str, + auth_url: &str, + ) -> Result { + let json_body = json!({"tenant": tenant}); + + let response = client + .post(auth_url) + .header("apikey", api_key) + .json(&json_body) + .send() + .await?; + + let status = response.status(); + let body_text = response.text().await?; + match status { + reqwest::StatusCode::OK => Ok(body_text), + _ => Err(DshError::DshCallError { + url: auth_url.to_string(), + status_code: status, + error_body: body_text, + }), + } + } +} + +impl Default for RestToken { + fn default() -> Self { + Self { + raw_token: "".to_string(), + exp: 0, + } + } +} + +/// Extracts the header and payload part of a JWT token. +/// +/// # Arguments +/// +/// * `raw_token` - The raw JWT token string. +/// +/// # Returns +/// +/// A Result containing the header and payload part of the JWT token or a `DshError`. +fn extract_header_and_payload(raw_token: &str) -> Result<&str, DshError> { + let parts: Vec<&str> = raw_token.split('.').collect(); + parts + .get(1) + .copied() + .ok_or_else(|| DshError::ParseDnError("Header and payload are missing".to_string())) +} + +/// Decodes a Base64-encoded string. +/// +/// # Arguments +/// +/// * `payload` - The Base64-encoded string. +/// +/// # Returns +/// +/// A Result containing the decoded byte vector or a `DshError`. +fn decode_base64(payload: &str) -> Result, DshError> { + use base64::{alphabet, engine, read}; + use std::io::Read; + + let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::NO_PAD); + let mut decoder = read::DecoderReader::new(payload.as_bytes(), &engine); + + let mut decoded_token = Vec::new(); + decoder + .read_to_end(&mut decoded_token) + .map_err(DshError::IoError)?; + + Ok(decoded_token) +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Matcher; + use tokio::sync::Mutex; + + async fn create_valid_fetcher() -> MqttTokenFetcher { + let exp_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600; + println!("exp_time: {}", exp_time); + let rest_token: RestToken = RestToken { + exp: exp_time as i32, + raw_token: "valid.token.payload".to_string(), + }; + let mqtt_token = MqttToken { + exp: exp_time, + raw_token: "valid.token.payload".to_string(), + }; + let mqtt_token_map = Arc::new(Mutex::new(HashMap::new())); + mqtt_token_map + .lock() + .await + .insert("test_client".to_string(), mqtt_token.clone()); + MqttTokenFetcher { + tenant_name: "test_tenant".to_string(), + rest_api_key: "test_api_key".to_string(), + rest_token: Mutex::new(rest_token), + rest_auth_url: "test_auth_url".to_string(), + mqtt_token: mqtt_token_map, + client: reqwest::Client::new(), + mqtt_auth_url: "test_auth_url".to_string(), + } + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_new() { + let tenant_name = "test_tenant".to_string(); + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let fetcher = MqttTokenFetcher::new(tenant_name, rest_api_key, platform); + + assert!(fetcher.mqtt_token.lock().await.is_empty()); + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_new_with_client() { + let tenant_name = "test_tenant".to_string(); + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); + let fetcher = + MqttTokenFetcher::new_with_client(tenant_name, rest_api_key, platform, client); + + assert!(fetcher.mqtt_token.lock().await.is_empty()); + } + + #[tokio::test] + async fn test_fetch_new_mqtt_token() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server.mock("POST", "/rest_auth_url") + .with_status(200) + .with_body(r#"{"raw_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"}"#) + .create_async() + .await; + let _m2 = mockito_server.mock("POST", "/mqtt_auth_url") + .with_status(200) + .with_body(r#"{"mqtt_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) + .create(); + + let client = reqwest::Client::new(); + let rest_token = RestToken { + raw_token: "initial_token".to_string(), + exp: 0, + }; + + let fetcher = MqttTokenFetcher { + client, + tenant_name: "test_tenant".to_string(), + rest_api_key: "test_api_key".to_string(), + mqtt_token: Arc::new(Mutex::new(HashMap::new())), + rest_auth_url: mockito_server.url() + "/rest_auth_url", + mqtt_auth_url: mockito_server.url() + "/mqtt_auth_url", + rest_token: Mutex::new(rest_token), + }; + + let result = fetcher.fetch_new_mqtt_token("test_client_id", None).await; + println!("{:?}", result); + assert!(result.is_ok()); + let mqtt_token = result.unwrap(); + assert_eq!(mqtt_token.exp, 1); + } + + #[tokio::test] + async fn test_mqtt_token_fetcher_get_token() { + let fetcher = create_valid_fetcher().await; + let token = fetcher.get_token("test_client", None).await.unwrap(); + assert_eq!(token.raw_token, "valid.token.payload"); + } + + #[test] + fn test_actions_display() { + let action = Actions::Publish; + assert_eq!(action.to_string(), "Publish"); + let action = Actions::Subscribe; + assert_eq!(action.to_string(), "Subscribe"); + } + + #[test] + fn test_token_request_new() { + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + assert_eq!(request.id.len(), 64); + assert_eq!(request.tenant, "test_tenant"); + } + + #[tokio::test] + async fn test_send_success() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/mqtt_auth_url") + .match_header("Authorization", "Bearer test_token") + .match_body(Matcher::Json(json!({"key": "value"}))) + .with_status(200) + .with_body("success_response") + .create(); + + let client = reqwest::Client::new(); + let payload = json!({"key": "value"}); + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let result = request + .send( + &client, + &format!("{}/mqtt_auth_url", mockito_server.url()), + "Bearer test_token", + &payload, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "success_response"); + } + + #[tokio::test] + async fn test_send_failure() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/mqtt_auth_url") + .match_header("Authorization", "Bearer test_token") + .match_body(Matcher::Json(json!({"key": "value"}))) + .with_status(400) + .with_body("error_response") + .create(); + + let client = reqwest::Client::new(); + let payload = json!({"key": "value"}); + let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let result = request + .send( + &client, + &format!("{}/mqtt_auth_url", mockito_server.url()), + "Bearer test_token", + &payload, + ) + .await; + + assert!(result.is_err()); + if let Err(DshError::DshCallError { + url, + status_code, + error_body, + }) = result + { + assert_eq!(url, format!("{}/mqtt_auth_url", mockito_server.url())); + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body, "error_response"); + } else { + panic!("Expected DshCallError"); + } + } + + #[test] + fn test_claims_new() { + let resource = Resource::new( + "stream".to_string(), + "prefix".to_string(), + "topic".to_string(), + None, + ); + let action = "publish".to_string(); + + let claims = Claims::new(resource.clone(), action.clone()); + + assert_eq!(claims.resource.stream, "stream"); + assert_eq!(claims.action, "publish"); + } + + #[test] + fn test_resource_new() { + let resource = Resource::new( + "stream".to_string(), + "prefix".to_string(), + "topic".to_string(), + None, + ); + + assert_eq!(resource.stream, "stream"); + assert_eq!(resource.prefix, "prefix"); + assert_eq!(resource.topic, "topic"); + } + + #[test] + fn test_mqtt_token_is_valid() { + let raw_token = "valid.token.payload".to_string(); + let token = MqttToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600, + raw_token, + }; + + assert!(token.is_valid()); + } + #[test] + fn test_mqtt_token_is_invalid() { + let raw_token = "valid.token.payload".to_string(); + let token = MqttToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32, + raw_token, + }; + + assert!(!token.is_valid()); + } + + #[test] + fn test_rest_token_is_valid() { + let token = RestToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32 + + 3600, + raw_token: "valid.token.payload".to_string(), + }; + + assert!(token.is_valid()); + } + + #[test] + fn test_rest_token_is_invalid() { + let token = RestToken { + exp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i32, + raw_token: "valid.token.payload".to_string(), + }; + + assert!(!token.is_valid()); + } + + #[test] + fn test_rest_token_default_is_invalid() { + let token = RestToken::default(); + + assert!(!token.is_valid()); + } + + #[test] + fn test_extract_header_and_payload() { + let raw = "header.payload.signature"; + let result = extract_header_and_payload(raw).unwrap(); + assert_eq!(result, "payload"); + + let raw = "header.payload"; + let result = extract_header_and_payload(raw).unwrap(); + assert_eq!(result, "payload"); + + let raw = "header"; + let result = extract_header_and_payload(raw); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_fetch_token_success() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/auth_url") + .match_header("apikey", "test_api_key") + .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) + .with_status(200) + .with_body("test_token") + .create(); + + let client = reqwest::Client::new(); + let result = RestToken::fetch_token( + &client, + "test_tenant", + "test_api_key", + &format!("{}/auth_url", mockito_server.url()), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test_token"); + } + + #[tokio::test] + async fn test_fetch_token_failure() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/auth_url") + .match_header("apikey", "test_api_key") + .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) + .with_status(400) + .with_body("error_response") + .create(); + + let client = reqwest::Client::new(); + let result = RestToken::fetch_token( + &client, + "test_tenant", + "test_api_key", + &format!("{}/auth_url", mockito_server.url()), + ) + .await; + + assert!(result.is_err()); + if let Err(DshError::DshCallError { + url, + status_code, + error_body, + }) = result + { + assert_eq!(url, format!("{}/auth_url", mockito_server.url())); + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body, "error_response"); + } else { + panic!("Expected DshCallError"); + } } } diff --git a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs index 245a61d..de392bc 100644 --- a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs +++ b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs @@ -22,8 +22,8 @@ pub struct ProtocolTokenFetcher { rest_api_key: String, rest_token: RwLock, rest_auth_url: String, - mqtt_token: RwLock>, // Mapping from Client ID to MqttToken - mqtt_auth_url: String, + protocol_token: RwLock>, // Mapping from Client ID to MqttToken + protocol_auth_url: String, client: reqwest::Client, //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration // port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too @@ -89,7 +89,7 @@ impl ProtocolTokenFetcher { /// # Example /// /// ```no_run - /// use dsh_sdk::mqtt_token_fetcher::ProtocolTokenFetcher; + /// use dsh_sdk::protocol_adapters::ProtocolTokenFetcher; /// use dsh_sdk::Platform; /// /// # #[tokio::main] @@ -114,8 +114,8 @@ impl ProtocolTokenFetcher { rest_api_key: api_key, rest_token: RwLock::new(rest_token), rest_auth_url: platform.endpoint_rest_token().to_string(), - mqtt_token: RwLock::new(HashMap::new()), - mqtt_auth_url: platform.endpoint_mqtt_token().to_string(), + protocol_token: RwLock::new(HashMap::new()), + protocol_auth_url: platform.endpoint_protocol_token().to_string(), client, } } @@ -136,18 +136,23 @@ impl ProtocolTokenFetcher { client_id: &str, claims: Option>, ) -> Result { - match self.mqtt_token.write().await.entry(client_id.to_string()) { + match self + .protocol_token + .write() + .await + .entry(client_id.to_string()) + { Entry::Occupied(mut entry) => { - let mqtt_token = entry.get_mut(); - if !mqtt_token.is_valid() { - *mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; + let protocol_token = entry.get_mut(); + if !protocol_token.is_valid() { + *protocol_token = self.fetch_new_protocol_token(client_id, claims).await?; }; - Ok(mqtt_token.clone()) + Ok(protocol_token.clone()) } Entry::Vacant(entry) => { - let mqtt_token = self.fetch_new_mqtt_token(client_id, claims).await?; - entry.insert(mqtt_token.clone()); - Ok(mqtt_token) + let protocol_token = self.fetch_new_protocol_token(client_id, claims).await?; + entry.insert(protocol_token.clone()); + Ok(protocol_token) } } } @@ -155,7 +160,7 @@ impl ProtocolTokenFetcher { /// Fetches a new MQTT token from the platform. /// /// This method handles token validation and fetching the token - async fn fetch_new_mqtt_token( + async fn fetch_new_protocol_token( &self, client_id: &str, claims: Option>, @@ -174,13 +179,13 @@ impl ProtocolTokenFetcher { let authorization_header = format!("Bearer {}", rest_token.raw_token); - let mqtt_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; - let payload = serde_json::to_value(&mqtt_token_request)?; + let protocol_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; + let payload = serde_json::to_value(&protocol_token_request)?; - let response = mqtt_token_request + let response = protocol_token_request .send( &self.client, - &self.mqtt_auth_url, + &self.protocol_auth_url, &authorization_header, &payload, ) @@ -286,12 +291,12 @@ impl MqttTokenRequest { async fn send( &self, reqwest_client: &reqwest::Client, - mqtt_auth_url: &str, + protocol_auth_url: &str, authorization_header: &str, payload: &serde_json::Value, ) -> Result { let response = reqwest_client - .post(mqtt_auth_url) + .post(protocol_auth_url) .header("Authorization", authorization_header) .json(payload) .send() @@ -301,7 +306,7 @@ impl MqttTokenRequest { Ok(response.text().await?) } else { Err(DshError::DshCallError { - url: mqtt_auth_url.to_string(), + url: protocol_auth_url.to_string(), status_code: response.status(), error_body: response.text().await?, }) @@ -528,39 +533,39 @@ mod tests { exp: exp_time as i32, raw_token: "valid.token.payload".to_string(), }; - let mqtt_token = MqttToken { + let protocol_token = MqttToken { exp: exp_time, raw_token: "valid.token.payload".to_string(), }; - let mqtt_token_map = RwLock::new(HashMap::new()); - mqtt_token_map + let protocol_token_map = RwLock::new(HashMap::new()); + protocol_token_map .write() .await - .insert("test_client".to_string(), mqtt_token.clone()); + .insert("test_client".to_string(), protocol_token.clone()); ProtocolTokenFetcher { tenant_name: "test_tenant".to_string(), rest_api_key: "test_api_key".to_string(), rest_token: RwLock::new(rest_token), rest_auth_url: "test_auth_url".to_string(), - mqtt_token: mqtt_token_map, + protocol_token: protocol_token_map, client: reqwest::Client::new(), - mqtt_auth_url: "test_auth_url".to_string(), + protocol_auth_url: "test_auth_url".to_string(), } } #[tokio::test] - async fn test_mqtt_token_fetcher_new() { + async fn test_protocol_token_fetcher_new() { let tenant_name = "test_tenant".to_string(); let rest_api_key = "test_api_key".to_string(); let platform = Platform::NpLz; let fetcher = ProtocolTokenFetcher::new(tenant_name, rest_api_key, platform); - assert!(fetcher.mqtt_token.read().await.is_empty()); + assert!(fetcher.protocol_token.read().await.is_empty()); } #[tokio::test] - async fn test_mqtt_token_fetcher_new_with_client() { + async fn test_protocol_token_fetcher_new_with_client() { let tenant_name = "test_tenant".to_string(); let rest_api_key = "test_api_key".to_string(); let platform = Platform::NpLz; @@ -569,20 +574,20 @@ mod tests { let fetcher = ProtocolTokenFetcher::new_with_client(tenant_name, rest_api_key, platform, client); - assert!(fetcher.mqtt_token.read().await.is_empty()); + assert!(fetcher.protocol_token.read().await.is_empty()); } #[tokio::test] - async fn test_fetch_new_mqtt_token() { + async fn test_fetch_new_protocol_token() { let mut mockito_server = mockito::Server::new_async().await; let _m = mockito_server.mock("POST", "/rest_auth_url") .with_status(200) .with_body(r#"{"raw_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"}"#) .create_async() .await; - let _m2 = mockito_server.mock("POST", "/mqtt_auth_url") + let _m2 = mockito_server.mock("POST", "/protocol_auth_url") .with_status(200) - .with_body(r#"{"mqtt_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) + .with_body(r#"{"protocol_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) .create(); let client = reqwest::Client::new(); @@ -595,21 +600,23 @@ mod tests { client, tenant_name: "test_tenant".to_string(), rest_api_key: "test_api_key".to_string(), - mqtt_token: RwLock::new(HashMap::new()), + protocol_token: RwLock::new(HashMap::new()), rest_auth_url: mockito_server.url() + "/rest_auth_url", - mqtt_auth_url: mockito_server.url() + "/mqtt_auth_url", + protocol_auth_url: mockito_server.url() + "/protocol_auth_url", rest_token: RwLock::new(rest_token), }; - let result = fetcher.fetch_new_mqtt_token("test_client_id", None).await; + let result = fetcher + .fetch_new_protocol_token("test_client_id", None) + .await; println!("{:?}", result); assert!(result.is_ok()); - let mqtt_token = result.unwrap(); - assert_eq!(mqtt_token.exp, 1); + let protocol_token = result.unwrap(); + assert_eq!(protocol_token.exp, 1); } #[tokio::test] - async fn test_mqtt_token_fetcher_get_token() { + async fn test_protocol_token_fetcher_get_token() { let fetcher = create_valid_fetcher().await; let token = fetcher.get_token("test_client", None).await.unwrap(); assert_eq!(token.raw_token, "valid.token.payload"); @@ -634,7 +641,7 @@ mod tests { async fn test_send_success() { let mut mockito_server = mockito::Server::new_async().await; let _m = mockito_server - .mock("POST", "/mqtt_auth_url") + .mock("POST", "/protocol_auth_url") .match_header("Authorization", "Bearer test_token") .match_body(Matcher::Json(json!({"key": "value"}))) .with_status(200) @@ -647,7 +654,7 @@ mod tests { let result = request .send( &client, - &format!("{}/mqtt_auth_url", mockito_server.url()), + &format!("{}/protocol_auth_url", mockito_server.url()), "Bearer test_token", &payload, ) @@ -661,7 +668,7 @@ mod tests { async fn test_send_failure() { let mut mockito_server = mockito::Server::new_async().await; let _m = mockito_server - .mock("POST", "/mqtt_auth_url") + .mock("POST", "/protocol_auth_url") .match_header("Authorization", "Bearer test_token") .match_body(Matcher::Json(json!({"key": "value"}))) .with_status(400) @@ -674,7 +681,7 @@ mod tests { let result = request .send( &client, - &format!("{}/mqtt_auth_url", mockito_server.url()), + &format!("{}/protocol_auth_url", mockito_server.url()), "Bearer test_token", &payload, ) @@ -687,7 +694,7 @@ mod tests { error_body, }) = result { - assert_eq!(url, format!("{}/mqtt_auth_url", mockito_server.url())); + assert_eq!(url, format!("{}/protocol_auth_url", mockito_server.url())); assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); assert_eq!(error_body, "error_response"); } else { @@ -726,7 +733,7 @@ mod tests { } #[test] - fn test_mqtt_token_is_valid() { + fn test_protocol_token_is_valid() { let raw_token = "valid.token.payload".to_string(); let token = MqttToken { exp: SystemTime::now() @@ -740,7 +747,7 @@ mod tests { assert!(token.is_valid()); } #[test] - fn test_mqtt_token_is_invalid() { + fn test_protocol_token_is_invalid() { let raw_token = "valid.token.payload".to_string(); let token = MqttToken { exp: SystemTime::now() diff --git a/dsh_sdk/src/rest_api_token_fetcher.rs b/dsh_sdk/src/rest_api_token_fetcher.rs index 8b06635..ed539e6 100644 --- a/dsh_sdk/src/rest_api_token_fetcher.rs +++ b/dsh_sdk/src/rest_api_token_fetcher.rs @@ -1,17 +1,147 @@ +//! Module for fetching and storing access tokens for the DSH Rest API client +//! +//! This module is meant to be used together with the [dsh_rest_api_client]. +//! +//! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. +//! +//! ## Example +//! Recommended usage is to use the [RestTokenFetcherBuilder] to create a new instance of the token fetcher. +//! However, you can also create a new instance of the token fetcher directly. +//! ```no_run +//! use dsh_sdk::{RestTokenFetcherBuilder, Platform}; +//! use dsh_rest_api_client::Client; +//! +//! const CLIENT_SECRET: &str = ""; +//! const TENANT: &str = "tenant-name"; +//! +//! #[tokio::main] +//! async fn main() { +//! let platform = Platform::NpLz; +//! let client = Client::new(platform.endpoint_rest_api()); +//! +//! let tf = RestTokenFetcherBuilder::new(platform) +//! .tenant_name(TENANT.to_string()) +//! .client_secret(CLIENT_SECRET.to_string()) +//! .build() +//! .unwrap(); +//! +//! let response = client +//! .topic_get_by_tenant_topic(TENANT, &tf.get_token().await.unwrap()) +//! .await; +//! println!("Available topics: {:#?}", response); +//! } +//! ``` + +use std::fmt::Debug; +use std::ops::Add; +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +use log::debug; +use serde::Deserialize; + +use crate::management_api::error::ManagementTokenError as DshRestTokenError; +use crate::utils::Platform; + +/// Access token of the authentication serveice of DSH. +/// +/// This is the response whem requesting for a new access token. +/// +/// ## Recommended usage +/// Use the [RestTokenFetcher::get_token] to get the bearer token, the `TokenFetcher` will automatically fetch a new token if the current token is not valid. +#[derive(Debug, Clone, Deserialize)] +pub struct AccessToken { + access_token: String, + expires_in: u64, + refresh_expires_in: u32, + token_type: String, + #[serde(rename(deserialize = "not-before-policy"))] + not_before_policy: u32, + scope: String, +} + +impl AccessToken { + /// Get the formatted token + pub fn formatted_token(&self) -> String { + format!("{} {}", self.token_type, self.access_token) + } + + /// Get the access token + pub fn access_token(&self) -> &str { + &self.access_token + } + + /// Get the expires in of the access token + pub fn expires_in(&self) -> u64 { + self.expires_in + } + + /// Get the refresh expires in of the access token + pub fn refresh_expires_in(&self) -> u32 { + self.refresh_expires_in + } + + /// Get the token type of the access token + pub fn token_type(&self) -> &str { + &self.token_type + } + + /// Get the not before policy of the access token + pub fn not_before_policy(&self) -> u32 { + self.not_before_policy + } + + /// Get the scope of the access token + pub fn scope(&self) -> &str { + &self.scope + } +} + +impl Default for AccessToken { + fn default() -> Self { + Self { + access_token: "".to_string(), + expires_in: 0, + refresh_expires_in: 0, + token_type: "".to_string(), + not_before_policy: 0, + scope: "".to_string(), + } + } +} + /// Fetch and store access tokens to be used in the DSH Rest API client /// /// This struct will fetch and store access tokens to be used in the DSH Rest API client. /// It will automatically fetch a new token if the current token is not valid. -pub struct RestTokenFetcher; +pub struct RestTokenFetcher { + access_token: Mutex, + fetched_at: Mutex, + client_id: String, + client_secret: String, + client: reqwest::Client, + auth_url: String, +} impl RestTokenFetcher { /// Create a new instance of the token fetcher - pub fn new( - client_id: String, - client_secret: String, - auth_url: String, - ) -> crate::ManagementApiTokenFetcher { - crate::ManagementApiTokenFetcher::new_with_client( + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::{RestTokenFetcher, Platform}; + /// use dsh_rest_api_client::Client; + /// + /// #[tokio::main] + /// async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret".to_string(); + /// let token_fetcher = RestTokenFetcher::new(client_id, client_secret, platform.endpoint_rest_access_token().to_string()); + /// let token = token_fetcher.get_token().await.unwrap(); + /// } + /// ``` + pub fn new(client_id: String, client_secret: String, auth_url: String) -> Self { + Self::new_with_client( client_id, client_secret, auth_url, @@ -20,30 +150,448 @@ impl RestTokenFetcher { } /// Create a new instance of the token fetcher with custom reqwest client + /// + /// ## Example + /// ```no_run + /// use dsh_sdk::{RestTokenFetcher, Platform}; + /// use dsh_rest_api_client::Client; + /// + /// #[tokio::main] + /// async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret".to_string(); + /// let client = reqwest::Client::new(); + /// let token_fetcher = RestTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_rest_access_token().to_string(), client); + /// let token = token_fetcher.get_token().await.unwrap(); + /// } + /// ``` pub fn new_with_client( client_id: String, client_secret: String, auth_url: String, client: reqwest::Client, - ) -> crate::ManagementApiTokenFetcher { - crate::ManagementApiTokenFetcher::new_with_client( + ) -> Self { + Self { + access_token: Mutex::new(AccessToken::default()), + fetched_at: Mutex::new(Instant::now()), client_id, client_secret, - auth_url, client, - ) + auth_url, + } + } + + /// Get token from the token fetcher + /// + /// If the cached token is not valid, it will fetch a new token from the server. + /// It will return the token as a string, formatted as "{token_type} {token}" + /// If the request fails for a new token, it will return a [DshRestTokenError::FailureTokenFetch] error. + /// This will contain the underlying reqwest error. + pub async fn get_token(&self) -> Result { + match self.is_valid() { + true => Ok(self.access_token.lock().unwrap().formatted_token()), + false => { + debug!("Token is expired, fetching new token"); + let access_token = self.fetch_access_token_from_server().await?; + let mut token = self.access_token.lock().unwrap(); + let mut fetched_at = self.fetched_at.lock().unwrap(); + *token = access_token; + *fetched_at = Instant::now(); + Ok(token.formatted_token()) + } + } + } + + /// Check if the current access token is still valid + /// + /// If the token has expired, it will return false. + pub fn is_valid(&self) -> bool { + let access_token = self.access_token.lock().unwrap_or_else(|mut e| { + **e.get_mut() = AccessToken::default(); + self.access_token.clear_poison(); + e.into_inner() + }); + let fetched_at = self.fetched_at.lock().unwrap_or_else(|e| { + self.fetched_at.clear_poison(); + e.into_inner() + }); + // Check if expires in has elapsed (+ safety margin of 5 seconds) + fetched_at.elapsed().add(Duration::from_secs(5)) + < Duration::from_secs(access_token.expires_in) + } + + /// Fetch a new access token from the server + /// + /// This will fetch a new access token from the server and return it. + /// If the request fails, it will return a [DshRestTokenError::FailureTokenFetch] error. + /// If the status code is not successful, it will return a [DshRestTokenError::StatusCode] error. + /// If the request is successful, it will return the [AccessToken]. + pub async fn fetch_access_token_from_server(&self) -> Result { + let response = self + .client + .post(&self.auth_url) + .form(&[ + ("client_id", self.client_id.as_ref()), + ("client_secret", self.client_secret.as_ref()), + ("grant_type", "client_credentials"), + ]) + .send() + .await + .map_err(DshRestTokenError::FailureTokenFetch)?; + if !response.status().is_success() { + Err(DshRestTokenError::StatusCode { + status_code: response.status(), + error_body: response.text().await.unwrap_or_default(), + }) + } else { + response + .json::() + .await + .map_err(DshRestTokenError::FailureTokenFetch) + } + } +} + +impl Debug for RestTokenFetcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RestTokenFetcher") + .field("access_token", &self.access_token) + .field("fetched_at", &self.fetched_at) + .field("client_id", &self.client_id) + .field("client_secret", &"xxxxxx") + .field("auth_url", &self.auth_url) + .finish() } } /// Builder for the token fetcher -pub struct RestTokenFetcherBuilder; +pub struct RestTokenFetcherBuilder { + client: Option, + client_id: Option, + client_secret: Option, + platform: Platform, + tenant_name: Option, +} impl RestTokenFetcherBuilder { /// Get a new instance of the ClientBuilder /// /// # Arguments /// * `platform` - The target platform to use for the token fetcher - pub fn new(platform: crate::Platform) -> crate::ManagementApiTokenFetcherBuilder { - crate::ManagementApiTokenFetcherBuilder::new(platform) + pub fn new(platform: Platform) -> Self { + Self { + client: None, + client_id: None, + client_secret: None, + platform, + tenant_name: None, + } + } + + /// Set the client_id for the client + /// + /// Alternatively, set `tenant_name` to generate the client_id. + /// `Client_id` does have precedence over `tenant_name`. + pub fn client_id(mut self, client_id: String) -> Self { + self.client_id = Some(client_id); + self + } + + /// Set the client_secret for the client + pub fn client_secret(mut self, client_secret: String) -> Self { + self.client_secret = Some(client_secret); + self + } + + /// Set the tenant_name for the client, this will generate the client_id + /// + /// Alternatively, set `client_id` directly. + /// `Tenant_name` does have precedence over `client_id`. + pub fn tenant_name(mut self, tenant_name: String) -> Self { + self.tenant_name = Some(tenant_name); + self + } + + /// Provide a custom configured Reqwest client for the token + /// + /// This is optional, if not provided, a default client will be used. + pub fn client(mut self, client: reqwest::Client) -> Self { + self.client = Some(client); + self + } + + /// Build the client and token fetcher + /// + /// This will build the client and token fetcher based on the given parameters. + /// It will return a tuple with the client and token fetcher. + /// + /// ## Example + /// ``` + /// # use dsh_sdk::{RestTokenFetcherBuilder, Platform}; + /// let platform = Platform::NpLz; + /// let client_id = "robot:dev-lz-dsh:my-tenant".to_string(); + /// let client_secret = "secret".to_string(); + /// let tf = RestTokenFetcherBuilder::new(platform) + /// .client_id(client_id) + /// .client_secret(client_secret) + /// .build() + /// .unwrap(); + /// ``` + pub fn build(self) -> Result { + let client_secret = self + .client_secret + .ok_or(DshRestTokenError::UnknownClientSecret)?; + let client_id = self + .client_id + .or_else(|| { + self.tenant_name + .as_ref() + .map(|tenant_name| self.platform.rest_client_id(tenant_name)) + }) + .ok_or(DshRestTokenError::UnknownClientId)?; + let client = self.client.unwrap_or_default(); + let token_fetcher = RestTokenFetcher::new_with_client( + client_id, + client_secret, + self.platform.endpoint_rest_access_token().to_string(), + client, + ); + Ok(token_fetcher) + } +} + +impl Debug for RestTokenFetcherBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let client_secret = self + .client_secret + .as_ref() + .map(|_| "Some(\"client_secret\")"); + f.debug_struct("RestTokenFetcherBuilder") + .field("client_id", &self.client_id) + .field("client_secret", &client_secret) + .field("platform", &self.platform) + .field("tenant_name", &self.tenant_name) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn create_mock_tf() -> RestTokenFetcher { + RestTokenFetcher { + access_token: Mutex::new(AccessToken::default()), + fetched_at: Mutex::new(Instant::now()), + client_id: "client_id".to_string(), + client_secret: "client_secret".to_string(), + client: reqwest::Client::new(), + auth_url: "http://localhost".to_string(), + } + } + + #[test] + fn test_access_token() { + let token_str = r#"{ + "access_token": "secret_access_token", + "expires_in": 600, + "refresh_expires_in": 0, + "token_type": "Bearer", + "not-before-policy": 0, + "scope": "email" + }"#; + let token: AccessToken = serde_json::from_str(token_str).unwrap(); + assert_eq!(token.access_token(), "secret_access_token"); + assert_eq!(token.expires_in(), 600); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), "email"); + assert_eq!(token.formatted_token(), "Bearer secret_access_token"); + } + + #[test] + fn test_access_token_default() { + let token = AccessToken::default(); + assert_eq!(token.access_token(), ""); + assert_eq!(token.expires_in(), 0); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), ""); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), ""); + assert_eq!(token.formatted_token(), " "); + } + + #[test] + fn test_rest_token_fetcher_is_valid_default_token() { + // Test is_valid when validating default token (should expire in 0 seconds) + let tf = create_mock_tf(); + assert!(!tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_valid_token() { + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + assert!(tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_expired_token() { + // Test is_valid when validating an expired token + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + *tf.fetched_at.lock().unwrap() = Instant::now() - Duration::from_secs(600); + assert!(!tf.is_valid()); + } + + #[test] + fn test_rest_token_fetcher_is_valid_poisoned_token() { + // Test is_valid when token is poisoned + let tf = create_mock_tf(); + tf.access_token.lock().unwrap().expires_in = 600; + let tf_arc = std::sync::Arc::new(tf); + let tf_clone = tf_arc.clone(); + assert!(tf_arc.is_valid(), "Token should be valid"); + let h = std::thread::spawn(move || { + let _unused = tf_clone.access_token.lock().unwrap(); + panic!("Poison token") + }); + let _ = h.join(); + assert!(!tf_arc.is_valid(), "Token should be invalid"); + } + + #[tokio::test] + async fn test_fetch_access_token_from_server() { + let mut auth_server = mockito::Server::new_async().await; + auth_server + .mock("POST", "/") + .with_status(200) + .with_body( + r#"{ + "access_token": "secret_access_token", + "expires_in": 600, + "refresh_expires_in": 0, + "token_type": "Bearer", + "not-before-policy": 0, + "scope": "email" + }"#, + ) + .create(); + let mut tf = create_mock_tf(); + tf.auth_url = auth_server.url(); + let token = tf.fetch_access_token_from_server().await.unwrap(); + assert_eq!(token.access_token(), "secret_access_token"); + assert_eq!(token.expires_in(), 600); + assert_eq!(token.refresh_expires_in(), 0); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.not_before_policy(), 0); + assert_eq!(token.scope(), "email"); + } + + #[tokio::test] + async fn test_fetch_access_token_from_server_error() { + let mut auth_server = mockito::Server::new_async().await; + auth_server + .mock("POST", "/") + .with_status(400) + .with_body("Bad request") + .create(); + let mut tf = create_mock_tf(); + tf.auth_url = auth_server.url(); + let err = tf.fetch_access_token_from_server().await.unwrap_err(); + match err { + DshRestTokenError::StatusCode { + status_code, + error_body, + } => { + assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(error_body, "Bad request"); + } + _ => panic!("Unexpected error: {:?}", err), + } + } + + #[test] + fn test_token_fetcher_builder_client_id() { + let platform = Platform::NpLz; + let client_id = "robot:dev-lz-dsh:my-tenant"; + let client_secret = "secret"; + let tf = RestTokenFetcherBuilder::new(platform) + .client_id(client_id.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_tenant_name() { + let platform = Platform::NpLz; + let tenant_name = "my-tenant"; + let client_secret = "secret"; + let tf = RestTokenFetcherBuilder::new(platform) + .tenant_name(tenant_name.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!( + tf.client_id, + format!("robot:{}:{}", Platform::NpLz.realm(), tenant_name) + ); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_custom_client() { + let platform = Platform::NpLz; + let client_id = "robot:dev-lz-dsh:my-tenant"; + let client_secret = "secret"; + let custom_client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); + let tf = RestTokenFetcherBuilder::new(platform) + .client_id(client_id.to_string()) + .client_secret(client_secret.to_string()) + .client(custom_client.clone()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_client_id_precedence() { + let platform = Platform::NpLz; + let tenant = "my-tenant"; + let client_id_override = "override"; + let client_secret = "secret"; + let tf = RestTokenFetcherBuilder::new(platform) + .tenant_name(tenant.to_string()) + .client_id(client_id_override.to_string()) + .client_secret(client_secret.to_string()) + .build() + .unwrap(); + assert_eq!(tf.client_id, client_id_override); + assert_eq!(tf.client_secret, client_secret); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + } + + #[test] + fn test_token_fetcher_builder_build_error() { + let err = RestTokenFetcherBuilder::new(Platform::NpLz) + .client_secret("client_secret".to_string()) + .build() + .unwrap_err(); + assert!(matches!(err, DshRestTokenError::UnknownClientId)); + + let err = RestTokenFetcherBuilder::new(Platform::NpLz) + .tenant_name("tenant_name".to_string()) + .build() + .unwrap_err(); + assert!(matches!(err, DshRestTokenError::UnknownClientSecret)); } } diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index 684c887..d364f8d 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -102,12 +102,22 @@ impl Platform { } } + #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] /// Get the endpoint for fetching DSH Rest Authentication Token /// /// With this token you can authenticate for the mqtt token endpoint /// /// It will return the endpoint for DSH Rest authentication token based on the platform pub fn endpoint_rest_token(&self) -> &str { + self.endpoint_management_api_token() + } + + /// Get the endpoint for fetching DSH Rest Authentication Token + /// + /// With this token you can authenticate for the mqtt token endpoint + /// + /// It will return the endpoint for DSH Rest authentication token based on the platform + pub fn endpoint_management_api_token(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token", @@ -117,10 +127,18 @@ impl Platform { } } - /// Get the endpoint for fetching DSH MQTT token + #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] + /// Get the endpoint for fetching DSH mqtt token /// /// It will return the endpoint for DSH MQTT Token based on the platform pub fn endpoint_mqtt_token(&self) -> &str { + self.endpoint_protocol_token() + } + + /// Get the endpoint for fetching DSH Protocol token + /// + /// It will return the endpoint for DSH Protocol adapter Token based on the platform + pub fn endpoint_protocol_token(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/datastreams/v0/mqtt/token", Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/datastreams/v0/mqtt/token", From 2f6ec807194cf30d3081a4dd16a24f4c1271373a Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:41:21 +0100 Subject: [PATCH 08/23] Untangle errors enum and update documentation (#105) * untangle errors * update documentation * update readme --- dsh_sdk/CHANGELOG.md | 19 +- dsh_sdk/README.md | 13 +- dsh_sdk/examples/protocol_token_fetcher.rs | 2 +- .../protocol_token_fetcher_specific_claims.rs | 2 +- dsh_sdk/src/certificates/bootstrap.rs | 33 +- dsh_sdk/src/certificates/error.rs | 26 + dsh_sdk/src/certificates/mod.rs | 88 ++- dsh_sdk/src/certificates/pki_config_dir.rs | 34 +- dsh_sdk/src/datastream/error.rs | 22 + .../src/{datastream.rs => datastream/mod.rs} | 70 ++- dsh_sdk/src/dsh.rs | 96 ++- dsh_sdk/src/dsh_old/mod.rs | 2 +- dsh_sdk/src/dsh_old/properties.rs | 2 - dsh_sdk/src/error.rs | 58 +- dsh_sdk/src/lib.rs | 86 +-- dsh_sdk/src/management_api/error.rs | 2 +- dsh_sdk/src/management_api/mod.rs | 43 +- dsh_sdk/src/management_api/token_fetcher.rs | 71 +-- dsh_sdk/src/metrics.rs | 4 +- dsh_sdk/src/mqtt_token_fetcher.rs | 2 +- dsh_sdk/src/protocol_adapters/error.rs | 19 + .../protocol_adapters/kafka_protocol/mod.rs | 19 +- dsh_sdk/src/protocol_adapters/mod.rs | 7 + .../protocol_adapters/token_fetcher/mod.rs | 91 ++- dsh_sdk/src/rest_api_token_fetcher.rs | 2 +- dsh_sdk/src/schema_store/api.rs | 66 +- dsh_sdk/src/schema_store/client.rs | 46 +- dsh_sdk/src/schema_store/mod.rs | 27 +- dsh_sdk/src/schema_store/request.rs | 21 +- dsh_sdk/src/utils/dlq.rs | 595 ------------------ dsh_sdk/src/utils/dlq/dlq.rs | 237 +++++++ dsh_sdk/src/utils/dlq/error.rs | 10 + dsh_sdk/src/utils/dlq/headers.rs | 221 +++++++ dsh_sdk/src/utils/dlq/mod.rs | 76 +++ dsh_sdk/src/utils/dlq/types.rs | 91 +++ dsh_sdk/src/utils/error.rs | 8 + dsh_sdk/src/utils/graceful_shutdown.rs | 4 +- dsh_sdk/src/utils/metrics.rs | 40 +- dsh_sdk/src/utils/mod.rs | 28 +- 39 files changed, 1180 insertions(+), 1103 deletions(-) create mode 100644 dsh_sdk/src/certificates/error.rs create mode 100644 dsh_sdk/src/datastream/error.rs rename dsh_sdk/src/{datastream.rs => datastream/mod.rs} (93%) create mode 100644 dsh_sdk/src/protocol_adapters/error.rs delete mode 100644 dsh_sdk/src/utils/dlq.rs create mode 100644 dsh_sdk/src/utils/dlq/dlq.rs create mode 100644 dsh_sdk/src/utils/dlq/error.rs create mode 100644 dsh_sdk/src/utils/dlq/headers.rs create mode 100644 dsh_sdk/src/utils/dlq/mod.rs create mode 100644 dsh_sdk/src/utils/dlq/types.rs create mode 100644 dsh_sdk/src/utils/error.rs diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index d822689..0620dfe 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -15,22 +15,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add support reading private key in DER format when reading from PKI_CONFIG_DIR ### Changed +- **Breaking change:** `DshError` is now split into error enums per feature flag to untangle mess + - `dsh_sdk::DshError` only applies on `bootstrap` feature flag - **Breaking change:** `dsh_sdk::Dsh::reqwest_client_config` now returns `reqwest::ClientConfig` instead of `Result` - **Breaking change:** `dsh_sdk::Dsh::reqwest_blocking_client_config` now returns `reqwest::ClientConfig` instead of `Result` - **Breaking change:** `dsh_sdk::utils::Dlq` does not require `Dsh`/`Properties` as argument anymore - -### Moved -- Moved `dsh_sdk::dsh::properties` to `dsh_sdk::propeties` +- Deprecated `dsh_sdk::dsh::properties` module - Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` - - **NOTE** Cargo.toml feature flag falls now under `management_api` (`rest-token-fetcher` will be removed in v0.6.0) +- `dsh_sdk::error::DshRestTokenError` renamed to `dsh_sdk::management_api::error::ManagementApiTokenError` + - **NOTE** Cargo.toml feature flag `rest-token-fetcher` renamed to`management-api-token-fetcher` - Moved `dsh_sdk::dsh::datastreams` to `dsh_sdk::datastreams` - Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module -- Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` where it is renamed to `ProtocolTokenFetcher` - - **NOTE** Cargo.toml feature flag falls now under `protocol-token-fetcher` (`mqtt-token-fetcher` will be removed in v0.6.0) -- Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` -- Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` -- Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` +- Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` and renamed to `ProtocolTokenFetcher` + - Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` +- **Breaking change:** Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` +- **Breaking change:** Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` +- **Breaking change:** Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` ### Removed - Removed `dsh_sdk::rdkafka` public re-export, import `rdkafka` directly diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index d5ca76f..9a18035 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -42,9 +42,11 @@ use dsh_sdk::DshKafkaConfig; use rdkafka::consumer::{Consumer, StreamConsumer}; use rdkafka::ClientConfig; -fn main() -> Result<(), Box>{ +#[tokio::main] +async fn main() -> Result<(), Box>{ // get a rdkafka consumer config for example let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; + Ok(()) } ``` @@ -59,15 +61,15 @@ The following features are available in this library and can be enabled/disabled | **feature** | **default** | **Description** | **Example** | | --- |--- | --- | --- | -| `bootstrap` | ✓ | Certificate signing process and fetch datastreams info | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | | `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | | `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | | `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | | `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | | `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | -| `metrics` | ✗ | Enable prometheus metrics including http server | [expose metrics](./examples/expose_metrics.rs) / [custom metrics](./examples/custom_metrics.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs) / [Custom metrics](./examples/custom_metrics.rs) | | `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | -| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementatione example](./examples/dlq_implementation.rs) | +| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](./examples/dlq_implementation.rs) | See the [api documentation](https://docs.rs/dsh_sdk/latest/dsh_sdk/) for more information on how to use these features. @@ -80,7 +82,8 @@ dsh_sdk = { version = "0.5", default-features = false, features = ["management-a ``` ## Environment variables -The SDK checks environment variables to change configuration for See [ENV_VARIABLES.md](ENV_VARIABLES.md) which . +The SDK checks environment variables to change configuration for connnecting to DSH. +See [ENV_VARIABLES.md](ENV_VARIABLES.md) which . ## Examples See folder [dsh_sdk/examples](./examples/) for simple examples on how to use the SDK. diff --git a/dsh_sdk/examples/protocol_token_fetcher.rs b/dsh_sdk/examples/protocol_token_fetcher.rs index 73e09eb..7467acb 100644 --- a/dsh_sdk/examples/protocol_token_fetcher.rs +++ b/dsh_sdk/examples/protocol_token_fetcher.rs @@ -8,7 +8,7 @@ async fn main() { let api_key = env::var("API_KEY").unwrap().to_string(); let mqtt_token_fetcher = ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); - let token: MqttToken = mqtt_token_fetcher + let token: ProtocolToken = mqtt_token_fetcher .get_token("Client-id", None) //Claims = None fetches all possible claims .await .unwrap(); diff --git a/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs b/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs index 4fb6ae2..c290bd4 100644 --- a/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs +++ b/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs @@ -21,7 +21,7 @@ async fn main() { let mqtt_token_fetcher = ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); - let token: MqttToken = mqtt_token_fetcher + let token: ProtocolToken = mqtt_token_fetcher .get_token("Client-id", Some(claims)) .await .unwrap(); diff --git a/dsh_sdk/src/certificates/bootstrap.rs b/dsh_sdk/src/certificates/bootstrap.rs index 746ce7b..f91b6c7 100644 --- a/dsh_sdk/src/certificates/bootstrap.rs +++ b/dsh_sdk/src/certificates/bootstrap.rs @@ -11,7 +11,7 @@ use reqwest::blocking::Client; use rcgen::{CertificateParams, CertificateSigningRequest, DnType, KeyPair}; -use crate::error::DshError; +use super::CertificatesError; use super::Cert; use crate::utils; @@ -22,7 +22,7 @@ pub(crate) fn bootstrap( config_host: &str, tenant_name: &str, task_id: &str, -) -> Result { +) -> Result { let dsh_config = DshBootstrapConfig::new(config_host, tenant_name, task_id)?; let client = reqwest_ca_client(dsh_config.dsh_ca_certificate.as_bytes())?; let dn = DshBootstapCall::Dn(&dsh_config).retryable_call(&client)?; @@ -46,14 +46,14 @@ fn get_signed_client_cert( dn: Dn, dsh_config: &DshBootstrapConfig, client: &Client, -) -> Result { +) -> Result { let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384)?; let csr = generate_csr(&key_pair, dn)?; let client_certificate = DshBootstapCall::CertificateSignRequest { config: dsh_config, csr: &csr.pem()?, } - .perform_call(client)?; + .retryable_call(client)?; let ca_cert = pem::parse_many(&dsh_config.dsh_ca_certificate)?; let client_cert = pem::parse_many(client_certificate)?; Ok(Cert::new( @@ -64,7 +64,10 @@ fn get_signed_client_cert( } /// Generate the certificate signing request. -fn generate_csr(key_pair: &KeyPair, dn: Dn) -> Result { +fn generate_csr( + key_pair: &KeyPair, + dn: Dn, +) -> Result { let mut params = CertificateParams::default(); params.distinguished_name.push(DnType::CommonName, dn.cn); params @@ -86,7 +89,11 @@ struct DshBootstrapConfig<'a> { dsh_ca_certificate: String, } impl<'a> DshBootstrapConfig<'a> { - fn new(config_host: &'a str, tenant_name: &'a str, task_id: &'a str) -> Result { + fn new( + config_host: &'a str, + tenant_name: &'a str, + task_id: &'a str, + ) -> Result { let dsh_secret_token = match utils::get_env_var(VAR_DSH_SECRET_TOKEN) { Ok(token) => token, Err(_) => { @@ -147,10 +154,10 @@ impl DshBootstapCall<'_> { } } - fn perform_call(&self, client: &Client) -> Result { + fn perform_call(&self, client: &Client) -> Result { let response = self.request_builder(client).send()?; if !response.status().is_success() { - return Err(DshError::DshCallError { + return Err(CertificatesError::DshCallError { url: self.url(), status_code: response.status(), error_body: response.text().unwrap_or_default(), @@ -159,7 +166,7 @@ impl DshBootstapCall<'_> { Ok(response.text()?) } - pub(crate) fn retryable_call(&self, client: &Client) -> Result { + pub(crate) fn retryable_call(&self, client: &Client) -> Result { let mut retries = 0; loop { match self.perform_call(client) { @@ -193,7 +200,7 @@ struct Dn { impl Dn { /// Parse the DN string into Dn struct. - fn parse_string(dn_string: &str) -> Result { + fn parse_string(dn_string: &str) -> Result { let mut cn = None; let mut ou = None; let mut o = None; @@ -211,13 +218,13 @@ impl Dn { } Ok(Dn { - cn: cn.ok_or(DshError::ParseDnError( + cn: cn.ok_or(CertificatesError::ParseDn( "CN is missing in DN string".to_string(), ))?, - ou: ou.ok_or(DshError::ParseDnError( + ou: ou.ok_or(CertificatesError::ParseDn( "OU is missing in DN string".to_string(), ))?, - o: o.ok_or(DshError::ParseDnError( + o: o.ok_or(CertificatesError::ParseDn( "O is missing in DN string".to_string(), ))?, }) diff --git a/dsh_sdk/src/certificates/error.rs b/dsh_sdk/src/certificates/error.rs new file mode 100644 index 0000000..fff3ace --- /dev/null +++ b/dsh_sdk/src/certificates/error.rs @@ -0,0 +1,26 @@ +/// Errors related to certificates +#[derive(Debug, thiserror::Error)] +pub enum CertificatesError { + #[error("Certificates are not set")] + NoCertificates, + #[error("Missing required injected variables")] + MisisngInjectedVariables, + #[error("Error parsing: {0}")] + ParseDn(String), + #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] + DshCallError { + url: String, + status_code: reqwest::StatusCode, + error_body: String, + }, + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + #[error("Rcgen error: {0}")] + PrivateKey(#[from] rcgen::Error), + #[error("Invalid PEM certificate: {0}")] + PemError(#[from] pem::PemError), + #[error("Utils error: {0}")] + UtilsError(#[from] crate::utils::UtilsError), + #[error("Reqwest: {0}")] + Reqwest(#[from] reqwest::Error), +} diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs index 3976cc7..9ca97e3 100644 --- a/dsh_sdk/src/certificates/mod.rs +++ b/dsh_sdk/src/certificates/mod.rs @@ -1,4 +1,4 @@ -//! This module holds the certificate struct and its methods. +//! Handle DSH Certificates and bootstrap process //! //! The certificate struct holds the DSH CA certificate, the DSH Kafka certificate and //! the private key. It also has methods to create a reqwest client with the DSH Kafka @@ -21,23 +21,21 @@ //! # Ok(()) //! # } //! ``` -//! -//! ## Reqwest Client -//! With this request client we can retrieve datastreams.json and connect to Schema Registry. use std::path::PathBuf; use std::sync::Arc; -use log::{info, warn}; +use log::info; use rcgen::KeyPair; use reqwest::blocking::{Client, ClientBuilder}; -use crate::error::DshError; +#[doc(inline)] +pub use error::CertificatesError; + use crate::utils; -use crate::{DEFAULT_CONFIG_HOST, VAR_KAFKA_CONFIG_HOST, VAR_PKI_CONFIG_DIR, VAR_TASK_ID}; +use crate::{VAR_KAFKA_CONFIG_HOST, VAR_TASK_ID}; -#[cfg(feature = "bootstrap")] mod bootstrap; -#[cfg(feature = "bootstrap")] +mod error; mod pki_config_dir; /// Hold all relevant certificates and keys to connect to DSH Kafka Cluster and Schema Store. @@ -49,7 +47,7 @@ pub struct Cert { } impl Cert { - /// Create new `Cert` struct + /// Create new [Cert] struct fn new( dsh_ca_certificate_pem: String, dsh_client_certificate_pem: String, @@ -71,16 +69,15 @@ impl Cert { /// This method also allows you to easily switch between Kafka Proxy or VPN connection, based on `PKI_CONFIG_DIR` environment variable. /// /// ## Arguments - /// * `config_host` - The DSH config host where the CSR can be send to. (default: `"https://pikachu.dsh.marathon.mesos:4443"`) + /// * `config_host` - The DSH config host where the CSR can be send to. /// * `tenant_name` - The tenant name. /// * `task_id` - The task id of running container. - #[cfg(feature = "bootstrap")] pub fn from_bootstrap( config_host: &str, tenant_name: &str, task_id: &str, - ) -> Result { - bootstrap::bootstrap(config_host, tenant_name, task_id) + ) -> Result { + bootstrap::bootstrap(&config_host, tenant_name, task_id) } /// Bootstrap to DSH and sign the certificates based on the injected environment variables by DSH. @@ -90,23 +87,17 @@ impl Cert { /// /// Else it will check `KAFKA_CONFIG_HOST`, `MESOS_TASK_ID` and `MARATHON_APP_ID` environment variables to bootstrap to DSH and sign the certificates. /// These environment variables are injected by DSH. - #[cfg(feature = "bootstrap")] - pub fn from_env() -> Result { - if let Ok(path) = utils::get_env_var(VAR_PKI_CONFIG_DIR) { - Self::from_pki_config_dir(Some(path)) - } else { - let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST) - .map(|host| ensure_https_prefix(host)) - .unwrap_or_else(|_| { - warn!( - "{} is not set, using default value {}", - VAR_KAFKA_CONFIG_HOST, DEFAULT_CONFIG_HOST - ); - DEFAULT_CONFIG_HOST.to_string() - }); - let task_id = utils::get_env_var(VAR_TASK_ID)?; - let tenant_name = utils::tenant_name()?; + pub fn from_env() -> Result { + if let Ok(cert) = Self::from_pki_config_dir::(None) { + Ok(cert) + } else if let (Ok(config_host), Ok(task_id), Ok(tenant_name)) = ( + utils::get_env_var(VAR_KAFKA_CONFIG_HOST), + utils::get_env_var(VAR_TASK_ID), + utils::tenant_name(), + ) { Self::from_bootstrap(&config_host, &tenant_name, &task_id) + } else { + Err(CertificatesError::MisisngInjectedVariables) } } @@ -124,8 +115,7 @@ impl Cert { /// ## Note /// Only certificates in PEM format are supported. /// Key files should be in PKCS8 format and can be DER or PEM files. - #[cfg(feature = "bootstrap")] - pub fn from_pki_config_dir

(path: Option

) -> Result + pub fn from_pki_config_dir

(path: Option

) -> Result where P: AsRef, { @@ -134,10 +124,6 @@ impl Cert { /// Build an async reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. - #[deprecated( - since = "0.5.0", - note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" - )] pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), @@ -152,10 +138,6 @@ impl Cert { /// Build a reqwest client with the DSH Kafka certificate included. /// With this client we can retrieve datastreams.json and conenct to Schema Registry. - #[deprecated( - since = "0.5.0", - note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" - )] pub fn reqwest_blocking_client_config(&self) -> ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), @@ -205,17 +187,17 @@ impl Cert { /// # Example /// /// ```no_run - /// use dsh_sdk::Properties; + /// use dsh_sdk::certificates::Cert; /// use std::path::PathBuf; /// /// # fn main() -> Result<(), Box> { - /// let dsh_properties = Properties::get(); - /// let directory = PathBuf::from("dir"); - /// dsh_properties.certificates()?.to_files(&directory)?; + /// let certificates = Cert::from_env()?; + /// let directory = PathBuf::from("path/to/dir"); + /// certificates.to_files(&directory)?; /// # Ok(()) /// # } /// ``` - pub fn to_files(&self, dir: &PathBuf) -> Result<(), DshError> { + pub fn to_files(&self, dir: &PathBuf) -> Result<(), CertificatesError> { std::fs::create_dir_all(dir)?; Self::create_file(dir.join("ca.crt"), self.dsh_ca_certificate_pem())?; Self::create_file(dir.join("client.pem"), self.dsh_kafka_certificate_pem())?; @@ -223,7 +205,7 @@ impl Cert { Ok(()) } - fn create_file>(path: PathBuf, contents: C) -> Result<(), DshError> { + fn create_file>(path: PathBuf, contents: C) -> Result<(), CertificatesError> { std::fs::write(&path, contents)?; info!("File created ({})", path.display()); Ok(()) @@ -258,11 +240,11 @@ impl Cert { } /// Helper function to ensure that the host starts with `https://` (or `http://`) -fn ensure_https_prefix(host: String) -> String { - if host.starts_with("https://") || host.starts_with("http://") { - host +pub(crate) fn ensure_https_prefix(host: impl AsRef) -> String { + if host.as_ref().starts_with("http://") || host.as_ref().starts_with("https://") { + host.as_ref().to_string() } else { - format!("https://{}", host) + format!("https://{}", host.as_ref()) } } @@ -366,15 +348,15 @@ mod tests { #[test] fn test_ensure_https_prefix() { - let host = "http://example.com".to_string(); + let host = "http://example.com"; let result = ensure_https_prefix(host); assert_eq!(result, "http://example.com"); - let host = "https://example.com".to_string(); + let host = "https://example.com"; let result = ensure_https_prefix(host); assert_eq!(result, "https://example.com"); - let host = "example.com".to_string(); + let host = "example.com"; let result = ensure_https_prefix(host); assert_eq!(result, "https://example.com"); } diff --git a/dsh_sdk/src/certificates/pki_config_dir.rs b/dsh_sdk/src/certificates/pki_config_dir.rs index a2665f6..7e6c663 100644 --- a/dsh_sdk/src/certificates/pki_config_dir.rs +++ b/dsh_sdk/src/certificates/pki_config_dir.rs @@ -6,7 +6,7 @@ //! This also makes it possible to use the DSH SDK with Kafka Proxy //! or VPN outside of the DSH environment. use super::Cert; -use crate::error::DshError; +use super::CertificatesError; use crate::{utils, VAR_PKI_CONFIG_DIR}; use log::{debug, info, warn}; @@ -14,7 +14,7 @@ use pem::{self, Pem}; use rcgen::KeyPair; use std::path::{Path, PathBuf}; -pub(crate) fn get_pki_certificates

(pki_config_dir: Option

) -> Result +pub(crate) fn get_pki_certificates

(pki_config_dir: Option

) -> Result where P: AsRef, { @@ -38,7 +38,7 @@ where /// Get certificate from the PKI config directory /// /// Looks for all files containing client*.pem and client.crt in the PKI config directory. -fn get_certificate(cert_paths: Vec) -> Result, DshError> { +fn get_certificate(cert_paths: Vec) -> Result, CertificatesError> { 'file: for file in cert_paths { info!("{} - Reading certificate file", file.display()); if let Ok(ca_cert) = std::fs::read(&file) { @@ -63,13 +63,13 @@ fn get_certificate(cert_paths: Vec) -> Result, DshError> { } } info!("No (valid) certificates found in the PKI config directory"); - Err(DshError::NoCertificates) + Err(CertificatesError::NoCertificates) } /// Get key pair from a file in the PKI config directory /// /// Returns first succesfull converted key pair found in the given list of paths -fn get_key_pair(key_paths: Vec) -> Result { +fn get_key_pair(key_paths: Vec) -> Result { for file in key_paths { info!("{} - Reading key file", file.display()); if let Ok(bytes) = std::fs::read(&file) { @@ -91,14 +91,14 @@ fn get_key_pair(key_paths: Vec) -> Result { } } info!("No (valid) key found in the PKI config directory"); - Err(DshError::NoCertificates) + Err(CertificatesError::NoCertificates) } /// Get the path to the PKI config direc fn get_file_path_bufs

( prefix: &str, contains: PkiFileType, config_dir: P, -) -> Result, DshError> +) -> Result, CertificatesError> where P: AsRef, { @@ -228,11 +228,11 @@ mod tests { assert_eq!(result.len(), 1); assert_eq!(result[0].tag(), "CERTIFICATE"); let result = get_certificate(vec![path_key]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_certificate(vec![]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_certificate(vec![path_ne]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); } #[test] @@ -249,11 +249,11 @@ mod tests { assert!(result.is_ok()); let result = get_key_pair(vec![path_ne.clone(), path_cert.clone(), path_ca.clone()]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_key_pair(vec![]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_key_pair(vec![path_ne]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); } #[test] @@ -270,11 +270,11 @@ mod tests { assert!(result.is_ok()); let result = get_key_pair(vec![path_ne.clone(), path_cert.clone(), path_ca.clone()]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_key_pair(vec![]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); let result = get_key_pair(vec![path_ne]).unwrap_err(); - assert!(matches!(result, DshError::NoCertificates)); + assert!(matches!(result, CertificatesError::NoCertificates)); } #[test] @@ -282,7 +282,7 @@ mod tests { fn test_get_pki_cert() { create_test_pki_config_dir(); let result = get_pki_certificates::(None).unwrap_err(); - assert!(matches!(result, DshError::EnvVarError(_, _))); + assert!(matches!(result, CertificatesError::UtilsError(_))); std::env::set_var(VAR_PKI_CONFIG_DIR, PKI_CONFIG_DIR); let result = get_pki_certificates::(None); assert!(result.is_ok()); diff --git a/dsh_sdk/src/datastream/error.rs b/dsh_sdk/src/datastream/error.rs new file mode 100644 index 0000000..53042f2 --- /dev/null +++ b/dsh_sdk/src/datastream/error.rs @@ -0,0 +1,22 @@ +/// Errors relataed to datastreams +#[derive(Debug, thiserror::Error)] +pub enum DatastreamError { + #[error("Error getting group id, index out of bounds for {0}")] + IndexGroupIdError(crate::datastream::GroupType), + #[error("Error getting topic name {0}, Topic not found in datastreams.")] + NotFoundTopicError(String), + #[error("Error in topic permissions: {0} does not have {1:?} permissions.")] + TopicPermissionsError(String, crate::datastream::ReadWriteAccess), + #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] + DshCallError { + url: String, + status_code: reqwest::StatusCode, + error_body: String, + }, + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + #[error("Serde_json error: {0}")] + JsonError(#[from] serde_json::Error), + #[error("Reqwest: {0}")] + Reqwest(#[from] reqwest::Error), +} diff --git a/dsh_sdk/src/datastream.rs b/dsh_sdk/src/datastream/mod.rs similarity index 93% rename from dsh_sdk/src/datastream.rs rename to dsh_sdk/src/datastream/mod.rs index 99c020c..2d027d8 100644 --- a/dsh_sdk/src/datastream.rs +++ b/dsh_sdk/src/datastream/mod.rs @@ -1,19 +1,23 @@ -//! Module to handle the datastreams.json file. +//! Datastream properties //! //! The datastreams.json can be parsed into a Datastream struct using serde_json. -//! This struct contains all the information from the datastreams.json file. -//! -//! You can get the Datastream struct via the 'Properties' struct. +//! This struct contains all the information from the datastreams properties file. //! //! # Example -//! ``` -//! use dsh_sdk::Properties; +//! ```no_run +//! use dsh_sdk::Dsh; //! -//! let properties = Properties::get(); -//! let datastream = properties.datastream(); +//! # #[tokio::main] +//! # async fn run() -> Result<(), Box> { +//! let dsh = Dsh::get(); +//! let datastream = dsh.datastream(); // immutable datastream which is fetched at initialization of SDK +//! // Or +//! let datastream = dsh.fetch_datastream().await?; // fetch a fresh datastream from dsh server //! //! let brokers = datastream.get_brokers(); -//! let schema_store = datastream.schema_store(); +//! let schema_store_url = datastream.schema_store(); +//! # Ok(()) +//! } //! ``` use std::collections::HashMap; use std::env; @@ -23,26 +27,31 @@ use std::io::Read; use log::{debug, error, info}; use serde::{Deserialize, Serialize}; -use crate::error::DshError; use crate::{ utils, VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_KAFKA_CONSUMER_GROUP_TYPE, VAR_LOCAL_DATASTREAMS_JSON, VAR_SCHEMA_REGISTRY_HOST, }; +#[doc(inline)] +pub use error::DatastreamError; + +mod error; const FILE_NAME: &str = "local_datastreams.json"; -/// This struct is equivalent to the datastreams.json +/// Datastream properties file +/// +/// Read from datastreams.json /// /// # Example /// ``` -/// use dsh_sdk::Properties; +/// use dsh_sdk::Dsh; /// -/// let properties = Properties::get(); +/// let properties = Dsh::get(); /// let datastream = properties.datastream(); /// /// let brokers = datastream.get_brokers(); /// let streams = datastream.streams(); -/// let schema_store = datastream.schema_store(); +/// let schema_store_url = datastream.schema_store(); /// ``` #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Datastream { @@ -70,14 +79,14 @@ impl Datastream { /// # Error /// If the index is greater then amount of groups in the datastreams /// (index out of bounds) - pub fn get_group_id(&self, group_type: GroupType) -> Result<&str, DshError> { + pub fn get_group_id(&self, group_type: GroupType) -> Result<&str, DatastreamError> { let group_id = match group_type { GroupType::Private(i) => self.private_consumer_groups.get(i), GroupType::Shared(i) => self.shared_consumer_groups.get(i), }; match group_id { Some(id) => Ok(id), - None => Err(DshError::IndexGroupIdError(group_type)), + None => Err(DatastreamError::IndexGroupIdError(group_type)), } } @@ -100,7 +109,7 @@ impl Datastream { &self, topics: &Vec, access: ReadWriteAccess, - ) -> Result<(), DshError> { + ) -> Result<(), DatastreamError> { let read_topics = self .streams() .values() @@ -129,7 +138,7 @@ impl Datastream { .collect::>() .join("."); if !read_topics.contains(&topic_name) { - return Err(DshError::NotFoundTopicError(topic.to_string())); + return Err(DatastreamError::NotFoundTopicError(topic.to_string())); } } Ok(()) @@ -156,7 +165,7 @@ impl Datastream { /// let path = std::path::PathBuf::from("/path/to/directory"); /// datastream.to_file(&path).unwrap(); /// ``` - pub fn to_file(&self, path: &std::path::Path) -> Result<(), DshError> { + pub fn to_file(&self, path: &std::path::Path) -> Result<(), DatastreamError> { let json_string = serde_json::to_string_pretty(self)?; std::fs::write(path.join("datastreams.json"), json_string)?; info!("File created ({})", path.display()); @@ -172,11 +181,11 @@ impl Datastream { host: &str, tenant: &str, task_id: &str, - ) -> Result { + ) -> Result { let url = Self::datastreams_endpoint(host, tenant, task_id); let response = client.get(&url).send().await?; if !response.status().is_success() { - return Err(DshError::DshCallError { + return Err(DatastreamError::DshCallError { url, status_code: response.status(), error_body: response.text().await.unwrap_or_default(), @@ -194,11 +203,11 @@ impl Datastream { host: &str, tenant: &str, task_id: &str, - ) -> Result { + ) -> Result { let url = Self::datastreams_endpoint(host, tenant, task_id); let response = client.get(&url).send()?; if !response.status().is_success() { - return Err(DshError::DshCallError { + return Err(DatastreamError::DshCallError { url, status_code: response.status(), error_body: response.text().unwrap_or_default(), @@ -215,7 +224,7 @@ impl Datastream { /// If it does not parse or the file is not found based on on Environment Variable, it will panic. /// If the Environment Variable is not set, it will look in the current directory. If it is not found, /// it will return a Error on the Result. Based on this it will use default Datastreams. - pub(crate) fn load_local_datastreams() -> Result { + pub(crate) fn load_local_datastreams() -> Result { let path_buf = if let Ok(path) = utils::get_env_var(VAR_LOCAL_DATASTREAMS_JSON) { let path = std::path::PathBuf::from(path); if !path.exists() { @@ -233,7 +242,7 @@ impl Datastream { path_buf.display(), e ); - DshError::IoError(e) + DatastreamError::IoError(e) })?; let mut contents = String::new(); file.read_to_string(&mut contents).unwrap(); @@ -352,11 +361,11 @@ impl Stream { /// /// ## Error /// If the topic does not have read access it returns a `TopicPermissionsError` - pub fn read_pattern(&self) -> Result<&str, DshError> { + pub fn read_pattern(&self) -> Result<&str, DatastreamError> { if self.read_access() { Ok(&self.read) } else { - Err(DshError::TopicPermissionsError( + Err(DatastreamError::TopicPermissionsError( self.name.clone(), ReadWriteAccess::Read, )) @@ -367,11 +376,11 @@ impl Stream { /// /// ## Error /// If the topic does not have write access it returns a `TopicPermissionsError` - pub fn write_pattern(&self) -> Result<&str, DshError> { + pub fn write_pattern(&self) -> Result<&str, DatastreamError> { if self.write_access() { Ok(&self.write) } else { - Err(DshError::TopicPermissionsError( + Err(DatastreamError::TopicPermissionsError( self.name.clone(), ReadWriteAccess::Write, )) @@ -386,6 +395,7 @@ pub enum ReadWriteAccess { Write, } +/// Enum to indicate the group type (private or shared) #[derive(Debug, PartialEq)] pub enum GroupType { Private(usize), @@ -670,7 +680,7 @@ mod tests { assert!(matches!( e, - DshError::TopicPermissionsError(_, ReadWriteAccess::Write) + DatastreamError::TopicPermissionsError(_, ReadWriteAccess::Write) )); } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index 182d24d..ce00f73 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -1,35 +1,36 @@ -//! # Dsh +//! High-level API to interact with DSH. //! -//! This module contains the High-level struct for all related -//! -//! From `Dsh` there are level functions to get the correct config to connect to Kafka and schema store. +//! From [Dsh] there are level functions to get the correct config to connect to Kafka and schema store. //! For more low level functions, see -//! - [datastream](datastream/index.html) module. -//! - [certificates](certificates/index.html) module. +//! - [datastream] module. +//! - [certificates] module. //! //! ## Environment variables //! See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for //! more information configuring the consmer or producer via environment variables. //! //! # Example -//! ``` +//! ```no_run //! use dsh_sdk::Dsh; //! use rdkafka::consumer::{Consumer, StreamConsumer}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! let dsh_properties = Dsh::get(); -//! let consumer_config = dsh_properties.consumer_rdkafka_config(); -//! let consumer: StreamConsumer = consumer_config.create()?; +//! let dsh = Dsh::get(); +//! let certificates = dsh.certificates()?; +//! let datastreams = dsh.datastream(); +//! let kafka_config = dsh.kafka_config(); +//! let tenant_name = dsh.tenant_name(); +//! let task_id = dsh.task_id(); //! //! # Ok(()) //! # } //! ``` -use log::{error, warn}; +use log::warn; use std::env; use std::sync::{Arc, OnceLock}; -use crate::certificates::Cert; +use crate::certificates::{ensure_https_prefix, Cert, CertificatesError}; use crate::datastream::Datastream; use crate::error::DshError; use crate::utils; @@ -41,7 +42,7 @@ use crate::protocol_adapters::kafka_protocol::config::KafkaConfig; // TODO: Remove at v0.6.0 pub use crate::dsh_old::*; -/// DSH properties struct. Create new to initialize all related components to connect to the DSH kafka clusters +/// Lazily initialize all related components to connect to the DSH /// - Contains info from datastreams.json /// - Metadata of running container/task /// - Certificates for Kafka and DSH Schema Registry @@ -113,32 +114,22 @@ impl Dsh { /// Initialize the properties and bootstrap to DSH fn init() -> Self { - let tenant_name = match utils::tenant_name() { - Ok(tenant_name) => tenant_name, - Err(_) => { - error!("{} and {} are not set, this may cause unexpected behaviour when connecting to DSH Kafka cluster!. Please set one of these environment variables.", VAR_APP_ID, VAR_DSH_TENANT_NAME); - "local_tenant".to_string() - } - }; + let tenant_name = utils::tenant_name().unwrap_or("local_tenant".to_string()); let task_id = utils::get_env_var(VAR_TASK_ID).unwrap_or("local_task_id".to_string()); - let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST) - .map(|host| format!("https://{}", host)) - .unwrap_or_else(|_| { - warn!( - "{} is not set, using default value {}", - VAR_KAFKA_CONFIG_HOST, DEFAULT_CONFIG_HOST - ); - DEFAULT_CONFIG_HOST.to_string() - }); + let config_host = + utils::get_env_var(VAR_KAFKA_CONFIG_HOST).map(|host| ensure_https_prefix(host)); let certificates = if let Ok(cert) = Cert::from_pki_config_dir::(None) { Some(cert) - } else { - Cert::from_bootstrap(&config_host, &tenant_name, &task_id) + } else if let Ok(config_host) = &config_host { + Cert::from_bootstrap(config_host, &tenant_name, &task_id) .inspect_err(|e| { warn!("Could not bootstrap to DSH, due to: {}", e); }) .ok() + } else { + None }; + let config_host = config_host.unwrap_or(DEFAULT_CONFIG_HOST.to_string()); let fetched_datastreams = certificates.as_ref().and_then(|cert| { cert.reqwest_blocking_client_config() .build() @@ -165,15 +156,11 @@ impl Dsh { /// # use reqwest::Client; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { - /// let dsh_properties = Dsh::get(); - /// let client = dsh_properties.reqwest_client_config().build()?; + /// let dsh = Dsh::get(); + /// let client = dsh.reqwest_client_config().build()?; /// # Ok(()) /// # } /// ``` - #[deprecated( - since = "0.5.0", - note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" - )] pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { if let Ok(certificates) = &self.certificates() { certificates.reqwest_client_config() @@ -185,23 +172,16 @@ impl Dsh { /// Get reqwest blocking client config to connect to DSH Schema Registry. /// If certificates are present, it will use SSL to connect to Schema Registry. /// - /// Use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) to connect to Schema Registry. - /// /// # Example /// ``` /// # use dsh_sdk::Dsh; /// # use reqwest::blocking::Client; - /// # use dsh_sdk::error::DshError; - /// # fn main() -> Result<(), DshError> { - /// let dsh_properties = Dsh::get(); - /// let client = dsh_properties.reqwest_blocking_client_config().build()?; + /// # fn main() -> Result<(), Box> { + /// let dsh = Dsh::get(); + /// let client = dsh.reqwest_blocking_client_config().build()?; /// # Ok(()) /// # } /// ``` - #[deprecated( - since = "0.5.0", - note = "Reqwest client is not used in DSH SDK, use `dsh_sdk::schema_store::SchemaStoreClient` instead" - )] pub fn reqwest_blocking_client_config(&self) -> reqwest::blocking::ClientBuilder { if let Ok(certificates) = &self.certificates() { certificates.reqwest_blocking_client_config() @@ -215,19 +195,18 @@ impl Dsh { /// # Example /// ```no_run /// # use dsh_sdk::Dsh; - /// # use dsh_sdk::error::DshError; - /// # fn main() -> Result<(), DshError> { - /// let dsh_properties = Dsh::get(); - /// let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); + /// # fn main() -> Result<(), Box> { + /// let dsh = Dsh::get(); + /// let dsh_kafka_certificate = dsh.certificates()?.dsh_kafka_certificate_pem(); /// # Ok(()) /// # } /// ``` pub fn certificates(&self) -> Result<&Cert, DshError> { - if let Some(cert) = &self.certificates { + Ok(if let Some(cert) = &self.certificates { Ok(cert) } else { - Err(DshError::NoCertificates) - } + Err(CertificatesError::NoCertificates) + }?) } /// Get the client id based on the task id. @@ -269,7 +248,7 @@ impl Dsh { .build() .expect("Could not build reqwest client for fetching datastream") }); - Datastream::fetch(client, &self.config_host, &self.tenant_name, &self.task_id).await + Ok(Datastream::fetch(client, &self.config_host, &self.tenant_name, &self.task_id).await?) } /// High level method to fetch the kafka properties provided by DSH (datastreams.json) in a blocking way. @@ -289,7 +268,12 @@ impl Dsh { .build() .expect("Could not build reqwest client for fetching datastream") }); - Datastream::fetch_blocking(client, &self.config_host, &self.tenant_name, &self.task_id) + Ok(Datastream::fetch_blocking( + client, + &self.config_host, + &self.tenant_name, + &self.task_id, + )?) } /// Get schema host of DSH. diff --git a/dsh_sdk/src/dsh_old/mod.rs b/dsh_sdk/src/dsh_old/mod.rs index 5de9d3e..647895d 100644 --- a/dsh_sdk/src/dsh_old/mod.rs +++ b/dsh_sdk/src/dsh_old/mod.rs @@ -25,7 +25,7 @@ mod bootstrap; pub mod certificates; mod config; pub mod datastream; -mod error; +pub mod error; mod pki_config_dir; pub mod properties; diff --git a/dsh_sdk/src/dsh_old/properties.rs b/dsh_sdk/src/dsh_old/properties.rs index e0f539b..1c0d310 100644 --- a/dsh_sdk/src/dsh_old/properties.rs +++ b/dsh_sdk/src/dsh_old/properties.rs @@ -366,7 +366,6 @@ impl Properties { /// ``` /// # use dsh_sdk::Properties; /// # use reqwest::blocking::Client; - /// # use dsh_sdk::error::DshError; /// # fn main() -> Result<(), Box> { /// let dsh_properties = Properties::get(); /// let client = dsh_properties.reqwest_blocking_client_config()?.build()?; @@ -388,7 +387,6 @@ impl Properties { /// # Example /// ```no_run /// # use dsh_sdk::Properties; - /// # use dsh_sdk::error::DshError; /// # fn main() -> Result<(), Box>{ /// let dsh_properties = Properties::get(); /// let dsh_kafka_certificate = dsh_properties.certificates()?.dsh_kafka_certificate_pem(); diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index a540e7d..a5baac5 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -1,56 +1,14 @@ -use thiserror::Error; - -#[derive(Error, Debug)] -#[non_exhaustive] +/// Errors for the DSH SDK +#[derive(Debug, thiserror::Error)] pub enum DshError { - #[error("IO Error: {0}")] - IoError(#[from] std::io::Error), - #[error("Env variable {0} error: {1}")] - EnvVarError(String, std::env::VarError), - #[error("Convert bytes to utf8 error: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] - #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] - DshCallError { - url: String, - status_code: reqwest::StatusCode, - error_body: String, - }, - #[cfg(feature = "bootstrap")] - #[error("Certificates are not set")] - NoCertificates, - #[cfg(feature = "bootstrap")] - #[error("Invalid PEM certificate: {0}")] - PemError(#[from] pem::PemError), - #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] + #[error("Certificates error: {0}")] + CertificatesError(#[from] crate::certificates::CertificatesError), + #[error("Datastream error: {0}")] + DatastreamError(#[from] crate::datastream::DatastreamError), + #[error("Utils error: {0}")] + UtilsError(#[from] crate::utils::UtilsError), #[error("Reqwest: {0}")] ReqwestError(#[from] reqwest::Error), - #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] - #[error("Serde_json error: {0}")] - JsonError(#[from] serde_json::Error), - #[cfg(feature = "bootstrap")] - #[error("Rcgen error: {0}")] - PrivateKeyError(#[from] rcgen::Error), - #[cfg(any(feature = "bootstrap", feature = "protocol-token-fetcher"))] - #[error("Error parsing: {0}")] - ParseDnError(String), - #[cfg(feature = "bootstrap")] - #[error("Error getting group id, index out of bounds for {0}")] - IndexGroupIdError(crate::datastream::GroupType), - #[error("No tenant name found")] - NoTenantName, - #[cfg(feature = "bootstrap")] - #[error("Error getting topic name {0}, Topic not found in datastreams.")] - NotFoundTopicError(String), - #[cfg(feature = "bootstrap")] - #[error("Error in topic permissions: {0} does not have {1:?} permissions.")] - TopicPermissionsError(String, crate::datastream::ReadWriteAccess), - #[cfg(feature = "metrics")] - #[error("Prometheus error: {0}")] - Prometheus(#[from] prometheus::Error), - #[cfg(feature = "metrics")] - #[error("Hyper error: {0}")] - HyperError(#[from] hyper::http::Error), } pub(crate) fn report(mut err: &dyn std::error::Error) -> String { diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index a305e2a..653d525 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -1,77 +1,4 @@ -//! # DSH -//! -//! Dsh properties struct. Create new to initialize all related components to connect to the DSH kafka clusters and get metadata of your tenant. -//! - Availablable datastreams info -//! - Metadata of running container/task -//! - Certificates for Kafka and DSH -//! -//! ## High level API -//! -//! The properties struct contains a high level API to interact with the DSH. -//! This includes generating RDKafka config for creating a consumer/producer and Reqwest config builder for Schema Registry. -//! -//! ### Example: -//! ``` -//! use dsh_sdk::DshKafkaConfig; -//! use rdkafka::ClientConfig; -//! use rdkafka::consumer::stream_consumer::StreamConsumer; -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box>{ -//! let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; -//! # Ok(()) -//! # } -//! ``` -//! -//! ## Low level API -//! It is also possible to get avaiable metadata or the certificates from the properties struct. -//! -//! ### Example: -//! ```no_run -//! use dsh_sdk::Dsh; - -//! # fn main() -> Result<(), Box>{ -//! let dsh = Dsh::get(); -//! // check for write access to topic -//! let write_access = dsh.datastream().get_stream("scratch.local.local-tenant").expect("Topic not found").write_access(); -//! // get the certificates, for example DSH_KAFKA_CERTIFICATE -//! let dsh_kafka_certificate = dsh.certificates()?.dsh_kafka_certificate_pem(); -//! # Ok(()) -//! # } -//! ``` -//! ## Kafka Proxy / VPN / Local -//! Read [CONNECT_PROXY_VPN_LOCAL.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/CONNECT_PROXY_VPN_LOCAL.md) on how to connect to DSH with Kafka Proxy, VPN or to a local Kafka cluster. -//! -//! # Metrics -//! The metrics module provides a way to expose prometheus metrics. This module is a re-export of the `prometheus` crate. It also contains a function to start a http server to expose the metrics to DSH. -//! -//! See [metrics](metrics/index.html) for more information. -//! -//! # Graceful shutdown -//! To implement a graceful shutdown in your service, you can use the `Shutdown` struct. This struct has an implementation based on the best practices example of Tokio. -//! -//! This gives you the option to properly handle shutdown in your components/tasks. -//! It listens for SIGTERM requests and sends out shutdown requests to all shutdown handles. -//! -//! See [graceful_shutdown](graceful_shutdown/index.html) for more information. -//! -//! # DLQ (Dead Letter Queue) -//! `OPTIONAL feature: dlq` -//! -//! This is an experimental feature and is not yet finalized. -//! -//! This implementation only includes pushing messages towards a kafka topic. (Dead or Retry topic) -//! -//! ### NOTE: -//! This implementation does not (and will not) handle any other DLQ related tasks like: -//! - Retrying messages -//! - Handling messages in DLQ -//! - Monitor the DLQ -//! Above tasks should be handled by a seperate component set up by the user, as these tasks are use case specific and can handle different strategies. -//! -//! The DLQ is implemented by running the `Dlq` struct to push messages towards the DLQ topics. -//! The `ErrorToDlq` trait can be implemented on your defined errors, to be able to send messages towards the DLQ Struct. - +#![doc = include_str!("../README.md")] #![allow(deprecated)] // to be kept in v0.6.0 @@ -81,7 +8,9 @@ pub mod certificates; pub mod datastream; #[cfg(feature = "bootstrap")] pub mod dsh; -pub mod error; +#[cfg(feature = "bootstrap")] +mod error; + #[cfg(feature = "management-api-token-fetcher")] pub mod management_api; pub mod protocol_adapters; @@ -92,16 +21,15 @@ pub mod schema_store; #[cfg(feature = "bootstrap")] #[doc(inline)] -pub use dsh::Dsh; +pub use {dsh::Dsh, error::DshError}; #[cfg(feature = "kafka")] #[doc(inline)] pub use protocol_adapters::kafka_protocol::DshKafkaConfig; #[cfg(feature = "management-api-token-fetcher")] -pub use management_api::token_fetcher::{ - ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder, -}; +#[doc(inline)] +pub use management_api::{ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; #[doc(inline)] pub use utils::Platform; diff --git a/dsh_sdk/src/management_api/error.rs b/dsh_sdk/src/management_api/error.rs index 62e90bd..01d678a 100644 --- a/dsh_sdk/src/management_api/error.rs +++ b/dsh_sdk/src/management_api/error.rs @@ -2,7 +2,7 @@ use thiserror::Error; #[derive(Error, Debug)] #[non_exhaustive] -pub enum ManagementTokenError { +pub enum ManagementApiTokenError { #[error("Client ID is unknown")] UnknownClientId, #[error("Client secret not set")] diff --git a/dsh_sdk/src/management_api/mod.rs b/dsh_sdk/src/management_api/mod.rs index 8f11a52..4bfe432 100644 --- a/dsh_sdk/src/management_api/mod.rs +++ b/dsh_sdk/src/management_api/mod.rs @@ -1,2 +1,41 @@ -pub mod error; -pub mod token_fetcher; +//! Fetch and store tokens for the DSH Management Rest API client +//! +//! This module is meant to be used together with the [dsh_rest_api_client]. +//! +//! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. +//! +//! ## Example +//! Recommended usage is to use the [ManagementApiTokenFetcherBuilder] to create a new instance of the token fetcher. +//! However, you can also create a new instance of the token fetcher directly. +//! ```no_run +//! use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; +//! use dsh_rest_api_client::Client; +//! +//! const CLIENT_SECRET: &str = ""; +//! const TENANT: &str = "tenant-name"; +//! +//! #[tokio::main] +//! async fn main() { +//! let platform = Platform::NpLz; +//! let client = Client::new(platform.endpoint_rest_api()); +//! +//! let tf = ManagementApiTokenFetcherBuilder::new(platform) +//! .tenant_name(TENANT.to_string()) +//! .client_secret(CLIENT_SECRET.to_string()) +//! .build() +//! .unwrap(); +//! +//! let response = client +//! .topic_get_by_tenant_topic(TENANT, &tf.get_token().await.unwrap()) +//! .await; +//! println!("Available topics: {:#?}", response); +//! } +//! ``` +mod error; +mod token_fetcher; + +#[doc(inline)] +pub use error::ManagementApiTokenError; + +#[doc(inline)] +pub use token_fetcher::{ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index 44c680d..2da98cd 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -1,37 +1,3 @@ -//! Module for fetching and storing access tokens for the DSH Management Rest API client -//! -//! This module is meant to be used together with the [dsh_rest_api_client]. -//! -//! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. -//! -//! ## Example -//! Recommended usage is to use the [ManagementApiTokenFetcherBuilder] to create a new instance of the token fetcher. -//! However, you can also create a new instance of the token fetcher directly. -//! ```no_run -//! use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; -//! use dsh_rest_api_client::Client; -//! -//! const CLIENT_SECRET: &str = ""; -//! const TENANT: &str = "tenant-name"; -//! -//! #[tokio::main] -//! async fn main() { -//! let platform = Platform::NpLz; -//! let client = Client::new(platform.endpoint_rest_api()); -//! -//! let tf = ManagementApiTokenFetcherBuilder::new(platform) -//! .tenant_name(TENANT.to_string()) -//! .client_secret(CLIENT_SECRET.to_string()) -//! .build() -//! .unwrap(); -//! -//! let response = client -//! .topic_get_by_tenant_topic(TENANT, &tf.get_token().await.unwrap()) -//! .await; -//! println!("Available topics: {:#?}", response); -//! } -//! ``` - use std::fmt::Debug; use std::ops::Add; use std::sync::Mutex; @@ -40,7 +6,7 @@ use std::time::{Duration, Instant}; use log::debug; use serde::Deserialize; -use super::error::ManagementTokenError; +use super::error::ManagementApiTokenError; use crate::utils::Platform; /// Access token of the authentication serveice of DSH. @@ -149,6 +115,11 @@ impl ManagementApiTokenFetcher { ) } + /// Get a [ManagementApiTokenFetcherBuilder] to create a new instance of the token fetcher + pub fn builder(platform: Platform) -> ManagementApiTokenFetcherBuilder { + ManagementApiTokenFetcherBuilder::new(platform) + } + /// Create a new instance of the token fetcher with custom reqwest client /// /// ## Example @@ -186,9 +157,9 @@ impl ManagementApiTokenFetcher { /// /// If the cached token is not valid, it will fetch a new token from the server. /// It will return the token as a string, formatted as "{token_type} {token}" - /// If the request fails for a new token, it will return a [ManagementTokenError::FailureTokenFetch] error. + /// If the request fails for a new token, it will return a [ManagementApiTokenError::FailureTokenFetch] error. /// This will contain the underlying reqwest error. - pub async fn get_token(&self) -> Result { + pub async fn get_token(&self) -> Result { match self.is_valid() { true => Ok(self.access_token.lock().unwrap().formatted_token()), false => { @@ -224,12 +195,12 @@ impl ManagementApiTokenFetcher { /// Fetch a new access token from the server /// /// This will fetch a new access token from the server and return it. - /// If the request fails, it will return a [ManagementTokenError::FailureTokenFetch] error. - /// If the status code is not successful, it will return a [ManagementTokenError::StatusCode] error. + /// If the request fails, it will return a [ManagementApiTokenError::FailureTokenFetch] error. + /// If the status code is not successful, it will return a [ManagementApiTokenError::StatusCode] error. /// If the request is successful, it will return the [AccessToken]. pub async fn fetch_access_token_from_server( &self, - ) -> Result { + ) -> Result { let response = self .client .post(&self.auth_url) @@ -240,9 +211,9 @@ impl ManagementApiTokenFetcher { ]) .send() .await - .map_err(ManagementTokenError::FailureTokenFetch)?; + .map_err(ManagementApiTokenError::FailureTokenFetch)?; if !response.status().is_success() { - Err(ManagementTokenError::StatusCode { + Err(ManagementApiTokenError::StatusCode { status_code: response.status(), error_body: response.text().await.unwrap_or_default(), }) @@ -250,7 +221,7 @@ impl ManagementApiTokenFetcher { response .json::() .await - .map_err(ManagementTokenError::FailureTokenFetch) + .map_err(ManagementApiTokenError::FailureTokenFetch) } } } @@ -267,7 +238,7 @@ impl Debug for ManagementApiTokenFetcher { } } -/// Builder for the token fetcher +/// Builder for the managemant api token fetcher pub struct ManagementApiTokenFetcherBuilder { client: Option, client_id: Option, @@ -340,10 +311,10 @@ impl ManagementApiTokenFetcherBuilder { /// .build() /// .unwrap(); /// ``` - pub fn build(self) -> Result { + pub fn build(self) -> Result { let client_secret = self .client_secret - .ok_or(ManagementTokenError::UnknownClientSecret)?; + .ok_or(ManagementApiTokenError::UnknownClientSecret)?; let client_id = self .client_id .or_else(|| { @@ -351,7 +322,7 @@ impl ManagementApiTokenFetcherBuilder { .as_ref() .map(|tenant_name| self.platform.rest_client_id(tenant_name)) }) - .ok_or(ManagementTokenError::UnknownClientId)?; + .ok_or(ManagementApiTokenError::UnknownClientId)?; let client = self.client.unwrap_or_default(); let token_fetcher = ManagementApiTokenFetcher::new_with_client( client_id, @@ -504,7 +475,7 @@ mod test { tf.auth_url = auth_server.url(); let err = tf.fetch_access_token_from_server().await.unwrap_err(); match err { - ManagementTokenError::StatusCode { + ManagementApiTokenError::StatusCode { status_code, error_body, } => { @@ -588,12 +559,12 @@ mod test { .client_secret("client_secret".to_string()) .build() .unwrap_err(); - assert!(matches!(err, ManagementTokenError::UnknownClientId)); + assert!(matches!(err, ManagementApiTokenError::UnknownClientId)); let err = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) .tenant_name("tenant_name".to_string()) .build() .unwrap_err(); - assert!(matches!(err, ManagementTokenError::UnknownClientSecret)); + assert!(matches!(err, ManagementApiTokenError::UnknownClientSecret)); } } diff --git a/dsh_sdk/src/metrics.rs b/dsh_sdk/src/metrics.rs index bf7bde0..8fd77c6 100644 --- a/dsh_sdk/src/metrics.rs +++ b/dsh_sdk/src/metrics.rs @@ -64,9 +64,7 @@ pub use prometheus::*; use tokio::net::TcpListener; use tokio::task::JoinHandle; -use crate::error::DshError; - -type DshResult = std::result::Result; +type DshResult = std::result::Result>; type BoxBody = http_body_util::combinators::BoxBody; static NOTFOUND: &[u8] = b"404: Not Found"; diff --git a/dsh_sdk/src/mqtt_token_fetcher.rs b/dsh_sdk/src/mqtt_token_fetcher.rs index 19be160..c4821d6 100644 --- a/dsh_sdk/src/mqtt_token_fetcher.rs +++ b/dsh_sdk/src/mqtt_token_fetcher.rs @@ -11,7 +11,7 @@ use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::Arc; -use crate::{error::DshError, Platform}; +use crate::{dsh_old::error::DshError, Platform}; /// `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. /// diff --git a/dsh_sdk/src/protocol_adapters/error.rs b/dsh_sdk/src/protocol_adapters/error.rs new file mode 100644 index 0000000..4c2f797 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/error.rs @@ -0,0 +1,19 @@ +#[cfg(feature = "protocol-token-fetcher")] +/// Error type for the protocol adapter token fetcher +#[derive(Debug, thiserror::Error)] +pub enum ProtocolTokenError { + #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] + DshCall { + url: String, + status_code: reqwest::StatusCode, + error_body: String, + }, + #[error("Reqwest: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("Serde_json error: {0}")] + Json(#[from] serde_json::Error), + #[error("IO Error: {0}")] + Io(#[from] std::io::Error), + #[error("JWT Parse error: {0}")] + Jwt(String), +} diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs index 9ea71d8..7f50893 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs @@ -1,8 +1,25 @@ -pub mod config; // TODO: should we make this public? What benefits would that bring? +//! DSH Configuration for Kafka. +//! +//! This module contains the required configurations to consume and produce messages from DSH Kafka Cluster. +//! +//! ## Example +//! ``` +//! use dsh_sdk::DshKafkaConfig; +//! use rdkafka::ClientConfig; +//! use rdkafka::consumer::StreamConsumer; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let consumer:StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; +//! # Ok(()) +//! # } +//! ``` +pub mod config; #[cfg(feature = "rdkafka")] mod rdkafka; +/// Set all required configurations to consume messages from DSH Kafka Cluster. pub trait DshKafkaConfig { /// Set all required configurations to consume messages from DSH Kafka Cluster. /// diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs index 64a79f7..860d687 100644 --- a/dsh_sdk/src/protocol_adapters/mod.rs +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -1,3 +1,5 @@ +//! The DSH Protocol adapter clients (HTTP, Kafka, MQTT) +//! #[cfg(feature = "http-protocol-adapter")] pub mod http_protocol; #[cfg(feature = "kafka")] @@ -7,6 +9,11 @@ pub mod mqtt_protocol; #[cfg(feature = "protocol-token-fetcher")] pub mod token_fetcher; +mod error; + +#[cfg(feature = "protocol-token-fetcher")] +#[doc(inline)] +pub use error::ProtocolTokenError; #[cfg(feature = "protocol-token-fetcher")] #[doc(inline)] pub use token_fetcher::ProtocolTokenFetcher; diff --git a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs index de392bc..818f5c4 100644 --- a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs +++ b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs @@ -1,6 +1,6 @@ -//! # MQTT Token Fetcher +//! Protocol Token Fetcher //! -//! `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. +//! `ProtocolTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. use std::collections::{hash_map::Entry, HashMap}; use std::fmt::{Display, Formatter}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -10,7 +10,8 @@ use serde_json::json; use sha2::{Digest, Sha256}; use tokio::sync::RwLock; -use crate::{error::DshError, Platform}; +use super::ProtocolTokenError; +use crate::Platform; /// `ProtocolTokenFetcher` is responsible for fetching and managing tokens for the DSH Mqtt and Http protocol adapters. /// @@ -22,7 +23,7 @@ pub struct ProtocolTokenFetcher { rest_api_key: String, rest_token: RwLock, rest_auth_url: String, - protocol_token: RwLock>, // Mapping from Client ID to MqttToken + protocol_token: RwLock>, // Mapping from Client ID to ProtocolToken protocol_auth_url: String, client: reqwest::Client, //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration @@ -39,7 +40,7 @@ pub struct ProtocolTokenFetcher { /// /// # Returns /// -/// Returns a `Result` containing a `MqttTokenFetcher` instance or a `DshError`. +/// Returns a `Result` containing a `ProtocolTokenFetcher` instance or a `ProtocolTokenError`. impl ProtocolTokenFetcher { /// Constructs a new `ProtocolTokenFetcher`. /// @@ -130,12 +131,12 @@ impl ProtocolTokenFetcher { /// /// # Returns /// - /// Returns a `Result` containing the `MqttToken` or a `DshError`. + /// Returns a `Result` containing the `ProtocolToken` or a `ProtocolTokenError`. pub async fn get_token( &self, client_id: &str, claims: Option>, - ) -> Result { + ) -> Result { match self .protocol_token .write() @@ -164,7 +165,7 @@ impl ProtocolTokenFetcher { &self, client_id: &str, claims: Option>, - ) -> Result { + ) -> Result { let mut rest_token = self.rest_token.write().await; if !rest_token.is_valid() { @@ -179,7 +180,8 @@ impl ProtocolTokenFetcher { let authorization_header = format!("Bearer {}", rest_token.raw_token); - let protocol_token_request = MqttTokenRequest::new(client_id, &self.tenant_name, claims)?; + let protocol_token_request = + ProtocolTokenRequest::new(client_id, &self.tenant_name, claims)?; let payload = serde_json::to_value(&protocol_token_request)?; let response = protocol_token_request @@ -191,7 +193,7 @@ impl ProtocolTokenFetcher { ) .await?; - MqttToken::new(response) + ProtocolToken::new(response) } } @@ -264,18 +266,18 @@ impl Resource { } #[derive(Serialize)] -struct MqttTokenRequest { +struct ProtocolTokenRequest { id: String, tenant: String, claims: Option>, } -impl MqttTokenRequest { +impl ProtocolTokenRequest { fn new( client_id: &str, tenant: &str, claims: Option>, - ) -> Result { + ) -> Result { let mut hasher = Sha256::new(); hasher.update(client_id); let result = hasher.finalize(); @@ -294,7 +296,7 @@ impl MqttTokenRequest { protocol_auth_url: &str, authorization_header: &str, payload: &serde_json::Value, - ) -> Result { + ) -> Result { let response = reqwest_client .post(protocol_auth_url) .header("Authorization", authorization_header) @@ -305,7 +307,7 @@ impl MqttTokenRequest { if response.status().is_success() { Ok(response.text().await?) } else { - Err(DshError::DshCallError { + Err(ProtocolTokenError::DshCall { url: protocol_auth_url.to_string(), status_code: response.status(), error_body: response.text().await?, @@ -317,7 +319,7 @@ impl MqttTokenRequest { /// Represents attributes associated with a mqtt token. #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "kebab-case")] -struct MqttTokenAttributes { +struct ProtocolTokenAttributes { gen: i32, endpoint: String, iss: String, @@ -330,13 +332,13 @@ struct MqttTokenAttributes { /// Represents a token used for MQTT connections. #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MqttToken { +pub struct ProtocolToken { exp: i32, raw_token: String, } -impl MqttToken { - /// Creates a new instance of `MqttToken` from a raw token string. +impl ProtocolToken { + /// Creates a new instance of `ProtocolToken` from a raw token string. /// /// # Arguments /// @@ -344,14 +346,14 @@ impl MqttToken { /// /// # Returns /// - /// A Result containing the created MqttToken or an error. - pub fn new(raw_token: String) -> Result { + /// A Result containing the created ProtocolToken or an error. + pub fn new(raw_token: String) -> Result { let header_payload = extract_header_and_payload(&raw_token)?; let decoded_token = decode_base64(header_payload)?; - let token_attributes: MqttTokenAttributes = serde_json::from_slice(&decoded_token)?; - let token = MqttToken { + let token_attributes: ProtocolTokenAttributes = serde_json::from_slice(&decoded_token)?; + let token = ProtocolToken { exp: token_attributes.exp, raw_token, }; @@ -408,13 +410,13 @@ impl RestToken { /// /// # Returns /// - /// A Result containing the created `RestToken` or a `DshError`. + /// A Result containing the created `RestToken` or a `ProtocolTokenError`. async fn get( client: &reqwest::Client, tenant: &str, api_key: &str, auth_url: &str, - ) -> Result { + ) -> Result { let raw_token = Self::fetch_token(client, tenant, api_key, auth_url).await?; let header_payload = extract_header_and_payload(&raw_token)?; @@ -444,7 +446,7 @@ impl RestToken { tenant: &str, api_key: &str, auth_url: &str, - ) -> Result { + ) -> Result { let json_body = json!({"tenant": tenant}); let response = client @@ -458,7 +460,7 @@ impl RestToken { let body_text = response.text().await?; match status { reqwest::StatusCode::OK => Ok(body_text), - _ => Err(DshError::DshCallError { + _ => Err(ProtocolTokenError::DshCall { url: auth_url.to_string(), status_code: status, error_body: body_text, @@ -484,13 +486,12 @@ impl Default for RestToken { /// /// # Returns /// -/// A Result containing the header and payload part of the JWT token or a `DshError`. -fn extract_header_and_payload(raw_token: &str) -> Result<&str, DshError> { +/// A Result containing the header and payload part of the JWT token or a `ProtocolTokenError`. +fn extract_header_and_payload(raw_token: &str) -> Result<&str, ProtocolTokenError> { let parts: Vec<&str> = raw_token.split('.').collect(); - parts - .get(1) - .copied() - .ok_or_else(|| DshError::ParseDnError("Header and payload are missing".to_string())) + parts.get(1).copied().ok_or_else(|| { + ProtocolTokenError::Jwt("Cannot extract header and payload from raw_token".to_string()) + }) } /// Decodes a Base64-encoded string. @@ -501,8 +502,8 @@ fn extract_header_and_payload(raw_token: &str) -> Result<&str, DshError> { /// /// # Returns /// -/// A Result containing the decoded byte vector or a `DshError`. -fn decode_base64(payload: &str) -> Result, DshError> { +/// A Result containing the decoded byte vector or a `ProtocolTokenError`. +fn decode_base64(payload: &str) -> Result, ProtocolTokenError> { use base64::{alphabet, engine, read}; use std::io::Read; @@ -510,9 +511,7 @@ fn decode_base64(payload: &str) -> Result, DshError> { let mut decoder = read::DecoderReader::new(payload.as_bytes(), &engine); let mut decoded_token = Vec::new(); - decoder - .read_to_end(&mut decoded_token) - .map_err(DshError::IoError)?; + decoder.read_to_end(&mut decoded_token)?; Ok(decoded_token) } @@ -533,7 +532,7 @@ mod tests { exp: exp_time as i32, raw_token: "valid.token.payload".to_string(), }; - let protocol_token = MqttToken { + let protocol_token = ProtocolToken { exp: exp_time, raw_token: "valid.token.payload".to_string(), }; @@ -632,7 +631,7 @@ mod tests { #[test] fn test_token_request_new() { - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); assert_eq!(request.id.len(), 64); assert_eq!(request.tenant, "test_tenant"); } @@ -650,7 +649,7 @@ mod tests { let client = reqwest::Client::new(); let payload = json!({"key": "value"}); - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); let result = request .send( &client, @@ -677,7 +676,7 @@ mod tests { let client = reqwest::Client::new(); let payload = json!({"key": "value"}); - let request = MqttTokenRequest::new("test_client", "test_tenant", None).unwrap(); + let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); let result = request .send( &client, @@ -688,7 +687,7 @@ mod tests { .await; assert!(result.is_err()); - if let Err(DshError::DshCallError { + if let Err(ProtocolTokenError::DshCall { url, status_code, error_body, @@ -735,7 +734,7 @@ mod tests { #[test] fn test_protocol_token_is_valid() { let raw_token = "valid.token.payload".to_string(); - let token = MqttToken { + let token = ProtocolToken { exp: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() @@ -749,7 +748,7 @@ mod tests { #[test] fn test_protocol_token_is_invalid() { let raw_token = "valid.token.payload".to_string(); - let token = MqttToken { + let token = ProtocolToken { exp: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() @@ -854,7 +853,7 @@ mod tests { .await; assert!(result.is_err()); - if let Err(DshError::DshCallError { + if let Err(ProtocolTokenError::DshCall { url, status_code, error_body, diff --git a/dsh_sdk/src/rest_api_token_fetcher.rs b/dsh_sdk/src/rest_api_token_fetcher.rs index ed539e6..45542dc 100644 --- a/dsh_sdk/src/rest_api_token_fetcher.rs +++ b/dsh_sdk/src/rest_api_token_fetcher.rs @@ -40,7 +40,7 @@ use std::time::{Duration, Instant}; use log::debug; use serde::Deserialize; -use crate::management_api::error::ManagementTokenError as DshRestTokenError; +use crate::management_api::ManagementApiTokenError as DshRestTokenError; use crate::utils::Platform; /// Access token of the authentication serveice of DSH. diff --git a/dsh_sdk/src/schema_store/api.rs b/dsh_sdk/src/schema_store/api.rs index 47f98a2..fd37875 100644 --- a/dsh_sdk/src/schema_store/api.rs +++ b/dsh_sdk/src/schema_store/api.rs @@ -1,7 +1,7 @@ use super::types::*; use super::request::Request; -use super::Result; +use super::SchemaStoreError; use super::SchemaStoreClient; @@ -13,15 +13,19 @@ pub trait SchemaStoreApi { /// Get glabal compatibility level /// /// {base_url}/config/{subject} - async fn get_config_subject(&self, subject: String) -> Result; + async fn get_config_subject(&self, subject: String) -> Result; /// Set compatibility on subject level. With 1 schema stored in the subject, you can change it to any compatibility level. Else, you can only change into a less restrictive level. Must be one of BACKWARD, BACKWARD_TRANSITIVE, FORWARD, FORWARD_TRANSITIVE, FULL, FULL_TRANSITIVE, NONE /// /// {base_url}/config/{subject} - async fn put_config_subject(&self, subject: String, body: Compatibility) -> Result; + async fn put_config_subject( + &self, + subject: String, + body: Compatibility, + ) -> Result; /// Get a list of registered subjects /// /// {base_url}/subjects - async fn get_subjects(&self) -> Result>; + async fn get_subjects(&self) -> Result, SchemaStoreError>; /// Check if a schema has already been registered under the specified subject. /// If so, this returns the schema string along with its globally unique identifier, @@ -32,12 +36,15 @@ pub trait SchemaStoreApi { &self, subject: String, body: RawSchemaWithType, - ) -> Result; + ) -> Result; /// Get a list of versions registered under the specified subject. /// /// {base_url}/subjects/{subject} - async fn get_subjects_subject_versions(&self, subject: String) -> Result>; + async fn get_subjects_subject_versions( + &self, + subject: String, + ) -> Result, SchemaStoreError>; /// Get a specific version of the schema registered under this subject. /// @@ -46,7 +53,7 @@ pub trait SchemaStoreApi { &self, subject: String, id: String, - ) -> Result; + ) -> Result; /// Register a new schema under the specified subject. /// @@ -63,7 +70,7 @@ pub trait SchemaStoreApi { &self, subject: String, body: RawSchemaWithType, - ) -> Result; + ) -> Result; /// Test input schema against a particular version of a subject’s schema for compatibility. /// Note that the compatibility level applied for the check is the configured compatibility level for the subject (GET /config/(string: subject)). @@ -75,7 +82,7 @@ pub trait SchemaStoreApi { subject: String, id: String, body: RawSchemaWithType, - ) -> Result; + ) -> Result; /// "Get the schema for the specified version of this subject. The unescaped schema only is returned. /// @@ -84,34 +91,41 @@ pub trait SchemaStoreApi { &self, subject: String, version_id: String, - ) -> Result; + ) -> Result; /// Get the schema for the specified version of schema. /// /// {base_url}/schemas/ids/{id} - async fn get_schemas_ids_id(&self, id: i32) -> Result; + async fn get_schemas_ids_id(&self, id: i32) -> Result; /// Get the related subjects vesrion for the specified schema. /// /// {base_url}/schemas/ids/{id}/versions - async fn get_schemas_ids_id_versions(&self, id: i32) -> Result>; + async fn get_schemas_ids_id_versions( + &self, + id: i32, + ) -> Result, SchemaStoreError>; } impl SchemaStoreApi for SchemaStoreClient where C: Request, { - async fn get_config_subject(&self, subject: String) -> Result { + async fn get_config_subject(&self, subject: String) -> Result { let url = format!("{}/config/{}", self.base_url, subject); Ok(self.client.get_request(url).await?) } - async fn put_config_subject(&self, subject: String, body: Compatibility) -> Result { + async fn put_config_subject( + &self, + subject: String, + body: Compatibility, + ) -> Result { let url = format!("{}/config/{}", self.base_url, subject); Ok(self.client.put_request(url, body).await?) } - async fn get_subjects(&self) -> Result> { + async fn get_subjects(&self) -> Result, SchemaStoreError> { let url = format!("{}/subjects", self.base_url); Ok(self.client.get_request(url).await?) } @@ -120,12 +134,15 @@ where &self, subject: String, body: RawSchemaWithType, - ) -> Result { + ) -> Result { let url = format!("{}/subjects/{}", self.base_url, subject); Ok(self.client.post_request(url, body).await?) } - async fn get_subjects_subject_versions(&self, subject: String) -> Result> { + async fn get_subjects_subject_versions( + &self, + subject: String, + ) -> Result, SchemaStoreError> { let url = format!("{}/subjects/{}/versions", self.base_url, subject); Ok(self.client.get_request(url).await?) } @@ -134,7 +151,7 @@ where &self, subject: String, version_id: String, - ) -> Result { + ) -> Result { let url = format!( "{}/subjects/{}/versions/{}", self.base_url, subject, version_id @@ -146,7 +163,7 @@ where &self, subject: String, body: RawSchemaWithType, - ) -> Result { + ) -> Result { let url = format!("{}/subjects/{}/versions", self.base_url, subject); Ok(self.client.post_request(url, body).await?) } @@ -156,7 +173,7 @@ where subject: String, version_id: String, body: RawSchemaWithType, - ) -> Result { + ) -> Result { let url = format!( "{}/compatibility/subjects/{}/versions/{}", self.base_url, subject, version_id @@ -168,7 +185,7 @@ where &self, subject: String, version_id: String, - ) -> Result { + ) -> Result { let url = format!( "{}/subjects/{}/versions/{}/schema", self.base_url, subject, version_id @@ -176,12 +193,15 @@ where Ok(self.client.get_request_plain(url).await?) } - async fn get_schemas_ids_id(&self, id: i32) -> Result { + async fn get_schemas_ids_id(&self, id: i32) -> Result { let url = format!("{}/schemas/ids/{}", self.base_url, id); Ok(self.client.get_request(url).await?) } - async fn get_schemas_ids_id_versions(&self, id: i32) -> Result> { + async fn get_schemas_ids_id_versions( + &self, + id: i32, + ) -> Result, SchemaStoreError> { let url = format!("{}/schemas/ids/{}/versions", self.base_url, id); Ok(self.client.get_request(url).await?) } diff --git a/dsh_sdk/src/schema_store/client.rs b/dsh_sdk/src/schema_store/client.rs index 622f57c..600278b 100644 --- a/dsh_sdk/src/schema_store/client.rs +++ b/dsh_sdk/src/schema_store/client.rs @@ -1,7 +1,7 @@ use super::api::SchemaStoreApi; use super::request::Request; use super::types::*; -use super::Result; +use super::SchemaStoreError; use crate::Dsh; /// High level Schema Store Client @@ -51,7 +51,10 @@ where /// # Ok(()) /// # } /// - pub async fn subject_compatibility(&self, subject: &SubjectName) -> Result { + pub async fn subject_compatibility( + &self, + subject: &SubjectName, + ) -> Result { Ok(self.get_config_subject(subject.name()).await?.into()) } @@ -83,7 +86,7 @@ where &self, subject: &SubjectName, compatibility: Compatibility, - ) -> Result { + ) -> Result { Ok(self .put_config_subject(subject.name(), compatibility) .await? @@ -105,7 +108,7 @@ where /// println!("Subjects: {:?}", client.subjects().await); /// # } /// ``` - pub async fn subjects(&self) -> Result> { + pub async fn subjects(&self) -> Result, SchemaStoreError> { self.get_subjects().await } @@ -127,7 +130,10 @@ where /// # Ok(()) /// # } /// ``` - pub async fn subject_versions(&self, subject: &SubjectName) -> Result> { + pub async fn subject_versions( + &self, + subject: &SubjectName, + ) -> Result, SchemaStoreError> { self.get_subjects_subject_versions(subject.name()).await } @@ -156,7 +162,11 @@ where /// # Ok(()) /// # } /// ``` - pub async fn subject(&self, subject: &SubjectName, version: V) -> Result + pub async fn subject( + &self, + subject: &SubjectName, + version: V, + ) -> Result where V: Into, { @@ -188,7 +198,11 @@ where /// # Ok(()) /// # } /// ``` - pub async fn subject_raw_schema(&self, subject: &SubjectName, version: V) -> Result + pub async fn subject_raw_schema( + &self, + subject: &SubjectName, + version: V, + ) -> Result where V: Into, { @@ -216,7 +230,10 @@ where /// let subjects = client.subject_all_schemas(&subject_name).await?; /// # Ok(()) /// # } - pub async fn subject_all_schemas(&self, subject: &SubjectName) -> Result> { + pub async fn subject_all_schemas( + &self, + subject: &SubjectName, + ) -> Result, SchemaStoreError> { let versions = self.subject_versions(&subject).await?; let mut subjects = Vec::new(); for version in versions { @@ -286,7 +303,7 @@ where &self, subject: &SubjectName, schema: RawSchemaWithType, - ) -> Result { + ) -> Result { Ok(self .post_subjects_subject_versions(subject.name(), schema) .await? @@ -330,7 +347,7 @@ where &self, subject: &SubjectName, schema: RawSchemaWithType, - ) -> Result { + ) -> Result { self.post_subjects_subject(subject.name(), schema).await } @@ -351,7 +368,7 @@ where subject: &SubjectName, version: Sv, schema: RawSchemaWithType, - ) -> Result + ) -> Result where Sv: Into, { @@ -369,7 +386,7 @@ where /// /// ## Arguments /// - `id`: The schema ID (Into<[i32]>) - pub async fn schema(&self, id: Si) -> Result + pub async fn schema(&self, id: Si) -> Result where Si: Into, { @@ -380,7 +397,10 @@ where /// /// ## Arguments /// - `id`: The schema ID (Into<[i32]>) - pub async fn schema_subjects(&self, id: Si) -> Result> + pub async fn schema_subjects( + &self, + id: Si, + ) -> Result, SchemaStoreError> where Si: Into, { diff --git a/dsh_sdk/src/schema_store/mod.rs b/dsh_sdk/src/schema_store/mod.rs index 77201a6..afb5ea8 100644 --- a/dsh_sdk/src/schema_store/mod.rs +++ b/dsh_sdk/src/schema_store/mod.rs @@ -1,6 +1,6 @@ //! Schema Store client //! -//! This module contains the SchemaStoreClient struct which is the main entry point for interacting with the DSH Schema Registry API. +//! This module contains the [SchemaStoreClient] which is the main entry point for interacting with the DSH Schema Registry API. //! //! It automatically connects to the Schema Registry API with proper certificates and uses the base URL provided by the datastreams.josn. //! @@ -12,16 +12,17 @@ //! use dsh_sdk::schema_store::types::*; //! //! # #[tokio::main] -//! # async fn main() { +//! # async fn main() -> Result<(), Box> { //! let client = SchemaStoreClient::new(); //! //! // List all subjects -//! let subjects = client.subjects().await.unwrap(); +//! let subjects = client.subjects().await?; //! //! // Get the latest version of a subjects value schema -//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); -//! let subject = client.subject(&subject_name, SubjectVersion::Latest).await.unwrap(); +//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; +//! let subject = client.subject(&subject_name, SubjectVersion::Latest).await?; //! let raw_schema = subject.schema; +//! # Ok(()) //! # } //! ``` //! @@ -30,16 +31,19 @@ //! ``` //! use dsh_sdk::schema_store::types::*; //! +//! # fn main() -> Result<(), Box> { //! // From original type //! let from_struct = SubjectName::TopicNameStrategy{topic: "scratch.example-topic.tenant".to_string(), key: false}; //! //! // From string -//! let from_str: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); // Note that `-value`` is added, else it will return error as it is not a valid SubjectName +//! let from_str: SubjectName = "scratch.example-topic.tenant-value".try_into()?; // Note that `-value`` is added, else it will return error as it is not a valid SubjectName //! assert_eq!(from_str, from_struct); //! //! // From tuple //! let from_tuple: SubjectName = ("scratch.example-topic.tenant", false).into(); //! assert_eq!(from_tuple, from_struct); +//! # Ok(()) +//! } //! ``` //! //! This means you can easily convert into [types::SubjectName] and [types::RawSchemaWithType]. @@ -49,12 +53,13 @@ //! use dsh_sdk::schema_store::types::*; //! //! # #[tokio::main] -//! # async fn main() { +//! # async fn main() -> Result<(), Box> { //! let client = SchemaStoreClient::new(); //! -//! let raw_schema: RawSchemaWithType = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#.try_into().unwrap(); -//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into().unwrap(); -//! client.subject_add_schema(&subject_name, raw_schema).await.unwrap(); // Returns error if schema is not valid +//! let raw_schema: RawSchemaWithType = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#.try_into()?; +//! let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; +//! client.subject_add_schema(&subject_name, raw_schema).await?; // Returns error if schema is not valid +//! # Ok(()) //! # } //! ``` mod api; @@ -67,5 +72,3 @@ pub mod types; pub use client::SchemaStoreClient; #[doc(inline)] pub use error::SchemaStoreError; - -type Result = std::result::Result; diff --git a/dsh_sdk/src/schema_store/request.rs b/dsh_sdk/src/schema_store/request.rs index 29fa727..2bc1f7e 100644 --- a/dsh_sdk/src/schema_store/request.rs +++ b/dsh_sdk/src/schema_store/request.rs @@ -2,24 +2,27 @@ use log::trace; use crate::Dsh; -use super::{Result, SchemaStoreError}; +use super::SchemaStoreError; const DEFAULT_CONTENT_TYPE: &str = "application/vnd.schemaregistry.v1+json"; pub trait Request { fn new_client() -> Self; - fn get_request(&self, url: String) -> impl std::future::Future> + Send + fn get_request( + &self, + url: String, + ) -> impl std::future::Future> + Send where R: serde::de::DeserializeOwned; fn get_request_plain( &self, url: String, - ) -> impl std::future::Future> + Send; + ) -> impl std::future::Future> + Send; fn post_request( &self, url: String, body: B, - ) -> impl std::future::Future> + Send + ) -> impl std::future::Future> + Send where R: serde::de::DeserializeOwned, B: serde::Serialize + Send; @@ -27,7 +30,7 @@ pub trait Request { &self, url: String, body: B, - ) -> impl std::future::Future> + Send + ) -> impl std::future::Future> + Send where R: serde::de::DeserializeOwned, B: serde::Serialize + Send; @@ -41,7 +44,7 @@ impl Request for reqwest::Client { .build() .expect("Failed to build reqwest client") } - async fn get_request(&self, url: String) -> Result + async fn get_request(&self, url: String) -> Result where R: serde::de::DeserializeOwned, { @@ -63,7 +66,7 @@ impl Request for reqwest::Client { } } - async fn get_request_plain(&self, url: String) -> Result { + async fn get_request_plain(&self, url: String) -> Result { trace!("GET {}", url); let request = self.get(&url); let response = request.send().await?; @@ -80,7 +83,7 @@ impl Request for reqwest::Client { } /// Helper function to send a POST request and return the response with the expected type (serde with as JSON) - async fn post_request(&self, url: String, body: B) -> Result + async fn post_request(&self, url: String, body: B) -> Result where R: serde::de::DeserializeOwned, B: serde::Serialize + Send, @@ -106,7 +109,7 @@ impl Request for reqwest::Client { } /// Helper function to send a PUT request and return the response with the expected type (serde with as JSON) - async fn put_request(&self, url: String, body: B) -> Result + async fn put_request(&self, url: String, body: B) -> Result where R: serde::de::DeserializeOwned, B: serde::Serialize + Send, diff --git a/dsh_sdk/src/utils/dlq.rs b/dsh_sdk/src/utils/dlq.rs deleted file mode 100644 index bb73430..0000000 --- a/dsh_sdk/src/utils/dlq.rs +++ /dev/null @@ -1,595 +0,0 @@ -//! # Dead Letter Queue -//! This optional module contains an implementation of pushing unprocessable/invalid messages towards a Dead Letter Queue (DLQ). -//! It is implemeted with [rdkafka] and [tokio]. -//! -//! add feature `dlq` to your Cargo.toml to enable this module. -//! -//! ### NOTE: -//! This module is meant for pushing messages towards a dead/retry topic only, it does and WILL not handle any logic for retrying messages. -//! Reason is, it can differ per use case what strategy is needed to retry messages and handle the dead letters. -//! -//! It is up to the user to implement the strategy and logic for retrying messages. -//! -//! ### How it works -//! The DLQ struct can -//! -//! ## How to use -//! 1. Implement the [ErrorToDlq] trait on top your (custom) error type. -//! 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) -//! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method. -//! -//! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. -//! -//! ### Example: -//! https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs - -use std::collections::HashMap; -use std::str::from_utf8; - -use log::{debug, error, info, warn}; -use rdkafka::client::DefaultClientContext; -use rdkafka::error::KafkaError; -use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; -use rdkafka::producer::{FutureProducer, FutureRecord}; -use rdkafka::ClientConfig; -use tokio::sync::mpsc; - -use crate::utils::get_env_var; -use crate::utils::graceful_shutdown::Shutdown; -use crate::DshKafkaConfig; - -/// Channel to send messages to the dead letter queue -pub type DlqChannel = mpsc::Sender; - -/// Trait to convert an error to a dlq message -/// This trait is implemented for all errors that can and should be converted to a dlq message -/// -/// Example: -///``` -/// use dsh_sdk::dlq; -/// use std::backtrace::Backtrace; -/// use thiserror::Error; -/// -/// #[derive(Error, Debug)] -/// enum ConsumerError { -/// #[error("Deserialization error: {0}")] -/// DeserializeError(String), -/// } -/// -/// impl dlq::ErrorToDlq for ConsumerError { -/// fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { -/// dlq::SendToDlq::new(kafka_message, self.retryable(), self.to_string(), None) -/// } -/// fn retryable(&self) -> dlq::Retryable { -/// match self { -/// ConsumerError::DeserializeError(e) => dlq::Retryable::NonRetryable, -/// } -/// } -/// } -/// ``` -pub trait ErrorToDlq { - /// Convert error message to a dlq message - fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq; - /// Match error if the orignal message is able to be retried or not - fn retryable(&self) -> Retryable; -} - -/// Struct with required details to send a channel message to the dlq -/// Error needs to be send as string, as it is not possible to send a struct that implements Error trait -pub struct SendToDlq { - kafka_message: OwnedMessage, - retryable: Retryable, - error: String, - stack_trace: Option, -} - -impl SendToDlq { - /// Create new SendToDlq message - pub fn new( - kafka_message: OwnedMessage, - retryable: Retryable, - error: String, - stack_trace: Option, - ) -> Self { - Self { - kafka_message, - retryable, - error, - stack_trace, - } - } - /// Send message to dlq channel - pub async fn send(self, dlq_tx: &mut DlqChannel) { - match dlq_tx.send(self).await { - Ok(_) => debug!("Message sent to DLQ channel"), - Err(e) => error!("Error sending message to DLQ: {}", e), - } - } - - fn get_original_msg(&self) -> OwnedMessage { - self.kafka_message.clone() - } -} - -/// Helper enum to decide to which topic the message should be sent to. -#[derive(Debug, Clone, Copy)] -pub enum Retryable { - Retryable, - NonRetryable, - Other, -} - -impl std::fmt::Display for Retryable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Retryable::Retryable => write!(f, "Retryable"), - Retryable::NonRetryable => write!(f, "NonRetryable"), - Retryable::Other => write!(f, "Other"), - } - } -} - -/// The dead letter queue -/// -/// ## How to use -/// 1. Implement the [ErrorToDlq] trait on top your (custom) error type. -/// 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) -/// 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method. -/// -/// # Example -/// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) -pub struct Dlq { - dlq_producer: FutureProducer, - dlq_rx: mpsc::Receiver, - dlq_dead_topic: String, - dlq_retry_topic: String, - _shutdown: Shutdown, // hold the shutdown alive until exit -} - -impl Dlq { - /// Start the dlq on a tokio task - /// - /// The DLQ will run until the return `Sender` is dropped. - /// - /// # Arguments - /// * `shutdown` - The shutdown is required to keep the DLQ alive until the DLQ Sender is dropped - /// - /// # Returns - /// * The [DlqChannel] to send messages to the DLQ - /// - /// # Note - /// **NEVER** borrow the [DlqChannel] to your consumer, always use an owned [DlqChannel]. - /// This is required to stop the gracefull shutdown the DLQ as it depends on the [DlqChannel] to be dropped. - /// - /// # Example - /// ```no_run - /// use dsh_sdk::utils::graceful_shutdown::Shutdown; - /// use dsh_sdk::utils::dlq::{Dlq, DlqChannel, SendToDlq}; - /// - /// async fn consume(dlq_channel: DlqChannel) { - /// // Your consumer logic together with error handling - /// loop { - /// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - /// } - /// } - /// - /// #[tokio::main] - /// async fn main() { - /// let shutdown = Shutdown::new(); - /// let dlq_channel = Dlq::start(shutdown.clone()).unwrap(); - /// - /// tokio::select! { - /// _ = async move { - /// // Your consumer logic together with the owned dlq_channel - /// dlq_channel - /// } => {} - /// _ = shutdown.signal_listener() => { - /// println!("Shutting down consumer"); - /// } - /// } - /// // wait for graceful shutdown to complete - /// // NOTE that the `dlq_channel` will go out of scope when shutdown is called and the DLQ will stop - /// shutdown.complete().await; - /// } - /// ``` - pub fn start(shutdown: Shutdown) -> Result> { - let (dlq_tx, dlq_rx) = mpsc::channel(200); - let dlq_producer: FutureProducer = - ClientConfig::new().set_dsh_producer_config().create()?; - let dlq_dead_topic = get_env_var("DLQ_DEAD_TOPIC")?; - let dlq_retry_topic = get_env_var("DLQ_RETRY_TOPIC")?; - let dlq = Self { - dlq_producer, - dlq_rx, - dlq_dead_topic, - dlq_retry_topic, - _shutdown: shutdown, - }; - tokio::spawn(dlq.run()); - Ok(dlq_tx) - } - - /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics - /// This function will run until the shutdown channel is closed - async fn run(mut self) { - info!("DLQ started"); - loop { - if let Some(mut dlq_message) = self.dlq_rx.recv().await { - match self.send(&mut dlq_message).await { - Ok(_) => {} - Err(e) => error!("Error sending message to DLQ: {}", e), - }; - } else { - warn!("DLQ stopped as there is no active DLQ Channel"); - break; - } - } - } - /// Create and send message towards the dlq - async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), KafkaError> { - let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); - let headers = orignal_kafka_msg - .generate_dlq_headers(dlq_message) - .to_owned_headers(); - let topic = self.dlq_topic(dlq_message.retryable); - let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); - let payload = orignal_kafka_msg.payload().unwrap_or_default(); - debug!("Sending message to DLQ topic: {}", topic); - let record = FutureRecord::to(topic) - .payload(payload) - .key(key) - .headers(headers); - let send = self.dlq_producer.send(record, None).await; - match send { - Ok((p, o)) => warn!( - "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", - from_utf8(key), - topic, - p, - o - ), - Err((e, _)) => return Err(e), - }; - Ok(()) - } - - fn dlq_topic(&self, retryable: Retryable) -> &str { - match retryable { - Retryable::Retryable => &self.dlq_retry_topic, - Retryable::NonRetryable => &self.dlq_dead_topic, - Retryable::Other => &self.dlq_dead_topic, - } - } -} - -trait DlqHeaders { - fn generate_dlq_headers<'a>( - &'a self, - dlq_message: &'a mut SendToDlq, - ) -> HashMap<&'a str, Option>>; -} - -impl DlqHeaders for OwnedMessage { - fn generate_dlq_headers<'a>( - &'a self, - dlq_message: &'a mut SendToDlq, - ) -> HashMap<&'a str, Option>> { - let mut hashmap_headers: HashMap<&str, Option>> = HashMap::new(); - // Get original headers and add to hashmap - if let Some(headers) = self.headers() { - for header in headers.iter() { - hashmap_headers.insert(header.key, header.value.map(|v| v.to_vec())); - } - } - - // Add dlq headers if not exist (we don't want to overwrite original dlq headers if message already failed earlier) - let partition = self.partition().to_string().as_bytes().to_vec(); - let offset = self.offset().to_string().as_bytes().to_vec(); - let timestamp = self - .timestamp() - .to_millis() - .unwrap_or(-1) - .to_string() - .as_bytes() - .to_vec(); - hashmap_headers - .entry("dlq_topic_origin") - .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); - hashmap_headers - .entry("dlq_partition_origin") - .or_insert_with(move || Some(partition)); - hashmap_headers - .entry("dlq_partition_offset_origin") - .or_insert_with(move || Some(offset)); - hashmap_headers - .entry("dlq_topic_origin") - .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); - hashmap_headers - .entry("dlq_timestamp_origin") - .or_insert_with(move || Some(timestamp)); - // Overwrite if exist - hashmap_headers.insert( - "dlq_retryable", - Some(dlq_message.retryable.to_string().as_bytes().to_vec()), - ); - hashmap_headers.insert( - "dlq_error", - Some(dlq_message.error.to_string().as_bytes().to_vec()), - ); - if let Some(stack_trace) = &dlq_message.stack_trace { - hashmap_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); - } - // update dlq_retries with +1 if exists, else add dlq_retries wiith 1 - let retries = hashmap_headers - .get("dlq_retries") - .map(|v| { - let mut retries = [0; 4]; - retries.copy_from_slice(v.as_ref().unwrap()); - i32::from_be_bytes(retries) - }) - .unwrap_or(0); - hashmap_headers.insert("dlq_retries", Some((retries + 1).to_be_bytes().to_vec())); - - hashmap_headers - } -} - -trait HashMapToKafkaHeaders { - fn to_owned_headers(&self) -> OwnedHeaders; -} - -impl HashMapToKafkaHeaders for HashMap<&str, Option>> { - fn to_owned_headers(&self) -> OwnedHeaders { - // Convert to OwnedHeaders - let mut owned_headers = OwnedHeaders::new_with_capacity(self.len()); - for header in self { - let value = header.1.as_ref().map(|value| value.as_slice()); - owned_headers = owned_headers.insert(Header { - key: header.0, - value, - }); - } - owned_headers - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rdkafka::config::ClientConfig; - use rdkafka::mocking::MockCluster; - - #[derive(Debug)] - enum MockError { - MockErrorRetryable(String), - MockErrorDead(String), - } - impl MockError { - fn to_string(&self) -> String { - match self { - MockError::MockErrorRetryable(e) => e.to_string(), - MockError::MockErrorDead(e) => e.to_string(), - } - } - } - - impl std::fmt::Display for MockError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MockError::MockErrorRetryable(e) => write!(f, "{}", e), - MockError::MockErrorDead(e) => write!(f, "{}", e), - } - } - } - - impl ErrorToDlq for MockError { - fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq { - let backtrace = "some_backtrace"; - SendToDlq::new( - kafka_message, - self.retryable(), - self.to_string(), - Some(backtrace.to_string()), - ) - } - - fn retryable(&self) -> Retryable { - match self { - MockError::MockErrorRetryable(_) => Retryable::Retryable, - MockError::MockErrorDead(_) => Retryable::NonRetryable, - } - } - } - - #[test] - fn test_dlq_get_original_msg() { - let topic = "original_topic"; - let partition = 0; - let offset = 123; - let timestamp = 456; - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "some_key", - value: Some("some_value".as_bytes()), - }); - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - topic.to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - let dlq_message = - MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); - let result = dlq_message.get_original_msg(); - assert_eq!( - result.payload(), - dlq_message.kafka_message.payload(), - "payoad does not match" - ); - assert_eq!( - result.key(), - dlq_message.kafka_message.key(), - "key does not match" - ); - assert_eq!( - result.topic(), - dlq_message.kafka_message.topic(), - "topic does not match" - ); - assert_eq!( - result.partition(), - dlq_message.kafka_message.partition(), - "partition does not match" - ); - assert_eq!( - result.offset(), - dlq_message.kafka_message.offset(), - "offset does not match" - ); - assert_eq!( - result.timestamp(), - dlq_message.kafka_message.timestamp(), - "timestamp does not match" - ); - } - - #[test] - fn test_dlq_hashmap_to_owned_headers() { - let mut hashmap: HashMap<&str, Option>> = HashMap::new(); - hashmap.insert("some_key", Some(b"key_value".to_vec())); - hashmap.insert("some_other_key", None); - let result: Vec<(&str, Option<&[u8]>)> = - vec![("some_key", Some(b"key_value")), ("some_other_key", None)]; - - let owned_headers = hashmap.to_owned_headers(); - for header in owned_headers.iter() { - assert!(result.contains(&(header.key, header.value))); - } - } - - #[test] - fn test_dlq_topic() { - let mock_cluster = MockCluster::new(1).unwrap(); - let mut producer = ClientConfig::new(); - producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); - let producer = producer.create().unwrap(); - let dlq = Dlq { - dlq_producer: producer, - dlq_rx: mpsc::channel(200).1, - dlq_dead_topic: "dead_topic".to_string(), - dlq_retry_topic: "retry_topic".to_string(), - _shutdown: Shutdown::new(), - }; - let error = MockError::MockErrorRetryable("some_error".to_string()); - let topic = dlq.dlq_topic(error.retryable()); - assert_eq!(topic, "retry_topic"); - let error = MockError::MockErrorDead("some_error".to_string()); - let topic = dlq.dlq_topic(error.retryable()); - assert_eq!(topic, "dead_topic"); - } - - #[test] - fn test_dlq_generate_dlq_headers() { - let topic = "original_topic"; - let partition = 0; - let offset = 123; - let timestamp = 456; - let error = Box::new(MockError::MockErrorRetryable("some_error".to_string())); - - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "some_key", - value: Some("some_value".as_bytes()), - }); - - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - topic.to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - - let mut dlq_message = error.to_dlq(owned_message.clone()); - - let mut expected_headers: HashMap<&str, Option>> = HashMap::new(); - expected_headers.insert("some_key", Some(b"some_value".to_vec())); - expected_headers.insert("dlq_topic_origin", Some(topic.as_bytes().to_vec())); - expected_headers.insert( - "dlq_partition_origin", - Some(partition.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_partition_offset_origin", - Some(offset.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_timestamp_origin", - Some(timestamp.to_string().as_bytes().to_vec()), - ); - expected_headers.insert( - "dlq_retryable", - Some(Retryable::Retryable.to_string().as_bytes().to_vec()), - ); - expected_headers.insert("dlq_retries", Some(1_i32.to_be_bytes().to_vec())); - expected_headers.insert("dlq_error", Some(error.to_string().as_bytes().to_vec())); - if let Some(stack_trace) = &dlq_message.stack_trace { - expected_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); - } - - let result = owned_message.generate_dlq_headers(&mut dlq_message); - for header in result.iter() { - assert_eq!( - header.1, - expected_headers.get(header.0).unwrap_or(&None), - "Header {} does not match", - header.0 - ); - } - - // Test if dlq headers are correctly overwritten when to be retried message was already retried before - let mut original_headers: OwnedHeaders = OwnedHeaders::new(); - original_headers = original_headers.insert(Header { - key: "dlq_error", - value: Some( - "to_be_overwritten_error_as_this_was_the_original_error_from_1st_retry".as_bytes(), - ), - }); - original_headers = original_headers.insert(Header { - key: "dlq_topic_origin", - value: Some(topic.as_bytes()), - }); - original_headers = original_headers.insert(Header { - key: "dlq_retries", - value: Some(&1_i32.to_be_bytes().to_vec()), - }); - - let owned_message = OwnedMessage::new( - Some(vec![1, 2, 3]), - Some(vec![4, 5, 6]), - "retry_topic".to_string(), - rdkafka::Timestamp::CreateTime(timestamp), - partition, - offset, - Some(original_headers), - ); - let result = owned_message.generate_dlq_headers(&mut dlq_message); - assert_eq!( - result.get("dlq_error").unwrap(), - &Some(error.to_string().as_bytes().to_vec()) - ); - assert_eq!( - result.get("dlq_topic_origin").unwrap(), - &Some(topic.as_bytes().to_vec()) - ); - assert_eq!( - result.get("dlq_retries").unwrap(), - &Some(2_i32.to_be_bytes().to_vec()) - ); - } -} diff --git a/dsh_sdk/src/utils/dlq/dlq.rs b/dsh_sdk/src/utils/dlq/dlq.rs new file mode 100644 index 0000000..3937175 --- /dev/null +++ b/dsh_sdk/src/utils/dlq/dlq.rs @@ -0,0 +1,237 @@ +//! Dead Letter Queue client + +use std::str::from_utf8; + +use log::{debug, error, info, warn}; +use rdkafka::client::DefaultClientContext; +use rdkafka::error::KafkaError; +use rdkafka::message::{Message, OwnedMessage}; +use rdkafka::producer::{FutureProducer, FutureRecord}; +use rdkafka::ClientConfig; +use tokio::sync::mpsc; + +use super::headers::{DlqHeaders, HashMapToKafkaHeaders}; + +use super::{DlqChannel, DlqErrror, Retryable, SendToDlq}; +use crate::utils::get_env_var; +use crate::utils::graceful_shutdown::Shutdown; +use crate::DshKafkaConfig; + +/// The dead letter queue +/// +/// ## How to use +/// 1. Implement the [ErrorToDlq](super::ErrorToDlq) trait on top your (custom) error type. +/// 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) +/// 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq](super::ErrorToDlq::to_dlq) method. +/// +/// # Example +/// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) +pub struct Dlq { + dlq_producer: FutureProducer, + dlq_rx: mpsc::Receiver, + dlq_dead_topic: String, + dlq_retry_topic: String, + _shutdown: Shutdown, // hold the shutdown alive until exit +} + +impl Dlq { + /// Start the dlq on a tokio task + /// + /// The DLQ will run until the return `Sender` is dropped. + /// + /// # Arguments + /// * `shutdown` - The shutdown is required to keep the DLQ alive until the DLQ Sender is dropped + /// + /// # Returns + /// * The [DlqChannel] to send messages to the DLQ + /// + /// # Note + /// **NEVER** borrow the [DlqChannel] to your consumer, always use an owned [DlqChannel]. + /// This is required to stop the gracefull shutdown the DLQ as it depends on the [DlqChannel] to be dropped. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// use dsh_sdk::utils::dlq::{Dlq, DlqChannel, SendToDlq}; + /// + /// async fn consume(dlq_channel: DlqChannel) { + /// // Your consumer logic together with error handling + /// loop { + /// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + /// } + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let shutdown = Shutdown::new(); + /// let dlq_channel = Dlq::start(shutdown.clone()).unwrap(); + /// + /// tokio::select! { + /// _ = async move { + /// // Your consumer logic together with the owned dlq_channel + /// dlq_channel + /// } => {} + /// _ = shutdown.signal_listener() => { + /// println!("Shutting down consumer"); + /// } + /// } + /// // wait for graceful shutdown to complete + /// // NOTE that the `dlq_channel` will go out of scope when shutdown is called and the DLQ will stop + /// shutdown.complete().await; + /// } + /// ``` + pub fn start(shutdown: Shutdown) -> Result { + let (dlq_tx, dlq_rx) = mpsc::channel(200); + let dlq_producer: FutureProducer = + ClientConfig::new().set_dsh_producer_config().create()?; + let dlq_dead_topic = get_env_var("DLQ_DEAD_TOPIC")?; + let dlq_retry_topic = get_env_var("DLQ_RETRY_TOPIC")?; + let dlq = Self { + dlq_producer, + dlq_rx, + dlq_dead_topic, + dlq_retry_topic, + _shutdown: shutdown, + }; + tokio::spawn(dlq.run()); + Ok(dlq_tx) + } + + /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics + /// This function will run until the shutdown channel is closed + async fn run(mut self) { + info!("DLQ started"); + loop { + if let Some(mut dlq_message) = self.dlq_rx.recv().await { + match self.send(&mut dlq_message).await { + Ok(_) => {} + Err(e) => error!("Error sending message to DLQ: {}", e), + }; + } else { + warn!("DLQ stopped as there is no active DLQ Channel"); + break; + } + } + } + /// Create and send message towards the dlq + async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), KafkaError> { + let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); + let headers = orignal_kafka_msg + .generate_dlq_headers(dlq_message) + .to_owned_headers(); + let topic = self.dlq_topic(dlq_message.retryable); + let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); + let payload = orignal_kafka_msg.payload().unwrap_or_default(); + debug!("Sending message to DLQ topic: {}", topic); + let record = FutureRecord::to(topic) + .payload(payload) + .key(key) + .headers(headers); + let send = self.dlq_producer.send(record, None).await; + match send { + Ok((p, o)) => warn!( + "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", + from_utf8(key), + topic, + p, + o + ), + Err((e, _)) => return Err(e), + }; + Ok(()) + } + + fn dlq_topic(&self, retryable: Retryable) -> &str { + match retryable { + Retryable::Retryable => &self.dlq_retry_topic, + Retryable::NonRetryable => &self.dlq_dead_topic, + Retryable::Other => &self.dlq_dead_topic, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::dlq::tests::MockError; + use crate::utils::dlq::types::*; + use rdkafka::config::ClientConfig; + use rdkafka::message::{Header, OwnedHeaders}; + use rdkafka::mocking::MockCluster; + + #[test] + fn test_dlq_topic() { + let mock_cluster = MockCluster::new(1).unwrap(); + let mut producer = ClientConfig::new(); + producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); + let producer = producer.create().unwrap(); + let dlq = Dlq { + dlq_producer: producer, + dlq_rx: mpsc::channel(200).1, + dlq_dead_topic: "dead_topic".to_string(), + dlq_retry_topic: "retry_topic".to_string(), + _shutdown: Shutdown::new(), + }; + let error = MockError::MockErrorRetryable("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "retry_topic"); + let error = MockError::MockErrorDead("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "dead_topic"); + } + + #[test] + fn test_dlq_get_original_msg() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let dlq_message = + MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); + let result = dlq_message.get_original_msg(); + assert_eq!( + result.payload(), + dlq_message.kafka_message.payload(), + "payoad does not match" + ); + assert_eq!( + result.key(), + dlq_message.kafka_message.key(), + "key does not match" + ); + assert_eq!( + result.topic(), + dlq_message.kafka_message.topic(), + "topic does not match" + ); + assert_eq!( + result.partition(), + dlq_message.kafka_message.partition(), + "partition does not match" + ); + assert_eq!( + result.offset(), + dlq_message.kafka_message.offset(), + "offset does not match" + ); + assert_eq!( + result.timestamp(), + dlq_message.kafka_message.timestamp(), + "timestamp does not match" + ); + } +} diff --git a/dsh_sdk/src/utils/dlq/error.rs b/dsh_sdk/src/utils/dlq/error.rs new file mode 100644 index 0000000..1b8e157 --- /dev/null +++ b/dsh_sdk/src/utils/dlq/error.rs @@ -0,0 +1,10 @@ +/// Related errors to the Dead Letter Queue +#[derive(Debug, thiserror::Error)] +pub enum DlqErrror { + #[error("Kafka Error: {0}")] + Kafka(#[from] rdkafka::error::KafkaError), + #[error("DSH Error: {0}")] + Dsh(#[from] crate::error::DshError), + #[error("Utils Error: {0}")] + Utils(#[from] crate::utils::UtilsError), +} diff --git a/dsh_sdk/src/utils/dlq/headers.rs b/dsh_sdk/src/utils/dlq/headers.rs new file mode 100644 index 0000000..548198f --- /dev/null +++ b/dsh_sdk/src/utils/dlq/headers.rs @@ -0,0 +1,221 @@ +//! Add dead letter queue metadata to the kafka headers + +use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; +use std::collections::HashMap; + +use super::SendToDlq; + +pub trait DlqHeaders { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>>; +} + +impl DlqHeaders for OwnedMessage { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>> { + let mut hashmap_headers: HashMap<&str, Option>> = HashMap::new(); + // Get original headers and add to hashmap + if let Some(headers) = self.headers() { + for header in headers.iter() { + hashmap_headers.insert(header.key, header.value.map(|v| v.to_vec())); + } + } + + // Add dlq headers if not exist (we don't want to overwrite original dlq headers if message already failed earlier) + let partition = self.partition().to_string().as_bytes().to_vec(); + let offset = self.offset().to_string().as_bytes().to_vec(); + let timestamp = self + .timestamp() + .to_millis() + .unwrap_or(-1) + .to_string() + .as_bytes() + .to_vec(); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_partition_origin") + .or_insert_with(move || Some(partition)); + hashmap_headers + .entry("dlq_partition_offset_origin") + .or_insert_with(move || Some(offset)); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_timestamp_origin") + .or_insert_with(move || Some(timestamp)); + // Overwrite if exist + hashmap_headers.insert( + "dlq_retryable", + Some(dlq_message.retryable.to_string().as_bytes().to_vec()), + ); + hashmap_headers.insert( + "dlq_error", + Some(dlq_message.error.to_string().as_bytes().to_vec()), + ); + if let Some(stack_trace) = &dlq_message.stack_trace { + hashmap_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + // update dlq_retries with +1 if exists, else add dlq_retries wiith 1 + let retries = hashmap_headers + .get("dlq_retries") + .map(|v| { + let mut retries = [0; 4]; + retries.copy_from_slice(v.as_ref().unwrap()); + i32::from_be_bytes(retries) + }) + .unwrap_or(0); + hashmap_headers.insert("dlq_retries", Some((retries + 1).to_be_bytes().to_vec())); + + hashmap_headers + } +} + +pub trait HashMapToKafkaHeaders { + fn to_owned_headers(&self) -> OwnedHeaders; +} + +impl HashMapToKafkaHeaders for HashMap<&str, Option>> { + fn to_owned_headers(&self) -> OwnedHeaders { + // Convert to OwnedHeaders + let mut owned_headers = OwnedHeaders::new_with_capacity(self.len()); + for header in self { + let value = header.1.as_ref().map(|value| value.as_slice()); + owned_headers = owned_headers.insert(Header { + key: header.0, + value, + }); + } + owned_headers + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::dlq::tests::MockError; + use crate::utils::dlq::types::*; + use rdkafka::message::OwnedMessage; + + #[test] + fn test_dlq_generate_dlq_headers() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let error = Box::new(MockError::MockErrorRetryable("some_error".to_string())); + + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + + let mut dlq_message = error.to_dlq(owned_message.clone()); + + let mut expected_headers: HashMap<&str, Option>> = HashMap::new(); + expected_headers.insert("some_key", Some(b"some_value".to_vec())); + expected_headers.insert("dlq_topic_origin", Some(topic.as_bytes().to_vec())); + expected_headers.insert( + "dlq_partition_origin", + Some(partition.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_partition_offset_origin", + Some(offset.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_timestamp_origin", + Some(timestamp.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_retryable", + Some(Retryable::Retryable.to_string().as_bytes().to_vec()), + ); + expected_headers.insert("dlq_retries", Some(1_i32.to_be_bytes().to_vec())); + expected_headers.insert("dlq_error", Some(error.to_string().as_bytes().to_vec())); + if let Some(stack_trace) = &dlq_message.stack_trace { + expected_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + + let result = owned_message.generate_dlq_headers(&mut dlq_message); + for header in result.iter() { + assert_eq!( + header.1, + expected_headers.get(header.0).unwrap_or(&None), + "Header {} does not match", + header.0 + ); + } + + // Test if dlq headers are correctly overwritten when to be retried message was already retried before + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "dlq_error", + value: Some( + "to_be_overwritten_error_as_this_was_the_original_error_from_1st_retry".as_bytes(), + ), + }); + original_headers = original_headers.insert(Header { + key: "dlq_topic_origin", + value: Some(topic.as_bytes()), + }); + original_headers = original_headers.insert(Header { + key: "dlq_retries", + value: Some(&1_i32.to_be_bytes().to_vec()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + "retry_topic".to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let result = owned_message.generate_dlq_headers(&mut dlq_message); + assert_eq!( + result.get("dlq_error").unwrap(), + &Some(error.to_string().as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_topic_origin").unwrap(), + &Some(topic.as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_retries").unwrap(), + &Some(2_i32.to_be_bytes().to_vec()) + ); + } + + #[test] + fn test_dlq_hashmap_to_owned_headers() { + let mut hashmap: HashMap<&str, Option>> = HashMap::new(); + hashmap.insert("some_key", Some(b"key_value".to_vec())); + hashmap.insert("some_other_key", None); + let result: Vec<(&str, Option<&[u8]>)> = + vec![("some_key", Some(b"key_value")), ("some_other_key", None)]; + + let owned_headers = hashmap.to_owned_headers(); + for header in owned_headers.iter() { + assert!(result.contains(&(header.key, header.value))); + } + } +} diff --git a/dsh_sdk/src/utils/dlq/mod.rs b/dsh_sdk/src/utils/dlq/mod.rs new file mode 100644 index 0000000..4ee9c88 --- /dev/null +++ b/dsh_sdk/src/utils/dlq/mod.rs @@ -0,0 +1,76 @@ +//! # Dead Letter Queue +//! This optional module contains an implementation of pushing unprocessable/invalid messages towards a Dead Letter Queue (DLQ). +//! It is implemeted with [rdkafka] and [tokio]. +//! +//! ## Feature flag +//! Add feature `dlq` to your Cargo.toml to enable this module. +//! +//! ### NOTE: +//! This module is meant for pushing messages towards a dead/retry topic only, it does and WILL not handle any logic for retrying messages. +//! Reason is, it can differ per use case what strategy is needed to retry messages and handle the dead letters. +//! +//! It is up to the user to implement the strategy and logic for retrying messages. +//! +//! ## How to use +//! 1. Implement the [ErrorToDlq] trait on top your (custom) error type. +//! 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) +//! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method which is implemented on your Error. +//! +//! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. +//! +//! ### Example: +//! +mod dlq; +mod error; +mod headers; +mod types; + +#[doc(inline)] +pub use dlq::Dlq; +#[doc(inline)] +pub use error::DlqErrror; +#[doc(inline)] +pub use types::*; +/// Channel to send messages to the dead letter queue +pub type DlqChannel = tokio::sync::mpsc::Sender; + +// Mock error avaialbnle in tests +#[cfg(test)] +mod tests { + use super::*; + use rdkafka::message::OwnedMessage; + + #[derive(Debug)] + pub enum MockError { + MockErrorRetryable(String), + MockErrorDead(String), + } + + impl std::fmt::Display for MockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MockError::MockErrorRetryable(e) => write!(f, "{}", e), + MockError::MockErrorDead(e) => write!(f, "{}", e), + } + } + } + + impl ErrorToDlq for MockError { + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq { + let backtrace = "some_backtrace"; + SendToDlq::new( + kafka_message, + self.retryable(), + self.to_string(), + Some(backtrace.to_string()), + ) + } + + fn retryable(&self) -> Retryable { + match self { + MockError::MockErrorRetryable(_) => Retryable::Retryable, + MockError::MockErrorDead(_) => Retryable::NonRetryable, + } + } + } +} diff --git a/dsh_sdk/src/utils/dlq/types.rs b/dsh_sdk/src/utils/dlq/types.rs new file mode 100644 index 0000000..f31c983 --- /dev/null +++ b/dsh_sdk/src/utils/dlq/types.rs @@ -0,0 +1,91 @@ +use log::{debug, error}; +use rdkafka::message::OwnedMessage; + +use super::DlqChannel; + +/// Trait to convert an error to a dlq message +/// This trait is implemented for all errors that can and should be converted to a dlq message +/// +/// Example: +///``` +/// use dsh_sdk::dlq; +/// use std::backtrace::Backtrace; +/// use thiserror::Error; +/// +/// #[derive(Error, Debug)] +/// enum ConsumerError { +/// #[error("Deserialization error: {0}")] +/// DeserializeError(String), +/// } +/// +/// impl dlq::ErrorToDlq for ConsumerError { +/// fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { +/// dlq::SendToDlq::new(kafka_message, self.retryable(), self.to_string(), None) +/// } +/// fn retryable(&self) -> dlq::Retryable { +/// match self { +/// ConsumerError::DeserializeError(e) => dlq::Retryable::NonRetryable, +/// } +/// } +/// } +/// ``` +pub trait ErrorToDlq { + /// Convert Error message to a dlq message + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq; + /// Match Error if the orignal message is able to be retried + fn retryable(&self) -> Retryable; +} + +/// DLQ Message that can be send to the [DlqChannel] +pub struct SendToDlq { + pub kafka_message: OwnedMessage, + pub retryable: Retryable, + pub error: String, + pub stack_trace: Option, +} + +impl SendToDlq { + /// Create new SendToDlq message + pub fn new( + kafka_message: OwnedMessage, + retryable: Retryable, + error: String, + stack_trace: Option, + ) -> Self { + Self { + kafka_message, + retryable, + error, + stack_trace, + } + } + /// Send message to dlq channel + pub async fn send(self, dlq_tx: &mut DlqChannel) { + match dlq_tx.send(self).await { + Ok(_) => debug!("Message sent to DLQ channel"), + Err(e) => error!("Error sending message to DLQ: {}", e), + } + } + + pub(crate) fn get_original_msg(&self) -> OwnedMessage { + self.kafka_message.clone() + } +} + +/// Helper enum to decide to which topic the message should be sent to. +#[derive(Debug, Clone, Copy)] +pub enum Retryable { + Retryable, + NonRetryable, + Other, +} + +impl std::fmt::Display for Retryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Retryable::Retryable => write!(f, "Retryable"), + Retryable::NonRetryable => write!(f, "NonRetryable"), + Retryable::Other => write!(f, "Other"), + } + } +} diff --git a/dsh_sdk/src/utils/error.rs b/dsh_sdk/src/utils/error.rs new file mode 100644 index 0000000..226f54d --- /dev/null +++ b/dsh_sdk/src/utils/error.rs @@ -0,0 +1,8 @@ +/// Error type for the utils module +#[derive(Debug, thiserror::Error)] +pub enum UtilsError { + #[error("Env variable {0} error: {1}")] + EnvVarError(&'static str, std::env::VarError), + #[error("No tenant name found")] + NoTenantName, +} diff --git a/dsh_sdk/src/utils/graceful_shutdown.rs b/dsh_sdk/src/utils/graceful_shutdown.rs index 0354bf1..ff83279 100644 --- a/dsh_sdk/src/utils/graceful_shutdown.rs +++ b/dsh_sdk/src/utils/graceful_shutdown.rs @@ -12,7 +12,7 @@ //! # Example: //! //! ```no_run -//! use dsh_sdk::graceful_shutdown::Shutdown; +//! use dsh_sdk::utils::graceful_shutdown::Shutdown; //! //! // your process task //! async fn process_task(shutdown: Shutdown) { @@ -34,7 +34,7 @@ //! #[tokio::main] //! async fn main() { //! // Create shutdown handle -//! let shutdown = dsh_sdk::graceful_shutdown::Shutdown::new(); +//! let shutdown = Shutdown::new(); //! // Create your process task with a cloned shutdown handle //! let process_task = process_task(shutdown.clone()); //! // Spawn your process task in a tokio runtime diff --git a/dsh_sdk/src/utils/metrics.rs b/dsh_sdk/src/utils/metrics.rs index 5c05587..7ba6c35 100644 --- a/dsh_sdk/src/utils/metrics.rs +++ b/dsh_sdk/src/utils/metrics.rs @@ -12,7 +12,7 @@ //! //! ### Example //! ``` -//! use dsh_sdk::metrics::*; +//! use dsh_sdk::utils::metrics::*; //! //! lazy_static! { //! pub static ref HIGH_FIVE_COUNTER: IntCounter = @@ -28,9 +28,9 @@ //! //! ### Example: //! ``` -//! use dsh_sdk::metrics::start_http_server; -//!#[tokio::main] -//!async fn main() { +//! use dsh_sdk::utils::metrics::start_http_server; +//! #[tokio::main] +//! async fn main() { //! start_http_server(9090); //!} //! ``` @@ -61,16 +61,26 @@ use hyper_util::rt::TokioIo; pub use lazy_static::lazy_static; use log::{error, warn}; pub use prometheus::*; +use thiserror::Error; use tokio::net::TcpListener; use tokio::task::JoinHandle; -use crate::error::DshError; - -type DshResult = std::result::Result; +type MetricsResult = std::result::Result; type BoxBody = http_body_util::combinators::BoxBody; static NOTFOUND: &[u8] = b"404: Not Found"; +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum MetricsError { + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + #[error("Hyper error: {0}")] + HyperError(#[from] hyper::http::Error), + #[error("Prometheus error: {0}")] + Prometheus(#[from] prometheus::Error), +} + /// Start a http server to expose prometheus metrics. /// /// The exposed endpoint is /metrics and port number needs to be defined. The server will run on a separate thread @@ -86,7 +96,7 @@ static NOTFOUND: &[u8] = b"404: Not Found"; /// # Example /// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. /// ```rust -/// use dsh_sdk::metrics::start_http_server; +/// use dsh_sdk::utils::metrics::start_http_server; /// /// #[tokio::main] /// async fn main() { @@ -119,7 +129,7 @@ static NOTFOUND: &[u8] = b"404: Not Found"; /// } /// } /// ``` -pub fn start_http_server(port: u16) -> JoinHandle> { +pub fn start_http_server(port: u16) -> JoinHandle> { tokio::spawn(async move { let result = run_server(port).await; warn!("HTTP server stopped: {:?}", result); @@ -128,15 +138,15 @@ pub fn start_http_server(port: u16) -> JoinHandle> { } /// Encode metrics to a string (UTF8) -pub fn metrics_to_string() -> DshResult { +pub fn metrics_to_string() -> MetricsResult { let encoder = prometheus::TextEncoder::new(); let mut buffer = Vec::new(); encoder.encode(&prometheus::gather(), &mut buffer)?; - Ok(String::from_utf8(buffer)?) + Ok(String::from_utf8_lossy(&buffer).into_owned()) } -async fn run_server(port: u16) -> DshResult<()> { +async fn run_server(port: u16) -> MetricsResult<()> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; @@ -155,21 +165,21 @@ async fn handle_connection(stream: tokio::net::TcpStream) { } } -async fn routes(req: Request) -> DshResult> { +async fn routes(req: Request) -> MetricsResult> { match (req.method(), req.uri().path()) { (&Method::GET, "/metrics") => get_metrics(), (_, _) => not_found(), } } -fn get_metrics() -> DshResult> { +fn get_metrics() -> MetricsResult> { Ok(Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, prometheus::TEXT_FORMAT) .body(full(metrics_to_string().unwrap_or_default()))?) } -fn not_found() -> DshResult> { +fn not_found() -> MetricsResult> { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(full(NOTFOUND))?) diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index d364f8d..7023c79 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -6,7 +6,9 @@ use std::env; use log::{debug, info, warn}; use super::{VAR_APP_ID, VAR_DSH_TENANT_NAME}; -use crate::error::DshError; + +#[doc(inline)] +pub use error::UtilsError; #[cfg(feature = "dlq")] pub mod dlq; @@ -17,6 +19,8 @@ pub(crate) mod http_client; #[cfg(feature = "metrics")] pub mod metrics; +mod error; + /// Available DSH platforms plus it's related metadata /// /// The platform enum contains @@ -56,9 +60,9 @@ impl Platform { /// ``` pub fn rest_client_id(&self, tenant: T) -> String where - T: AsRef + std::fmt::Display, + T: AsRef, { - format!("robot:{}:{}", self.realm(), tenant) + format!("robot:{}:{}", self.realm(), tenant.as_ref()) } /// Get the endpoint for the DSH Rest API @@ -174,7 +178,7 @@ impl Platform { /// assert_eq!(topics[2], "topic3"); /// # std::env::remove_var("TOPICS"); /// ``` -pub fn get_configured_topics() -> Result, DshError> { +pub fn get_configured_topics() -> Result, UtilsError> { let kafka_topic_string = get_env_var("TOPICS")?; Ok(kafka_topic_string .split(',') @@ -191,7 +195,7 @@ pub fn get_configured_topics() -> Result, DshError> { /// ## Example /// ``` /// # use dsh_sdk::utils::tenant_name; -/// # use dsh_sdk::error::DshError; +/// # use dsh_sdk::utils::UtilsError; /// std::env::set_var("MARATHON_APP_ID", "/dsh-tenant-name/app-name"); // Injected by DSH by default /// /// let tenant = tenant_name().unwrap(); @@ -205,10 +209,10 @@ pub fn get_configured_topics() -> Result, DshError> { /// /// // If neither of the environment variables are set, it will return an error /// let result = tenant_name(); -/// assert!(matches!(result, Err(DshError::NoTenantName))); +/// assert!(matches!(result, Err(UtilsError::NoTenantName))); /// ``` -pub fn tenant_name() -> Result { +pub fn tenant_name() -> Result { if let Ok(app_id) = get_env_var(VAR_APP_ID) { let tenant_name = app_id.split('/').nth(1); match tenant_name { @@ -224,8 +228,8 @@ pub fn tenant_name() -> Result { } else if let Ok(tenant_name) = get_env_var(VAR_DSH_TENANT_NAME) { Ok(tenant_name) } else { - log::error!("{} and {} are not set, this may cause unexpected behaviour when connecting to DSH Kafka cluster!. Please set one of these environment variables.", VAR_DSH_TENANT_NAME, VAR_APP_ID); - Err(DshError::NoTenantName) + log::warn!("{} and {} are not set, this may cause unexpected behaviour when connecting to DSH Kafka cluster as the group ID is based on this!. Please set one of these environment variables.", VAR_DSH_TENANT_NAME, VAR_APP_ID); + Err(UtilsError::NoTenantName) } } @@ -233,13 +237,13 @@ pub fn tenant_name() -> Result { /// /// Returns the value of the environment variable if it is set, otherwise returns /// `VarError` error. -pub(crate) fn get_env_var(var_name: &str) -> Result { +pub(crate) fn get_env_var(var_name: &'static str) -> Result { debug!("Reading {} from environment variable", var_name); match env::var(var_name) { Ok(value) => Ok(value), Err(e) => { info!("{} is not set", var_name); - Err(DshError::EnvVarError(var_name.to_string(), e)) + Err(UtilsError::EnvVarError(var_name, e)) } } } @@ -276,7 +280,7 @@ mod tests { #[serial(env_dependency)] fn test_dsh_config_tenant_name() { let result = tenant_name(); - assert!(matches!(result, Err(DshError::NoTenantName))); + assert!(matches!(result, Err(UtilsError::NoTenantName))); env::set_var(VAR_APP_ID, "/parsed-tenant-name/app-name"); let result = tenant_name().unwrap(); assert_eq!(result, "parsed-tenant-name".to_string()); From 19403ab1e9cff9e55b25fc7e7f8fee3a8b2d4c9c Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:59:02 +0100 Subject: [PATCH 09/23] Feature/remove metrics (#106) * remove re-exports of prometheus and lazy_static * remove duplicated tests * fix --- dsh_sdk/CHANGELOG.md | 3 + dsh_sdk/Cargo.toml | 6 +- dsh_sdk/README.md | 2 +- dsh_sdk/examples/custom_metrics.rs | 17 -- dsh_sdk/examples/expose_metrics.rs | 30 +- dsh_sdk/src/dsh.rs | 2 +- dsh_sdk/src/error.rs | 2 - dsh_sdk/src/lib.rs | 2 +- dsh_sdk/src/metrics.rs | 355 +++++++--------------- dsh_sdk/src/utils/metrics.rs | 231 +++++++------- example_dsh_service/Cargo.toml | 1 + example_dsh_service/src/custom_metrics.rs | 20 +- example_dsh_service/src/main.rs | 4 +- 13 files changed, 278 insertions(+), 397 deletions(-) delete mode 100644 dsh_sdk/examples/custom_metrics.rs diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 0620dfe..ecbe9d3 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed - Removed `dsh_sdk::rdkafka` public re-export, import `rdkafka` directly - **NOTE** Feature-flag `rdkafka-ssl` and `rdkafka-ssl-vendored` are removed! +- Removed re-export of `prometheus` and `lazy_static` in `metrics` module, if needed import them directly + - **NOTE** See [examples](./examples/expose_metrics.rs) how to use the http server + - Removed `Default` trait for `Dsh` (original `Properties`) struct as this should be public ### Fixed diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 14ae05e..9ff59f8 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -25,10 +25,8 @@ hyper-rustls = { version = "0.27",features = ["ring","http1", "native-tokio", "l http = { version = "1.2", optional = true } rustls = { version = "0.23", features = ["ring", "tls12", "logging"], default-features = false, optional = true } rustls-pemfile = { version = "2.2", optional = true } -lazy_static = { version = "1.5", optional = true } log = "0.4" pem = {version = "3", optional = true } -prometheus = { version = "0.13", features = ["process"], optional = true } protofish = { version = "0.5.2", optional = true } rcgen = { version = "0.13", optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "blocking"], optional = true } @@ -51,7 +49,7 @@ schema-store = ["bootstrap", "reqwest", "serde_json", "apache-avro", "protofish" graceful-shutdown = ["tokio", "tokio-util"] management-api-token-fetcher = ["reqwest"] protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] -metrics = ["prometheus", "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "lazy_static", "tokio", "bytes"] +metrics = [ "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "tokio", "bytes"] dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/ssl-vendored", "rdkafka/libz", "rdkafka/tokio", "graceful-shutdown"] # http-protocol-adapter = ["protocol-token-fetcher"] @@ -69,3 +67,5 @@ dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.3.0" } dsh_sdk = { features = ["dlq"], path = "." } env_logger = "0.11" rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"], default-features = true } +lazy_static = { version = "1.5" } +prometheus = { version = "0.13", features = ["process"] } \ No newline at end of file diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 9a18035..afe2feb 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -67,7 +67,7 @@ The following features are available in this library and can be enabled/disabled | `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | | `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | | `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | -| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs) / [Custom metrics](./examples/custom_metrics.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs)| | `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | | `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](./examples/dlq_implementation.rs) | diff --git a/dsh_sdk/examples/custom_metrics.rs b/dsh_sdk/examples/custom_metrics.rs deleted file mode 100644 index 5874ade..0000000 --- a/dsh_sdk/examples/custom_metrics.rs +++ /dev/null @@ -1,17 +0,0 @@ -use dsh_sdk::utils::metrics::*; - -lazy_static! { - pub static ref HIGH_FIVE_COUNTER: IntCounter = - register_int_counter!("highfives", "Number of high fives recieved").unwrap(); - pub static ref SPEEDOMETRE: IntGaugeVec = - register_int_gauge_vec!("speedometre", "Speedometre", &["type"]).unwrap(); -} - -fn main() { - // increment the high five counter - HIGH_FIVE_COUNTER.inc(); - // set the speed to 100 - SPEEDOMETRE.with_label_values(&["speed"]).set(100); - // simple print statement to show the metrics in prometheus format - println!("{}", metrics_to_string().unwrap()) -} diff --git a/dsh_sdk/examples/expose_metrics.rs b/dsh_sdk/examples/expose_metrics.rs index 59906dc..4853ac2 100644 --- a/dsh_sdk/examples/expose_metrics.rs +++ b/dsh_sdk/examples/expose_metrics.rs @@ -1,19 +1,43 @@ -use dsh_sdk::utils::metrics::*; +use dsh_sdk::utils::metrics::start_http_server; +use lazy_static::lazy_static; +use prometheus::{register_int_counter, IntCounter}; +use std::sync::OnceLock; +// Register counter with lazy_static +// (not reccomended to use lazy_static) lazy_static! { pub static ref HIGH_FIVE_COUNTER: IntCounter = register_int_counter!("highfives", "Number of high fives recieved").unwrap(); } +// Register counter with Rust std library +// Recomended way to register metrics +pub fn low_five_counter() -> &'static IntCounter { + static CONSUMED_MESSAGES: OnceLock = OnceLock::new(); + CONSUMED_MESSAGES.get_or_init(|| { + register_int_counter!("consumed_messages", "Number of messages consumed").unwrap() + }) +} + +/// Gather and encode metrics to a string (UTF8) +pub fn encode_metrics() -> String { + let encoder = prometheus::TextEncoder::new(); + encoder + .encode_to_string(&prometheus::gather()) + .unwrap_or_default() +} + #[tokio::main] async fn main() { println!("Starting metrics server on http://localhost:8080/metrics"); - start_http_server(8080); + start_http_server(8080, encode_metrics); - // increment the high five counter every second for 20 times + // increment the counters every second for 20 times for i in 0..20 { println!("High five number: {}", i + 1); HIGH_FIVE_COUNTER.inc(); + println!("Low five number: {}", i + 1); + low_five_counter().inc(); tokio::time::sleep(std::time::Duration::from_secs(1)).await; } } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index ce00f73..ed3d09e 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -1,4 +1,4 @@ -//! High-level API to interact with DSH. +//! High-level API to interact with DSH when your container is running on DSH. //! //! From [Dsh] there are level functions to get the correct config to connect to Kafka and schema store. //! For more low level functions, see diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index a5baac5..03b6c52 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -7,8 +7,6 @@ pub enum DshError { DatastreamError(#[from] crate::datastream::DatastreamError), #[error("Utils error: {0}")] UtilsError(#[from] crate::utils::UtilsError), - #[error("Reqwest: {0}")] - ReqwestError(#[from] reqwest::Error), } pub(crate) fn report(mut err: &dyn std::error::Error) -> String { diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 653d525..6ec099f 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -64,7 +64,7 @@ pub mod graceful_shutdown; )] pub mod metrics; -#[cfg(feature = "protocol-token-fetcher")] +#[cfg(all(feature = "protocol-token-fetcher", feature = "bootstrap"))] #[deprecated( since = "0.5.0", note = "`dsh_sdk::mqtt_token_fetcher` is moved to `dsh_sdk::protocol_adapters::token_fetcher`" diff --git a/dsh_sdk/src/metrics.rs b/dsh_sdk/src/metrics.rs index 8fd77c6..1c26a23 100644 --- a/dsh_sdk/src/metrics.rs +++ b/dsh_sdk/src/metrics.rs @@ -1,40 +1,30 @@ -//! This module wraps the prometheus metrics library and provides a http server to expose the metrics. +//! Provides a lightweight HTTP server to expose (prometheus) metrics. //! -//! It is technically a re-exports the prometheus metrics library with some additional functions. +//! ## Expose metrics to DSH / HTTP Server //! -//! # Create custom metrics +//! This module provides a http server to expose the metrics to DSH. A port number and a function that encode the metrics to [String] needs to be defined. //! -//! To define custom metrics, the prometheus macros can be used. They are re-exported in this module. +//! Most metrics libraries provide a way to encode the metrics to a string. For example, +//! - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. +//! - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. //! -//! As they are a pub static reference, you can use them anywhere in your code. -//! -//! See [prometheus](https://docs.rs/prometheus/0.13.3/prometheus/index.html#macros) for more information. -//! -//! ### Example +//! ### Example: //! ``` -//! use dsh_sdk::metrics::*; +//! use dsh_sdk::utils::metrics::start_http_server; //! -//! lazy_static! { -//! pub static ref HIGH_FIVE_COUNTER: IntCounter = -//! register_int_counter!("highfives", "Number of high fives recieved").unwrap(); +//! fn encode_metrics() -> String { +//! // Provide here your logic to gather and encode the metrics to a string +//! // Check your chosen metrics library for the correct implementation +//! "my_metrics 1".to_string() // Dummy example //! } //! -//! HIGH_FIVE_COUNTER.inc(); -//! ``` -//! -//! # Expose metrics to DSH / HTTP Server -//! -//! This module provides a http server to expose the metrics to DSH. A port number needs to be defined. -//! -//! ### Example: -//! ``` -//! use dsh_sdk::metrics::start_http_server; -//!#[tokio::main] -//!async fn main() { -//! start_http_server(9090); +//! #[tokio::main] +//! async fn main() { +//! start_http_server(9090, encode_metrics); //!} //! ``` -//! After starting the http server, the metrics can be found at http://localhost:8080/metrics. +//! After starting the http server, the metrics can be found at http://localhost:9090/metrics. //! To expose the metrics to DSH, the port number needs to be defined in the DSH service configuration. //! //! ```json @@ -58,116 +48,130 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{header, Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; -pub use lazy_static::lazy_static; use log::{error, warn}; -pub use prometheus::*; +use thiserror::Error; use tokio::net::TcpListener; use tokio::task::JoinHandle; -type DshResult = std::result::Result>; type BoxBody = http_body_util::combinators::BoxBody; static NOTFOUND: &[u8] = b"404: Not Found"; -/// Start a http server to expose prometheus metrics. -/// -/// The exposed endpoint is /metrics and port number needs to be defined. The server will run on a separate thread -/// and this function will return a JoinHandle of the thread. It is optional to handle the thread status. If left unhandled, -/// the server will run until the main thread is stopped. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum MetricsError { + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + #[error("Hyper error: {0}")] + HyperError(#[from] hyper::http::Error), +} + +/// A lihghtweight HTTP server to expose prometheus metrics. /// -/// # Note! -/// Don't forget to expose the port in your dockerfile and add the port number to the DSH service configuration. -///```Dockerfile -/// EXPOSE 9090 -/// ``` +/// The exposed endpoint is /metrics and port number needs to be defined together with your gather and encode function to string. +/// The server will run on a separate thread and this function will return a JoinHandle of the thread. +/// It is optional to handle the thread status. If left unhandled, the server will run until the main thread is stopped. /// /// # Example /// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. -/// ```rust -/// use dsh_sdk::metrics::start_http_server; -/// -/// #[tokio::main] -/// async fn main() { -/// start_http_server(9090); +/// ``` +/// use dsh_sdk::utils::metrics::start_http_server; +/// +/// fn encode_metrics() -> String { +/// // Provide here your logic to gather and encode the metrics to a string +/// // Check your chosen metrics library for the correct implementation +/// "my_metrics 1".to_string() // Dummy example +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// start_http_server(9090, encode_metrics); /// } -/// ``` +/// ``` /// -/// # Optional: Check http server thread status +/// ## Optional: Check http server thread status /// Await the JoinHandle in a a tokio select besides your application logic to check if the server is still running. /// ```rust -/// use dsh_sdk::metrics::start_http_server; +/// # use dsh_sdk::utils::metrics::start_http_server; /// # use tokio::time::sleep; /// # use std::time::Duration; -/// -/// #[tokio::main] -/// async fn main() { -/// let server = start_http_server(9090); -/// tokio::select! { -/// // Replace sleep with your application logic -/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, -/// // Check if the server is still running -/// tokio_result = server => { -/// match tokio_result { -/// Ok(server_result) => if let Err(e) = server_result { -/// eprintln!("Metrics server operation failed: {}", e); -/// }, -/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), -/// } -/// } -/// } +/// # fn encode_metrics() -> String { +/// # "my_metrics 1".to_string() // Dummy example +/// # } +/// # #[tokio::main] +/// # async fn main() { +/// let server = start_http_server(9090, encode_metrics); +/// tokio::select! { +/// // Replace sleep with your application logic +/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, +/// // Check if the server is still running +/// tokio_result = server => { +/// match tokio_result { +/// Ok(server_result) => if let Err(e) = server_result { +/// eprintln!("Metrics server operation failed: {}", e); +/// }, +/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), +/// } +/// } /// } +/// # } /// ``` -pub fn start_http_server(port: u16) -> JoinHandle> { +pub fn start_http_server( + port: u16, + metrics_encode_fn: fn() -> String, +) -> JoinHandle> { + let server = MetricsServer { + port, + metrics_encode_fn, + }; tokio::spawn(async move { - let result = run_server(port).await; + let result = server.run_server().await; warn!("HTTP server stopped: {:?}", result); result }) } -/// Encode metrics to a string (UTF8) -pub fn metrics_to_string() -> DshResult { - let encoder = prometheus::TextEncoder::new(); - - let mut buffer = Vec::new(); - encoder.encode(&prometheus::gather(), &mut buffer)?; - Ok(String::from_utf8(buffer)?) +struct MetricsServer { + port: u16, + metrics_encode_fn: fn() -> String, } -async fn run_server(port: u16) -> DshResult<()> { - let addr = SocketAddr::from(([0, 0, 0, 0], port)); - let listener = TcpListener::bind(addr).await?; +impl MetricsServer { + async fn run_server(&self) -> Result<(), MetricsError> { + let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); + let listener = TcpListener::bind(addr).await?; - loop { - let (stream, _) = listener.accept().await?; - tokio::spawn(handle_connection(stream)); + loop { + let (stream, _) = listener.accept().await?; + self.handle_connection(stream).await; + } } -} - -async fn handle_connection(stream: tokio::net::TcpStream) { - let io = TokioIo::new(stream); - let service = service_fn(routes); - if let Err(err) = http1::Builder::new().serve_connection(io, service).await { - error!("Failed to serve connection: {:?}", err); + async fn handle_connection(&self, stream: tokio::net::TcpStream) { + let io = TokioIo::new(stream); + let service = service_fn(|req| self.routes(req)); + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + error!("Failed to serve metrics connection: {:?}", err); + } } -} -async fn routes(req: Request) -> DshResult> { - match (req.method(), req.uri().path()) { - (&Method::GET, "/metrics") => get_metrics(), - (_, _) => not_found(), + async fn routes(&self, req: Request) -> Result, MetricsError> { + match (req.method(), req.uri().path()) { + (&Method::GET, "/metrics") => self.get_metrics(), + (_, _) => not_found(), + } } -} -fn get_metrics() -> DshResult> { - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, prometheus::TEXT_FORMAT) - .body(full(metrics_to_string().unwrap_or_default()))?) + fn get_metrics(&self) -> Result, MetricsError> { + let body = (self.metrics_encode_fn)(); + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/plain") + .body(full(body))?) + } } -fn not_found() -> DshResult> { +fn not_found() -> Result, MetricsError> { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(full(NOTFOUND))?) @@ -178,154 +182,3 @@ fn full>(chunk: T) -> BoxBody { .map_err(|never| match never {}) .boxed() } - -#[cfg(test)] -mod tests { - use super::*; - use http_body_util::Empty; - use hyper::body::Body; - use hyper::client::conn; - use hyper::client::conn::http1::{Connection, SendRequest}; - use hyper::http::HeaderValue; - use hyper::Uri; - use serial_test::serial; - use tokio::net::TcpStream; - - lazy_static! { - pub static ref HIGH_FIVE_COUNTER_OLD: IntCounter = - register_int_counter!("highfives_old", "Number of high fives recieved").unwrap(); - } - - async fn create_client( - url: &Uri, - ) -> ( - SendRequest>, - Connection, Empty>, - ) { - let host = url.host().expect("uri has no host"); - let port = url.port_u16().unwrap_or(80); - let addr = format!("{}:{}", host, port); - - let stream = TcpStream::connect(addr).await.unwrap(); - let io = TokioIo::new(stream); - - conn::http1::handshake(io).await.unwrap() - } - - fn to_get_req(url: &Uri) -> Request> { - Request::builder() - .uri(url) - .method(Method::GET) - .header(header::HOST, url.authority().unwrap().clone().as_str()) - .body(Empty::::new()) - .unwrap() - } - - #[tokio::test] - #[serial(port_usage)] - async fn test_http_metric_response() { - // Increment the counter - HIGH_FIVE_COUNTER_OLD.inc(); - - // Call the function - let res = get_metrics(); - - // Check if the function returns a result - assert!(res.is_ok()); - - // Check if the result is not an empty string - let response = res.unwrap(); - let status_code = response.status(); - - assert_eq!(status_code, StatusCode::OK); - assert!(response.body().size_hint().exact().unwrap() > 0); - assert_eq!( - response.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static(prometheus::TEXT_FORMAT) - ); - } - - #[tokio::test] - #[serial(port_usage)] - async fn test_start_http_server() { - // Start HTTP server - let server = start_http_server(8080); - - // increment the counter - HIGH_FIVE_COUNTER_OLD.inc(); - - // Give the server a moment to start - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - let url: Uri = "http://localhost:8080/metrics".parse().unwrap(); - let (mut request_sender, connection) = create_client(&url).await; - tokio::task::spawn(async move { - if let Err(err) = connection.await { - error!("Connection failed: {:?}", err); - } - }); - - // Send a request to the server - let request = to_get_req(&url); - let response = request_sender.send_request(request).await.unwrap(); - - // Check if the server returns a 200 status - assert_eq!(response.status(), StatusCode::OK); - assert_eq!( - response.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static(prometheus::TEXT_FORMAT) - ); - - // Check if the response body is not empty - let buf = response.collect().await.unwrap().to_bytes(); - let res = String::from_utf8(buf.to_vec()).unwrap(); - - println!("{}", res); - assert!(!res.is_empty()); - - // Terminate the server - server.abort(); - } - - #[tokio::test] - #[serial(port_usage)] - async fn test_unknown_path() { - // Start HTTP server - let server = start_http_server(9900); - - // Give the server a moment to start - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - let url: Uri = "http://localhost:9900".parse().unwrap(); - let (mut request_sender, connection) = create_client(&url).await; - tokio::task::spawn(async move { - if let Err(err) = connection.await { - error!("Connection failed: {:?}", err); - } - }); - - // Send a request to the server - let request = to_get_req(&url); - - let response = request_sender.send_request(request).await.unwrap(); - - // Check if the server returns a 404 status - assert_eq!(response.status(), StatusCode::NOT_FOUND); - - // Check if the response body is not empty - let buf = response.collect().await.unwrap().to_bytes(); - let res = String::from_utf8(buf.to_vec()).unwrap(); - - assert_eq!(res, String::from_utf8_lossy(NOTFOUND)); - - // Terminate the server - server.abort(); - } - - #[test] - fn test_metrics_to_string() { - HIGH_FIVE_COUNTER_OLD.inc(); - let res = metrics_to_string().unwrap(); - assert!(res.contains("highfives")); - } -} diff --git a/dsh_sdk/src/utils/metrics.rs b/dsh_sdk/src/utils/metrics.rs index 7ba6c35..9d8c6d1 100644 --- a/dsh_sdk/src/utils/metrics.rs +++ b/dsh_sdk/src/utils/metrics.rs @@ -1,40 +1,30 @@ -//! This module wraps the prometheus metrics library and provides a http server to expose the metrics. +//! Provides a lightweight HTTP server to expose (prometheus) metrics. //! -//! It is technically a re-exports the prometheus metrics library with some additional functions. +//! ## Expose metrics to DSH / HTTP Server //! -//! # Create custom metrics +//! This module provides a http server to expose the metrics to DSH. A port number and a function that encode the metrics to [String] needs to be defined. //! -//! To define custom metrics, the prometheus macros can be used. They are re-exported in this module. +//! Most metrics libraries provide a way to encode the metrics to a string. For example, +//! - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. +//! - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. //! -//! As they are a pub static reference, you can use them anywhere in your code. -//! -//! See [prometheus](https://docs.rs/prometheus/0.13.3/prometheus/index.html#macros) for more information. -//! -//! ### Example +//! ### Example: //! ``` -//! use dsh_sdk::utils::metrics::*; +//! use dsh_sdk::utils::metrics::start_http_server; //! -//! lazy_static! { -//! pub static ref HIGH_FIVE_COUNTER: IntCounter = -//! register_int_counter!("highfives", "Number of high fives recieved").unwrap(); +//! fn encode_metrics() -> String { +//! // Provide here your logic to gather and encode the metrics to a string +//! // Check your chosen metrics library for the correct implementation +//! "my_metrics 1".to_string() // Dummy example //! } //! -//! HIGH_FIVE_COUNTER.inc(); -//! ``` -//! -//! # Expose metrics to DSH / HTTP Server -//! -//! This module provides a http server to expose the metrics to DSH. A port number needs to be defined. -//! -//! ### Example: -//! ``` -//! use dsh_sdk::utils::metrics::start_http_server; //! #[tokio::main] //! async fn main() { -//! start_http_server(9090); +//! start_http_server(9090, encode_metrics); //!} //! ``` -//! After starting the http server, the metrics can be found at http://localhost:8080/metrics. +//! After starting the http server, the metrics can be found at http://localhost:9090/metrics. //! To expose the metrics to DSH, the port number needs to be defined in the DSH service configuration. //! //! ```json @@ -58,14 +48,11 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{header, Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; -pub use lazy_static::lazy_static; use log::{error, warn}; -pub use prometheus::*; use thiserror::Error; use tokio::net::TcpListener; use tokio::task::JoinHandle; -type MetricsResult = std::result::Result; type BoxBody = http_body_util::combinators::BoxBody; static NOTFOUND: &[u8] = b"404: Not Found"; @@ -77,109 +64,123 @@ pub enum MetricsError { IoError(#[from] std::io::Error), #[error("Hyper error: {0}")] HyperError(#[from] hyper::http::Error), - #[error("Prometheus error: {0}")] - Prometheus(#[from] prometheus::Error), } -/// Start a http server to expose prometheus metrics. +/// A lihghtweight HTTP server to expose prometheus metrics. /// -/// The exposed endpoint is /metrics and port number needs to be defined. The server will run on a separate thread -/// and this function will return a JoinHandle of the thread. It is optional to handle the thread status. If left unhandled, -/// the server will run until the main thread is stopped. +/// The exposed endpoint is /metrics and port number needs to be defined together with your gather and encode function to string. +/// The server will run on a separate thread and this function will return a JoinHandle of the thread. +/// It is optional to handle the thread status. If left unhandled, the server will run until the main thread is stopped. /// -/// # Note! -/// Don't forget to expose the port in your dockerfile and add the port number to the DSH service configuration. -///```Dockerfile -/// EXPOSE 9090 -/// ``` +/// ## Expose metrics to DSH / HTTP Server /// -/// # Example -/// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. -/// ```rust -/// use dsh_sdk::utils::metrics::start_http_server; +/// This module provides a http server to expose the metrics to DSH. A port number and a function that encode the metrics to [String] needs to be defined. /// -/// #[tokio::main] -/// async fn main() { -/// start_http_server(9090); +/// Most metrics libraries provide a way to encode the metrics to a string. For example, +/// - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. +/// - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +/// See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. +/// +/// ## Example +/// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. +/// ``` +/// use dsh_sdk::utils::metrics::start_http_server; +/// +/// fn encode_metrics() -> String { +/// // Provide here your logic to gather and encode the metrics to a string +/// // Check your chosen metrics library for the correct implementation +/// "my_metrics 1".to_string() // Dummy example +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// start_http_server(9090, encode_metrics); /// } -/// ``` +/// ``` /// /// # Optional: Check http server thread status /// Await the JoinHandle in a a tokio select besides your application logic to check if the server is still running. /// ```rust -/// use dsh_sdk::metrics::start_http_server; +/// # use dsh_sdk::utils::metrics::start_http_server; /// # use tokio::time::sleep; /// # use std::time::Duration; -/// -/// #[tokio::main] -/// async fn main() { -/// let server = start_http_server(9090); -/// tokio::select! { -/// // Replace sleep with your application logic -/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, -/// // Check if the server is still running -/// tokio_result = server => { -/// match tokio_result { -/// Ok(server_result) => if let Err(e) = server_result { -/// eprintln!("Metrics server operation failed: {}", e); -/// }, -/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), -/// } -/// } -/// } +/// # fn encode_metrics() -> String { +/// # "my_metrics 1".to_string() // Dummy example +/// # } +/// # #[tokio::main] +/// # async fn main() { +/// let server = start_http_server(9090, encode_metrics); +/// tokio::select! { +/// // Replace sleep with your application logic +/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, +/// // Check if the server is still running +/// tokio_result = server => { +/// match tokio_result { +/// Ok(server_result) => if let Err(e) = server_result { +/// eprintln!("Metrics server operation failed: {}", e); +/// }, +/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), +/// } +/// } /// } +/// # } /// ``` -pub fn start_http_server(port: u16) -> JoinHandle> { +pub fn start_http_server( + port: u16, + metrics_encode_fn: fn() -> String, +) -> JoinHandle> { + let server = MetricsServer { + port, + metrics_encode_fn, + }; tokio::spawn(async move { - let result = run_server(port).await; + let result = server.run_server().await; warn!("HTTP server stopped: {:?}", result); result }) } -/// Encode metrics to a string (UTF8) -pub fn metrics_to_string() -> MetricsResult { - let encoder = prometheus::TextEncoder::new(); - - let mut buffer = Vec::new(); - encoder.encode(&prometheus::gather(), &mut buffer)?; - Ok(String::from_utf8_lossy(&buffer).into_owned()) +struct MetricsServer { + port: u16, + metrics_encode_fn: fn() -> String, } -async fn run_server(port: u16) -> MetricsResult<()> { - let addr = SocketAddr::from(([0, 0, 0, 0], port)); - let listener = TcpListener::bind(addr).await?; +impl MetricsServer { + async fn run_server(&self) -> Result<(), MetricsError> { + let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); + let listener = TcpListener::bind(addr).await?; - loop { - let (stream, _) = listener.accept().await?; - tokio::spawn(handle_connection(stream)); + loop { + let (stream, _) = listener.accept().await?; + self.handle_connection(stream).await; + } } -} -async fn handle_connection(stream: tokio::net::TcpStream) { - let io = TokioIo::new(stream); - let service = service_fn(routes); - - if let Err(err) = http1::Builder::new().serve_connection(io, service).await { - error!("Failed to serve connection: {:?}", err); + async fn handle_connection(&self, stream: tokio::net::TcpStream) { + let io = TokioIo::new(stream); + let service = service_fn(|req| self.routes(req)); + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + error!("Failed to serve metrics connection: {:?}", err); + } } -} -async fn routes(req: Request) -> MetricsResult> { - match (req.method(), req.uri().path()) { - (&Method::GET, "/metrics") => get_metrics(), - (_, _) => not_found(), + async fn routes(&self, req: Request) -> Result, MetricsError> { + match (req.method(), req.uri().path()) { + (&Method::GET, "/metrics") => self.get_metrics(), + (_, _) => not_found(), + } } -} -fn get_metrics() -> MetricsResult> { - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, prometheus::TEXT_FORMAT) - .body(full(metrics_to_string().unwrap_or_default()))?) + fn get_metrics(&self) -> Result, MetricsError> { + let body = (self.metrics_encode_fn)(); + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/plain") + .body(full(body))?) + } } -fn not_found() -> MetricsResult> { +fn not_found() -> Result, MetricsError> { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(full(NOTFOUND))?) @@ -198,13 +199,22 @@ mod tests { use hyper::body::Body; use hyper::client::conn; use hyper::client::conn::http1::{Connection, SendRequest}; - use hyper::http::HeaderValue; use hyper::Uri; + use lazy_static::lazy_static; + use prometheus::{register_int_counter, IntCounter}; use serial_test::serial; use tokio::net::TcpStream; const PORT: u16 = 9090; + /// Gather and encode metrics to a string (UTF8) + pub fn metrics_to_string() -> String { + let encoder = prometheus::TextEncoder::new(); + encoder + .encode_to_string(&prometheus::gather()) + .unwrap_or_default() + } + lazy_static! { pub static ref HIGH_FIVE_COUNTER: IntCounter = register_int_counter!("highfives", "Number of high fives recieved").unwrap(); @@ -240,8 +250,12 @@ mod tests { // Increment the counter HIGH_FIVE_COUNTER.inc(); + let server = MetricsServer { + port: PORT, + metrics_encode_fn: metrics_to_string, + }; // Call the function - let res = get_metrics(); + let res = server.get_metrics(); // Check if the function returns a result assert!(res.is_ok()); @@ -254,7 +268,7 @@ mod tests { assert!(response.body().size_hint().exact().unwrap() > 0); assert_eq!( response.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static(prometheus::TEXT_FORMAT) + "text/plain" ); } @@ -262,7 +276,7 @@ mod tests { #[serial(port_usage)] async fn test_start_http_server() { // Start HTTP server - let server = start_http_server(PORT); + let server = start_http_server(PORT, metrics_to_string); // increment the counter HIGH_FIVE_COUNTER.inc(); @@ -286,7 +300,7 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); assert_eq!( response.headers().get(header::CONTENT_TYPE).unwrap(), - HeaderValue::from_static(prometheus::TEXT_FORMAT) + "text/plain" ); // Check if the response body is not empty @@ -304,7 +318,7 @@ mod tests { #[serial(port_usage)] async fn test_unknown_path() { // Start HTTP server - let server = start_http_server(PORT); + let server = start_http_server(PORT, metrics_to_string); // Give the server a moment to start tokio::time::sleep(std::time::Duration::from_secs(1)).await; @@ -334,11 +348,4 @@ mod tests { // Terminate the server server.abort(); } - - #[test] - fn test_metrics_to_string() { - HIGH_FIVE_COUNTER.inc(); - let res = metrics_to_string().unwrap(); - assert!(res.contains("highfives")); - } } diff --git a/example_dsh_service/Cargo.toml b/example_dsh_service/Cargo.toml index 6df8ac9..b4f3b8f 100644 --- a/example_dsh_service/Cargo.toml +++ b/example_dsh_service/Cargo.toml @@ -9,4 +9,5 @@ dsh_sdk = { path = "../dsh_sdk", version = "0.5.0-rc.1", features = ["rdkafka-co rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } log = "0.4" env_logger = "0.11" +prometheus = "0.13" tokio = { version = "^1.35", features = ["full"] } \ No newline at end of file diff --git a/example_dsh_service/src/custom_metrics.rs b/example_dsh_service/src/custom_metrics.rs index a3e93d3..310e87d 100644 --- a/example_dsh_service/src/custom_metrics.rs +++ b/example_dsh_service/src/custom_metrics.rs @@ -1,6 +1,18 @@ -use dsh_sdk::utils::metrics::*; +use prometheus::{register_int_counter, IntCounter}; +use std::sync::OnceLock; -lazy_static! { - pub static ref CONSUMED_MESSAGES: IntCounter = - register_int_counter!("consumed_messages", "Number of messages consumed").unwrap(); +/// Counter for consumed messages +pub fn consumed_messages() -> &'static IntCounter { + static CONSUMED_MESSAGES: OnceLock = OnceLock::new(); + CONSUMED_MESSAGES.get_or_init(|| { + register_int_counter!("consumed_messages", "Number of messages consumed").unwrap() + }) +} + +/// Gather and encode the metrics to string +pub fn gather_and_encode() -> String { + let encoder = prometheus::TextEncoder::new(); + encoder + .encode_to_string(&prometheus::gather()) + .unwrap_or_default() } diff --git a/example_dsh_service/src/main.rs b/example_dsh_service/src/main.rs index c89ae38..991adb0 100644 --- a/example_dsh_service/src/main.rs +++ b/example_dsh_service/src/main.rs @@ -30,7 +30,7 @@ async fn consume(consumer: StreamConsumer, shutdown: Shutdown) { tokio::select! { Ok(msg) = consumer.recv() => { // Increment the counter that is defined in src/metrics.rs - custom_metrics::CONSUMED_MESSAGES.inc(); + custom_metrics::consumed_messages().inc(); // Deserialize and print the message deserialize_and_print(&msg); // Commit the message @@ -56,7 +56,7 @@ async fn main() -> Result<(), Box> { .init(); // Start http server for exposing prometheus metrics, note that in Dockerfile we expose port 8080 as well - dsh_sdk::utils::metrics::start_http_server(8080); + dsh_sdk::utils::metrics::start_http_server(8080, custom_metrics::gather_and_encode); // Get the configured topics from env variable TOPICS (comma separated) let topics_string = std::env::var("TOPICS").expect("TOPICS env variable not set"); From e164f488f79bdfa62f0ad781facfeb9cc61740ca Mon Sep 17 00:00:00 2001 From: Arend-Jan Date: Mon, 13 Jan 2025 11:21:58 +0000 Subject: [PATCH 10/23] improving README.md (#108) * improving README.md * Update dsh_sdk/README.md --------- Co-authored-by: Frank Hol Co-authored-by: Frank Hol <96832951+toelo3@users.noreply.github.com> --- dsh_sdk/README.md | 187 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 132 insertions(+), 55 deletions(-) diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index afe2feb..c348fa8 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -1,3 +1,4 @@ + # dsh-sdk-platform-rs [![Build Status](https://github.com/kpn-dsh/dsh-sdk-platform-rs/actions/workflows/main.yaml/badge.svg)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/actions/workflows/main.yaml) @@ -5,103 +6,179 @@ [![dependency status](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs/status.svg)](https://deps.rs/repo/github/kpn-dsh/dsh-sdk-platform-rs) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -# NOTE -As this is a release candidate it may contain bugs and/or incomplete features and incorrect documentation and future updates may contain breaking changes. +A Rust SDK to interact with the DSH Platform. This library provides convenient building blocks for services that need to connect to DSH Kafka, fetch tokens for various protocols, manage Prometheus metrics, and more. + +> **Note** +> This library (v0.5.x) is a _release candidate_. It may contain incomplete features and/or bugs. Future updates might introduce breaking changes. Please report any issues you find. + +--- -Please report any issues you encounter. +## Table of Contents + +1. [Migration Guide 0.4.X -> 0.5.X](#migration-guide-04x---05x) +2. [Description](#description) +3. [Usage](#usage) +4. [Connecting to DSH](#connect-to-dsh) +5. [Feature Flags](#feature-flags) +6. [Environment Variables](#environment-variables) +7. [Examples](#examples) +8. [Changelog](#changelog) +9. [Contributing](#contributing) +10. [License](#license) +11. [Security](#security) + +--- -## Migration guide 0.4.X -> 0.5.X -See [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for more information on how to migrate from 0.4.X to 0.5.X. +## Migration Guide 0.4.X -> 0.5.X + +If you are migrating from `0.4.X` to `0.5.X`, please see the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for details on breaking changes and how to update your code accordingly. + +--- ## Description -This library can be used to interact with the DSH Platform. It is intended to be used as a base for services that will be used to interact with DSH. Features include: -- Connect to DSH Kafka (DSH, Kafka Proxy, VPN, System Space, Local) - - Bootstrap (fetch datastreams info and generate signed certificate) - - PKI Config Directory (for Kafka Proxy/VPN) -- Kafka config for DSH (incl. RDKafka) -- Management API Token Fetcher (to be used with [dsh_rest_api_client](https://crates.io/crates/dsh_rest_api_client)) -- Protocol Token Fetcher (MQTT and HTTP) -- Common utilities - - Prometheus Metrics (web server and re-export of metrics crate) - - Graceful shutdown - - Dead Letter Queue + +The `dsh-sdk-platform-rs` library offers: + +- **DSH Kafka Connectivity** + - Supports both direct DSH, Kafka Proxy, VPN, and local Kafka. + - Handles datastream information retrieval, certificate signing (bootstrap), and PKI configuration. + +- **Token Fetchers** + - **Management API Token Fetcher**: For use with [`dsh_rest_api_client`](https://crates.io/crates/dsh_rest_api_client). + - **Protocol Token Fetcher**: Obtain tokens for MQTT and HTTP protocol adapters. + +- **DSH Kafka Configuration** + - Trait for getting DSH Compatible Kafka Clients (DSH, Proxy, VPN and Local) + - **RDKafka** implementation + +- **Common Utilities** + - Prometheus metrics (built-in HTTP server, plus re-export of the `metrics` crate). + - Tokio-based graceful shutdown handling. + - Dead Letter Queue (DLQ) functionality. + +--- ## Usage -To use this SDK with the default features in your project, add the following to your Cargo.toml file: - + +To get started, add the following to your `Cargo.toml`: + ```toml [dependencies] dsh_sdk = "0.5" -rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } +rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } ``` -See [feature flags](#feature-flags) for more information on the available features. -To use this SDK in your project +> **Note** +> By default, this SDK enables several features (see [Feature Flags](#feature-flags)). If you do not need them all, you can disable default features to reduce compile times and dependencies. + +### Example + ```rust -use dsh_sdk::DshKafkaConfig; +use dsh_sdk::DshKafkaConfig; // Trait for applying DSH-specific configurations use rdkafka::consumer::{Consumer, StreamConsumer}; use rdkafka::ClientConfig; #[tokio::main] -async fn main() -> Result<(), Box>{ - // get a rdkafka consumer config for example - let consumer: StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; +async fn main() -> Result<(), Box> { + // Configure an rdkafka consumer with DSH settings + let consumer: StreamConsumer = ClientConfig::new() + .set_dsh_consumer_config() + .create()?; + + // Your application logic here + Ok(()) } ``` +--- + ## Connect to DSH -The SDK is compatible with running in a container on a DSH tenant, on DSH System Space, on a machine with Kafka Proxy/VPN or on a local machine to a local Kafka. -See [CONNECT_PROXY_VPN_LOCAL](CONNECT_PROXY_VPN_LOCAL.md) for more info. -## Feature flags -See the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for more information on the changes in feature flags since the v0.5.X update. +This SDK accommodates multiple deployment environments: +- Running in a container on a DSH tenant +- Running in DSH System Space +- Running on a machine with Kafka Proxy/VPN +- Running locally with a local Kafka instance + +For more information, see the [CONNECT_PROXY_VPN_LOCAL.md](CONNECT_PROXY_VPN_LOCAL.md) document. -The following features are available in this library and can be enabled/disabled in your Cargo.toml file: +--- + +## Feature Flags -| **feature** | **default** | **Description** | **Example** | -| --- |--- | --- | --- | -| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | -| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | -| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | -| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs)| -| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | -| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](./examples/dlq_implementation.rs) | +> **Important** +> The feature flags have changed since the `v0.5.X` update. Check the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for details. -See the [api documentation](https://docs.rs/dsh_sdk/latest/dsh_sdk/) for more information on how to use these features. +Below is an overview of the available features: -If you would like to use specific features, you can specify them in your Cargo.toml file. This can save compile time and dependencies. -For example, if you only want to use the Management API token fetcher feature, add the following to your Cargo.toml file: +| **feature** | **default** | **Description** | **Example** | +|--------------------------------|-------------|-------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------| +| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | +| `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | +| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | +| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs) | +| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | +| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](./examples/dlq_implementation.rs) | +### Selecting Features + +To pick only the features you need, disable the default features and enable specific ones. For instance, if you only want the Management API Token Fetcher: ```toml [dependencies] dsh_sdk = { version = "0.5", default-features = false, features = ["management-api-token-fetcher"] } ``` -## Environment variables -The SDK checks environment variables to change configuration for connnecting to DSH. -See [ENV_VARIABLES.md](ENV_VARIABLES.md) which . +--- + +## Environment Variables + +This SDK uses certain environment variables to configure connections to DSH. For a full list of supported variables and their usage, see [ENV_VARIABLES.md](ENV_VARIABLES.md). + +--- ## Examples -See folder [dsh_sdk/examples](./examples/) for simple examples on how to use the SDK. -### Full service example -See folder [example_dsh_service](../example_dsh_service/) for a full service, including how to build the Rust project and post it to Harbor. See [readme](../example_dsh_service/README.md) for more information. +You can find simple usage examples in the [`examples/` directory](./examples/). + +### Full Service Example + +A more complete example is provided in the [`example_dsh_service/`](../example_dsh_service/) directory, showcasing: + +- How to build the Rust project +- How to package and push it to Harbor +- An end-to-end setup of a DSH service + +See the [README](../example_dsh_service/README.md) in that directory for more information. + +--- ## Changelog -See [CHANGELOG.md](CHANGELOG.md) for all changes per version. + +All changes per version are documented in [CHANGELOG.md](CHANGELOG.md). + +--- ## Contributing -See [CONTRIBUTING.md](../CONTRIBUTING.md) for more information on how to contribute to this project. + +Contributions are welcome! For details on how to help improve this project, please see [CONTRIBUTING.md](../CONTRIBUTING.md). + +--- ## License -See [LICENSE](../LICENSE) for more information on the license for this project. + +This project is licensed under the [Apache License 2.0](../LICENSE). + +--- ## Security -See [SECURITY.md](../SECURITY.md) for more information on the security policy for this project. + +For information about the security policy of this project, including how to report vulnerabilities, see [SECURITY.md](../SECURITY.md). --- -_Copyright (c) Koninklijke KPN N.V._ + +© Koninklijke KPN N.V. + From a04f85f254d8ed1102560c50d69e5b30640a0a1e Mon Sep 17 00:00:00 2001 From: Arend-Jan Date: Mon, 13 Jan 2025 11:24:20 +0000 Subject: [PATCH 11/23] 110 improve in code documentation (#111) * improving documententation and comments * improving documententation and comments * improving documententation and commentas for dsh module * improving documententation and commentas for certificates module * improving documententation and commentas for datastream module * [FIX] closure problem * improving documententation and commentas. also added test to the errors * remove tests from error * improving documententation and comments for management api * [FIX] test error problem * [FIX] typo * remove error example in documentation * remove error example in documentation * Update dsh_sdk/src/datastream/mod.rs * Apply suggestions from code review --------- Co-authored-by: Frank Hol <96832951+toelo3@users.noreply.github.com> --- dsh_sdk/src/certificates/mod.rs | 170 +++++++++++++++---------- dsh_sdk/src/datastream/mod.rs | 204 +++++++++++++++++++----------- dsh_sdk/src/dsh.rs | 214 ++++++++++++++------------------ dsh_sdk/src/error.rs | 31 ++++- dsh_sdk/src/lib.rs | 98 +++++++++++++-- 5 files changed, 443 insertions(+), 274 deletions(-) diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs index 9ca97e3..9977993 100644 --- a/dsh_sdk/src/certificates/mod.rs +++ b/dsh_sdk/src/certificates/mod.rs @@ -1,15 +1,22 @@ -//! Handle DSH Certificates and bootstrap process +//! Handles DSH certificates and the bootstrap process. //! -//! The certificate struct holds the DSH CA certificate, the DSH Kafka certificate and -//! the private key. It also has methods to create a reqwest client with the DSH Kafka -//! certificate included and to retrieve the certificates and keys as PEM strings. Also -//! it is possible to create the ca.crt, client.pem, and client.key files in a desired -//! directory. +//! The [`Cert`] struct holds the DSH CA certificate, the DSH Kafka certificate, and +//! the corresponding private key. It provides methods to: +//! - Create Reqwest clients (async/blocking) that embed the Kafka certificate for secure connections +//! - Retrieve certificates and keys as PEM strings +//! - Generate certificate files (`ca.crt`, `client.pem`, and `client.key`) in a target directory //! -//! ## Create files +//! # Usage Flow +//! Typically, you either: +//! 1. **Bootstrap**: Generate and sign certificates using [`Cert::from_bootstrap`] or [`Cert::from_env`], +//! which fetches or creates certificates at runtime. +//! 2. **Load**: Read existing certificates from a directory using [`Cert::from_pki_config_dir`]. //! -//! To create the ca.crt, client.pem, and client.key files in a desired directory, use the -//! `to_files` method. +//! After obtaining a [`Cert`] instance, you can create HTTP clients or retrieve the raw certificate/key data. +//! +//! ## Creating Files +//! To create the `ca.crt`, `client.pem`, and `client.key` files in a desired directory, use the +//! [`Cert::to_files`] method. //! ```no_run //! use dsh_sdk::certificates::Cert; //! use std::path::PathBuf; @@ -38,7 +45,12 @@ mod bootstrap; mod error; mod pki_config_dir; -/// Hold all relevant certificates and keys to connect to DSH Kafka Cluster and Schema Store. +/// Holds all relevant certificates and private keys to connect to the DSH Kafka cluster and the Schema Store. +/// +/// This struct includes: +/// - `dsh_ca_certificate_pem`: The CA certificate (equivalent to `ca.crt`) +/// - `dsh_client_certificate_pem`: The client (Kafka) certificate (equivalent to `client.pem`) +/// - `key_pair`: The private key used for Kafka connections (equivalent to `client.key`) #[derive(Debug, Clone)] pub struct Cert { dsh_ca_certificate_pem: String, @@ -47,7 +59,7 @@ pub struct Cert { } impl Cert { - /// Create new [Cert] struct + /// Creates a new [`Cert`] struct from the given certificate strings and key pair. fn new( dsh_ca_certificate_pem: String, dsh_client_certificate_pem: String, @@ -60,34 +72,41 @@ impl Cert { } } - /// Bootstrap to DSH and sign the certificates. + /// Bootstraps to DSH and signs the certificates. + /// + /// This fetches the DSH CA certificate, creates/signs a Kafka certificate, and generates a private key. /// - /// This method will get DSH CA certificate, sign the Kafka certificate and generate a private key. + /// # Recommended Approach + /// Use [`Cert::from_env`] if you rely on environment variables injected by DSH (e.g., `KAFKA_CONFIG_HOST`, + /// `MESOS_TASK_ID`). This allows an easier switch between Kafka Proxy, VPN connection, etc. /// - /// ## Recommended - /// Use [Cert::from_env] to get the certificates and keys. As this method will check based on the injected environment variables by DSH. - /// This method also allows you to easily switch between Kafka Proxy or VPN connection, based on `PKI_CONFIG_DIR` environment variable. + /// # Arguments + /// - `config_host`: The DSH config host where the CSR is sent. + /// - `tenant_name`: The tenant name. + /// - `task_id`: The running container’s task ID. /// - /// ## Arguments - /// * `config_host` - The DSH config host where the CSR can be send to. - /// * `tenant_name` - The tenant name. - /// * `task_id` - The task id of running container. + /// # Errors + /// Returns a [`CertificatesError`] if the bootstrap process fails (e.g., network issues or invalid inputs). pub fn from_bootstrap( config_host: &str, tenant_name: &str, task_id: &str, ) -> Result { - bootstrap::bootstrap(&config_host, tenant_name, task_id) + bootstrap::bootstrap(config_host, tenant_name, task_id) } - /// Bootstrap to DSH and sign the certificates based on the injected environment variables by DSH. + /// Bootstraps to DSH and signs certificates based on environment variables injected by DSH. /// - /// This method will first check if `PKI_CONFIG_DIR` environment variable is set. If set, it will use the certificates from the directory. - /// This is usefull when you want to use Kafka Proxy, VPN or when a different process that already created the certificates. More info at [CONNECT_PROXY_VPN_LOCAL.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/CONNECT_PROXY_VPN_LOCAL.md). + /// This method checks if `PKI_CONFIG_DIR` is set: + /// - If it is, certificates are loaded from that directory (e.g., when using Kafka Proxy or VPN). + /// - Otherwise, it uses `KAFKA_CONFIG_HOST`, `MESOS_TASK_ID`, and `MARATHON_APP_ID` to bootstrap + /// and sign certificates. /// - /// Else it will check `KAFKA_CONFIG_HOST`, `MESOS_TASK_ID` and `MARATHON_APP_ID` environment variables to bootstrap to DSH and sign the certificates. - /// These environment variables are injected by DSH. + /// # Errors + /// Returns a [`CertificatesError::MisisngInjectedVariables`] if required environment variables are absent, + /// or if the bootstrap operation fails for another reason. pub fn from_env() -> Result { + // Attempt to load from PKI_CONFIG_DIR if let Ok(cert) = Self::from_pki_config_dir::(None) { Ok(cert) } else if let (Ok(config_host), Ok(task_id), Ok(tenant_name)) = ( @@ -101,20 +120,21 @@ impl Cert { } } - /// Get the certificates from a directory. + /// Loads the certificates from a specified directory (or from `PKI_CONFIG_DIR` if set). + /// + /// Useful if certificates are already created and stored locally (e.g., Kafka Proxy, VPN usage). /// - /// This method is usefull if you already have the certificates in a directory. - /// For example if you are using Kafka Proxy, VPN or when a different process already - /// created the certificates. + /// # Arguments + /// - `path`: An optional path to the directory containing the certificates in PEM format. /// - /// ## Arguments - /// * `path` - Path to the directory where the certificates are stored (Optional). + /// If omitted, the `PKI_CONFIG_DIR` environment variable is used. /// - /// path can be overruled by setting the environment variable `PKI_CONFIG_DIR`. + /// # Note + /// - Only PEM format for certificates is supported. + /// - Key files should be in PKCS#8 format and can be in DER or PEM. /// - /// ## Note - /// Only certificates in PEM format are supported. - /// Key files should be in PKCS8 format and can be DER or PEM files. + /// # Errors + /// Returns a [`CertificatesError`] if files are missing, malformed, or cannot be read. pub fn from_pki_config_dir

(path: Option

) -> Result where P: AsRef, @@ -122,8 +142,13 @@ impl Cert { pki_config_dir::get_pki_certificates(path) } - /// Build an async reqwest client with the DSH Kafka certificate included. - /// With this client we can retrieve datastreams.json and conenct to Schema Registry. + /// Builds an **async** Reqwest client with the DSH Kafka certificate included. + /// + /// This client can be used to securely fetch `datastreams.json` or connect to the Schema Registry. + /// + /// # Panics + /// Panics if the certificate or private key is invalid. In practice, this should not occur if + /// the [`Cert`] was instantiated successfully. pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), @@ -136,8 +161,13 @@ impl Cert { .use_rustls_tls() } - /// Build a reqwest client with the DSH Kafka certificate included. - /// With this client we can retrieve datastreams.json and conenct to Schema Registry. + /// Builds a **blocking** Reqwest client with the DSH Kafka certificate included. + /// + /// This client can be used to securely fetch `datastreams.json` or connect to the Schema Registry. + /// + /// # Panics + /// Panics if the certificate or private key is invalid. This should not occur if + /// the [`Cert`] was instantiated successfully. pub fn reqwest_blocking_client_config(&self) -> ClientBuilder { let (pem_identity, reqwest_cert) = Self::prepare_reqwest_client( self.dsh_kafka_certificate_pem(), @@ -150,42 +180,41 @@ impl Cert { .use_rustls_tls() } - /// Get the root certificate as PEM string. Equivalent to ca.crt. + /// Returns the root CA certificate as a PEM string (equivalent to `ca.crt`). pub fn dsh_ca_certificate_pem(&self) -> &str { - self.dsh_ca_certificate_pem.as_str() + &self.dsh_ca_certificate_pem } - /// Get the kafka certificate as PEM string. Equivalent to client.pem. + /// Returns the Kafka certificate as a PEM string (equivalent to `client.pem`). pub fn dsh_kafka_certificate_pem(&self) -> &str { - self.dsh_client_certificate_pem.as_str() + &self.dsh_client_certificate_pem } - /// Get the private key as PKCS8 and return bytes based on asn1 DER format. + /// Returns the private key in PKCS#8 ASN.1 DER-encoded bytes. pub fn private_key_pkcs8(&self) -> Vec { self.key_pair.serialize_der() } - /// Get the private key as PEM string. Equivalent to client.key. + /// Returns the private key as a PEM string (equivalent to `client.key`). pub fn private_key_pem(&self) -> String { self.key_pair.serialize_pem() } - /// Get the public key as PEM string. + /// Returns the public key in PEM format. pub fn public_key_pem(&self) -> String { self.key_pair.public_key_pem() } - /// Get the public key as DER bytes. + /// Returns the public key as DER bytes. pub fn public_key_der(&self) -> Vec { self.key_pair.public_key_der() } - /// Create the ca.crt, client.pem, and client.key files in a desired directory. + /// Creates `ca.crt`, `client.pem`, and `client.key` files in the specified directory. /// - /// This method will create the directory if it does not exist. + /// This method also creates the directory if it doesn't exist. /// /// # Example - /// /// ```no_run /// use dsh_sdk::certificates::Cert; /// use std::path::PathBuf; @@ -197,6 +226,9 @@ impl Cert { /// # Ok(()) /// # } /// ``` + /// + /// # Errors + /// Returns a [`CertificatesError`] if files cannot be created or written. pub fn to_files(&self, dir: &PathBuf) -> Result<(), CertificatesError> { std::fs::create_dir_all(dir)?; Self::create_file(dir.join("ca.crt"), self.dsh_ca_certificate_pem())?; @@ -205,12 +237,17 @@ impl Cert { Ok(()) } + /// Internal helper to create a file with the specified contents. fn create_file>(path: PathBuf, contents: C) -> Result<(), CertificatesError> { std::fs::write(&path, contents)?; info!("File created ({})", path.display()); Ok(()) } + /// Creates a [`reqwest::Identity`] from the certificate and private key bytes. + /// + /// # Errors + /// Returns a `reqwest::Error` if the provided bytes are invalid. fn create_identity( cert: &[u8], private_key: &[u8], @@ -221,25 +258,28 @@ impl Cert { reqwest::Identity::from_pem(&ident) } - /// Panics when the certificate or key is not valid. - /// However, these are already validated during the creation of the `Cert` struct and converted if nedded. + /// Internal helper to set up the [`reqwest::Identity`] and root certificate. + /// + /// # Panics + /// Panics if the certificate or key is invalid, but they should already be validated + /// during [`Cert`] construction. fn prepare_reqwest_client( kafka_certificate: &str, private_key: &str, ca_certificate: &str, ) -> (reqwest::Identity, reqwest::tls::Certificate) { let pem_identity = - Cert::create_identity(kafka_certificate.as_bytes(), private_key.as_bytes()).expect( - "Error creating identity. The kafka certificate or key is not valid. Please check the certificate and key.", - ); - let reqwest_cert = reqwest::tls::Certificate::from_pem(ca_certificate.as_bytes()).expect( - "Error parsing CA certificate as PEM to be used in Reqwest. The certificate is not valid. Please check the certificate.", - ); + Cert::create_identity(kafka_certificate.as_bytes(), private_key.as_bytes()) + .expect("Error creating identity. The Kafka certificate or key is invalid."); + + let reqwest_cert = reqwest::tls::Certificate::from_pem(ca_certificate.as_bytes()) + .expect("Error parsing CA certificate as PEM. The certificate is invalid."); + (pem_identity, reqwest_cert) } } -/// Helper function to ensure that the host starts with `https://` (or `http://`) +/// Helper function to ensure that the host starts with `https://` or `http://`. pub(crate) fn ensure_https_prefix(host: impl AsRef) -> String { if host.as_ref().starts_with("http://") || host.as_ref().starts_with("https://") { host.as_ref().to_string() @@ -273,7 +313,7 @@ mod tests { let pkey_pem_bytes = pkey.private_key_to_pem_pkcs8().unwrap(); let key_pem = cert.private_key_pem(); - let pkey_pem = String::from_utf8_lossy(pkey_pem_bytes.as_slice()); + let pkey_pem = String::from_utf8_lossy(&pkey_pem_bytes); assert_eq!(key_pem, pkey_pem); } @@ -281,11 +321,11 @@ mod tests { fn test_public_key_pem() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); let der = cert.key_pair.serialize_der(); - let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey = PKey::private_key_from_der(&der).unwrap(); let pkey_pub_pem_bytes = pkey.public_key_to_pem().unwrap(); let pub_pem = cert.public_key_pem(); - let pkey_pub_pem = String::from_utf8_lossy(pkey_pub_pem_bytes.as_slice()); + let pkey_pub_pem = String::from_utf8_lossy(&pkey_pub_pem_bytes); assert_eq!(pub_pem, pkey_pub_pem); } @@ -293,7 +333,7 @@ mod tests { fn test_public_key_der() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); let der = cert.key_pair.serialize_der(); - let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey = PKey::private_key_from_der(&der).unwrap(); let pkey_pub_der = pkey.public_key_to_der().unwrap(); let pub_der = cert.public_key_der(); @@ -304,7 +344,7 @@ mod tests { fn test_private_key_pkcs8() { let cert = TEST_CERTIFICATES.get_or_init(set_test_cert); let der = cert.key_pair.serialize_der(); - let pkey = PKey::private_key_from_der(der.as_slice()).unwrap(); + let pkey = PKey::private_key_from_der(&der).unwrap(); let pkey = pkey.private_key_to_pkcs8().unwrap(); let key = cert.private_key_pkcs8(); diff --git a/dsh_sdk/src/datastream/mod.rs b/dsh_sdk/src/datastream/mod.rs index 2d027d8..c2908e7 100644 --- a/dsh_sdk/src/datastream/mod.rs +++ b/dsh_sdk/src/datastream/mod.rs @@ -1,7 +1,17 @@ -//! Datastream properties +//! Datastream properties for DSH. //! -//! The datastreams.json can be parsed into a Datastream struct using serde_json. -//! This struct contains all the information from the datastreams properties file. +//! This module provides the [`Datastream`] struct, which represents the contents of +//! a `datastreams.json` file. This file contains the Kafka broker URLs, streams, consumer groups, +//! and additional metadata needed for interacting with DSH. +//! +//! # Usage Overview +//! - **Local Loading**: By default, you can load `datastreams.json` from the local filesystem +//! (see [`load_local_datastreams`] or [`Datastream::default`]) if running on an environment outside of DSH. +//! - **Server Fetching**: You can also fetch an up-to-date `datastreams.json` from a DSH server (only works when running on DSH) +//! using [`Datastream::fetch`] (async) or [`Datastream::fetch_blocking`] (blocking). +//! +//! The [`Dsh`](crate::dsh::Dsh) struct uses these methods internally to provide either an +//! immutable, initialized `Datastream` or a freshly fetched copy. //! //! # Example //! ```no_run @@ -10,14 +20,17 @@ //! # #[tokio::main] //! # async fn run() -> Result<(), Box> { //! let dsh = Dsh::get(); -//! let datastream = dsh.datastream(); // immutable datastream which is fetched at initialization of SDK -//! // Or -//! let datastream = dsh.fetch_datastream().await?; // fetch a fresh datastream from dsh server +//! +//! // An immutable Datastream, fetched at SDK initialization +//! let datastream = dsh.datastream(); +//! +//! // Or fetch a new Datastream from the DSH server at runtime +//! let datastream = dsh.fetch_datastream().await?; //! //! let brokers = datastream.get_brokers(); //! let schema_store_url = datastream.schema_store(); //! # Ok(()) -//! } +//! # } //! ``` use std::collections::HashMap; use std::env; @@ -31,23 +44,29 @@ use crate::{ utils, VAR_KAFKA_BOOTSTRAP_SERVERS, VAR_KAFKA_CONSUMER_GROUP_TYPE, VAR_LOCAL_DATASTREAMS_JSON, VAR_SCHEMA_REGISTRY_HOST, }; + #[doc(inline)] pub use error::DatastreamError; mod error; +/// Default filename for local datastream definitions. const FILE_NAME: &str = "local_datastreams.json"; -/// Datastream properties file +/// The main struct representing the datastream properties file (`datastreams.json`). /// -/// Read from datastreams.json +/// This file generally includes: +/// - A list of Kafka brokers +/// - Configurable private/shared consumer groups +/// - Mapping of topic names to [`Stream`] configurations +/// - A Schema Store URL /// /// # Example /// ``` /// use dsh_sdk::Dsh; /// -/// let properties = Dsh::get(); -/// let datastream = properties.datastream(); +/// let dsh = Dsh::get(); +/// let datastream = dsh.datastream(); // Typically loaded at init /// /// let brokers = datastream.get_brokers(); /// let streams = datastream.streams(); @@ -64,21 +83,21 @@ pub struct Datastream { } impl Datastream { - /// Get the kafka brokers from the datastreams as a vector of strings + /// Returns a list of Kafka brokers (as `&str`) from this datastream configuration. pub fn get_brokers(&self) -> Vec<&str> { self.brokers.iter().map(|s| s.as_str()).collect() } - /// Get the kafka brokers as comma seperated string from the datastreams + /// Returns the Kafka brokers as a comma-separated string. pub fn get_brokers_string(&self) -> String { self.brokers.join(", ") } - /// Get the group id from the datastreams based on GroupType + /// Returns the consumer group ID based on the specified [`GroupType`]. /// - /// # Error - /// If the index is greater then amount of groups in the datastreams - /// (index out of bounds) + /// # Errors + /// Returns [`DatastreamError::IndexGroupIdError`] if the index is out of bounds or + /// if no such group ID exists. pub fn get_group_id(&self, group_type: GroupType) -> Result<&str, DatastreamError> { let group_id = match group_type { GroupType::Private(i) => self.private_consumer_groups.get(i), @@ -90,21 +109,27 @@ impl Datastream { } } - /// Get all available datastreams (scratch topics, internal topics and stream topics) + /// Returns a reference to the map of all configured streams. + /// + /// Each entry typically corresponds to a topic or topic group in Kafka. pub fn streams(&self) -> &HashMap { &self.streams } - /// Get a specific datastream based on the topic name - /// If the topic is not found, it will return None + /// Looks up a specific stream by its topic name (truncating to the first two segments of the topic). + /// + /// If the topic is not found in the `streams` map, returns `None`. pub fn get_stream(&self, topic: &str) -> Option<&Stream> { - // if topic name contains 2 dots, get the first 2 parts of the topic name - // this is needed because the topic name in datastreams.json is only the first 2 parts let topic_name = topic.split('.').take(2).collect::>().join("."); self.streams().get(&topic_name) } - /// Check if a list of topics is present in the read topics of datastreams + /// Verifies that a list of topic names exist in either the `read` or `write` patterns + /// (depending on the specified [`ReadWriteAccess`]). + /// + /// # Errors + /// Returns [`DatastreamError::NotFoundTopicError`] if any provided topic is missing + /// the required read or write patterns. pub fn verify_list_of_topics( &self, topics: &Vec, @@ -144,19 +169,17 @@ impl Datastream { Ok(()) } - /// Get schema store url from datastreams. - /// - /// ## How to connect to schema registry - /// Use the Reqwest client from `Cert` to connect to the schema registry. - /// As this client is already configured with the correct certificates. + /// Returns the schema store (registry) URL from this datastream configuration. /// - /// You can use [schema_registry_converter](https://crates.io/crates/schema_registry_converter) - /// to fetch the schema and decode your payload. + /// # Connecting to the Schema Registry + /// Use a [`reqwest::Client`] built from [`crate::certificates::Cert`] to connect securely. + /// Tools like [`schema_registry_converter`](https://crates.io/crates/schema_registry_converter) + /// can help fetch and decode messages. pub fn schema_store(&self) -> &str { &self.schema_store } - /// Write datastreams.json in a directory + /// Writes the current `Datastream` to a file named `datastreams.json` in the specified directory. /// /// # Example /// ```no_run @@ -165,6 +188,9 @@ impl Datastream { /// let path = std::path::PathBuf::from("/path/to/directory"); /// datastream.to_file(&path).unwrap(); /// ``` + /// + /// # Errors + /// Returns [`DatastreamError::IoError`] if the file cannot be written. pub fn to_file(&self, path: &std::path::Path) -> Result<(), DatastreamError> { let json_string = serde_json::to_string_pretty(self)?; std::fs::write(path.join("datastreams.json"), json_string)?; @@ -172,10 +198,15 @@ impl Datastream { Ok(()) } - /// Fetch datastreams from the dsh server (async) + /// Asynchronously fetches a `Datastream` from the DSH server using a provided [`reqwest::Client`]. + /// + /// The client should typically be built from [`crate::certificates::Cert::reqwest_client_config`] + /// to include the required SSL certificates. /// - /// Make sure you use a Reqwest client from `Cert` to connect to the dsh server. - /// As this client is already configured with the correct certificates. + /// # Errors + /// Returns: + /// - [`DatastreamError::DshCallError`] if the server responds with a non-success status code. + /// - Any networking or deserialization errors wrapped by [`DatastreamError`]. pub async fn fetch( client: &reqwest::Client, host: &str, @@ -194,10 +225,15 @@ impl Datastream { Ok(response.json().await?) } - /// Fetch datastreams from the dsh server (blocking) + /// Fetches a `Datastream` from the DSH server in a **blocking** manner using a [`reqwest::blocking::Client`]. + /// + /// The client should typically be built from [`crate::certificates::Cert::reqwest_blocking_client_config`] + /// to include the required SSL certificates. /// - /// Make sure you use a Reqwest client from `Cert` to connect to the dsh server. - /// As this client is already configured with the correct certificates. + /// # Errors + /// Returns: + /// - [`DatastreamError::DshCallError`] if the server responds with a non-success status code. + /// - Any networking or deserialization errors wrapped by [`DatastreamError`]. pub fn fetch_blocking( client: &reqwest::blocking::Client, host: &str, @@ -216,14 +252,19 @@ impl Datastream { Ok(response.json()?) } + /// Constructs the URL endpoint for fetching datastreams from the DSH server. pub(crate) fn datastreams_endpoint(host: &str, tenant: &str, task_id: &str) -> String { format!("{}/kafka/config/{}/{}", host, tenant, task_id) } - /// If local_datastreams.json is found, it will load the datastreams from this file. - /// If it does not parse or the file is not found based on on Environment Variable, it will panic. - /// If the Environment Variable is not set, it will look in the current directory. If it is not found, - /// it will return a Error on the Result. Based on this it will use default Datastreams. + /// Attempts to load a local `datastreams.json` from either the current directory or + /// from the path specified by the [`VAR_LOCAL_DATASTREAMS_JSON`] environment variable. + /// + /// If the file cannot be opened or parsed, the method will panic. + /// If the file isn’t found and no environment variable is set, returns an error wrapped in [`DatastreamError`]. + /// + /// # Panics + /// Panics if it finds a file but fails to parse valid JSON. pub(crate) fn load_local_datastreams() -> Result { let path_buf = if let Ok(path) = utils::get_env_var(VAR_LOCAL_DATASTREAMS_JSON) { let path = std::path::PathBuf::from(path); @@ -235,10 +276,11 @@ impl Datastream { } else { std::env::current_dir().unwrap().join(FILE_NAME) }; + debug!("Reading local datastreams from {}", path_buf.display()); let mut file = File::open(&path_buf).map_err(|e| { debug!( - "Failed opening local_datastreams.json ({}): {}", + "Failed to open local_datastreams.json ({}): {}", path_buf.display(), e ); @@ -246,31 +288,44 @@ impl Datastream { })?; let mut contents = String::new(); file.read_to_string(&mut contents).unwrap(); + let mut datastream: Datastream = serde_json::from_str(&contents) - .unwrap_or_else(|e| panic!("Failed to parse {}, {:?}", path_buf.display(), e)); + .unwrap_or_else(|e| panic!("Failed to parse {}: {:?}", path_buf.display(), e)); + + // Allow env vars to override broker or schema store values if let Ok(brokers) = utils::get_env_var(VAR_KAFKA_BOOTSTRAP_SERVERS) { datastream.brokers = brokers.split(',').map(|s| s.to_string()).collect(); } if let Ok(schema_store) = utils::get_env_var(VAR_SCHEMA_REGISTRY_HOST) { datastream.schema_store = schema_store; } + Ok(datastream) } } impl Default for Datastream { + /// Returns a `Datastream` with: + /// - Default or environment-derived brokers + /// - Placeholder consumer groups + /// - A default schema store URL or the environment variable override + /// + /// Typically useful for local development if no `datastreams.json` is present. fn default() -> Self { let group_id = format!( "{}_default_group", utils::tenant_name().unwrap_or("local".to_string()) ); + let brokers = if let Ok(brokers) = utils::get_env_var(VAR_KAFKA_BOOTSTRAP_SERVERS) { brokers.split(',').map(|s| s.to_string()).collect() } else { vec!["localhost:9092".to_string()] }; + let schema_store = utils::get_env_var(VAR_SCHEMA_REGISTRY_HOST) - .unwrap_or("http://localhost:8081/apis/ccompat/v7".to_string()); + .unwrap_or_else(|_| "http://localhost:8081/apis/ccompat/v7".to_string()); + Datastream { brokers, streams: HashMap::new(), @@ -282,7 +337,9 @@ impl Default for Datastream { } } -/// Struct containing all topic information which also is provided in datastreams.json +/// Represents a single stream's information as provided by `datastreams.json`. +/// +/// Includes topic names, partitioning information, read/write access patterns, and more. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct Stream { @@ -298,69 +355,69 @@ pub struct Stream { } impl Stream { - /// Get the Stream's name + /// Returns this stream’s `name` field. pub fn name(&self) -> &str { &self.name } - /// Get the Stream's cluster + /// Returns this stream’s `cluster` field (e.g., “/tt”). pub fn cluster(&self) -> &str { &self.cluster } - /// Get the read pattern as stated in datastreams. + /// Returns the read pattern (regex or exact topic name). /// - /// Use `read_pattern` method to validate if read access is allowed. + /// Use [`Self::read_access`] or [`Self::read_pattern`] to confirm read permissions. pub fn read(&self) -> &str { &self.read } - /// Get the write pattern + /// Returns the write pattern (regex or exact topic name). /// - /// Use `write_pattern` method to validate if write access is allowed. + /// Use [`Self::write_access`] or [`Self::write_pattern`] to confirm write permissions. pub fn write(&self) -> &str { &self.write } - /// Get the Stream's number of partitions + /// Returns the number of partitions for this stream. pub fn partitions(&self) -> i32 { self.partitions } - /// Get the Stream's replication factor + /// Returns the replication factor for this stream. pub fn replication(&self) -> i32 { self.replication } - /// Get the Stream's partitioner + /// Returns the partitioner (e.g., “default-partitioner”). pub fn partitioner(&self) -> &str { &self.partitioner } - /// Get the Stream's partitioning depth + /// Returns the partitioning depth (a more advanced Kafka concept). pub fn partitioning_depth(&self) -> i32 { self.partitioning_depth } - /// Get the Stream's can retain value + /// Indicates whether data retention is possible for this stream. pub fn can_retain(&self) -> bool { self.can_retain } - /// Check read access on topic based on datastream + /// Checks if the stream has a `read` pattern configured. pub fn read_access(&self) -> bool { !self.read.is_empty() } - /// Check write access on topic based on datastream + /// Checks if the stream has a `write` pattern configured. pub fn write_access(&self) -> bool { !self.write.is_empty() } - /// Get the Stream's Read whitelist pattern + /// Returns the read pattern, or errors if the stream has no read access. /// - /// ## Error - /// If the topic does not have read access it returns a `TopicPermissionsError` + /// # Errors + /// Returns [`DatastreamError::TopicPermissionsError`] if the stream has no read pattern set. pub fn read_pattern(&self) -> Result<&str, DatastreamError> { if self.read_access() { Ok(&self.read) @@ -372,10 +429,10 @@ impl Stream { } } - /// Get the Stream's Write pattern + /// Returns the write pattern, or errors if the stream has no write access. /// - /// ## Error - /// If the topic does not have write access it returns a `TopicPermissionsError` + /// # Errors + /// Returns [`DatastreamError::TopicPermissionsError`] if the stream has no write pattern set. pub fn write_pattern(&self) -> Result<&str, DatastreamError> { if self.write_access() { Ok(&self.write) @@ -388,14 +445,15 @@ impl Stream { } } -/// Enum to indicate if we want to check the read or write topics +/// Indicates whether the caller needs read or write access. #[derive(Debug, Clone, PartialEq)] pub enum ReadWriteAccess { Read, Write, } -/// Enum to indicate the group type (private or shared) +/// Specifies whether a consumer group is private or shared, along with an index +/// for selecting from the corresponding array in `Datastream`. #[derive(Debug, PartialEq)] pub enum GroupType { Private(usize), @@ -403,15 +461,15 @@ pub enum GroupType { } impl GroupType { - /// Get the group type from the environment variable KAFKA_CONSUMER_GROUP_TYPE - /// If KAFKA_CONSUMER_GROUP_TYPE is not (properly) set, it defaults to shared + /// Determines the group type from the `KAFKA_CONSUMER_GROUP_TYPE` environment variable, + /// defaulting to [`GroupType::Shared(0)`] if unset or invalid. pub fn from_env() -> Self { let group_type = env::var(VAR_KAFKA_CONSUMER_GROUP_TYPE); match group_type { - Ok(s) if s.to_lowercase() == *"private" => GroupType::Private(0), - Ok(s) if s.to_lowercase() == *"shared" => GroupType::Shared(0), + Ok(s) if s.eq_ignore_ascii_case("private") => GroupType::Private(0), + Ok(s) if s.eq_ignore_ascii_case("shared") => GroupType::Shared(0), Ok(_) => { - error!("KAFKA_CONSUMER_GROUP_TYPE is not set with \"shared\" or \"private\", defaulting to shared group type."); + error!("KAFKA_CONSUMER_GROUP_TYPE is not set to \"shared\" or \"private\". Defaulting to shared group type."); GroupType::Shared(0) } Err(_) => { @@ -425,8 +483,8 @@ impl GroupType { impl std::fmt::Display for GroupType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - GroupType::Private(i) => write!(f, "private; index: {}", i), - GroupType::Shared(i) => write!(f, "shared; index: {}", i), + GroupType::Private(i) => write!(f, "private; index: {i}"), + GroupType::Shared(i) => write!(f, "shared; index: {i}"), } } } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index ed3d09e..21d7d46 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -1,13 +1,12 @@ -//! High-level API to interact with DSH when your container is running on DSH. +//! High-level API for interacting with DSH when your container is running on DSH. //! -//! From [Dsh] there are level functions to get the correct config to connect to Kafka and schema store. -//! For more low level functions, see -//! - [datastream] module. -//! - [certificates] module. +//! From [`Dsh`] you can retrieve the correct configuration to connect to Kafka and the schema store. //! -//! ## Environment variables -//! See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for -//! more information configuring the consmer or producer via environment variables. +//! For more low-level functions, see the [`datastream`] and [`certificates`] modules. +//! +//! ## Environment Variables +//! Refer to [`ENV_VARIABLES.md`](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) +//! for more information on configuring the consumer or producer via environment variables. //! //! # Example //! ```no_run @@ -17,12 +16,11 @@ //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let dsh = Dsh::get(); -//! let certificates = dsh.certificates()?; +//! let certificates = dsh.certificates()?; //! let datastreams = dsh.datastream(); //! let kafka_config = dsh.kafka_config(); //! let tenant_name = dsh.tenant_name(); //! let task_id = dsh.task_id(); -//! //! # Ok(()) //! # } //! ``` @@ -42,15 +40,14 @@ use crate::protocol_adapters::kafka_protocol::config::KafkaConfig; // TODO: Remove at v0.6.0 pub use crate::dsh_old::*; -/// Lazily initialize all related components to connect to the DSH -/// - Contains info from datastreams.json -/// - Metadata of running container/task -/// - Certificates for Kafka and DSH Schema Registry +/// Lazily initializes all related components to connect to DSH: +/// - Information from `datastreams.json` +/// - Metadata of the running container/task +/// - Certificates for Kafka and DSH Schema Registry /// -/// ## Environment variables -/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for -/// more information configuring the consmer or producer via environment variables. - +/// ## Environment Variables +/// Refer to [`ENV_VARIABLES.md`](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) +/// for details on configuring the consumer or producer via environment variables. #[derive(Debug, Clone)] pub struct Dsh { config_host: String, @@ -63,7 +60,7 @@ pub struct Dsh { } impl Dsh { - /// New `Dsh` struct + /// Constructs a new `Dsh` struct. This is internal and should typically be accessed via [`Dsh::get()`]. pub(crate) fn new( config_host: String, task_id: String, @@ -82,19 +79,18 @@ impl Dsh { kafka_config: KafkaConfig::new(Some(datastream)), } } - /// Get the DSH Dsh on a lazy way. If not already initialized, it will initialize the properties - /// and bootstrap to DSH. - /// - /// This struct contains all configuration and certificates needed to connect to Kafka and DSH. + + /// Returns a reference to the global `Dsh` instance, initializing it if necessary. /// - /// - Contains a struct equal to datastreams.json - /// - Metadata of running container/task - /// - Certificates for Kafka and DSH + /// This struct contains configuration and certificates needed to connect to Kafka and DSH: + /// - A struct mirroring `datastreams.json` + /// - Metadata for the running container/task + /// - Certificates for Kafka and DSH /// /// # Panics - /// This method can panic when running on local machine and tries to load incorrect [local_datastream.json](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/local_datastreams.json). - /// When no file is available in root or path on env variable `LOCAL_DATASTREAMS_JSON` is not set, it will - /// return a default datastream struct and NOT panic. + /// Panics if attempting to load an incorrect `local_datastream.json` on a local machine. + /// If no file is available or the `LOCAL_DATASTREAMS_JSON` env variable is unset, it returns a default + /// `datastream` struct and does **not** panic. /// /// # Example /// ``` @@ -112,24 +108,24 @@ impl Dsh { PROPERTIES.get_or_init(|| tokio::task::block_in_place(Self::init)) } - /// Initialize the properties and bootstrap to DSH + /// Initializes the properties and bootstraps to DSH. fn init() -> Self { - let tenant_name = utils::tenant_name().unwrap_or("local_tenant".to_string()); - let task_id = utils::get_env_var(VAR_TASK_ID).unwrap_or("local_task_id".to_string()); - let config_host = - utils::get_env_var(VAR_KAFKA_CONFIG_HOST).map(|host| ensure_https_prefix(host)); + let tenant_name = utils::tenant_name().unwrap_or_else(|_| "local_tenant".to_string()); + let task_id = + utils::get_env_var(VAR_TASK_ID).unwrap_or_else(|_| "local_task_id".to_string()); + let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST).map(ensure_https_prefix); + let certificates = if let Ok(cert) = Cert::from_pki_config_dir::(None) { Some(cert) } else if let Ok(config_host) = &config_host { Cert::from_bootstrap(config_host, &tenant_name, &task_id) - .inspect_err(|e| { - warn!("Could not bootstrap to DSH, due to: {}", e); - }) + .inspect_err(|e| warn!("Could not bootstrap to DSH, due to: {}", e)) .ok() } else { None }; - let config_host = config_host.unwrap_or(DEFAULT_CONFIG_HOST.to_string()); + + let config_host = config_host.unwrap_or_else(|_| DEFAULT_CONFIG_HOST.to_string()); let fetched_datastreams = certificates.as_ref().and_then(|cert| { cert.reqwest_blocking_client_config() .build() @@ -138,17 +134,19 @@ impl Dsh { Datastream::fetch_blocking(&client, &config_host, &tenant_name, &task_id).ok() }) }); + let datastream = if let Some(datastream) = fetched_datastreams { datastream } else { - warn!("Could not fetch datastreams.json, using local or default datastreams"); + warn!("Could not fetch datastreams.json; using local or default datastreams"); Datastream::load_local_datastreams().unwrap_or_default() }; + Self::new(config_host, task_id, tenant_name, datastream, certificates) } - /// Get reqwest async client config to connect to DSH Schema Registry. - /// If certificates are present, it will use SSL to connect to Schema Registry. + /// Returns a `reqwest::ClientBuilder` configured to connect to the DSH Schema Registry. + /// If certificates are present, SSL is used. Otherwise, it falls back to a non-SSL connection. /// /// # Example /// ``` @@ -158,7 +156,7 @@ impl Dsh { /// # async fn main() -> Result<(), Box> { /// let dsh = Dsh::get(); /// let client = dsh.reqwest_client_config().build()?; - /// # Ok(()) + /// # Ok(()) /// # } /// ``` pub fn reqwest_client_config(&self) -> reqwest::ClientBuilder { @@ -169,8 +167,8 @@ impl Dsh { } } - /// Get reqwest blocking client config to connect to DSH Schema Registry. - /// If certificates are present, it will use SSL to connect to Schema Registry. + /// Returns a `reqwest::blocking::ClientBuilder` configured to connect to the DSH Schema Registry. + /// If certificates are present, SSL is used. Otherwise, it falls back to a non-SSL connection. /// /// # Example /// ``` @@ -179,7 +177,7 @@ impl Dsh { /// # fn main() -> Result<(), Box> { /// let dsh = Dsh::get(); /// let client = dsh.reqwest_blocking_client_config().build()?; - /// # Ok(()) + /// # Ok(()) /// # } /// ``` pub fn reqwest_blocking_client_config(&self) -> reqwest::blocking::ClientBuilder { @@ -190,7 +188,7 @@ impl Dsh { } } - /// Get the certificates and private key. Returns an error when running on local machine. + /// Retrieves the certificates and private key. Returns an error when running on a local machine. /// /// # Example /// ```no_run @@ -198,48 +196,43 @@ impl Dsh { /// # fn main() -> Result<(), Box> { /// let dsh = Dsh::get(); /// let dsh_kafka_certificate = dsh.certificates()?.dsh_kafka_certificate_pem(); - /// # Ok(()) + /// # Ok(()) /// # } /// ``` pub fn certificates(&self) -> Result<&Cert, DshError> { - Ok(if let Some(cert) = &self.certificates { - Ok(cert) - } else { - Err(CertificatesError::NoCertificates) - }?) + match &self.certificates { + Some(cert) => Ok(cert), + None => Err(CertificatesError::NoCertificates.into()), + } } - /// Get the client id based on the task id. + /// Returns the client ID derived from the task ID. pub fn client_id(&self) -> &str { &self.task_id } - /// Get the tenant name of running container. + /// Returns the tenant name of the running container. pub fn tenant_name(&self) -> &str { &self.tenant_name } - /// Get the task id of running container. + /// Returns the task ID of the running container. pub fn task_id(&self) -> &str { &self.task_id } - /// Get the kafka properties provided by DSH (datastreams.json) - /// - /// This datastream is fetched at initialization of the properties, and can not be updated during runtime. + /// Returns the current datastream object (fetched at initialization). This cannot be updated at runtime. pub fn datastream(&self) -> &Datastream { self.datastream.as_ref() } - /// High level method to fetch the kafka properties provided by DSH (datastreams.json) - /// This will fetch the datastream from DSH. This can be used to update the datastream during runtime. - /// - /// This method keeps the reqwest client in memory to prevent creating a new client for every request. + /// Fetches the latest datastream (Kafka properties) from DSH asynchronously. + /// This can be used to update the datastream during runtime. /// /// # Panics - /// This method panics when it can't initialize a reqwest client. + /// Panics if it fails to build a reqwest client. /// - /// Use [Datastream::fetch] as a lowlevel method where you can provide your own client. + /// For a lower-level method allowing a custom client, see [`Datastream::fetch`]. pub async fn fetch_datastream(&self) -> Result { static ASYNC_CLIENT: OnceLock = OnceLock::new(); @@ -248,18 +241,17 @@ impl Dsh { .build() .expect("Could not build reqwest client for fetching datastream") }); + Ok(Datastream::fetch(client, &self.config_host, &self.tenant_name, &self.task_id).await?) } - /// High level method to fetch the kafka properties provided by DSH (datastreams.json) in a blocking way. - /// This will fetch the datastream from DSH. This can be used to update the datastream during runtime. - /// - /// This method keeps the reqwest client in memory to prevent creating a new client for every request. + /// Fetches the latest datastream from DSH in a blocking manner. + /// This can be used to update the datastream during runtime. /// /// # Panics - /// This method panics when it can't initialize a reqwest client. + /// Panics if it fails to build a reqwest blocking client. /// - /// Use [Datastream::fetch_blocking] as a lowlevel method where you can provide your own client. + /// For a lower-level method allowing a custom client, see [`Datastream::fetch_blocking`]. pub fn fetch_datastream_blocking(&self) -> Result { static BLOCKING_CLIENT: OnceLock = OnceLock::new(); @@ -268,6 +260,7 @@ impl Dsh { .build() .expect("Could not build reqwest client for fetching datastream") }); + Ok(Datastream::fetch_blocking( client, &self.config_host, @@ -276,7 +269,7 @@ impl Dsh { )?) } - /// Get schema host of DSH. + /// Returns the schema registry host as defined by the datastream. pub fn schema_registry_host(&self) -> &str { self.datastream().schema_store() } @@ -284,17 +277,13 @@ impl Dsh { #[cfg(feature = "kafka")] #[deprecated( since = "0.5.0", - note = "Moved to `Dsh::kafka_config().kafka_brokers()` and is part of the `kafka` feature" + note = "Moved to `Dsh::kafka_config().kafka_brokers()`. Part of the `kafka` feature." )] - /// Get the Kafka brokers. - /// - /// ## Environment variables - /// To manipulate the hastnames of the brokers, you can set the following environment variables. + /// Returns the Kafka brokers. /// - /// ### `KAFKA_BOOTSTRAP_SERVERS` - /// - Usage: Overwrite hostnames of brokers - /// - Default: Brokers based on datastreams - /// - Required: `false` + /// ## Environment Variables + /// - `KAFKA_BOOTSTRAP_SERVERS`: Overwrites broker hostnames (optional). + /// Defaults to brokers from the datastream. pub fn kafka_brokers(&self) -> String { self.datastream().get_brokers_string() } @@ -302,24 +291,15 @@ impl Dsh { #[cfg(feature = "kafka")] #[deprecated( since = "0.5.0", - note = "Moved to `Dsh::kafka_config().group_id()` and is part of the `kafka` feature" + note = "Moved to `Dsh::kafka_config().group_id()`. Part of the `kafka` feature." )] - /// Get the kafka_group_id based. + /// Returns the Kafka group ID. /// - /// ## Environment variables - /// To manipulate the group id, you can set the following environment variables. - /// - /// ### `KAFKA_CONSUMER_GROUP_TYPE` - /// - Usage: Picks group_id based on type from datastreams - /// - Default: Shared - /// - Options: private, shared - /// - Required: `false` + /// ## Environment Variables + /// - `KAFKA_CONSUMER_GROUP_TYPE`: Chooses a group ID type (private or shared). + /// - `KAFKA_GROUP_ID`: Custom group ID. Overrules `KAFKA_CONSUMER_GROUP_TYPE`. /// - /// ### `KAFKA_GROUP_ID` - /// - Usage: Custom group id - /// - Default: NA - /// - Required: `false` - /// - Remark: Overrules `KAFKA_CONSUMER_GROUP_TYPE`. Mandatory to start with tenant name. (will prefix tenant name automatically if not set) + /// If the group ID doesn't start with the tenant name, it's automatically prefixed. pub fn kafka_group_id(&self) -> String { if let Ok(group_id) = env::var(VAR_KAFKA_GROUP_ID) { if !group_id.starts_with(self.tenant_name()) { @@ -338,18 +318,12 @@ impl Dsh { #[cfg(feature = "kafka")] #[deprecated( since = "0.5.0", - note = "Moved to `Dsh::kafka_config().enable_auto_commit()` and is part of the `kafka` feature" + note = "Moved to `Dsh::kafka_config().enable_auto_commit()`. Part of the `kafka` feature." )] - /// Get the confifured kafka auto commit setinngs. + /// Returns the configured Kafka auto-commit setting. /// - /// ## Environment variables - /// To manipulate the auto commit settings, you can set the following environment variables. - /// - /// ### `KAFKA_ENABLE_AUTO_COMMIT` - /// - Usage: Enable/Disable auto commit - /// - Default: `false` - /// - Required: `false` - /// - Options: `true`, `false` + /// ## Environment Variables + /// - `KAFKA_ENABLE_AUTO_COMMIT`: Enables/disables auto commit (default: `false`). pub fn kafka_auto_commit(&self) -> bool { self.kafka_config.enable_auto_commit() } @@ -357,33 +331,28 @@ impl Dsh { #[cfg(feature = "kafka")] #[deprecated( since = "0.5.0", - note = "Moved to `Dsh::kafka_config().auto_offset_reset()` and is part of the `kafka` feature" + note = "Moved to `Dsh::kafka_config().auto_offset_reset()`. Part of the `kafka` feature." )] - /// Get the kafka auto offset reset settings. - /// - /// ## Environment variables - /// To manipulate the auto offset reset settings, you can set the following environment variables. + /// Returns the Kafka auto-offset-reset setting. /// - /// ### `KAFKA_AUTO_OFFSET_RESET` - /// - Usage: Set the offset reset settings to start consuming from set option. - /// - Default: earliest - /// - Required: `false` - /// - Options: smallest, earliest, beginning, largest, latest, end + /// ## Environment Variables + /// - `KAFKA_AUTO_OFFSET_RESET`: Set the offset reset policy (default: `earliest`). pub fn kafka_auto_offset_reset(&self) -> String { self.kafka_config.auto_offset_reset().to_string() } #[cfg(feature = "kafka")] - /// Get the kafka config from initiated Dsh struct. + /// Returns the [`KafkaConfig`] from this `Dsh` instance. pub fn kafka_config(&self) -> &KafkaConfig { &self.kafka_config } #[deprecated( since = "0.5.0", - note = "Use `Dsh::DshKafkaConfig` trait instead, see https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)" + note = "Use the `DshKafkaConfig` trait instead. See wiki for migration details." )] #[cfg(feature = "rdkafka-config")] + /// Returns an `rdkafka::config::ClientConfig` for a consumer, configured via Dsh. pub fn consumer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); @@ -393,9 +362,10 @@ impl Dsh { #[deprecated( since = "0.5.0", - note = "Use `Dsh::DshKafkaConfig` trait instead, see https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)" + note = "Use the `DshKafkaConfig` trait instead. See wiki for migration details." )] #[cfg(feature = "rdkafka-config")] + /// Returns an `rdkafka::config::ClientConfig` for a producer, configured via Dsh. pub fn producer_rdkafka_config(&self) -> rdkafka::config::ClientConfig { use crate::protocol_adapters::kafka_protocol::DshKafkaConfig; let mut config = rdkafka::config::ClientConfig::new(); @@ -426,7 +396,7 @@ mod tests { } } - // maybe replace with local_datastreams.json? + // Helper to load test datastreams from a file. fn datastreams_json() -> String { std::fs::File::open("test_resources/valid_datastreams.json") .map(|mut file| { @@ -437,9 +407,9 @@ mod tests { .unwrap() } - // Define a reusable Dsh instance + // Helper to create a test Datastream. fn datastream() -> Datastream { - serde_json::from_str(datastreams_json().as_str()).unwrap() + serde_json::from_str(&datastreams_json()).unwrap() } #[test] @@ -461,7 +431,7 @@ mod tests { fn test_reqwest_client_config() { let properties = Dsh::default(); let _ = properties.reqwest_client_config(); - assert!(true) + assert!(true); } #[test] diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index 03b6c52..0bc3a86 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -1,14 +1,43 @@ -/// Errors for the DSH SDK +//! Error types and reporting utilities for the DSH SDK. +//! +//! This module defines the primary error enum, [`DshError`], which aggregates +//! sub-errors from certificates, datastreams, and various utilities. It also +//! includes a helper function, [`report`], for generating a more readable error +//! trace by iterating over source causes. + +/// The main error type for the DSH SDK. +/// +/// This enum wraps more specific errors from different parts of the SDK: +/// - [`CertificatesError`](crate::certificates::CertificatesError) +/// - [`DatastreamError`](crate::datastream::DatastreamError) +/// - [`UtilsError`](crate::utils::UtilsError) +/// +/// Each variant implements `std::error::Error` and can be conveniently converted +/// from the underlying error types (via `#[from]`). +/// #[derive(Debug, thiserror::Error)] pub enum DshError { + /// Wraps an error originating from certificate handling. #[error("Certificates error: {0}")] CertificatesError(#[from] crate::certificates::CertificatesError), + + /// Wraps an error originating from datastream operations or configuration. #[error("Datastream error: {0}")] DatastreamError(#[from] crate::datastream::DatastreamError), + + /// Wraps an error from general utilities or environment lookups. #[error("Utils error: {0}")] UtilsError(#[from] crate::utils::UtilsError), } +/// Generates a user-friendly error trace by traversing all `source()` +/// causes in the given error. +/// +/// The returned `String` contains the primary error message, followed +/// by each causal error (if any) on separate lines, preceded by `"Caused by:"`. +/// +/// This is helpful for logging or displaying the entire chain of errors. +/// pub(crate) fn report(mut err: &dyn std::error::Error) -> String { let mut s = format!("{}", err); while let Some(src) = err.source() { diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 6ec099f..6cf5a93 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -1,7 +1,22 @@ #![doc = include_str!("../README.md")] #![allow(deprecated)] -// to be kept in v0.6.0 +//! This crate provides core functionalities, environment variable references, and feature-gated +//! modules for working with DSH (Data Services Hub). It also includes deprecated items from +//! older versions of the API that remain available for backward compatibility until v0.6.0. +//! +//! # Crate Overview +//! - **Feature-gated modules**: Certain modules like [`certificates`], [`datastream`], [`dsh`], etc. +//! are only compiled if the `bootstrap` feature is enabled. +//! - **Environment variables**: Constants representing environment variables are declared for +//! configuring Kafka, schema registry, and other components. +//! - **Deprecated modules**: Modules such as [`dlq`], [`dsh_old`], [`graceful_shutdown`], and +//! [`metrics`], among others, are slated for removal in v0.6.0 or have been moved to new +//! namespaces. +//! +//! Refer to the included `README.md` for more information on usage, setup, and features. + +// Keep in v0.6.0 for backward compatibility #[cfg(feature = "bootstrap")] pub mod certificates; #[cfg(feature = "bootstrap")] @@ -11,14 +26,19 @@ pub mod dsh; #[cfg(feature = "bootstrap")] mod error; +// Management API token fetcher feature #[cfg(feature = "management-api-token-fetcher")] pub mod management_api; + +// Protocol adapters and utilities pub mod protocol_adapters; pub mod utils; +// Schema Store feature #[cfg(feature = "schema-store")] pub mod schema_store; +// Re-exports for convenience #[cfg(feature = "bootstrap")] #[doc(inline)] pub use {dsh::Dsh, error::DshError}; @@ -42,11 +62,9 @@ pub mod dlq; #[cfg(feature = "bootstrap")] #[deprecated( since = "0.5.0", - note = "The `Properties` struct phased out. Use - `dsh_sdk::Dsh` for an all-in-one struct, similar to the original `Properties`; - `dsh_sdk::certificates` for all certificate related info; - `dsh_sdk::datastream` for all datastream related info; - " + note = "The `Properties` struct is phased out. Use `dsh_sdk::Dsh` for an all-in-one struct; \ + `dsh_sdk::certificates` for certificate management; `dsh_sdk::datastream` for \ + datastream handling." )] pub mod dsh_old; @@ -70,50 +88,104 @@ pub mod metrics; note = "`dsh_sdk::mqtt_token_fetcher` is moved to `dsh_sdk::protocol_adapters::token_fetcher`" )] pub mod mqtt_token_fetcher; + #[cfg(feature = "bootstrap")] pub use dsh_old::Properties; #[cfg(feature = "management-api-token-fetcher")] #[deprecated( since = "0.5.0", - note = "`RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to `ManagementApiTokenFetcher` and `ManagementApiTokenFetcherBuilder`" + note = "`RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to \ + `ManagementApiTokenFetcher` and `ManagementApiTokenFetcherBuilder`" )] mod rest_api_token_fetcher; #[cfg(feature = "management-api-token-fetcher")] pub use rest_api_token_fetcher::{RestTokenFetcher, RestTokenFetcherBuilder}; +// ---------------------------------------------------------------------------- // Environment variables +// ---------------------------------------------------------------------------- + +// These constants define the names of environment variables used throughout the DSH SDK. +// They are grouped into logical sections for clarity. + +// -------------------- General environment variables -------------------- + +/// Environment variable for retrieving the Marathon application ID. const VAR_APP_ID: &str = "MARATHON_APP_ID"; + +/// Environment variable for retrieving the Mesos task ID. const VAR_TASK_ID: &str = "MESOS_TASK_ID"; + +/// Environment variable for retrieving the DSH CA certificate const VAR_DSH_CA_CERTIFICATE: &str = "DSH_CA_CERTIFICATE"; + +/// Inline secret token used to authorize requests to DSH. const VAR_DSH_SECRET_TOKEN: &str = "DSH_SECRET_TOKEN"; + +/// Filesystem path to the DSH secret token. const VAR_DSH_SECRET_TOKEN_PATH: &str = "DSH_SECRET_TOKEN_PATH"; + +/// Tenant name for DSH. const VAR_DSH_TENANT_NAME: &str = "DSH_TENANT_NAME"; + +/// DSH config host, typically pointing to an internal endpoint. const VAR_KAFKA_CONFIG_HOST: &str = "KAFKA_CONFIG_HOST"; -// kafka general +/// PKI configuration directory. +const VAR_PKI_CONFIG_DIR: &str = "PKI_CONFIG_DIR"; + +/// Local datastream configurations in JSON format (e.g., for Kafka Proxy or local testing). +const VAR_LOCAL_DATASTREAMS_JSON: &str = "LOCAL_DATASTREAMS_JSON"; + +// -------------------- Kafka general environment variables -------------------- + +/// Lists Kafka bootstrap servers. const VAR_KAFKA_BOOTSTRAP_SERVERS: &str = "KAFKA_BOOTSTRAP_SERVERS"; + +/// Specifies the schema registry host for Kafka. const VAR_SCHEMA_REGISTRY_HOST: &str = "SCHEMA_REGISTRY_HOST"; -// Consumer +// -------------------- Kafka consumer environment variables -------------------- + +/// Controls how offsets are handled when no offset is available (e.g., "earliest" or "latest"). const VAR_KAFKA_AUTO_OFFSET_RESET: &str = "KAFKA_AUTO_OFFSET_RESET"; + +/// Indicates the Kafka consumer group type. const VAR_KAFKA_CONSUMER_GROUP_TYPE: &str = "KAFKA_CONSUMER_GROUP_TYPE"; + +/// Specifies whether consumer auto-commit is enabled. const VAR_KAFKA_ENABLE_AUTO_COMMIT: &str = "KAFKA_ENABLE_AUTO_COMMIT"; + +/// Defines the group ID for Kafka consumers. const VAR_KAFKA_GROUP_ID: &str = "KAFKA_GROUP_ID"; + +/// Kafka consumer session timeout in milliseconds. const VAR_KAFKA_CONSUMER_SESSION_TIMEOUT_MS: &str = "KAFKA_CONSUMER_SESSION_TIMEOUT_MS"; + +/// Kafka consumer's maximum queued buffering in kilobytes. const VAR_KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES: &str = "KAFKA_CONSUMER_QUEUED_BUFFERING_MAX_MESSAGES_KBYTES"; -// Producer +// -------------------- Kafka producer environment variables -------------------- + +/// Kafka producer's number of messages per batch. const VAR_KAFKA_PRODUCER_BATCH_NUM_MESSAGES: &str = "KAFKA_PRODUCER_BATCH_NUM_MESSAGES"; + +/// Kafka producer's maximum queue buffering for messages. const VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES: &str = "KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MESSAGES"; + +/// Kafka producer's maximum queue buffering in kilobytes. const VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES: &str = "KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_KBYTES"; -const VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS: &str = "KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS"; -const VAR_PKI_CONFIG_DIR: &str = "PKI_CONFIG_DIR"; +/// Kafka producer's maximum queue buffering duration in milliseconds. +const VAR_KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS: &str = "KAFKA_PRODUCER_QUEUE_BUFFERING_MAX_MS"; -const VAR_LOCAL_DATASTREAMS_JSON: &str = "LOCAL_DATASTREAMS_JSON"; +// ---------------------------------------------------------------------------- +// Default configuration +// ---------------------------------------------------------------------------- +/// Default configuration host for DSH, used if no environment variable overrides are provided. const DEFAULT_CONFIG_HOST: &str = "https://pikachu.dsh.marathon.mesos:4443"; From cf118f421b3057a8c3fea5eab7da6aa46d399327 Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 13:48:54 +0100 Subject: [PATCH 12/23] move platform --- dsh_sdk/src/utils/mod.rs | 168 +--------------------------------- dsh_sdk/src/utils/platform.rs | 168 ++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 164 deletions(-) create mode 100644 dsh_sdk/src/utils/platform.rs diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index 7023c79..e716fc4 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -19,149 +19,12 @@ pub(crate) mod http_client; #[cfg(feature = "metrics")] pub mod metrics; -mod error; - -/// Available DSH platforms plus it's related metadata -/// -/// The platform enum contains -/// - `Prod` (kpn-dsh.com) -/// - `ProdAz` (az.kpn-dsh.com) -/// - `ProdLz` (dsh-prod.dsh.prod.aws.kpn.com) -/// - `NpLz` (dsh-dev.dsh.np.aws.kpn.com) -/// - `Poc` (poc.kpn-dsh.com) -/// -/// Each platform has it's own realm, endpoint for the DSH Rest API and endpoint for the DSH Rest API access token. -#[derive(Clone, Debug)] -#[non_exhaustive] -pub enum Platform { - /// Production platform (kpn-dsh.com) - Prod, - /// Production platform on Azure (az.kpn-dsh.com) - ProdAz, - /// Production Landing Zone on AWS (dsh-prod.dsh.prod.aws.kpn.com) - ProdLz, - /// Non-Production (Dev) Landing Zone on AWS (dsh-dev.dsh.np.aws.kpn.com) - NpLz, - /// Proof of Concept platform (poc.kpn-dsh.com) - Poc, -} - -impl Platform { - /// Get a properly formatted client_id for the Rest API based on the given name of a tenant - /// - /// It will return a string formatted as "robot:{realm}:{tenant_name}" - /// - /// ## Example - /// ``` - /// # use dsh_sdk::Platform; - /// let platform = Platform::NpLz; - /// let client_id = platform.rest_client_id("my-tenant"); - /// assert_eq!(client_id, "robot:dev-lz-dsh:my-tenant"); - /// ``` - pub fn rest_client_id(&self, tenant: T) -> String - where - T: AsRef, - { - format!("robot:{}:{}", self.realm(), tenant.as_ref()) - } - - /// Get the endpoint for the DSH Rest API - /// - /// It will return the endpoint for the DSH Rest API based on the platform - /// - /// ## Example - /// ``` - /// # use dsh_sdk::Platform; - /// let platform = Platform::NpLz; - /// let endpoint = platform.endpoint_rest_api(); - /// assert_eq!(endpoint, "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0"); - /// ``` - pub fn endpoint_rest_api(&self) -> &str { - match self { - Self::Prod => "https://api.kpn-dsh.com/resources/v0", - Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0", - Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/resources/v0", - Self::ProdAz => "https://api.az.kpn-dsh.com/resources/v0", - Self::Poc => "https://api.poc.kpn-dsh.com/resources/v0", - } - } - /// Get the endpoint for the DSH Rest API access token - /// - /// It will return the endpoint for the DSH Rest API access token based on the platform - /// - /// ## Example - /// ``` - /// # use dsh_sdk::Platform; - /// let platform = Platform::NpLz; - /// let endpoint = platform.endpoint_rest_access_token(); - /// assert_eq!(endpoint, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); - /// ``` - pub fn endpoint_rest_access_token(&self) -> &str { - match self { - Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", - Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", - Self::ProdLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/prod-lz-dsh/protocol/openid-connect/token", - Self::ProdAz => "https://auth.prod.cp.kpn-dsh.com/auth/realms/prod-azure-dsh/protocol/openid-connect/token", - Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", - } - } - - #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] - /// Get the endpoint for fetching DSH Rest Authentication Token - /// - /// With this token you can authenticate for the mqtt token endpoint - /// - /// It will return the endpoint for DSH Rest authentication token based on the platform - pub fn endpoint_rest_token(&self) -> &str { - self.endpoint_management_api_token() - } - - /// Get the endpoint for fetching DSH Rest Authentication Token - /// - /// With this token you can authenticate for the mqtt token endpoint - /// - /// It will return the endpoint for DSH Rest authentication token based on the platform - pub fn endpoint_management_api_token(&self) -> &str { - match self { - Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", - Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token", - Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/auth/v0/token", - Self::ProdAz => "https://api.az.kpn-dsh.com/auth/v0/token", - Self::Poc => "https://api.poc.kpn-dsh.com/auth/v0/token", - } - } - - #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] - /// Get the endpoint for fetching DSH mqtt token - /// - /// It will return the endpoint for DSH MQTT Token based on the platform - pub fn endpoint_mqtt_token(&self) -> &str { - self.endpoint_protocol_token() - } +mod platform; - /// Get the endpoint for fetching DSH Protocol token - /// - /// It will return the endpoint for DSH Protocol adapter Token based on the platform - pub fn endpoint_protocol_token(&self) -> &str { - match self { - Self::Prod => "https://api.kpn-dsh.com/datastreams/v0/mqtt/token", - Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/datastreams/v0/mqtt/token", - Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/datastreams/v0/mqtt/token", - Self::ProdAz => "https://api.az.kpn-dsh.com/datastreams/v0/mqtt/token", - Self::Poc => "https://api.poc.kpn-dsh.com/datastreams/v0/mqtt/token", - } - } +mod error; - pub fn realm(&self) -> &str { - match self { - Self::Prod => "tt-dsh", - Self::NpLz => "dev-lz-dsh", - Self::ProdLz => "prod-lz-dsh", - Self::ProdAz => "prod-azure-dsh", - Self::Poc => "poc-dsh", - } - } -} +#[doc(inline)] +pub use platform::Platform; /// Get the configured topics from the environment variable TOPICS /// Topics can be delimited by a comma @@ -253,29 +116,6 @@ mod tests { use super::*; use serial_test::serial; - #[test] - fn test_platform_realm() { - assert_eq!(Platform::NpLz.realm(), "dev-lz-dsh"); - assert_eq!(Platform::ProdLz.realm(), "prod-lz-dsh"); - assert_eq!(Platform::Poc.realm(), "poc-dsh"); - } - - #[test] - fn test_platform_client_id() { - assert_eq!( - Platform::NpLz.rest_client_id("my-tenant"), - "robot:dev-lz-dsh:my-tenant" - ); - assert_eq!( - Platform::ProdLz.rest_client_id("my-tenant".to_string()), - "robot:prod-lz-dsh:my-tenant" - ); - assert_eq!( - Platform::Poc.rest_client_id("my-tenant"), - "robot:poc-dsh:my-tenant" - ); - } - #[test] #[serial(env_dependency)] fn test_dsh_config_tenant_name() { diff --git a/dsh_sdk/src/utils/platform.rs b/dsh_sdk/src/utils/platform.rs new file mode 100644 index 0000000..7a684de --- /dev/null +++ b/dsh_sdk/src/utils/platform.rs @@ -0,0 +1,168 @@ +/// Available DSH platforms plus it's related metadata +/// +/// The platform enum contains +/// - `Prod` (kpn-dsh.com) +/// - `ProdAz` (az.kpn-dsh.com) +/// - `ProdLz` (dsh-prod.dsh.prod.aws.kpn.com) +/// - `NpLz` (dsh-dev.dsh.np.aws.kpn.com) +/// - `Poc` (poc.kpn-dsh.com) +/// +/// Each platform has it's own realm, endpoint for the DSH Rest API and endpoint for the DSH Rest API access token. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum Platform { + /// Production platform (kpn-dsh.com) + Prod, + /// Production platform on Azure (az.kpn-dsh.com) + ProdAz, + /// Production Landing Zone on AWS (dsh-prod.dsh.prod.aws.kpn.com) + ProdLz, + /// Non-Production (Dev) Landing Zone on AWS (dsh-dev.dsh.np.aws.kpn.com) + NpLz, + /// Proof of Concept platform (poc.kpn-dsh.com) + Poc, +} + +impl Platform { + /// Get a properly formatted client_id for the Rest API based on the given name of a tenant + /// + /// It will return a string formatted as "robot:{realm}:{tenant_name}" + /// + /// ## Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// assert_eq!(client_id, "robot:dev-lz-dsh:my-tenant"); + /// ``` + pub fn rest_client_id(&self, tenant: T) -> String + where + T: AsRef, + { + format!("robot:{}:{}", self.realm(), tenant.as_ref()) + } + + /// Get the endpoint for the DSH Rest API + /// + /// It will return the endpoint for the DSH Rest API based on the platform + /// + /// ## Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let endpoint = platform.endpoint_rest_api(); + /// assert_eq!(endpoint, "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0"); + /// ``` + pub fn endpoint_rest_api(&self) -> &str { + match self { + Self::Prod => "https://api.kpn-dsh.com/resources/v0", + Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0", + Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/resources/v0", + Self::ProdAz => "https://api.az.kpn-dsh.com/resources/v0", + Self::Poc => "https://api.poc.kpn-dsh.com/resources/v0", + } + } + /// Get the endpoint for the DSH Rest API access token + /// + /// It will return the endpoint for the DSH Rest API access token based on the platform + /// + /// ## Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let endpoint = platform.endpoint_rest_access_token(); + /// assert_eq!(endpoint, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); + /// ``` + pub fn endpoint_rest_access_token(&self) -> &str { + match self { + Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", + Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", + Self::ProdLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/prod-lz-dsh/protocol/openid-connect/token", + Self::ProdAz => "https://auth.prod.cp.kpn-dsh.com/auth/realms/prod-azure-dsh/protocol/openid-connect/token", + Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", + } + } + + #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] + /// Get the endpoint for fetching DSH Rest Authentication Token + /// + /// With this token you can authenticate for the mqtt token endpoint + /// + /// It will return the endpoint for DSH Rest authentication token based on the platform + pub fn endpoint_rest_token(&self) -> &str { + self.endpoint_management_api_token() + } + + /// Get the endpoint for fetching DSH Rest Authentication Token + /// + /// With this token you can authenticate for the mqtt token endpoint + /// + /// It will return the endpoint for DSH Rest authentication token based on the platform + pub fn endpoint_management_api_token(&self) -> &str { + match self { + Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", + Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token", + Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/auth/v0/token", + Self::ProdAz => "https://api.az.kpn-dsh.com/auth/v0/token", + Self::Poc => "https://api.poc.kpn-dsh.com/auth/v0/token", + } + } + + #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] + /// Get the endpoint for fetching DSH mqtt token + /// + /// It will return the endpoint for DSH MQTT Token based on the platform + pub fn endpoint_mqtt_token(&self) -> &str { + self.endpoint_protocol_token() + } + + /// Get the endpoint for fetching DSH Protocol token + /// + /// It will return the endpoint for DSH Protocol adapter Token based on the platform + pub fn endpoint_protocol_token(&self) -> &str { + match self { + Self::Prod => "https://api.kpn-dsh.com/datastreams/v0/mqtt/token", + Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/datastreams/v0/mqtt/token", + Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/datastreams/v0/mqtt/token", + Self::ProdAz => "https://api.az.kpn-dsh.com/datastreams/v0/mqtt/token", + Self::Poc => "https://api.poc.kpn-dsh.com/datastreams/v0/mqtt/token", + } + } + + pub fn realm(&self) -> &str { + match self { + Self::Prod => "tt-dsh", + Self::NpLz => "dev-lz-dsh", + Self::ProdLz => "prod-lz-dsh", + Self::ProdAz => "prod-azure-dsh", + Self::Poc => "poc-dsh", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_platform_realm() { + assert_eq!(Platform::NpLz.realm(), "dev-lz-dsh"); + assert_eq!(Platform::ProdLz.realm(), "prod-lz-dsh"); + assert_eq!(Platform::Poc.realm(), "poc-dsh"); + } + + #[test] + fn test_platform_client_id() { + assert_eq!( + Platform::NpLz.rest_client_id("my-tenant"), + "robot:dev-lz-dsh:my-tenant" + ); + assert_eq!( + Platform::ProdLz.rest_client_id("my-tenant".to_string()), + "robot:prod-lz-dsh:my-tenant" + ); + assert_eq!( + Platform::Poc.rest_client_id("my-tenant"), + "robot:poc-dsh:my-tenant" + ); + } +} From beed01a5839e958646c7b5e79446cd57cdc434ce Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 13:53:15 +0100 Subject: [PATCH 13/23] bump to 0.5.0-rc.2 --- dsh_sdk/Cargo.toml | 2 +- dsh_sdk/README.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 9ff59f8..26ebd78 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true name = "dsh_sdk" readme = 'README.md' repository.workspace = true -version = "0.5.0-rc.1" +version = "0.5.0-rc.2" [package.metadata.docs.rs] all-features = true diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index c348fa8..d0400aa 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -64,7 +64,7 @@ To get started, add the following to your `Cargo.toml`: ```toml [dependencies] -dsh_sdk = "0.5" +dsh_sdk = "0.5.0-rc.2" rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } ``` @@ -129,7 +129,7 @@ To pick only the features you need, disable the default features and enable spec ```toml [dependencies] -dsh_sdk = { version = "0.5", default-features = false, features = ["management-api-token-fetcher"] } +dsh_sdk = { version = "0.5.0-rc.2", default-features = false, features = ["management-api-token-fetcher"] } ``` --- From 36135f746f7eef02872a00ffadbf7a6a702c7a65 Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 14:23:34 +0100 Subject: [PATCH 14/23] FQDN links so it is compatible on docs.rs --- dsh_sdk/README.md | 35 ++++++++++++++++++----------------- dsh_sdk/src/lib.rs | 15 --------------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index d0400aa..70ae188 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -101,7 +101,7 @@ This SDK accommodates multiple deployment environments: - Running on a machine with Kafka Proxy/VPN - Running locally with a local Kafka instance -For more information, see the [CONNECT_PROXY_VPN_LOCAL.md](CONNECT_PROXY_VPN_LOCAL.md) document. +For more information, see the [CONNECT_PROXY_VPN_LOCAL.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/CONNECT_PROXY_VPN_LOCAL.md) document. --- @@ -114,15 +114,16 @@ Below is an overview of the available features: | **feature** | **default** | **Description** | **Example** | |--------------------------------|-------------|-------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------| -| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](./examples/kafka_example.rs) / [Kafka Proxy](./examples/kafka_proxy.rs) | -| `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](./examples/schema_store_api.rs) | -| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](./examples/protocol_token_fetcher.rs) / [with specific claims](./examples/protocol_token_fetcher_specific_claims.rs) | -| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](./examples/management_api_token_fetcher.rs) | -| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](./examples/expose_metrics.rs) | -| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](./examples/graceful_shutdown.rs) | -| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](./examples/dlq_implementation.rs) | +| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/schema_store_api.rs) | +| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_token_fetcher.rs) / [with specific claims](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs) | +| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/management_api_token_fetcher.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) | +| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/graceful_shutdown.rs) | +| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) | + ### Selecting Features To pick only the features you need, disable the default features and enable specific ones. For instance, if you only want the Management API Token Fetcher: @@ -142,41 +143,41 @@ This SDK uses certain environment variables to configure connections to DSH. For ## Examples -You can find simple usage examples in the [`examples/` directory](./examples/). +You can find simple usage examples in the [`examples/` directory](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/). ### Full Service Example -A more complete example is provided in the [`example_dsh_service/`](../example_dsh_service/) directory, showcasing: +A more complete example is provided in the [`example_dsh_service/`](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/example_dsh_service/) directory, showcasing: - How to build the Rust project - How to package and push it to Harbor - An end-to-end setup of a DSH service -See the [README](../example_dsh_service/README.md) in that directory for more information. +See the [README](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/example_dsh_service/README.md) in that directory for more information. --- ## Changelog -All changes per version are documented in [CHANGELOG.md](CHANGELOG.md). +All changes per version are documented in [CHANGELOG.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/CHANGELOG.md). --- ## Contributing -Contributions are welcome! For details on how to help improve this project, please see [CONTRIBUTING.md](../CONTRIBUTING.md). +Contributions are welcome! For details on how to help improve this project, please see [CONTRIBUTING.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/CONTRIBUTING.md). --- ## License -This project is licensed under the [Apache License 2.0](../LICENSE). +This project is licensed under the [Apache License 2.0](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/LICENSE). --- ## Security -For information about the security policy of this project, including how to report vulnerabilities, see [SECURITY.md](../SECURITY.md). +For information about the security policy of this project, including how to report vulnerabilities, see [SECURITY.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/SECURITY.md). --- diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index 6cf5a93..da7acbb 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -1,21 +1,6 @@ #![doc = include_str!("../README.md")] #![allow(deprecated)] -//! This crate provides core functionalities, environment variable references, and feature-gated -//! modules for working with DSH (Data Services Hub). It also includes deprecated items from -//! older versions of the API that remain available for backward compatibility until v0.6.0. -//! -//! # Crate Overview -//! - **Feature-gated modules**: Certain modules like [`certificates`], [`datastream`], [`dsh`], etc. -//! are only compiled if the `bootstrap` feature is enabled. -//! - **Environment variables**: Constants representing environment variables are declared for -//! configuring Kafka, schema registry, and other components. -//! - **Deprecated modules**: Modules such as [`dlq`], [`dsh_old`], [`graceful_shutdown`], and -//! [`metrics`], among others, are slated for removal in v0.6.0 or have been moved to new -//! namespaces. -//! -//! Refer to the included `README.md` for more information on usage, setup, and features. - // Keep in v0.6.0 for backward compatibility #[cfg(feature = "bootstrap")] pub mod certificates; From c9f2f9ba2d781133967c404f16239205df9addeb Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 16:56:58 +0100 Subject: [PATCH 15/23] restore dlq for deprecation warning --- dsh_sdk/src/dlq.rs | 546 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 541 insertions(+), 5 deletions(-) diff --git a/dsh_sdk/src/dlq.rs b/dsh_sdk/src/dlq.rs index b74f928..9b4ffa7 100644 --- a/dsh_sdk/src/dlq.rs +++ b/dsh_sdk/src/dlq.rs @@ -23,8 +23,544 @@ //! ### Example: //! See the examples folder on github for a working example. -#[deprecated( - since = "0.5.0", - note = "The DLQ is moved to [crate::utils::dlq](crate::utils::dlq)" -)] -pub use crate::utils::dlq::*; +//pub use crate::utils::dlq::*; + + + +use std::collections::HashMap; +use std::env; +use std::str::from_utf8; + +use log::{debug, error, info, warn}; + +use rdkafka::message::{Header, Headers, Message, OwnedHeaders, OwnedMessage}; +use rdkafka::producer::{FutureProducer, FutureRecord}; + +use tokio::sync::mpsc; + +use crate::graceful_shutdown::Shutdown; +use crate::Properties; + +/// Trait to convert an error to a dlq message +/// This trait is implemented for all errors that can and should be converted to a dlq message +/// +/// Example: +///``` +/// use dsh_sdk::dlq; +/// use std::backtrace::Backtrace; +/// use thiserror::Error; +/// +/// #[derive(Error, Debug)] +/// enum ConsumerError { +/// #[error("Deserialization error: {0}")] +/// DeserializeError(String), +/// } +/// +/// impl dlq::ErrorToDlq for ConsumerError { +/// fn to_dlq(&self, kafka_message: rdkafka::message::OwnedMessage) -> dlq::SendToDlq { +/// dlq::SendToDlq::new(kafka_message, self.retryable(), self.to_string(), None) +/// } +/// fn retryable(&self) -> dlq::Retryable { +/// match self { +/// ConsumerError::DeserializeError(e) => dlq::Retryable::NonRetryable, +/// } +/// } +/// } +/// ``` +pub trait ErrorToDlq { + /// Convert error message to a dlq message + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq; + /// Match error if the orignal message is able to be retried or not + fn retryable(&self) -> Retryable; +} + +/// Struct with required details to send a channel message to the dlq +/// Error needs to be send as string, as it is not possible to send a struct that implements Error trait +pub struct SendToDlq { + kafka_message: OwnedMessage, + retryable: Retryable, + error: String, + stack_trace: Option, +} + +impl SendToDlq { + /// Create new SendToDlq message + pub fn new( + kafka_message: OwnedMessage, + retryable: Retryable, + error: String, + stack_trace: Option, + ) -> Self { + Self { + kafka_message, + retryable, + error, + stack_trace, + } + } + /// Send message to dlq channel + pub async fn send(self, dlq_tx: &mut mpsc::Sender) { + match dlq_tx.send(self).await { + Ok(_) => debug!("Message sent to DLQ channel"), + Err(e) => error!("Error sending message to DLQ: {}", e), + } + } + + fn get_original_msg(&self) -> OwnedMessage { + self.kafka_message.clone() + } +} + +/// Helper enum to decide to which topic the message should be sent to. +#[derive(Debug, Clone, Copy)] +pub enum Retryable { + Retryable, + NonRetryable, + Other, +} + +impl std::fmt::Display for Retryable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Retryable::Retryable => write!(f, "Retryable"), + Retryable::NonRetryable => write!(f, "NonRetryable"), + Retryable::Other => write!(f, "Other"), + } + } +} + +/// Struct with implementation to send messages to the dlq +pub struct Dlq { + dlq_producer: FutureProducer, + dlq_rx: mpsc::Receiver, + dlq_tx: mpsc::Sender, + dlq_dead_topic: String, + dlq_retry_topic: String, + shutdown: Shutdown, +} + +impl Dlq { + /// Create new Dlq struct + pub fn new( + dsh_prop: &Properties, + shutdown: Shutdown, + ) -> Result> { + use crate::dsh_old::datastream::ReadWriteAccess; + let (dlq_tx, dlq_rx) = mpsc::channel(200); + let dlq_producer = Self::build_producer(dsh_prop)?; + let dlq_dead_topic = env::var("DLQ_DEAD_TOPIC")?; + let dlq_retry_topic = env::var("DLQ_RETRY_TOPIC")?; + dsh_prop.datastream().verify_list_of_topics( + &vec![&dlq_dead_topic, &dlq_retry_topic], + ReadWriteAccess::Write, + )?; + Ok(Self { + dlq_producer, + dlq_rx, + dlq_tx, + dlq_dead_topic, + dlq_retry_topic, + shutdown, + }) + } + + /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics + /// This function will run until the shutdown channel is closed + pub async fn run(&mut self) { + info!("DLQ started"); + loop { + tokio::select! { + _ = self.shutdown.recv() => { + warn!("DLQ shutdown"); + return; + }, + Some(mut dlq_message) = self.dlq_rx.recv() => { + match self.send(&mut dlq_message).await { + Ok(_) => {}, + Err(e) => error!("Error sending message to DLQ: {}", e), + }; + } + } + } + } + + /// Get the dlq channel sender. To be used in your service to send messages to the dlq in case of errors. + /// + /// This channel can be used to send messages to the dlq from different threads. + pub fn dlq_records_tx(&self) -> mpsc::Sender { + self.dlq_tx.clone() + } + + /// Create and send message towards the dlq + async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), rdkafka::error::KafkaError> { + let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); + let headers = orignal_kafka_msg + .generate_dlq_headers(dlq_message) + .to_owned_headers(); + let topic = self.dlq_topic(dlq_message.retryable); + let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); + let payload = orignal_kafka_msg.payload().unwrap_or_default(); + debug!("Sending message to DLQ topic: {}", topic); + let record = FutureRecord::to(topic) + .payload(payload) + .key(key) + .headers(headers); + let s = self.dlq_producer.send(record, None).await; + match s { + Ok((p, o)) => warn!( + "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", + from_utf8(key), + topic, + p, + o + ), + Err((e, _)) => return Err(e), + }; + Ok(()) + } + + fn dlq_topic(&self, retryable: Retryable) -> &str { + match retryable { + Retryable::Retryable => &self.dlq_retry_topic, + Retryable::NonRetryable => &self.dlq_dead_topic, + Retryable::Other => &self.dlq_dead_topic, + } + } + + fn build_producer(dsh_prop: &Properties) -> Result { + dsh_prop.producer_rdkafka_config().create() + } +} + +trait DlqHeaders { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>>; +} + +impl DlqHeaders for OwnedMessage { + fn generate_dlq_headers<'a>( + &'a self, + dlq_message: &'a mut SendToDlq, + ) -> HashMap<&'a str, Option>> { + let mut hashmap_headers: HashMap<&str, Option>> = HashMap::new(); + // Get original headers and add to hashmap + if let Some(headers) = self.headers() { + for header in headers.iter() { + hashmap_headers.insert(header.key, header.value.map(|v| v.to_vec())); + } + } + + // Add dlq headers if not exist (we don't want to overwrite original dlq headers if message already failed earlier) + let partition = self.partition().to_string().as_bytes().to_vec(); + let offset = self.offset().to_string().as_bytes().to_vec(); + let timestamp = self + .timestamp() + .to_millis() + .unwrap_or(-1) + .to_string() + .as_bytes() + .to_vec(); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_partition_origin") + .or_insert_with(move || Some(partition)); + hashmap_headers + .entry("dlq_partition_offset_origin") + .or_insert_with(move || Some(offset)); + hashmap_headers + .entry("dlq_topic_origin") + .or_insert_with(|| Some(self.topic().as_bytes().to_vec())); + hashmap_headers + .entry("dlq_timestamp_origin") + .or_insert_with(move || Some(timestamp)); + // Overwrite if exist + hashmap_headers.insert( + "dlq_retryable", + Some(dlq_message.retryable.to_string().as_bytes().to_vec()), + ); + hashmap_headers.insert( + "dlq_error", + Some(dlq_message.error.to_string().as_bytes().to_vec()), + ); + if let Some(stack_trace) = &dlq_message.stack_trace { + hashmap_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + // update dlq_retries with +1 if exists, else add dlq_retries wiith 1 + let retries = hashmap_headers + .get("dlq_retries") + .map(|v| { + let mut retries = [0; 4]; + retries.copy_from_slice(v.as_ref().unwrap()); + i32::from_be_bytes(retries) + }) + .unwrap_or(0); + hashmap_headers.insert("dlq_retries", Some((retries + 1).to_be_bytes().to_vec())); + + hashmap_headers + } +} + +trait HashMapToKafkaHeaders { + fn to_owned_headers(&self) -> OwnedHeaders; +} + +impl HashMapToKafkaHeaders for HashMap<&str, Option>> { + fn to_owned_headers(&self) -> OwnedHeaders { + // Convert to OwnedHeaders + let mut owned_headers = OwnedHeaders::new_with_capacity(self.len()); + for header in self { + let value = header.1.as_ref().map(|value| value.as_slice()); + owned_headers = owned_headers.insert(Header { + key: header.0, + value, + }); + } + owned_headers + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rdkafka::config::ClientConfig; + use rdkafka::mocking::MockCluster; + + #[derive(Debug)] + enum MockError { + MockErrorRetryable(String), + MockErrorDead(String), + } + impl MockError { + fn to_string(&self) -> String { + match self { + MockError::MockErrorRetryable(e) => e.to_string(), + MockError::MockErrorDead(e) => e.to_string(), + } + } + } + + impl std::fmt::Display for MockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MockError::MockErrorRetryable(e) => write!(f, "{}", e), + MockError::MockErrorDead(e) => write!(f, "{}", e), + } + } + } + + impl ErrorToDlq for MockError { + fn to_dlq(&self, kafka_message: OwnedMessage) -> SendToDlq { + let backtrace = "some_backtrace"; + SendToDlq::new( + kafka_message, + self.retryable(), + self.to_string(), + Some(backtrace.to_string()), + ) + } + + fn retryable(&self) -> Retryable { + match self { + MockError::MockErrorRetryable(_) => Retryable::Retryable, + MockError::MockErrorDead(_) => Retryable::NonRetryable, + } + } + } + + #[test] + fn test_dlq_get_original_msg() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let dlq_message = + MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); + let result = dlq_message.get_original_msg(); + assert_eq!( + result.payload(), + dlq_message.kafka_message.payload(), + "payoad does not match" + ); + assert_eq!( + result.key(), + dlq_message.kafka_message.key(), + "key does not match" + ); + assert_eq!( + result.topic(), + dlq_message.kafka_message.topic(), + "topic does not match" + ); + assert_eq!( + result.partition(), + dlq_message.kafka_message.partition(), + "partition does not match" + ); + assert_eq!( + result.offset(), + dlq_message.kafka_message.offset(), + "offset does not match" + ); + assert_eq!( + result.timestamp(), + dlq_message.kafka_message.timestamp(), + "timestamp does not match" + ); + } + + #[test] + fn test_dlq_hashmap_to_owned_headers() { + let mut hashmap: HashMap<&str, Option>> = HashMap::new(); + hashmap.insert("some_key", Some(b"key_value".to_vec())); + hashmap.insert("some_other_key", None); + let result: Vec<(&str, Option<&[u8]>)> = + vec![("some_key", Some(b"key_value")), ("some_other_key", None)]; + + let owned_headers = hashmap.to_owned_headers(); + for header in owned_headers.iter() { + assert!(result.contains(&(header.key, header.value))); + } + } + + #[test] + fn test_dlq_topic() { + let mock_cluster = MockCluster::new(1).unwrap(); + let mut producer = ClientConfig::new(); + producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); + let producer = producer.create().unwrap(); + let dlq = Dlq { + dlq_producer: producer, + dlq_rx: mpsc::channel(200).1, + dlq_tx: mpsc::channel(200).0, + dlq_dead_topic: "dead_topic".to_string(), + dlq_retry_topic: "retry_topic".to_string(), + shutdown: Shutdown::new(), + }; + let error = MockError::MockErrorRetryable("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "retry_topic"); + let error = MockError::MockErrorDead("some_error".to_string()); + let topic = dlq.dlq_topic(error.retryable()); + assert_eq!(topic, "dead_topic"); + } + + #[test] + fn test_dlq_generate_dlq_headers() { + let topic = "original_topic"; + let partition = 0; + let offset = 123; + let timestamp = 456; + let error = Box::new(MockError::MockErrorRetryable("some_error".to_string())); + + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "some_key", + value: Some("some_value".as_bytes()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + topic.to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + + let mut dlq_message = error.to_dlq(owned_message.clone()); + + let mut expected_headers: HashMap<&str, Option>> = HashMap::new(); + expected_headers.insert("some_key", Some(b"some_value".to_vec())); + expected_headers.insert("dlq_topic_origin", Some(topic.as_bytes().to_vec())); + expected_headers.insert( + "dlq_partition_origin", + Some(partition.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_partition_offset_origin", + Some(offset.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_timestamp_origin", + Some(timestamp.to_string().as_bytes().to_vec()), + ); + expected_headers.insert( + "dlq_retryable", + Some(Retryable::Retryable.to_string().as_bytes().to_vec()), + ); + expected_headers.insert("dlq_retries", Some(1_i32.to_be_bytes().to_vec())); + expected_headers.insert("dlq_error", Some(error.to_string().as_bytes().to_vec())); + if let Some(stack_trace) = &dlq_message.stack_trace { + expected_headers.insert("dlq_stack_trace", Some(stack_trace.as_bytes().to_vec())); + } + + let result = owned_message.generate_dlq_headers(&mut dlq_message); + for header in result.iter() { + assert_eq!( + header.1, + expected_headers.get(header.0).unwrap_or(&None), + "Header {} does not match", + header.0 + ); + } + + // Test if dlq headers are correctly overwritten when to be retried message was already retried before + let mut original_headers: OwnedHeaders = OwnedHeaders::new(); + original_headers = original_headers.insert(Header { + key: "dlq_error", + value: Some( + "to_be_overwritten_error_as_this_was_the_original_error_from_1st_retry".as_bytes(), + ), + }); + original_headers = original_headers.insert(Header { + key: "dlq_topic_origin", + value: Some(topic.as_bytes()), + }); + original_headers = original_headers.insert(Header { + key: "dlq_retries", + value: Some(&1_i32.to_be_bytes().to_vec()), + }); + + let owned_message = OwnedMessage::new( + Some(vec![1, 2, 3]), + Some(vec![4, 5, 6]), + "retry_topic".to_string(), + rdkafka::Timestamp::CreateTime(timestamp), + partition, + offset, + Some(original_headers), + ); + let result = owned_message.generate_dlq_headers(&mut dlq_message); + assert_eq!( + result.get("dlq_error").unwrap(), + &Some(error.to_string().as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_topic_origin").unwrap(), + &Some(topic.as_bytes().to_vec()) + ); + assert_eq!( + result.get("dlq_retries").unwrap(), + &Some(2_i32.to_be_bytes().to_vec()) + ); + } +} \ No newline at end of file From 3b211594e867cdf01a0290e9a147886b7ec98f0e Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 16:57:19 +0100 Subject: [PATCH 16/23] update changelog and dlq documentation --- dsh_sdk/CHANGELOG.md | 2 +- dsh_sdk/src/utils/dlq/dlq.rs | 28 ++++++++++++++++++++-------- dsh_sdk/src/utils/dlq/mod.rs | 10 +++++++++- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index ecbe9d3..3df79ff 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` - `dsh_sdk::error::DshRestTokenError` renamed to `dsh_sdk::management_api::error::ManagementApiTokenError` - **NOTE** Cargo.toml feature flag `rest-token-fetcher` renamed to`management-api-token-fetcher` -- Moved `dsh_sdk::dsh::datastreams` to `dsh_sdk::datastreams` +- Moved `dsh_sdk::dsh::datastream` to `dsh_sdk::datastream` - Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module - Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` and renamed to `ProtocolTokenFetcher` diff --git a/dsh_sdk/src/utils/dlq/dlq.rs b/dsh_sdk/src/utils/dlq/dlq.rs index 3937175..6781a9a 100644 --- a/dsh_sdk/src/utils/dlq/dlq.rs +++ b/dsh_sdk/src/utils/dlq/dlq.rs @@ -19,11 +19,19 @@ use crate::DshKafkaConfig; /// The dead letter queue /// -/// ## How to use -/// 1. Implement the [ErrorToDlq](super::ErrorToDlq) trait on top your (custom) error type. -/// 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) -/// 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq](super::ErrorToDlq::to_dlq) method. +/// # How to use +/// 1. Implement the [`ErrorToDlq`](super::ErrorToDlq) trait on top your (custom) error type. +/// 2. Use the [`Dlq::start`] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) +/// 3. Get the dlq [`DlqChannel`] from the [`Dlq::start`] method and use this channel to communicate errored messages with the [`Dlq`] via the [`ErrorToDlq::to_dlq`](super::ErrorToDlq::to_dlq) method. /// +/// # Importance of `DlqChannel` in the graceful shutdown procedure +/// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. +/// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. +/// This is to make sure that **all** messages are properly processed before the application is shut down. +/// +/// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. +/// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. +/// /// # Example /// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) pub struct Dlq { @@ -40,14 +48,18 @@ impl Dlq { /// The DLQ will run until the return `Sender` is dropped. /// /// # Arguments - /// * `shutdown` - The shutdown is required to keep the DLQ alive until the DLQ Sender is dropped + /// * `shutdown` - The [`Shutdown`] is required to keep the DLQ alive until the [`DlqChannel`] is dropped /// /// # Returns /// * The [DlqChannel] to send messages to the DLQ /// - /// # Note - /// **NEVER** borrow the [DlqChannel] to your consumer, always use an owned [DlqChannel]. - /// This is required to stop the gracefull shutdown the DLQ as it depends on the [DlqChannel] to be dropped. + /// # Importance of `DlqChannel` in the graceful shutdown procedure + /// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. + /// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. + /// This is to make sure that **all** messages are properly processed before the application is shut down. + /// + /// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. + /// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. /// /// # Example /// ```no_run diff --git a/dsh_sdk/src/utils/dlq/mod.rs b/dsh_sdk/src/utils/dlq/mod.rs index 4ee9c88..f9b192f 100644 --- a/dsh_sdk/src/utils/dlq/mod.rs +++ b/dsh_sdk/src/utils/dlq/mod.rs @@ -17,6 +17,14 @@ //! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method which is implemented on your Error. //! //! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. +//! +//! ## Importance of `DlqChannel` in the graceful shutdown procedure +//! The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. +//! This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. +//! This is to make sure that **all** messages are properly processed before the application is shut down. +//! +//! **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. +//! It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. //! //! ### Example: //! @@ -31,7 +39,7 @@ pub use dlq::Dlq; pub use error::DlqErrror; #[doc(inline)] pub use types::*; -/// Channel to send messages to the dead letter queue +/// Channel to send [SendToDlq] messages to the dead letter queue pub type DlqChannel = tokio::sync::mpsc::Sender; // Mock error avaialbnle in tests From cfe6bf08dfb4b0cf3388a94096d6ddc4733bc2ae Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 16:58:38 +0100 Subject: [PATCH 17/23] fmt --- dsh_sdk/src/dlq.rs | 4 +--- dsh_sdk/src/utils/dlq/dlq.rs | 18 +++++++++--------- dsh_sdk/src/utils/dlq/mod.rs | 10 +++++----- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/dsh_sdk/src/dlq.rs b/dsh_sdk/src/dlq.rs index 9b4ffa7..f14646c 100644 --- a/dsh_sdk/src/dlq.rs +++ b/dsh_sdk/src/dlq.rs @@ -25,8 +25,6 @@ //pub use crate::utils::dlq::*; - - use std::collections::HashMap; use std::env; use std::str::from_utf8; @@ -563,4 +561,4 @@ mod tests { &Some(2_i32.to_be_bytes().to_vec()) ); } -} \ No newline at end of file +} diff --git a/dsh_sdk/src/utils/dlq/dlq.rs b/dsh_sdk/src/utils/dlq/dlq.rs index 6781a9a..82b2d46 100644 --- a/dsh_sdk/src/utils/dlq/dlq.rs +++ b/dsh_sdk/src/utils/dlq/dlq.rs @@ -25,13 +25,13 @@ use crate::DshKafkaConfig; /// 3. Get the dlq [`DlqChannel`] from the [`Dlq::start`] method and use this channel to communicate errored messages with the [`Dlq`] via the [`ErrorToDlq::to_dlq`](super::ErrorToDlq::to_dlq) method. /// /// # Importance of `DlqChannel` in the graceful shutdown procedure -/// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. -/// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. +/// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. +/// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. /// This is to make sure that **all** messages are properly processed before the application is shut down. -/// -/// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. +/// +/// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. /// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. -/// +/// /// # Example /// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) pub struct Dlq { @@ -54,11 +54,11 @@ impl Dlq { /// * The [DlqChannel] to send messages to the DLQ /// /// # Importance of `DlqChannel` in the graceful shutdown procedure - /// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. - /// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. + /// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. + /// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. /// This is to make sure that **all** messages are properly processed before the application is shut down. - /// - /// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. + /// + /// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. /// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. /// /// # Example diff --git a/dsh_sdk/src/utils/dlq/mod.rs b/dsh_sdk/src/utils/dlq/mod.rs index f9b192f..048a8f2 100644 --- a/dsh_sdk/src/utils/dlq/mod.rs +++ b/dsh_sdk/src/utils/dlq/mod.rs @@ -17,13 +17,13 @@ //! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method which is implemented on your Error. //! //! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. -//! +//! //! ## Importance of `DlqChannel` in the graceful shutdown procedure -//! The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. -//! This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. +//! The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. +//! This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. //! This is to make sure that **all** messages are properly processed before the application is shut down. -//! -//! **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. +//! +//! **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. //! It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. //! //! ### Example: From 93e45ad357a852b561fb0fdba70e6237d699052a Mon Sep 17 00:00:00 2001 From: Frank Hol Date: Mon, 13 Jan 2025 17:08:16 +0100 Subject: [PATCH 18/23] cargo clippy --- dsh_sdk/src/metrics.rs | 3 ++- .../src/protocol_adapters/kafka_protocol/rdkafka.rs | 2 +- dsh_sdk/src/protocol_adapters/mod.rs | 10 +++++----- dsh_sdk/src/utils/metrics.rs | 4 +++- dsh_sdk/src/utils/mod.rs | 4 ++-- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/dsh_sdk/src/metrics.rs b/dsh_sdk/src/metrics.rs index 1c26a23..7f55ff2 100644 --- a/dsh_sdk/src/metrics.rs +++ b/dsh_sdk/src/metrics.rs @@ -7,7 +7,8 @@ //! Most metrics libraries provide a way to encode the metrics to a string. For example, //! - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. //! - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. -//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. +//! +//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. //! //! ### Example: //! ``` diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs index 74c06e1..97fe3bd 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/rdkafka.rs @@ -68,7 +68,7 @@ impl DshKafkaConfig for ClientConfig { if group_id.starts_with(tenant) { self.set("group.id", group_id) } else { - self.set("group.id", &format!("{}_{}", tenant, group_id)) + self.set("group.id", format!("{}_{}", tenant, group_id)) } } diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs index 860d687..ea5e196 100644 --- a/dsh_sdk/src/protocol_adapters/mod.rs +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -1,11 +1,11 @@ //! The DSH Protocol adapter clients (HTTP, Kafka, MQTT) -//! -#[cfg(feature = "http-protocol-adapter")] -pub mod http_protocol; + +//#[cfg(feature = "http-protocol-adapter")] +//pub mod http_protocol; #[cfg(feature = "kafka")] pub mod kafka_protocol; -#[cfg(feature = "mqtt-protocol-adapter")] -pub mod mqtt_protocol; +// #[cfg(feature = "mqtt-protocol-adapter")] +// pub mod mqtt_protocol; #[cfg(feature = "protocol-token-fetcher")] pub mod token_fetcher; diff --git a/dsh_sdk/src/utils/metrics.rs b/dsh_sdk/src/utils/metrics.rs index 9d8c6d1..a55a4e0 100644 --- a/dsh_sdk/src/utils/metrics.rs +++ b/dsh_sdk/src/utils/metrics.rs @@ -7,7 +7,8 @@ //! Most metrics libraries provide a way to encode the metrics to a string. For example, //! - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. //! - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. -//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. +//! +//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. //! //! ### Example: //! ``` @@ -79,6 +80,7 @@ pub enum MetricsError { /// Most metrics libraries provide a way to encode the metrics to a string. For example, /// - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. /// - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +/// /// See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. /// /// ## Example diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index e716fc4..0364e56 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -14,8 +14,8 @@ pub use error::UtilsError; pub mod dlq; #[cfg(feature = "graceful-shutdown")] pub mod graceful_shutdown; -#[cfg(feature = "hyper-client")] // TODO: to be implemented -pub(crate) mod http_client; +// #[cfg(feature = "hyper-client")] // TODO: to be implemented +// pub(crate) mod http_client; #[cfg(feature = "metrics")] pub mod metrics; From b681c2acdd6a89894c908c7e121dffab6020226a Mon Sep 17 00:00:00 2001 From: Arend-Jan Date: Tue, 14 Jan 2025 17:49:17 +0000 Subject: [PATCH 19/23] 112 improve in code documentation 2 (#113) * improving documententation and comments for token fetcher * improving documententation and comments for kafka protocol * improving documententation and comments for schema store client * improving documententation and comments for utils graceful shutdown * improving documententation and comments for metric * improving documententation and comments for platform utils * improving documententation and comments for dead letter queue utils * rename schema() to content() * [fix] doc examples failures on test (incorrect imports) * fmt * Review token_fetcher and minor improvements * Review and fix minor stuff in documentation * Refactor DLQ documentation * Review graceful_shutdown and minor improvements * Review metrics and minor improvements * Review platform and mino improvements * fmt * Fix doc warnings * Fix doc test (missing env variable) --------- Co-authored-by: Frank Hol --- dsh_sdk/src/datastream/mod.rs | 2 +- dsh_sdk/src/error.rs | 26 +- dsh_sdk/src/management_api/mod.rs | 2 +- dsh_sdk/src/management_api/token_fetcher.rs | 384 +++++++++----- .../protocol_adapters/kafka_protocol/mod.rs | 103 ++-- dsh_sdk/src/rest_api_token_fetcher.rs | 2 +- dsh_sdk/src/schema_store/client.rs | 480 ++++++++++-------- .../schema_store/types/schema/raw_schema.rs | 32 +- dsh_sdk/src/utils/dlq/dlq.rs | 196 +++---- dsh_sdk/src/utils/dlq/mod.rs | 49 +- dsh_sdk/src/utils/graceful_shutdown.rs | 262 +++++++--- dsh_sdk/src/utils/metrics.rs | 216 ++++---- dsh_sdk/src/utils/platform.rs | 120 +++-- 13 files changed, 1156 insertions(+), 718 deletions(-) diff --git a/dsh_sdk/src/datastream/mod.rs b/dsh_sdk/src/datastream/mod.rs index c2908e7..7fb5540 100644 --- a/dsh_sdk/src/datastream/mod.rs +++ b/dsh_sdk/src/datastream/mod.rs @@ -6,7 +6,7 @@ //! //! # Usage Overview //! - **Local Loading**: By default, you can load `datastreams.json` from the local filesystem -//! (see [`load_local_datastreams`] or [`Datastream::default`]) if running on an environment outside of DSH. +//! if running on an environment outside of DSH. //! - **Server Fetching**: You can also fetch an up-to-date `datastreams.json` from a DSH server (only works when running on DSH) //! using [`Datastream::fetch`] (async) or [`Datastream::fetch_blocking`] (blocking). //! diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index 0bc3a86..c193e7e 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -37,7 +37,6 @@ pub enum DshError { /// by each causal error (if any) on separate lines, preceded by `"Caused by:"`. /// /// This is helpful for logging or displaying the entire chain of errors. -/// pub(crate) fn report(mut err: &dyn std::error::Error) -> String { let mut s = format!("{}", err); while let Some(src) = err.source() { @@ -46,3 +45,28 @@ pub(crate) fn report(mut err: &dyn std::error::Error) -> String { } s } + +#[cfg(test)] +mod tests { + use super::*; + use crate::certificates::CertificatesError; + + /// Demonstrates how to construct and print `DshError` variants, + /// as well as how to use `report` to see the full causal chain. + #[test] + fn test_dsh_error_and_report() { + // Create a wrapped DshError (CertificatesError for demonstration) + let cert_err = CertificatesError::NoCertificates; + let dsh_err = DshError::from(cert_err); + + // Verify the display output + let error_message = format!("{}", dsh_err); + println!("{}", error_message); + assert!(error_message.contains("Certificates error: Certificates are not set")); + + // Demonstrate the 'report' function + let report_output = report(&dsh_err); + // Should contain the same info, but also handle possible sources. + assert!(report_output.contains("Certificates are not set")); + } +} diff --git a/dsh_sdk/src/management_api/mod.rs b/dsh_sdk/src/management_api/mod.rs index 4bfe432..6ffa0a2 100644 --- a/dsh_sdk/src/management_api/mod.rs +++ b/dsh_sdk/src/management_api/mod.rs @@ -1,6 +1,6 @@ //! Fetch and store tokens for the DSH Management Rest API client //! -//! This module is meant to be used together with the [dsh_rest_api_client]. +//! This module is meant to be used together with the [dsh_rest_api_client](https://crates.io/crates/dsh_rest_api_client). //! //! The TokenFetcher will fetch and store access tokens to be used in the DSH Rest API client. //! diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index 2da98cd..87883b8 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -1,3 +1,39 @@ +//! Management API token fetching for DSH. +//! +//! This module provides an interface (`ManagementApiTokenFetcher`) for fetching and +//! caching access tokens required to communicate with DSH’s management (REST) API. +//! Access tokens are automatically refreshed when expired, allowing seamless +//! integrations with the DSH platform. +//! +//! # Overview +//! - **[`AccessToken`]**: Access token from the authentication server. +//! - **[`ManagementApiTokenFetcher`]**: A token fetcher that caches tokens and +//! refreshes them upon expiration. +//! - **[`ManagementApiTokenFetcherBuilder`]**: A builder for customizing the fetcher’s +//! client, credentials, and target platform. +//! +//! # Typical Usage +//! 1. **Instantiate** a fetcher with credentials: +//! ``` +//! use dsh_sdk::management_api::ManagementApiTokenFetcherBuilder; +//! use dsh_sdk::Platform; +//! +//! let platform = Platform::NpLz; +//! let token_fetcher = ManagementApiTokenFetcherBuilder::new(platform) +//! .tenant_name("my-tenant") +//! .client_secret("my-secret") +//! .build() +//! .unwrap(); +//! ``` +//! 2. **Fetch** the token when needed: +//! ```ignore +//! let token = token_fetcher.get_token().await?; +//! ``` +//! 3. **Reuse** the same fetcher for subsequent calls; it auto-refreshes tokens. +//! +//! For more advanced usage (custom [`reqwest::Client`] or different credential sourcing), +//! see [`ManagementApiTokenFetcher::new_with_client`] or the [`ManagementApiTokenFetcherBuilder`]. + use std::fmt::Debug; use std::ops::Add; use std::sync::Mutex; @@ -9,12 +45,12 @@ use serde::Deserialize; use super::error::ManagementApiTokenError; use crate::utils::Platform; -/// Access token of the authentication serveice of DSH. -/// -/// This is the response whem requesting for a new access token. +/// Represents the Access Token by DSH’s authentication service. /// -/// ## Recommended usage -/// Use the [RestTokenFetcher::get_token] to get the bearer token, the `TokenFetcher` will automatically fetch a new token if the current token is not valid. +/// The fields include information about the token’s validity window, +/// token type, and scope. Typically, you won’t instantiate `AccessToken` directly: +/// use [`ManagementApiTokenFetcher::get_token`](crate::management_api::ManagementApiTokenFetcher::get_token) +/// to automatically obtain or refresh a valid token. #[derive(Debug, Clone, Deserialize)] pub struct AccessToken { access_token: String, @@ -27,37 +63,37 @@ pub struct AccessToken { } impl AccessToken { - /// Get the formatted token + /// Returns a complete token string, i.e. `"{token_type} {access_token}"`. pub fn formatted_token(&self) -> String { format!("{} {}", self.token_type, self.access_token) } - /// Get the access token + /// Returns the raw access token string (without the token type). pub fn access_token(&self) -> &str { &self.access_token } - /// Get the expires in of the access token + /// Returns the number of seconds until this token expires. pub fn expires_in(&self) -> u64 { self.expires_in } - /// Get the refresh expires in of the access token + /// Returns the number of seconds until the refresh token expires. pub fn refresh_expires_in(&self) -> u32 { self.refresh_expires_in } - /// Get the token type of the access token + /// Returns the token type (e.g., `"Bearer"`). pub fn token_type(&self) -> &str { &self.token_type } - /// Get the not before policy of the access token + /// Returns the “not before” policy timestamp from the authentication server. pub fn not_before_policy(&self) -> u32 { self.not_before_policy } - /// Get the scope of the access token + /// Returns the scope string (e.g., `"email"`). pub fn scope(&self) -> &str { &self.scope } @@ -76,10 +112,39 @@ impl Default for AccessToken { } } -/// Fetch and store access tokens to be used in the DSH Rest API client +/// A fetcher for obtaining and storing access tokens, enabling authenticated +/// requests to DSH’s management (REST) API. +/// +/// This struct caches the token in memory and refreshes it automatically +/// once expired. +/// +/// # Usage +/// - **`new`**: Construct a fetcher with provided credentials. +/// - **`new_with_client`**: Provide a custom [`reqwest::Client`] if needed. +/// - **`get_token`**: Returns the current token if still valid, or fetches a new one. +/// +/// # Example +/// ```no_run +/// use dsh_sdk::management_api::ManagementApiTokenFetcher; +/// use dsh_sdk::Platform; /// -/// This struct will fetch and store access tokens to be used in the DSH Rest API client. -/// It will automatically fetch a new token if the current token is not valid. +/// # use std::error::Error; +/// # #[tokio::main] +/// # async fn main() -> Result<(), Box> { +/// let platform = Platform::NpLz; +/// let client_id = platform.rest_client_id("my-tenant"); +/// let client_secret = "my-secret".to_string(); +/// let token_fetcher = ManagementApiTokenFetcher::new( +/// client_id, +/// client_secret, +/// platform.endpoint_rest_access_token().to_string() +/// ); +/// +/// let token = token_fetcher.get_token().await?; +/// println!("Obtained token: {}", token); +/// # Ok(()) +/// # } +/// ``` pub struct ManagementApiTokenFetcher { access_token: Mutex, fetched_at: Mutex, @@ -90,23 +155,33 @@ pub struct ManagementApiTokenFetcher { } impl ManagementApiTokenFetcher { - /// Create a new instance of the token fetcher + /// Creates a new token fetcher. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::{ManagementApiTokenFetcher, Platform}; - /// use dsh_rest_api_client::Client; + /// use dsh_sdk::management_api::ManagementApiTokenFetcher; + /// use dsh_sdk::Platform; /// - /// #[tokio::main] - /// async fn main() { - /// let platform = Platform::NpLz; - /// let client_id = platform.rest_client_id("my-tenant"); - /// let client_secret = "my-secret".to_string(); - /// let token_fetcher = ManagementApiTokenFetcher::new(client_id, client_secret, platform.endpoint_rest_access_token().to_string()); - /// let token = token_fetcher.get_token().await.unwrap(); - /// } + /// # #[tokio::main] + /// # async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret"; + /// let token_fetcher = ManagementApiTokenFetcher::new( + /// client_id, + /// client_secret, + /// platform.endpoint_rest_access_token() + /// ); + /// + /// let token = token_fetcher.get_token().await.unwrap(); + /// println!("Token: {}", token); + /// # } /// ``` - pub fn new(client_id: String, client_secret: String, auth_url: String) -> Self { + pub fn new( + client_id: impl AsRef, + client_secret: impl AsRef, + auth_url: impl AsRef, + ) -> Self { Self::new_with_client( client_id, client_secret, @@ -115,68 +190,95 @@ impl ManagementApiTokenFetcher { ) } - /// Get a [ManagementApiTokenFetcherBuilder] to create a new instance of the token fetcher + /// Returns a [`ManagementApiTokenFetcherBuilder`] for more flexible creation + /// of a token fetcher (e.g., specifying a custom client). pub fn builder(platform: Platform) -> ManagementApiTokenFetcherBuilder { ManagementApiTokenFetcherBuilder::new(platform) } - /// Create a new instance of the token fetcher with custom reqwest client + /// Creates a new fetcher with a **custom** [`reqwest::Client`]. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::{ManagementApiTokenFetcher, Platform}; - /// use dsh_rest_api_client::Client; + /// use dsh_sdk::management_api::ManagementApiTokenFetcher; + /// use dsh_sdk::Platform; /// - /// #[tokio::main] - /// async fn main() { - /// let platform = Platform::NpLz; - /// let client_id = platform.rest_client_id("my-tenant"); - /// let client_secret = "my-secret".to_string(); - /// let client = reqwest::Client::new(); - /// let token_fetcher = ManagementApiTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_rest_access_token().to_string(), client); - /// let token = token_fetcher.get_token().await.unwrap(); - /// } + /// # #[tokio::main] + /// # async fn main() { + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// let client_secret = "my-secret"; + /// let custom_client = reqwest::Client::new(); + /// let token_fetcher = ManagementApiTokenFetcher::new_with_client( + /// client_id, + /// client_secret, + /// platform.endpoint_rest_access_token().to_string(), + /// custom_client + /// ); + /// let token = token_fetcher.get_token().await.unwrap(); + /// println!("Token: {}", token); + /// # } /// ``` pub fn new_with_client( - client_id: String, - client_secret: String, - auth_url: String, + client_id: impl AsRef, + client_secret: impl AsRef, + auth_url: impl AsRef, client: reqwest::Client, ) -> Self { Self { access_token: Mutex::new(AccessToken::default()), fetched_at: Mutex::new(Instant::now()), - client_id, - client_secret, + client_id: client_id.as_ref().to_string(), + client_secret: client_secret.as_ref().to_string(), client, - auth_url, + auth_url: auth_url.as_ref().to_string(), } } - /// Get token from the token fetcher + /// Obtains the token from cache if still valid, otherwise fetches a new one. /// - /// If the cached token is not valid, it will fetch a new token from the server. - /// It will return the token as a string, formatted as "{token_type} {token}" - /// If the request fails for a new token, it will return a [ManagementApiTokenError::FailureTokenFetch] error. - /// This will contain the underlying reqwest error. + /// The returned string is formatted as `"{token_type} {access_token}"`. + /// + /// # Errors + /// - [`ManagementApiTokenError::FailureTokenFetch`]: + /// If the network request fails or times out when fetching a new token. + /// - [`ManagementApiTokenError::StatusCode`]: + /// If the authentication server returns a non-success HTTP status code. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::management_api::ManagementApiTokenFetcher; + /// # #[tokio::main] + /// # async fn main() { + /// let tf = ManagementApiTokenFetcher::new( + /// "client_id".to_string(), + /// "client_secret".to_string(), + /// "http://example.com/auth".to_string() + /// ); + /// match tf.get_token().await { + /// Ok(token) => println!("Got token: {}", token), + /// Err(e) => eprintln!("Error fetching token: {}", e), + /// } + /// } + /// ``` pub async fn get_token(&self) -> Result { - match self.is_valid() { - true => Ok(self.access_token.lock().unwrap().formatted_token()), - false => { - debug!("Token is expired, fetching new token"); - let access_token = self.fetch_access_token_from_server().await?; - let mut token = self.access_token.lock().unwrap(); - let mut fetched_at = self.fetched_at.lock().unwrap(); - *token = access_token; - *fetched_at = Instant::now(); - Ok(token.formatted_token()) - } + if self.is_valid() { + Ok(self.access_token.lock().unwrap().formatted_token()) + } else { + debug!("Token is expired, fetching new token"); + let access_token = self.fetch_access_token_from_server().await?; + let mut token = self.access_token.lock().unwrap(); + let mut fetched_at = self.fetched_at.lock().unwrap(); + *token = access_token; + *fetched_at = Instant::now(); + Ok(token.formatted_token()) } } - /// Check if the current access token is still valid + /// Determines if the internally cached token is still valid. /// - /// If the token has expired, it will return false. + /// A token is considered valid if its remaining lifetime + /// (minus a 5-second safety margin) is greater than zero. pub fn is_valid(&self) -> bool { let access_token = self.access_token.lock().unwrap_or_else(|mut e| { **e.get_mut() = AccessToken::default(); @@ -187,17 +289,18 @@ impl ManagementApiTokenFetcher { self.fetched_at.clear_poison(); e.into_inner() }); - // Check if expires in has elapsed (+ safety margin of 5 seconds) + // Check if 'expires_in' has elapsed (+ 5-second safety margin) fetched_at.elapsed().add(Duration::from_secs(5)) < Duration::from_secs(access_token.expires_in) } - /// Fetch a new access token from the server + /// Fetches a fresh `AccessToken` from the authentication server. /// - /// This will fetch a new access token from the server and return it. - /// If the request fails, it will return a [ManagementApiTokenError::FailureTokenFetch] error. - /// If the status code is not successful, it will return a [ManagementApiTokenError::StatusCode] error. - /// If the request is successful, it will return the [AccessToken]. + /// # Errors + /// - [`ManagementApiTokenError::FailureTokenFetch`]: + /// If the network request fails or times out. + /// - [`ManagementApiTokenError::StatusCode`]: + /// If the server returns a non-success status code. pub async fn fetch_access_token_from_server( &self, ) -> Result { @@ -212,6 +315,7 @@ impl ManagementApiTokenFetcher { .send() .await .map_err(ManagementApiTokenError::FailureTokenFetch)?; + if !response.status().is_success() { Err(ManagementApiTokenError::StatusCode { status_code: response.status(), @@ -232,13 +336,38 @@ impl Debug for ManagementApiTokenFetcher { .field("access_token", &self.access_token) .field("fetched_at", &self.fetched_at) .field("client_id", &self.client_id) + // For security, obfuscate the secret .field("client_secret", &"xxxxxx") .field("auth_url", &self.auth_url) .finish() } } -/// Builder for the managemant api token fetcher +/// A builder for constructing a [`ManagementApiTokenFetcher`]. +/// +/// This builder allows customization of the token fetcher by specifying: +/// - **client_id** or **tenant_name** (tenant name is used to generate the client_id) +/// - **client_secret** +/// - **custom [`reqwest::Client`]** (optional) +/// - **platform** (e.g., [`Platform::NpLz`] or [`Platform::Poc`]) +/// +/// # Example +/// ``` +/// use dsh_sdk::management_api::ManagementApiTokenFetcherBuilder; +/// use dsh_sdk::Platform; +/// +/// # fn main() -> Result<(), Box> { +/// let platform = Platform::NpLz; +/// let client_id = "robot:dev-lz-dsh:my-tenant".to_string(); +/// let client_secret = "secret".to_string(); +/// let token_fetcher = ManagementApiTokenFetcherBuilder::new(platform) +/// .client_id(client_id) +/// .client_secret(client_secret) +/// .build()?; +/// // Use `token_fetcher` +/// # Ok(()) +/// # } +/// ``` pub struct ManagementApiTokenFetcherBuilder { client: Option, client_id: Option, @@ -248,10 +377,11 @@ pub struct ManagementApiTokenFetcherBuilder { } impl ManagementApiTokenFetcherBuilder { - /// Get a new instance of the ClientBuilder + /// Creates a new builder configured for the specified [`Platform`]. /// /// # Arguments - /// * `platform` - The target platform to use for the token fetcher + /// - `platform`: The target platform (e.g., `Platform::NpLz`) to determine + /// default endpoints for fetching tokens. pub fn new(platform: Platform) -> Self { Self { client: None, @@ -262,59 +392,62 @@ impl ManagementApiTokenFetcherBuilder { } } - /// Set the client_id for the client + /// Sets an explicit client ID for authentication. /// - /// Alternatively, set `tenant_name` to generate the client_id. - /// `Client_id` does have precedence over `tenant_name`. - pub fn client_id(mut self, client_id: String) -> Self { - self.client_id = Some(client_id); + /// If you also specify `tenant_name`, the client ID here takes precedence. + pub fn client_id(mut self, client_id: impl AsRef) -> Self { + self.client_id = Some(client_id.as_ref().to_string()); self } - /// Set the client_secret for the client - pub fn client_secret(mut self, client_secret: String) -> Self { - self.client_secret = Some(client_secret); + /// Sets a client secret required for token fetching. + pub fn client_secret(mut self, client_secret: impl AsRef) -> Self { + self.client_secret = Some(client_secret.as_ref().to_string()); self } - /// Set the tenant_name for the client, this will generate the client_id + /// Sets a tenant name from which the client ID will be derived. /// - /// Alternatively, set `client_id` directly. - /// `Tenant_name` does have precedence over `client_id`. - pub fn tenant_name(mut self, tenant_name: String) -> Self { - self.tenant_name = Some(tenant_name); + /// This will use `platform.rest_client_id(tenant_name)` unless `client_id` + /// is already set. + pub fn tenant_name(mut self, tenant_name: impl AsRef) -> Self { + self.tenant_name = Some(tenant_name.as_ref().to_string()); self } - /// Provide a custom configured Reqwest client for the token - /// - /// This is optional, if not provided, a default client will be used. + /// Supplies a custom [`reqwest::Client`] if you need specialized settings + /// (e.g., proxy configuration, timeouts, etc.). pub fn client(mut self, client: reqwest::Client) -> Self { self.client = Some(client); self } - /// Build the client and token fetcher + /// Builds a [`ManagementApiTokenFetcher`] based on the provided configuration. /// - /// This will build the client and token fetcher based on the given parameters. - /// It will return a tuple with the client and token fetcher. + /// # Errors + /// - [`ManagementApiTokenError::UnknownClientSecret`]: + /// If the client secret is unset. + /// - [`ManagementApiTokenError::UnknownClientId`]: + /// If neither `client_id` nor `tenant_name` is provided. /// - /// ## Example + /// # Example /// ``` - /// # use dsh_sdk::{ManagementApiTokenFetcherBuilder, Platform}; - /// let platform = Platform::NpLz; - /// let client_id = "robot:dev-lz-dsh:my-tenant".to_string(); - /// let client_secret = "secret".to_string(); - /// let tf = ManagementApiTokenFetcherBuilder::new(platform) - /// .client_id(client_id) - /// .client_secret(client_secret) - /// .build() - /// .unwrap(); + /// use dsh_sdk::management_api::{ManagementApiTokenFetcherBuilder, ManagementApiTokenError}; + /// use dsh_sdk::Platform; + /// + /// # fn main() -> Result<(), ManagementApiTokenError> { + /// let fetcher = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) + /// .client_id("robot:dev-lz-dsh:my-tenant".to_string()) + /// .client_secret("secret".to_string()) + /// .build()?; + /// # Ok(()) + /// # } /// ``` pub fn build(self) -> Result { let client_secret = self .client_secret .ok_or(ManagementApiTokenError::UnknownClientSecret)?; + let client_id = self .client_id .or_else(|| { @@ -323,6 +456,7 @@ impl ManagementApiTokenFetcherBuilder { .map(|tenant_name| self.platform.rest_client_id(tenant_name)) }) .ok_or(ManagementApiTokenError::UnknownClientId)?; + let client = self.client.unwrap_or_default(); let token_fetcher = ManagementApiTokenFetcher::new_with_client( client_id, @@ -364,6 +498,7 @@ mod test { } } + /// Ensures `AccessToken` is properly deserialized and returns expected fields. #[test] fn test_access_token() { let token_str = r#"{ @@ -384,6 +519,7 @@ mod test { assert_eq!(token.formatted_token(), "Bearer secret_access_token"); } + /// Validates the default constructor yields an empty `AccessToken`. #[test] fn test_access_token_default() { let token = AccessToken::default(); @@ -396,45 +532,53 @@ mod test { assert_eq!(token.formatted_token(), " "); } + /// Verifies that a default token is considered invalid since it expires immediately. #[test] fn test_rest_token_fetcher_is_valid_default_token() { - // Test is_valid when validating default token (should expire in 0 seconds) let tf = create_mock_tf(); - assert!(!tf.is_valid()); + assert!(!tf.is_valid(), "Default token should be invalid"); } + /// Demonstrates that `is_valid` returns true if a token is configured with future expiration. #[test] fn test_rest_token_fetcher_is_valid_valid_token() { let tf = create_mock_tf(); tf.access_token.lock().unwrap().expires_in = 600; - assert!(tf.is_valid()); + assert!( + tf.is_valid(), + "Token with 600s lifetime should be valid initially" + ); } + /// Confirms `is_valid` returns false after the token’s entire lifetime has elapsed. #[test] fn test_rest_token_fetcher_is_valid_expired_token() { - // Test is_valid when validating an expired token let tf = create_mock_tf(); tf.access_token.lock().unwrap().expires_in = 600; *tf.fetched_at.lock().unwrap() = Instant::now() - Duration::from_secs(600); - assert!(!tf.is_valid()); + assert!(!tf.is_valid(), "Token should expire after 600s have passed"); } + /// Tests behavior when a token is “poisoned” (i.e., panicked while locked). #[test] fn test_rest_token_fetcher_is_valid_poisoned_token() { - // Test is_valid when token is poisoned let tf = create_mock_tf(); tf.access_token.lock().unwrap().expires_in = 600; let tf_arc = std::sync::Arc::new(tf); let tf_clone = tf_arc.clone(); - assert!(tf_arc.is_valid(), "Token should be valid"); - let h = std::thread::spawn(move || { + assert!(tf_arc.is_valid(), "Token should be valid before poison"); + let handle = std::thread::spawn(move || { let _unused = tf_clone.access_token.lock().unwrap(); - panic!("Poison token") + panic!("Poison token"); }); - let _ = h.join(); - assert!(!tf_arc.is_valid(), "Token should be invalid"); + let _ = handle.join(); + assert!( + !tf_arc.is_valid(), + "Token should be reset to default after poisoning" + ); } + /// Checks success scenario for fetching a new token from a mock server. #[tokio::test] async fn test_fetch_access_token_from_server() { let mut auth_server = mockito::Server::new_async().await; @@ -457,12 +601,9 @@ mod test { let token = tf.fetch_access_token_from_server().await.unwrap(); assert_eq!(token.access_token(), "secret_access_token"); assert_eq!(token.expires_in(), 600); - assert_eq!(token.refresh_expires_in(), 0); - assert_eq!(token.token_type(), "Bearer"); - assert_eq!(token.not_before_policy(), 0); - assert_eq!(token.scope(), "email"); } + /// Checks that an HTTP 400 response is handled as an error. #[tokio::test] async fn test_fetch_access_token_from_server_error() { let mut auth_server = mockito::Server::new_async().await; @@ -486,6 +627,7 @@ mod test { } } + /// Ensures the builder sets `client_id` explicitly. #[test] fn test_token_fetcher_builder_client_id() { let platform = Platform::NpLz; @@ -501,6 +643,7 @@ mod test { assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); } + /// Ensures the builder can auto-generate `client_id` from the `tenant_name`. #[test] fn test_token_fetcher_builder_tenant_name() { let platform = Platform::NpLz; @@ -519,6 +662,7 @@ mod test { assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); } + /// Validates that a custom `reqwest::Client` can be injected into the builder. #[test] fn test_token_fetcher_builder_custom_client() { let platform = Platform::NpLz; @@ -536,6 +680,7 @@ mod test { assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); } + /// Tests precedence of `client_id` over a derived tenant-based client ID. #[test] fn test_token_fetcher_builder_client_id_precedence() { let platform = Platform::NpLz; @@ -553,6 +698,7 @@ mod test { assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); } + /// Ensures builder returns errors if `client_id` or `client_secret` are missing. #[test] fn test_token_fetcher_builder_build_error() { let err = ManagementApiTokenFetcherBuilder::new(Platform::NpLz) diff --git a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs index 7f50893..fa31e2f 100644 --- a/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs +++ b/dsh_sdk/src/protocol_adapters/kafka_protocol/mod.rs @@ -1,8 +1,11 @@ -//! DSH Configuration for Kafka. +//! DSH Configuration for Kafka //! -//! This module contains the required configurations to consume and produce messages from DSH Kafka Cluster. +//! This module provides all necessary configurations for consuming and producing messages +//! to/from the DSH (Data Services Hub) Kafka Cluster. The [`DshKafkaConfig`] trait is at +//! the core of this module, guiding you to set the essential Kafka config parameters +//! automatically (e.g., brokers, security certificates, group ID). //! -//! ## Example +//! # Example //! ``` //! use dsh_sdk::DshKafkaConfig; //! use rdkafka::ClientConfig; @@ -10,50 +13,86 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! let consumer:StreamConsumer = ClientConfig::new().set_dsh_consumer_config().create()?; +//! // Build an rdkafka consumer with DSH settings. +//! let consumer: StreamConsumer = ClientConfig::new() +//! .set_dsh_consumer_config() +//! .create()?; +//! +//! // Use your consumer... //! # Ok(()) //! # } //! ``` + pub mod config; #[cfg(feature = "rdkafka")] mod rdkafka; -/// Set all required configurations to consume messages from DSH Kafka Cluster. +/// Trait defining core DSH configurations for Kafka consumers and producers. +/// +/// Implementing `DshKafkaConfig` ensures that the correct settings (including SSL) +/// are applied for connecting to a DSH-managed Kafka cluster. The trait provides: +/// - [`set_dsh_consumer_config`](DshKafkaConfig::set_dsh_consumer_config) +/// - [`set_dsh_producer_config`](DshKafkaConfig::set_dsh_producer_config) +/// - [`set_dsh_group_id`](DshKafkaConfig::set_dsh_group_id) +/// - [`set_dsh_certificates`](DshKafkaConfig::set_dsh_certificates) +/// +/// # Environment Variables +/// Via environment variables you can override or supplement certain default settings: +/// +/// See [ENV_VARIABLES.md](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/ENV_VARIABLES.md) for the full list. +/// +/// By configuring these variables, you can control broker endpoints, group IDs, and +/// various Kafka client behaviors without modifying code. pub trait DshKafkaConfig { - /// Set all required configurations to consume messages from DSH Kafka Cluster. + /// Applies all required consumer settings to connect with the DSH Kafka Cluster. /// - /// | **config** | **Default value** | **Remark** | - /// |---------------------------|----------------------------------|------------------------------------------------------------------------| - /// | `bootstrap.servers` | Brokers based on datastreams | Overwritable by env variable KAFKA_BOOTSTRAP_SERVERS` | - /// | `group.id` | Shared Group ID from datastreams | Overwritable by setting `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`| - /// | `client.id` | Task_id of service | | - /// | `enable.auto.commit` | `false` | Overwritable by setting `KAFKA_ENABLE_AUTO_COMMIT` | - /// | `auto.offset.reset` | `earliest` | Overwritable by setting `KAFKA_AUTO_OFFSET_RESET` | - /// | `security.protocol` | ssl (DSH) / plaintext (local) | Security protocol | - /// | `ssl.key.pem` | private key | Generated when sdk is initiated | - /// | `ssl.certificate.pem` | dsh kafka certificate | Signed certificate to connect to kafka cluster | - /// | `ssl.ca.pem` | CA certifacte | CA certificate, provided by DSH. | + /// Below is a table of configurations applied by this function: + /// + /// | **Config Key** | **Default Value** | **Overridable?** | **Description** | + /// |----------------------------|----------------------------------|----------------------------------------------------------------|---------------------------------------------------------------------------------| + /// | `bootstrap.servers` | Brokers from `datastreams.json` | Env var `KAFKA_BOOTSTRAP_SERVERS` | List of Kafka brokers to connect to. | + /// | `group.id` | Shared group from `datastreams` | Env vars `KAFKA_GROUP_ID` / `KAFKA_CONSUMER_GROUP_TYPE` | Consumer group ID (DSH requires tenant prefix). | + /// | `client.id` | `task_id` of the service | _No direct override_ | Used for consumer identification in logs/metrics. | + /// | `enable.auto.commit` | `false` | Env var `KAFKA_ENABLE_AUTO_COMMIT` | Controls whether offsets are committed automatically. | + /// | `auto.offset.reset` | `earliest` | Env var `KAFKA_AUTO_OFFSET_RESET` | Defines behavior when no valid offset is available (e.g., `earliest`, `latest`).| + /// | `security.protocol` | `ssl` to DSH, `plaintext` locally| _Internal_ | Chooses SSL if DSH certificates are present, otherwise plaintext. | + /// | `ssl.key.pem` | Private key from certificates | _Auto-configured_ | Loaded from SDK during bootstrap. | + /// | `ssl.certificate.pem` | DSH Kafka certificate | _Auto-configured_ | Signed certificate to connect to the Kafka cluster. | + /// | `ssl.ca.pem` | CA certificate from DSH | _Auto-configured_ | Authority certificate for SSL. | fn set_dsh_consumer_config(&mut self) -> &mut Self; - /// Set all required configurations to produce messages to DSH Kafka Cluster. + + /// Applies all required producer settings to publish messages to the DSH Kafka Cluster. /// - /// ## Configurations - /// | **config** | **Default value** | **Remark** | - /// |---------------------|--------------------------------|-----------------------------------------------------------------------------------------| - /// | bootstrap.servers | Brokers based on datastreams | Overwritable by env variable `KAFKA_BOOTSTRAP_SERVERS` | - /// | client.id | task_id of service | Based on task_id of running service | - /// | security.protocol | ssl (DSH)) / plaintext (local) | Security protocol | - /// | ssl.key.pem | private key | Generated when bootstrap is initiated | - /// | ssl.certificate.pem | dsh kafka certificate | Signed certificate to connect to kafka cluster
(signed when bootstrap is initiated) | - /// | ssl.ca.pem | CA certifacte | CA certificate, provided by DSH. | + /// ## Producer Configurations + /// | **Config Key** | **Default Value** | **Overridable?** | **Description** | + /// |----------------------------|----------------------------------|---------------------------------------------------|-------------------------------------------------------------------------------------| + /// | `bootstrap.servers` | Brokers from `datastreams.json` | Env var `KAFKA_BOOTSTRAP_SERVERS` | List of Kafka brokers to connect to. | + /// | `client.id` | `task_id` of the service | _No direct override_ | Used for producer identification in logs/metrics. | + /// | `security.protocol` | `ssl` in DSH, `plaintext` locally| _Internal_ | Chooses SSL if DSH certificates are present, otherwise plaintext. | + /// | `ssl.key.pem` | Private key from certificates | _Auto-configured_ | Loaded from SDK during bootstrap. | + /// | `ssl.certificate.pem` | DSH Kafka certificate | _Auto-configured_ | Signed certificate (when bootstrapped) to connect to the Kafka cluster. | + /// | `ssl.ca.pem` | CA certificate from DSH | _Auto-configured_ | Authority certificate for SSL. | fn set_dsh_producer_config(&mut self) -> &mut Self; - /// Set a DSH compatible group id. + + /// Applies a DSH-compatible group ID. /// - /// DSH Requires a group id with the prefix of the tenant name. + /// DSH requires the consumer group ID to be prefixed with the tenant name. + /// If an environment variable (e.g., `KAFKA_GROUP_ID` or `KAFKA_CONSUMER_GROUP_TYPE`) + /// is set, that value can override what is found in `datastreams.json`. fn set_dsh_group_id(&mut self, group_id: &str) -> &mut Self; - /// Set the required DSH Certificates. + + /// Sets the required DSH certificates for secure SSL connections. + /// + /// If the required certificates are found (via the DSH bootstrap or + /// environment variables), this function configures SSL. Otherwise, + /// it falls back to plaintext (for local development). /// - /// This function will set the required SSL configurations if the certificates are present. - /// Else it will return plaintext. (for connection to a local kafka cluster) + /// # Note + /// This method typically sets: + /// - `security.protocol` to `ssl` + /// - `ssl.key.pem`, `ssl.certificate.pem`, and `ssl.ca.pem` + /// + /// If certificates are missing, `security.protocol` remains `plaintext`. fn set_dsh_certificates(&mut self) -> &mut Self; } diff --git a/dsh_sdk/src/rest_api_token_fetcher.rs b/dsh_sdk/src/rest_api_token_fetcher.rs index 45542dc..12fe063 100644 --- a/dsh_sdk/src/rest_api_token_fetcher.rs +++ b/dsh_sdk/src/rest_api_token_fetcher.rs @@ -226,7 +226,7 @@ impl RestTokenFetcher { /// This will fetch a new access token from the server and return it. /// If the request fails, it will return a [DshRestTokenError::FailureTokenFetch] error. /// If the status code is not successful, it will return a [DshRestTokenError::StatusCode] error. - /// If the request is successful, it will return the [AccessToken]. + /// If the request is successful, it will return the AccesToken pub async fn fetch_access_token_from_server(&self) -> Result { let response = self .client diff --git a/dsh_sdk/src/schema_store/client.rs b/dsh_sdk/src/schema_store/client.rs index 600278b..4e4e2d8 100644 --- a/dsh_sdk/src/schema_store/client.rs +++ b/dsh_sdk/src/schema_store/client.rs @@ -4,20 +4,73 @@ use super::types::*; use super::SchemaStoreError; use crate::Dsh; -/// High level Schema Store Client +/// A high-level client for interacting with the DSH Schema Store API. /// -/// Client to interact with the Schema Store API. +/// This client wraps various schema registry operations, such as: +/// - Retrieving or setting a subject’s compatibility level. +/// - Listing all subjects and versions. +/// - Fetching a specific schema (by subject/version or by schema ID). +/// - Adding new schemas or checking if they’re already registered. +/// - Verifying schema compatibility against an existing subject/version. +/// +/// By default, the client’s base URL is derived from [`Dsh::get().schema_registry_host()`]. +/// You can override this behavior with [`SchemaStoreClient::new_with_base_url`]. +/// +/// Most methods return a [`Result`], which encapsulates +/// potential network failures or schema parsing issues. +/// +/// # Example +/// ```no_run +/// use dsh_sdk::schema_store::{SchemaStoreClient, types::SubjectName}; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let client = SchemaStoreClient::new(); +/// let subject_name: SubjectName = "my.topic-name.tenant-value".try_into()?; +/// let subjects = client.subjects().await?; +/// println!("All subjects: {:?}", subjects); +/// let versions = client.subject_versions(&subject_name).await?; +/// println!("Versions of {:?}: {:?}", subject_name, versions); +/// Ok(()) +/// } +/// ``` pub struct SchemaStoreClient { pub(crate) base_url: String, pub(crate) client: C, } impl SchemaStoreClient { + /// Creates a new `SchemaStoreClient` using the default schema registry URL from + /// [`Dsh::get().schema_registry_host()`]. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = SchemaStoreClient::new(); + /// // Use `client` to interact with the schema store... + /// } + /// ``` pub fn new() -> Self { Self::new_with_base_url(Dsh::get().schema_registry_host()) } - /// Create SchemaStoreClient with a custom base URL + /// Creates a `SchemaStoreClient` with a **custom** base URL. + /// + /// This is useful if you want to target a non-default or test endpoint. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// + /// #[tokio::main] + /// async fn main() { + /// let base_url = "http://my.custom-registry/api"; + /// let client = SchemaStoreClient::new_with_base_url(base_url); + /// } + /// ``` pub fn new_with_base_url(base_url: &str) -> Self { Self { base_url: base_url.trim_end_matches('/').to_string(), @@ -30,27 +83,30 @@ impl SchemaStoreClient where C: Request, { - /// Get the compatibility level for a subject - /// - /// ## Returns - /// Returns a Result of the compatibility level of given subject + /// Retrieves the compatibility level for a given subject. /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// # Returns + /// - `Ok(Compatibility)` if successful, representing the subject’s configured compatibility level. /// - /// ## Example - /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::SubjectName; + /// # Arguments + /// - `subject`: A [`SubjectName`]. Conversion from a `&str` or `String` can be done via `try_into()`. /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// println!("Config: {:?}", client.subject_compatibility(&subject_name).await); - /// # Ok(()) - /// # } + /// # Errors + /// Returns [`SchemaStoreError`] if the request fails or if the subject name is invalid. /// + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SubjectName}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject_name: SubjectName = "example-topic.tenant-value".try_into()?; + /// let comp = client.subject_compatibility(&subject_name).await?; + /// println!("Subject compatibility: {:?}", comp); + /// Ok(()) + /// } + /// ``` pub async fn subject_compatibility( &self, subject: &SubjectName, @@ -58,29 +114,29 @@ where Ok(self.get_config_subject(subject.name()).await?.into()) } - /// Set the compatibility level for a subject + /// Sets (updates) the compatibility level for a given subject. /// - /// Set compatibility on subject level. With 1 schema stored in the subject, you can change it to any compatibility level. - /// Else, you can only change into a less restrictive level. + /// - If the subject has no existing schema, you can set any compatibility. + /// - If the subject already has schemas, you can only switch to a **less restrictive** level. /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// # Returns + /// - `Ok(Compatibility)` representing the **new** compatibility level. /// - /// ## Returns - /// Returns a Result of the new compatibility level + /// # Errors + /// Returns [`SchemaStoreError`] if the network call fails, if the subject doesn’t exist, + /// or if the requested compatibility is not allowed. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::{Compatibility, SubjectName}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// client.subject_compatibility_update(&subject_name, Compatibility::FULL).await?; - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::{SubjectName, Compatibility}}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// client.subject_compatibility_update(&subject, Compatibility::FULL).await?; + /// Ok(()) + /// } /// ``` pub async fn subject_compatibility_update( &self, @@ -93,42 +149,48 @@ where .into()) } - /// Get a list of all registered subjects + /// Lists **all** registered subjects in the schema registry. /// - /// ## Returns - /// Returns a Result of of all registered subjects from the schema registry + /// # Returns + /// - `Ok(Vec)` containing subject names. /// - /// ## Example + /// # Errors + /// Returns [`SchemaStoreError`] if the HTTP request or JSON deserialization fails. + /// + /// # Example /// ```no_run /// use dsh_sdk::schema_store::SchemaStoreClient; /// - /// # #[tokio::main] - /// # async fn main() { - /// let client = SchemaStoreClient::new(); - /// println!("Subjects: {:?}", client.subjects().await); - /// # } + /// #[tokio::main] + /// async fn main() { + /// let client = SchemaStoreClient::new(); + /// match client.subjects().await { + /// Ok(subjs) => println!("Registered subjects: {:?}", subjs), + /// Err(e) => eprintln!("Error: {}", e), + /// } + /// } /// ``` pub async fn subjects(&self) -> Result, SchemaStoreError> { self.get_subjects().await } - /// Get a list of all versions of a subject + /// Retrieves the version IDs for a specified subject. /// - /// ## Returns - /// Returns a Result of all version ID's of a subject from the schema registry + /// # Returns + /// - `Ok(Vec)` containing the version numbers registered for this subject. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::SubjectName; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// println!("Available versions: {:?}", client.subject_versions(&subject_name).await); - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SubjectName}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// let versions = client.subject_versions(&subject).await?; + /// println!("Subject versions: {:?}", versions); + /// Ok(()) + /// } /// ``` pub async fn subject_versions( &self, @@ -137,30 +199,28 @@ where self.get_subjects_subject_versions(subject.name()).await } - /// Get subject for specific version - /// - /// ## Returns - /// Returns a Result of the schema for the given subject and version + /// Fetches a specific schema for a given subject at a specified version. /// - /// ## Example - /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::{SubjectName, SubjectVersion}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; + /// - Use [`SubjectVersion::Latest`] for the latest version. + /// - Use [`SubjectVersion::Version(i32)`] for a specific numbered version. /// - /// // Get the latest version of the schema - /// let subject = client.subject(&subject_name, SubjectVersion::Latest).await?; - /// let raw_schema = subject.schema; + /// # Returns + /// - `Ok(Subject)` containing metadata and the schema content. /// - /// // Get a specific version of the schema - /// let subject = client.subject(&subject_name, SubjectVersion::Version(1)).await?; - /// let raw_schema = subject.schema; - /// # Ok(()) - /// # } + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::{SubjectName, SubjectVersion}}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// // Latest version + /// let latest = client.subject(&subject, SubjectVersion::Latest).await?; + /// // Specific version + /// let specific = client.subject(&subject, SubjectVersion::Version(2)).await?; + /// Ok(()) + /// } /// ``` pub async fn subject( &self, @@ -176,27 +236,25 @@ where .await } - /// Get the raw schema string for the specified version of subject. + /// Retrieves **only** the raw schema string for a specified subject version. /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) - /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) + /// This is useful if you only need the JSON/Avro/Protobuf text, without additional metadata. /// - /// ## Returns - /// Returns a Result of the raw schema string for the given subject and version + /// # Returns + /// - `Ok(String)` containing the schema definition in its raw form. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::SubjectName; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// let raw_schema = client.subject_raw_schema(&subject_name, 1).await.unwrap(); - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SubjectName}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// let raw = client.subject_raw_schema(&subject, 1).await?; + /// println!("Schema text: {}", raw); + /// Ok(()) + /// } /// ``` pub async fn subject_raw_schema( &self, @@ -210,94 +268,75 @@ where .await } - /// Get all schemas for a subject - /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) + /// Retrieves **all** schema versions for a specified subject, returning a vector of [`Subject`]. /// - /// ## Returns - /// Returns a Result of all schemas for the given subject + /// This method simply calls [`subject_versions`](Self::subject_versions) and then iterates + /// over each version to fetch the schema details. + /// _Note: This can be more expensive than retrieving a single version._ /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::SubjectName; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// let subjects = client.subject_all_schemas(&subject_name).await?; - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::SubjectName}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// let all_schemas = client.subject_all_schemas(&subject).await?; + /// println!("All schemas: {:?}", all_schemas); + /// Ok(()) + /// } + /// ``` pub async fn subject_all_schemas( &self, subject: &SubjectName, ) -> Result, SchemaStoreError> { - let versions = self.subject_versions(&subject).await?; + let versions = self.subject_versions(subject).await?; let mut subjects = Vec::new(); for version in versions { - let subject = self.subject(&subject, version).await?; - subjects.push(subject); + let subject_schema = self.subject(subject, version).await?; + subjects.push(subject_schema); } Ok(subjects) } - /// Get all schemas for a topic - /// - /// ## Arguments - /// - `topic`: &str/String of the topic name - /// - /// ## Returns - /// - // pub async fn topic_all_schemas(&self, topic: S) -> Result<(Vec,Vec)> + // Example of a commented-out method that’s not fully implemented yet: + // /// Gets all schemas for a given topic. + // /// This might differentiate key vs. value schemas. + // pub async fn topic_all_schemas(&self, topic: S) -> Result<(Vec, Vec), SchemaStoreError> // where // S: AsRef, // { - // let key_schemas = self.subject_all_schemas((topic.as_ref(), true)).await?; - // let value_schemas = self.subject_all_schemas((topic.as_ref(), false)).await?; - // Ok(subjects) + // // Implementation to fetch key_schemas and value_schemas is pending. // } - /// Post a new schema for a (new) subject + /// Registers a **new** schema under the given subject. /// - /// ## Errors - /// - If the given schema cannot be converted into a String with given schema type - /// - The API call will retun a error when - /// - subject already has a schema and it's compatibility does not allow it - /// - subject already has a schema with a different schema type - /// - schema is invalid + /// - If the subject doesn’t exist, it is created with the provided schema. + /// - If the subject **does** exist and is incompatible with this schema, the registry + /// returns an error. If the schema is identical, the existing ID is returned. /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) - /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) + /// # Returns + /// - `Ok(i32)` containing the new or existing schema ID. /// - /// ## Returns - /// Returns a Result of the new schema ID. - /// If schema already exists, it will return with the existing schema ID. + /// # Errors + /// - If the schema can’t be converted into a valid `RawSchemaWithType`. + /// - If the API call fails due to network/compatibility issues. /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::{RawSchemaWithType, SubjectName}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// - /// // Get subjectname (note it ends on "-value") - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// - /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) - /// let raw_schema = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#; - /// let schema_with_type:RawSchemaWithType = raw_schema.try_into()?; - /// let schema_version = client.subject_add_schema(&subject_name, schema_with_type).await?; - /// - /// // Or if you have a schema object - /// let avro_schema:RawSchemaWithType = apache_avro::Schema::parse_str(raw_schema)?.try_into()?; // or ProtoBuf or JSON schema - /// let schema_version = client.subject_add_schema(&subject_name, avro_schema).await?; - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::{RawSchemaWithType, SubjectName}}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// let raw_schema = r#"{"type":"record","name":"User","fields":[{"name":"name","type":"string"}]}"#; + /// let schema_with_type: RawSchemaWithType = raw_schema.try_into()?; + /// let schema_id = client.subject_add_schema(&subject, schema_with_type).await?; + /// println!("Schema ID: {}", schema_id); + /// Ok(()) + /// } /// ``` pub async fn subject_add_schema( &self, @@ -310,38 +349,31 @@ where .id()) } - /// Check if schema already been registred for a subject + /// Checks if a given schema already exists under the specified subject. /// - /// If it returns 404, it means the schema is not yet registered (even when it states "unable to process") + /// - Returns 404 if the schema is not registered under that subject. + /// - Returns [`Subject`] info (including schema ID) if it **is** already present. /// - /// ## Errors - /// - If the given schema cannot be converted into a String with given schema type - /// - The API call will retun a error when - /// - provided schema is different - /// - schema is invalid + /// # Returns + /// - `Ok(Subject)` if the schema matches an existing registration. + /// - `Err(SchemaStoreError)` if the request fails or the schema is invalid. /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) - /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) - /// - /// ## Returns - /// If schema exists, it will return with the existing version and schema ID. - /// - /// ## Example + /// # Example /// ```no_run - /// use dsh_sdk::schema_store::SchemaStoreClient; - /// use dsh_sdk::schema_store::types::{SubjectName, SchemaType, RawSchemaWithType}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// let client = SchemaStoreClient::new(); - /// - /// // You can provide the schema as a raw string (Schema type is optional, it will be detected automatically) - /// let raw_schema: RawSchemaWithType = r#"{ "type": "record", "name": "User", "fields": [ { "name": "name", "type": "string" } ] }"#.try_into()?; - /// let subject_name: SubjectName = "scratch.example-topic.tenant-value".try_into()?; - /// let subject = client.subject_schema_exist(&subject_name, raw_schema).await?; - /// # Ok(()) - /// # } + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::{SubjectName, RawSchemaWithType}}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let raw_schema = r#"{"type":"record","name":"User","fields":[{"name":"age","type":"int"}]}"#; + /// let schema: RawSchemaWithType = raw_schema.try_into()?; + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// match client.subject_schema_exist(&subject, schema).await { + /// Ok(existing) => println!("Schema already registered: {:?}", existing.id), + /// Err(e) => eprintln!("Not found or error: {}", e), + /// } + /// Ok(()) + /// } /// ``` pub async fn subject_schema_exist( &self, @@ -351,18 +383,28 @@ where self.post_subjects_subject(subject.name(), schema).await } - /// Check if schema is compatible with a specific version of a subject based on the compatibility level + /// Checks if a **new** schema is compatible with a specific version of the subject. /// - /// Note that the compatibility level applied for the check is the configured compatibility level for the subject. - /// If this subject’s compatibility level was never changed, then the global compatibility level applies. + /// This leverages the configured compatibility level for the subject (or global level if none is explicitly set). /// - /// ## Arguments - /// - `subject`: [SubjectName], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SubjectStrategy) - /// - `version`: Anything that can be converted into a [SubjectVersion] - /// - `schema`: [RawSchemaWithType], use [TryInto] to convert from &str/String (Returns [SchemaStoreError] error if invalid SchemaType) + /// # Returns + /// - `Ok(bool)` indicating whether the new schema is compatible (`true`) or incompatible (`false`). /// - /// ## Returns - /// Returns a Result of a boolean if the schema is compatible with the given version of the subject + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::{SchemaStoreClient, types::{SubjectName, RawSchemaWithType, SubjectVersion}}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let raw_schema = r#"{"type":"record","name":"User","fields":[{"name":"name","type":"string"}]}"#; + /// let schema: RawSchemaWithType = raw_schema.try_into()?; + /// let subject: SubjectName = "example-topic.tenant-value".try_into()?; + /// let is_compatible = client.subject_new_schema_compatibility(&subject, SubjectVersion::Latest, schema).await?; + /// println!("Is compatible? {}", is_compatible); + /// Ok(()) + /// } + /// ``` pub async fn subject_new_schema_compatibility( &self, subject: &SubjectName, @@ -382,10 +424,22 @@ where .is_compatible()) } - /// Get the schema based in schema ID. + /// Retrieves a schema by its **global** schema ID. /// - /// ## Arguments - /// - `id`: The schema ID (Into<[i32]>) + /// # Arguments + /// - `id`: schema ID (`i32`) referencing the global schema registry ID. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let schema = client.schema(123).await?; + /// println!("Schema content: {}", schema.content()); + /// # Ok(()) + /// # } + /// ``` pub async fn schema(&self, id: Si) -> Result where Si: Into, @@ -393,10 +447,22 @@ where self.get_schemas_ids_id(id.into()).await } - /// Get all subjects that are using the given schema + /// Lists all subjects that use the specified **global** schema ID. /// - /// ## Arguments - /// - `id`: The schema ID (Into<[i32]>) + /// # Returns + /// - `Ok(Vec)` detailing each subject and version that references the schema. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::schema_store::SchemaStoreClient; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let client = SchemaStoreClient::new(); + /// let references = client.schema_subjects(123).await?; + /// println!("Subjects referencing schema #123: {:?}", references); + /// # Ok(()) + /// # } + /// ``` pub async fn schema_subjects( &self, id: Si, diff --git a/dsh_sdk/src/schema_store/types/schema/raw_schema.rs b/dsh_sdk/src/schema_store/types/schema/raw_schema.rs index 1222ef4..9cf42f4 100644 --- a/dsh_sdk/src/schema_store/types/schema/raw_schema.rs +++ b/dsh_sdk/src/schema_store/types/schema/raw_schema.rs @@ -23,7 +23,7 @@ impl RawSchemaWithType { } /// Raw schema string - pub fn schema(&self) -> &str { + pub fn content(&self) -> &str { &self.schema } @@ -153,7 +153,7 @@ mod tests { r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; let schema = RawSchemaWithType::parse(raw_schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -161,7 +161,7 @@ mod tests { let raw_schema = r#"{"fields":[{"name":"name","type":"string"}],"name":"User"}"#; let schema = RawSchemaWithType::parse(raw_schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::JSON); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -169,7 +169,7 @@ mod tests { let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; let schema = RawSchemaWithType::parse(raw_schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -190,7 +190,7 @@ mod tests { let schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); let schema = RawSchemaWithType::try_from(schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -199,7 +199,7 @@ mod tests { let schema = serde_json::from_str::(raw_schema).unwrap(); let schema = RawSchemaWithType::try_from(schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::JSON); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -207,7 +207,7 @@ mod tests { let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; let schema = RawSchemaWithType::try_from(raw_schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -215,7 +215,7 @@ mod tests { let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::JSON)).unwrap(); assert_eq!(schema.schema_type(), SchemaType::JSON); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -224,7 +224,7 @@ mod tests { r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::AVRO)).unwrap(); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -232,7 +232,7 @@ mod tests { let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; let schema = RawSchemaWithType::try_from((raw_schema, SchemaType::PROTOBUF)).unwrap(); assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -259,18 +259,18 @@ mod tests { let raw_schema = r#"{"name":"User","fields":[{"name":"name","type":"string"}]}"#; let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); assert_eq!(schema.schema_type(), SchemaType::JSON); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); let raw_schema = r#"{"name":"User","type":"record","fields":[{"name":"name","type":"string"}]}"#; let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); let raw_schema = r#"syntax = "proto3"; message User { string name = 1; }"#; let schema = RawSchemaWithType::try_from(raw_schema.to_string()).unwrap(); assert_eq!(schema.schema_type(), SchemaType::PROTOBUF); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); let raw_schema = r#"not a schema"#; let schema = RawSchemaWithType::try_from(raw_schema.to_string()); @@ -301,7 +301,7 @@ mod tests { }; let schema = RawSchemaWithType::from(subject); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -310,7 +310,7 @@ mod tests { let schema = serde_json::from_str::(raw_schema).unwrap(); let schema = RawSchemaWithType::try_from(schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::JSON); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } #[test] @@ -320,6 +320,6 @@ mod tests { let schema = AvroSchema::parse_str(raw_schema).unwrap(); let schema = RawSchemaWithType::try_from(schema).unwrap(); assert_eq!(schema.schema_type(), SchemaType::AVRO); - assert_eq!(schema.schema(), raw_schema); + assert_eq!(schema.content(), raw_schema); } } diff --git a/dsh_sdk/src/utils/dlq/dlq.rs b/dsh_sdk/src/utils/dlq/dlq.rs index 82b2d46..266642e 100644 --- a/dsh_sdk/src/utils/dlq/dlq.rs +++ b/dsh_sdk/src/utils/dlq/dlq.rs @@ -1,4 +1,4 @@ -//! Dead Letter Queue client +//! Dead Letter Queue (DLQ) client for handling messages that cannot be processed successfully. use std::str::from_utf8; @@ -11,93 +11,70 @@ use rdkafka::ClientConfig; use tokio::sync::mpsc; use super::headers::{DlqHeaders, HashMapToKafkaHeaders}; - use super::{DlqChannel, DlqErrror, Retryable, SendToDlq}; use crate::utils::get_env_var; use crate::utils::graceful_shutdown::Shutdown; use crate::DshKafkaConfig; -/// The dead letter queue -/// -/// # How to use -/// 1. Implement the [`ErrorToDlq`](super::ErrorToDlq) trait on top your (custom) error type. -/// 2. Use the [`Dlq::start`] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) -/// 3. Get the dlq [`DlqChannel`] from the [`Dlq::start`] method and use this channel to communicate errored messages with the [`Dlq`] via the [`ErrorToDlq::to_dlq`](super::ErrorToDlq::to_dlq) method. +/// Dead Letter Queue (DLQ) struct that runs asynchronously and processes error messages. /// -/// # Importance of `DlqChannel` in the graceful shutdown procedure -/// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. -/// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. -/// This is to make sure that **all** messages are properly processed before the application is shut down. +/// Once started via [`Dlq::start`], it listens on an `mpsc` channel for [`SendToDlq`] items. +/// Each received item is routed to the configured “dead” or “retry” Kafka topics, +/// depending on whether it is [`Retryable::Retryable`] or not. /// -/// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. -/// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. -/// -/// # Example -/// See full implementation example [here](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) +/// A full implementation can be found in the [DLQ example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs). pub struct Dlq { dlq_producer: FutureProducer, dlq_rx: mpsc::Receiver, dlq_dead_topic: String, dlq_retry_topic: String, - _shutdown: Shutdown, // hold the shutdown alive until exit + // Holding the shutdown handle ensures it remains valid until the DLQ stops. + _shutdown: Shutdown, } impl Dlq { - /// Start the dlq on a tokio task - /// - /// The DLQ will run until the return `Sender` is dropped. + /// Spawns the DLQ in a dedicated Tokio task, returning a [`DlqChannel`] for sending error messages. /// - /// # Arguments - /// * `shutdown` - The [`Shutdown`] is required to keep the DLQ alive until the [`DlqChannel`] is dropped + /// - Internally creates a Kafka producer with [`set_dsh_producer_config`](DshKafkaConfig::set_dsh_producer_config). + /// - Reads environment variables for `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC` to determine the topics + /// used for permanently-dead and retryable messages. /// /// # Returns - /// * The [DlqChannel] to send messages to the DLQ + /// A `DlqChannel` (an `mpsc::Sender`) used by your worker logic to push errored messages. /// - /// # Importance of `DlqChannel` in the graceful shutdown procedure - /// The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. - /// This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. - /// This is to make sure that **all** messages are properly processed before the application is shut down. + /// # Shutdown Procedure + /// The DLQ runs until its channel is dropped. Once the channel closes, the DLQ finishes pending + /// messages and then stops. The [`Shutdown`](crate::utils::graceful_shutdown::Shutdown) handle + /// ensures that the main application waits for the DLQ to finish. /// - /// **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. - /// It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. + /// # Errors + /// Returns a [`DlqErrror`] if the producer could not be created or if environment variables + /// (`DLQ_DEAD_TOPIC`, `DLQ_RETRY_TOPIC`) are missing. /// /// # Example + /// A full implementation can be found in the [DLQ example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs), /// ```no_run /// use dsh_sdk::utils::graceful_shutdown::Shutdown; - /// use dsh_sdk::utils::dlq::{Dlq, DlqChannel, SendToDlq}; - /// - /// async fn consume(dlq_channel: DlqChannel) { - /// // Your consumer logic together with error handling - /// loop { - /// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - /// } - /// } + /// use dsh_sdk::utils::dlq::Dlq; /// /// #[tokio::main] /// async fn main() { /// let shutdown = Shutdown::new(); - /// let dlq_channel = Dlq::start(shutdown.clone()).unwrap(); - /// - /// tokio::select! { - /// _ = async move { - /// // Your consumer logic together with the owned dlq_channel - /// dlq_channel - /// } => {} - /// _ = shutdown.signal_listener() => { - /// println!("Shutting down consumer"); - /// } - /// } - /// // wait for graceful shutdown to complete - /// // NOTE that the `dlq_channel` will go out of scope when shutdown is called and the DLQ will stop - /// shutdown.complete().await; + /// let dlq_channel = Dlq::start(shutdown.clone()).expect("Failed to start DLQ"); + /// + /// // Spawn your worker logic, pass `dlq_channel` to handle errors... /// } /// ``` pub fn start(shutdown: Shutdown) -> Result { let (dlq_tx, dlq_rx) = mpsc::channel(200); + + // Build a Kafka producer with standard DSH config let dlq_producer: FutureProducer = ClientConfig::new().set_dsh_producer_config().create()?; + let dlq_dead_topic = get_env_var("DLQ_DEAD_TOPIC")?; let dlq_retry_topic = get_env_var("DLQ_RETRY_TOPIC")?; + let dlq = Self { dlq_producer, dlq_rx, @@ -105,59 +82,68 @@ impl Dlq { dlq_retry_topic, _shutdown: shutdown, }; + + // Spawn the main DLQ processing loop tokio::spawn(dlq.run()); Ok(dlq_tx) } - /// Run the dlq. This will consume messages from the dlq channel and send them to the dlq topics - /// This function will run until the shutdown channel is closed + /// Core loop for receiving `SendToDlq` messages and forwarding them to the correct Kafka topic. + /// + /// Runs until the `mpsc::Receiver` is closed (no more references to the channel exist). async fn run(mut self) { - info!("DLQ started"); - loop { - if let Some(mut dlq_message) = self.dlq_rx.recv().await { - match self.send(&mut dlq_message).await { - Ok(_) => {} - Err(e) => error!("Error sending message to DLQ: {}", e), - }; - } else { - warn!("DLQ stopped as there is no active DLQ Channel"); - break; - } + info!("DLQ started and awaiting messages..."); + while let Some(mut dlq_message) = self.dlq_rx.recv().await { + match self.send(&mut dlq_message).await { + Ok(_) => {} + Err(e) => error!("Error sending message to DLQ: {}", e), + }; } + warn!("DLQ stopped — channel closed, no further messages."); } - /// Create and send message towards the dlq + + /// Sends an individual message to either the “dead” or “retry” topic based on its [`Retryable`] status. + /// + /// # Errors + /// Returns a [`KafkaError`] if the underlying producer fails to publish the message. async fn send(&self, dlq_message: &mut SendToDlq) -> Result<(), KafkaError> { - let orignal_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); - let headers = orignal_kafka_msg + let original_kafka_msg: OwnedMessage = dlq_message.get_original_msg(); + + // Create Kafka headers with error details. + let headers = original_kafka_msg .generate_dlq_headers(dlq_message) .to_owned_headers(); + let topic = self.dlq_topic(dlq_message.retryable); - let key: &[u8] = orignal_kafka_msg.key().unwrap_or_default(); - let payload = orignal_kafka_msg.payload().unwrap_or_default(); - debug!("Sending message to DLQ topic: {}", topic); + let key = original_kafka_msg.key().unwrap_or_default(); + let payload = original_kafka_msg.payload().unwrap_or_default(); + + debug!("Sending DLQ message to topic: {}", topic); + let record = FutureRecord::to(topic) .payload(payload) .key(key) .headers(headers); - let send = self.dlq_producer.send(record, None).await; - match send { - Ok((p, o)) => warn!( - "Message {:?} sent to DLQ topic: {}, partition: {}, offset: {}", + + let result = self.dlq_producer.send(record, None).await; + match result { + Ok((partition, offset)) => warn!( + "DLQ message [{:?}] -> topic: {}, partition: {}, offset: {}", from_utf8(key), topic, - p, - o + partition, + offset ), - Err((e, _)) => return Err(e), + Err((err, _)) => return Err(err), }; Ok(()) } + /// Returns the appropriate DLQ topic, based on whether the message is [`Retryable::Retryable`] or not. fn dlq_topic(&self, retryable: Retryable) -> &str { match retryable { Retryable::Retryable => &self.dlq_retry_topic, - Retryable::NonRetryable => &self.dlq_dead_topic, - Retryable::Other => &self.dlq_dead_topic, + Retryable::NonRetryable | Retryable::Other => &self.dlq_dead_topic, } } } @@ -177,6 +163,7 @@ mod tests { let mut producer = ClientConfig::new(); producer.set("bootstrap.servers", mock_cluster.bootstrap_servers()); let producer = producer.create().unwrap(); + let dlq = Dlq { dlq_producer: producer, dlq_rx: mpsc::channel(200).1, @@ -184,10 +171,14 @@ mod tests { dlq_retry_topic: "retry_topic".to_string(), _shutdown: Shutdown::new(), }; - let error = MockError::MockErrorRetryable("some_error".to_string()); + + // Retryable => "retry_topic" + let error = MockError::MockErrorRetryable("some_error".into()); let topic = dlq.dlq_topic(error.retryable()); assert_eq!(topic, "retry_topic"); - let error = MockError::MockErrorDead("some_error".to_string()); + + // Non-retryable => "dead_topic" + let error = MockError::MockErrorDead("some_error".into()); let topic = dlq.dlq_topic(error.retryable()); assert_eq!(topic, "dead_topic"); } @@ -201,8 +192,9 @@ mod tests { let mut original_headers: OwnedHeaders = OwnedHeaders::new(); original_headers = original_headers.insert(Header { key: "some_key", - value: Some("some_value".as_bytes()), + value: Some(b"some_value"), }); + let owned_message = OwnedMessage::new( Some(vec![1, 2, 3]), Some(vec![4, 5, 6]), @@ -212,38 +204,16 @@ mod tests { offset, Some(original_headers), ); + let dlq_message = MockError::MockErrorRetryable("some_error".to_string()).to_dlq(owned_message.clone()); let result = dlq_message.get_original_msg(); - assert_eq!( - result.payload(), - dlq_message.kafka_message.payload(), - "payoad does not match" - ); - assert_eq!( - result.key(), - dlq_message.kafka_message.key(), - "key does not match" - ); - assert_eq!( - result.topic(), - dlq_message.kafka_message.topic(), - "topic does not match" - ); - assert_eq!( - result.partition(), - dlq_message.kafka_message.partition(), - "partition does not match" - ); - assert_eq!( - result.offset(), - dlq_message.kafka_message.offset(), - "offset does not match" - ); - assert_eq!( - result.timestamp(), - dlq_message.kafka_message.timestamp(), - "timestamp does not match" - ); + + assert_eq!(result.payload(), dlq_message.kafka_message.payload()); + assert_eq!(result.key(), dlq_message.kafka_message.key()); + assert_eq!(result.topic(), dlq_message.kafka_message.topic()); + assert_eq!(result.partition(), dlq_message.kafka_message.partition()); + assert_eq!(result.offset(), dlq_message.kafka_message.offset()); + assert_eq!(result.timestamp(), dlq_message.kafka_message.timestamp()); } } diff --git a/dsh_sdk/src/utils/dlq/mod.rs b/dsh_sdk/src/utils/dlq/mod.rs index 048a8f2..0b45b73 100644 --- a/dsh_sdk/src/utils/dlq/mod.rs +++ b/dsh_sdk/src/utils/dlq/mod.rs @@ -1,33 +1,36 @@ -//! # Dead Letter Queue -//! This optional module contains an implementation of pushing unprocessable/invalid messages towards a Dead Letter Queue (DLQ). -//! It is implemeted with [rdkafka] and [tokio]. +//! Dead Letter Queue (DLQ) client for handling messages that cannot be processed successfully. //! -//! ## Feature flag -//! Add feature `dlq` to your Cargo.toml to enable this module. +//! The [`Dlq`] provides an asynchronous mechanism to route unprocessable messages to special +//! “dead” or “retry” kafka topics. It coordinates with [`Shutdown`](crate::utils::graceful_shutdown::Shutdown) +//! to ensure messages are handled before the application exits. //! -//! ### NOTE: -//! This module is meant for pushing messages towards a dead/retry topic only, it does and WILL not handle any logic for retrying messages. -//! Reason is, it can differ per use case what strategy is needed to retry messages and handle the dead letters. +//! # Overview +//! +//! | **Component** | **Description** | +//! |----------------------|---------------------------------------------------------------------------------------| +//! | [`Dlq`] | Main struct managing the producer, dead/retry topics, and queue of failed messages. | +//! | [`DlqChannel`] | An `mpsc` sender returned by [`Dlq::start`], used by tasks to submit errored messages.| +//! | [`SendToDlq`] | Wrapper carrying both the original Kafka message and error details. | +//! | [`Retryable`] | Enum indicating whether a message is retryable or should be permanently “dead.” | //! -//! It is up to the user to implement the strategy and logic for retrying messages. -//! -//! ## How to use -//! 1. Implement the [ErrorToDlq] trait on top your (custom) error type. -//! 2. Use the [Dlq::start] in your main or at start of your process logic. (this will start the DLQ in a separate tokio task) -//! 3. Get the dlq [DlqChannel] from the [Dlq::start] method and use this channel to communicate errored messages with the [Dlq] via the [ErrorToDlq::to_dlq] method which is implemented on your Error. +//! # Usage Flow +//! 1. **Implement** the [`ErrorToDlq`] trait on your custom error type. +//! 2. **Start** the DLQ by calling [`Dlq::start`], which returns a [`DlqChannel`]. +//! 3. **Own** the [`DlqChannel`] in your processing logic (do **not** hold it in `main`!), and +//! call [`ErrorToDlq::to_dlq`] when you need to push a message/error into the queue. +//! 4. **Graceful Shutdown**: The [`DlqChannel`] should naturally drop during shutdown, letting +//! the `Dlq` finish processing any remaining messages before the application fully closes. //! //! The topics are set via environment variables `DLQ_DEAD_TOPIC` and `DLQ_RETRY_TOPIC`. //! -//! ## Importance of `DlqChannel` in the graceful shutdown procedure -//! The [`Dlq::start`] will return a [`DlqChannel`]. The [`Dlq`] will keep running till the moment [`DlqChannel`] is dropped and finished processing all messages. -//! This also means that the [`Shutdown`] procedure will wait for the [`Dlq`] to finish processing all messages before the application is shut down. -//! This is to make sure that **all** messages are properly processed before the application is shut down. -//! -//! **NEVER** borrow the [`DlqChannel`] but provide the channel as owned/cloned version to your processing logic and **NEVER** keep an owned version in main function, as this will result in a **deadlock** and your application will never shut down. -//! It is fine to start the [`Dlq`] in the main function, but make sure the [`DlqChannel`] is moved to your processing logic. +//! # Important Graceful Shutdown Notes +//! - The [`Dlq`] remains active until the [`DlqChannel`] is dropped and all messages are processed. +//! - Keep the [`DlqChannel`] **in your worker logic** and not in `main`, preventing deadlocks. +//! - The [`Shutdown`](crate::utils::graceful_shutdown::Shutdown) will wait for the DLQ to finish once +//! all channels have closed, ensuring no messages are lost. //! -//! ### Example: -//! +//! # Example: +//! A detailed implementation example can be found in the [DLQ example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) mod dlq; mod error; mod headers; diff --git a/dsh_sdk/src/utils/graceful_shutdown.rs b/dsh_sdk/src/utils/graceful_shutdown.rs index ff83279..2d6b6c2 100644 --- a/dsh_sdk/src/utils/graceful_shutdown.rs +++ b/dsh_sdk/src/utils/graceful_shutdown.rs @@ -1,29 +1,44 @@ -//! Graceful shutdown +//! Graceful Shutdown Module //! -//! This module provides a shutdown handle for graceful shutdown of (tokio tasks within) your service. -//! It listens for SIGTERM requests and sends out shutdown requests to all shutdown handles. +//! This module provides a handle for initiating and coordinating graceful shutdown of your +//! application (e.g., ending background tasks in a controlled manner). It listens for +//! Unix/MacOS/Windows signals (`SIGTERM`, `SIGINT`) and broadcasts a shutdown request to any +//! cloned handles. When a shutdown is requested, tasks can finalize operations before +//! exiting, ensuring a clean teardown of the service. //! -//! It creates a clonable object which can be used to send shutdown request to all tasks. -//! Based on this request you are able to handle your shutdown procedure. +//! The design is inspired by [Tokio’s graceful shutdown approach](https://tokio.rs/tokio/topics/shutdown). //! -//! This appproach is based on Tokio's graceful shutdown example: -//! +//! # Key Components +//! - [`Shutdown`] struct: The main handle that tasks can clone to receive or initiate shutdown. +//! - **Signal Handling**: [`Shutdown::signal_listener`] blocks on system signals and triggers a shutdown. +//! - **Manual Trigger**: [`Shutdown::start`] can be called to programmatically start shutdown. +//! - **Completion Wait**: [`Shutdown::complete`] ensures that all tasks have finished before the main thread exits. //! -//! # Example: +//! # Table of Methods +//! | **Method** | **Description** | +//! |---------------------------- |--------------------------------------------------------------------------------------------------------------------| +//! | [`Shutdown::new`] | Creates a fresh [`Shutdown`] handle, along with a channel to track completion. | +//! | [`Shutdown::clone`] | Clone a [`Shutdown`] handle which is linked to the original handle. | +//! | [`Shutdown::start`] | Signals all clones that a shutdown is in progress, causing each to break out of their loops. | +//! | [`Shutdown::recv`] | Awaitable method for a cloned handle to detect when a shutdown has started. | +//! | [`Shutdown::signal_listener`] | Waits for `SIGTERM`/`SIGINT`, then calls [`start`](Shutdown::start) automatically to notify the other handles. | +//! | [`Shutdown::complete`] | Waits for all handles are finished before returning, ensuring a graceful final exit. | //! +//! # Usage Example //! ```no_run //! use dsh_sdk::utils::graceful_shutdown::Shutdown; +//! use tokio::time::{sleep, Duration}; //! -//! // your process task +//! // A background task that runs until shutdown is requested. //! async fn process_task(shutdown: Shutdown) { //! loop { //! tokio::select! { -//! _ = tokio::time::sleep(tokio::time::Duration::from_secs(1)) => { -//! // Do something here, e.g. consume messages from Kafka -//! println!("Still processing the task") +//! _ = sleep(Duration::from_secs(1)) => { +//! // Perform background work (e.g., read from Kafka, handle jobs, etc.) +//! println!("Still processing the task..."); //! }, //! _ = shutdown.recv() => { -//! // shutdown request received, include your shutdown procedure here e.g. close db connection +//! // A shutdown signal was received; finalize or clean up as needed. //! println!("Gracefully exiting process_task"); //! break; //! }, @@ -31,26 +46,32 @@ //! } //! } //! -//! #[tokio::main] +//! #[tokio::main] //! async fn main() { -//! // Create shutdown handle +//! // Create the primary shutdown handle //! let shutdown = Shutdown::new(); -//! // Create your process task with a cloned shutdown handle -//! let process_task = process_task(shutdown.clone()); -//! // Spawn your process task in a tokio runtime +//! +//! // Clone the handle for use in background tasks +//! let cloned_shutdown = shutdown.clone(); //! let process_task_handle = tokio::spawn(async move { -//! process_task.await; +//! process_task(cloned_shutdown).await; //! }); -//! -//! // Listen for shutdown request or if process task stopped -//! // If your process stops, start shutdown procedure to stop other tasks (if any) +//! +//! // Concurrently wait for OS signals OR for the background task to exit //! tokio::select! { +//! // If a signal (SIGINT or SIGTERM) is received, initiate shutdown //! _ = shutdown.signal_listener() => println!("Exit signal received!"), -//! _ = process_task_handle => {println!("process_task stopped"); shutdown.start()}, +//! +//! // If the background task completes on its own, start the shutdown +//! _ = process_task_handle => { +//! println!("process_task stopped"); +//! shutdown.start(); +//! }, //! } -//! // Wait till shutdown procedures is finished -//! let _ = shutdown.complete().await; -//! println!("Exiting main...") +//! +//! // Wait for all tasks to acknowledge the shutdown and finish +//! shutdown.complete().await; +//! println!("All tasks have completed. Exiting main..."); //! } //! ``` @@ -58,12 +79,18 @@ use log::{info, warn}; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; -/// Shutdown handle to interact on SIGTERM of DSH for a graceful shutdown. +/// A handle that facilitates graceful shutdown of the application or individual tasks. /// -/// Use original to wait for shutdown complete. -/// Use clone to send shutdown request to all shutdown handles. +/// Cloning this handle allows tasks to listen for shutdown signals (internal or +/// from the OS). The original handle can trigger the shutdown and subsequently +/// await the completion of all other handles through [`Shutdown::complete`]. /// -/// see [dsh_sdk::graceful_shutdown](index.html) for full implementation example. +/// # Usage +/// 1. **Create** a primary handle with [`Shutdown::new`]. +/// 2. **Clone** it to each task that needs to respond to a shutdown signal. +/// 3. **Optionally** call [`Shutdown::signal_listener`] in your main or toplevel to wait for OS signals (`SIGTERM`, `SIGINT`). +/// 4. **Call** [`Shutdown::start`] manually if you’d like to trigger a shutdown yourself (e.g., error condition). +/// 5. **Await** [`Shutdown::complete`] to ensure all tasks are finished. #[derive(Debug)] pub struct Shutdown { cancel_token: CancellationToken, @@ -72,12 +99,16 @@ pub struct Shutdown { } impl Shutdown { - /// Create new shutdown handle. - /// Returns shutdown handle and shutdown complete receiver. - /// Shutdown complete receiver is used to wait for all tasks to finish. + /// Creates a new shutdown handle and a completion channel. /// - /// NOTE: Make sure to clone shutdown handles to use it in other components/tasks. - /// Use orignal in main and receive shutdown complete. + /// # Details + /// - The returned handle can be cloned for other tasks. + /// - The original handle retains a `Receiver` so it can wait for the final + /// signal indicating all tasks have ended (`complete`). + /// + /// # Note + /// Ensure that you only keep the original handle in your main function or + /// manager. pub fn new() -> Self { let cancel_token = CancellationToken::new(); let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); @@ -88,30 +119,89 @@ impl Shutdown { } } - /// Send out internal shutdown request to all Shutdown handles, so they can start their shutdown procedure. + /// Initiates the shutdown sequence, notifying all clone holders to stop. + /// + /// This effectively cancels the [`CancellationToken`], causing any tasks + /// awaiting [`recv`](Self::recv) to return immediately. With this, all the + /// other handles know when to gracefully shut down. + /// + /// # Example + /// ``` + /// # use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// let shutdown = Shutdown::new(); + /// // ... spawn tasks ... + /// shutdown.start(); // triggers `recv` in all clones + /// ``` pub fn start(&self) { self.cancel_token.cancel(); } - /// Listen to internal shutdown request. - /// Based on this you can start shutdown procedure in your component/task. + /// Awaits a shutdown signal. + /// + /// If [`start`](Self::start) has already been called, this returns immediately. + /// Otherwise, it suspends the task until the shutdown is triggered or a signal is received. + /// + /// # Example + /// ```no_run + /// # use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// async fn background_task(shutdown: Shutdown) { + /// loop { + /// // Do work here... + /// tokio::select! { + /// _ = shutdown.recv() => { + /// // time to clean up + /// break; + /// } + /// } + /// } + /// } + /// ``` pub async fn recv(&self) { self.cancel_token.cancelled().await; } - /// Listen for external shutdown request coming from DSH (SIGTERM) or CTRL-C/SIGINT and start shutdown procedure. + /// Waits for an external shutdown signal (`SIGINT` or `SIGTERM`) and then calls [`start`](Self::start). + /// + /// ## Compatibility + /// - **Unix**: Waits for `SIGTERM` or `SIGINT`. + /// - **Windows**: Waits for `SIGINT` (Ctrl-C). + /// + /// # Example + /// ```no_run + /// # use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// #[tokio::main] + /// async fn main() { + /// let shutdown = Shutdown::new(); + /// tokio::spawn({ + /// let s = shutdown.clone(); + /// async move { + /// // Some worker logic... + /// s.recv().await; + /// // Cleanup worker + /// } + /// }); /// - /// Compatible with Unix (SIGINT and SIGTERM) and Windows (SIGINT). + /// // Main thread checks for signals + /// shutdown.signal_listener().await; + /// + /// // All tasks are signaled to shut down + /// shutdown.complete().await; + /// println!("All done!"); + /// } + /// ``` pub async fn signal_listener(&self) { let ctrl_c_signal = tokio::signal::ctrl_c(); + #[cfg(unix)] let mut sigterm_signal = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap(); + #[cfg(unix)] tokio::select! { _ = ctrl_c_signal => {}, _ = sigterm_signal.recv() => {} } + #[cfg(windows)] let _ = ctrl_c_signal.await; @@ -119,15 +209,37 @@ impl Shutdown { self.start(); } - /// This function can only be called by the original shutdown handle. + /// Waits for all tasks to confirm they have shut down. + /// + /// This consumes the original [`Shutdown`] handle (the one that includes the + /// receiver), dropping the `Sender` so that `recv` eventually returns. + /// Useful to ensure that no tasks remain active before final exit. /// - /// Check if all tasks are finished and shutdown complete. - /// This function should be awaited after all tasks are spawned. + /// # Note + /// Calling `complete` on a cloned handle is invalid, as clones can not hold + /// the completion receiver. + /// + /// # Example + /// ```no_run + /// # use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// #[tokio::main] + /// async fn main() { + /// let shutdown = Shutdown::new(); + /// // spawn tasks with shutdown.clone()... + /// shutdown.start(); + /// shutdown.complete().await; // blocks until all tasks have reported completion + /// println!("Graceful shutdown finished!"); + /// } + /// ``` pub async fn complete(self) { - // drop original shutdown_complete_tx, else it would await forever + // Dropping the transmitter ensures that once all clones are dropped, + // the channel closes. The last task to shut down doesn't hold a Tx, + // so the moment they stop using it, the channel will close. drop(self.shutdown_complete_tx); + + // Wait for the channel to be closed (i.e., all tasks done). self.shutdown_complete_rx.unwrap().recv().await; - info!("Shutdown complete!") + info!("Shutdown complete!"); } } @@ -137,10 +249,26 @@ impl Default for Shutdown { } } -impl std::clone::Clone for Shutdown { - /// Clone shutdown handle. +impl Clone for Shutdown { + /// Creates a cloned [`Shutdown`] handle that can receive and/or trigger + /// shutdown, but does **not** hold the channel receiver for `complete`. /// - /// Use this handle in your components/tasks. + /// # Example + /// ```no_run + /// # use dsh_sdk::utils::graceful_shutdown::Shutdown; + /// async fn worker_task(shutdown: Shutdown) { + /// shutdown.recv().await; + /// // Cleanup... + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let shutdown = Shutdown::new(); + /// let worker_shutdown = shutdown.clone(); + /// tokio::spawn(worker_task(worker_shutdown)); + /// // ... + /// } + /// ``` fn clone(&self) -> Self { Self { cancel_token: self.cancel_token.clone(), @@ -152,26 +280,28 @@ impl std::clone::Clone for Shutdown { #[cfg(test)] mod tests { - use std::sync::{Arc, Mutex}; - use super::*; + use std::sync::{Arc, Mutex}; use tokio::time::Duration; #[tokio::test] async fn test_shutdown_recv() { let shutdown = Shutdown::new(); let shutdown_clone = shutdown.clone(); - // receive shutdown task + // This task listens for shutdown: let task = tokio::spawn(async move { shutdown_clone.recv().await; 1 }); - // start shutdown task after 200 ms - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(200)).await; - shutdown.start(); + // Trigger shutdown after 200ms + tokio::spawn({ + let s = shutdown.clone(); + async move { + tokio::time::sleep(Duration::from_millis(200)).await; + s.start(); + } }); - // if shutdown is not received within 5 seconds, fail test + // If no shutdown within 5s, fail let check_value = tokio::select! { _ = tokio::time::sleep(Duration::from_secs(5)) => panic!("Shutdown not received within 5 seconds"), v = task => v.unwrap(), @@ -183,21 +313,27 @@ mod tests { async fn test_shutdown_wait_for_complete() { let shutdown = Shutdown::new(); let shutdown_clone = shutdown.clone(); + let check_value: Arc> = Arc::new(Mutex::new(false)); let check_value_clone = Arc::clone(&check_value); - // receive shutdown task + + // A task that waits for shutdown, then sets a flag tokio::spawn(async move { shutdown_clone.recv().await; tokio::time::sleep(Duration::from_millis(200)).await; - let mut check: std::sync::MutexGuard<'_, bool> = check_value_clone.lock().unwrap(); - *check = true; + let mut guard = check_value_clone.lock().unwrap(); + *guard = true; }); + + // Initiate shutdown shutdown.start(); + // Ensure all tasks are done shutdown.complete().await; - let check = check_value.lock().unwrap(); - assert_eq!( - *check, true, - "shutdown did not succesfully wait for complete" + + let guard = check_value.lock().unwrap(); + assert!( + *guard, + "Shutdown did not successfully wait for completion of tasks." ); } } diff --git a/dsh_sdk/src/utils/metrics.rs b/dsh_sdk/src/utils/metrics.rs index a55a4e0..33697bf 100644 --- a/dsh_sdk/src/utils/metrics.rs +++ b/dsh_sdk/src/utils/metrics.rs @@ -1,33 +1,53 @@ -//! Provides a lightweight HTTP server to expose (prometheus) metrics. +//! Provides a lightweight HTTP server to expose (Prometheus) metrics. //! -//! ## Expose metrics to DSH / HTTP Server +//! This module runs a simple HTTP server that listens on a specified port +//! and serves an endpoint (`/metrics`) which returns a plain-text string +//! representation of your metrics. It can be used to expose metrics to DSH +//! or any Prometheus-compatible monitoring service. //! -//! This module provides a http server to expose the metrics to DSH. A port number and a function that encode the metrics to [String] needs to be defined. +//! # Overview //! -//! Most metrics libraries provide a way to encode the metrics to a string. For example, -//! - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. -//! - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +//! - **Port**: Chosen at runtime; ensure it’s exposed in your container if using Docker. +//! - **Metrics Encoder**: You supply a function that returns a `String` representation +//! of your metrics (e.g., from a Prometheus client library). +//! - **Thread Model**: The server runs on a separate Tokio task. You can optionally +//! keep the resulting `JoinHandle` if you want to monitor or manage its lifecycle. //! -//! See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. +//! # Common Usage +//! 1. **Define a function** that gathers and encodes your metrics to a `String`. +//! 2. **Call** [`start_http_server`] with the port and your metrics function. +//! 3. **Access** your metrics at `http://:/metrics`. +//! 4. **Configure** your DSH or Docker environment accordingly (if needed). //! -//! ### Example: +//! ## Example //! ``` //! use dsh_sdk::utils::metrics::start_http_server; //! //! fn encode_metrics() -> String { -//! // Provide here your logic to gather and encode the metrics to a string -//! // Check your chosen metrics library for the correct implementation -//! "my_metrics 1".to_string() // Dummy example +//! // Provide custom logic to gather and encode metrics into a string. +//! // Example below is a placeholder. +//! "my_counter 1".to_string() //! } //! //! #[tokio::main] //! async fn main() { -//! start_http_server(9090, encode_metrics); -//!} +//! // Launch a metrics server on port 9090 +//! start_http_server(9090, encode_metrics); +//! // The server runs until the main thread stops or is aborted. +//! // ... +//! } +//! ``` +//! +//! Once running, you can query your metrics at `http://localhost:9090/metrics`. +//! +//! # Configuration with DSH +//! +//! In your Dockerfile, be sure to expose that port: +//! ```dockerfile +//! EXPOSE 9090 //! ``` -//! After starting the http server, the metrics can be found at http://localhost:9090/metrics. -//! To expose the metrics to DSH, the port number needs to be defined in the DSH service configuration. //! +//! Then, in your DSH service configuration, specify the port and path for the metrics: //! ```json //! "metrics": { //! "port": 9090, @@ -35,9 +55,40 @@ //! }, //! ``` //! -//! And in your dockerfile expose the port: -//! ```dockerfile -//! EXPOSE 9090 +//! # Monitoring the Server Task +//! +//! `start_http_server` spawns a Tokio task which returns a [`JoinHandle`]. You can: +//! - **Ignore** it: The server continues until the main application exits. +//! - **Await** it to see if the server encounters an error or closes unexpectedly. +//! +//! ```no_run +//! # use dsh_sdk::utils::metrics::start_http_server; +//! # use tokio::time::sleep; +//! # use std::time::Duration; +//! fn encode_metrics() -> String { +//! "my_metrics 1".to_string() // Dummy example +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let server_handle = start_http_server(9090, encode_metrics); +//! tokio::select! { +//! // Some app logic or graceful shutdown condition +//! _ = sleep(Duration::from_secs(300)) => { +//! println!("Main application stopping..."); +//! } +//! +//! // If the metrics server stops unexpectedly, handle the error +//! result = server_handle => { +//! match result { +//! Ok(Ok(())) => println!("Metrics server finished gracefully."), +//! Ok(Err(e)) => eprintln!("Metrics server error: {}", e), +//! Err(join_err) => eprintln!("Metrics server thread panicked: {}", join_err), +//! } +//! } +//! } +//! println!("All done!"); +//! } //! ``` use std::net::SocketAddr; @@ -58,75 +109,47 @@ type BoxBody = http_body_util::combinators::BoxBody; static NOTFOUND: &[u8] = b"404: Not Found"; +/// Errors that can occur while running the metrics server. #[derive(Error, Debug)] #[non_exhaustive] pub enum MetricsError { - #[error("IO Error: {0}")] + /// An I/O error occurred (e.g., binding the port failed). + #[error("I/O Error: {0}")] IoError(#[from] std::io::Error), + + /// An HTTP error occurred while building or sending a response. #[error("Hyper error: {0}")] HyperError(#[from] hyper::http::Error), } -/// A lihghtweight HTTP server to expose prometheus metrics. -/// -/// The exposed endpoint is /metrics and port number needs to be defined together with your gather and encode function to string. -/// The server will run on a separate thread and this function will return a JoinHandle of the thread. -/// It is optional to handle the thread status. If left unhandled, the server will run until the main thread is stopped. +/// Starts a lightweight HTTP server to expose Prometheus-like metrics on `"/metrics"`. /// -/// ## Expose metrics to DSH / HTTP Server +/// # Parameters +/// - `port`: The port on which the server listens (e.g., `9090`). +/// - `metrics_encode_fn`: A function returning a `String` containing all relevant metrics. /// -/// This module provides a http server to expose the metrics to DSH. A port number and a function that encode the metrics to [String] needs to be defined. +/// # Returns +/// A [`JoinHandle`] wrapping a [`Result<(), MetricsError>`]. The server: +/// - Runs until the main process exits or the handle is aborted. +/// - May exit early if an underlying error (`MetricsError`) occurs. /// -/// Most metrics libraries provide a way to encode the metrics to a string. For example, -/// - [prometheus-client](https://crates.io/crates/prometheus-client) library provides a [render](https://docs.rs/prometheus-client/latest/prometheus_client/encoding/text/index.html) function to encode the metrics to a string. -/// - [prometheus](https://crates.io/crates/prometheus) library provides a [TextEncoder](https://docs.rs/prometheus/latest/prometheus/struct.TextEncoder.html) to encode the metrics to a string. +/// # Example +/// ```no_run +/// use dsh_sdk::utils::metrics::start_http_server; /// -/// See [expose_metrics.rs](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) for a full example implementation. -/// -/// ## Example -/// This starts a http server on port 9090 on a separate thread. The server will run until the main thread is stopped. -/// ``` -/// use dsh_sdk::utils::metrics::start_http_server; -/// -/// fn encode_metrics() -> String { -/// // Provide here your logic to gather and encode the metrics to a string -/// // Check your chosen metrics library for the correct implementation -/// "my_metrics 1".to_string() // Dummy example -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// start_http_server(9090, encode_metrics); +/// fn encode_metrics() -> String { +/// // Provide logic that gathers and encodes your metrics as a string. +/// "my_counter 123".to_string() /// } -/// ``` /// -/// # Optional: Check http server thread status -/// Await the JoinHandle in a a tokio select besides your application logic to check if the server is still running. -/// ```rust -/// # use dsh_sdk::utils::metrics::start_http_server; -/// # use tokio::time::sleep; -/// # use std::time::Duration; -/// # fn encode_metrics() -> String { -/// # "my_metrics 1".to_string() // Dummy example -/// # } -/// # #[tokio::main] -/// # async fn main() { -/// let server = start_http_server(9090, encode_metrics); -/// tokio::select! { -/// // Replace sleep with your application logic -/// _ = sleep(Duration::from_secs(1)) => {println!("Application is stoped!")}, -/// // Check if the server is still running -/// tokio_result = server => { -/// match tokio_result { -/// Ok(server_result) => if let Err(e) = server_result { -/// eprintln!("Metrics server operation failed: {}", e); -/// }, -/// Err(e) => println!("Server thread stopped unexpectedly: {}", e), -/// } -/// } +/// #[tokio::main] +/// async fn main() { +/// start_http_server(9090, encode_metrics); +/// // The server runs in the background until main ends. /// } -/// # } /// ``` +/// +/// See the module-level docs for more details on usage patterns. pub fn start_http_server( port: u16, metrics_encode_fn: fn() -> String, @@ -137,17 +160,19 @@ pub fn start_http_server( }; tokio::spawn(async move { let result = server.run_server().await; - warn!("HTTP server stopped: {:?}", result); + warn!("HTTP metrics server stopped: {:?}", result); result }) } +/// Internal struct containing server configuration and logic. struct MetricsServer { port: u16, metrics_encode_fn: fn() -> String, } impl MetricsServer { + /// Runs the server in a loop, accepting connections and handling them. async fn run_server(&self) -> Result<(), MetricsError> { let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); let listener = TcpListener::bind(addr).await?; @@ -158,6 +183,7 @@ impl MetricsServer { } } + /// Handles an individual TCP connection by serving HTTP/1.1 requests. async fn handle_connection(&self, stream: tokio::net::TcpStream) { let io = TokioIo::new(stream); let service = service_fn(|req| self.routes(req)); @@ -166,13 +192,15 @@ impl MetricsServer { } } + /// Routes requests to the correct handler based on method & path. async fn routes(&self, req: Request) -> Result, MetricsError> { match (req.method(), req.uri().path()) { (&Method::GET, "/metrics") => self.get_metrics(), - (_, _) => not_found(), + _ => not_found(), } } + /// Generates a response containing the metrics string. fn get_metrics(&self) -> Result, MetricsError> { let body = (self.metrics_encode_fn)(); Ok(Response::builder() @@ -182,12 +210,14 @@ impl MetricsServer { } } +/// Returns a 404 Not Found response. fn not_found() -> Result, MetricsError> { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(full(NOTFOUND))?) } +/// Converts a string (or byte slice) into a boxed HTTP body. fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()) .map_err(|never| match never {}) @@ -209,7 +239,7 @@ mod tests { const PORT: u16 = 9090; - /// Gather and encode metrics to a string (UTF8) + /// Example function to gather metrics from the `prometheus` crate. pub fn metrics_to_string() -> String { let encoder = prometheus::TextEncoder::new(); encoder @@ -219,7 +249,7 @@ mod tests { lazy_static! { pub static ref HIGH_FIVE_COUNTER: IntCounter = - register_int_counter!("highfives", "Number of high fives recieved").unwrap(); + register_int_counter!("highfives", "Number of high fives received").unwrap(); } async fn create_client( @@ -228,7 +258,7 @@ mod tests { SendRequest>, Connection, Empty>, ) { - let host = url.host().expect("uri has no host"); + let host = url.host().expect("URI has no host"); let port = url.port_u16().unwrap_or(PORT); let addr = format!("{}:{}", host, port); @@ -242,8 +272,8 @@ mod tests { Request::builder() .uri(url) .method(Method::GET) - .header(header::HOST, url.authority().unwrap().clone().as_str()) - .body(Empty::::new()) + .header(header::HOST, url.authority().unwrap().as_str()) + .body(Empty::new()) .unwrap() } @@ -256,22 +286,14 @@ mod tests { port: PORT, metrics_encode_fn: metrics_to_string, }; - // Call the function - let res = server.get_metrics(); - - // Check if the function returns a result - assert!(res.is_ok()); - - // Check if the result is not an empty string - let response = res.unwrap(); - let status_code = response.status(); - - assert_eq!(status_code, StatusCode::OK); - assert!(response.body().size_hint().exact().unwrap() > 0); + let response = server.get_metrics().expect("failed to get metrics"); + assert_eq!(response.status(), StatusCode::OK); assert_eq!( response.headers().get(header::CONTENT_TYPE).unwrap(), "text/plain" ); + // Ensure the body is non-empty + assert!(response.body().size_hint().exact().unwrap() > 0); } #[tokio::test] @@ -308,11 +330,9 @@ mod tests { // Check if the response body is not empty let buf = response.collect().await.unwrap().to_bytes(); let res = String::from_utf8(buf.to_vec()).unwrap(); - - println!("{}", res); assert!(!res.is_empty()); - // Terminate the server + // Stop the server server.abort(); } @@ -333,21 +353,17 @@ mod tests { } }); - // Send a request to the server + // Send a request to the server with no path (i.e., "/") let request = to_get_req(&url); - let response = request_sender.send_request(request).await.unwrap(); - - // Check if the server returns a 404 status assert_eq!(response.status(), StatusCode::NOT_FOUND); - // Check if the response body is not empty + // Check the 404 body let buf = response.collect().await.unwrap().to_bytes(); let res = String::from_utf8(buf.to_vec()).unwrap(); - assert_eq!(res, String::from_utf8_lossy(NOTFOUND)); - // Terminate the server + // Stop the server server.abort(); } } diff --git a/dsh_sdk/src/utils/platform.rs b/dsh_sdk/src/utils/platform.rs index 7a684de..2603f1f 100644 --- a/dsh_sdk/src/utils/platform.rs +++ b/dsh_sdk/src/utils/platform.rs @@ -1,52 +1,69 @@ -/// Available DSH platforms plus it's related metadata +//! Provides an enum of DSH platforms and related metadata. +//! +//! This module defines the [`Platform`] enum, representing different DSH deployments, +//! each with its own realm, REST API endpoints, and token endpoints. The platform choice +//! influences how you authenticate and where you send REST/Protocol requests. +//! +//! # Platforms +//! The platforms defined are: +//! - `Prod` (kpn-dsh.com) +//! - `ProdAz` (az.kpn-dsh.com) +//! - `ProdLz` (dsh-prod.dsh.prod.aws.kpn.com) +//! - `NpLz` (dsh-dev.dsh.np.aws.kpn.com) +//! - `Poc` (poc.kpn-dsh.com) +//! +//! ## Usage +//! Use a [`Platform`] variant to generate appropriate URLs and client IDs for your environment. +//! For example, you might select `Platform::NpLz` when deploying a service to the development +//! landing zone. + +/// Represents an available DSH platform and its related metadata. /// -/// The platform enum contains +/// The platform defined are: /// - `Prod` (kpn-dsh.com) /// - `ProdAz` (az.kpn-dsh.com) /// - `ProdLz` (dsh-prod.dsh.prod.aws.kpn.com) /// - `NpLz` (dsh-dev.dsh.np.aws.kpn.com) /// - `Poc` (poc.kpn-dsh.com) -/// -/// Each platform has it's own realm, endpoint for the DSH Rest API and endpoint for the DSH Rest API access token. #[derive(Clone, Debug)] #[non_exhaustive] pub enum Platform { - /// Production platform (kpn-dsh.com) + /// Production platform (`kpn-dsh.com`). Prod, - /// Production platform on Azure (az.kpn-dsh.com) + /// Production platform on Azure (`az.kpn-dsh.com`). ProdAz, - /// Production Landing Zone on AWS (dsh-prod.dsh.prod.aws.kpn.com) + /// Production Landing Zone on AWS (`dsh-prod.dsh.prod.aws.kpn.com`). ProdLz, - /// Non-Production (Dev) Landing Zone on AWS (dsh-dev.dsh.np.aws.kpn.com) + /// Non-Production (Dev) Landing Zone on AWS (`dsh-dev.dsh.np.aws.kpn.com`). NpLz, - /// Proof of Concept platform (poc.kpn-dsh.com) + /// Proof of Concept platform (`poc.kpn-dsh.com`). Poc, } impl Platform { - /// Get a properly formatted client_id for the Rest API based on the given name of a tenant + /// Returns a properly formatted client ID for the DSH REST API, given a tenant name. /// - /// It will return a string formatted as "robot:{realm}:{tenant_name}" + /// The format is: + /// \[ + /// `"robot:{realm}:{tenant_name}"` + /// \] /// - /// ## Example + /// # Example /// ``` /// # use dsh_sdk::Platform; /// let platform = Platform::NpLz; /// let client_id = platform.rest_client_id("my-tenant"); /// assert_eq!(client_id, "robot:dev-lz-dsh:my-tenant"); /// ``` - pub fn rest_client_id(&self, tenant: T) -> String - where - T: AsRef, - { + pub fn rest_client_id(&self, tenant: impl AsRef) -> String { format!("robot:{}:{}", self.realm(), tenant.as_ref()) } - /// Get the endpoint for the DSH Rest API + /// Returns the base URL for the DSH REST API, depending on the platform. /// - /// It will return the endpoint for the DSH Rest API based on the platform + /// This endpoint is typically used for general resource operations in DSH. /// - /// ## Example + /// # Example /// ``` /// # use dsh_sdk::Platform; /// let platform = Platform::NpLz; @@ -62,42 +79,46 @@ impl Platform { Self::Poc => "https://api.poc.kpn-dsh.com/resources/v0", } } - /// Get the endpoint for the DSH Rest API access token - /// - /// It will return the endpoint for the DSH Rest API access token based on the platform + + /// Returns the URL endpoint for retrieving DSH REST API OAuth tokens. /// - /// ## Example + /// # Example /// ``` /// # use dsh_sdk::Platform; /// let platform = Platform::NpLz; - /// let endpoint = platform.endpoint_rest_access_token(); - /// assert_eq!(endpoint, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); + /// let token_url = platform.endpoint_rest_access_token(); + /// assert_eq!(token_url, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); /// ``` pub fn endpoint_rest_access_token(&self) -> &str { match self { - Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", - Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", + Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", + Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", Self::ProdLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/prod-lz-dsh/protocol/openid-connect/token", Self::ProdAz => "https://auth.prod.cp.kpn-dsh.com/auth/realms/prod-azure-dsh/protocol/openid-connect/token", - Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", + Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", } } - #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] - /// Get the endpoint for fetching DSH Rest Authentication Token - /// - /// With this token you can authenticate for the mqtt token endpoint + /// (Deprecated) Returns the DSH REST authentication token endpoint. /// - /// It will return the endpoint for DSH Rest authentication token based on the platform + /// *Prefer using [`endpoint_management_api_token`](Self::endpoint_management_api_token) instead.* + #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] pub fn endpoint_rest_token(&self) -> &str { self.endpoint_management_api_token() } - /// Get the endpoint for fetching DSH Rest Authentication Token + /// Returns the endpoint for fetching a DSH Management API authentication token. /// - /// With this token you can authenticate for the mqtt token endpoint + /// This endpoint is typically used to authenticate requests to certain management or admin-level + /// DSH services. /// - /// It will return the endpoint for DSH Rest authentication token based on the platform + /// # Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let mgmt_token_url = platform.endpoint_management_api_token(); + /// assert_eq!(mgmt_token_url, "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token"); + /// ``` pub fn endpoint_management_api_token(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", @@ -108,17 +129,23 @@ impl Platform { } } - #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] - /// Get the endpoint for fetching DSH mqtt token + /// (Deprecated) Returns the DSH MQTT token endpoint. /// - /// It will return the endpoint for DSH MQTT Token based on the platform + /// *Prefer using [`endpoint_protocol_token`](Self::endpoint_protocol_token) instead.* + #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] pub fn endpoint_mqtt_token(&self) -> &str { self.endpoint_protocol_token() } - /// Get the endpoint for fetching DSH Protocol token + /// Returns the endpoint for fetching DSH protocol tokens (e.g., for MQTT). /// - /// It will return the endpoint for DSH Protocol adapter Token based on the platform + /// # Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::Prod; + /// let protocol_token_url = platform.endpoint_protocol_token(); + /// assert_eq!(protocol_token_url, "https://api.kpn-dsh.com/datastreams/v0/mqtt/token"); + /// ``` pub fn endpoint_protocol_token(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/datastreams/v0/mqtt/token", @@ -129,6 +156,16 @@ impl Platform { } } + /// Returns the Keycloak realm string associated with this platform. + /// + /// This is used to construct OpenID Connect tokens (e.g., for Kafka or REST API authentication). + /// + /// # Example + /// ``` + /// # use dsh_sdk::Platform; + /// let realm = Platform::Prod.realm(); + /// assert_eq!(realm, "tt-dsh"); + /// ``` pub fn realm(&self) -> &str { match self { Self::Prod => "tt-dsh", @@ -143,6 +180,7 @@ impl Platform { #[cfg(test)] mod tests { use super::*; + #[test] fn test_platform_realm() { assert_eq!(Platform::NpLz.realm(), "dev-lz-dsh"); From e1ce115da06362ad91672acde63425a080485241 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:11:41 +0100 Subject: [PATCH 20/23] rename platform methods to meaningful names (#115) * rename platform methods to meaningful names --- dsh_sdk/CHANGELOG.md | 14 +- .../examples/management_api_token_fetcher.rs | 2 +- dsh_sdk/src/utils/platform.rs | 182 +++++++++++++++++- 3 files changed, 183 insertions(+), 15 deletions(-) diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 3df79ff..7a6538e 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bootstrap to DSH - Read certificates from PKI_CONFIG_DIR - Add support reading private key in DER format when reading from PKI_CONFIG_DIR +- Implement `TryFrom<&Str>` and `RryFrom` for `dsh_sdk::Platform` ### Changed - **Breaking change:** `DshError` is now split into error enums per feature flag to untangle mess @@ -20,18 +21,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Breaking change:** `dsh_sdk::Dsh::reqwest_client_config` now returns `reqwest::ClientConfig` instead of `Result` - **Breaking change:** `dsh_sdk::Dsh::reqwest_blocking_client_config` now returns `reqwest::ClientConfig` instead of `Result` - **Breaking change:** `dsh_sdk::utils::Dlq` does not require `Dsh`/`Properties` as argument anymore -- Deprecated `dsh_sdk::dsh::properties` module -- Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` -- `dsh_sdk::error::DshRestTokenError` renamed to `dsh_sdk::management_api::error::ManagementApiTokenError` +- **Breaking change:** `dsh_sdk::utils::Dlq::new` is removed and replaced with `dsh_sdk::utils::Dlq::start` which starts the DLQ and returns a channel to send dlq messages +- **Breaking change:** Deprecated `dsh_sdk::dsh::properties` module +- **Breaking change:** Moved `dsh_sdk::rest_api_token_fetcher` to `dsh_sdk::management_api::token_fetcher` and renamed `RestApiTokenFetcher` to `ManagementApiTokenFetcher` +- **Breaking change:** `dsh_sdk::error::DshRestTokenError` renamed to `dsh_sdk::management_api::error::ManagementApiTokenError` - **NOTE** Cargo.toml feature flag `rest-token-fetcher` renamed to`management-api-token-fetcher` - Moved `dsh_sdk::dsh::datastream` to `dsh_sdk::datastream` - Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module -- Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` and renamed to `ProtocolTokenFetcher` - - Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` +- **Breaking change:** Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` and renamed to `ProtocolTokenFetcher` + - **NOTE** Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` +- **Breaking change:** Renamed `dsh_sdk::Platform` methods to more meaningful names - **Breaking change:** Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` - **Breaking change:** Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` - **Breaking change:** Moved `dsh_sdk::metrics` to `dsh_sdk::utils::metrics` +- **Breaking change:** `dsh_sdk::utils::metrics::start_metrics_server` requires `fn() -> String` which gathers and encodes metrics ### Removed - Removed `dsh_sdk::rdkafka` public re-export, import `rdkafka` directly diff --git a/dsh_sdk/examples/management_api_token_fetcher.rs b/dsh_sdk/examples/management_api_token_fetcher.rs index 899819d..8ae1b35 100644 --- a/dsh_sdk/examples/management_api_token_fetcher.rs +++ b/dsh_sdk/examples/management_api_token_fetcher.rs @@ -15,7 +15,7 @@ async fn main() { let client_secret = env::var("CLIENT_SECRET").expect("CLIENT_SECRET must be set as environment variable"); let tenant = env::var("TENANT").expect("TENANT must be set as environment variable"); - let client = Client::new(platform.endpoint_rest_api()); + let client = Client::new(platform.endpoint_management_api()); let tf = ManagementApiTokenFetcherBuilder::new(platform) .tenant_name(tenant.clone()) .client_secret(client_secret) diff --git a/dsh_sdk/src/utils/platform.rs b/dsh_sdk/src/utils/platform.rs index 2603f1f..6173c91 100644 --- a/dsh_sdk/src/utils/platform.rs +++ b/dsh_sdk/src/utils/platform.rs @@ -25,7 +25,9 @@ /// - `ProdLz` (dsh-prod.dsh.prod.aws.kpn.com) /// - `NpLz` (dsh-dev.dsh.np.aws.kpn.com) /// - `Poc` (poc.kpn-dsh.com) -#[derive(Clone, Debug)] +/// +/// Each platform has it's own realm, endpoint for the DSH Rest API and endpoint for the DSH Rest API access token. +#[derive(Clone, Debug, PartialEq)] #[non_exhaustive] pub enum Platform { /// Production platform (`kpn-dsh.com`). @@ -41,6 +43,10 @@ pub enum Platform { } impl Platform { + #[deprecated( + since = "0.5.0", + note = "Use `dsh_sdk::Platform::management_api_client_id` instead" + )] /// Returns a properly formatted client ID for the DSH REST API, given a tenant name. /// /// The format is: @@ -56,10 +62,32 @@ impl Platform { /// assert_eq!(client_id, "robot:dev-lz-dsh:my-tenant"); /// ``` pub fn rest_client_id(&self, tenant: impl AsRef) -> String { + self.management_api_client_id(tenant) + } + + /// Returns a properly formatted client ID for the DSH Management API, given a tenant name. + /// + /// The format is: + /// \[ + /// `"robot:{realm}:{tenant_name}"` + /// \] + /// + /// # Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let client_id = platform.rest_client_id("my-tenant"); + /// assert_eq!(client_id, "robot:dev-lz-dsh:my-tenant"); + /// ``` + pub fn management_api_client_id(&self, tenant: impl AsRef) -> String { format!("robot:{}:{}", self.realm(), tenant.as_ref()) } - /// Returns the base URL for the DSH REST API, depending on the platform. + #[deprecated( + since = "0.5.0", + note = "Use `dsh_sdk::Platform::endpoint_management_api` instead" + )] + /// Get the endpoint for the DSH Rest API /// /// This endpoint is typically used for general resource operations in DSH. /// @@ -71,6 +99,21 @@ impl Platform { /// assert_eq!(endpoint, "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0"); /// ``` pub fn endpoint_rest_api(&self) -> &str { + self.endpoint_management_api() + } + + /// Returns the endpoint for the DSH Management API + /// + /// It will return the endpoint for the DSH Rest API based on the platform + /// + /// ## Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let endpoint = platform.endpoint_management_api(); + /// assert_eq!(endpoint, "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0"); + /// ``` + pub fn endpoint_management_api(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/resources/v0", Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/resources/v0", @@ -99,10 +142,15 @@ impl Platform { } } - /// (Deprecated) Returns the DSH REST authentication token endpoint. + #[deprecated( + since = "0.5.0", + note = "Use `dsh_sdk::Platform::endpoint_management_api_token` instead" + )] + /// Get the endpoint for fetching DSH Rest Authentication Token + /// + /// With this token you can authenticate for the mqtt token endpoint /// - /// *Prefer using [`endpoint_management_api_token`](Self::endpoint_management_api_token) instead.* - #[deprecated(since = "0.5.0", note = "Use `endpoint_management_api_token` instead")] + /// It will return the endpoint for DSH Rest authentication token based on the platform pub fn endpoint_rest_token(&self) -> &str { self.endpoint_management_api_token() } @@ -129,10 +177,13 @@ impl Platform { } } + #[deprecated( + since = "0.5.0", + note = "Use `dsh_sdk::Platform::endpoint_protocol_token` instead" + )] /// (Deprecated) Returns the DSH MQTT token endpoint. /// /// *Prefer using [`endpoint_protocol_token`](Self::endpoint_protocol_token) instead.* - #[deprecated(since = "0.5.0", note = "Use `endpoint_protocol_token` instead")] pub fn endpoint_mqtt_token(&self) -> &str { self.endpoint_protocol_token() } @@ -177,6 +228,29 @@ impl Platform { } } +impl TryFrom<&str> for Platform { + type Error = &'static str; + + fn try_from(value: &str) -> Result { + match value.to_lowercase().replace("-", "").as_str() { + "prod" => Ok(Self::Prod), + "prodaz" => Ok(Self::ProdAz), + "prodlz" => Ok(Self::ProdLz), + "nplz" => Ok(Self::NpLz), + "poc" => Ok(Self::Poc), + _ => Err("Invalid platform"), + } + } +} + +impl TryFrom for Platform { + type Error = &'static str; + + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -191,16 +265,106 @@ mod tests { #[test] fn test_platform_client_id() { assert_eq!( - Platform::NpLz.rest_client_id("my-tenant"), + Platform::NpLz.management_api_client_id("my-tenant"), "robot:dev-lz-dsh:my-tenant" ); assert_eq!( - Platform::ProdLz.rest_client_id("my-tenant".to_string()), + Platform::ProdLz.management_api_client_id("my-tenant".to_string()), "robot:prod-lz-dsh:my-tenant" ); assert_eq!( - Platform::Poc.rest_client_id("my-tenant"), + Platform::Poc.management_api_client_id("my-tenant"), "robot:poc-dsh:my-tenant" ); } + + #[test] + fn test_try_from_str() { + assert_eq!(Platform::try_from("prod").unwrap(), Platform::Prod); + assert_eq!(Platform::try_from("PROD").unwrap(), Platform::Prod); + assert_eq!(Platform::try_from("prod-az").unwrap(), Platform::ProdAz); + assert_eq!(Platform::try_from("PROD-AZ").unwrap(), Platform::ProdAz); + assert_eq!(Platform::try_from("prodaz").unwrap(), Platform::ProdAz); + assert_eq!(Platform::try_from("PRODAZ").unwrap(), Platform::ProdAz); + assert_eq!(Platform::try_from("prod-lz").unwrap(), Platform::ProdLz); + assert_eq!(Platform::try_from("PROD-LZ").unwrap(), Platform::ProdLz); + assert_eq!(Platform::try_from("prodlz").unwrap(), Platform::ProdLz); + assert_eq!(Platform::try_from("PRODLZ").unwrap(), Platform::ProdLz); + assert_eq!(Platform::try_from("np-lz").unwrap(), Platform::NpLz); + assert_eq!(Platform::try_from("NP-LZ").unwrap(), Platform::NpLz); + assert_eq!(Platform::try_from("nplz").unwrap(), Platform::NpLz); + assert_eq!(Platform::try_from("NPLZ").unwrap(), Platform::NpLz); + assert_eq!(Platform::try_from("poc").unwrap(), Platform::Poc); + assert_eq!(Platform::try_from("POC").unwrap(), Platform::Poc); + assert!(Platform::try_from("invalid").is_err()); + } + + #[test] + fn test_try_from_string() { + assert_eq!( + Platform::try_from("prod".to_string()).unwrap(), + Platform::Prod + ); + assert_eq!( + Platform::try_from("PROD".to_string()).unwrap(), + Platform::Prod + ); + assert_eq!( + Platform::try_from("prod-az".to_string()).unwrap(), + Platform::ProdAz + ); + assert_eq!( + Platform::try_from("PROD-AZ".to_string()).unwrap(), + Platform::ProdAz + ); + assert_eq!( + Platform::try_from("prodaz".to_string()).unwrap(), + Platform::ProdAz + ); + assert_eq!( + Platform::try_from("PRODAZ".to_string()).unwrap(), + Platform::ProdAz + ); + assert_eq!( + Platform::try_from("prod-lz".to_string()).unwrap(), + Platform::ProdLz + ); + assert_eq!( + Platform::try_from("PROD-LZ".to_string()).unwrap(), + Platform::ProdLz + ); + assert_eq!( + Platform::try_from("prodlz".to_string()).unwrap(), + Platform::ProdLz + ); + assert_eq!( + Platform::try_from("PRODLZ".to_string()).unwrap(), + Platform::ProdLz + ); + assert_eq!( + Platform::try_from("np-lz".to_string()).unwrap(), + Platform::NpLz + ); + assert_eq!( + Platform::try_from("NP-LZ".to_string()).unwrap(), + Platform::NpLz + ); + assert_eq!( + Platform::try_from("nplz".to_string()).unwrap(), + Platform::NpLz + ); + assert_eq!( + Platform::try_from("NPLZ".to_string()).unwrap(), + Platform::NpLz + ); + assert_eq!( + Platform::try_from("poc".to_string()).unwrap(), + Platform::Poc + ); + assert_eq!( + Platform::try_from("POC".to_string()).unwrap(), + Platform::Poc + ); + assert!(Platform::try_from("invalid".to_string()).is_err()); + } } From 21888ed2cd527374bf2023da901dcba37b6c1eff Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 22 Jan 2025 13:46:58 +0100 Subject: [PATCH 21/23] 114 review protocol token fetcher (#116) * Fix Platforms enum * add mqtt support plus examples * update README and examples * update changelog and fmt --- dsh_sdk/CHANGELOG.md | 5 +- dsh_sdk/Cargo.toml | 5 +- dsh_sdk/README.md | 16 +- dsh_sdk/examples/mqtt_example.rs | 97 ++ dsh_sdk/examples/mqtt_ws_example.rs | 99 ++ .../protocol_authentication_full_mediation.rs | 77 ++ ...otocol_authentication_partial_mediation.rs | 99 ++ dsh_sdk/examples/protocol_token_fetcher.rs | 16 - .../protocol_token_fetcher_specific_claims.rs | 29 - dsh_sdk/src/certificates/mod.rs | 24 - dsh_sdk/src/dsh.rs | 4 +- dsh_sdk/src/management_api/token_fetcher.rs | 16 +- dsh_sdk/src/protocol_adapters/mod.rs | 11 +- .../token/api_client_token_fetcher.rs | 612 ++++++++++++ .../token/data_access_token/claims.rs | 152 +++ .../token/data_access_token/mod.rs | 73 ++ .../token/data_access_token/request.rs | 207 +++++ .../token/data_access_token/token.rs | 230 +++++ .../protocol_adapters/{ => token}/error.rs | 10 +- dsh_sdk/src/protocol_adapters/token/mod.rs | 74 ++ .../token/rest_token/claims.rs | 164 ++++ .../protocol_adapters/token/rest_token/mod.rs | 11 + .../token/rest_token/request.rs | 166 ++++ .../token/rest_token/token.rs | 181 ++++ .../protocol_adapters/token_fetcher/mod.rs | 869 ------------------ dsh_sdk/src/rest_api_token_fetcher.rs | 14 +- dsh_sdk/src/utils/mod.rs | 24 + dsh_sdk/src/utils/platform.rs | 60 +- 28 files changed, 2336 insertions(+), 1009 deletions(-) create mode 100644 dsh_sdk/examples/mqtt_example.rs create mode 100644 dsh_sdk/examples/mqtt_ws_example.rs create mode 100644 dsh_sdk/examples/protocol_authentication_full_mediation.rs create mode 100644 dsh_sdk/examples/protocol_authentication_partial_mediation.rs delete mode 100644 dsh_sdk/examples/protocol_token_fetcher.rs delete mode 100644 dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/api_client_token_fetcher.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/data_access_token/claims.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/data_access_token/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/data_access_token/request.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/data_access_token/token.rs rename dsh_sdk/src/protocol_adapters/{ => token}/error.rs (71%) create mode 100644 dsh_sdk/src/protocol_adapters/token/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/rest_token/claims.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/rest_token/mod.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/rest_token/request.rs create mode 100644 dsh_sdk/src/protocol_adapters/token/rest_token/token.rs delete mode 100644 dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 7a6538e..0f78b5c 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -29,8 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `dsh_sdk::dsh::datastream` to `dsh_sdk::datastream` - Moved `dsh_sdk::dsh::certificates` to `dsh_sdk::certificates` - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module -- **Breaking change:** Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token_fetcher` and renamed to `ProtocolTokenFetcher` - - **NOTE** Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` +- **Breaking change:** Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token` and renamed to `ApiClientTokenFetcher` + - **NOTE** The code is refactored to follow the partial mediation and full mediation pattern + - **NOTE** Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` - **Breaking change:** Renamed `dsh_sdk::Platform` methods to more meaningful names - **Breaking change:** Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` - **Breaking change:** Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 26ebd78..8baaa3a 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -58,14 +58,17 @@ dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/s [dev-dependencies] +# Dependencies for the test mockito = "1.1.1" openssl = "0.10" tokio = { version = "^1.35", features = ["full"] } hyper = { version = "1.3", features = ["full"] } serial_test = "3.1.0" dsh_rest_api_client = { path = "../dsh_rest_api_client", version = "0.3.0" } +# Dependencies for the examples dsh_sdk = { features = ["dlq"], path = "." } env_logger = "0.11" rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"], default-features = true } lazy_static = { version = "1.5" } -prometheus = { version = "0.13", features = ["process"] } \ No newline at end of file +prometheus = { version = "0.13", features = ["process"] } +rumqttc = { version = "0.24", features = ["websocket"] } \ No newline at end of file diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 70ae188..9876fe6 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -114,15 +114,15 @@ Below is an overview of the available features: | **feature** | **default** | **Description** | **Example** | |--------------------------------|-------------|-------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------| -| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | -| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | -| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `bootstrap` | ✓ | Certificate signing process and fetch datastreams properties | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | +| `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | | `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/schema_store_api.rs) | -| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_token_fetcher.rs) / [with specific claims](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs) | -| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [ Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/management_api_token_fetcher.rs) | -| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) | -| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/graceful_shutdown.rs) | -| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) | +| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Mqtt client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) / [Mqtt websocket client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) /
[token fetcher (full mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_full_mediation.rs) / [token fetcher (partial mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_partial_mediation.rs) | +| `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/management_api_token_fetcher.rs) | +| `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) | +| `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/graceful_shutdown.rs) | +| `dlq` | ✗ | Dead Letter Queue implementation | [Full implementation example](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/dlq_implementation.rs) | ### Selecting Features diff --git a/dsh_sdk/examples/mqtt_example.rs b/dsh_sdk/examples/mqtt_example.rs new file mode 100644 index 0000000..c19d161 --- /dev/null +++ b/dsh_sdk/examples/mqtt_example.rs @@ -0,0 +1,97 @@ +//! This example demonstrates how to connect to the DSH MQTT broker and consume data using the DSH SDK and Rumqttc. +//! +//! Run example with: +//! ```bash +//! API_KEY={API_KEY} TENANT={TENANT} CLIENT_ID=sdk_example_client cargo run --all-features --example mqtt_example +//! ``` +//! +//! NEVER distribute the API_KEY to an external client, this is only for demonstration purposes. +//! +//! The example will: +//! - Request a DataAccessToken +//! - Create a new MqttOptions based on the fetched token +//! - Create a new async client +//! - Subscribe to a topic +//! - Print received messages +use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; +use dsh_sdk::protocol_adapters::token::data_access_token::{ + DataAccessToken, RequestDataAccessToken, +}; + +use rumqttc::{AsyncClient, MqttOptions, Transport}; + +/// The platform to fetch the token for. +const PLATFORM: dsh_sdk::Platform = dsh_sdk::Platform::NpLz; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let tenant_name = std::env::var("TENANT").expect("TENANT is not set"); + let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID is not set"); + + // Start logger to Stdout to show what is happening + env_logger::builder() + .filter(Some("dsh_sdk"), log::LevelFilter::Trace) + .target(env_logger::Target::Stdout) + .init(); + + // Create a request for the Data Access Token (this request for full access) + let request = RequestDataAccessToken::new(tenant_name, client_id); + + // Fetch token from your API Client Authentication service + let token = ApiClientAuthenticationService::get_data_access_token(request).await; + + // Create a new MqttOptions based on info from token + let mut mqttoptions = MqttOptions::new(token.client_id(), token.endpoint(), token.port_mqtt()); + mqttoptions.set_credentials("", token.raw_token()); + mqttoptions.set_transport(Transport::tls_with_default_config()); + + // Create a new async client + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + // For demonstration purposes, we select the first topic that is available in token + let topic = token + .claims() + .iter() + .next() + .expect("No avaialable topics") + .full_qualified_topic_name(); + + // Subscribe to a topic + client.subscribe(topic, rumqttc::QoS::AtMostOnce).await?; + + loop { + match eventloop.poll().await { + Ok(v) => { + println!("Received = {:?}", v); + } + Err(e) => { + println!("Error = {e:?}"); + break; + } + } + } + + Ok(()) +} + +/// NEVER use this in an external client, this is only for demonstration purposes. +/// This should be implemented in your API Client Authentication service. +/// +/// See the following examples for a more complete example: +/// - `examples/protocol_authentication_partial_mediation.rs` +/// - `examples/protocol_authentication_full_mediation.rs` +struct ApiClientAuthenticationService; + +impl ApiClientAuthenticationService { + /// This should be properly implemented in your API Client Authentication service. + async fn get_data_access_token(request: RequestDataAccessToken) -> DataAccessToken { + let api_key = std::env::var("API_KEY").expect("API_KEY is not set"); + + let token_fetcher = ApiClientTokenFetcher::new(api_key, PLATFORM); + + token_fetcher + .fetch_data_access_token(request) + .await + .unwrap() + } +} diff --git a/dsh_sdk/examples/mqtt_ws_example.rs b/dsh_sdk/examples/mqtt_ws_example.rs new file mode 100644 index 0000000..58ccdc8 --- /dev/null +++ b/dsh_sdk/examples/mqtt_ws_example.rs @@ -0,0 +1,99 @@ +//! This example demonstrates how to connect to the DSH MQTT broker over websockets +//! and consume data using the DSH SDK and Rumqttc. +//! +//! Run example with: +//! ```bash +//! API_KEY={API_KEY} TENANT={TENANT} CLIENT_ID=sdk_example_client cargo run --all-features --example mqtt_ws_example +//! ``` +//! +//! NEVER distribute the API_KEY to an external client, this is only for demonstration purposes. +//! +//! The example will: +//! - Request a DataAccessToken +//! - Create a new MqttOptions based on the fetched token +//! - Create a new async client +//! - Subscribe to a topic +//! - Print received messages +use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; +use dsh_sdk::protocol_adapters::token::data_access_token::{ + DataAccessToken, RequestDataAccessToken, +}; + +use rumqttc::{AsyncClient, MqttOptions, Transport}; + +/// The platform to fetch the token for. +const PLATFORM: dsh_sdk::Platform = dsh_sdk::Platform::NpLz; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let tenant_name = std::env::var("TENANT").expect("TENANT is not set"); + let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID is not set"); + + // Start logger to Stdout to show what is happening + env_logger::builder() + .filter_level(log::LevelFilter::Trace) + .target(env_logger::Target::Stdout) + .init(); + + // Create a request for the Data Access Token (this request for full access) + let request = RequestDataAccessToken::new(tenant_name, client_id); + + // Fetch token from your API Client Authentication service + let token = ApiClientAuthenticationService::get_data_access_token(request).await; + + // Create a new MqttOptions based on info from token + let mut mqttoptions = + MqttOptions::new(token.client_id(), token.endpoint_wss(), token.port_wss()); + mqttoptions.set_credentials("", token.raw_token()); + mqttoptions.set_transport(Transport::wss_with_default_config()); + + // Create a new async client + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + // For demonstration purposes, we select the first topic that is available in token + let topic = token + .claims() + .iter() + .next() + .expect("No avaialable topics") + .full_qualified_topic_name(); + + // Subscribe to a topic + client.subscribe(topic, rumqttc::QoS::AtMostOnce).await?; + + loop { + match eventloop.poll().await { + Ok(v) => { + println!("Received = {:?}", v); + } + Err(e) => { + println!("Error = {e:?}"); + break; + } + } + } + + Ok(()) +} + +/// NEVER use this in an external client, this is only for demonstration purposes. +/// This should be implemented in your API Client Authentication service. +/// +/// See the following examples for a more complete example: +/// - `examples/protocol_authentication_partial_mediation.rs` +/// - `examples/protocol_authentication_full_mediation.rs` +struct ApiClientAuthenticationService; + +impl ApiClientAuthenticationService { + /// This should be properly implemented in your API Client Authentication service. + async fn get_data_access_token(request: RequestDataAccessToken) -> DataAccessToken { + let api_key = std::env::var("API_KEY").expect("API_KEY is not set"); + + let token_fetcher = ApiClientTokenFetcher::new(api_key, PLATFORM); + + token_fetcher + .fetch_data_access_token(request) + .await + .unwrap() + } +} diff --git a/dsh_sdk/examples/protocol_authentication_full_mediation.rs b/dsh_sdk/examples/protocol_authentication_full_mediation.rs new file mode 100644 index 0000000..418eec9 --- /dev/null +++ b/dsh_sdk/examples/protocol_authentication_full_mediation.rs @@ -0,0 +1,77 @@ +//! Example: API Client Authentication service fetching a DataAccessToken for a device. +//! +//! The DataAccessToken allows a device to connect to protocol adapters with specific permissions. +//! +//! ## Important Notes: +//! - **Do NOT implement this logic in device applications or external clients!** +//! - This logic is exclusive to the **API Client role** in the DSH architecture. +//! - The API Client uses a long-lived API_KEY to fetch short-lived tokens for devices. +//! - **The API_KEY must never be distributed.** + +use dsh_sdk::protocol_adapters::token::{ + api_client_token_fetcher::ApiClientTokenFetcher, Action, RequestDataAccessToken, + TopicPermission, +}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Target platform for fetching the token. +const PLATFORM: dsh_sdk::Platform = dsh_sdk::Platform::NpLz; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Retrieve environment variables + let tenant_name = std::env::var("TENANT").expect("TENANT environment variable is not set"); + let api_key = std::env::var("API_KEY").expect("API_KEY environment variable is not set"); + + // Initialize logger to display detailed SDK activity in stdout + env_logger::builder() + .filter(Some("dsh_sdk"), log::LevelFilter::Trace) + .target(env_logger::Target::Stdout) + .init(); + + // Example Scenario: + // Assume the API Authentication service receives a request from an external client. + // We want to delegate a DataAccessToken with the following properties: + // - Valid for 10 minutes + // - Allows fetching another DataAccessToken with: + // - Maximum expiration of 5 minutes + // - Usage restricted to the external client ID "External-client-id" + + // Instantiate the API Client Token Fetcher + let token_fetcher = ApiClientTokenFetcher::new(api_key, PLATFORM); + + // Define the permissions for the DataAccessToken + let permissions = vec![TopicPermission::new( + Action::Subscribe, + "amp", + "/tt", + format!("state/app/{}/#", tenant_name), + )]; + + // Create the token request + let expiration_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time is before UNIX epoch") + .as_secs() as i64 + + 600; // 10 minutes in seconds + + let token_request = RequestDataAccessToken::new(&tenant_name, "External-client-id") + .set_exp(expiration_time) + .set_claims(permissions); + + // Fetch the DataAccessToken + let token = token_fetcher + .get_or_fetch_data_access_token(token_request) + .await?; + + println!( + "\nGenerated DataAccessToken with partial permissions: {:?}\n", + token + ); + + // Extract and send the raw token to the external client + let raw_token = token.raw_token(); + println!("Raw token to send to external client: {}", raw_token); + + Ok(()) +} diff --git a/dsh_sdk/examples/protocol_authentication_partial_mediation.rs b/dsh_sdk/examples/protocol_authentication_partial_mediation.rs new file mode 100644 index 0000000..47742a1 --- /dev/null +++ b/dsh_sdk/examples/protocol_authentication_partial_mediation.rs @@ -0,0 +1,99 @@ +//! Example: API Client Authentication service fetching a REST token for a device. +//! +//! The REST token enables a device to fetch its own DataAccessToken to connect to protocol adapters. +//! +//! ## Important Notes: +//! - **Do NOT implement this logic in device applications or external clients!** +//! - This logic is part of the **API Client role** in the DSH architecture. +//! - The API Client uses a long-lived API_KEY (REST token) to fetch short-lived tokens for devices. +//! - **The API_KEY must never be distributed.** + +use dsh_sdk::protocol_adapters::token::{ + api_client_token_fetcher::ApiClientTokenFetcher, DatastreamsMqttTokenClaim, + RequestDataAccessToken, RequestRestToken, RestToken, +}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Target platform for fetching the token. +const PLATFORM: dsh_sdk::Platform = dsh_sdk::Platform::NpLz; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Retrieve environment variables + let tenant_name = std::env::var("TENANT").expect("TENANT environment variable is not set"); + let api_key = std::env::var("API_KEY").expect("API_KEY environment variable is not set"); + + // Initialize logger to display detailed SDK activity in stdout + env_logger::builder() + .filter(Some("dsh_sdk"), log::LevelFilter::Trace) + .target(env_logger::Target::Stdout) + .init(); + + // Example Scenario: + // Assume the API Authentication service receives a request from an external client. + // We want to delegate a short-lived REST token with the following properties: + // - REST token: + // - Valid for 10 minutes + // - Allows fetching a DataAccessToken with: + // - Maximum expiration of 5 minutes + // - Usage restricted to the external client ID "External-client-id" + + println!("API Authentication Service Code:\n"); + + // Instantiate the API Client Token Fetcher + let token_fetcher = ApiClientTokenFetcher::new(api_key, PLATFORM); + + // Define the claim for the DatastreamsMqttToken endpoint + let claim = DatastreamsMqttTokenClaim::new() + .set_id("External-client-id") // External client ID (should be unique) + .set_relexp(300); // Relative expiration of 5 minutes (300 seconds) + + // Create a token request with the claim and expiration time + let expiration_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time is before UNIX epoch") + .as_secs() as i64 + + 600; // 10 minutes in seconds + + let rest_token_request = RequestRestToken::new(&tenant_name) + .set_exp(expiration_time) + .set_claims(claim); + + // Fetch the REST token + let partial_token = token_fetcher + .get_or_fetch_rest_token(rest_token_request) + .await?; + println!( + "\nGenerated REST token with partial permissions: {:?}", + partial_token + ); + + // Send the raw token to the external client + let raw_token = partial_token.raw_token(); + println!("\nRaw token to send to external client: {}", raw_token); + + // ------------------------------------------------------------------------------------- + // External Client Code: + // + // When the external client receives the raw_token, it can fetch its own DataAccessToken: + // 1. Parse the raw token into a RestToken. + // 2. Prepare a request for a DataAccessToken with the external client ID. + // 3. Fetch the DataAccessToken using the RestToken. + // ------------------------------------------------------------------------------------- + println!("\nExternal Client Code:"); + + // Parse the raw token into a RestToken + let rest_token = RestToken::parse(raw_token)?; + println!("\nParsed REST token: {:?}", rest_token); + + // Prepare a request for a DataAccessToken using the external client ID + let data_access_request = + RequestDataAccessToken::new(rest_token.tenant_id(), "External-client-id"); + + // Use an HTTP client to send the request and fetch the DataAccessToken + let http_client = reqwest::Client::new(); + let data_access_token = data_access_request.send(&http_client, rest_token).await?; + println!("\nFetched DataAccessToken: {:#?}", data_access_token); + + Ok(()) +} diff --git a/dsh_sdk/examples/protocol_token_fetcher.rs b/dsh_sdk/examples/protocol_token_fetcher.rs deleted file mode 100644 index 7467acb..0000000 --- a/dsh_sdk/examples/protocol_token_fetcher.rs +++ /dev/null @@ -1,16 +0,0 @@ -use std::env; - -use dsh_sdk::protocol_adapters::token_fetcher::*; - -#[tokio::main] -async fn main() { - let tenant_name = env::var("TENANT").unwrap().to_string(); - let api_key = env::var("API_KEY").unwrap().to_string(); - let mqtt_token_fetcher = - ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); - let token: ProtocolToken = mqtt_token_fetcher - .get_token("Client-id", None) //Claims = None fetches all possible claims - .await - .unwrap(); - println!("MQTT Token: {:?}", token); -} diff --git a/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs b/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs deleted file mode 100644 index c290bd4..0000000 --- a/dsh_sdk/examples/protocol_token_fetcher_specific_claims.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::env; - -use dsh_sdk::protocol_adapters::token_fetcher::*; - -#[tokio::main] -async fn main() { - // Get the config and secret from the environment - let tenant_name = env::var("TENANT").unwrap().to_string(); - let api_key = env::var("API_KEY").unwrap().to_string(); - let stream = env::var("STREAM").unwrap().to_string(); - - let topic = "#".to_string(); // check MQTT documentation for better understanding of wildcards - let resource = Resource::new(stream, "/tt".to_string(), topic, Some("topic".to_string())); - - let claims_sub = Claims::new(resource.clone(), Actions::Subscribe); - - let claims_pub = Claims::new(resource, Actions::Publish); - - let claims = vec![claims_sub, claims_pub]; - - let mqtt_token_fetcher = - ProtocolTokenFetcher::new(tenant_name, api_key, dsh_sdk::Platform::NpLz); - - let token: ProtocolToken = mqtt_token_fetcher - .get_token("Client-id", Some(claims)) - .await - .unwrap(); - println!("MQTT Token: {:?}", token); -} diff --git a/dsh_sdk/src/certificates/mod.rs b/dsh_sdk/src/certificates/mod.rs index 9977993..da56c1b 100644 --- a/dsh_sdk/src/certificates/mod.rs +++ b/dsh_sdk/src/certificates/mod.rs @@ -279,15 +279,6 @@ impl Cert { } } -/// Helper function to ensure that the host starts with `https://` or `http://`. -pub(crate) fn ensure_https_prefix(host: impl AsRef) -> String { - if host.as_ref().starts_with("http://") || host.as_ref().starts_with("https://") { - host.as_ref().to_string() - } else { - format!("https://{}", host.as_ref()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -385,19 +376,4 @@ mod tests { ); assert!(identity.is_ok()); } - - #[test] - fn test_ensure_https_prefix() { - let host = "http://example.com"; - let result = ensure_https_prefix(host); - assert_eq!(result, "http://example.com"); - - let host = "https://example.com"; - let result = ensure_https_prefix(host); - assert_eq!(result, "https://example.com"); - - let host = "example.com"; - let result = ensure_https_prefix(host); - assert_eq!(result, "https://example.com"); - } } diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index 21d7d46..1c7948f 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -28,7 +28,7 @@ use log::warn; use std::env; use std::sync::{Arc, OnceLock}; -use crate::certificates::{ensure_https_prefix, Cert, CertificatesError}; +use crate::certificates::{Cert, CertificatesError}; use crate::datastream::Datastream; use crate::error::DshError; use crate::utils; @@ -113,7 +113,7 @@ impl Dsh { let tenant_name = utils::tenant_name().unwrap_or_else(|_| "local_tenant".to_string()); let task_id = utils::get_env_var(VAR_TASK_ID).unwrap_or_else(|_| "local_task_id".to_string()); - let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST).map(ensure_https_prefix); + let config_host = utils::get_env_var(VAR_KAFKA_CONFIG_HOST).map(utils::ensure_https_prefix); let certificates = if let Ok(cert) = Cert::from_pki_config_dir::(None) { Some(cert) diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index 87883b8..19af4fd 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -137,7 +137,7 @@ impl Default for AccessToken { /// let token_fetcher = ManagementApiTokenFetcher::new( /// client_id, /// client_secret, -/// platform.endpoint_rest_access_token().to_string() +/// platform.endpoint_management_api_token().to_string() /// ); /// /// let token = token_fetcher.get_token().await?; @@ -170,7 +170,7 @@ impl ManagementApiTokenFetcher { /// let token_fetcher = ManagementApiTokenFetcher::new( /// client_id, /// client_secret, - /// platform.endpoint_rest_access_token() + /// platform.endpoint_management_api_token() /// ); /// /// let token = token_fetcher.get_token().await.unwrap(); @@ -212,7 +212,7 @@ impl ManagementApiTokenFetcher { /// let token_fetcher = ManagementApiTokenFetcher::new_with_client( /// client_id, /// client_secret, - /// platform.endpoint_rest_access_token().to_string(), + /// platform.endpoint_management_api_token().to_string(), /// custom_client /// ); /// let token = token_fetcher.get_token().await.unwrap(); @@ -461,7 +461,7 @@ impl ManagementApiTokenFetcherBuilder { let token_fetcher = ManagementApiTokenFetcher::new_with_client( client_id, client_secret, - self.platform.endpoint_rest_access_token().to_string(), + self.platform.endpoint_management_api_token().to_string(), client, ); Ok(token_fetcher) @@ -640,7 +640,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } /// Ensures the builder can auto-generate `client_id` from the `tenant_name`. @@ -659,7 +659,7 @@ mod test { format!("robot:{}:{}", Platform::NpLz.realm(), tenant_name) ); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } /// Validates that a custom `reqwest::Client` can be injected into the builder. @@ -677,7 +677,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } /// Tests precedence of `client_id` over a derived tenant-based client ID. @@ -695,7 +695,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id_override); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } /// Ensures builder returns errors if `client_id` or `client_secret` are missing. diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs index ea5e196..d969edb 100644 --- a/dsh_sdk/src/protocol_adapters/mod.rs +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -7,13 +7,4 @@ pub mod kafka_protocol; // #[cfg(feature = "mqtt-protocol-adapter")] // pub mod mqtt_protocol; #[cfg(feature = "protocol-token-fetcher")] -pub mod token_fetcher; - -mod error; - -#[cfg(feature = "protocol-token-fetcher")] -#[doc(inline)] -pub use error::ProtocolTokenError; -#[cfg(feature = "protocol-token-fetcher")] -#[doc(inline)] -pub use token_fetcher::ProtocolTokenFetcher; +pub mod token; diff --git a/dsh_sdk/src/protocol_adapters/token/api_client_token_fetcher.rs b/dsh_sdk/src/protocol_adapters/token/api_client_token_fetcher.rs new file mode 100644 index 0000000..cd19518 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/api_client_token_fetcher.rs @@ -0,0 +1,612 @@ +//! Token fetcher for fetching [`RestToken`]s and [`DataAccessToken`]s by an API Client authentication service. +//! +//! An API client is an organization that develops or supplies services (applications) +//! and devices (external clients) that will read data from or publish data on the public +//! data streams on the DSH platform. +//! +//! # Note +//! **NEVER** use this fetcher in a device or external client. +//! +//! This fetcher should be used by your API Client authentication services that delegates [`RestToken`] +//! and/or [`DataAccessToken`] to external clients. +use std::collections::HashMap; +use std::hash::{DefaultHasher, Hash, Hasher}; + +use tokio::sync::RwLock; + +use super::data_access_token::*; +use super::rest_token::*; +use super::ProtocolTokenError; +use crate::Platform; + +/// Fetcher for Rest and Data Access tokens by an API Client authentication service. +/// +/// # Note +/// **NEVER** implement this fetcher in a device or external client. +/// +/// This fetcher should be used by your API Client authentication services that delegates [`RestToken`] +/// and/or [`DataAccessToken`] to external clients. +pub struct ApiClientTokenFetcher { + api_key: String, + auth_url: String, + cache_rest_tokens: RwLock>, + cache_data_access_tokens: RwLock>, + reqwest_client: reqwest::Client, +} + +impl ApiClientTokenFetcher { + /// Creates a new instance of the API client token fetcher. + /// + /// # Note + /// **NEVER** implement this fetcher in a device or external client. + /// + /// # Arguments + /// - `api_key` - The API key to authenticate to DSH. + /// - `platform` - The DSH [`Platform`] to fetch the token for. + /// + /// # Returns + /// A new instance of the API client token fetcher. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// // Get the API key from the environment variable + /// let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// + /// // Create a token fetcher with the API key and platform + /// let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// ``` + pub fn new(api_key: impl Into, platform: Platform) -> Self { + let reqwest_client = reqwest::Client::builder() + .use_rustls_tls() + .build() + .expect("Failed to create reqwest client with rustls as tls backend"); + Self { + api_key: api_key.into(), + auth_url: platform.endpoint_protocol_rest_token().to_string(), + cache_rest_tokens: RwLock::new(HashMap::new()), + cache_data_access_tokens: RwLock::new(HashMap::new()), + reqwest_client, + } + } + + /// Creates a new instance of the API client token fetcher with custom [reqwest::Client]. + /// + /// # Note + /// **NEVER** implement this fetcher in a device or external client. + /// + /// # Arguments + /// - `api_key` - The API key to authenticate to DSH. + /// - `platform` - The DSH [`Platform`] to fetch the token for. + /// + /// # Returns + /// A new instance of the API client token fetcher. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// use dsh_sdk::Platform; + /// + /// // Get the API key from the environment variable + /// let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// + /// // Create a token fetcher with the API key and platform + /// let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// ``` + pub fn new_with_client( + api_key: impl Into, + platform: Platform, + client: reqwest::Client, + ) -> Self { + Self { + api_key: api_key.into(), + auth_url: platform.endpoint_protocol_rest_token().to_string(), + cache_rest_tokens: RwLock::new(HashMap::new()), + cache_data_access_tokens: RwLock::new(HashMap::new()), + reqwest_client: client, + } + } + + /// Clears the cache of all [`RestToken`]s. + pub async fn clear_cache_rest_tokens(&self) { + self.cache_rest_tokens.write().await.clear(); + } + + /// Clears the cache of all [`DataAccessToken`]s. + pub async fn clear_cache_data_access_tokens(&self) { + self.cache_data_access_tokens.write().await.clear(); + } + + /// Clears the cache of all tokens. + pub async fn clear_cache(&self) { + self.clear_cache_rest_tokens().await; + self.clear_cache_data_access_tokens().await; + } + + /// Fetches a new [`RestToken`] from the DSH platform. + /// + /// # Arguments + /// - `request` - The [`RequestRestToken`] to fetch the token. + /// + /// # Returns + /// The [`RestToken`] fetched from the Cache or from the DSH platform. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::rest_token::RequestRestToken; + /// # use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// # use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// # let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// // Create a token request + /// let request = RequestRestToken::new("example-tenant"); + /// + /// // Fetch token + /// let token = token_fetcher.fetch_rest_token(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn fetch_rest_token( + &self, + request: RequestRestToken, + ) -> Result { + let token = request + .send(&self.reqwest_client, &self.api_key, &self.auth_url) + .await; + log::trace!( + "Fetched new token for tenant '{}' client '{}': {:?}", + request.tenant(), + request.client_id().unwrap_or(request.tenant()), + token, + ); + token + } + + /// Get a [`RestToken`] from Cache if valid or fetch a new one from the DSH platform. + /// + /// It will check the cache first and check if it is still valid. + /// If not it will fetch a new [`RestToken`] + /// + /// # Arguments + /// - `request` - The [`RequestRestToken`] to fetch the token. + /// + /// # Returns + /// The [`RestToken`] fetched from the Cache or from the DSH platform. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::rest_token::RequestRestToken; + /// # use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// # use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// # let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// // Create a token request + /// let request = RequestRestToken::new("example-tenant"); + /// + /// // Get or fetch token + /// let token = token_fetcher.get_or_fetch_rest_token(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn get_or_fetch_rest_token( + &self, + request: RequestRestToken, + ) -> Result { + // Get the tenant name from the request + let tenant = request.tenant(); + // Get the client_id from the request, if not present use the tenant name as client_id + let client_id = request + .claims() + .and_then(|claim| claim.mqtt_token_claim().id()) + .unwrap_or_else(|| tenant); + + let key = generate_cache_key((tenant, client_id)); + + // Check if a valid token is already in the cache with a read lock + if let Some(token) = self.get_valid_cached_rest_token(key).await { + return Ok(token); + } + + let mut cache_write_lock = self.cache_rest_tokens.write().await; + + // Get an entry in the cache + let token = cache_write_lock.entry(key).or_insert(RestToken::init()); + + // Check if the token is valid (for if another thread already fetched a new token) + if !token.is_valid() { + *token = self.fetch_rest_token(request).await?; + }; + + Ok(token.clone()) + } + + /// Fetches a new [`DataAccessToken`] from the DSH platform. + /// + /// # Arguments + /// - `request` - The [`RequestDataAccessToken`] to fetch the token. + /// + /// # Returns + /// The [`DataAccessToken`] fetched from the Cache or from the DSH platform. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::data_access_token::RequestDataAccessToken; + /// # use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// # use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// // Create a token fetcher with the API key and platform + /// let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// + /// // Create a token request + /// let request = RequestDataAccessToken::new("example-tenant", "external-client-id"); + /// + /// // Fetch token + /// let token = token_fetcher.fetch_data_access_token(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn fetch_data_access_token( + &self, + request: RequestDataAccessToken, + ) -> Result { + let rest_token = self + .get_or_fetch_rest_token(RequestRestToken::new(request.tenant())) + .await?; + let token = request.send(&self.reqwest_client, rest_token).await?; + log::trace!( + "Fetched new token for tenant '{}' client '{}': {:?}", + request.tenant(), + request.id(), + token, + ); + Ok(token) + } + + /// Fetches a new [`DataAccessToken`] from the DSH platform and caches it. + /// + /// It will check the cache first and check if it is still valid. + /// If not it will fetch a new [`DataAccessToken`] + /// + /// # Arguments + /// - `request` - The [`RequestDataAccessToken`] to fetch the token. + /// + /// # Returns + /// The [`DataAccessToken`] fetched from the Cache or from the DSH platform. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::data_access_token::RequestDataAccessToken; + /// # use dsh_sdk::protocol_adapters::token::api_client_token_fetcher::ApiClientTokenFetcher; + /// # use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let api_key = std::env::var("API_KEY").expect("API_KEY env variable is not set"); + /// // Create a token fetcher with the API key and platform + /// let token_fetcher = ApiClientTokenFetcher::new(api_key, Platform::NpLz); + /// + /// // Create a token request + /// let request = RequestDataAccessToken::new("example-tenant", "external-client-id"); + /// + /// // Get or fetch token + /// let token = token_fetcher.get_or_fetch_data_access_token(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn get_or_fetch_data_access_token( + &self, + request: RequestDataAccessToken, + ) -> Result { + let key = generate_cache_key(&request); + + if let Some(token) = self.get_valid_cached_data_access_token(key).await { + return Ok(token); + } + + let mut cache_write_lock = self.cache_data_access_tokens.write().await; + + // Get a reference to the token + let token = cache_write_lock + .entry(key) + .or_insert(DataAccessToken::init()); + + // Check if the token is valid (for if another thread already fetched a new token) + if !token.is_valid() { + *token = self.fetch_data_access_token(request).await?; + }; + Ok(token.clone()) + } + + /// Attempts to retrieve a valid cached token with a read lock. + async fn get_valid_cached_rest_token(&self, key: u64) -> Option { + if let Some(token) = self.cache_rest_tokens.read().await.get(&key) { + if token.is_valid() { + log::trace!( + "Valid RestToken found in cache for tenant '{}' client '{}': {:?}", + token.tenant_id(), + token.client_id().unwrap_or(token.tenant_id()), + token + ); + Some(token.clone()) + } else { + log::trace!( + "Invalid RestToken found in cache for tenant '{}' client '{}': {:?}", + token.tenant_id(), + token.client_id().unwrap_or(token.tenant_id()), + token + ); + None + } + } else { + log::trace!("No RestToken found in cache"); + None + } + } + + /// Attempts to retrieve a valid cached token with a read lock. + async fn get_valid_cached_data_access_token(&self, key: u64) -> Option { + if let Some(token) = self.cache_data_access_tokens.read().await.get(&key) { + if token.is_valid() { + log::trace!( + "Valid DataAccessToken found in cache for client '{}': {:?}", + token.client_id(), + token + ); + return Some(token.clone()); + } else { + log::trace!( + "Invalid DataAccessToken found in cache for client '{}': {:?}", + token.client_id(), + token + ); + } + } else { + log::trace!("No DataAccessToken found in cache"); + } + None + } +} + +impl std::fmt::Debug for ApiClientTokenFetcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApiClientTokenFetcher") + .field("api_key", &"xxxxxx") + .field("auth_url", &self.auth_url) + .finish() + } +} + +/// Hashes the key and returns the hash. +fn generate_cache_key(key: impl Hash) -> u64 { + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + hasher.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + use super::{DataAccessToken, RequestDataAccessToken, RequestRestToken, RestToken}; + + // create a valid fetcher with dummy tokens + fn create_valid_fetcher() -> ApiClientTokenFetcher { + let rest_token = RestToken::parse("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoiZHVtbXlfZW5kcG9pbnQiLCJpc3MiOiJTdHJpbmciLCJjbGFpbXMiOlt7InJlc291cmNlIjoiZHVtbXkiLCJhY3Rpb24iOiJwdXNoIn1dLCJleHAiOjIxNDc0ODM2NDcsImNsaWVudC1pZCI6ImR1bW15X3RlbmFudCIsImlhdCI6MCwidGVuYW50LWlkIjoiZHVtbXlfdGVuYW50In0.SbePw_EmLrkiSfk5XykLosqOoFb0xC_QE4A8283rFfY").unwrap(); + let data_access_token = DataAccessToken::parse("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiIxIiwiZ2VuIjoxLCJleHAiOjIxNDc0ODM2NDcsImlhdCI6MjE0NzQ4MzY0NywiZW5kcG9pbnQiOiJkdW1teV9lbmRwb2ludCIsInBvcnRzIjp7Im1xdHRzIjpbODg4M10sIm1xdHR3c3MiOls0NDMsODQ0M119LCJ0ZW5hbnQtaWQiOiJkdW1teV90ZW5hbnQiLCJjbGllbnQtaWQiOiJEdW1teS1jbGllbnQtaWQiLCJjbGFpbXMiOlt7ImFjdGlvbiI6InN1YnNjcmliZSIsInJlc291cmNlIjp7InR5cGUiOiJ0b3BpYyIsInByZWZpeCI6Ii90dCIsInN0cmVhbSI6ImR1bW1teSIsInRvcGljIjoiL2R1bW15LyMifX1dfQ.651w8PULFURETQoaKVyKSTE6qghxKfLSm_oODzFU1mM").unwrap(); + + let mut rest_cache = HashMap::new(); + rest_cache.insert( + generate_cache_key(("dummy_tenant", "dummy_tenant")), + rest_token, + ); + let mut data_cache = HashMap::new(); + data_cache.insert( + generate_cache_key(RequestDataAccessToken::new( + "dummy_tenant", + "Dummy-client-id", + )), + data_access_token, + ); + ApiClientTokenFetcher { + api_key: "abc123".to_string(), + auth_url: "dummy_auth_url".to_string(), + cache_rest_tokens: RwLock::new(rest_cache), + cache_data_access_tokens: RwLock::new(data_cache), + reqwest_client: reqwest::Client::new(), + } + } + + #[tokio::test] + async fn test_api_client_token_fetcher_new() { + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let fetcher = ApiClientTokenFetcher::new(rest_api_key, platform); + + assert!(fetcher.cache_rest_tokens.read().await.is_empty()); + assert!(fetcher.cache_data_access_tokens.read().await.is_empty()); + assert_eq!(fetcher.api_key, "test_api_key".to_string()); + assert_eq!( + fetcher.auth_url, + Platform::NpLz.endpoint_protocol_rest_token() + ); + } + + #[tokio::test] + async fn test_api_client_token_fetcher_new_with_client() { + let rest_api_key = "test_api_key".to_string(); + let platform = Platform::NpLz; + + let client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); + let fetcher = ApiClientTokenFetcher::new_with_client(rest_api_key, platform, client); + + assert!(fetcher.cache_rest_tokens.read().await.is_empty()); + assert!(fetcher.cache_data_access_tokens.read().await.is_empty()); + assert_eq!(fetcher.api_key, "test_api_key".to_string()); + assert_eq!( + fetcher.auth_url, + Platform::NpLz.endpoint_protocol_rest_token() + ); + } + #[tokio::test] + async fn test_fetch_new_rest_token() { + let mut mockito_server = mockito::Server::new_async().await; + let raw_rest_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"; + let _m = mockito_server + .mock("POST", "/rest_auth_url") + .with_status(200) + .with_body(raw_rest_token) + .create_async() + .await; + println!("server url: {}", mockito_server.url()); + let client = reqwest::Client::new(); + + let fetcher = ApiClientTokenFetcher { + api_key: "test_api_key".to_string(), + auth_url: format!("{}{}", mockito_server.url(), "/rest_auth_url"), + cache_data_access_tokens: RwLock::new(HashMap::new()), + cache_rest_tokens: RwLock::new(HashMap::new()), + reqwest_client: client, + }; + + let request = RequestRestToken::new("test_tenant"); + + let result = fetcher.fetch_rest_token(request).await; + println!("{:?}", result); + assert!(result.is_ok()); + let rest_token = result.unwrap(); + assert_eq!(rest_token.exp(), 1); + assert_eq!(rest_token.gen(), 1); + assert_eq!(rest_token.endpoint(), "test_endpoint"); + assert_eq!(rest_token.iss(), "String"); + assert_eq!(rest_token.raw_token(), raw_rest_token); + } + + #[tokio::test] + async fn test_fetch_new_data_access_token() { + let mut opt: mockito::ServerOpts = mockito::ServerOpts::default(); + opt.port = 7999; + let mut mockito_server = mockito::Server::new_with_opts_async(opt).await; + let raw_rest_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6Imh0dHA6Ly8xMjcuMC4wLjE6Nzk5OSIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19.j5ekqMiWyBhJyRQE_aARFS9mQJiN7S2rpKTsn3rZ5lQ"; + let raw_access_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywiaWF0IjoyMTQ3NDgzNjQ3LCJlbmRwb2ludCI6InRlc3RfZW5kcG9pbnQiLCJwb3J0cyI6eyJtcXR0cyI6Wzg4ODNdLCJtcXR0d3NzIjpbNDQzLDg0NDNdfSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImNsYWltcyI6W3siYWN0aW9uIjoic3Vic2NyaWJlIiwicmVzb3VyY2UiOnsidHlwZSI6InRvcGljIiwicHJlZml4IjoiL3R0Iiwic3RyZWFtIjoidGVzdCIsInRvcGljIjoiL3Rlc3QvIyJ9fV19.LwYIMIX39J502TDqpEqH5T2Rlj-HczeT3WLfs5Do3B0"; + let _m = mockito_server + .mock("POST", "/rest_auth_url") + .with_status(200) + .with_body(raw_rest_token) + .create_async() + .await; + let _m2 = mockito_server + .mock("POST", "/datastreams/v0/mqtt/token") + .with_status(200) + .with_body(raw_access_token) + .create(); + + println!("server url: {}", mockito_server.url()); + let client = reqwest::Client::new(); + + let fetcher = ApiClientTokenFetcher { + api_key: "test_api_key".to_string(), + auth_url: format!("{}{}", mockito_server.url(), "/rest_auth_url"), + cache_data_access_tokens: RwLock::new(HashMap::new()), + cache_rest_tokens: RwLock::new(HashMap::new()), + reqwest_client: client, + }; + + let request = RequestDataAccessToken::new("test_tenant", "test_client"); + let result = fetcher.fetch_data_access_token(request).await; + + println!("{:?}", result); + assert!(result.is_ok()); + let acces_token = result.unwrap(); + assert_eq!(acces_token.exp(), 2147483647); + assert_eq!(acces_token.gen(), 1); + assert_eq!(acces_token.endpoint(), "test_endpoint"); + assert_eq!(acces_token.iss(), "String"); + assert_eq!(acces_token.client_id(), "test_client"); + assert_eq!(acces_token.tenant_id(), "test_tenant"); + assert_eq!(acces_token.raw_token(), raw_access_token); + } + + #[tokio::test] + async fn test_get_cached_rest_token() { + let fetcher = create_valid_fetcher(); + + let request = RequestRestToken::new("dummy_tenant"); + let token = fetcher.get_or_fetch_rest_token(request).await.unwrap(); + assert_eq!(token.tenant_id(), "dummy_tenant"); + assert_eq!(token.client_id(), None); + + let request = RequestRestToken::new("not_in_cache"); + let token = fetcher.get_or_fetch_rest_token(request).await; + assert!(token.is_err()); + } + + #[tokio::test] + async fn test_get_cached_data_access_token() { + let fetcher = create_valid_fetcher(); + + let request = RequestDataAccessToken::new("dummy_tenant", "Dummy-client-id"); + let token = fetcher + .get_or_fetch_data_access_token(request) + .await + .unwrap(); + assert_eq!(token.tenant_id(), "dummy_tenant"); + assert_eq!(token.client_id(), "Dummy-client-id"); + + let request = RequestDataAccessToken::new("dummy_tenant", "not_in_cache"); + let token = fetcher.get_or_fetch_data_access_token(request).await; + assert!(token.is_err()); + } + + #[tokio::test] + async fn test_clear_cache_rest_tokens() { + let fetcher = create_valid_fetcher(); + assert!(!fetcher.cache_rest_tokens.read().await.is_empty()); + fetcher.clear_cache_rest_tokens().await; + assert!(fetcher.cache_rest_tokens.read().await.is_empty()); + } + + #[tokio::test] + async fn test_clear_cache_data_access_tokens() { + let fetcher = create_valid_fetcher(); + assert!(!fetcher.cache_data_access_tokens.read().await.is_empty()); + fetcher.clear_cache_data_access_tokens().await; + assert!(fetcher.cache_data_access_tokens.read().await.is_empty()); + } + + #[tokio::test] + async fn test_clear_cache() { + let fetcher = create_valid_fetcher(); + assert!(!fetcher.cache_rest_tokens.read().await.is_empty()); + assert!(!fetcher.cache_data_access_tokens.read().await.is_empty()); + fetcher.clear_cache().await; + assert!(fetcher.cache_rest_tokens.read().await.is_empty()); + assert!(fetcher.cache_data_access_tokens.read().await.is_empty()); + } + + #[test] + fn test_debug_client_fetcher() { + let client = create_valid_fetcher(); + let debug = format!("{:?}", client); + assert_eq!( + debug, + "ApiClientTokenFetcher { api_key: \"xxxxxx\", auth_url: \"dummy_auth_url\" }" + ); + } + + #[test] + fn test_generate_cache_key() { + let key = generate_cache_key(("test_tenant", "test_client")); + assert_eq!(key, 17569805883005093029); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/data_access_token/claims.rs b/dsh_sdk/src/protocol_adapters/token/data_access_token/claims.rs new file mode 100644 index 0000000..187c797 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/data_access_token/claims.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +/// Permissions per topic for the [`DataAccessToken`](super::DataAccessToken). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash)] +#[serde(rename_all = "kebab-case")] +pub struct TopicPermission { + /// Publish or Subscribe + action: Action, + /// The resource to define what the client can access in terms of stream, prefix, topic, and type. + resource: Resource, +} + +/// `publish` or `subscribe` permisison for [`TopicPermission`]. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy, Hash)] +pub enum Action { + #[serde(alias = "publish")] + Publish, + #[serde(alias = "subscribe")] + Subscribe, +} + +/// Represents a resource/datastream in the [`TopicPermission`] claim. +/// +/// The resource defines what the client can access in terms of stream, prefix, topic, and type. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash)] +#[serde(rename_all = "kebab-case")] +struct Resource { + /// The type of the resource (always "topic"). + #[serde(rename = "type")] + resource_type: String, + /// data stream name, e.g. weather or ivi + stream: String, + /// topic prefix, e.g. /tt + prefix: String, + /// topic pattern, e.g. +/+/+/something/# + topic: String, +} + +impl TopicPermission { + /// Creates a new [`TopicPermission`] instance. + /// + /// # Arguments + /// + /// * `resource` - the resource to define what the client can access in terms of stream, prefix, topic, and type. + /// * `action` - the action to define what the client can do with the resource. + /// + /// # Returns + /// + /// Returns a new [`TopicPermission`] instance. + pub fn new( + action: Action, + stream: impl Into, + prefix: impl Into, + topic_pattern: impl Into, + ) -> Self { + let resource = Resource::new(stream, prefix, topic_pattern); + Self { resource, action } + } + + /// Returns the full qualified topic name of resource. + pub fn full_qualified_topic_name(&self) -> String { + format!( + "{}/{}/{}", + self.resource.prefix, self.resource.stream, self.resource.topic + ) + } + + /// topic prefix, e.g. `/tt` + pub fn prefix(&self) -> &str { + &self.resource.prefix + } + + /// data stream name, e.g. `weather` or `ivi` + pub fn stream(&self) -> &str { + &self.resource.stream + } + + /// topic pattern, e.g. `+/+/+/something/#` + pub fn topic_pattern(&self) -> &str { + &self.resource.topic + } + + /// Returns the [`Action`] to define what the client can do with the resource. + pub fn action(&self) -> Action { + self.action + } +} + +impl Resource { + /// Creates a new [`Resource`] instance. + /// + /// # Arguments + /// + /// * `stream` - data stream name, e.g. `weather` or `ivi` + /// * `prefix` - topic prefix, e.g.`/tt` + /// * `topic` - topic pattern, e.g. `+/+/+/something/#` + /// + /// # Returns + /// + /// Returns a new [`Resource`] instance. + pub fn new( + stream: impl Into, + prefix: impl Into, + topic_pattern: impl Into, + ) -> Self { + Self { + stream: stream.into(), + prefix: prefix.into(), + topic: topic_pattern.into(), + resource_type: "topic".to_string(), // always topic + } + } +} + +impl std::fmt::Display for Action { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Publish => write!(f, "publish"), + Self::Subscribe => write!(f, "subscribe"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_topic_permission_new() { + let topic_permission = TopicPermission::new(Action::Publish, "stream", "prefix", "topic/#"); + assert_eq!(topic_permission.action(), Action::Publish); + assert_eq!(topic_permission.stream(), "stream"); + assert_eq!(topic_permission.prefix(), "prefix"); + assert_eq!(topic_permission.topic_pattern(), "topic/#"); + } + + #[test] + fn test_resource_new() { + let resource = Resource::new("stream", "prefix", "topic/#"); + assert_eq!(resource.stream, "stream"); + assert_eq!(resource.prefix, "prefix"); + assert_eq!(resource.topic, "topic/#"); + } + + #[test] + fn test_action_display() { + let action = Action::Publish; + assert_eq!(action.to_string(), "publish"); + let action = Action::Subscribe; + assert_eq!(action.to_string(), "subscribe"); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/data_access_token/mod.rs b/dsh_sdk/src/protocol_adapters/token/data_access_token/mod.rs new file mode 100644 index 0000000..60c7822 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/data_access_token/mod.rs @@ -0,0 +1,73 @@ +//! Access Token to authenticate to the DSH Mqtt or Http brokers +use super::ProtocolTokenError; + +mod claims; +mod request; +mod token; + +#[doc(inline)] +pub use claims::{Action, TopicPermission}; +#[doc(inline)] +pub use request::RequestDataAccessToken; +#[doc(inline)] +pub use token::{DataAccessToken, Ports}; + +/// Validates if a string can be used as a client_id +/// +/// DSH Allows the following as a client_id: +/// - A maximum of 64 characters +/// - Can only contain: +/// - Alphanumeric characters (a-z, A-z, 0-9) +/// - @, -, _, . and : +/// +/// it will return an [ProtocolTokenError::InvalidClientId] if the client_id is invalid +/// including the reason why it is invalid +/// +/// # Example +/// ``` +/// # use dsh_sdk::protocol_adapters::token::data_access_token::validate_client_id; +/// // valid client id's +/// assert!(validate_client_id("client-12345").is_ok()); +/// assert!(validate_client_id("ABCDEFasbcdef1234567890@-_.:").is_ok()); +/// +/// // invalid client id's +/// assert!(validate_client_id("client A").is_err()); +/// assert!(validate_client_id("1234567890qwertyuiopasdfghjklzxcvbnmz1234567890qwertyuiopasdfghjklzxcvbnmz").is_err()); +/// ``` +pub fn validate_client_id(id: impl AsRef) -> Result<(), ProtocolTokenError> { + let ref_id = id.as_ref(); + if !ref_id.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '@' || c == '-' || c == '_' || c == '.' || c == ':' + }) { + Err(ProtocolTokenError::InvalidClientId( + ref_id.to_string(), + "client_id: Can only contain: Alphanumeric characters (a-z, A-z, 0-9) @, -, _, . and :", + )) + } else if ref_id.len() > 64 { + // Note this works because all valid characters are ASCII and have a single byte + Err(ProtocolTokenError::InvalidClientId( + ref_id.to_string(), + "Exceeded a maximum of 64 characters", + )) + } else { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_client_id() { + assert!(validate_client_id("ABCDEF1234567890@-_.:asbcdef").is_ok()); + assert!(validate_client_id("!").is_err()); + assert!(validate_client_id( + "1234567890qwertyuiopasdfghjklzxcvbnmz1234567890qwertyuiopasdfghjklzxcvbnmz" + ) + .is_err()); + assert!(validate_client_id("client A").is_err()); + assert!(validate_client_id("client\nA").is_err()); + assert!(validate_client_id(r#"client\nA"#).is_err()); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/data_access_token/request.rs b/dsh_sdk/src/protocol_adapters/token/data_access_token/request.rs new file mode 100644 index 0000000..e81df68 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/data_access_token/request.rs @@ -0,0 +1,207 @@ +use serde::{Deserialize, Serialize}; + +use super::claims::TopicPermission; +use super::token::DataAccessToken; +use crate::protocol_adapters::token::rest_token::RestToken; +use crate::protocol_adapters::token::ProtocolTokenError; +use crate::utils::ensure_https_prefix; + +/// Request for geting a [`DataAccessToken`] which can be used to authenticate to the DSH Mqtt or Http brokers +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RequestDataAccessToken { + /// Tenant name + tenant: String, + /// Unique client ID that must be used when connecting to the broker + id: String, + /// Requested expiration time (in seconds since UNIX epoch) + #[serde(skip_serializing_if = "Option::is_none")] + exp: Option, + /// Optional list of topic permissions + #[serde(skip_serializing_if = "Option::is_none")] + claims: Option>, + /// DSH Client Claims optional field for commiumicating between external clients and DSH + #[serde(skip_serializing_if = "Option::is_none")] + dshclc: Option, +} + +impl RequestDataAccessToken { + /// + /// client_id: Has a maximum of 64 characters + /// Can only contain: + /// haracters (a-z, A-z, 0-9) + /// @, -, _, . and : + pub fn new(tenant: impl Into, client_id: impl Into) -> Self { + Self { + tenant: tenant.into(), + id: client_id.into(), + exp: None, + claims: None, + dshclc: None, + } + } + + /// Returns the set tenant name + pub fn tenant(&self) -> &str { + &self.tenant + } + + /// Returns the set client ID + pub fn id(&self) -> &str { + &self.id + } + + /// Set the requested expiration time for the token. + pub fn set_exp(mut self, exp: i64) -> Self { + self.exp = Some(exp); + self + } + + /// Returns the requested expiration time for the token. + pub fn exp(&self) -> Option { + self.exp + } + + /// Set a list of [`TopicPermission`] for the token. + pub fn set_claims(mut self, claims: Vec) -> Self { + self.claims = Some(claims); + self + } + + /// Extend the list of [`TopicPermission`] for the token. + pub fn extend_claims(mut self, claims: impl Iterator) -> Self { + self.claims.get_or_insert_with(Vec::new).extend(claims); + self + } + + /// Returns the list of [`TopicPermission`] for the token. + pub fn claims(&self) -> Option<&Vec> { + self.claims.as_ref() + } + + /// Set the DSH Client Claims. + /// + /// This field is optional and can be used to communicate between external clients and the API client authentication service. + pub fn set_dshclc(mut self, dshclc: impl Into) -> Self { + self.dshclc = Some(dshclc.into()); + self + } + + /// Returns the DSH Client Claims. + pub fn dshclc(&self) -> Option<&serde_json::Value> { + self.dshclc.as_ref() + } + + /// Send the request to the DSH platform to get a [`DataAccessToken`]. + /// + /// # Arguments + /// - `client` - The reqwest client to use for the request. + /// - `rest_token` - The rest token to use for the request. + /// + /// # Returns + /// The [`DataAccessToken`] if the request was successful. + /// Otherwise a [`ProtocolTokenError`] is returned. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::data_access_token::RequestDataAccessToken; + /// use dsh_sdk::protocol_adapters::token::rest_token::RestToken; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let client = reqwest::Client::new(); + /// let rest_token = RestToken::parse("valid.jwt.token")?; + /// let request = RequestDataAccessToken::new("example_tenant", "Example-client-id"); + /// let token = request.send(&client, rest_token).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send( + &self, + client: &reqwest::Client, + rest_token: RestToken, + ) -> Result { + super::validate_client_id(&self.id)?; + + let auth_url = ensure_https_prefix(format!( + "{}/datastreams/v0/mqtt/token", + rest_token.endpoint(), + )); + log::debug!("Sending request to '{}': {:?}", auth_url, self); + let response = client + .post(&auth_url) + .header( + reqwest::header::AUTHORIZATION, + format!("Bearer {}", rest_token.raw_token()), + ) + .json(self) + .send() + .await?; + let status = response.status(); + let body_text = response.text().await?; + match status { + reqwest::StatusCode::OK => Ok(DataAccessToken::parse(body_text)?), + _ => Err(ProtocolTokenError::DshCall { + url: auth_url, + status_code: status, + error_body: body_text, + }), + } + } +} + +impl PartialEq for RequestDataAccessToken { + fn eq(&self, other: &Self) -> bool { + // Ignore the exp field + self.tenant == other.tenant + && self.id == other.id + && self.claims == other.claims + && self.dshclc == other.dshclc + } +} + +impl std::hash::Hash for RequestDataAccessToken { + fn hash(&self, state: &mut H) { + // Ignore the exp field + self.tenant.hash(state); + self.id.hash(state); + self.claims.hash(state); + self.dshclc.hash(state); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_request_new() { + let request = RequestDataAccessToken::new("test_tenant", "test_id"); + assert_eq!(request.tenant, "test_tenant"); + assert_eq!(request.id, "test_id"); + assert_eq!(request.exp, None); + assert_eq!(request.claims, None); + assert_eq!(request.dshclc, None); + } + + #[tokio::test] + async fn test_send_success() { + let mut opt: mockito::ServerOpts = mockito::ServerOpts::default(); + opt.port = 7998; + let mut mockito_server = mockito::Server::new_with_opts_async(opt).await; + let rest_token = RestToken::parse("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6Imh0dHA6Ly8xMjcuMC4wLjE6Nzk5OCIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19.NsCVyQ8Cmp1N6QmFs1n8EgD0HJDC6zZaOxW_6xu4m10").unwrap(); + let _m = mockito_server + .mock("POST", "/datastreams/v0/mqtt/token") + .match_header("Authorization", "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6Imh0dHA6Ly8xMjcuMC4wLjE6Nzk5OCIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19.NsCVyQ8Cmp1N6QmFs1n8EgD0HJDC6zZaOxW_6xu4m10") + .with_status(200) + .with_body("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywiaWF0IjoyMTQ3NDgzNjQ3LCJlbmRwb2ludCI6InRlc3RfZW5kcG9pbnQiLCJwb3J0cyI6eyJtcXR0cyI6Wzg4ODNdLCJtcXR0d3NzIjpbNDQzLDg0NDNdfSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImNsYWltcyI6W3siYWN0aW9uIjoic3Vic2NyaWJlIiwicmVzb3VyY2UiOnsidHlwZSI6InRvcGljIiwicHJlZml4IjoiL3R0Iiwic3RyZWFtIjoidGVzdCIsInRvcGljIjoiL3Rlc3QvIyJ9fV19.LwYIMIX39J502TDqpEqH5T2Rlj-HczeT3WLfs5Do3B0") + .create(); + + let client = reqwest::Client::new(); + let request = RequestDataAccessToken::new("test_tenant", "test_client"); + let result = request.send(&client, rest_token).await; + + assert!(result.is_ok()); + let token = result.unwrap(); + assert!(token.is_valid()); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/data_access_token/token.rs b/dsh_sdk/src/protocol_adapters/token/data_access_token/token.rs new file mode 100644 index 0000000..bef392f --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/data_access_token/token.rs @@ -0,0 +1,230 @@ +//! Access Token to authenticate to the DSH Mqtt or Http brokers +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use super::claims::TopicPermission; +use crate::protocol_adapters::token::{JwtToken, ProtocolTokenError}; + +/// Access Token to authenticate to the DSH Mqtt or Http brokers +#[derive(Serialize, Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct DataAccessToken { + gen: i32, + pub(crate) endpoint: String, + ports: Ports, + iss: String, + claims: Vec, + exp: i64, + client_id: String, + iat: i32, + tenant_id: String, + #[serde(skip)] + raw_token: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Ports { + mqtts: Vec, + mqttwss: Vec, +} + +impl DataAccessToken { + /// Creates a new [`DataAccessToken`] instance based on a raw JWT Token. + pub fn parse(raw_token: impl Into) -> Result { + let raw_token = raw_token.into(); + let jwt_token = JwtToken::parse(&raw_token)?; + + let mut token: Self = serde_json::from_slice(&jwt_token.b64_decode_payload()?)?; + token.raw_token = raw_token; + + Ok(token) + } + + pub(crate) fn init() -> Self { + Self { + gen: 0, + endpoint: "".to_string(), + ports: Ports { + mqtts: vec![], + mqttwss: vec![], + }, + iss: "".to_string(), + claims: Vec::new(), + exp: 0, + client_id: "".to_string(), + iat: 0, + tenant_id: "".to_string(), + raw_token: "".to_string(), + } + } + + /// Returns the generation of the token. + pub fn gen(&self) -> i32 { + self.gen + } + + /// Returns the endpoint which the MQTT client should connect to. + pub fn endpoint(&self) -> &str { + &self.endpoint + } + + /// Returns the endpoint which the MQTT websocket client should connect to. + pub fn endpoint_wss(&self) -> String { + format!("wss://{}/mqtt", self.endpoint) + } + + /// Returns the port number which the MQTT client should connect to for `mqtt` protocol. + pub fn port_mqtt(&self) -> u16 { + *self.ports.mqtts.get(0).unwrap_or(&8883) + } + + /// Returns the port number which the MQTT client should connect to for `websocket` protocol. + pub fn port_wss(&self) -> u16 { + *self.ports.mqttwss.get(0).unwrap_or(&443) + } + + /// Returns the [`Ports`] which the MQTT client can connect to. + pub fn ports(&self) -> &Ports { + &self.ports + } + + /// Returns the iss. + pub fn iss(&self) -> &str { + &self.iss + } + + /// Returns the [`TopicPermission`] of the token + pub fn claims(&self) -> &Vec { + &self.claims + } + + /// Returns the expiration time (in seconds since UNIX epoch). + pub fn exp(&self) -> i64 { + self.exp + } + + /// Returns the client_id + pub fn client_id(&self) -> &str { + &self.client_id + } + + /// Returns the issued at time (in seconds since UNIX epoch). + pub fn iat(&self) -> i32 { + self.iat + } + + /// Returns the tenant name. + pub fn tenant_id(&self) -> &str { + &self.tenant_id + } + + /// Returns the raw JWT token. + pub fn raw_token(&self) -> &str { + &self.raw_token + } + + /// Checks if the token is valid. + pub fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i64; + self.exp >= current_unixtime + 5 && !self.raw_token.is_empty() + } +} + +impl Ports { + pub fn mqtts(&self) -> &Vec { + &self.mqtts + } + + pub fn mqttwss(&self) -> &Vec { + &self.mqttwss + } +} + +impl std::fmt::Debug for DataAccessToken { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DataAccessToken") + .field("gen", &self.gen) + .field("endpoint", &self.endpoint) + .field("iss", &self.iss) + .field("claims", &self.claims) + .field("exp", &self.exp) + .field("client_id", &self.client_id) + .field("iat", &self.iat) + .field("tenant_id", &self.tenant_id) + .field( + "raw_token", + &self + .raw_token + .split('.') + .take(2) + .collect::>() + .join("."), + ) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_data_access_token() { + let raw_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywiaWF0IjoyMTQ3NDgzNjQ3LCJlbmRwb2ludCI6InRlc3RfZW5kcG9pbnQiLCJwb3J0cyI6eyJtcXR0cyI6Wzg4ODNdLCJtcXR0d3NzIjpbNDQzLDg0NDNdfSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImNsYWltcyI6W3siYWN0aW9uIjoic3Vic2NyaWJlIiwicmVzb3VyY2UiOnsidHlwZSI6InRvcGljIiwicHJlZml4IjoiL3R0Iiwic3RyZWFtIjoidGVzdCIsInRvcGljIjoiL3Rlc3QvIyJ9fV19.LwYIMIX39J502TDqpEqH5T2Rlj-HczeT3WLfs5Do3B0"; + let token = DataAccessToken::parse(raw_token).unwrap(); + assert_eq!(token.gen(), 1); + assert_eq!(token.endpoint(), "test_endpoint"); + assert_eq!(token.port_mqtt(), 8883); + assert_eq!(token.port_wss(), 443); + assert_eq!(token.iss(), "String"); + assert_eq!(token.exp(), 2147483647); + assert_eq!(token.iat(), 2147483647); + assert_eq!(token.client_id(), "test_client"); + assert_eq!(token.tenant_id(), "test_tenant"); + assert_eq!(token.raw_token(), raw_token); + assert!(token.is_valid()); + } + + #[test] + fn test_init_data_access_token() { + let token = DataAccessToken::init(); + assert_eq!(token.gen(), 0); + assert_eq!(token.endpoint(), ""); + assert_eq!(token.port_mqtt(), 8883); + assert_eq!(token.port_wss(), 443); + assert_eq!(token.iss(), ""); + assert_eq!(token.exp(), 0); + assert_eq!(token.iat(), 0); + assert_eq!(token.client_id(), ""); + assert_eq!(token.tenant_id(), ""); + assert_eq!(token.raw_token(), ""); + assert!(!token.is_valid()); + } + + #[test] + fn test_is_valid_data_access_token() { + let mut token = DataAccessToken::init(); + assert!(!token.is_valid()); + token.exp = 1; + assert!(!token.is_valid()); + token.raw_token = "test".to_string(); + assert!(!token.is_valid()); + token.exp = 2147483647; + assert!(token.is_valid()); + } + + #[test] + fn test_debug_data_access_token() { + let raw_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywiaWF0IjoyMTQ3NDgzNjQ3LCJlbmRwb2ludCI6InRlc3RfZW5kcG9pbnQiLCJwb3J0cyI6eyJtcXR0cyI6Wzg4ODNdLCJtcXR0d3NzIjpbNDQzLDg0NDNdfSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImNsYWltcyI6W3siYWN0aW9uIjoic3Vic2NyaWJlIiwicmVzb3VyY2UiOnsidHlwZSI6InRvcGljIiwicHJlZml4IjoiL3R0Iiwic3RyZWFtIjoidGVzdCIsInRvcGljIjoiL3Rlc3QvIyJ9fV19.LwYIMIX39J502TDqpEqH5T2Rlj-HczeT3WLfs5Do3B0"; + let token = DataAccessToken::parse(raw_token).unwrap(); + let debug = format!("{:?}", token); + assert_eq!(debug,"DataAccessToken { gen: 1, endpoint: \"test_endpoint\", iss: \"String\", claims: [TopicPermission { action: Subscribe, resource: Resource { resource_type: \"topic\", stream: \"test\", prefix: \"/tt\", topic: \"/test/#\" } }], exp: 2147483647, client_id: \"test_client\", iat: 2147483647, tenant_id: \"test_tenant\", raw_token: \"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywiaWF0IjoyMTQ3NDgzNjQ3LCJlbmRwb2ludCI6InRlc3RfZW5kcG9pbnQiLCJwb3J0cyI6eyJtcXR0cyI6Wzg4ODNdLCJtcXR0d3NzIjpbNDQzLDg0NDNdfSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImNsYWltcyI6W3siYWN0aW9uIjoic3Vic2NyaWJlIiwicmVzb3VyY2UiOnsidHlwZSI6InRvcGljIiwicHJlZml4IjoiL3R0Iiwic3RyZWFtIjoidGVzdCIsInRvcGljIjoiL3Rlc3QvIyJ9fV19\" }"); + let init_token = format!("{:?}", DataAccessToken::init()); + let debug = format!("{:?}", init_token); + assert_eq!(debug, "\"DataAccessToken { gen: 0, endpoint: \\\"\\\", iss: \\\"\\\", claims: [], exp: 0, client_id: \\\"\\\", iat: 0, tenant_id: \\\"\\\", raw_token: \\\"\\\" }\""); + } +} diff --git a/dsh_sdk/src/protocol_adapters/error.rs b/dsh_sdk/src/protocol_adapters/token/error.rs similarity index 71% rename from dsh_sdk/src/protocol_adapters/error.rs rename to dsh_sdk/src/protocol_adapters/token/error.rs index 4c2f797..96ab086 100644 --- a/dsh_sdk/src/protocol_adapters/error.rs +++ b/dsh_sdk/src/protocol_adapters/token/error.rs @@ -1,5 +1,5 @@ #[cfg(feature = "protocol-token-fetcher")] -/// Error type for the protocol adapter token fetcher +/// Error type for the protocol tokens #[derive(Debug, thiserror::Error)] pub enum ProtocolTokenError { #[error("Error calling: {url}, status code: {status_code}, error body: {error_body}")] @@ -8,12 +8,16 @@ pub enum ProtocolTokenError { status_code: reqwest::StatusCode, error_body: String, }, + #[error("JWT Parse error: {0}")] + Jwt(String), + #[error("Invalid client_id: {0} - Reason: {1}")] + InvalidClientId(String, &'static str), #[error("Reqwest: {0}")] Reqwest(#[from] reqwest::Error), #[error("Serde_json error: {0}")] Json(#[from] serde_json::Error), #[error("IO Error: {0}")] Io(#[from] std::io::Error), - #[error("JWT Parse error: {0}")] - Jwt(String), + #[error("Base64 decode error: {0}")] + Base64Decode(#[from] base64::DecodeError), } diff --git a/dsh_sdk/src/protocol_adapters/token/mod.rs b/dsh_sdk/src/protocol_adapters/token/mod.rs new file mode 100644 index 0000000..226c133 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/mod.rs @@ -0,0 +1,74 @@ +//! The [`RestToken`] and [`DataAccessToken`] which can be used to authenticate against the DSH platform. +pub mod api_client_token_fetcher; +pub mod data_access_token; +mod error; +pub mod rest_token; + +#[doc(inline)] +pub use data_access_token::{Action, DataAccessToken, RequestDataAccessToken, TopicPermission}; +#[doc(inline)] +pub use error::ProtocolTokenError; +#[doc(inline)] +pub use rest_token::{Claims, DatastreamsMqttTokenClaim, RequestRestToken, RestToken}; + +#[allow(dead_code)] +#[derive(Debug, Clone)] +struct JwtToken { + header: String, + payload: String, + signature: String, +} + +impl JwtToken { + /// Extracts the header, payload and signature part of a JWT token. + /// + /// # Arguments + /// + /// * `raw_token` - The raw JWT token string. + /// + /// # Returns + /// + /// A Result containing the [JwtToken] or a [`ProtocolTokenError`]. + fn parse(raw_token: &str) -> Result { + let parts: Vec<&str> = raw_token.split('.').collect(); + if parts.len() != 3 { + return Err(ProtocolTokenError::Jwt(format!( + "Invalid JWT token {}", + raw_token + ))); + } + Ok(JwtToken { + header: parts[0].to_string(), + payload: parts[1].to_string(), + signature: parts[2].to_string(), + }) + } + + fn b64_decode_payload(&self) -> Result, ProtocolTokenError> { + use base64::engine::general_purpose::STANDARD_NO_PAD; + use base64::Engine; + Ok(STANDARD_NO_PAD.decode(self.payload.as_bytes())?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_jwt() { + let raw = "header.payload.signature"; + let result = JwtToken::parse(raw).unwrap(); + assert_eq!(result.header, "header"); + assert_eq!(result.payload, "payload"); + assert_eq!(result.signature, "signature"); + + let raw = "header.payload"; + let result = JwtToken::parse(raw); + assert!(result.is_err()); + + let raw = "header"; + let result = JwtToken::parse(raw); + assert!(result.is_err()); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/rest_token/claims.rs b/dsh_sdk/src/protocol_adapters/token/rest_token/claims.rs new file mode 100644 index 0000000..f151055 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/rest_token/claims.rs @@ -0,0 +1,164 @@ +use serde::{Deserialize, Serialize}; + +/// Represents the claims for the [`RestToken`](super::RestToken) +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Claims { + // TODO: inverstigate if this is complete + #[serde(rename = "datastreams/v0/mqtt/token")] + mqtt_token_claim: DatastreamsMqttTokenClaim, +} + +impl Default for Claims { + fn default() -> Self { + Self { + mqtt_token_claim: DatastreamsMqttTokenClaim::default(), + } + } +} + +impl Claims { + pub fn set_mqtt_token_claim(mut self, claim: DatastreamsMqttTokenClaim) -> Self { + self.mqtt_token_claim = claim; + self + } + + /// Returns the MQTT token claim + pub fn mqtt_token_claim(&self) -> &DatastreamsMqttTokenClaim { + &self.mqtt_token_claim + } +} + +/// "datastreams/v0/mqtt/token" endpoint claim in the [`RestToken`](super::RestToken) +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DatastreamsMqttTokenClaim { + /// External Client ID + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + /// Tenant name + #[serde(skip_serializing_if = "Option::is_none")] + tenant: Option, + /// Maximum token lifetime in seconds for to be requested [`DataAccessToken`](super::data_access_token::DataAccessToken) in seconds + #[serde(skip_serializing_if = "Option::is_none")] + relexp: Option, + /// Requested expiration time in seconds (in seconds since UNIX epoch) + #[serde(skip_serializing_if = "Option::is_none")] + exp: Option, + /// Requested expiration time in seconds (in seconds since UNIX epoch) + #[serde(skip_serializing_if = "Option::is_none")] + claims: Option>, // TODO: investigate which claims are possible +} + +impl DatastreamsMqttTokenClaim { + /// Creates a new `DatastreamsMqttTokenClaim` instance. + pub fn new() -> Self { + Self::default() + } + + /// Sets the external client ID for which the [`RestToken`](super::RestToken) is requested + pub fn set_id(mut self, id: impl Into) -> Self { + self.id = Some(id.into()); + self + } + + /// Returns the external client ID for which the [`RestToken`](super::RestToken) is requested + pub fn id(&self) -> Option<&str> { + self.id.as_deref() + } + + /// Sets the tenant name + pub fn set_tenant(mut self, tenant: impl Into) -> Self { + self.tenant = Some(tenant.into()); + self + } + + /// Returns the tenant name + pub fn tenant(&self) -> Option<&str> { + self.tenant.as_deref() + } + + /// Sets the requested expiration time in seconds for [`DataAccessToken`](crate::protocol_adapters::token::DataAccessToken) (in seconds since from moment of request) + pub fn set_relexp(mut self, relexp: i32) -> Self { + self.relexp = Some(relexp); + self + } + + /// Returns the requested expiration time in seconds (in seconds since from now) + pub fn relexp(&self) -> Option { + self.relexp + } + + /// Sets the requested expiration time in seconds (in seconds since UNIX epoch) + pub fn set_exp(mut self, exp: i32) -> Self { + self.exp = Some(exp); + self + } + + /// Returns the requested expiration time in seconds (in seconds since UNIX epoch) + pub fn exp(&self) -> Option { + self.exp + } +} + +impl Default for DatastreamsMqttTokenClaim { + fn default() -> Self { + Self { + id: None, + tenant: None, + relexp: None, + exp: None, + claims: None, + } + } +} + +impl From for Claims { + fn from(claim: DatastreamsMqttTokenClaim) -> Self { + Self { + mqtt_token_claim: claim, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_datastreams_mqtt_token_claim() { + let claim = DatastreamsMqttTokenClaim::new(); + assert_eq!(claim.id(), None); + assert_eq!(claim.tenant(), None); + assert_eq!(claim.relexp(), None); + assert_eq!(claim.exp(), None); + let claim = claim + .set_id("test-id") + .set_tenant("test-tenant") + .set_relexp(100) + .set_exp(200); + assert_eq!(claim.id(), Some("test-id")); + assert_eq!(claim.tenant(), Some("test-tenant")); + assert_eq!(claim.relexp(), Some(100)); + assert_eq!(claim.exp(), Some(200)); + } + + #[test] + fn test_from_mqtt_claims() { + let claim = DatastreamsMqttTokenClaim::new(); + let claims: Claims = claim.clone().into(); + assert_eq!(claims.mqtt_token_claim(), &claim); + let claim = claim + .set_id("test-id") + .set_tenant("test-tenant") + .set_relexp(100) + .set_exp(200); + let claims: Claims = claim.clone().into(); + assert_eq!(claims.mqtt_token_claim(), &claim); + } + + #[test] + fn test_claims() { + let claim = DatastreamsMqttTokenClaim::new(); + let claims = Claims::default(); + assert_eq!(claims.mqtt_token_claim(), &claim); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/rest_token/mod.rs b/dsh_sdk/src/protocol_adapters/token/rest_token/mod.rs new file mode 100644 index 0000000..a9c96ec --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/rest_token/mod.rs @@ -0,0 +1,11 @@ +//! Rest token to be used for fetching [`DataAccessToken`](crate::protocol_adapters::token::data_access_token::DataAccessToken) +mod claims; +mod request; +mod token; + +#[doc(inline)] +pub use claims::{Claims, DatastreamsMqttTokenClaim}; +#[doc(inline)] +pub use request::RequestRestToken; +#[doc(inline)] +pub use token::RestToken; diff --git a/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs b/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs new file mode 100644 index 0000000..1b32846 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs @@ -0,0 +1,166 @@ +use serde::{Deserialize, Serialize}; + +use super::claims::Claims; +use super::token::RestToken; +use crate::protocol_adapters::token::ProtocolTokenError; + +/// Request for geting a [`RestToken`] which can be used to get a [`DataAccessToken`](crate::protocol_adapters::token::data_access_token::DataAccessToken). +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct RequestRestToken { + /// Tenant name + tenant: String, + #[serde(skip_serializing_if = "Option::is_none")] + /// Requested expiration time in seconds (in seconds since UNIX epoch) + exp: Option, + /// Requested claims and permissions that the [`DataAccessToken`](crate::protocol_adapters::token::data_access_token::DataAccessToken) should have + #[serde(skip_serializing_if = "Option::is_none")] + claims: Option, +} + +impl RequestRestToken { + /// Creates a new [`RequestRestToken`] instance with full access request. + pub fn new(tenant: impl Into) -> Self { + Self { + tenant: tenant.into(), + exp: None, + claims: None, + } + } + + /// Send the request to the DSH platform to get a [`RestToken`]. + /// + /// # Arguments + /// - `client` - The reqwest client to use for the request. + /// - `api_key` - The API key to authenticate to the DSH platform. + /// - `auth_url` - The URL of the DSH platform to send the request to (See [Platform::endpoint_protocol_rest_token](crate::Platform::endpoint_protocol_rest_token)). + /// + /// # Returns + /// The [`RestToken`] if the request was successful. + /// Otherwise a [`ProtocolTokenError`] is returned. + /// + /// # Example + /// ```no_run + /// use dsh_sdk::protocol_adapters::token::RequestRestToken; + /// use dsh_sdk::Platform; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let request = RequestRestToken::new("example_tenant"); + /// let client = reqwest::Client::new(); + /// let platform = Platform::NpLz; + /// let token = request.send(&client, "API_KEY", platform.endpoint_protocol_rest_token()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send( + &self, + client: &reqwest::Client, + api_key: &str, + auth_url: &str, + ) -> Result { + log::debug!("Sending request to '{}': {:?}", auth_url, self); + let response = client + .post(auth_url) + .header("apikey", api_key) + .json(self) + .send() + .await?; + + let status = response.status(); + let body_text = response.text().await?; + match status { + reqwest::StatusCode::OK => Ok(RestToken::parse(body_text)?), + _ => Err(ProtocolTokenError::DshCall { + url: auth_url.to_string(), + status_code: status, + error_body: body_text, + }), + } + } + + /// Returns the tenant + pub fn tenant(&self) -> &str { + &self.tenant + } + + /// Sets the expiration time (in seconds since UNIX epoch) + pub fn set_exp(mut self, exp: i64) -> Self { + self.exp = Some(exp); + self + } + + /// Returns the expiration time (in seconds since UNIX epoch) + pub fn exp(&self) -> Option { + self.exp + } + + /// Sets the claims + pub fn set_claims(mut self, claims: impl Into) -> Self { + self.claims = Some(claims.into()); + self + } + /// Returns the claims + pub fn claims(&self) -> Option<&Claims> { + self.claims.as_ref() + } + + /// Returns the client_id if it is set in claims + pub fn client_id(&self) -> Option<&str> { + self.claims.as_ref().and_then(|c| c.mqtt_token_claim().id()) + } +} + +impl PartialEq for RequestRestToken { + fn eq(&self, other: &Self) -> bool { + // Ignore the requested expiration time, not relevant for equality + self.tenant == other.tenant && self.claims == other.claims + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol_adapters::token::rest_token::DatastreamsMqttTokenClaim; + use mockito::Matcher; + use serde_json::json; + + #[test] + fn test_rest_token_request() { + let request = RequestRestToken::new("test-tenant"); + assert_eq!(request.tenant(), "test-tenant"); + assert_eq!(request.exp(), None); + assert_eq!(request.claims(), None); + let claims: Claims = DatastreamsMqttTokenClaim::new().set_exp(1).into(); + let request = request.set_exp(100).set_claims(claims.clone()); + let request = request; + assert_eq!(request.exp(), Some(100)); + assert_eq!(request.claims(), Some(&claims)) + } + + #[tokio::test] + async fn test_send_success() { + let mut mockito_server = mockito::Server::new_async().await; + let _m = mockito_server + .mock("POST", "/protocol_auth_url") + .match_header("apikey", "test_token") + .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) + .with_status(200) + .with_body("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MjE0NzQ4MzY0NywidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6InRlc3RfZW5wb2ludCIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19.Eh2-UBOgame_cQw5iHjc19-hRZXAPxMYlCHVCwcE8CU") + .create(); + + let client = reqwest::Client::new(); + let request = RequestRestToken::new("test_tenant"); + let result = request + .send( + &client, + "test_token", + &format!("{}/protocol_auth_url", mockito_server.url()), + ) + .await; + + assert!(result.is_ok()); + let token = result.unwrap(); + assert!(token.is_valid()); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token/rest_token/token.rs b/dsh_sdk/src/protocol_adapters/token/rest_token/token.rs new file mode 100644 index 0000000..8d09709 --- /dev/null +++ b/dsh_sdk/src/protocol_adapters/token/rest_token/token.rs @@ -0,0 +1,181 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use super::claims::Claims; +use crate::protocol_adapters::token::{JwtToken, ProtocolTokenError}; + +/// Token to request a [`DataAccessToken`](crate::protocol_adapters::token::data_access_token::DataAccessToken). +#[derive(Serialize, Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct RestToken { + gen: i64, + endpoint: String, + iss: String, + claims: Claims, + exp: i64, + tenant_id: String, + #[serde(skip)] + raw_token: String, +} + +impl RestToken { + /// Creates a new [`RestToken`] instance based on a JWT Token. + pub fn parse(raw_token: impl Into) -> Result { + let raw_token = raw_token.into(); + let jwt_token = JwtToken::parse(&raw_token)?; + + let mut token: Self = serde_json::from_slice(&jwt_token.b64_decode_payload()?)?; + token.raw_token = raw_token; + + Ok(token) + } + + pub(crate) fn init() -> Self { + Self { + gen: 0, + endpoint: "".to_string(), + iss: "".to_string(), + claims: Claims::default(), + exp: 0, + tenant_id: "".to_string(), + raw_token: "".to_string(), + } + } + + pub fn gen(&self) -> i64 { + self.gen + } + + /// Returns the endpoint which the MQTT client should connect to + pub fn endpoint(&self) -> &str { + &self.endpoint + } + + /// Returns the iss + pub fn iss(&self) -> &str { + &self.iss + } + + /// Returns the claims + /// + /// The claims are (optional) API endpoints and restrictions that the token can access. + /// If no claims are present, the token will have full access to all endpoints. + pub fn claims(&self) -> &Claims { + &self.claims + } + + /// Returns the client_id if it is set in claims + pub fn client_id(&self) -> Option<&str> { + self.claims.mqtt_token_claim().id() + } + + /// Returns the expiration time (in seconds since UNIX epoch) + pub fn exp(&self) -> i64 { + self.exp + } + + /// Returns the tenant id + pub fn tenant_id(&self) -> &str { + &self.tenant_id + } + + /// Returns the raw token + pub fn raw_token(&self) -> &str { + &self.raw_token + } + // Checks if the token is valid + pub fn is_valid(&self) -> bool { + let current_unixtime = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("SystemTime before UNIX EPOCH!") + .as_secs() as i64; + self.exp >= current_unixtime + 5 && !self.raw_token.is_empty() + } +} + +impl std::fmt::Debug for RestToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RestToken") + .field("gen", &self.gen) + .field("endpoint", &self.endpoint) + .field("iss", &self.iss) + .field("claims", &self.claims) + .field("exp", &self.exp) + .field("tenant_id", &self.tenant_id) + .field( + "raw_token", + &self + .raw_token + .split('.') + .take(2) + .collect::>() + .join("."), + ) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_rest_token() { + let raw_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJzdHJpbmciLCJnZW4iOjEsImV4cCI6MTczOTU0Nzg3OCwidGVuYW50LWlkIjoiZm9vIiwiZW5kcG9pbnQiOiJ0ZXN0X2VuZHBvaW50IiwiY2xhaW1zIjp7ImRhdGFzdHJlYW1zL3YwL21xdHQvdG9rZW4iOnsiaWQiOiJqdXN0LXRoaXMtZGV2aWNlIiwiZXhwIjoxNzM5NTQ3ODc4LCJ0ZW5hbnQiOiJmb28iLCJjbGFpbXMiOltdfX19.signature"; + let token = RestToken::parse(raw_token.to_string()).unwrap(); + assert_eq!(token.gen(), 1); + assert_eq!(token.endpoint(), "test_endpoint"); + assert_eq!(token.iss(), "string"); + assert_eq!( + token.claims().mqtt_token_claim().id(), + Some("just-this-device") + ); + assert_eq!(token.claims().mqtt_token_claim().exp(), Some(1739547878)); + assert_eq!(token.claims().mqtt_token_claim().tenant(), Some("foo")); + assert_eq!(token.exp(), 1739547878); + assert_eq!(token.tenant_id(), "foo"); + assert_eq!(token.raw_token(), raw_token); + } + + #[test] + fn test_init_rest_token() { + let token = RestToken::init(); + assert_eq!(token.gen(), 0); + assert_eq!(token.endpoint(), ""); + assert_eq!(token.iss(), ""); + assert_eq!(token.claims().mqtt_token_claim().id(), None); + assert_eq!(token.claims().mqtt_token_claim().exp(), None); + assert_eq!(token.claims().mqtt_token_claim().tenant(), None); + assert_eq!(token.exp(), 0); + assert_eq!(token.tenant_id(), ""); + assert_eq!(token.raw_token(), ""); + } + + #[test] + fn test_is_valid() { + let mut token = RestToken::init(); + assert!(!token.is_valid()); + token.exp = 0; + assert!(!token.is_valid()); + token.raw_token = "test".to_string(); + assert!(!token.is_valid()); + token.exp = 2147483647; + assert!(token.is_valid()); + } + + #[test] + fn test_debug_rest_token() { + let raw_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6Imh0dHA6Ly8xMjcuMC4wLjE6Nzk5OSIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19.j5ekqMiWyBhJyRQE_aARFS9mQJiN7S2rpKTsn3rZ5lQ"; + let token = RestToken::parse(raw_token).unwrap(); + assert_eq!( + format!("{:?}", token), + "RestToken { gen: 1, endpoint: \"http://127.0.0.1:7999\", iss: \"String\", claims: Claims { mqtt_token_claim: DatastreamsMqttTokenClaim { id: None, tenant: None, relexp: None, exp: None, claims: None } }, exp: 1, tenant_id: \"test_tenant\", raw_token: \"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJTdHJpbmciLCJnZW4iOjEsImV4cCI6MSwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQiLCJlbmRwb2ludCI6Imh0dHA6Ly8xMjcuMC4wLjE6Nzk5OSIsImNsYWltcyI6eyJkYXRhc3RyZWFtcy92MC9tcXR0L3Rva2VuIjp7fX19\" }" + ); + let token = RestToken::init(); + assert_eq!( + format!("{:?}", token), + "RestToken { gen: 0, endpoint: \"\", iss: \"\", claims: Claims { mqtt_token_claim: DatastreamsMqttTokenClaim { id: None, tenant: None, relexp: None, exp: None, claims: None } }, exp: 0, tenant_id: \"\", raw_token: \"\" }" + ); + } +} diff --git a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs b/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs deleted file mode 100644 index 818f5c4..0000000 --- a/dsh_sdk/src/protocol_adapters/token_fetcher/mod.rs +++ /dev/null @@ -1,869 +0,0 @@ -//! Protocol Token Fetcher -//! -//! `ProtocolTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH. -use std::collections::{hash_map::Entry, HashMap}; -use std::fmt::{Display, Formatter}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use serde::{Deserialize, Serialize}; -use serde_json::json; -use sha2::{Digest, Sha256}; -use tokio::sync::RwLock; - -use super::ProtocolTokenError; -use crate::Platform; - -/// `ProtocolTokenFetcher` is responsible for fetching and managing tokens for the DSH Mqtt and Http protocol adapters. -/// -/// It ensures that the tokens are valid, and if not, it refreshes them automatically. The struct -/// is thread-safe and can be shared across multiple threads. - -pub struct ProtocolTokenFetcher { - tenant_name: String, - rest_api_key: String, - rest_token: RwLock, - rest_auth_url: String, - protocol_token: RwLock>, // Mapping from Client ID to ProtocolToken - protocol_auth_url: String, - client: reqwest::Client, - //token_lifetime: Option, // TODO: Implement option of passing token lifetime to request token for specific duration - // port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too -} - -/// Constructs a new `ProtocolTokenFetcher`. -/// -/// # Arguments -/// -/// * `tenant_name` - The tenant name in DSH. -/// * `rest_api_key` - The REST API key used for authentication. -/// * `platform` - The DSH platform environment -/// -/// # Returns -/// -/// Returns a `Result` containing a `ProtocolTokenFetcher` instance or a `ProtocolTokenError`. -impl ProtocolTokenFetcher { - /// Constructs a new `ProtocolTokenFetcher`. - /// - /// # Arguments - /// - /// * `tenant_name` - The tenant name of DSH. - /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. - /// * `platform` - The target DSH platform environment. - /// - /// # Example - /// - /// ```no_run - /// use dsh_sdk::protocol_adapters::ProtocolTokenFetcher; - /// use dsh_sdk::Platform; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let tenant_name = "test_tenant".to_string(); - /// let api_key = "aAbB123".to_string(); - /// let platform = Platform::NpLz; - /// - /// let fetcher = ProtocolTokenFetcher::new(tenant_name, api_key, platform); - /// let token = fetcher.get_token("test_client", None).await.unwrap(); - /// # } - /// ``` - pub fn new(tenant_name: String, api_key: String, platform: Platform) -> Self { - const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); - - let reqwest_client = reqwest::Client::builder() - .timeout(DEFAULT_TIMEOUT) - .http1_only() - .build() - .expect("Failed to build reqwest client"); - Self::new_with_client(tenant_name, api_key, platform, reqwest_client) - } - - /// Constructs a new `ProtocolTokenFetcher` with a custom reqwest client. - /// On this Reqwest client, you can set custom timeouts, headers, Rustls etc. - /// - /// # Arguments - /// - /// * `tenant_name` - The tenant name of DSH. - /// * `api_key` - The realted API key of tenant used for authentication to fetech Token for MQTT. - /// * `platform` - The target DSH platform environment. - /// * `client` - User configured reqwest client to be used for fetching tokens - /// - /// # Example - /// - /// ```no_run - /// use dsh_sdk::protocol_adapters::ProtocolTokenFetcher; - /// use dsh_sdk::Platform; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let tenant_name = "test_tenant".to_string(); - /// let api_key = "aAbB123".to_string(); - /// let platform = Platform::NpLz; - /// let client = reqwest::Client::new(); - /// let fetcher = ProtocolTokenFetcher::new_with_client(tenant_name, api_key, platform, client); - /// let token = fetcher.get_token("test_client", None).await.unwrap(); - /// # } - /// ``` - pub fn new_with_client( - tenant_name: String, - api_key: String, - platform: Platform, - client: reqwest::Client, - ) -> Self { - let rest_token = RestToken::default(); - Self { - tenant_name, - rest_api_key: api_key, - rest_token: RwLock::new(rest_token), - rest_auth_url: platform.endpoint_rest_token().to_string(), - protocol_token: RwLock::new(HashMap::new()), - protocol_auth_url: platform.endpoint_protocol_token().to_string(), - client, - } - } - /// Retrieves an MQTT token for the specified client ID. - /// - /// If the token is expired or does not exist, it fetches a new token. - /// - /// # Arguments - /// - /// * `client_id` - The identifier for the MQTT client. - /// * `claims` - Optional claims for the MQTT token. - /// - /// # Returns - /// - /// Returns a `Result` containing the `ProtocolToken` or a `ProtocolTokenError`. - pub async fn get_token( - &self, - client_id: &str, - claims: Option>, - ) -> Result { - match self - .protocol_token - .write() - .await - .entry(client_id.to_string()) - { - Entry::Occupied(mut entry) => { - let protocol_token = entry.get_mut(); - if !protocol_token.is_valid() { - *protocol_token = self.fetch_new_protocol_token(client_id, claims).await?; - }; - Ok(protocol_token.clone()) - } - Entry::Vacant(entry) => { - let protocol_token = self.fetch_new_protocol_token(client_id, claims).await?; - entry.insert(protocol_token.clone()); - Ok(protocol_token) - } - } - } - - /// Fetches a new MQTT token from the platform. - /// - /// This method handles token validation and fetching the token - async fn fetch_new_protocol_token( - &self, - client_id: &str, - claims: Option>, - ) -> Result { - let mut rest_token = self.rest_token.write().await; - - if !rest_token.is_valid() { - *rest_token = RestToken::get( - &self.client, - &self.tenant_name, - &self.rest_api_key, - &self.rest_auth_url, - ) - .await? - } - - let authorization_header = format!("Bearer {}", rest_token.raw_token); - - let protocol_token_request = - ProtocolTokenRequest::new(client_id, &self.tenant_name, claims)?; - let payload = serde_json::to_value(&protocol_token_request)?; - - let response = protocol_token_request - .send( - &self.client, - &self.protocol_auth_url, - &authorization_header, - &payload, - ) - .await?; - - ProtocolToken::new(response) - } -} - -/// Represent Claims information for MQTT request -/// * `action` - can be subscribe or publish -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Claims { - resource: Resource, - action: String, -} - -impl Claims { - pub fn new(resource: Resource, action: Actions) -> Claims { - Claims { - resource, - action: action.to_string(), - } - } -} - -/// Enumeration representing possible actions in MQTT claims. -pub enum Actions { - Publish, - Subscribe, -} - -impl Display for Actions { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match self { - Actions::Publish => write!(f, "publish"), - Actions::Subscribe => write!(f, "subscribe"), - } - } -} - -/// Represents a resource in the MQTT claim. -/// -/// The resource defines what the client can access in terms of stream, prefix, topic, and type. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Resource { - stream: String, - prefix: String, - topic: String, - #[serde(rename = "type")] - type_: Option, -} - -impl Resource { - /// Creates a new `Resource` instance. Please check DSH MQTT Documentation for further explanation of the fields. - /// - /// # Arguments - /// - /// * `stream` - The data stream name. - /// * `prefix` - The prefix of the topic. - /// * `topic` - The topic name. - /// * `type_` - The optional type of the resource. - /// - /// - /// # Returns - /// - /// Returns a new `Resource` instance. - pub fn new(stream: String, prefix: String, topic: String, type_: Option) -> Resource { - Resource { - stream, - prefix, - topic, - type_, - } - } -} - -#[derive(Serialize)] -struct ProtocolTokenRequest { - id: String, - tenant: String, - claims: Option>, -} - -impl ProtocolTokenRequest { - fn new( - client_id: &str, - tenant: &str, - claims: Option>, - ) -> Result { - let mut hasher = Sha256::new(); - hasher.update(client_id); - let result = hasher.finalize(); - let id = format!("{:x}", result); - - Ok(Self { - id, - tenant: tenant.to_string(), - claims, - }) - } - - async fn send( - &self, - reqwest_client: &reqwest::Client, - protocol_auth_url: &str, - authorization_header: &str, - payload: &serde_json::Value, - ) -> Result { - let response = reqwest_client - .post(protocol_auth_url) - .header("Authorization", authorization_header) - .json(payload) - .send() - .await?; - - if response.status().is_success() { - Ok(response.text().await?) - } else { - Err(ProtocolTokenError::DshCall { - url: protocol_auth_url.to_string(), - status_code: response.status(), - error_body: response.text().await?, - }) - } - } -} - -/// Represents attributes associated with a mqtt token. -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "kebab-case")] -struct ProtocolTokenAttributes { - gen: i32, - endpoint: String, - iss: String, - claims: Option>, - exp: i32, - client_id: String, - iat: i32, - tenant_id: String, -} - -/// Represents a token used for MQTT connections. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ProtocolToken { - exp: i32, - raw_token: String, -} - -impl ProtocolToken { - /// Creates a new instance of `ProtocolToken` from a raw token string. - /// - /// # Arguments - /// - /// * `raw_token` - The raw token string. - /// - /// # Returns - /// - /// A Result containing the created ProtocolToken or an error. - pub fn new(raw_token: String) -> Result { - let header_payload = extract_header_and_payload(&raw_token)?; - - let decoded_token = decode_base64(header_payload)?; - - let token_attributes: ProtocolTokenAttributes = serde_json::from_slice(&decoded_token)?; - let token = ProtocolToken { - exp: token_attributes.exp, - raw_token, - }; - - Ok(token) - } - - /// Checks if the MQTT token is still valid. - fn is_valid(&self) -> bool { - let current_unixtime = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - .as_secs() as i32; - self.exp >= current_unixtime + 5 - } -} - -/// Represents attributes associated with a Rest token. -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "kebab-case")] -struct RestTokenAttributes { - gen: i64, - endpoint: String, - iss: String, - claims: RestClaims, - exp: i32, - tenant_id: String, -} - -#[derive(Serialize, Deserialize, Debug)] -struct RestClaims { - #[serde(rename = "datastreams/v0/mqtt/token")] - datastreams_token: DatastreamsData, -} - -#[derive(Serialize, Deserialize, Debug)] -struct DatastreamsData {} - -/// Represents a rest token with its raw value and attributes. -#[derive(Serialize, Deserialize, Debug)] -struct RestToken { - raw_token: String, - exp: i32, -} - -impl RestToken { - /// Retrieves a new REST token from the platform. - /// - /// # Arguments - /// - /// * `tenant` - The tenant name associated with the DSH platform. - /// * `api_key` - The REST API key used for authentication. - /// * `env` - The platform environment (e.g., production, staging). - /// - /// # Returns - /// - /// A Result containing the created `RestToken` or a `ProtocolTokenError`. - async fn get( - client: &reqwest::Client, - tenant: &str, - api_key: &str, - auth_url: &str, - ) -> Result { - let raw_token = Self::fetch_token(client, tenant, api_key, auth_url).await?; - - let header_payload = extract_header_and_payload(&raw_token)?; - - let decoded_token = decode_base64(header_payload)?; - - let token_attributes: RestTokenAttributes = serde_json::from_slice(&decoded_token)?; - let token = RestToken { - raw_token, - exp: token_attributes.exp, - }; - - Ok(token) - } - - // Checks if the REST token is still valid. - fn is_valid(&self) -> bool { - let current_unixtime = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - .as_secs() as i32; - self.exp >= current_unixtime + 5 - } - - async fn fetch_token( - client: &reqwest::Client, - tenant: &str, - api_key: &str, - auth_url: &str, - ) -> Result { - let json_body = json!({"tenant": tenant}); - - let response = client - .post(auth_url) - .header("apikey", api_key) - .json(&json_body) - .send() - .await?; - - let status = response.status(); - let body_text = response.text().await?; - match status { - reqwest::StatusCode::OK => Ok(body_text), - _ => Err(ProtocolTokenError::DshCall { - url: auth_url.to_string(), - status_code: status, - error_body: body_text, - }), - } - } -} - -impl Default for RestToken { - fn default() -> Self { - Self { - raw_token: "".to_string(), - exp: 0, - } - } -} - -/// Extracts the header and payload part of a JWT token. -/// -/// # Arguments -/// -/// * `raw_token` - The raw JWT token string. -/// -/// # Returns -/// -/// A Result containing the header and payload part of the JWT token or a `ProtocolTokenError`. -fn extract_header_and_payload(raw_token: &str) -> Result<&str, ProtocolTokenError> { - let parts: Vec<&str> = raw_token.split('.').collect(); - parts.get(1).copied().ok_or_else(|| { - ProtocolTokenError::Jwt("Cannot extract header and payload from raw_token".to_string()) - }) -} - -/// Decodes a Base64-encoded string. -/// -/// # Arguments -/// -/// * `payload` - The Base64-encoded string. -/// -/// # Returns -/// -/// A Result containing the decoded byte vector or a `ProtocolTokenError`. -fn decode_base64(payload: &str) -> Result, ProtocolTokenError> { - use base64::{alphabet, engine, read}; - use std::io::Read; - - let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::NO_PAD); - let mut decoder = read::DecoderReader::new(payload.as_bytes(), &engine); - - let mut decoded_token = Vec::new(); - decoder.read_to_end(&mut decoded_token)?; - - Ok(decoded_token) -} - -#[cfg(test)] -mod tests { - use super::*; - use mockito::Matcher; - - async fn create_valid_fetcher() -> ProtocolTokenFetcher { - let exp_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600; - println!("exp_time: {}", exp_time); - let rest_token: RestToken = RestToken { - exp: exp_time as i32, - raw_token: "valid.token.payload".to_string(), - }; - let protocol_token = ProtocolToken { - exp: exp_time, - raw_token: "valid.token.payload".to_string(), - }; - let protocol_token_map = RwLock::new(HashMap::new()); - protocol_token_map - .write() - .await - .insert("test_client".to_string(), protocol_token.clone()); - ProtocolTokenFetcher { - tenant_name: "test_tenant".to_string(), - rest_api_key: "test_api_key".to_string(), - rest_token: RwLock::new(rest_token), - rest_auth_url: "test_auth_url".to_string(), - protocol_token: protocol_token_map, - client: reqwest::Client::new(), - protocol_auth_url: "test_auth_url".to_string(), - } - } - - #[tokio::test] - async fn test_protocol_token_fetcher_new() { - let tenant_name = "test_tenant".to_string(); - let rest_api_key = "test_api_key".to_string(); - let platform = Platform::NpLz; - - let fetcher = ProtocolTokenFetcher::new(tenant_name, rest_api_key, platform); - - assert!(fetcher.protocol_token.read().await.is_empty()); - } - - #[tokio::test] - async fn test_protocol_token_fetcher_new_with_client() { - let tenant_name = "test_tenant".to_string(); - let rest_api_key = "test_api_key".to_string(); - let platform = Platform::NpLz; - - let client = reqwest::Client::builder().use_rustls_tls().build().unwrap(); - let fetcher = - ProtocolTokenFetcher::new_with_client(tenant_name, rest_api_key, platform, client); - - assert!(fetcher.protocol_token.read().await.is_empty()); - } - - #[tokio::test] - async fn test_fetch_new_protocol_token() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server.mock("POST", "/rest_auth_url") - .with_status(200) - .with_body(r#"{"raw_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImNsYWltcyI6W3sicmVzb3VyY2UiOiJ0ZXN0IiwiYWN0aW9uIjoicHVzaCJ9XSwiZXhwIjoxLCJjbGllbnQtaWQiOiJ0ZXN0X2NsaWVudCIsImlhdCI6MCwidGVuYW50LWlkIjoidGVzdF90ZW5hbnQifQ.WCf03qyxV1NwxXpzTYF7SyJYwB3uAkQZ7u-TVrDRJgE"}"#) - .create_async() - .await; - let _m2 = mockito_server.mock("POST", "/protocol_auth_url") - .with_status(200) - .with_body(r#"{"protocol_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJnZW4iOjEsImVuZHBvaW50IjoidGVzdF9lbmRwb2ludCIsImlzcyI6IlN0cmluZyIsImV4cCI6MSwiY2xpZW50LWlkIjoidGVzdF9jbGllbnQiLCJpYXQiOjAsInRlbmFudC1pZCI6InRlc3RfdGVuYW50In0.VwlKomR4OnLtLX-NwI-Fpol8b6t-kmptRS_vPnwNd3A"}"#) - .create(); - - let client = reqwest::Client::new(); - let rest_token = RestToken { - raw_token: "initial_token".to_string(), - exp: 0, - }; - - let fetcher = ProtocolTokenFetcher { - client, - tenant_name: "test_tenant".to_string(), - rest_api_key: "test_api_key".to_string(), - protocol_token: RwLock::new(HashMap::new()), - rest_auth_url: mockito_server.url() + "/rest_auth_url", - protocol_auth_url: mockito_server.url() + "/protocol_auth_url", - rest_token: RwLock::new(rest_token), - }; - - let result = fetcher - .fetch_new_protocol_token("test_client_id", None) - .await; - println!("{:?}", result); - assert!(result.is_ok()); - let protocol_token = result.unwrap(); - assert_eq!(protocol_token.exp, 1); - } - - #[tokio::test] - async fn test_protocol_token_fetcher_get_token() { - let fetcher = create_valid_fetcher().await; - let token = fetcher.get_token("test_client", None).await.unwrap(); - assert_eq!(token.raw_token, "valid.token.payload"); - } - - #[test] - fn test_actions_display() { - let action = Actions::Publish; - assert_eq!(action.to_string(), "publish"); - let action = Actions::Subscribe; - assert_eq!(action.to_string(), "subscribe"); - } - - #[test] - fn test_token_request_new() { - let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); - assert_eq!(request.id.len(), 64); - assert_eq!(request.tenant, "test_tenant"); - } - - #[tokio::test] - async fn test_send_success() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/protocol_auth_url") - .match_header("Authorization", "Bearer test_token") - .match_body(Matcher::Json(json!({"key": "value"}))) - .with_status(200) - .with_body("success_response") - .create(); - - let client = reqwest::Client::new(); - let payload = json!({"key": "value"}); - let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); - let result = request - .send( - &client, - &format!("{}/protocol_auth_url", mockito_server.url()), - "Bearer test_token", - &payload, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "success_response"); - } - - #[tokio::test] - async fn test_send_failure() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/protocol_auth_url") - .match_header("Authorization", "Bearer test_token") - .match_body(Matcher::Json(json!({"key": "value"}))) - .with_status(400) - .with_body("error_response") - .create(); - - let client = reqwest::Client::new(); - let payload = json!({"key": "value"}); - let request = ProtocolTokenRequest::new("test_client", "test_tenant", None).unwrap(); - let result = request - .send( - &client, - &format!("{}/protocol_auth_url", mockito_server.url()), - "Bearer test_token", - &payload, - ) - .await; - - assert!(result.is_err()); - if let Err(ProtocolTokenError::DshCall { - url, - status_code, - error_body, - }) = result - { - assert_eq!(url, format!("{}/protocol_auth_url", mockito_server.url())); - assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body, "error_response"); - } else { - panic!("Expected DshCallError"); - } - } - - #[test] - fn test_claims_new() { - let resource = Resource::new( - "stream".to_string(), - "prefix".to_string(), - "topic".to_string(), - None, - ); - let action = Actions::Publish; - - let claims = Claims::new(resource.clone(), action); - - assert_eq!(claims.resource.stream, "stream"); - assert_eq!(claims.action, "publish"); - } - - #[test] - fn test_resource_new() { - let resource = Resource::new( - "stream".to_string(), - "prefix".to_string(), - "topic".to_string(), - None, - ); - - assert_eq!(resource.stream, "stream"); - assert_eq!(resource.prefix, "prefix"); - assert_eq!(resource.topic, "topic"); - } - - #[test] - fn test_protocol_token_is_valid() { - let raw_token = "valid.token.payload".to_string(); - let token = ProtocolToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600, - raw_token, - }; - - assert!(token.is_valid()); - } - #[test] - fn test_protocol_token_is_invalid() { - let raw_token = "valid.token.payload".to_string(); - let token = ProtocolToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32, - raw_token, - }; - - assert!(!token.is_valid()); - } - - #[test] - fn test_rest_token_is_valid() { - let token = RestToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32 - + 3600, - raw_token: "valid.token.payload".to_string(), - }; - - assert!(token.is_valid()); - } - - #[test] - fn test_rest_token_is_invalid() { - let token = RestToken { - exp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i32, - raw_token: "valid.token.payload".to_string(), - }; - - assert!(!token.is_valid()); - } - - #[test] - fn test_rest_token_default_is_invalid() { - let token = RestToken::default(); - - assert!(!token.is_valid()); - } - - #[test] - fn test_extract_header_and_payload() { - let raw = "header.payload.signature"; - let result = extract_header_and_payload(raw).unwrap(); - assert_eq!(result, "payload"); - - let raw = "header.payload"; - let result = extract_header_and_payload(raw).unwrap(); - assert_eq!(result, "payload"); - - let raw = "header"; - let result = extract_header_and_payload(raw); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_fetch_token_success() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/auth_url") - .match_header("apikey", "test_api_key") - .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) - .with_status(200) - .with_body("test_token") - .create(); - - let client = reqwest::Client::new(); - let result = RestToken::fetch_token( - &client, - "test_tenant", - "test_api_key", - &format!("{}/auth_url", mockito_server.url()), - ) - .await; - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "test_token"); - } - - #[tokio::test] - async fn test_fetch_token_failure() { - let mut mockito_server = mockito::Server::new_async().await; - let _m = mockito_server - .mock("POST", "/auth_url") - .match_header("apikey", "test_api_key") - .match_body(Matcher::Json(json!({"tenant": "test_tenant"}))) - .with_status(400) - .with_body("error_response") - .create(); - - let client = reqwest::Client::new(); - let result = RestToken::fetch_token( - &client, - "test_tenant", - "test_api_key", - &format!("{}/auth_url", mockito_server.url()), - ) - .await; - - assert!(result.is_err()); - if let Err(ProtocolTokenError::DshCall { - url, - status_code, - error_body, - }) = result - { - assert_eq!(url, format!("{}/auth_url", mockito_server.url())); - assert_eq!(status_code, reqwest::StatusCode::BAD_REQUEST); - assert_eq!(error_body, "error_response"); - } else { - panic!("Expected DshCallError"); - } - } -} diff --git a/dsh_sdk/src/rest_api_token_fetcher.rs b/dsh_sdk/src/rest_api_token_fetcher.rs index 12fe063..0917765 100644 --- a/dsh_sdk/src/rest_api_token_fetcher.rs +++ b/dsh_sdk/src/rest_api_token_fetcher.rs @@ -136,7 +136,7 @@ impl RestTokenFetcher { /// let platform = Platform::NpLz; /// let client_id = platform.rest_client_id("my-tenant"); /// let client_secret = "my-secret".to_string(); - /// let token_fetcher = RestTokenFetcher::new(client_id, client_secret, platform.endpoint_rest_access_token().to_string()); + /// let token_fetcher = RestTokenFetcher::new(client_id, client_secret, platform.endpoint_management_api_token().to_string()); /// let token = token_fetcher.get_token().await.unwrap(); /// } /// ``` @@ -162,7 +162,7 @@ impl RestTokenFetcher { /// let client_id = platform.rest_client_id("my-tenant"); /// let client_secret = "my-secret".to_string(); /// let client = reqwest::Client::new(); - /// let token_fetcher = RestTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_rest_access_token().to_string(), client); + /// let token_fetcher = RestTokenFetcher::new_with_client(client_id, client_secret, platform.endpoint_management_api_token().to_string(), client); /// let token = token_fetcher.get_token().await.unwrap(); /// } /// ``` @@ -354,7 +354,7 @@ impl RestTokenFetcherBuilder { let token_fetcher = RestTokenFetcher::new_with_client( client_id, client_secret, - self.platform.endpoint_rest_access_token().to_string(), + self.platform.endpoint_management_api_token().to_string(), client, ); Ok(token_fetcher) @@ -525,7 +525,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } #[test] @@ -543,7 +543,7 @@ mod test { format!("robot:{}:{}", Platform::NpLz.realm(), tenant_name) ); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } #[test] @@ -560,7 +560,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } #[test] @@ -577,7 +577,7 @@ mod test { .unwrap(); assert_eq!(tf.client_id, client_id_override); assert_eq!(tf.client_secret, client_secret); - assert_eq!(tf.auth_url, Platform::NpLz.endpoint_rest_access_token()); + assert_eq!(tf.auth_url, Platform::NpLz.endpoint_management_api_token()); } #[test] diff --git a/dsh_sdk/src/utils/mod.rs b/dsh_sdk/src/utils/mod.rs index 0364e56..a3573a1 100644 --- a/dsh_sdk/src/utils/mod.rs +++ b/dsh_sdk/src/utils/mod.rs @@ -111,6 +111,15 @@ pub(crate) fn get_env_var(var_name: &'static str) -> Result } } +/// Helper function to ensure that the host starts with `https://` or `http://`. +pub(crate) fn ensure_https_prefix(host: impl AsRef) -> String { + if host.as_ref().starts_with("http://") || host.as_ref().starts_with("https://") { + host.as_ref().to_string() + } else { + format!("https://{}", host.as_ref()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -157,4 +166,19 @@ mod tests { let result = get_env_var("TEST_ENV_VAR").unwrap(); assert_eq!(result, "test_value"); } + + #[test] + fn test_ensure_https_prefix() { + let host = "http://example.com"; + let result = ensure_https_prefix(host); + assert_eq!(result, "http://example.com"); + + let host = "https://example.com"; + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + + let host = "example.com"; + let result = ensure_https_prefix(host); + assert_eq!(result, "https://example.com"); + } } diff --git a/dsh_sdk/src/utils/platform.rs b/dsh_sdk/src/utils/platform.rs index 6173c91..04a89ec 100644 --- a/dsh_sdk/src/utils/platform.rs +++ b/dsh_sdk/src/utils/platform.rs @@ -123,25 +123,6 @@ impl Platform { } } - /// Returns the URL endpoint for retrieving DSH REST API OAuth tokens. - /// - /// # Example - /// ``` - /// # use dsh_sdk::Platform; - /// let platform = Platform::NpLz; - /// let token_url = platform.endpoint_rest_access_token(); - /// assert_eq!(token_url, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); - /// ``` - pub fn endpoint_rest_access_token(&self) -> &str { - match self { - Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", - Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", - Self::ProdLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/prod-lz-dsh/protocol/openid-connect/token", - Self::ProdAz => "https://auth.prod.cp.kpn-dsh.com/auth/realms/prod-azure-dsh/protocol/openid-connect/token", - Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", - } - } - #[deprecated( since = "0.5.0", note = "Use `dsh_sdk::Platform::endpoint_management_api_token` instead" @@ -165,27 +146,27 @@ impl Platform { /// # use dsh_sdk::Platform; /// let platform = Platform::NpLz; /// let mgmt_token_url = platform.endpoint_management_api_token(); - /// assert_eq!(mgmt_token_url, "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token"); + /// assert_eq!(mgmt_token_url, "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token"); /// ``` pub fn endpoint_management_api_token(&self) -> &str { match self { - Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", - Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token", - Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/auth/v0/token", - Self::ProdAz => "https://api.az.kpn-dsh.com/auth/v0/token", - Self::Poc => "https://api.poc.kpn-dsh.com/auth/v0/token", + Self::Prod => "https://auth.prod.cp.kpn-dsh.com/auth/realms/tt-dsh/protocol/openid-connect/token", + Self::NpLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/dev-lz-dsh/protocol/openid-connect/token", + Self::ProdLz => "https://auth.prod.cp-prod.dsh.prod.aws.kpn.com/auth/realms/prod-lz-dsh/protocol/openid-connect/token", + Self::ProdAz => "https://auth.prod.cp.kpn-dsh.com/auth/realms/prod-azure-dsh/protocol/openid-connect/token", + Self::Poc => "https://auth.prod.cp.kpn-dsh.com/auth/realms/poc-dsh/protocol/openid-connect/token", } } #[deprecated( since = "0.5.0", - note = "Use `dsh_sdk::Platform::endpoint_protocol_token` instead" + note = "Use `dsh_sdk::Platform::endpoint_protocol_access_token` instead" )] /// (Deprecated) Returns the DSH MQTT token endpoint. /// - /// *Prefer using [`endpoint_protocol_token`](Self::endpoint_protocol_token) instead.* + /// *Prefer using [`endpoint_protocol_access_token`](Self::endpoint_protocol_access_token) instead.* pub fn endpoint_mqtt_token(&self) -> &str { - self.endpoint_protocol_token() + self.endpoint_protocol_access_token() } /// Returns the endpoint for fetching DSH protocol tokens (e.g., for MQTT). @@ -194,10 +175,10 @@ impl Platform { /// ``` /// # use dsh_sdk::Platform; /// let platform = Platform::Prod; - /// let protocol_token_url = platform.endpoint_protocol_token(); + /// let protocol_token_url = platform.endpoint_protocol_access_token(); /// assert_eq!(protocol_token_url, "https://api.kpn-dsh.com/datastreams/v0/mqtt/token"); /// ``` - pub fn endpoint_protocol_token(&self) -> &str { + pub fn endpoint_protocol_access_token(&self) -> &str { match self { Self::Prod => "https://api.kpn-dsh.com/datastreams/v0/mqtt/token", Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/datastreams/v0/mqtt/token", @@ -207,6 +188,25 @@ impl Platform { } } + /// Returns the URL endpoint for retrieving DSH REST API OAuth tokens. + /// + /// # Example + /// ``` + /// # use dsh_sdk::Platform; + /// let platform = Platform::NpLz; + /// let token_url = platform.endpoint_protocol_rest_token(); + /// assert_eq!(token_url, "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token"); + /// ``` + pub fn endpoint_protocol_rest_token(&self) -> &str { + match self { + Self::Prod => "https://api.kpn-dsh.com/auth/v0/token", + Self::NpLz => "https://api.dsh-dev.dsh.np.aws.kpn.com/auth/v0/token", + Self::ProdLz => "https://api.dsh-prod.dsh.prod.aws.kpn.com/auth/v0/token", + Self::ProdAz => "https://api.az.kpn-dsh.com/auth/v0/token", + Self::Poc => "https://api.poc.kpn-dsh.com/auth/v0/token", + } + } + /// Returns the Keycloak realm string associated with this platform. /// /// This is used to construct OpenID Connect tokens (e.g., for Kafka or REST API authentication). From 0b92924611271ba980093ca6cf05d7c1e7515cf7 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:19:18 +0100 Subject: [PATCH 22/23] Finalize release, minor improvements (#117) * update README and examples * add note upgrade to v0.6 * Add AccessToken to public API * Into instead of AsRef * improve some minor stuff * add fav-icon and kpn logo to docs * cargo fmt --- dsh_rest_api_client/README.md | 8 ++-- dsh_sdk/CHANGELOG.md | 4 +- dsh_sdk/Cargo.toml | 11 +++-- dsh_sdk/README.md | 17 ++++--- dsh_sdk/examples/mqtt_ws_example.rs | 2 +- .../protocol_authentication_full_mediation.rs | 4 +- dsh_sdk/src/dsh.rs | 2 +- dsh_sdk/src/error.rs | 2 +- dsh_sdk/src/lib.rs | 44 +++++++++++-------- dsh_sdk/src/management_api/mod.rs | 2 +- dsh_sdk/src/management_api/token_fetcher.rs | 30 ++++++------- dsh_sdk/src/protocol_adapters/mod.rs | 4 +- dsh_sdk/src/protocol_adapters/token/error.rs | 2 +- dsh_sdk/src/protocol_adapters/token/mod.rs | 2 +- .../token/rest_token/request.rs | 10 ++++- dsh_sdk/src/schema_store/client.rs | 10 ----- .../schema_store/types/subject_strategy.rs | 4 +- 17 files changed, 80 insertions(+), 78 deletions(-) diff --git a/dsh_rest_api_client/README.md b/dsh_rest_api_client/README.md index 0838cc2..b8156e1 100644 --- a/dsh_rest_api_client/README.md +++ b/dsh_rest_api_client/README.md @@ -37,15 +37,15 @@ It is recommended to use the Rest Token Fetcher from the `dsh_sdk` crate. To do ```toml [dependencies] -dsh_rest_api_client = "0.2.0" -dsh_sdk = { version = "0.4", features = ["rest-token-fetcher"], default-features = false } +dsh_rest_api_client = "0.3.0" +dsh_sdk = { version = "0.5", features = ["management-api-token-fetcher"], default-features = false } tokio = { version = "1", features = ["full"] } ``` To use the client in your project: ```rust use dsh_rest_api_client::Client; -use dsh_sdk::{Platform, RestTokenFetcherBuilder}; +use dsh_sdk::{Platform, ManagementApiTokenFetcherBuilder}; const CLIENT_SECRET: &str = ""; const TENANT: &str = "tenant-name"; @@ -55,7 +55,7 @@ async fn main() { let platform = Platform::NpLz; let client = Client::new(platform.endpoint_rest_api()); - let tf = RestTokenFetcherBuilder::new(platform) + let tf = ManagementApiTokenFetcherBuilder::new(platform) .tenant_name(TENANT.to_string()) .client_secret(CLIENT_SECRET.to_string()) .build() diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index 0f78b5c..ab5be8b 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.5.0] - unreleased +## [0.5.0] - 2025-01-21 ### Added - DSH Kafka Config trait to configure kafka client with RDKafka implementation - DSH Schema store API Client @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Private module `dsh_sdk::dsh::bootstrap` and `dsh_sdk::dsh::pki_config_dir` are now part of `certificates` module - **Breaking change:** Moved `dsh_sdk::mqtt_token_fetcher` to `dsh_sdk::protocol_adapters::token` and renamed to `ApiClientTokenFetcher` - **NOTE** The code is refactored to follow the partial mediation and full mediation pattern - - **NOTE** Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token-fetcher` + - **NOTE** Cargo.toml feature flag `mqtt-token-fetcher` renamed to `protocol-token` - **Breaking change:** Renamed `dsh_sdk::Platform` methods to more meaningful names - **Breaking change:** Moved `dsh_sdk::dlq` to `dsh_sdk::utils::dlq` - **Breaking change:** Moved `dsh_sdk::graceful_shutdown` to `dsh_sdk::utils::graceful_shutdown` diff --git a/dsh_sdk/Cargo.toml b/dsh_sdk/Cargo.toml index 8baaa3a..9c3277b 100644 --- a/dsh_sdk/Cargo.toml +++ b/dsh_sdk/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true name = "dsh_sdk" readme = 'README.md' repository.workspace = true -version = "0.5.0-rc.2" +version = "0.5.0" [package.metadata.docs.rs] all-features = true @@ -48,13 +48,16 @@ rdkafka-config = ["rdkafka", "kafka"] # Impl of config trait only schema-store = ["bootstrap", "reqwest", "serde_json", "apache-avro", "protofish"] graceful-shutdown = ["tokio", "tokio-util"] management-api-token-fetcher = ["reqwest"] -protocol-token-fetcher = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] +protocol-token = ["base64", "reqwest", "serde_json", "sha2", "tokio/sync"] metrics = [ "hyper/server", "hyper/http1" , "hyper-util", "http-body-util", "tokio", "bytes"] dlq = ["tokio", "bootstrap", "rdkafka-config", "rdkafka/cmake-build", "rdkafka/ssl-vendored", "rdkafka/libz", "rdkafka/tokio", "graceful-shutdown"] -# http-protocol-adapter = ["protocol-token-fetcher"] -# mqtt-protocol-adapter = ["protocol-token-fetcher"] +# http-protocol-adapter = ["protocol-token"] +# mqtt-protocol-adapter = ["protocol-token"] # hyper-client = ["hyper", "hyper-util", "hyper-rustls", "rustls", "http", "rustls-pemfile"] +# deprecated! +mqtt-token-fetcher = ["protocol-token"] +rest-token-fetcher = ["management-api-token-fetcher"] [dev-dependencies] diff --git a/dsh_sdk/README.md b/dsh_sdk/README.md index 9876fe6..27b1546 100644 --- a/dsh_sdk/README.md +++ b/dsh_sdk/README.md @@ -8,9 +8,6 @@ A Rust SDK to interact with the DSH Platform. This library provides convenient building blocks for services that need to connect to DSH Kafka, fetch tokens for various protocols, manage Prometheus metrics, and more. -> **Note** -> This library (v0.5.x) is a _release candidate_. It may contain incomplete features and/or bugs. Future updates might introduce breaking changes. Please report any issues you find. - --- ## Table of Contents @@ -31,7 +28,9 @@ A Rust SDK to interact with the DSH Platform. This library provides convenient b ## Migration Guide 0.4.X -> 0.5.X -If you are migrating from `0.4.X` to `0.5.X`, please see the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for details on breaking changes and how to update your code accordingly. +If you are migrating from `v0.4.X` to `v0.5.X`(or `v0.6.X`), please see the [migration guide](https://github.com/kpn-dsh/dsh-sdk-platform-rs/wiki/Migration-guide-(v0.4.X-%E2%80%90--v0.5.X)) for details on breaking changes and how to update your code accordingly. + +`v0.6.0` will not contain any breaking changes, except for the removal of deprecated code. You can use `v0.5.X` as a stepping stone to `v0.6.0` by following the deprecation warnings the compiler gives. --- @@ -52,7 +51,7 @@ The `dsh-sdk-platform-rs` library offers: - **RDKafka** implementation - **Common Utilities** - - Prometheus metrics (built-in HTTP server, plus re-export of the `metrics` crate). + - Lightweight HTTP server for exposing Metircs. - Tokio-based graceful shutdown handling. - Dead Letter Queue (DLQ) functionality. @@ -64,7 +63,7 @@ To get started, add the following to your `Cargo.toml`: ```toml [dependencies] -dsh_sdk = "0.5.0-rc.2" +dsh_sdk = "0.5" rdkafka = { version = "0.37", features = ["cmake-build", "ssl-vendored"] } ``` @@ -118,7 +117,7 @@ Below is an overview of the available features: | `kafka` | ✓ | Enable `DshKafkaConfig` trait and Config struct to connect to DSH | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | | `rdkafka-config` | ✓ | Enable `DshKafkaConfig` implementation for RDKafka | [Kafka](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_example.rs) / [Kafka Proxy](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/kafka_proxy.rs) | | `schema-store` | ✗ | Interact with DSH Schema Store | [Schema Store API](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/schema_store_api.rs) | -| `protocol-token-fetcher` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Mqtt client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) / [Mqtt websocket client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) /
[token fetcher (full mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_full_mediation.rs) / [token fetcher (partial mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_partial_mediation.rs) | +| `protocol-token` | ✗ | Fetch tokens to use DSH Protocol adapters (MQTT and HTTP) | [Mqtt client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) / [Mqtt websocket client](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/mqtt_example.rs.rs) /
[Token fetcher (full mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_full_mediation.rs) / [Token fetcher (partial mediation)](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/protocol_authentication_partial_mediation.rs) | | `management-api-token-fetcher` | ✗ | Fetch tokens to use DSH Management API | [Token fetcher](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/management_api_token_fetcher.rs) | | `metrics` | ✗ | Enable prometheus metrics including http server | [Expose metrics](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/expose_metrics.rs) | | `graceful-shutdown` | ✗ | Tokio based graceful shutdown handler | [Graceful shutdown](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/dsh_sdk/examples/graceful_shutdown.rs) | @@ -130,7 +129,7 @@ To pick only the features you need, disable the default features and enable spec ```toml [dependencies] -dsh_sdk = { version = "0.5.0-rc.2", default-features = false, features = ["management-api-token-fetcher"] } +dsh_sdk = { version = "0.5", default-features = false, features = ["management-api-token-fetcher"] } ``` --- @@ -151,7 +150,7 @@ A more complete example is provided in the [`example_dsh_service/`](https://gith - How to build the Rust project - How to package and push it to Harbor -- An end-to-end setup of a DSH service +- An end-to-end setup of a DSH service uising Kafka See the [README](https://github.com/kpn-dsh/dsh-sdk-platform-rs/blob/main/example_dsh_service/README.md) in that directory for more information. diff --git a/dsh_sdk/examples/mqtt_ws_example.rs b/dsh_sdk/examples/mqtt_ws_example.rs index 58ccdc8..cde76d4 100644 --- a/dsh_sdk/examples/mqtt_ws_example.rs +++ b/dsh_sdk/examples/mqtt_ws_example.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), Box> { // Start logger to Stdout to show what is happening env_logger::builder() - .filter_level(log::LevelFilter::Trace) + .filter(Some("dsh_sdk"), log::LevelFilter::Trace) .target(env_logger::Target::Stdout) .init(); diff --git a/dsh_sdk/examples/protocol_authentication_full_mediation.rs b/dsh_sdk/examples/protocol_authentication_full_mediation.rs index 418eec9..e85b7fc 100644 --- a/dsh_sdk/examples/protocol_authentication_full_mediation.rs +++ b/dsh_sdk/examples/protocol_authentication_full_mediation.rs @@ -33,9 +33,7 @@ async fn main() -> Result<(), Box> { // Assume the API Authentication service receives a request from an external client. // We want to delegate a DataAccessToken with the following properties: // - Valid for 10 minutes - // - Allows fetching another DataAccessToken with: - // - Maximum expiration of 5 minutes - // - Usage restricted to the external client ID "External-client-id" + // - Allows subscribing to the topic "state/app/{tenant_name}" in the "amp" stream // Instantiate the API Client Token Fetcher let token_fetcher = ApiClientTokenFetcher::new(api_key, PLATFORM); diff --git a/dsh_sdk/src/dsh.rs b/dsh_sdk/src/dsh.rs index 1c7948f..91cb400 100644 --- a/dsh_sdk/src/dsh.rs +++ b/dsh_sdk/src/dsh.rs @@ -40,7 +40,7 @@ use crate::protocol_adapters::kafka_protocol::config::KafkaConfig; // TODO: Remove at v0.6.0 pub use crate::dsh_old::*; -/// Lazily initializes all related components to connect to DSH: +/// Lazily initializes all related components to connect to DSH and Kafka. /// - Information from `datastreams.json` /// - Metadata of the running container/task /// - Certificates for Kafka and DSH Schema Registry diff --git a/dsh_sdk/src/error.rs b/dsh_sdk/src/error.rs index c193e7e..0dc87ad 100644 --- a/dsh_sdk/src/error.rs +++ b/dsh_sdk/src/error.rs @@ -5,7 +5,7 @@ //! includes a helper function, [`report`], for generating a more readable error //! trace by iterating over source causes. -/// The main error type for the DSH SDK. +/// Errpors defined in [`Dsh`](super::Dsh). /// /// This enum wraps more specific errors from different parts of the SDK: /// - [`CertificatesError`](crate::certificates::CertificatesError) diff --git a/dsh_sdk/src/lib.rs b/dsh_sdk/src/lib.rs index da7acbb..c49c1f8 100644 --- a/dsh_sdk/src/lib.rs +++ b/dsh_sdk/src/lib.rs @@ -1,7 +1,12 @@ +#![doc( + html_favicon_url = "" +)] +#![doc( + html_logo_url = "" +)] #![doc = include_str!("../README.md")] #![allow(deprecated)] -// Keep in v0.6.0 for backward compatibility #[cfg(feature = "bootstrap")] pub mod certificates; #[cfg(feature = "bootstrap")] @@ -10,34 +15,26 @@ pub mod datastream; pub mod dsh; #[cfg(feature = "bootstrap")] mod error; - -// Management API token fetcher feature #[cfg(feature = "management-api-token-fetcher")] pub mod management_api; - // Protocol adapters and utilities pub mod protocol_adapters; -pub mod utils; - -// Schema Store feature #[cfg(feature = "schema-store")] pub mod schema_store; +pub mod utils; // Re-exports for convenience -#[cfg(feature = "bootstrap")] +#[cfg(feature = "management-api-token-fetcher")] #[doc(inline)] -pub use {dsh::Dsh, error::DshError}; - +pub use management_api::{ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; #[cfg(feature = "kafka")] #[doc(inline)] pub use protocol_adapters::kafka_protocol::DshKafkaConfig; - -#[cfg(feature = "management-api-token-fetcher")] -#[doc(inline)] -pub use management_api::{ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; - #[doc(inline)] pub use utils::Platform; +#[cfg(feature = "bootstrap")] +#[doc(inline)] +pub use {dsh::Dsh, error::DshError}; // TODO: to be removed in v0.6.0 #[cfg(feature = "dlq")] @@ -51,6 +48,7 @@ pub mod dlq; `dsh_sdk::certificates` for certificate management; `dsh_sdk::datastream` for \ datastream handling." )] +#[doc(hidden)] pub mod dsh_old; #[cfg(feature = "graceful-shutdown")] @@ -67,20 +65,28 @@ pub mod graceful_shutdown; )] pub mod metrics; -#[cfg(all(feature = "protocol-token-fetcher", feature = "bootstrap"))] +#[cfg(all( + any(feature = "protocol-token", feature = "mqtt-token-fetcher"), + feature = "bootstrap" +))] #[deprecated( since = "0.5.0", - note = "`dsh_sdk::mqtt_token_fetcher` is moved to `dsh_sdk::protocol_adapters::token_fetcher`" + note = "Use cargo feature `protocol-token` instead of `mqtt-token-fetcher` \ + `dsh_sdk::mqtt_token_fetcher` is moved to `dsh_sdk::protocol_adapters::token_fetcher`" )] pub mod mqtt_token_fetcher; #[cfg(feature = "bootstrap")] pub use dsh_old::Properties; -#[cfg(feature = "management-api-token-fetcher")] +#[cfg(any( + feature = "management-api-token-fetcher", + feature = "rest-token-fetcher" +))] #[deprecated( since = "0.5.0", - note = "`RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to \ + note = "Use cargo feature flag `management-api-token-fetcher` instead of `rest-token-fetcher` \ + `RestTokenFetcher` and `RestTokenFetcherBuilder` are renamed to \ `ManagementApiTokenFetcher` and `ManagementApiTokenFetcherBuilder`" )] mod rest_api_token_fetcher; diff --git a/dsh_sdk/src/management_api/mod.rs b/dsh_sdk/src/management_api/mod.rs index 6ffa0a2..a288103 100644 --- a/dsh_sdk/src/management_api/mod.rs +++ b/dsh_sdk/src/management_api/mod.rs @@ -38,4 +38,4 @@ mod token_fetcher; pub use error::ManagementApiTokenError; #[doc(inline)] -pub use token_fetcher::{ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; +pub use token_fetcher::{AccessToken, ManagementApiTokenFetcher, ManagementApiTokenFetcherBuilder}; diff --git a/dsh_sdk/src/management_api/token_fetcher.rs b/dsh_sdk/src/management_api/token_fetcher.rs index 19af4fd..9f876fe 100644 --- a/dsh_sdk/src/management_api/token_fetcher.rs +++ b/dsh_sdk/src/management_api/token_fetcher.rs @@ -178,9 +178,9 @@ impl ManagementApiTokenFetcher { /// # } /// ``` pub fn new( - client_id: impl AsRef, - client_secret: impl AsRef, - auth_url: impl AsRef, + client_id: impl Into, + client_secret: impl Into, + auth_url: impl Into, ) -> Self { Self::new_with_client( client_id, @@ -220,18 +220,18 @@ impl ManagementApiTokenFetcher { /// # } /// ``` pub fn new_with_client( - client_id: impl AsRef, - client_secret: impl AsRef, - auth_url: impl AsRef, + client_id: impl Into, + client_secret: impl Into, + auth_url: impl Into, client: reqwest::Client, ) -> Self { Self { access_token: Mutex::new(AccessToken::default()), fetched_at: Mutex::new(Instant::now()), - client_id: client_id.as_ref().to_string(), - client_secret: client_secret.as_ref().to_string(), + client_id: client_id.into(), + client_secret: client_secret.into(), client, - auth_url: auth_url.as_ref().to_string(), + auth_url: auth_url.into(), } } @@ -395,14 +395,14 @@ impl ManagementApiTokenFetcherBuilder { /// Sets an explicit client ID for authentication. /// /// If you also specify `tenant_name`, the client ID here takes precedence. - pub fn client_id(mut self, client_id: impl AsRef) -> Self { - self.client_id = Some(client_id.as_ref().to_string()); + pub fn client_id(mut self, client_id: impl Into) -> Self { + self.client_id = Some(client_id.into()); self } /// Sets a client secret required for token fetching. - pub fn client_secret(mut self, client_secret: impl AsRef) -> Self { - self.client_secret = Some(client_secret.as_ref().to_string()); + pub fn client_secret(mut self, client_secret: impl Into) -> Self { + self.client_secret = Some(client_secret.into()); self } @@ -410,8 +410,8 @@ impl ManagementApiTokenFetcherBuilder { /// /// This will use `platform.rest_client_id(tenant_name)` unless `client_id` /// is already set. - pub fn tenant_name(mut self, tenant_name: impl AsRef) -> Self { - self.tenant_name = Some(tenant_name.as_ref().to_string()); + pub fn tenant_name(mut self, tenant_name: impl Into) -> Self { + self.tenant_name = Some(tenant_name.into()); self } diff --git a/dsh_sdk/src/protocol_adapters/mod.rs b/dsh_sdk/src/protocol_adapters/mod.rs index d969edb..ab045af 100644 --- a/dsh_sdk/src/protocol_adapters/mod.rs +++ b/dsh_sdk/src/protocol_adapters/mod.rs @@ -1,4 +1,4 @@ -//! The DSH Protocol adapter clients (HTTP, Kafka, MQTT) +//! The DSH Protocol adapters (HTTP, Kafka, MQTT) //#[cfg(feature = "http-protocol-adapter")] //pub mod http_protocol; @@ -6,5 +6,5 @@ pub mod kafka_protocol; // #[cfg(feature = "mqtt-protocol-adapter")] // pub mod mqtt_protocol; -#[cfg(feature = "protocol-token-fetcher")] +#[cfg(feature = "protocol-token")] pub mod token; diff --git a/dsh_sdk/src/protocol_adapters/token/error.rs b/dsh_sdk/src/protocol_adapters/token/error.rs index 96ab086..9f1540f 100644 --- a/dsh_sdk/src/protocol_adapters/token/error.rs +++ b/dsh_sdk/src/protocol_adapters/token/error.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "protocol-token-fetcher")] +#[cfg(feature = "protocol-token")] /// Error type for the protocol tokens #[derive(Debug, thiserror::Error)] pub enum ProtocolTokenError { diff --git a/dsh_sdk/src/protocol_adapters/token/mod.rs b/dsh_sdk/src/protocol_adapters/token/mod.rs index 226c133..a2fe8f7 100644 --- a/dsh_sdk/src/protocol_adapters/token/mod.rs +++ b/dsh_sdk/src/protocol_adapters/token/mod.rs @@ -1,4 +1,4 @@ -//! The [`RestToken`] and [`DataAccessToken`] which can be used to authenticate against the DSH platform. +//! The [`RestToken`] and [`DataAccessToken`] which can be used to authenticate against the DSH Mqtt and Http brokers. pub mod api_client_token_fetcher; pub mod data_access_token; mod error; diff --git a/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs b/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs index 1b32846..591646b 100644 --- a/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs +++ b/dsh_sdk/src/protocol_adapters/token/rest_token/request.rs @@ -20,6 +20,12 @@ pub struct RequestRestToken { impl RequestRestToken { /// Creates a new [`RequestRestToken`] instance with full access request. + /// + /// # Arguments + /// - `tenant` - The tenant name or API client name. + /// + /// # Returns + /// A new [`RequestRestToken`] instance with full access. pub fn new(tenant: impl Into) -> Self { Self { tenant: tenant.into(), @@ -31,7 +37,7 @@ impl RequestRestToken { /// Send the request to the DSH platform to get a [`RestToken`]. /// /// # Arguments - /// - `client` - The reqwest client to use for the request. + /// - `client` - The [reqwest client](reqwest::Client) to use for the request. /// - `api_key` - The API key to authenticate to the DSH platform. /// - `auth_url` - The URL of the DSH platform to send the request to (See [Platform::endpoint_protocol_rest_token](crate::Platform::endpoint_protocol_rest_token)). /// @@ -79,7 +85,7 @@ impl RequestRestToken { } } - /// Returns the tenant + /// Returns the tenant name or API client name. pub fn tenant(&self) -> &str { &self.tenant } diff --git a/dsh_sdk/src/schema_store/client.rs b/dsh_sdk/src/schema_store/client.rs index 4e4e2d8..8eeb967 100644 --- a/dsh_sdk/src/schema_store/client.rs +++ b/dsh_sdk/src/schema_store/client.rs @@ -300,16 +300,6 @@ where Ok(subjects) } - // Example of a commented-out method that’s not fully implemented yet: - // /// Gets all schemas for a given topic. - // /// This might differentiate key vs. value schemas. - // pub async fn topic_all_schemas(&self, topic: S) -> Result<(Vec, Vec), SchemaStoreError> - // where - // S: AsRef, - // { - // // Implementation to fetch key_schemas and value_schemas is pending. - // } - /// Registers a **new** schema under the given subject. /// /// - If the subject doesn’t exist, it is created with the provided schema. diff --git a/dsh_sdk/src/schema_store/types/subject_strategy.rs b/dsh_sdk/src/schema_store/types/subject_strategy.rs index 88d82f4..8537338 100644 --- a/dsh_sdk/src/schema_store/types/subject_strategy.rs +++ b/dsh_sdk/src/schema_store/types/subject_strategy.rs @@ -31,10 +31,10 @@ pub enum SubjectName { impl SubjectName { pub fn new_topic_name_strategy(topic: S, key: bool) -> Self where - S: AsRef, + S: Into, { Self::TopicNameStrategy { - topic: topic.as_ref().to_string(), + topic: topic.into(), key, } } From 97d7cfb6f19050d7d95f8e3dacd2c1c177715ff5 Mon Sep 17 00:00:00 2001 From: Frank Hol <96832951+toelo3@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:21:37 +0100 Subject: [PATCH 23/23] Update CHANGELOG.md --- dsh_sdk/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsh_sdk/CHANGELOG.md b/dsh_sdk/CHANGELOG.md index ab5be8b..a0e7715 100644 --- a/dsh_sdk/CHANGELOG.md +++ b/dsh_sdk/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.5.0] - 2025-01-21 +## [0.5.0] - 2025-01-22 ### Added - DSH Kafka Config trait to configure kafka client with RDKafka implementation - DSH Schema store API Client