diff --git a/presage-cli/src/main.rs b/presage-cli/src/main.rs index 220580708..e98f12155 100644 --- a/presage-cli/src/main.rs +++ b/presage-cli/src/main.rs @@ -229,7 +229,8 @@ async fn main() -> anyhow::Result<()> { args.passphrase, MigrationConflictStrategy::Raise, OnNewIdentity::Trust, - )?; + ) + .await?; run(args.subcommand, config_store).await } @@ -289,7 +290,7 @@ async fn process_incoming_message( notifications: bool, content: &Content, ) { - print_message(manager, notifications, content); + print_message(manager, notifications, content).await; let sender = content.metadata.sender.uuid; if let ContentBody::DataMessage(DataMessage { attachments, .. }) = &content.body { @@ -324,7 +325,7 @@ async fn process_incoming_message( } } -fn print_message( +async fn print_message( manager: &Manager, notifications: bool, content: &Content, @@ -334,66 +335,74 @@ fn print_message( return; }; - let format_data_message = |thread: &Thread, data_message: &DataMessage| match data_message { - DataMessage { - quote: - Some(Quote { - text: Some(quoted_text), - .. - }), - body: Some(body), - .. - } => Some(format!("Answer to message \"{quoted_text}\": {body}")), - DataMessage { - reaction: - Some(Reaction { - target_sent_timestamp: Some(ts), - emoji: Some(emoji), - .. - }), - .. - } => { - let Ok(Some(message)) = manager.store().message(thread, *ts) else { - warn!(%thread, sent_at = ts, "no message found in thread"); - return None; - }; + async fn format_data_message( + thread: &Thread, + data_message: &DataMessage, + manager: &Manager, + ) -> Option { + match data_message { + DataMessage { + quote: + Some(Quote { + text: Some(quoted_text), + .. + }), + body: Some(body), + .. + } => Some(format!("Answer to message \"{quoted_text}\": {body}")), + DataMessage { + reaction: + Some(Reaction { + target_sent_timestamp: Some(ts), + emoji: Some(emoji), + .. + }), + .. + } => { + let Ok(Some(message)) = manager.store().message(thread, *ts).await else { + warn!(%thread, sent_at = ts, "no message found in thread"); + return None; + }; - let ContentBody::DataMessage(DataMessage { - body: Some(body), .. - }) = message.body - else { - warn!("message reacted to has no body"); - return None; - }; + let ContentBody::DataMessage(DataMessage { + body: Some(body), .. + }) = message.body + else { + warn!("message reacted to has no body"); + return None; + }; - Some(format!("Reacted with {emoji} to message: \"{body}\"")) + Some(format!("Reacted with {emoji} to message: \"{body}\"")) + } + DataMessage { + body: Some(body), .. + } => Some(body.to_string()), + _ => Some("Empty data message".to_string()), } - DataMessage { - body: Some(body), .. - } => Some(body.to_string()), - _ => Some("Empty data message".to_string()), - }; + } - let format_contact = |uuid| { + async fn format_contact(uuid: &Uuid, manager: &Manager) -> String { manager .store() .contact_by_id(uuid) + .await .ok() .flatten() .filter(|c| !c.name.is_empty()) .map(|c| format!("{}: {}", c.name, uuid)) .unwrap_or_else(|| uuid.to_string()) - }; + } - let format_group = |key| { + async fn format_group(key: [u8; 32], manager: &Manager) -> String { manager .store() .group(key) + .await .ok() .flatten() .map(|g| g.title) .unwrap_or_else(|| "".to_string()) - }; + } enum Msg<'a> { Received(&'a Thread, String), @@ -406,12 +415,16 @@ fn print_message( "Null message (for example deleted)".to_string(), )), ContentBody::DataMessage(data_message) => { - format_data_message(&thread, data_message).map(|body| Msg::Received(&thread, body)) + format_data_message(&thread, data_message, manager) + .await + .map(|body| Msg::Received(&thread, body)) } ContentBody::EditMessage(EditMessage { data_message: Some(data_message), .. - }) => format_data_message(&thread, data_message).map(|body| Msg::Received(&thread, body)), + }) => format_data_message(&thread, data_message, manager) + .await + .map(|body| Msg::Received(&thread, body)), ContentBody::EditMessage(EditMessage { .. }) => None, ContentBody::SynchronizeMessage(SyncMessage { sent: @@ -420,7 +433,9 @@ fn print_message( .. }), .. - }) => format_data_message(&thread, data_message).map(|body| Msg::Sent(&thread, body)), + }) => format_data_message(&thread, data_message, manager) + .await + .map(|body| Msg::Sent(&thread, body)), ContentBody::SynchronizeMessage(SyncMessage { sent: Some(Sent { @@ -432,7 +447,9 @@ fn print_message( .. }), .. - }) => format_data_message(&thread, data_message).map(|body| Msg::Sent(&thread, body)), + }) => format_data_message(&thread, data_message, manager) + .await + .map(|body| Msg::Sent(&thread, body)), ContentBody::SynchronizeMessage(SyncMessage { .. }) => None, ContentBody::CallMessage(_) => Some(Msg::Received(&thread, "is calling!".into())), ContentBody::TypingMessage(_) => Some(Msg::Received(&thread, "is typing...".into())), @@ -456,20 +473,20 @@ fn print_message( let ts = content.timestamp(); let (prefix, body) = match msg { Msg::Received(Thread::Contact(sender), body) => { - let contact = format_contact(sender); + let contact = format_contact(sender, manager).await; (format!("From {contact} @ {ts}: "), body) } Msg::Sent(Thread::Contact(recipient), body) => { - let contact = format_contact(recipient); + let contact = format_contact(recipient, manager).await; (format!("To {contact} @ {ts}"), body) } Msg::Received(Thread::Group(key), body) => { - let sender = format_contact(&content.metadata.sender.uuid); - let group = format_group(*key); + let sender = format_contact(&content.metadata.sender.uuid, manager).await; + let group = format_group(*key, manager).await; (format!("From {sender} to group {group} @ {ts}: "), body) } Msg::Sent(Thread::Group(key), body) => { - let group = format_group(*key); + let group = format_group(*key, manager).await; (format!("To group {group} @ {ts}"), body) } }; @@ -660,7 +677,8 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { if profile_key.is_none() { for contact in manager .store() - .contacts()? + .contacts() + .await? .filter_map(Result::ok) .filter(|c| c.uuid == uuid) { @@ -681,7 +699,7 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { } Cmd::ListGroups => { let manager = Manager::load_registered(config_store).await?; - for group in manager.store().groups()? { + for group in manager.store().groups().await? { match group { Ok(( group_master_key, @@ -712,14 +730,14 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { uuid, phone_number, .. - } in manager.store().contacts()?.flatten() + } in manager.store().contacts().await?.flatten() { println!("{uuid} / {phone_number:?} / {name}"); } } Cmd::ListStickerPacks => { let manager = Manager::load_registered(config_store).await?; - for sticker_pack in manager.sticker_packs().await? { + for sticker_pack in manager.store().sticker_packs().await? { match sticker_pack { Ok(sticker_pack) => { println!( @@ -748,7 +766,7 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { } Cmd::GetContact { ref uuid } => { let manager = Manager::load_registered(config_store).await?; - match manager.store().contact_by_id(uuid)? { + match manager.store().contact_by_id(uuid).await? { Some(contact) => println!("{contact:#?}"), None => eprintln!("Could not find contact for {uuid}"), } @@ -761,7 +779,8 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { let manager = Manager::load_registered(config_store).await?; for contact in manager .store() - .contacts()? + .contacts() + .await? .filter_map(Result::ok) .filter(|c| uuid.map_or_else(|| true, |u| c.uuid == u)) .filter(|c| c.phone_number == phone_number) @@ -786,10 +805,12 @@ async fn run(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { _ => unreachable!(), }; for msg in manager - .messages(&thread, from.unwrap_or(0)..)? + .store() + .messages(&thread, from.unwrap_or(0)..) + .await? .filter_map(Result::ok) { - print_message(&manager, false, &msg); + print_message(&manager, false, &msg).await; } } Cmd::Stats => { diff --git a/presage-store-sled/src/content.rs b/presage-store-sled/src/content.rs index fb0b64aa3..de986e886 100644 --- a/presage-store-sled/src/content.rs +++ b/presage-store-sled/src/content.rs @@ -39,7 +39,7 @@ impl ContentsStore for SledStore { type MessagesIter = SledMessagesIter; type StickerPacksIter = SledStickerPacksIter; - fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError> { + async fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError> { let db = self.write(); db.drop_tree(SLED_TREE_PROFILES)?; db.drop_tree(SLED_TREE_PROFILE_KEYS)?; @@ -48,7 +48,7 @@ impl ContentsStore for SledStore { Ok(()) } - fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError> { + async fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError> { let db = self.write(); db.drop_tree(SLED_TREE_CONTACTS)?; db.drop_tree(SLED_TREE_GROUPS)?; @@ -65,18 +65,18 @@ impl ContentsStore for SledStore { Ok(()) } - fn clear_contacts(&mut self) -> Result<(), SledStoreError> { + async fn clear_contacts(&mut self) -> Result<(), SledStoreError> { self.write().drop_tree(SLED_TREE_CONTACTS)?; Ok(()) } - fn save_contact(&mut self, contact: &Contact) -> Result<(), SledStoreError> { + async fn save_contact(&mut self, contact: &Contact) -> Result<(), SledStoreError> { self.insert(SLED_TREE_CONTACTS, contact.uuid, contact)?; debug!("saved contact"); Ok(()) } - fn contacts(&self) -> Result { + async fn contacts(&self) -> Result { Ok(SledContactsIter { iter: self.read().open_tree(SLED_TREE_CONTACTS)?.iter(), #[cfg(feature = "encryption")] @@ -84,20 +84,20 @@ impl ContentsStore for SledStore { }) } - fn contact_by_id(&self, id: &Uuid) -> Result, SledStoreError> { + async fn contact_by_id(&self, id: &Uuid) -> Result, SledStoreError> { self.get(SLED_TREE_CONTACTS, id) } /// Groups - fn clear_groups(&mut self) -> Result<(), SledStoreError> { + async fn clear_groups(&mut self) -> Result<(), SledStoreError> { let db = self.write(); db.drop_tree(SLED_TREE_GROUPS)?; db.flush()?; Ok(()) } - fn groups(&self) -> Result { + async fn groups(&self) -> Result { Ok(SledGroupsIter { iter: self.read().open_tree(SLED_TREE_GROUPS)?.iter(), #[cfg(feature = "encryption")] @@ -105,14 +105,14 @@ impl ContentsStore for SledStore { }) } - fn group( + async fn group( &self, master_key_bytes: GroupMasterKeyBytes, ) -> Result, SledStoreError> { self.get(SLED_TREE_GROUPS, master_key_bytes) } - fn save_group( + async fn save_group( &self, master_key: GroupMasterKeyBytes, group: impl Into, @@ -121,14 +121,14 @@ impl ContentsStore for SledStore { Ok(()) } - fn group_avatar( + async fn group_avatar( &self, master_key_bytes: GroupMasterKeyBytes, ) -> Result, SledStoreError> { self.get(SLED_TREE_GROUP_AVATARS, master_key_bytes) } - fn save_group_avatar( + async fn save_group_avatar( &self, master_key: GroupMasterKeyBytes, avatar: &AvatarBytes, @@ -139,7 +139,7 @@ impl ContentsStore for SledStore { /// Messages - fn clear_messages(&mut self) -> Result<(), SledStoreError> { + async fn clear_messages(&mut self) -> Result<(), SledStoreError> { let db = self.write(); for name in db.tree_names() { if name @@ -153,7 +153,7 @@ impl ContentsStore for SledStore { Ok(()) } - fn clear_thread(&mut self, thread: &Thread) -> Result<(), SledStoreError> { + async fn clear_thread(&mut self, thread: &Thread) -> Result<(), SledStoreError> { trace!(%thread, "clearing thread"); let db = self.write(); @@ -163,7 +163,7 @@ impl ContentsStore for SledStore { Ok(()) } - fn save_message(&self, thread: &Thread, message: Content) -> Result<(), SledStoreError> { + async fn save_message(&self, thread: &Thread, message: Content) -> Result<(), SledStoreError> { let ts = message.timestamp(); trace!(%thread, ts, "storing a message with thread"); @@ -178,12 +178,20 @@ impl ContentsStore for SledStore { Ok(()) } - fn delete_message(&mut self, thread: &Thread, timestamp: u64) -> Result { + async fn delete_message( + &mut self, + thread: &Thread, + timestamp: u64, + ) -> Result { let tree = messages_thread_tree_name(thread); self.remove(&tree, timestamp.to_be_bytes()) } - fn message(&self, thread: &Thread, timestamp: u64) -> Result, SledStoreError> { + async fn message( + &self, + thread: &Thread, + timestamp: u64, + ) -> Result, SledStoreError> { // Big-Endian needed, otherwise wrong ordering in sled. let val: Option> = self.get(&messages_thread_tree_name(thread), timestamp.to_be_bytes())?; @@ -197,7 +205,7 @@ impl ContentsStore for SledStore { } } - fn messages( + async fn messages( &self, thread: &Thread, range: impl RangeBounds, @@ -228,15 +236,19 @@ impl ContentsStore for SledStore { }) } - fn upsert_profile_key(&mut self, uuid: &Uuid, key: ProfileKey) -> Result { + async fn upsert_profile_key( + &mut self, + uuid: &Uuid, + key: ProfileKey, + ) -> Result { self.insert(SLED_TREE_PROFILE_KEYS, uuid.as_bytes(), key) } - fn profile_key(&self, uuid: &Uuid) -> Result, SledStoreError> { + async fn profile_key(&self, uuid: &Uuid) -> Result, SledStoreError> { self.get(SLED_TREE_PROFILE_KEYS, uuid.as_bytes()) } - fn save_profile( + async fn save_profile( &mut self, uuid: Uuid, key: ProfileKey, @@ -247,12 +259,16 @@ impl ContentsStore for SledStore { Ok(()) } - fn profile(&self, uuid: Uuid, key: ProfileKey) -> Result, SledStoreError> { + async fn profile( + &self, + uuid: Uuid, + key: ProfileKey, + ) -> Result, SledStoreError> { let key = self.profile_key_for_uuid(uuid, key); self.get(SLED_TREE_PROFILES, key) } - fn save_profile_avatar( + async fn save_profile_avatar( &mut self, uuid: Uuid, key: ProfileKey, @@ -263,7 +279,7 @@ impl ContentsStore for SledStore { Ok(()) } - fn profile_avatar( + async fn profile_avatar( &self, uuid: Uuid, key: ProfileKey, @@ -272,20 +288,20 @@ impl ContentsStore for SledStore { self.get(SLED_TREE_PROFILE_AVATARS, key) } - fn add_sticker_pack(&mut self, pack: &StickerPack) -> Result<(), SledStoreError> { + async fn add_sticker_pack(&mut self, pack: &StickerPack) -> Result<(), SledStoreError> { self.insert(SLED_TREE_STICKER_PACKS, pack.id.clone(), pack)?; Ok(()) } - fn remove_sticker_pack(&mut self, id: &[u8]) -> Result { + async fn remove_sticker_pack(&mut self, id: &[u8]) -> Result { self.remove(SLED_TREE_STICKER_PACKS, id) } - fn sticker_pack(&self, id: &[u8]) -> Result, SledStoreError> { + async fn sticker_pack(&self, id: &[u8]) -> Result, SledStoreError> { self.get(SLED_TREE_STICKER_PACKS, id) } - fn sticker_packs(&self) -> Result { + async fn sticker_packs(&self) -> Result { Ok(SledStickerPacksIter { cipher: self.cipher.clone(), iter: self.read().open_tree(SLED_TREE_STICKER_PACKS)?.iter(), diff --git a/presage-store-sled/src/lib.rs b/presage-store-sled/src/lib.rs index 24b47735e..40187bf15 100644 --- a/presage-store-sled/src/lib.rs +++ b/presage-store-sled/src/lib.rs @@ -134,7 +134,7 @@ impl SledStore { }) } - pub fn open( + pub async fn open( db_path: impl AsRef, migration_conflict_strategy: MigrationConflictStrategy, trust_new_identities: OnNewIdentity, @@ -145,9 +145,10 @@ impl SledStore { migration_conflict_strategy, trust_new_identities, ) + .await } - pub fn open_with_passphrase( + pub async fn open_with_passphrase( db_path: impl AsRef, passphrase: Option>, migration_conflict_strategy: MigrationConflictStrategy, @@ -155,7 +156,7 @@ impl SledStore { ) -> Result { let passphrase = passphrase.as_ref(); - migrate(&db_path, passphrase, migration_conflict_strategy)?; + migrate(&db_path, passphrase, migration_conflict_strategy).await?; Self::new(db_path, passphrase, trust_new_identities) } @@ -312,7 +313,7 @@ impl SledStore { } } -fn migrate( +async fn migrate( db_path: impl AsRef, passphrase: Option>, migration_conflict_strategy: MigrationConflictStrategy, @@ -320,7 +321,7 @@ fn migrate( let db_path = db_path.as_ref(); let passphrase = passphrase.as_ref(); - let run_migrations = move || { + let run_migrations = { let mut store = SledStore::new(db_path, passphrase, OnNewIdentity::Reject)?; let schema_version = store.schema_version(); for step in schema_version.steps() { @@ -337,7 +338,7 @@ fn migrate( let state = serde_json::from_slice(&data).map_err(SledStoreError::from)?; // save it the new school way - store.save_registration_data(&state)?; + store.save_registration_data(&state).await?; // remove old data let db = store.write(); @@ -347,11 +348,11 @@ fn migrate( } SchemaVersion::V3 => { debug!("migrating from schema v2 to v3: dropping encrypted group cache"); - store.clear_groups()?; + store.clear_groups().await?; } SchemaVersion::V4 => { debug!("migrating from schema v3 to v4: dropping profile cache"); - store.clear_profiles()?; + store.clear_profiles().await?; } SchemaVersion::V5 => { debug!("migrating from schema v4 to v5: moving identity key pairs"); @@ -368,27 +369,31 @@ fn migrate( pub(crate) pni_public_key: Option, } - let run_step = || -> Result<(), SledStoreError> { + let run_step: Result<(), SledStoreError> = { let registration_data: Option = store.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)?; if let Some(data) = registration_data { - store.set_aci_identity_key_pair(IdentityKeyPair::new( - data.aci_public_key, - data.aci_private_key, - ))?; + store + .set_aci_identity_key_pair(IdentityKeyPair::new( + data.aci_public_key, + data.aci_private_key, + )) + .await?; if let Some((public_key, private_key)) = data.pni_public_key.zip(data.pni_private_key) { - store.set_pni_identity_key_pair(IdentityKeyPair::new( - public_key, - private_key, - ))?; + store + .set_pni_identity_key_pair(IdentityKeyPair::new( + public_key, + private_key, + )) + .await?; } } Ok(()) }; - if let Err(error) = run_step() { + if let Err(error) = run_step { error!("failed to run v4 -> v5 migration: {error}"); } } @@ -436,7 +441,7 @@ fn migrate( Ok(()) }; - if let Err(SledStoreError::MigrationConflict) = run_migrations() { + if let Err(SledStoreError::MigrationConflict) = run_migrations { match migration_conflict_strategy { MigrationConflictStrategy::BackupAndDrop => { let mut new_db_path = db_path.to_path_buf(); @@ -464,34 +469,40 @@ fn migrate( impl StateStore for SledStore { type StateStoreError = SledStoreError; - fn load_registration_data(&self) -> Result, SledStoreError> { + async fn load_registration_data(&self) -> Result, SledStoreError> { self.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION) } - fn set_aci_identity_key_pair( + async fn set_aci_identity_key_pair( &self, key_pair: IdentityKeyPair, ) -> Result<(), Self::StateStoreError> { self.set_identity_key_pair::(key_pair) } - fn set_pni_identity_key_pair( + async fn set_pni_identity_key_pair( &self, key_pair: IdentityKeyPair, ) -> Result<(), Self::StateStoreError> { self.set_identity_key_pair::(key_pair) } - fn save_registration_data(&mut self, state: &RegistrationData) -> Result<(), SledStoreError> { + async fn save_registration_data( + &mut self, + state: &RegistrationData, + ) -> Result<(), SledStoreError> { self.insert(SLED_TREE_STATE, SLED_KEY_REGISTRATION, state)?; Ok(()) } - fn is_registered(&self) -> bool { - self.load_registration_data().unwrap_or_default().is_some() + async fn is_registered(&self) -> bool { + self.load_registration_data() + .await + .unwrap_or_default() + .is_some() } - fn clear_registration(&mut self) -> Result<(), SledStoreError> { + async fn clear_registration(&mut self) -> Result<(), SledStoreError> { // drop registration data (includes identity keys) { let db = self.write(); @@ -501,7 +512,7 @@ impl StateStore for SledStore { } // drop all saved profile (+avatards) and profile keys - self.clear_profiles()?; + self.clear_profiles().await?; // drop all keys self.aci_protocol_store().clear(true)?; @@ -516,9 +527,9 @@ impl Store for SledStore { type AciStore = SledProtocolStore; type PniStore = SledProtocolStore; - fn clear(&mut self) -> Result<(), SledStoreError> { - self.clear_registration()?; - self.clear_contents()?; + async fn clear(&mut self) -> Result<(), SledStoreError> { + self.clear_registration().await?; + self.clear_contents().await?; Ok(()) } @@ -639,23 +650,35 @@ mod tests { async fn test_store_messages(thread: Thread, content: Content) -> anyhow::Result<()> { let db = SledStore::temporary()?; let thread = thread.0; - db.save_message(&thread, content_with_timestamp(&content, 1678295210))?; - db.save_message(&thread, content_with_timestamp(&content, 1678295220))?; - db.save_message(&thread, content_with_timestamp(&content, 1678295230))?; - db.save_message(&thread, content_with_timestamp(&content, 1678295240))?; - db.save_message(&thread, content_with_timestamp(&content, 1678280000))?; - - assert_eq!(db.messages(&thread, ..).unwrap().count(), 5); - assert_eq!(db.messages(&thread, 0..).unwrap().count(), 5); - assert_eq!(db.messages(&thread, 1678280000..).unwrap().count(), 5); - - assert_eq!(db.messages(&thread, 0..1678280000)?.count(), 0); - assert_eq!(db.messages(&thread, 0..1678295210)?.count(), 1); - assert_eq!(db.messages(&thread, 1678295210..1678295240)?.count(), 3); - assert_eq!(db.messages(&thread, 1678295210..=1678295240)?.count(), 4); + db.save_message(&thread, content_with_timestamp(&content, 1678295210)) + .await?; + db.save_message(&thread, content_with_timestamp(&content, 1678295220)) + .await?; + db.save_message(&thread, content_with_timestamp(&content, 1678295230)) + .await?; + db.save_message(&thread, content_with_timestamp(&content, 1678295240)) + .await?; + db.save_message(&thread, content_with_timestamp(&content, 1678280000)) + .await?; + + assert_eq!(db.messages(&thread, ..).await.unwrap().count(), 5); + assert_eq!(db.messages(&thread, 0..).await.unwrap().count(), 5); + assert_eq!(db.messages(&thread, 1678280000..).await.unwrap().count(), 5); + + assert_eq!(db.messages(&thread, 0..1678280000).await?.count(), 0); + assert_eq!(db.messages(&thread, 0..1678295210).await?.count(), 1); + assert_eq!( + db.messages(&thread, 1678295210..1678295240).await?.count(), + 3 + ); + assert_eq!( + db.messages(&thread, 1678295210..=1678295240).await?.count(), + 4 + ); assert_eq!( - db.messages(&thread, 0..=1678295240)? + db.messages(&thread, 0..=1678295240) + .await? .next() .unwrap()? .metadata @@ -663,7 +686,8 @@ mod tests { 1678280000 ); assert_eq!( - db.messages(&thread, 0..=1678295240)? + db.messages(&thread, 0..=1678295240) + .await? .next_back() .unwrap()? .metadata diff --git a/presage-store-sled/src/protocol.rs b/presage-store-sled/src/protocol.rs index 5036071bf..7ac0e4d95 100644 --- a/presage-store-sled/src/protocol.rs +++ b/presage-store-sled/src/protocol.rs @@ -17,7 +17,7 @@ use presage::{ ServiceAddress, }, proto::verified, - store::{ContentsStore, StateStore}, + store::{save_trusted_identity_message, StateStore}, }; use sled::Batch; use tracing::{error, trace, warn}; @@ -537,7 +537,8 @@ impl IdentityKeyStore for SledProtocolStore { async fn get_local_registration_id(&self) -> Result { let data = self.store - .load_registration_data()? + .load_registration_data() + .await? .ok_or(SignalProtocolError::InvalidState( "failed to load registration ID", "no registration data".into(), @@ -563,7 +564,8 @@ impl IdentityKeyStore for SledProtocolStore { error })?; - self.store.save_trusted_identity_message( + save_trusted_identity_message( + &self.store, address, *identity_key, if existed_before { @@ -571,7 +573,8 @@ impl IdentityKeyStore for SledProtocolStore { } else { verified::State::Default }, - ); + ) + .await?; Ok(true) } diff --git a/presage/Cargo.toml b/presage/Cargo.toml index 6482852ac..96cc96c68 100644 --- a/presage/Cargo.toml +++ b/presage/Cargo.toml @@ -1,7 +1,7 @@ [package] # be a sign or warning of (an imminent event, typically an unwelcome one). name = "presage" -version = "0.6.2" +version = "0.7.0-dev" authors = ["Gabriel FĂ©ron "] edition = "2021" license = "AGPL-3.0-only" diff --git a/presage/src/manager/confirmation.rs b/presage/src/manager/confirmation.rs index 60aaec22a..01b8a89fb 100644 --- a/presage/src/manager/confirmation.rs +++ b/presage/src/manager/confirmation.rs @@ -60,11 +60,8 @@ impl Manager { }; let service_configuration: ServiceConfiguration = signal_servers.into(); - let mut identified_push_service = PushService::new( - service_configuration, - Some(credentials), - crate::USER_AGENT.to_string(), - ); + let mut identified_push_service = + PushService::new(service_configuration, Some(credentials), crate::USER_AGENT); let session = identified_push_service .submit_verification_code(&session_id, confirmation_code.as_ref()) @@ -87,9 +84,11 @@ impl Manager { // generate new identity keys used in `register_account` and below self.store - .set_aci_identity_key_pair(IdentityKeyPair::generate(&mut self.rng))?; + .set_aci_identity_key_pair(IdentityKeyPair::generate(&mut self.rng)) + .await?; self.store - .set_pni_identity_key_pair(IdentityKeyPair::generate(&mut self.rng))?; + .set_pni_identity_key_pair(IdentityKeyPair::generate(&mut self.rng)) + .await?; let skip_device_transfer = true; let mut account_manager = AccountManager::new(identified_push_service, Some(profile_key)); @@ -143,11 +142,14 @@ impl Manager { }), }; - manager.store.save_registration_data(&manager.state.data)?; + manager + .store + .save_registration_data(&manager.state.data) + .await?; if let Err(e) = manager.register_pre_keys().await { // clear the entire store on any error, there's no possible recovery here - manager.store.clear_registration()?; + manager.store.clear_registration().await?; Err(e) } else { Ok(manager) diff --git a/presage/src/manager/linking.rs b/presage/src/manager/linking.rs index 3d26c1253..ed338f74e 100644 --- a/presage/src/manager/linking.rs +++ b/presage/src/manager/linking.rs @@ -64,7 +64,7 @@ impl Manager { ) -> Result, Error> { // clear the database: the moment we start the process, old API credentials are invalidated // and you won't be able to use this client anyways - store.clear_registration()?; + store.clear_registration().await?; // generate a random alphanumeric 24 chars password let mut rng = StdRng::from_entropy(); @@ -75,8 +75,7 @@ impl Manager { rng.fill_bytes(&mut signaling_key); let service_configuration: ServiceConfiguration = signal_servers.into(); - let push_service = - PushService::new(service_configuration, None, crate::USER_AGENT.to_string()); + let push_service = PushService::new(service_configuration, None, crate::USER_AGENT); let (tx, mut rx) = mpsc::channel(1); @@ -138,16 +137,20 @@ impl Manager { profile_key, }; - store.set_aci_identity_key_pair(IdentityKeyPair::new( - aci_public_key, - aci_private_key, - ))?; - store.set_pni_identity_key_pair(IdentityKeyPair::new( - pni_public_key, - pni_private_key, - ))?; + store + .set_aci_identity_key_pair(IdentityKeyPair::new( + aci_public_key, + aci_private_key, + )) + .await?; + store + .set_pni_identity_key_pair(IdentityKeyPair::new( + pni_public_key, + pni_private_key, + )) + .await?; - store.save_registration_data(®istration_data)?; + store.save_registration_data(®istration_data).await?; info!( "successfully registered device {}", ®istration_data.service_ids @@ -162,14 +165,14 @@ impl Manager { // Register pre-keys with the server. If this fails, this can lead to issues // receiving, in that case clear the registration and propagate the error. if let Err(e) = manager.register_pre_keys().await { - store.clear_registration()?; + store.clear_registration().await?; Err(e) } else { Ok(manager) } } Err(e) => { - store.clear_registration()?; + store.clear_registration().await?; Err(e) } } diff --git a/presage/src/manager/registered.rs b/presage/src/manager/registered.rs index ee15dc4b7..40e846b91 100644 --- a/presage/src/manager/registered.rs +++ b/presage/src/manager/registered.rs @@ -1,5 +1,4 @@ use std::fmt; -use std::ops::RangeBounds; use std::pin::pin; use std::sync::{Arc, OnceLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -142,7 +141,8 @@ impl Manager { /// Returns a instance of [Manager] you can use to send & receive messages. pub async fn load_registered(store: S) -> Result> { let registration_data = store - .load_registration_data()? + .load_registration_data() + .await? .ok_or(Error::NotYetRegisteredError)?; let mut manager = Self { @@ -180,7 +180,7 @@ impl Manager { PushService::new( self.state.service_configuration(), self.credentials(), - crate::USER_AGENT.to_string(), + crate::USER_AGENT, ) }) .clone() @@ -193,11 +193,7 @@ impl Manager { self.state .unidentified_push_service .get_or_init(|| { - PushService::new( - self.state.service_configuration(), - None, - crate::USER_AGENT.to_string(), - ) + PushService::new(self.state.service_configuration(), None, crate::USER_AGENT) }) .clone() } @@ -297,7 +293,7 @@ impl Manager { } else { info!("migrating to PNI"); let pni_registration_id = generate_registration_id(&mut StdRng::from_entropy()); - self.store.save_registration_data(&self.state.data)?; + self.store.save_registration_data(&self.state.data).await?; pni_registration_id }; @@ -332,7 +328,7 @@ impl Manager { debug!("fetching PNI UUID and updating state"); let whoami = self.whoami().await?; self.state.data.service_ids.pni = whoami.pni; - self.store.save_registration_data(&self.state.data)?; + self.store.save_registration_data(&self.state.data).await?; } trace!("done setting account attributes"); @@ -462,7 +458,7 @@ impl Manager { // Check if profile is cached. // TODO: Create a migration in the store removing all profiles. // TODO: Is there some way to know if this is outdated? - if let Some(profile) = self.store.profile(uuid, profile_key).ok().flatten() { + if let Some(profile) = self.store.profile(uuid, profile_key).await.ok().flatten() { return Ok(profile); } @@ -473,7 +469,10 @@ impl Manager { .retrieve_profile(ServiceAddress::from_aci(uuid)) .await?; - let _ = self.store.save_profile(uuid, profile_key, profile.clone()); + let _ = self + .store + .save_profile(uuid, profile_key, profile.clone()) + .await; Ok(profile) } @@ -488,7 +487,13 @@ impl Manager { // Check if group avatar is cached. // TODO: Is there some way to know if this is outdated? - if let Some(avatar) = self.store.group_avatar(master_key_bytes).ok().flatten() { + if let Some(avatar) = self + .store + .group_avatar(master_key_bytes) + .await + .ok() + .flatten() + { return Ok(Some(avatar)); } @@ -516,7 +521,7 @@ impl Manager { ) .await?; if let Some(avatar) = &avatar { - let _ = self.store.save_group_avatar(master_key_bytes, avatar); + let _ = self.store.save_group_avatar(master_key_bytes, avatar).await; } Ok(avatar) } @@ -528,15 +533,22 @@ impl Manager { ) -> Result, Error> { // Check if profile avatar is cached. // TODO: Is there some way to know if this is outdated? - if let Some(avatar) = self.store.profile_avatar(uuid, profile_key).ok().flatten() { + if let Some(avatar) = self + .store + .profile_avatar(uuid, profile_key) + .await + .ok() + .flatten() + { return Ok(Some(avatar)); } - let profile = if let Some(profile) = self.store.profile(uuid, profile_key).ok().flatten() { - profile - } else { - self.retrieve_profile_by_uuid(uuid, profile_key).await? - }; + let profile = + if let Some(profile) = self.store.profile(uuid, profile_key).await.ok().flatten() { + profile + } else { + self.retrieve_profile_by_uuid(uuid, profile_key).await? + }; let Some(avatar) = profile.avatar.as_ref() else { return Ok(None); @@ -553,19 +565,13 @@ impl Manager { let cipher = ProfileCipher::from(profile_key); let avatar = cipher.decrypt_avatar(&contents)?; - let _ = self.store.save_profile_avatar(uuid, profile_key, &avatar); + let _ = self + .store + .save_profile_avatar(uuid, profile_key, &avatar) + .await; Ok(Some(avatar)) } - /// Gets an iterator of messages in a thread, optionally starting from a point in time. - pub fn messages( - &self, - thread: &Thread, - range: impl RangeBounds, - ) -> Result> { - Ok(self.store.messages(thread, range)?) - } - async fn receive_messages_encrypted( &mut self, ) -> Result>, Error> { @@ -661,11 +667,11 @@ impl Manager { { match state.message_receiver.retrieve_contacts(contacts).await { Ok(contacts) => { - let _ = state.store.clear_contacts(); + let _ = state.store.clear_contacts().await; info!("saving contacts"); for contact in contacts.filter_map(Result::ok) { if let Err(error) = - state.store.save_contact(&contact.into()) + state.store.save_contact(&contact.into()).await { warn!(%error, "failed to save contacts"); break; @@ -722,6 +728,7 @@ impl Manager { match state .store .remove_sticker_pack(operation.pack_id()) + .await { Ok(was_present) => { debug!(was_present, "removed stick pack") @@ -841,12 +848,13 @@ impl Manager { let thread = Thread::Contact(recipient.uuid); let mut content_body: ContentBody = message.into(); - self.restore_thread_timer(&thread, &mut content_body); + self.restore_thread_timer(&thread, &mut content_body).await; let sender_certificate = self.sender_certificate().await?; let unidentified_access = self.store - .profile_key(&recipient.uuid)? + .profile_key(&recipient.uuid) + .await? .map(|profile_key| UnidentifiedAccess { key: profile_key.derive_access_key().to_vec(), certificate: sender_certificate.clone(), @@ -923,7 +931,7 @@ impl Manager { .expect("Master key bytes to be of size 32."); let thread = Thread::Group(master_key_bytes); - self.restore_thread_timer(&thread, &mut content_body); + self.restore_thread_timer(&thread, &mut content_body).await; let mut sender = self.new_message_sender().await?; @@ -943,7 +951,8 @@ impl Manager { { let unidentified_access = self.store - .profile_key(&member.uuid)? + .profile_key(&member.uuid) + .await? .map(|profile_key| UnidentifiedAccess { key: profile_key.derive_access_key().to_vec(), certificate: sender_certificate.clone(), @@ -995,23 +1004,21 @@ impl Manager { Ok(()) } - fn restore_thread_timer(&mut self, thread: &Thread, content_body: &mut ContentBody) { - let store_expire_timer = self.store.expire_timer(thread).unwrap_or_default(); - - match content_body { - ContentBody::DataMessage(DataMessage { - expire_timer: ref mut timer, - expire_timer_version: ref mut version, - .. - }) => { - if timer.is_none() { - *timer = store_expire_timer.map(|(t, _)| t); - *version = Some(store_expire_timer.map(|(_, v)| v).unwrap_or_default()); - } else { - *version = Some(store_expire_timer.map(|(_, v)| v).unwrap_or_default() + 1); - } + async fn restore_thread_timer(&mut self, thread: &Thread, content_body: &mut ContentBody) { + let store_expire_timer = self.store.expire_timer(thread).await.unwrap_or_default(); + + if let ContentBody::DataMessage(DataMessage { + expire_timer: ref mut timer, + expire_timer_version: ref mut version, + .. + }) = content_body + { + if timer.is_none() { + *timer = store_expire_timer.map(|(t, _)| t); + *version = Some(store_expire_timer.map(|(_, v)| v).unwrap_or_default()); + } else { + *version = Some(store_expire_timer.map(|(_, v)| v).unwrap_or_default() + 1); } - _ => (), } } @@ -1058,26 +1065,13 @@ impl Manager { Ok(ciphertext) } - /// Gets an iterator over installed sticker packs - pub async fn sticker_packs(&self) -> Result> { - Ok(self.store.sticker_packs()?) - } - - /// Gets a sticker pack by id - pub async fn sticker_pack( - &self, - pack_id: &[u8], - ) -> Result, Error> { - Ok(self.store.sticker_pack(pack_id)?) - } - /// Gets the metadata of a sticker pub async fn sticker_metadata( &mut self, pack_id: &[u8], sticker_id: u32, ) -> Result, Error> { - Ok(self.store.sticker_pack(pack_id)?.and_then(|pack| { + Ok(self.store.sticker_pack(pack_id).await?.and_then(|pack| { pack.manifest .stickers .iter() @@ -1150,7 +1144,7 @@ impl Manager { ) .await?; - self.store.remove_sticker_pack(pack_id)?; + self.store.remove_sticker_pack(pack_id).await?; Ok(()) } @@ -1238,7 +1232,7 @@ impl Manager { pub async fn thread_title(&self, thread: &Thread) -> Result> { match thread { Thread::Contact(uuid) => { - let contact = match self.store.contact_by_id(uuid) { + let contact = match self.store.contact_by_id(uuid).await { Ok(contact) => contact, Err(error) => { info!(%error, %uuid, "error getting contact by id"); @@ -1250,7 +1244,7 @@ impl Manager { None => uuid.to_string(), }) } - Thread::Group(id) => match self.store.group(*id)? { + Thread::Group(id) => match self.store.group(*id).await? { Some(group) => Ok(group.title), None => Ok("".to_string()), }, @@ -1314,47 +1308,6 @@ impl Manager { Ok(account_manager.linked_devices(&aci_protocol_store).await?) } - - /// Deprecated methods - - /// Get a single contact by its UUID - /// - /// Note: this only currently works when linked as secondary device (the contacts are sent by the primary device at linking time) - #[deprecated = "use the store handle directly"] - pub fn contact_by_id(&self, id: &Uuid) -> Result, Error> { - Ok(self.store.contact_by_id(id)?) - } - - /// Returns an iterator on contacts stored in the [Store]. - #[deprecated = "use the store handle directly"] - pub fn contacts( - &self, - ) -> Result>>, Error> { - let iter = self.store.contacts()?; - Ok(iter.map(|r| r.map_err(Into::into))) - } - - /// Get a group (either from the local cache, or fetch it remotely) using its master key - #[deprecated = "use the store handle directly"] - pub fn group(&self, master_key_bytes: &[u8]) -> Result, Error> { - Ok(self.store.group(master_key_bytes.try_into()?)?) - } - - /// Returns an iterator on groups stored in the [Store]. - #[deprecated = "use the store handle directly"] - pub fn groups(&self) -> Result> { - Ok(self.store.groups()?) - } - - /// Get a single message in a thread (identified by its server-side sent timestamp) - #[deprecated = "use the store handle directly"] - pub fn message( - &self, - thread: &Thread, - timestamp: u64, - ) -> Result, Error> { - Ok(self.store.message(thread, timestamp)?) - } } /// The mode receiving messages stream @@ -1377,7 +1330,7 @@ async fn upsert_group( master_key_bytes: &[u8], revision: &u32, ) -> Result, Error> { - let upsert_group = match store.group(master_key_bytes.try_into()?) { + let upsert_group = match store.group(master_key_bytes.try_into()?).await { Ok(Some(group)) => { debug!(group_name =% group.title, "loaded group from local db"); group.revision < *revision @@ -1394,7 +1347,7 @@ async fn upsert_group( match groups_manager.fetch_encrypted_group(master_key_bytes).await { Ok(encrypted_group) => { let group = decrypt_group(master_key_bytes, encrypted_group)?; - if let Err(error) = store.save_group(master_key_bytes.try_into()?, group) { + if let Err(error) = store.save_group(master_key_bytes.try_into()?, group).await { error!(%error, "failed to save group"); } } @@ -1404,7 +1357,7 @@ async fn upsert_group( } } - Ok(store.group(master_key_bytes.try_into()?)?) + Ok(store.group(master_key_bytes.try_into()?).await?) } /// Download and decrypt a sticker manifest @@ -1452,7 +1405,7 @@ async fn download_sticker_pack( }; // save everything in store - store.add_sticker_pack(&sticker_pack)?; + store.add_sticker_pack(&sticker_pack).await?; Ok(sticker_pack) } @@ -1524,9 +1477,10 @@ async fn save_message( // - insert a new contact with the profile information // - update the contact if the profile key has changed // TODO: mark this contact as "created by us" maybe to know whether we should update it or not - if store.contact_by_id(&sender_uuid)?.is_none() + if store.contact_by_id(&sender_uuid).await?.is_none() || !store - .profile_key(&sender_uuid)? + .profile_key(&sender_uuid) + .await? .is_some_and(|p| p.bytes == profile_key.bytes) { let encrypted_profile = push_service @@ -1563,15 +1517,17 @@ async fn save_message( }; info!(%sender_uuid, "saved contact on first sight"); - store.save_contact(&contact)?; + store.save_contact(&contact).await?; } - store.upsert_profile_key(&sender_uuid, profile_key)?; + store.upsert_profile_key(&sender_uuid, profile_key).await?; } if let Some(expire_timer) = data_message.expire_timer { let version = data_message.expire_timer_version.unwrap_or(1); - store.update_expire_timer(&thread, expire_timer, version)?; + store + .update_expire_timer(&thread, expire_timer, version) + .await?; } match data_message { @@ -1583,10 +1539,10 @@ async fn save_message( .. } => { // replace an existing message by an empty NullMessage - if let Some(mut existing_msg) = store.message(&thread, *ts)? { + if let Some(mut existing_msg) = store.message(&thread, *ts).await? { existing_msg.metadata.sender.uuid = Uuid::nil(); existing_msg.body = NullMessage::default().into(); - store.save_message(&thread, existing_msg)?; + store.save_message(&thread, existing_msg).await?; debug!(%thread, ts, "message in thread deleted"); None } else { @@ -1613,7 +1569,7 @@ async fn save_message( }), .. }) => { - if let Some(mut existing_msg) = store.message(&thread, ts)? { + if let Some(mut existing_msg) = store.message(&thread, ts).await? { existing_msg.metadata = message.metadata; existing_msg.body = ContentBody::DataMessage(data_message); // TODO: find a way to mark the message as edited (so that it's visible in a client) @@ -1656,7 +1612,7 @@ async fn save_message( }; if let Some(message) = message { - store.save_message(&thread, message)?; + store.save_message(&thread, message).await?; } Ok(()) diff --git a/presage/src/manager/registration.rs b/presage/src/manager/registration.rs index cd1bdf759..c7c5dd4c2 100644 --- a/presage/src/manager/registration.rs +++ b/presage/src/manager/registration.rs @@ -73,19 +73,18 @@ impl Manager { } = registration_options; // check if we are already registered - if !force && store.is_registered() { + if !force && store.is_registered().await { return Err(Error::AlreadyRegisteredError); } - store.clear_registration()?; + store.clear_registration().await?; // generate a random alphanumeric 24 chars password let mut rng = StdRng::from_entropy(); let password = Alphanumeric.sample_string(&mut rng, 24); let service_configuration: ServiceConfiguration = signal_servers.into(); - let mut push_service = - PushService::new(service_configuration, None, crate::USER_AGENT.to_string()); + let mut push_service = PushService::new(service_configuration, None, crate::USER_AGENT); trace!("creating registration verification session"); diff --git a/presage/src/model/groups.rs b/presage/src/model/groups.rs index bf5583ed5..c2b7ef62a 100644 --- a/presage/src/model/groups.rs +++ b/presage/src/model/groups.rs @@ -40,45 +40,45 @@ pub struct RequestingMember { pub timestamp: u64, } -impl Into for libsignal_service::groups_v2::Group { - fn into(self) -> Group { +impl From for Group { + fn from(val: libsignal_service::groups_v2::Group) -> Self { Group { - title: self.title, - avatar: self.avatar, - disappearing_messages_timer: self.disappearing_messages_timer, - access_control: self.access_control, - revision: self.revision, - members: self.members, - pending_members: self.pending_members.into_iter().map(Into::into).collect(), - requesting_members: self + title: val.title, + avatar: val.avatar, + disappearing_messages_timer: val.disappearing_messages_timer, + access_control: val.access_control, + revision: val.revision, + members: val.members, + pending_members: val.pending_members.into_iter().map(Into::into).collect(), + requesting_members: val .requesting_members .into_iter() .map(Into::into) .collect(), - invite_link_password: self.invite_link_password, - description: self.description, + invite_link_password: val.invite_link_password, + description: val.description, } } } -impl Into for libsignal_service::groups_v2::PendingMember { - fn into(self) -> PendingMember { +impl From for PendingMember { + fn from(val: libsignal_service::groups_v2::PendingMember) -> Self { PendingMember { - uuid: self.address.uuid, - service_id_type: self.address.identity.into(), - role: self.role, - added_by_uuid: self.added_by_uuid, - timestamp: self.timestamp, + uuid: val.address.uuid, + service_id_type: val.address.identity.into(), + role: val.role, + added_by_uuid: val.added_by_uuid, + timestamp: val.timestamp, } } } -impl Into for libsignal_service::groups_v2::RequestingMember { - fn into(self) -> RequestingMember { +impl From for RequestingMember { + fn from(val: libsignal_service::groups_v2::RequestingMember) -> Self { RequestingMember { - uuid: self.uuid, - profile_key: self.profile_key, - timestamp: self.timestamp, + uuid: val.uuid, + profile_key: val.profile_key, + timestamp: val.timestamp, } } } diff --git a/presage/src/model/mod.rs b/presage/src/model/mod.rs index 8c2587fe2..7920317b5 100644 --- a/presage/src/model/mod.rs +++ b/presage/src/model/mod.rs @@ -16,9 +16,9 @@ pub enum ServiceIdType { PhoneNumberIdentity, } -impl Into for libsignal_service::ServiceIdType { - fn into(self) -> ServiceIdType { - match self { +impl From for ServiceIdType { + fn from(val: libsignal_service::ServiceIdType) -> Self { + match val { libsignal_service::ServiceIdType::AccountIdentity => ServiceIdType::AccountIdentity, libsignal_service::ServiceIdType::PhoneNumberIdentity => { ServiceIdType::PhoneNumberIdentity diff --git a/presage/src/store.rs b/presage/src/store.rs index 7f9010942..956e95068 100644 --- a/presage/src/store.rs +++ b/presage/src/store.rs @@ -1,6 +1,6 @@ //! Traits that are used by the manager for storing the data. -use std::{fmt, ops::RangeBounds, time::SystemTime}; +use std::{fmt, future::Future, ops::RangeBounds, time::SystemTime}; use libsignal_service::{ content::{ContentBody, Metadata}, @@ -17,7 +17,7 @@ use libsignal_service::{ Profile, ServiceAddress, }; use serde::{Deserialize, Serialize}; -use tracing::{error, trace}; +use tracing::trace; use crate::{ manager::RegistrationData, @@ -33,29 +33,31 @@ pub trait StateStore { type StateStoreError: StoreError; /// Load registered (or linked) state - fn load_registration_data(&self) -> Result, Self::StateStoreError>; + fn load_registration_data( + &self, + ) -> impl Future, Self::StateStoreError>>; fn set_aci_identity_key_pair( &self, key_pair: IdentityKeyPair, - ) -> Result<(), Self::StateStoreError>; + ) -> impl Future>; fn set_pni_identity_key_pair( &self, key_pair: IdentityKeyPair, - ) -> Result<(), Self::StateStoreError>; + ) -> impl Future>; /// Save registered (or linked) state fn save_registration_data( &mut self, state: &RegistrationData, - ) -> Result<(), Self::StateStoreError>; + ) -> impl Future>; /// Returns whether this store contains registration data or not - fn is_registered(&self) -> bool; + fn is_registered(&self) -> impl Future; /// Clear registration data (including keys), but keep received messages, groups and contacts. - fn clear_registration(&mut self) -> Result<(), Self::StateStoreError>; + fn clear_registration(&mut self) -> impl Future>; } /// Stores messages, contacts, groups and profiles @@ -77,70 +79,28 @@ pub trait ContentsStore: Send + Sync { type StickerPacksIter: Iterator>; // Clear all profiles - fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError>; + fn clear_profiles(&mut self) -> impl Future>; // Clear all stored messages - fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError>; + fn clear_contents(&mut self) -> impl Future>; // Messages /// Clear all stored messages. - fn clear_messages(&mut self) -> Result<(), Self::ContentsStoreError>; + fn clear_messages(&mut self) -> impl Future>; /// Clear the messages in a thread. - fn clear_thread(&mut self, thread: &Thread) -> Result<(), Self::ContentsStoreError>; + fn clear_thread( + &mut self, + thread: &Thread, + ) -> impl Future>; /// Save a message in a [Thread] identified by a timestamp. fn save_message( &self, thread: &Thread, message: Content, - ) -> Result<(), Self::ContentsStoreError>; - - /// Saves a message that can show users when the identity of a contact has changed - /// On Signal Android, this is usually displayed as: "Your safety number with XYZ has changed." - fn save_trusted_identity_message( - &self, - protocol_address: &ProtocolAddress, - right_identity_key: IdentityKey, - verified_state: verified::State, - ) { - let Ok(sender) = protocol_address.name().parse() else { - return; - }; - - // TODO: this is a hack to save a message showing that the verification status changed - // It is possibly ok to do it like this, but rebuidling the metadata and content body feels dirty - let thread = Thread::Contact(sender); - let verified_sync_message = Content { - metadata: Metadata { - sender: ServiceAddress::from_aci(sender), - destination: ServiceAddress::from_aci(sender), - sender_device: 0, - server_guid: None, - timestamp: SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as u64, - needs_receipt: false, - unidentified_sender: false, - }, - body: SyncMessage { - verified: Some(Verified { - destination_aci: None, - identity_key: Some(right_identity_key.public_key().serialize().to_vec()), - state: Some(verified_state.into()), - null_message: None, - }), - ..Default::default() - } - .into(), - }; - - if let Err(error) = self.save_message(&thread, verified_sync_message) { - error!(%error, "failed to save the verified session message in thread"); - } - } + ) -> impl Future>; /// Delete a single message, identified by its received timestamp from a thread. /// Useful when you want to delete a message locally only. @@ -148,37 +108,41 @@ pub trait ContentsStore: Send + Sync { &mut self, thread: &Thread, timestamp: u64, - ) -> Result; + ) -> impl Future>; /// Retrieve a message from a [Thread] by its timestamp. fn message( &self, thread: &Thread, timestamp: u64, - ) -> Result, Self::ContentsStoreError>; + ) -> impl Future, Self::ContentsStoreError>>; /// Retrieve all messages from a [Thread] within a range in time fn messages( &self, thread: &Thread, range: impl RangeBounds, - ) -> Result; + ) -> impl Future>; /// Get the expire timer from a [Thread], which corresponds to either [Contact::expire_timer] /// or [Group::disappearing_messages_timer]. fn expire_timer( &self, thread: &Thread, - ) -> Result, Self::ContentsStoreError> { - match thread { - Thread::Contact(uuid) => Ok(self - .contact_by_id(uuid)? - .map(|c| (c.expire_timer, c.expire_timer_version))), - Thread::Group(key) => Ok(self - .group(*key)? - .and_then(|g| g.disappearing_messages_timer) - // TODO: most likely we can have versions here - .map(|t| (t.duration, 1))), // Groups do not have expire_timer_version + ) -> impl Future, Self::ContentsStoreError>> { + async move { + match thread { + Thread::Contact(uuid) => Ok(self + .contact_by_id(uuid) + .await? + .map(|c| (c.expire_timer, c.expire_timer_version))), + Thread::Group(key) => Ok(self + .group(*key) + .await? + .and_then(|g| g.disappearing_messages_timer) + // TODO: most likely we can have versions here + .map(|t| (t.duration, 1))), // Groups do not have expire_timer_version + } } } @@ -189,29 +153,31 @@ pub trait ContentsStore: Send + Sync { thread: &Thread, timer: u32, version: u32, - ) -> Result<(), Self::ContentsStoreError> { - trace!(%thread, timer, version, "updating expire timer"); - match thread { - Thread::Contact(uuid) => { - let contact = self.contact_by_id(uuid)?; - if let Some(mut contact) = contact { - let current_version = contact.expire_timer_version; - if version <= current_version { - return Ok(()); + ) -> impl Future> { + async move { + trace!(%thread, timer, version, "updating expire timer"); + match thread { + Thread::Contact(uuid) => { + let contact = self.contact_by_id(uuid).await?; + if let Some(mut contact) = contact { + let current_version = contact.expire_timer_version; + if version <= current_version { + return Ok(()); + } + contact.expire_timer_version = version; + contact.expire_timer = timer; + self.save_contact(&contact).await?; } - contact.expire_timer_version = version; - contact.expire_timer = timer; - self.save_contact(&contact)?; + Ok(()) } - Ok(()) - } - Thread::Group(key) => { - let group = self.group(*key)?; - if let Some(mut g) = group { - g.disappearing_messages_timer = Some(Timer { duration: timer }); - self.save_group(*key, g)?; + Thread::Group(key) => { + let group = self.group(*key).await?; + if let Some(mut g) = group { + g.disappearing_messages_timer = Some(Timer { duration: timer }); + self.save_group(*key, g).await?; + } + Ok(()) } - Ok(()) } } } @@ -219,48 +185,56 @@ pub trait ContentsStore: Send + Sync { // Contacts /// Clear all saved synchronized contact data - fn clear_contacts(&mut self) -> Result<(), Self::ContentsStoreError>; + fn clear_contacts(&mut self) -> impl Future>; /// Save a contact - fn save_contact(&mut self, contacts: &Contact) -> Result<(), Self::ContentsStoreError>; + fn save_contact( + &mut self, + contacts: &Contact, + ) -> impl Future>; /// Get an iterator on all stored (synchronized) contacts - fn contacts(&self) -> Result; + fn contacts( + &self, + ) -> impl Future>; /// Get contact data for a single user by its [Uuid]. - fn contact_by_id(&self, id: &Uuid) -> Result, Self::ContentsStoreError>; + fn contact_by_id( + &self, + id: &Uuid, + ) -> impl Future, Self::ContentsStoreError>>; /// Delete all cached group data - fn clear_groups(&mut self) -> Result<(), Self::ContentsStoreError>; + fn clear_groups(&mut self) -> impl Future>; /// Save a group in the cache fn save_group( &self, master_key: GroupMasterKeyBytes, group: impl Into, - ) -> Result<(), Self::ContentsStoreError>; + ) -> impl Future>; /// Get an iterator on all cached groups - fn groups(&self) -> Result; + fn groups(&self) -> impl Future>; /// Retrieve a single unencrypted group indexed by its `[GroupMasterKeyBytes]` fn group( &self, master_key: GroupMasterKeyBytes, - ) -> Result, Self::ContentsStoreError>; + ) -> impl Future, Self::ContentsStoreError>>; /// Save a group avatar in the cache fn save_group_avatar( &self, master_key: GroupMasterKeyBytes, avatar: &AvatarBytes, - ) -> Result<(), Self::ContentsStoreError>; + ) -> impl Future>; /// Retrieve a group avatar from the cache. fn group_avatar( &self, master_key: GroupMasterKeyBytes, - ) -> Result, Self::ContentsStoreError>; + ) -> impl Future, Self::ContentsStoreError>>; // Profiles @@ -269,10 +243,13 @@ pub trait ContentsStore: Send + Sync { &mut self, uuid: &Uuid, key: ProfileKey, - ) -> Result; + ) -> impl Future>; /// Get the profile key for a contact - fn profile_key(&self, uuid: &Uuid) -> Result, Self::ContentsStoreError>; + fn profile_key( + &self, + uuid: &Uuid, + ) -> impl Future, Self::ContentsStoreError>>; /// Save a profile by [Uuid] and [ProfileKey]. fn save_profile( @@ -280,14 +257,14 @@ pub trait ContentsStore: Send + Sync { uuid: Uuid, key: ProfileKey, profile: Profile, - ) -> Result<(), Self::ContentsStoreError>; + ) -> impl Future>; /// Retrieve a profile by [Uuid] and [ProfileKey]. fn profile( &self, uuid: Uuid, key: ProfileKey, - ) -> Result, Self::ContentsStoreError>; + ) -> impl Future, Self::ContentsStoreError>>; /// Save a profile avatar by [Uuid] and [ProfileKey]. fn save_profile_avatar( @@ -295,28 +272,39 @@ pub trait ContentsStore: Send + Sync { uuid: Uuid, key: ProfileKey, profile: &AvatarBytes, - ) -> Result<(), Self::ContentsStoreError>; + ) -> impl Future>; /// Retrieve a profile avatar by [Uuid] and [ProfileKey]. fn profile_avatar( &self, uuid: Uuid, key: ProfileKey, - ) -> Result, Self::ContentsStoreError>; + ) -> impl Future, Self::ContentsStoreError>>; /// Stickers /// Add a sticker pack - fn add_sticker_pack(&mut self, pack: &StickerPack) -> Result<(), Self::ContentsStoreError>; + fn add_sticker_pack( + &mut self, + pack: &StickerPack, + ) -> impl Future> + Send; /// Gets a cached sticker pack - fn sticker_pack(&self, id: &[u8]) -> Result, Self::ContentsStoreError>; + fn sticker_pack( + &self, + id: &[u8], + ) -> impl Future, Self::ContentsStoreError>>; /// Removes a sticker pack - fn remove_sticker_pack(&mut self, id: &[u8]) -> Result; + fn remove_sticker_pack( + &mut self, + id: &[u8], + ) -> impl Future>; /// Get an iterator on all installed stickerpacks - fn sticker_packs(&self) -> Result; + fn sticker_packs( + &self, + ) -> impl Future>; } /// The manager store trait combining all other stores into a single one @@ -335,7 +323,9 @@ pub trait Store: /// Clear the entire store /// /// This can be useful when resetting an existing client. - fn clear(&mut self) -> Result<(), ::StateStoreError>; + fn clear( + &mut self, + ) -> impl Future::StateStoreError>> + Send; fn aci_protocol_store(&self) -> Self::AciStore; @@ -531,3 +521,46 @@ impl From for Sticker { } } } + +/// Saves a message that can show users when the identity of a contact has changed +/// On Signal Android, this is usually displayed as: "Your safety number with XYZ has changed." +pub async fn save_trusted_identity_message( + store: &S, + protocol_address: &ProtocolAddress, + right_identity_key: IdentityKey, + verified_state: verified::State, +) -> Result<(), S::Error> { + let Ok(sender) = protocol_address.name().parse() else { + return Ok(()); + }; + + // TODO: this is a hack to save a message showing that the verification status changed + // It is possibly ok to do it like this, but rebuidling the metadata and content body feels dirty + let thread = Thread::Contact(sender); + let verified_sync_message = Content { + metadata: Metadata { + sender: ServiceAddress::from_aci(sender), + destination: ServiceAddress::from_aci(sender), + sender_device: 0, + server_guid: None, + timestamp: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64, + needs_receipt: false, + unidentified_sender: false, + }, + body: SyncMessage { + verified: Some(Verified { + destination_aci: None, + identity_key: Some(right_identity_key.public_key().serialize().to_vec()), + state: Some(verified_state.into()), + null_message: None, + }), + ..Default::default() + } + .into(), + }; + + store.save_message(&thread, verified_sync_message).await +}