diff --git a/payday_core/src/aggregate/lightning_aggregate.rs b/payday_core/src/aggregate/lightning_aggregate.rs index 14fbfac..715bc52 100644 --- a/payday_core/src/aggregate/lightning_aggregate.rs +++ b/payday_core/src/aggregate/lightning_aggregate.rs @@ -105,6 +105,10 @@ impl Aggregate for LightningInvoice { amount, invoice, } => { + if !self.invoice_id.is_empty() { + return Err(Error::InvoiceAlreadyExists(invoice_id)); + } + if amount.currency != Currency::Btc { return Err(Error::InvalidCurrency( amount.currency.to_string(), @@ -124,10 +128,13 @@ impl Aggregate for LightningInvoice { }]) } LightningInvoiceCommand::SettleInvoice { received_amount } => { + if self.paid { + return Ok(vec![]); + } Ok(vec![LightningInvoiceEvent::InvoiceSettled { received_amount, - overpaid: received_amount.amount > self.amount.amount, - paid: received_amount.amount >= self.amount.amount, + overpaid: received_amount.cent_amount > self.amount.cent_amount, + paid: received_amount.cent_amount >= self.amount.cent_amount, }]) } } @@ -160,3 +167,130 @@ impl Aggregate for LightningInvoice { } } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use cqrs_es::test::TestFramework; + + use super::*; + + type LightningInvoiceTestFramework = TestFramework; + + #[test] + fn test_create_lightning_invoice() { + let expected_event = mock_created_event(); + LightningInvoiceTestFramework::with(()) + .given_no_previous_events() + .when(LightningInvoiceCommand::CreateInvoice { + invoice_id: "123".to_string(), + node_id: "node1".to_string(), + amount: Amount::sats(100_000), + invoice: get_invoice(), + }) + .then_expect_events(vec![expected_event]); + } + + #[test] + fn test_settle_lightning_invoice() { + let expected_event = mock_settled_event(Amount::sats(100_000), false, true); + LightningInvoiceTestFramework::with(()) + .given(vec![mock_created_event()]) + .when(LightningInvoiceCommand::SettleInvoice { + received_amount: Amount::sats(100_000), + }) + .then_expect_events(vec![expected_event]); + } + + #[test] + fn test_create_lightning_invoice_invalid_currency() { + let expected_error = Error::InvalidCurrency("USD".to_string(), "BTC".to_string()); + LightningInvoiceTestFramework::with(()) + .given_no_previous_events() + .when(LightningInvoiceCommand::CreateInvoice { + invoice_id: "123".to_string(), + node_id: "node1".to_string(), + amount: Amount::new(Currency::Usd, 100_000), + invoice: get_invoice(), + }) + .then_expect_error(expected_error); + } + + #[test] + fn test_settle_lightning_invoice_overpaid() { + let expected_event = mock_settled_event(Amount::sats(200_000), true, true); + LightningInvoiceTestFramework::with(()) + .given(vec![mock_created_event()]) + .when(LightningInvoiceCommand::SettleInvoice { + received_amount: Amount::sats(200_000), + }) + .then_expect_events(vec![expected_event]); + } + + #[test] + fn test_settle_lightning_invoice_underpaid() { + let expected_event = mock_settled_event(Amount::sats(50_000), false, false); + LightningInvoiceTestFramework::with(()) + .given(vec![mock_created_event()]) + .when(LightningInvoiceCommand::SettleInvoice { + received_amount: Amount::sats(50_000), + }) + .then_expect_events(vec![expected_event]); + } + #[test] + fn test_create_lightning_invoice_already_exists() { + let expected_error = Error::InvoiceAlreadyExists("123".to_string()); + LightningInvoiceTestFramework::with(()) + .given(vec![mock_created_event()]) + .when(LightningInvoiceCommand::CreateInvoice { + invoice_id: "123".to_string(), + node_id: "node1".to_string(), + amount: Amount::sats(100_000), + invoice: get_invoice(), + }) + .then_expect_error(expected_error); + } + + #[test] + fn test_set_confirmed_lightning_invoice_already_confirmed() { + LightningInvoiceTestFramework::with(()) + .given(vec![ + mock_created_event(), + mock_settled_event(Amount::sats(100_000), false, true), + ]) + .when(LightningInvoiceCommand::SettleInvoice { + received_amount: Amount::sats(100_000), + }) + .then_expect_events(vec![]); + } + + fn mock_created_event() -> LightningInvoiceEvent { + let invoice = get_invoice(); + LightningInvoiceEvent::InvoiceCreated { + invoice_id: "123".to_string(), + node_id: "node1".to_string(), + amount: Amount::sats(100_000), + invoice: invoice.to_string(), + r_hash: invoice.payment_hash().to_string(), + } + } + + fn mock_settled_event( + received_amount: Amount, + overpaid: bool, + paid: bool, + ) -> LightningInvoiceEvent { + LightningInvoiceEvent::InvoiceSettled { + received_amount, + overpaid, + paid, + } + } + + fn get_invoice() -> Bolt11Invoice { + Bolt11Invoice::from_str( + "lntbs3m1pnf36h3pp5dm63f7meus5thxd3h23uqkfuydw340nrf6v8y398ga7tqjfrpnfsdq5w3jhxapqd9h8vmmfvdjscqzzsxq97ztucsp5yle6azm0tpy7h3dh0d6kmpzzzpyvzqkck476l96z5p5leqaraumq9qyyssqghpt4k54rrutwumlq6hav5wdjghlrxnyxe5dde37e5t4wwz4kkq3r5284l3rcnyzzqvry6xz4s8mq42npq8fzr7j9tvvuyh32xmh97gq0h8hdp" + ).expect("valid invoice") + } +} diff --git a/payday_core/src/aggregate/on_chain_aggregate.rs b/payday_core/src/aggregate/on_chain_aggregate.rs index f70860c..796245e 100644 --- a/payday_core/src/aggregate/on_chain_aggregate.rs +++ b/payday_core/src/aggregate/on_chain_aggregate.rs @@ -169,6 +169,7 @@ impl Aggregate for OnChainInvoice { amount, address, } => { + // invalid currency if amount.currency != Currency::Btc { return Err(Error::InvalidCurrency( amount.currency.to_string(), @@ -176,6 +177,11 @@ impl Aggregate for OnChainInvoice { )); } + // invoice already exists + if !self.invoice_id.is_empty() { + return Err(Error::InvoiceAlreadyExists(invoice_id)); + } + Ok(vec![OnChainInvoiceEvent::InvoiceCreated { invoice_id, node_id, @@ -184,23 +190,35 @@ impl Aggregate for OnChainInvoice { }]) } OnChainInvoiceCommand::SetPending { amount } => { + // already payd or pending + if self.received_amount.cent_amount > 0 { + return Ok(Vec::new()); + } + Ok(vec![OnChainInvoiceEvent::PaymentPending { received_amount: amount, - underpayment: amount.amount < self.amount.amount, - overpayment: amount.amount > self.amount.amount, + underpayment: amount.cent_amount < self.amount.cent_amount, + overpayment: amount.cent_amount > self.amount.cent_amount, }]) } OnChainInvoiceCommand::SetConfirmed { confirmations, amount, transaction_id, - } => Ok(vec![OnChainInvoiceEvent::PaymentConfirmed { - received_amount: amount, - underpayment: amount.amount < self.amount.amount, - overpayment: amount.amount > self.amount.amount, - confirmations, - transaction_id, - }]), + } => { + // already confirmed + if self.confirmations > 0 { + return Ok(Vec::new()); + } + + Ok(vec![OnChainInvoiceEvent::PaymentConfirmed { + received_amount: amount, + underpayment: amount.cent_amount < self.amount.cent_amount, + overpayment: amount.cent_amount > self.amount.cent_amount, + confirmations, + transaction_id, + }]) + } } } @@ -261,7 +279,7 @@ mod aggregate_tests { .when(OnChainInvoiceCommand::CreateInvoice { invoice_id: "123".to_string(), node_id: "node1".to_string(), - amount: amount_fn(100_000), + amount: Amount::sats(100_000), address: "tb1q6xm2qgh5r83lvmmu0v7c3d4wrd9k2uxu3sgcr4".to_string(), }) .then_expect_events(vec![expected]) @@ -269,8 +287,8 @@ mod aggregate_tests { #[test] fn test_set_pending() { - let amount = amount_fn(100_000); - let expected = mock_pending_event(amount.amount, false, false); + let amount = Amount::sats(100_000); + let expected = mock_pending_event(amount.cent_amount, false, false); OnChainInvoiceTestFramework::with(()) .given(vec![mock_created_event(100_000)]) .when(OnChainInvoiceCommand::SetPending { amount }) @@ -279,8 +297,8 @@ mod aggregate_tests { #[test] fn test_pending_overpayment() { - let amount = amount_fn(100_001); - let expected = mock_pending_event(amount.amount, false, true); + let amount = Amount::sats(100_001); + let expected = mock_pending_event(amount.cent_amount, false, true); OnChainInvoiceTestFramework::with(()) .given(vec![mock_created_event(100_000)]) .when(OnChainInvoiceCommand::SetPending { amount }) @@ -289,8 +307,8 @@ mod aggregate_tests { #[test] fn test_pending_underpayment() { - let amount = amount_fn(99_999); - let expected = mock_pending_event(amount.amount, true, false); + let amount = Amount::sats(99_999); + let expected = mock_pending_event(amount.cent_amount, true, false); OnChainInvoiceTestFramework::with(()) .given(vec![mock_created_event(100_000)]) .when(OnChainInvoiceCommand::SetPending { amount }) @@ -300,7 +318,7 @@ mod aggregate_tests { #[test] fn test_set_confirmed() { let expected = OnChainInvoiceEvent::PaymentConfirmed { - received_amount: Amount::new(Currency::Btc, 100_000), + received_amount: Amount::sats(100_000), underpayment: false, overpayment: false, confirmations: 1, @@ -310,14 +328,55 @@ mod aggregate_tests { .given(vec![mock_created_event(100_000)]) .when(OnChainInvoiceCommand::SetConfirmed { confirmations: 1, - amount: Amount::new(Currency::Btc, 100_000), + amount: Amount::sats(100_000), transaction_id: "txid".to_string(), }) .then_expect_events(vec![expected]) } - fn amount_fn(amount: u64) -> Amount { - Amount::new(Currency::Btc, amount) + #[test] + fn test_create_invoice_already_exists() { + let expected_error = Error::InvoiceAlreadyExists("123".to_string()); + OnChainInvoiceTestFramework::with(()) + .given(vec![mock_created_event(100_000)]) + .when(OnChainInvoiceCommand::CreateInvoice { + invoice_id: "123".to_string(), + node_id: "node1".to_string(), + amount: Amount::sats(100_000), + address: "tb1q6xm2qgh5r83lvmmu0v7c3d4wrd9k2uxu3sgcr4".to_string(), + }) + .then_expect_error(expected_error); + } + + #[test] + fn test_set_pending_already_pending() { + let amount = Amount::sats(100_000); + OnChainInvoiceTestFramework::with(()) + .given(vec![ + mock_created_event(100_000), + mock_pending_event(100_000, false, false), + ]) + .when(OnChainInvoiceCommand::SetPending { amount }) + .then_expect_events(vec![]); + } + + #[test] + fn test_set_confirmed_already_confirmed() { + let expected = OnChainInvoiceEvent::PaymentConfirmed { + received_amount: Amount::new(Currency::Btc, 100_000), + underpayment: false, + overpayment: false, + confirmations: 1, + transaction_id: "txid".to_string(), + }; + OnChainInvoiceTestFramework::with(()) + .given(vec![mock_created_event(100_000), expected.clone()]) + .when(OnChainInvoiceCommand::SetConfirmed { + confirmations: 1, + amount: Amount::new(Currency::Btc, 100_000), + transaction_id: "txid".to_string(), + }) + .then_expect_events(vec![]); } fn mock_pending_event( @@ -326,7 +385,7 @@ mod aggregate_tests { overpayment: bool, ) -> OnChainInvoiceEvent { OnChainInvoiceEvent::PaymentPending { - received_amount: amount_fn(amount), + received_amount: Amount::sats(amount), underpayment, overpayment, } @@ -336,7 +395,7 @@ mod aggregate_tests { OnChainInvoiceEvent::InvoiceCreated { invoice_id: "123".to_string(), node_id: "node1".to_string(), - amount: amount_fn(amount), + amount: Amount::sats(amount), address: "tb1q6xm2qgh5r83lvmmu0v7c3d4wrd9k2uxu3sgcr4".to_string(), } } diff --git a/payday_core/src/api/lightining_api.rs b/payday_core/src/api/lightining_api.rs index d856a0f..3a84477 100644 --- a/payday_core/src/api/lightining_api.rs +++ b/payday_core/src/api/lightining_api.rs @@ -49,6 +49,19 @@ pub trait LightningTransactionStreamApi: Send + Sync { ) -> Result>; } +#[async_trait] +pub trait LightningTransactionEventProcessorApi: Send + Sync { + fn node_id(&self) -> String; + async fn get_offset(&self) -> Result; + async fn set_offset(&self, settle_index: u64) -> Result<()>; + async fn process_event(&self, event: LightningTransactionEvent) -> Result<()>; +} + +#[async_trait] +pub trait LightningTransactionEventHandler: Send + Sync { + async fn process_event(&self, event: LightningTransactionEvent) -> Result<()>; +} + #[derive(Debug, Clone)] pub struct LnInvoice { pub invoice: String, @@ -84,6 +97,14 @@ pub enum LightningTransactionEvent { Settled(LightningTransaction), } +impl LightningTransactionEvent { + pub fn settle_index(&self) -> Option { + match self { + LightningTransactionEvent::Settled(tx) => Some(tx.settle_index), + } + } +} + #[derive(Debug, Clone, PartialEq)] pub enum InvoiceState { OPEN, diff --git a/payday_core/src/api/on_chain_api.rs b/payday_core/src/api/on_chain_api.rs index ebc8623..3308d4c 100644 --- a/payday_core/src/api/on_chain_api.rs +++ b/payday_core/src/api/on_chain_api.rs @@ -59,8 +59,8 @@ pub trait OnChainTransactionApi: Send + Sync { #[async_trait] pub trait OnChainTransactionEventProcessorApi: Send + Sync { fn node_id(&self) -> String; - async fn get_offset(&self) -> Result; - async fn set_block_height(&self, block_height: i32) -> Result<()>; + async fn get_offset(&self) -> Result; + async fn set_block_height(&self, block_height: u64) -> Result<()>; async fn process_event(&self, event: OnChainTransactionEvent) -> Result<()>; } diff --git a/payday_core/src/payment/amount.rs b/payday_core/src/payment/amount.rs index 11a5855..c926ddb 100644 --- a/payday_core/src/payment/amount.rs +++ b/payday_core/src/payment/amount.rs @@ -7,25 +7,28 @@ use crate::payment::currency::Currency; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub struct Amount { pub currency: Currency, - pub amount: u64, + pub cent_amount: u64, } impl Amount { - pub fn new(currency: Currency, amount: u64) -> Self { - Self { currency, amount } + pub fn new(currency: Currency, cent_amount: u64) -> Self { + Self { + currency, + cent_amount, + } } pub fn zero(currency: Currency) -> Self { Self { currency, - amount: 0, + cent_amount: 0, } } pub fn sats(sats: u64) -> Self { Self { currency: Currency::Btc, - amount: sats, + cent_amount: sats, } } } @@ -34,13 +37,13 @@ impl Default for Amount { fn default() -> Self { Self { currency: Currency::Btc, - amount: 0, + cent_amount: 0, } } } impl Display for Amount { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {}", self.amount, self.currency) + write!(f, "{} {}", self.cent_amount, self.currency) } } diff --git a/payday_core/src/payment/invoice.rs b/payday_core/src/payment/invoice.rs index b257c88..f263823 100644 --- a/payday_core/src/payment/invoice.rs +++ b/payday_core/src/payment/invoice.rs @@ -2,19 +2,18 @@ use std::fmt::{Display, Formatter}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use serde_json::Value; use crate::payment::amount::Amount; pub type InvoiceId = String; -pub type PaymentType = String; pub type Result = std::result::Result; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum Error { InvalidAmount(Amount), InvalidCurrency(String, String), ServiceError(String), + InvoiceAlreadyExists(String), } impl std::error::Error for Error {} @@ -29,17 +28,79 @@ impl Display for Error { required, received ), Error::ServiceError(err) => write!(f, "Invoice service error: {}", err), + Error::InvoiceAlreadyExists(id) => write!(f, "Invoice already exists: {}", id), } } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum PaymentType { + BitcoinOnChain, + BitcoinLightning, + BitcoinUnified, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Invoice { - pub service_name: String, pub invoice_id: InvoiceId, - pub amount: Amount, + pub node_id: String, pub payment_type: PaymentType, - pub payment_info: Value, + pub invoice_amount: Amount, + pub received_amount: Amount, + pub underpayment: bool, + pub overpayment: bool, + pub paid: bool, + pub details: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PaymentEvent { + PaymentUnconfirmed(PaymentReceivedEventPayload), + PaymentReceived(PaymentReceivedEventPayload), + UnexpectedPaymentReceived(UnexpectedPaymentReceivedEventPayload), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PaymentDetails { + OnChain(OnChainPaymentDetais), + Lightning(LightningPaymentDetails), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PaymentReceivedEventPayload { + pub invoice_id: InvoiceId, + pub node_id: String, + pub payment_type: PaymentType, + pub invoice_amount: Amount, + pub received_amount: Amount, + pub underpayment: bool, + pub overpayment: bool, + pub paid: bool, + pub details: PaymentDetails, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnexpectedPaymentReceivedEventPayload { + pub node_id: String, + pub payment_type: PaymentType, + pub received_amount: Amount, + pub paid: bool, + pub details: PaymentDetails, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OnChainPaymentDetais { + pub address: String, + pub confirmations: u32, + pub transaction_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LightningPaymentDetails { + pub invoice: String, + pub r_hash: String, } #[async_trait] diff --git a/payday_core/src/processor/lightning_processor.rs b/payday_core/src/processor/lightning_processor.rs new file mode 100644 index 0000000..f98e29b --- /dev/null +++ b/payday_core/src/processor/lightning_processor.rs @@ -0,0 +1,62 @@ +use crate::{ + api::lightining_api::{ + LightningTransactionEvent, LightningTransactionEventHandler, + LightningTransactionEventProcessorApi, + }, + persistence::offset::OffsetStoreApi, + Result, +}; +use async_trait::async_trait; + +pub struct LightningTransactionProcessor { + node_id: String, + settle_index_store: Box, + handler: Box, +} + +impl LightningTransactionProcessor { + pub fn new( + node_id: &str, + settle_index_store: Box, + handler: Box, + ) -> Self { + Self { + node_id: node_id.to_string(), + settle_index_store, + handler, + } + } +} + +#[async_trait] +impl LightningTransactionEventProcessorApi for LightningTransactionProcessor { + fn node_id(&self) -> String { + self.node_id.to_string() + } + async fn get_offset(&self) -> Result { + self.settle_index_store.get_offset().await.map(|o| o.offset) + } + + async fn set_offset(&self, block_height: u64) -> Result<()> { + self.settle_index_store.set_offset(block_height).await + } + + async fn process_event(&self, event: LightningTransactionEvent) -> Result<()> { + let index = event.settle_index(); + self.handler.process_event(event).await?; + if let Some(idx) = index { + self.set_offset(idx).await?; + } + Ok(()) + } +} + +pub struct LightningTransactionPrintHandler; + +#[async_trait] +impl LightningTransactionEventHandler for LightningTransactionPrintHandler { + async fn process_event(&self, event: LightningTransactionEvent) -> Result<()> { + println!("LightningTransactionEvent: {:?}", event); + Ok(()) + } +} diff --git a/payday_core/src/processor/mod.rs b/payday_core/src/processor/mod.rs index c878fca..d0df7ff 100644 --- a/payday_core/src/processor/mod.rs +++ b/payday_core/src/processor/mod.rs @@ -1 +1,2 @@ +pub mod lightning_processor; pub mod on_chain_processor; diff --git a/payday_core/src/processor/on_chain_processor.rs b/payday_core/src/processor/on_chain_processor.rs index 8b31f85..32b4112 100644 --- a/payday_core/src/processor/on_chain_processor.rs +++ b/payday_core/src/processor/on_chain_processor.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{ api::on_chain_api::{ OnChainTransactionEvent, OnChainTransactionEventHandler, @@ -9,13 +7,11 @@ use crate::{ Result, }; use async_trait::async_trait; -use tokio::sync::Mutex; pub struct OnChainTransactionProcessor { node_id: String, block_height_store: Box, handler: Box, - current_block_height: Arc>, } impl OnChainTransactionProcessor { @@ -28,7 +24,6 @@ impl OnChainTransactionProcessor { node_id: node_id.to_string(), block_height_store, handler, - current_block_height: Arc::new(Mutex::new(-1)), } } } @@ -38,28 +33,20 @@ impl OnChainTransactionEventProcessorApi for OnChainTransactionProcessor { fn node_id(&self) -> String { self.node_id.to_string() } - async fn get_offset(&self) -> Result { - let mut current_block_height = self.current_block_height.lock().await; - if *current_block_height < 0 { - *current_block_height = self.block_height_store.get_offset().await?.offset as i32; - } - Ok(*current_block_height) + + async fn get_offset(&self) -> Result { + self.block_height_store.get_offset().await.map(|o| o.offset) } - async fn set_block_height(&self, block_height: i32) -> Result<()> { - let mut current_block_height = self.current_block_height.lock().await; - if *current_block_height < block_height { - self.block_height_store - .set_offset(block_height as u64) - .await?; - *current_block_height = block_height; - } - Ok(()) + + async fn set_block_height(&self, block_height: u64) -> Result<()> { + self.block_height_store.set_offset(block_height).await } + async fn process_event(&self, event: OnChainTransactionEvent) -> Result<()> { let block_height = event.block_height(); self.handler.process_event(event).await?; if let Some(bh) = block_height { - self.set_block_height(bh).await?; + self.set_block_height(bh as u64).await?; } Ok(()) } diff --git a/payday_node_lnd/src/lnd.rs b/payday_node_lnd/src/lnd.rs index f6be075..af2ff45 100644 --- a/payday_node_lnd/src/lnd.rs +++ b/payday_node_lnd/src/lnd.rs @@ -99,7 +99,7 @@ impl LightningInvoiceApi for Lnd { memo: Option, ttl: Option, ) -> Result { - let amount = bitcoin::Amount::from_sat(amount.amount); + let amount = bitcoin::Amount::from_sat(amount.cent_amount); let invoice = self.client.create_invoice(amount, memo, ttl).await?; Ok(invoice) } @@ -170,7 +170,7 @@ impl OnChainPaymentApi for Lnd { #[async_trait] impl LightningPaymentApi for Lnd { async fn pay_to_node_pub_key(&self, pub_key: String, amount: Amount) -> Result<()> { - let amt = bitcoin::Amount::from_sat(amount.amount); + let amt = bitcoin::Amount::from_sat(amount.cent_amount); self.client .pay_to_node_id(pub_key.parse()?, amt, None) .await?; diff --git a/payday_postgres/src/offset.rs b/payday_postgres/src/offset.rs index ce14b61..c5bd3fa 100644 --- a/payday_postgres/src/offset.rs +++ b/payday_postgres/src/offset.rs @@ -4,10 +4,12 @@ use payday_core::{ Error, Result, }; use sqlx::{Pool, Postgres, Row}; +use tokio::sync::Mutex; pub struct OffsetStore { db: Pool, id: String, + current_offset: Box>>, } impl OffsetStore { @@ -15,37 +17,42 @@ impl OffsetStore { Self { db, id: id.to_string(), + current_offset: Box::new(Mutex::new(None)), } } + async fn get_cached(&self) -> Option { + let cached = self.current_offset.lock().await; + *cached + } + + async fn set_cached(&self, offset: u64) { + let mut cached = self.current_offset.lock().await; + *cached = Some(offset); + } + async fn get_offset_internal(&self) -> Result> { + let cached = self.get_cached().await; + if let Some(cached) = cached { + return Ok(Some(cached)); + } let res: Option = sqlx::query("SELECT current_offset FROM offsets WHERE id = $1") .bind(&self.id) .fetch_optional(&self.db) .await .map_err(|e| Error::DbError(e.to_string()))? .map(|r| r.get("current_offset")); - Ok(res.and_then(|r| u64::try_from(r).ok())) - } -} -#[async_trait] -impl OffsetStoreApi for OffsetStore { - async fn get_offset(&self) -> Result { - let offset: Option = self.get_offset_internal().await?; - match offset { - Some(offset) => Ok(Offset { - id: self.id.to_owned(), - offset, - }), - None => Ok(Offset { - id: self.id.to_owned(), - offset: 0, - }), + match res.and_then(|r| u64::try_from(r).ok()) { + Some(offset) => { + self.set_cached(offset).await; + Ok(Some(offset)) + } + _ => Ok(None), } } - async fn set_offset(&self, offset: u64) -> Result<()> { + async fn set_offset_internal(&self, offset: u64) -> Result<()> { let existing: Option = self.get_offset_internal().await?; if existing.is_some() { sqlx::query("UPDATE offsets SET current_offset = $1 WHERE id = $2") @@ -62,11 +69,32 @@ impl OffsetStoreApi for OffsetStore { .await .map_err(|e| Error::DbError(e.to_string()))?; } - + self.set_cached(offset).await; Ok(()) } } +#[async_trait] +impl OffsetStoreApi for OffsetStore { + async fn get_offset(&self) -> Result { + let offset: Option = self.get_offset_internal().await?; + match offset { + Some(offset) => Ok(Offset { + id: self.id.to_owned(), + offset, + }), + None => Ok(Offset { + id: self.id.to_owned(), + offset: 0, + }), + } + } + + async fn set_offset(&self, offset: u64) -> Result<()> { + self.set_offset_internal(offset).await + } +} + #[cfg(test)] mod tests { use super::*; @@ -91,6 +119,11 @@ mod tests { .set_offset(10) .await .expect("Query executed successfully"); + + assert!(store.current_offset.lock().await.is_some()); + assert!(store.get_cached().await.is_some()); + assert!(store.get_cached().await.unwrap().eq(&10)); + let result = store .get_offset() .await