diff --git a/src/errors.rs b/src/errors.rs index fc3d01c..94447c6 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -141,6 +141,8 @@ pub enum RvError { ErrPkiCertNotFound, #[error("PKI role is not found.")] ErrPkiRoleNotFound, + #[error("PKI data is invalid.")] + ErrPkiDataInvalid, #[error("PKI internal error.")] ErrPkiInternal, #[error("Credentail is invalid.")] @@ -323,6 +325,7 @@ impl PartialEq for RvError { | (RvError::ErrPkiKeyOperationInvalid, RvError::ErrPkiKeyOperationInvalid) | (RvError::ErrPkiCertNotFound, RvError::ErrPkiCertNotFound) | (RvError::ErrPkiRoleNotFound, RvError::ErrPkiRoleNotFound) + | (RvError::ErrPkiDataInvalid, RvError::ErrPkiDataInvalid) | (RvError::ErrPkiInternal, RvError::ErrPkiInternal) | (RvError::ErrCredentailInvalid, RvError::ErrCredentailInvalid) | (RvError::ErrCredentailNotConfig, RvError::ErrCredentailNotConfig) diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index 1a77faa..a2304b6 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -9,7 +9,7 @@ use crate::{ logical::{Backend, Field, FieldType, Operation, Path, PathOperation, Request, Response}, new_fields, new_fields_internal, new_path, new_path_internal, storage::StorageEntry, - utils::key::KeyBundle, + utils::key::{KeyBundle, EncryptExtraData}, }; const PKI_CONFIG_KEY_PREFIX: &str = "config/key/"; @@ -398,7 +398,7 @@ impl PkiBackendInner { let key_bundle = self.fetch_key(req, key_name)?; let decoded_data = hex::decode(data.as_bytes())?; - let result = key_bundle.encrypt(&decoded_data, Some(aad.as_bytes()))?; + let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; let resp_data = json!({ "result": hex::encode(&result), @@ -421,7 +421,7 @@ impl PkiBackendInner { let key_bundle = self.fetch_key(req, key_name)?; let decoded_data = hex::decode(data.as_bytes())?; - let result = key_bundle.decrypt(&decoded_data, Some(aad.as_bytes()))?; + let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?; let resp_data = json!({ "result": hex::encode(&result), diff --git a/src/utils/key.rs b/src/utils/key.rs index 34eb39d..e6f8b97 100644 --- a/src/utils/key.rs +++ b/src/utils/key.rs @@ -23,6 +23,12 @@ pub struct KeyBundle { pub bits: u32, } +#[derive(Debug, Clone)] +pub enum EncryptExtraData<'a> { + Aad(&'a [u8]), + Flag(bool), +} + impl Default for KeyBundle { fn default() -> Self { KeyBundle { @@ -138,9 +144,13 @@ impl KeyBundle { } } - pub fn encrypt(&self, data: &[u8], aad: Option<&[u8]>) -> Result, RvError> { + pub fn encrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { match self.key_type.as_str() { "aes-gcm" => { + let aad = extra.map_or("".as_bytes(), |ex| match ex { + EncryptExtraData::Aad(aad) => aad, + _ => "".as_bytes(), + }); let cipher = match self.bits { 128 => Cipher::aes_128_gcm(), 192 => Cipher::aes_192_gcm(), @@ -151,7 +161,7 @@ impl KeyBundle { }; let mut tag = vec![0u8; 16]; let mut ciphertext = - encrypt_aead(cipher, &self.key, Some(&self.iv), aad.unwrap_or("".as_bytes()), data, &mut tag)?; + encrypt_aead(cipher, &self.key, Some(&self.iv), aad, data, &mut tag)?; ciphertext.extend_from_slice(&tag); Ok(ciphertext) } @@ -181,13 +191,17 @@ impl KeyBundle { } "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; - if data.len() >= rsa.size() as usize { + if data.len() > rsa.size() as usize { return Err(RvError::ErrPkiInternal); } let mut buf: Vec = vec![0; rsa.size() as usize]; - if aad.unwrap_or("0".as_bytes())[0] != b'1' { + let flag = extra.map_or(false, |ex| match ex { + EncryptExtraData::Flag(flag) => flag, + _ => false, + }); + if !flag { let _ = rsa.private_encrypt(data, &mut buf, Padding::PKCS1)?; } else { let _ = rsa.public_encrypt(data, &mut buf, Padding::PKCS1)?; @@ -201,9 +215,13 @@ impl KeyBundle { } } - pub fn decrypt(&self, data: &[u8], aad: Option<&[u8]>) -> Result, RvError> { + pub fn decrypt(&self, data: &[u8], extra: Option) -> Result, RvError> { match self.key_type.as_str() { "aes-gcm" => { + let aad = extra.map_or("".as_bytes(), |ex| match ex { + EncryptExtraData::Aad(aad) => aad, + _ => "".as_bytes(), + }); let cipher = match self.bits { 128 => Cipher::aes_128_gcm(), 192 => Cipher::aes_192_gcm(), @@ -213,7 +231,7 @@ impl KeyBundle { } }; let (ciphertext, tag) = data.split_at(data.len() - 16); - Ok(decrypt_aead(cipher, &self.key, Some(&self.iv), aad.unwrap_or("".as_bytes()), ciphertext, tag)?) + Ok(decrypt_aead(cipher, &self.key, Some(&self.iv), aad, ciphertext, tag)?) } "aes-cbc" => { let cipher = match self.bits { @@ -242,12 +260,16 @@ impl KeyBundle { "rsa" => { let rsa = Rsa::private_key_from_pem(&self.key)?; if data.len() > rsa.size() as usize { - return Err(RvError::ErrPkiInternal); + return Err(RvError::ErrPkiDataInvalid); } let mut buf: Vec = vec![0; rsa.size() as usize]; - if aad.unwrap_or("0".as_bytes())[0] != b'1' { + let flag = extra.map_or(false, |ex| match ex { + EncryptExtraData::Flag(flag) => flag, + _ => false, + }); + if !flag { let rsa_pub_der = rsa.public_key_to_der()?; let rsa_pub = Rsa::public_key_from_der(&rsa_pub_der)?; let _ = rsa_pub.public_decrypt(data, &mut buf, Padding::PKCS1)?; @@ -284,13 +306,13 @@ mod test { assert!(verify.unwrap()); } - fn test_key_encrypt_decrypt(key_bundle: &mut KeyBundle, aad: Option<&[u8]>) { + fn test_key_encrypt_decrypt(key_bundle: &mut KeyBundle, extra: Option) { assert!(key_bundle.generate().is_ok()); let data = "123456789"; - let result = key_bundle.encrypt(data.as_bytes(), aad); + let result = key_bundle.encrypt(data.as_bytes(), extra.clone()); assert!(result.is_ok()); let encrypted_data = result.unwrap(); - let result = key_bundle.decrypt(&encrypted_data, aad); + let result = key_bundle.decrypt(&encrypted_data, extra); assert!(result.is_ok()); let decrypted_data = result.unwrap(); assert_eq!(std::str::from_utf8(&decrypted_data).unwrap(), data); @@ -301,17 +323,17 @@ mod test { let mut key_bundle = KeyBundle::new("rsa-2048", "rsa", 2048); test_key_sign_verify(&mut key_bundle); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("1".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Flag(true))); let mut key_bundle = KeyBundle::new("rsa-3072", "rsa", 3072); test_key_sign_verify(&mut key_bundle); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("1".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Flag(true))); let mut key_bundle = KeyBundle::new("rsa-4096", "rsa", 4096); test_key_sign_verify(&mut key_bundle); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("1".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Flag(true))); } #[test] @@ -332,16 +354,16 @@ mod test { let mut key_bundle = KeyBundle::new("aes-gcm-128", "aes-gcm", 128); test_key_encrypt_decrypt(&mut key_bundle, None); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("rusty_vault".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); let mut key_bundle = KeyBundle::new("aes-gcm-192", "aes-gcm", 192); test_key_encrypt_decrypt(&mut key_bundle, None); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("rusty_vault".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); let mut key_bundle = KeyBundle::new("aes-gcm-256", "aes-gcm", 256); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("rusty_vault".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); test_key_encrypt_decrypt(&mut key_bundle, None); - test_key_encrypt_decrypt(&mut key_bundle, Some("rusty_vault".as_bytes())); + test_key_encrypt_decrypt(&mut key_bundle, Some(EncryptExtraData::Aad("rusty_vault".as_bytes()))); // test aes-cbc let mut key_bundle = KeyBundle::new("aes-cbc-128", "aes-cbc", 128);