Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the get_data_or_default interface. #64

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/logical/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,15 @@ mod test {
},
{op: Operation::Write, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result<Option<Response>, RvError> {
let array_val = req.get_data("array")?;
let array_default_val = req.get_data("array_default")?;
let array_default_val = req.get_data_or_default("array_default")?;
let bool_val = req.get_data("bool")?;
let bool_default_val = req.get_data("bool_default")?;
let bool_default_val = req.get_data_or_default("bool_default")?;
let comma_val = req.get_data("comma")?;
let comma_default_val = req.get_data("comma_default")?;
let comma_default_val = req.get_data_or_default("comma_default")?;
let map_val = req.get_data("map")?;
let map_default_val = req.get_data("map_default")?;
let map_default_val = req.get_data_or_default("map_default")?;
let duration_val = req.get_data("duration")?;
let duration_default_val = req.get_data("duration_default")?;
let duration_default_val = req.get_data_or_default("duration_default")?;
let data = json!({
"array": array_val,
"array_default": array_default_val,
Expand Down
69 changes: 57 additions & 12 deletions src/logical/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,7 @@ impl Request {
Self { operation: Operation::Renew, path: path.to_string(), auth, data, ..Default::default() }
}

pub fn get_data(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

fn get_data_raw(&self, key: &str, default: bool) -> Result<Value, RvError> {
let field = self.match_path.as_ref().unwrap().get_field(key);
if field.is_none() {
return Err(RvError::ErrRequestNoDataField);
Expand All @@ -95,11 +87,64 @@ impl Request {
}
}

if field.required {
return Err(RvError::ErrRequestFieldNotFound);
if default {
if field.required {
return Err(RvError::ErrRequestFieldNotFound);
}

return field.get_default();
}

return Err(RvError::ErrRequestFieldNotFound);
}

pub fn get_data(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

self.get_data_raw(key, false)
}

pub fn get_data_or_default(&self, key: &str) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

self.get_data_raw(key, true)
}

pub fn get_data_or_next(&self, keys: &[&str]) -> Result<Value, RvError> {
if self.storage.is_none() || self.match_path.is_none() {
return Err(RvError::ErrRequestNotReady);
}

if self.data.is_none() && self.body.is_none() {
return Err(RvError::ErrRequestNoData);
}

for &key in keys.iter() {
match self.get_data_raw(key, false) {
Ok(raw) => {
return Ok(raw);
},
Err(e) => {
if e != RvError::ErrRequestFieldNotFound {
return Err(e);
}
}
}
}

return field.get_default();
return Err(RvError::ErrRequestFieldNotFound);
}

//TODO: the sensitive data is still in the memory. Need to totally resolve this in `serde_json` someday.
Expand Down
36 changes: 17 additions & 19 deletions src/modules/credential/userpass/path_users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl UserPassBackendInner {

pub fn read_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap().to_lowercase();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase();

let entry = self.get_user(req, &username)?;
if entry.is_none() {
Expand All @@ -165,43 +165,41 @@ impl UserPassBackendInner {

pub fn write_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_lowercase();

let mut user_entry = UserEntry::default();

let entry = self.get_user(req, username)?;
if entry.is_some() {
user_entry = entry.unwrap();
if let Some(entry) = self.get_user(req, &username)? {
user_entry = entry;
}

let password_value = req.get_data("password")?;
let password = password_value.as_str().unwrap();
if password != "" {
let password_hash = self.gen_password_hash(password)?;

user_entry.password_hash = password_hash;
if let Ok(password_value) = req.get_data("password") {
let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if password != "" {
user_entry.password_hash = self.gen_password_hash(password)?;
}
}

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_u64().unwrap();
let ttl_value = req.get_data_or_default("ttl")?;
let ttl = ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl > 0 {
user_entry.ttl = Duration::from_secs(ttl);
}

let max_ttl_value = req.get_data("max_ttl")?;
let max_ttl = max_ttl_value.as_u64().unwrap();
let max_ttl_value = req.get_data_or_default("max_ttl")?;
let max_ttl = max_ttl_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
if max_ttl > 0 {
user_entry.max_ttl = Duration::from_secs(max_ttl);
}

self.set_user(req, username, &user_entry)?;
self.set_user(req, &username, &user_entry)?;

Ok(None)
}

pub fn delete_user(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if username == "" {
return Err(RvError::ErrRequestNoDataField);
}
Expand All @@ -218,7 +216,7 @@ impl UserPassBackendInner {

pub fn write_user_password(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let username_value = req.get_data("username")?;
let username = username_value.as_str().unwrap();
let username = username_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut user_entry = UserEntry::default();

Expand All @@ -228,7 +226,7 @@ impl UserPassBackendInner {
}

let password_value = req.get_data("password")?;
let password = password_value.as_str().unwrap();
let password = password_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let password_hash = self.gen_password_hash(password)?;

Expand Down
1 change: 0 additions & 1 deletion src/modules/pki/path_fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Using "ca" or "crl" as the value fetches the appropriate information in DER enco
fields: {
"serial": {
field_type: FieldType::Str,
default: "72h",
description: "Certificate serial number, in colon- or hyphen-separated octal"
}
},
Expand Down
21 changes: 7 additions & 14 deletions src/modules/pki/path_issue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,16 @@ requested common name is allowed by the role policy.

impl PkiBackendInner {
pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
//let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut common_names = Vec::new();

let common_name_value = req.get_data("common_name")?;
let common_name_value = req.get_data_or_default("common_name")?;
let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if common_name != "" {
common_names.push(common_name.to_string());
}

let alt_names_value = req.get_data("alt_names");
if alt_names_value.is_ok() {
let alt_names_val = alt_names_value.unwrap();
let alt_names = alt_names_val.as_str().unwrap();
if let Ok(alt_names_value) = req.get_data("alt_names") {
let alt_names = alt_names_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if alt_names != "" {
for v in alt_names.split(',') {
common_names.push(v.to_string());
Expand All @@ -94,10 +90,8 @@ impl PkiBackendInner {
let role_entry = role.unwrap();

let mut ip_sans = Vec::new();
let ip_sans_value = req.get_data("ip_sans");
if ip_sans_value.is_ok() {
let ip_sans_val = ip_sans_value.unwrap();
let ip_sans_str = ip_sans_val.as_str().unwrap();
if let Ok(ip_sans_value) = req.get_data("ip_sans") {
let ip_sans_str = ip_sans_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ip_sans_str != "" {
for v in ip_sans_str.split(',') {
ip_sans.push(v.to_string());
Expand All @@ -109,9 +103,8 @@ impl PkiBackendInner {
let not_before = SystemTime::now() - Duration::from_secs(10);
let mut not_after = not_before + parse_duration("30d").unwrap();

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl != "" {
if let Ok(ttl_value) = req.get_data("ttl") {
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let ttl_dur = parse_duration(ttl)?;
let req_ttl_not_after_dur = SystemTime::now() + ttl_dur;
let req_ttl_not_after =
Expand Down
20 changes: 11 additions & 9 deletions src/modules/pki/path_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ impl PkiBackend {
pattern: r"keys/generate/(exported|internal)",
fields: {
"key_name": {
required: true,
field_type: FieldType::Str,
description: "key name"
},
"key_bits": {
required: true,
field_type: FieldType::Int,
default: 0,
description: r#"
The number of bits to use. Allowed values are 0 (universal default); with rsa
key_type: 2048 (default), 3072, or 4096; with ec key_type: 224, 256 (default),
Expand Down Expand Up @@ -213,9 +214,10 @@ impl PkiBackendInner {
pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type_value = req.get_data_or_default("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits_value = req.get_data_or_default("key_bits")?;
let key_bits = key_bits_value.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut export_private_key = false;
if req.path.ends_with("/exported") {
Expand Down Expand Up @@ -266,11 +268,11 @@ impl PkiBackendInner {
pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type_value = req.get_data_or_default("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle_value = req.get_data_or_default("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let hex_bundle_value = req.get_data("hex_bundle")?;
let hex_bundle_value = req.get_data_or_default("hex_bundle")?;
let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

if pem_bundle.len() == 0 && hex_bundle.len() == 0 {
Expand Down Expand Up @@ -310,7 +312,7 @@ impl PkiBackendInner {
return Err(RvError::ErrPkiKeyBitsInvalid);
}
};
let iv_value = req.get_data("iv")?;
let iv_value = req.get_data_or_default("iv")?;
let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm");
#[cfg(tongsuo)]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm");
Expand Down Expand Up @@ -391,7 +393,7 @@ impl PkiBackendInner {
pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let data_value = req.get_data("data")?;
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad_value = req.get_data_or_default("aad")?;
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
Expand All @@ -412,7 +414,7 @@ impl PkiBackendInner {
pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let data_value = req.get_data("data")?;
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad_value = req.get_data_or_default("aad")?;
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
Expand Down
Loading
Loading