@@ -7,6 +7,7 @@ use core_crypto_keystore::{
77 connection:: FetchFromDatabase ,
88 entities:: { EntityFindParams , StoredEncryptionKeyPair , StoredHpkePrivateKey , StoredKeypackage } ,
99} ;
10+ use futures_util:: { StreamExt , TryStreamExt , stream:: FuturesUnordered } ;
1011use mls_crypto_provider:: { Database , MlsCryptoProvider } ;
1112use openmls:: prelude:: {
1213 Credential as MlsCredential , CredentialWithKey , CryptoConfig , KeyPackage , KeyPackageRef , Lifetime ,
@@ -31,6 +32,23 @@ pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
3132pub const KEYPACKAGE_DEFAULT_LIFETIME : Duration = Duration :: from_secs ( 60 * 60 * 24 * 28 * 3 ) ; // ~3 months
3233
3334impl Session {
35+ /// Get an unambiguous credential from the provided ref.
36+ async fn credential_from_ref ( & self , credential_ref : & CredentialRef ) -> Result < Credential > {
37+ let mut credentials = credential_ref
38+ . load ( & self . crypto_provider . keystore ( ) )
39+ . await
40+ . map_err ( RecursiveError :: mls_credential_ref ( "loading credential from reference" ) ) ?;
41+ let credential = credentials. pop ( ) . ok_or ( Error :: CredentialNotFound (
42+ credential_ref. r#type ( ) ,
43+ credential_ref. signature_scheme ( ) ,
44+ ) ) ?;
45+ if !credentials. is_empty ( ) {
46+ return Err ( Error :: AmbiguousCredentialRef ) ;
47+ }
48+
49+ Ok ( credential)
50+ }
51+
3452 /// Generate a [KeyPackage] from the referenced credential.
3553 ///
3654 /// Makes no attempt to look up or prune existing keypackges.
@@ -47,18 +65,7 @@ impl Session {
4765 lifetime : Option < Duration > ,
4866 ) -> Result < KeyPackage > {
4967 let lifetime = Lifetime :: new ( lifetime. unwrap_or ( KEYPACKAGE_DEFAULT_LIFETIME ) . as_secs ( ) ) ;
50-
51- let mut credentials = credential_ref
52- . load ( & self . crypto_provider . keystore ( ) )
53- . await
54- . map_err ( RecursiveError :: mls_credential_ref ( "loading credential from reference" ) ) ?;
55- let credential = credentials. pop ( ) . ok_or ( Error :: CredentialNotFound (
56- credential_ref. r#type ( ) ,
57- credential_ref. signature_scheme ( ) ,
58- ) ) ?;
59- if !credentials. is_empty ( ) {
60- return Err ( Error :: AmbiguousCredentialRef ) ;
61- }
68+ let credential = self . credential_from_ref ( credential_ref) . await ?;
6269
6370 let config = CryptoConfig {
6471 ciphersuite : credential. ciphersuite . into ( ) ,
@@ -79,7 +86,7 @@ impl Session {
7986 }
8087
8188 /// Get all [`KeyPackageRef`]s in the database.
82- pub async fn get_keypackages ( & self ) -> Result < Vec < KeyPackageRef > > {
89+ pub async fn get_keypackage_refs ( & self ) -> Result < Vec < KeyPackageRef > > {
8390 let stored_keypackages: Vec < StoredKeypackage > = self
8491 . crypto_provider
8592 . keystore ( )
@@ -108,6 +115,29 @@ impl Session {
108115 . map_err ( Into :: into)
109116 }
110117
118+ /// Get all [`KeyPackage`]s in the database.
119+ ///
120+ /// This is moderately complicated because mapping with an asynchronous function is intrinsically
121+ /// a bit complicated, unfortunately.
122+ pub ( crate ) async fn get_keypackages ( & self ) -> Result < Vec < KeyPackage > > {
123+ let keypackage_refs = self . get_keypackage_refs ( ) . await ?;
124+ let keypackages = keypackage_refs
125+ . iter ( )
126+ . map ( |kp_ref| self . load_keypackage ( kp_ref) )
127+ . collect :: < FuturesUnordered < _ > > ( )
128+ // if any ref from loading all fails to load now, skip it
129+ // strictly we could panic, but this is safer--maybe someone removed it concurrently
130+ . filter_map ( async |kp| kp. transpose ( ) )
131+ // it is weirdly difficult to get the "collect into a result" behavior we're accustomed to in sync-land
132+ // when what we have is a bunch of futures, but this seems to accomplish that
133+ . try_fold ( Vec :: new ( ) , async |mut acc, keypackage| {
134+ acc. push ( keypackage) ;
135+ Ok ( acc)
136+ } )
137+ . await ?;
138+ Ok ( keypackages)
139+ }
140+
111141 /// Remove one [`KeyPackage`] from the database.
112142 ///
113143 /// Succeeds silently if the keypackage does not exist in the database.
@@ -133,6 +163,52 @@ impl Session {
133163 Ok ( ( ) )
134164 }
135165
166+ /// Remove all keypackages associated with this credential.
167+ ///
168+ /// This is fairly expensive as it must first load all keypackages, then delete those matching the credential.
169+ ///
170+ /// Implementation note: once it makes it as far as having a list of keypackages, does _not_ short-circuit
171+ /// if removing one returns an error. In that case, only the first produced error is returned.
172+ /// This helps ensure that as many keypackages for the given credential ref are removed as possible.
173+ pub async fn remove_keypackages_for ( & self , credential_ref : & CredentialRef ) -> Result < ( ) > {
174+ let credential = self . credential_from_ref ( credential_ref) . await ?;
175+ let mls_credential = credential. mls_credential ( ) ;
176+
177+ let mut first_err = None ;
178+ macro_rules! try_retain_err {
179+ ( $e: expr) => {
180+ match $e {
181+ Err ( err) => {
182+ if first_err. is_none( ) {
183+ first_err = Some ( Error :: from( err) ) ;
184+ }
185+ continue ;
186+ }
187+ Ok ( val) => val,
188+ }
189+ } ;
190+ }
191+
192+ for keypackage in self
193+ . get_keypackages ( )
194+ . await ?
195+ . into_iter ( )
196+ . filter ( |keypackage| keypackage. leaf_node ( ) . credential ( ) == mls_credential)
197+ {
198+ let kp_ref = try_retain_err ! (
199+ keypackage
200+ . hash_ref( self . crypto_provider. crypto( ) )
201+ . map_err( MlsError :: wrap( "getting keypackage ref in remove_keypackages_for" ) )
202+ ) ;
203+ try_retain_err ! ( self . remove_keypackage( & kp_ref) . await ) ;
204+ }
205+
206+ match first_err {
207+ None => Ok ( ( ) ) ,
208+ Some ( err) => Err ( err) ,
209+ }
210+ }
211+
136212 /// Generates a single new keypackage
137213 ///
138214 /// # Arguments
0 commit comments