Skip to content

Commit c6e7270

Browse files
committed
feat: add TransactionContext::remove_keypackages_for
1 parent 5122909 commit c6e7270

File tree

3 files changed

+105
-17
lines changed

3 files changed

+105
-17
lines changed

crypto/src/mls/credential/credential_ref/find.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ use core_crypto_keystore::{
33
entities::{EntityFindParams, StoredCredential},
44
};
55
use mls_crypto_provider::Database;
6-
use openmls::prelude::{Credential as MlsCredential};
6+
use openmls::prelude::Credential as MlsCredential;
77
use tls_codec::Deserialize as _;
88

99
use super::{Error, Result};
10-
use crate::{ClientId, Ciphersuite, CredentialRef, CredentialType, KeystoreError, mls::session::id::ClientIdRef};
10+
use crate::{Ciphersuite, ClientId, CredentialRef, CredentialType, KeystoreError, mls::session::id::ClientIdRef};
1111

1212
/// Filters to narrow down the set of credentials returned from various credential-finding methods.
1313
///

crypto/src/mls/session/key_package.rs

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use core_crypto_keystore::{
77
connection::FetchFromDatabase,
88
entities::{EntityFindParams, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage},
99
};
10+
use futures_util::{StreamExt, TryStreamExt, stream::FuturesUnordered};
1011
use mls_crypto_provider::{Database, MlsCryptoProvider};
1112
use openmls::prelude::{
1213
Credential as MlsCredential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime,
@@ -31,6 +32,23 @@ pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
3132
pub const KEYPACKAGE_DEFAULT_LIFETIME: Duration = Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
3233

3334
impl Session {
35+
/// Get an unambiguous credential from the provided ref.
36+
async fn credential_from_ref(&self, credential_ref: &CredentialRef) -> Result<Credential> {
37+
let mut credentials = credential_ref
38+
.load(&self.crypto_provider.keystore())
39+
.await
40+
.map_err(RecursiveError::mls_credential_ref("loading credential from reference"))?;
41+
let credential = credentials.pop().ok_or(Error::CredentialNotFound(
42+
credential_ref.r#type(),
43+
credential_ref.signature_scheme(),
44+
))?;
45+
if !credentials.is_empty() {
46+
return Err(Error::AmbiguousCredentialRef);
47+
}
48+
49+
Ok(credential)
50+
}
51+
3452
/// Generate a [KeyPackage] from the referenced credential.
3553
///
3654
/// Makes no attempt to look up or prune existing keypackges.
@@ -47,18 +65,7 @@ impl Session {
4765
lifetime: Option<Duration>,
4866
) -> Result<KeyPackage> {
4967
let lifetime = Lifetime::new(lifetime.unwrap_or(KEYPACKAGE_DEFAULT_LIFETIME).as_secs());
50-
51-
let mut credentials = credential_ref
52-
.load(&self.crypto_provider.keystore())
53-
.await
54-
.map_err(RecursiveError::mls_credential_ref("loading credential from reference"))?;
55-
let credential = credentials.pop().ok_or(Error::CredentialNotFound(
56-
credential_ref.r#type(),
57-
credential_ref.signature_scheme(),
58-
))?;
59-
if !credentials.is_empty() {
60-
return Err(Error::AmbiguousCredentialRef);
61-
}
68+
let credential = self.credential_from_ref(credential_ref).await?;
6269

6370
let config = CryptoConfig {
6471
ciphersuite: credential.ciphersuite.into(),
@@ -79,7 +86,7 @@ impl Session {
7986
}
8087

8188
/// Get all [`KeyPackageRef`]s in the database.
82-
pub async fn get_keypackages(&self) -> Result<Vec<KeyPackageRef>> {
89+
pub async fn get_keypackage_refs(&self) -> Result<Vec<KeyPackageRef>> {
8390
let stored_keypackages: Vec<StoredKeypackage> = self
8491
.crypto_provider
8592
.keystore()
@@ -108,6 +115,29 @@ impl Session {
108115
.map_err(Into::into)
109116
}
110117

118+
/// Get all [`KeyPackage`]s in the database.
119+
///
120+
/// This is moderately complicated because mapping with an asynchronous function is intrinsically
121+
/// a bit complicated, unfortunately.
122+
pub(crate) async fn get_keypackages(&self) -> Result<Vec<KeyPackage>> {
123+
let keypackage_refs = self.get_keypackage_refs().await?;
124+
let keypackages = keypackage_refs
125+
.iter()
126+
.map(|kp_ref| self.load_keypackage(kp_ref))
127+
.collect::<FuturesUnordered<_>>()
128+
// if any ref from loading all fails to load now, skip it
129+
// strictly we could panic, but this is safer--maybe someone removed it concurrently
130+
.filter_map(async |kp| kp.transpose())
131+
// it is weirdly difficult to get the "collect into a result" behavior we're accustomed to in sync-land
132+
// when what we have is a bunch of futures, but this seems to accomplish that
133+
.try_fold(Vec::new(), async |mut acc, keypackage| {
134+
acc.push(keypackage);
135+
Ok(acc)
136+
})
137+
.await?;
138+
Ok(keypackages)
139+
}
140+
111141
/// Remove one [`KeyPackage`] from the database.
112142
///
113143
/// Succeeds silently if the keypackage does not exist in the database.
@@ -133,6 +163,52 @@ impl Session {
133163
Ok(())
134164
}
135165

166+
/// Remove all keypackages associated with this credential.
167+
///
168+
/// This is fairly expensive as it must first load all keypackages, then delete those matching the credential.
169+
///
170+
/// Implementation note: once it makes it as far as having a list of keypackages, does _not_ short-circuit
171+
/// if removing one returns an error. In that case, only the first produced error is returned.
172+
/// This helps ensure that as many keypackages for the given credential ref are removed as possible.
173+
pub async fn remove_keypackages_for(&self, credential_ref: &CredentialRef) -> Result<()> {
174+
let credential = self.credential_from_ref(credential_ref).await?;
175+
let mls_credential = credential.mls_credential();
176+
177+
let mut first_err = None;
178+
macro_rules! try_retain_err {
179+
($e:expr) => {
180+
match $e {
181+
Err(err) => {
182+
if first_err.is_none() {
183+
first_err = Some(Error::from(err));
184+
}
185+
continue;
186+
}
187+
Ok(val) => val,
188+
}
189+
};
190+
}
191+
192+
for keypackage in self
193+
.get_keypackages()
194+
.await?
195+
.into_iter()
196+
.filter(|keypackage| keypackage.leaf_node().credential() == mls_credential)
197+
{
198+
let kp_ref = try_retain_err!(
199+
keypackage
200+
.hash_ref(self.crypto_provider.crypto())
201+
.map_err(MlsError::wrap("getting keypackage ref in remove_keypackages_for"))
202+
);
203+
try_retain_err!(self.remove_keypackage(&kp_ref).await);
204+
}
205+
206+
match first_err {
207+
None => Ok(()),
208+
Some(err) => Err(err),
209+
}
210+
}
211+
136212
/// Generates a single new keypackage
137213
///
138214
/// # Arguments

crypto/src/transaction_context/key_package.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ impl TransactionContext {
8686
}
8787

8888
/// Get all [`KeyPackageRef`]s known to the keystore.
89-
pub async fn get_keypackages(&self) -> Result<Vec<KeyPackageRef>> {
89+
pub async fn get_keypackage_refs(&self) -> Result<Vec<KeyPackageRef>> {
9090
let session = self.session().await?;
9191
session
92-
.get_keypackages()
92+
.get_keypackage_refs()
9393
.await
9494
.map_err(RecursiveError::mls_client(
9595
"getting all key package refs for transaction",
@@ -106,4 +106,16 @@ impl TransactionContext {
106106
.map_err(RecursiveError::mls_client("removing a keypackage for transaction"))
107107
.map_err(Into::into)
108108
}
109+
110+
/// Remove all [`KeyPackage`]s associated with this ref.
111+
pub async fn remove_keypackages_for(&self, credential_ref: &CredentialRef) -> Result<()> {
112+
let session = self.session().await?;
113+
session
114+
.remove_keypackages_for(credential_ref)
115+
.await
116+
.map_err(RecursiveError::mls_client(
117+
"removing all keypackages for credential ref for transaction",
118+
))
119+
.map_err(Into::into)
120+
}
109121
}

0 commit comments

Comments
 (0)