From 1e127c44010f4ee80328b880f1e9ae8d5dbc46a9 Mon Sep 17 00:00:00 2001 From: Yasir Shariff Date: Thu, 9 Nov 2023 17:27:45 +0300 Subject: [PATCH] Initial implementation of auth caching --- Cargo.toml | 18 ++++++----- src/client.rs | 61 +++++++++++++++++++++---------------- src/lib.rs | 1 + tests/mpesa-rust/helpers.rs | 3 +- 4 files changed, 48 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 400f2eccb..7ae0cb585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,14 @@ readme = "./README.md" license = "MIT" [dependencies] -chrono = {version = "0.4", optional = true, default-features = false, features = ["clock", "serde"] } -openssl = {version = "0.10", optional = true} -reqwest = {version = "0.11", features = ["json"]} -serde = {version="1.0", features= ["derive"]} +cached = { version = "0.46", features = ["wasm", "async", "proc_macro"] } +chrono = { version = "0.4", optional = true, default-features = false, features = [ + "clock", + "serde", +] } +openssl = { version = "0.10", optional = true } +reqwest = { version = "0.11", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_repr = "0.1" thiserror = "1.0.37" @@ -21,7 +25,7 @@ wiremock = "0.5" [dev-dependencies] dotenv = "0.15" -tokio = {version = "1", features = ["rt", "rt-multi-thread", "macros"]} +tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros"] } wiremock = "0.5" [features] @@ -34,7 +38,7 @@ default = [ "c2b_simulate", "express_request", "transaction_reversal", - "transaction_status" + "transaction_status", ] account_balance = ["dep:openssl"] b2b = ["dep:openssl"] @@ -44,4 +48,4 @@ c2b_register = [] c2b_simulate = [] express_request = ["dep:chrono"] transaction_reversal = ["dep:openssl"] -transaction_status= ["dep:openssl"] +transaction_status = ["dep:openssl"] diff --git a/src/client.rs b/src/client.rs index 65c87ee7b..918409b95 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,4 @@ +use crate::auth::AUTH; use crate::environment::ApiEnvironment; use crate::services::{ AccountBalanceBuilder, B2bBuilder, B2cBuilder, BulkInvoiceBuilder, C2bRegisterBuilder, @@ -5,12 +6,13 @@ use crate::services::{ OnboardModifyBuilder, ReconciliationBuilder, SingleInvoiceBuilder, TransactionReversalBuilder, TransactionStatusBuilder, }; -use crate::{ApiError, MpesaError}; +use crate::{auth, MpesaError}; +use cached::Cached; use openssl::base64; use openssl::rsa::Padding; use openssl::x509::X509; use reqwest::Client as HttpClient; -use serde_json::Value; + use std::cell::RefCell; /// Source: [test credentials](https://developer.safaricom.co.ke/test_credentials) @@ -68,6 +70,16 @@ impl<'mpesa, Env: ApiEnvironment> Mpesa { p.to_owned() } + /// Get the client key + pub fn client_key(&self) -> &str { + &self.client_key + } + + /// Get the client secret + pub fn client_secret(&self) -> &str { + &self.client_secret + } + /// Optional in development but required for production, you will need to call this method and set your production initiator password. /// If in development, default initiator password is already pre-set /// ```ignore @@ -102,33 +114,28 @@ impl<'mpesa, Env: ApiEnvironment> Mpesa { /// /// # Errors /// Returns a `MpesaError` on failure - #[allow(clippy::single_char_pattern)] pub(crate) async fn auth(&self) -> MpesaResult { - let url = format!( - "{}/oauth/v1/generate?grant_type=client_credentials", - self.environment.base_url() - ); - let response = self - .http_client - .get(&url) - .basic_auth(&self.client_key, Some(&self.client_secret)) - .send() - .await?; - if response.status().is_success() { - let value = response.json::().await?; - let access_token = value - .get("access_token") - .ok_or_else(|| String::from("Failed to extract token from the response")) - .unwrap(); - let access_token = access_token - .as_str() - .ok_or_else(|| String::from("Error converting access token to string")) - .unwrap(); - - return Ok(access_token.to_string()); + if let Some(token) = AUTH.lock().await.cache_get(&self.client_key) { + return Ok(token.clone()); } - let error = response.json::().await?; - Err(MpesaError::AuthenticationError(error)) + + // Generate a new access token + let new_token = match auth::auth_prime_cache(self).await { + Ok(token) => token, + Err(e) => return Err(e), + }; + + // Double-check if the access token is cached by another thread + if let Some(token) = AUTH.lock().await.cache_get(&self.client_key) { + return Ok(token.clone()); + } + + // Cache the new token + AUTH.lock() + .await + .cache_set(self.client_key.clone(), new_token.clone()); + + Ok(new_token) } /// **B2C Builder** diff --git a/src/lib.rs b/src/lib.rs index 7a9e84138..5932fe87c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod constants; pub mod environment; mod errors; pub mod services; +mod auth; pub use client::{Mpesa, MpesaResult}; pub use constants::{ diff --git a/tests/mpesa-rust/helpers.rs b/tests/mpesa-rust/helpers.rs index a6f13cf60..466f53bfc 100644 --- a/tests/mpesa-rust/helpers.rs +++ b/tests/mpesa-rust/helpers.rs @@ -49,7 +49,8 @@ macro_rules! get_mpesa_client { .and(path("/oauth/v1/generate")) .and(query_param("grant_type", "client_credentials")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ - "access_token": "dummy_access_token" + "access_token": "dummy_access_token", + "expiry_in": 3600 }))) .expect($expected_requests) .mount(&server)