From 63d60965d80a150769dd4a1e6e0d9979463b9af2 Mon Sep 17 00:00:00 2001 From: Jin Jiu Date: Sat, 6 Apr 2024 22:57:57 +0800 Subject: [PATCH] Enhancements to RvError and optimization of the logical field type. --- Cargo.toml | 5 + src/errors.rs | 19 +- src/http/logical.rs | 10 +- src/http/mod.rs | 13 +- src/lib.rs | 1 - src/logical/auth.rs | 2 + src/logical/backend.rs | 284 ++++++++++++++++++++++++++++- src/logical/connection.rs | 14 +- src/logical/field.rs | 313 +++++++++++++++++++++++++++++--- src/logical/mod.rs | 1 + src/logical/path.rs | 7 +- src/logical/request.rs | 9 +- src/modules/auth/token_store.rs | 1 + src/modules/credential/mod.rs | 1 + src/modules/pki/field.rs | 18 +- src/modules/pki/mod.rs | 2 +- src/modules/pki/path_keys.rs | 93 +++++----- src/modules/system/mod.rs | 4 +- src/utils/mod.rs | 11 ++ 19 files changed, 709 insertions(+), 99 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8a36d6b..298044f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,11 @@ diesel = { version = "2.1.4", features = ["mysql", "r2d2"] } r2d2 = "0.8.9" r2d2-diesel = "1.0.0" bcrypt = "0.15" +url = "2.5" +ureq = "2.9" +glob = "0.3" +serde_asn1_der = "0.8" +base64 = "0.22" [target.'cfg(unix)'.dependencies] daemonize = "0.5" diff --git a/src/errors.rs b/src/errors.rs index 446e0ff..077acb8 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -89,6 +89,8 @@ pub enum RvError { ErrRequestClientTokenMissing, #[error("Request field is not found.")] ErrRequestFieldNotFound, + #[error("Request field is invalid.")] + ErrRequestFieldInvalid, #[error("Handler is default.")] ErrHandlerDefault, #[error("Module kv data field is missing.")] @@ -143,6 +145,8 @@ pub enum RvError { ErrPkiInternal, #[error("Credentail is invalid.")] ErrCredentailInvalid, + #[error("Credentail is not config.")] + ErrCredentailNotConfig, #[error("Some IO error happened, {:?}", .source)] IO { #[from] @@ -208,11 +212,16 @@ pub enum RvError { #[from] source: bcrypt::BcryptError, }, + #[error("Some ureq error happened, {:?}", .source)] + UreqError { + #[from] + source: ureq::Error, + }, #[error("RwLock was poisoned (reading)")] ErrRwLockReadPoison, #[error("RwLock was poisoned (writing)")] ErrRwLockWritePoison, - + /// Database Errors Begin /// #[error("Database type is not support now. Please try postgressql or mysql again.")] @@ -234,6 +243,12 @@ pub enum RvError { #[error(transparent)] ErrOther(#[from] anyhow::Error), + #[error("Some error happend, text: {0}")] + ErrText(String), + #[error("Some error happend, response text: {0}")] + ErrResponse(String), + #[error("Some error happend, status: {0}, response text: {1}")] + ErrResponseStatus(u16, String), #[error("Unknown error.")] ErrUnknown, } @@ -278,6 +293,7 @@ impl PartialEq for RvError { | (RvError::ErrRequestInvalid, RvError::ErrRequestInvalid) | (RvError::ErrRequestClientTokenMissing, RvError::ErrRequestClientTokenMissing) | (RvError::ErrRequestFieldNotFound, RvError::ErrRequestFieldNotFound) + | (RvError::ErrRequestFieldInvalid, RvError::ErrRequestFieldInvalid) | (RvError::ErrHandlerDefault, RvError::ErrHandlerDefault) | (RvError::ErrModuleKvDataFieldMissing, RvError::ErrModuleKvDataFieldMissing) | (RvError::ErrRustDowncastFailed, RvError::ErrRustDowncastFailed) @@ -311,6 +327,7 @@ impl PartialEq for RvError { | (RvError::ErrPkiRoleNotFound, RvError::ErrPkiRoleNotFound) | (RvError::ErrPkiInternal, RvError::ErrPkiInternal) | (RvError::ErrCredentailInvalid, RvError::ErrCredentailInvalid) + | (RvError::ErrCredentailNotConfig, RvError::ErrCredentailNotConfig) | (RvError::ErrUnknown, RvError::ErrUnknown) => true, _ => false, } diff --git a/src/http/logical.rs b/src/http/logical.rs index 92c2cd8..71763a8 100644 --- a/src/http/logical.rs +++ b/src/http/logical.rs @@ -18,7 +18,7 @@ use crate::{ core::Core, errors::RvError, http::{request_auth, response_error, response_json_ok, response_ok, Connection}, - logical::{Operation, Response}, + logical::{Operation, Connection as ReqConnection, Response}, }; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -55,8 +55,16 @@ async fn logical_request_handler( let conn = req.conn_data::().unwrap(); log::debug!("logical request, connection info: {:?}, method: {:?}, path: {:?}", conn, method, path); + let mut req_conn = ReqConnection::default(); + req_conn.peer_addr = conn.peer.to_string(); + if conn.tls.is_some() { + let tls_client_info = conn.tls.as_ref().unwrap(); + req_conn.peer_tls_cert = tls_client_info.client_cert_chain.clone(); + } + let mut r = request_auth(&req); r.path = path.into_inner().clone(); + r.connection = Some(req_conn); match method { Method::GET => { diff --git a/src/http/mod.rs b/src/http/mod.rs index a9626ee..234088e 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -90,8 +90,17 @@ pub fn init_service(cfg: &mut web::ServiceConfig) { impl ResponseError for RvError { // builds the actual response to send back when an error occurs fn error_response(&self) -> HttpResponse { - let err_json = json!({ "error": self.to_string() }); - HttpResponse::InternalServerError().json(err_json) + let mut status = StatusCode::INTERNAL_SERVER_ERROR; + let text: String; + if let RvError::ErrResponse(resp_text) = self { + text = resp_text.clone(); + } else if let RvError::ErrResponseStatus(status_code, resp_text) = self { + status = StatusCode::from_u16(status_code.clone()).unwrap(); + text = resp_text.clone(); + } else { + text = self.to_string(); + } + HttpResponse::build(status).json(json!({ "error": text })) } } diff --git a/src/lib.rs b/src/lib.rs index d782aa9..0c0d205 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ -#[macro_use] extern crate diesel; pub mod cli; diff --git a/src/logical/auth.rs b/src/logical/auth.rs index 87c80be..f96bad4 100644 --- a/src/logical/auth.rs +++ b/src/logical/auth.rs @@ -13,6 +13,7 @@ pub struct Auth { pub client_token: String, pub display_name: String, pub policies: Vec, + pub internal_data: HashMap, pub metadata: HashMap, } @@ -23,6 +24,7 @@ impl Default for Auth { client_token: String::new(), display_name: String::new(), policies: Vec::new(), + internal_data: HashMap::new(), metadata: HashMap::new(), } } diff --git a/src/logical/backend.rs b/src/logical/backend.rs index bbf67a4..8d574eb 100644 --- a/src/logical/backend.rs +++ b/src/logical/backend.rs @@ -251,7 +251,7 @@ mod test { use super::*; use crate::{ - logical::{Field, FieldType, PathOperation}, + logical::{Field, field::FieldTrait, FieldType, PathOperation}, new_fields, new_fields_internal, new_path, new_path_internal, new_secret, new_secret_internal, storage::{barrier_aes_gcm::AESGCMBarrier, physical}, }; @@ -456,4 +456,286 @@ mod test { assert!(logical_backend.secret("kv").unwrap().renew(&logical_backend, &mut req).is_ok()); assert!(logical_backend.secret("kv").unwrap().revoke(&logical_backend, &mut req).is_ok()); } + + #[test] + fn test_logical_path_field() { + let dir = env::temp_dir().join("rusty_vault_test_logical_path_field"); + assert!(fs::create_dir(&dir).is_ok()); + defer! ( + assert!(fs::remove_dir_all(&dir).is_ok()); + ); + + let mut conf: HashMap = HashMap::new(); + conf.insert("path".to_string(), Value::String(dir.to_string_lossy().into_owned())); + + let backend = physical::new_backend("file", &conf).unwrap(); + + let barrier = AESGCMBarrier::new(Arc::clone(&backend)); + + let mut logical_backend = new_logical_backend!({ + paths: [ + { + pattern: "/1/(?P[^/.]+?)", + fields: { + "mytype": { + field_type: FieldType::Int, + description: "haha" + }, + "mypath": { + field_type: FieldType::Str, + description: "hehe" + } + }, + operations: [ + {op: Operation::Write, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result, RvError> { + let _bar = req.get_data("bar")?; + Ok(None) + } + } + ] + }, + { + pattern: "/2/(?P.+?)/(?P.+)", + fields: { + "myflag": { + field_type: FieldType::Bool, + description: "hoho" + }, + "foo": { + field_type: FieldType::Str, + description: "foo" + }, + "goo": { + field_type: FieldType::Int, + description: "goo" + }, + "array": { + field_type: FieldType::Array, + required: true, + description: "array" + }, + "array_default": { + field_type: FieldType::Array, + default: "[]", + description: "array default" + }, + "bool": { + field_type: FieldType::Bool, + description: "boolean" + }, + "bool_default": { + field_type: FieldType::Bool, + default: true, + description: "boolean default" + }, + "comma": { + field_type: FieldType::CommaStringSlice, + description: "comma string slice" + }, + "comma_default": { + field_type: FieldType::CommaStringSlice, + default: "", + description: "comma string slice" + }, + "map": { + field_type: FieldType::Map, + description: "map" + }, + "map_default": { + field_type: FieldType::Map, + default: {}, + description: "map" + }, + "duration": { + field_type: FieldType::DurationSecond, + description: "duration" + }, + "duration_default": { + field_type: FieldType::DurationSecond, + default: 50, + description: "duration" + } + }, + operations: [ + {op: Operation::Read, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result, RvError> { + let _foo = req.get_data("foo")?; + let _goo = req.get_data("goo")?; + Ok(None) + } + }, + {op: Operation::Write, raw_handler: |_backend: &dyn Backend, req: &mut Request| -> Result, RvError> { + let array_val = req.get_data("array")?; + let array_default_val = req.get_data("array_default")?; + let bool_val = req.get_data("bool")?; + let bool_default_val = req.get_data("bool_default")?; + let comma_val = req.get_data("comma")?; + let comma_default_val = req.get_data("comma_default")?; + let map_val = req.get_data("map")?; + let map_default_val = req.get_data("map_default")?; + let duration_val = req.get_data("duration")?; + let duration_default_val = req.get_data("duration_default")?; + let data = json!({ + "array": array_val, + "array_default": array_default_val, + "bool": bool_val, + "bool_default": bool_default_val, + "comma": comma_val.as_comma_string_slice().unwrap(), + "comma_default": comma_default_val.as_comma_string_slice().unwrap(), + "map": map_val, + "map_default": map_default_val, + "duration": duration_val.as_duration().unwrap().as_secs(), + "duration_default": duration_default_val.as_duration().unwrap().as_secs(), + }) + .as_object() + .unwrap() + .clone(); + Ok(Some(Response::data_response(Some(data)))) + } + } + ] + } + ], + help: "help content", + }); + + assert!(logical_backend.init().is_ok()); + + let mut req = Request::new("/1/bar"); + req.operation = Operation::Read; + req.storage = Some(Arc::new(barrier)); + assert!(logical_backend.handle_request(&mut req).is_err()); + + req.path = "/2/foo/goo".to_string(); + assert!(logical_backend.handle_request(&mut req).is_err()); + + req.path = "/2/foo/22".to_string(); + assert!(logical_backend.handle_request(&mut req).is_ok()); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": true, + "comma": "aa,bb,cc", + "map": {"aa":"bb"}, + "duration": 100, + }); + req.operation = Operation::Write; + req.body = Some(req_body.as_object().unwrap().clone()); + let resp = logical_backend.handle_request(&mut req); + println!("resp: {:?}", resp); + assert!(resp.is_ok()); + let data = resp.unwrap().unwrap().data; + assert!(data.is_some()); + let resp_body = data.unwrap(); + assert_eq!(req_body["array"], resp_body["array"]); + assert_eq!(req_body["bool"], resp_body["bool"]); + let comma = json!(req_body["comma"]); + let comma_slice = comma.as_comma_string_slice(); + assert!(comma_slice.is_some()); + let req_comma = json!(comma_slice.unwrap()); + assert_eq!(req_comma, resp_body["comma"]); + assert_eq!(req_body["map"], resp_body["map"]); + assert_eq!(req_body["duration"], resp_body["duration"]); + assert_eq!(resp_body["array_default"], json!([])); + assert_eq!(resp_body["bool_default"], json!(true)); + assert_eq!(resp_body["comma_default"], json!([])); + assert_eq!(resp_body["map_default"], json!({})); + assert_eq!(resp_body["duration_default"], json!(50)); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": "true", + "comma": "aa,bb,cc", + "map": {"aa":"bb"}, + "duration": 100, + }); + req.body = Some(req_body.as_object().unwrap().clone()); + assert!(logical_backend.handle_request(&mut req).is_err()); + + let req_body = json!({ + "array": "[1, 2, 3]", + "bool": true, + "comma": "aa,bb,cc", + "map": {"aa":"bb"}, + "duration": 100, + }); + req.body = Some(req_body.as_object().unwrap().clone()); + assert!(logical_backend.handle_request(&mut req).is_err()); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": true, + "comma": true, + "map": {"aa":"bb"}, + "duration": 100, + }); + req.body = Some(req_body.as_object().unwrap().clone()); + assert!(logical_backend.handle_request(&mut req).is_err()); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": true, + "comma": "aa,bb,cc", + "map": 11, + "duration": 100, + }); + req.body = Some(req_body.as_object().unwrap().clone()); + assert!(logical_backend.handle_request(&mut req).is_err()); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": true, + "comma": "aa,bb,cc", + "map": {"aa":"bb"}, + "duration": "1000", + }); + req.body = Some(req_body.as_object().unwrap().clone()); + let resp = logical_backend.handle_request(&mut req); + assert!(resp.is_ok()); + let data = resp.unwrap().unwrap().data; + assert!(data.is_some()); + let resp_body = data.unwrap(); + assert_eq!(resp_body["duration"], json!(1000)); + + let req_body = json!({ + "array": [1, 2, 3], + "bool": true, + "comma": [11, 22, 33], + "map": {"aa":"bb"}, + "duration": "1000", + }); + req.body = Some(req_body.as_object().unwrap().clone()); + let resp = logical_backend.handle_request(&mut req); + assert!(resp.is_ok()); + let data = resp.unwrap().unwrap().data; + assert!(data.is_some()); + let resp_body = data.unwrap(); + assert_eq!(resp_body["duration"], json!(1000)); + assert_eq!(resp_body["comma"], json!(["11", "22", "33"])); + + let req_body = json!({ + "array": [1, 2, 3], + "array_default": [1, 2, 3, 4], + "bool": true, + "bool_default": false, + "comma": [11, 22, 33], + "comma_default": [11, 22, 33, 44], + "map": {"aa":"bb"}, + "map_default": {"aa": "bb", "cc": "dd"}, + "duration": "1000", + "duration_default": "2000", + }); + req.body = Some(req_body.as_object().unwrap().clone()); + let resp = logical_backend.handle_request(&mut req); + assert!(resp.is_ok()); + let data = resp.unwrap().unwrap().data; + assert!(data.is_some()); + let resp_body = data.unwrap(); + assert_eq!(resp_body["duration"], json!(1000)); + assert_eq!(resp_body["comma"], json!(["11", "22", "33"])); + assert_eq!(req_body["array_default"], resp_body["array_default"]); + assert_eq!(req_body["bool_default"], resp_body["bool_default"]); + assert_eq!(resp_body["comma_default"], json!(["11", "22", "33", "44"])); + assert_eq!(req_body["map_default"], resp_body["map_default"]); + assert_eq!(resp_body["duration_default"], json!(2000)); + } } diff --git a/src/logical/connection.rs b/src/logical/connection.rs index fff291f..0941373 100644 --- a/src/logical/connection.rs +++ b/src/logical/connection.rs @@ -1,3 +1,15 @@ +use openssl::x509::X509; + pub struct Connection { - pub remote_addr: String, + pub peer_addr: String, + pub peer_tls_cert: Option>, +} + +impl Default for Connection { + fn default() -> Self { + Self { + peer_addr: String::new(), + peer_tls_cert: None, + } + } } diff --git a/src/logical/field.rs b/src/logical/field.rs index 6c14301..f2bd91f 100644 --- a/src/logical/field.rs +++ b/src/logical/field.rs @@ -1,8 +1,8 @@ -use std::{any::Any, fmt, sync::Arc}; +use std::{fmt, time::Duration}; use enum_map::Enum; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{json, Value}; use strum::{Display, EnumString}; use crate::errors::RvError; @@ -19,41 +19,285 @@ pub enum FieldType { Bool, #[strum(to_string = "map")] Map, + #[strum(to_string = "array")] + Array, + #[strum(to_string = "duration_second")] + DurationSecond, + #[strum(to_string = "comma_string_slice")] + CommaStringSlice, } #[derive(Clone)] pub struct Field { pub required: bool, pub field_type: FieldType, - pub default: Arc, + pub default: Value, pub description: String, } +pub trait FieldTrait { + fn is_int(&self) -> bool; + fn is_duration(&self) -> bool; + fn is_comma_string_slice(&self) -> bool; + fn as_int(&self) -> Option; + fn as_duration(&self) -> Option; + fn as_comma_string_slice(&self) -> Option>; +} + +impl FieldTrait for Value { + fn is_int(&self) -> bool { + if self.is_i64() { + return true; + } + + let int_str = self.as_str(); + if int_str.is_none() { + return false; + } + + let int = int_str.unwrap().parse::().ok(); + if int.is_none() { + return false; + } + + true + } + + fn is_duration(&self) -> bool { + if self.is_i64() { + return true; + } + + let secs_str = self.as_str(); + if secs_str.is_none() { + return false; + } + + let secs = secs_str.unwrap().parse::().ok(); + if secs.is_none() { + return false; + } + + true + } + + fn is_comma_string_slice(&self) -> bool { + let arr = self.as_array(); + if arr.is_some() { + let arr_val = arr.unwrap(); + for item in arr_val.iter() { + let item_val = item.as_str(); + if item_val.is_some() { + continue; + } + + let item_val = item.as_i64(); + if item_val.is_some() { + continue; + } + + return false; + } + + return true; + } + + let value = self.as_i64(); + if value.is_some() { + return true; + } + + let value = self.as_str(); + if value.is_some() { + return true; + } + + false + } + + fn as_int(&self) -> Option { + let mut int = self.as_i64(); + if int.is_none() { + let int_str = self.as_str(); + if int_str.is_none() { + return None; + } + + int = int_str.unwrap().parse::().ok(); + if int.is_none() { + return None; + } + } + + int + } + + fn as_duration(&self) -> Option { + let mut secs = self.as_u64(); + if secs.is_none() { + let secs_str = self.as_str(); + if secs_str.is_none() { + return None; + } + + secs = secs_str.unwrap().parse::().ok(); + if secs.is_none() { + return None; + } + } + Some(Duration::from_secs(secs.unwrap())) + } + + fn as_comma_string_slice(&self) -> Option> { + let mut ret = Vec::new(); + let arr = self.as_array(); + if arr.is_some() { + let arr_val = arr.unwrap(); + for item in arr_val.iter() { + let item_val = item.as_str(); + if item_val.is_some() { + ret.push(item_val.unwrap().trim().to_string()); + continue; + } + + let item_val = item.as_i64(); + if item_val.is_some() { + ret.push(item_val.unwrap().to_string()); + continue; + } + + return None; + } + + return Some(ret); + } + + let value = self.as_i64(); + if value.is_some() { + ret.push(value.unwrap().to_string()); + return Some(ret); + } + + let value = self.as_str(); + if value.is_some() { + return Some(value.unwrap().split(',').map(|s| s.trim().to_string()).filter(|s| !s.is_empty()).collect()); + } + + None + } +} + impl Field { pub fn new() -> Self { Self { - required: true, + required: false, field_type: FieldType::Str, - default: Arc::new(String::new()), + default: json!(null), description: String::new(), } } - pub fn get_default(&self) -> Result { + pub fn check_data_type(&self, data: &Value) -> bool { match &self.field_type { - FieldType::Str => self.cast_value::(), - FieldType::SecretStr => self.cast_value::(), - FieldType::Int => self.cast_value::(), - FieldType::Bool => self.cast_value::(), - FieldType::Map => self.cast_value::(), + FieldType::SecretStr | FieldType::Str => data.is_string(), + FieldType::Int => data.is_int(), + FieldType::Bool => data.is_boolean(), + FieldType::Array => data.is_array(), + FieldType::Map => data.is_object(), + FieldType::DurationSecond => data.is_duration(), + FieldType::CommaStringSlice => data.is_comma_string_slice(), } } - fn cast_value(&self) -> Result { - if let Some(value) = self.default.downcast_ref::() { - Ok(serde_json::to_value(value).map_err(|_| RvError::ErrRustDowncastFailed)?) - } else { - Err(RvError::ErrRustDowncastFailed) + pub fn get_default(&self) -> Result { + if self.default.is_null() { + match &self.field_type { + FieldType::SecretStr | FieldType::Str => { + return Ok(json!("")); + }, + FieldType::Int => { + return Ok(json!(0)); + }, + FieldType::Bool => { + return Ok(json!(false)); + }, + FieldType::Array => { + return Ok(json!([])); + }, + FieldType::Map => { + return Ok(serde_json::from_str("{}")?); + }, + FieldType::DurationSecond => { + return Ok(json!(0)); + }, + FieldType::CommaStringSlice => { + return Ok(json!([])); + } + } + } + + match &self.field_type { + FieldType::SecretStr | FieldType::Str => { + if self.default.is_string() { + return Ok(self.default.clone()); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::Int => { + if self.default.is_i64() { + return Ok(self.default.clone()); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::Bool => { + if self.default.is_boolean() { + return Ok(self.default.clone()); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::Array => { + if self.default.is_array() { + return Ok(self.default.clone()); + } else if self.default.is_string() { + let arr_str = self.default.as_str(); + if arr_str.is_none() { + return Err(RvError::ErrRustDowncastFailed); + } + return Ok(serde_json::from_str(arr_str.unwrap())?); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::Map => { + if self.default.is_object() { + return Ok(self.default.clone()); + } else if self.default.is_string() { + let arr_str = self.default.as_str(); + if arr_str.is_none() { + return Err(RvError::ErrRustDowncastFailed); + } + return Ok(serde_json::from_str(arr_str.unwrap())?); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::DurationSecond => { + if self.default.is_duration() { + return Ok(self.default.clone()); + } + + return Err(RvError::ErrRustDowncastFailed); + }, + FieldType::CommaStringSlice => { + if self.default.is_comma_string_slice() { + return Ok(self.default.clone()); + } + + return Err(RvError::ErrRustDowncastFailed); + } } } } @@ -71,8 +315,6 @@ impl fmt::Debug for Field { #[cfg(test)] mod test { - use std::sync::Arc; - use serde_json::{json, Number, Value}; use super::*; @@ -80,20 +322,20 @@ mod test { #[test] fn test_field_get_default() { let mut field = Field::new(); - field.default = Arc::new("foo".to_string()); + field.default = json!("foo"); assert!(field.get_default().is_ok()); assert_eq!(field.get_default().unwrap(), Value::String("foo".to_string())); field.field_type = FieldType::Int; assert!(field.get_default().is_err()); - field.default = Arc::new(443); + field.default = json!(443); assert!(field.get_default().is_ok()); assert_eq!(field.get_default().unwrap(), Value::Number(Number::from(443))); field.field_type = FieldType::Bool; assert!(field.get_default().is_err()); - field.default = Arc::new(false); + field.default = json!(false); assert!(field.get_default().is_ok()); assert_eq!(field.get_default().unwrap(), Value::Bool(false)); - field.default = Arc::new(true); + field.default = json!(true); assert!(field.get_default().is_ok()); assert_eq!(field.get_default().unwrap(), Value::Bool(true)); field.field_type = FieldType::Map; @@ -107,8 +349,33 @@ mod test { }, "arr": [1, 2, 3], }); - field.default = Arc::new(value.clone()); + field.default = value.clone(); assert!(field.get_default().is_ok()); assert_eq!(field.get_default().unwrap(), value); + field.field_type = FieldType::Array; + field.default = json!([1, 2, 3]); + assert!(field.get_default().is_ok()); + let val = json!([1, 2, 3]); + assert_eq!(field.get_default().unwrap(), val); + field.field_type = FieldType::DurationSecond; + field.default = json!("10"); + println!("{:?}", field.get_default()); + assert!(field.get_default().is_ok()); + assert_eq!(field.get_default().unwrap().as_duration().unwrap(), Duration::from_secs(10)); + field.field_type = FieldType::CommaStringSlice; + field.default = json!([1, 2, 3]); + assert!(field.get_default().is_ok()); + let val_int = json!([1, 2, 3]); + let val_str = vec!["1", "2", "3"]; + let val = field.get_default().unwrap(); + assert_eq!(val.as_comma_string_slice(), Some(val_str.iter().map(|&s| s.to_string()).collect::>())); + assert_eq!(val, val_int); + field.default = json!("a,b,c"); + let val_str = vec!["a", "b", "c"]; + let val = field.get_default().unwrap(); + assert_eq!(val.as_comma_string_slice(), Some(val_str.iter().map(|&s| s.to_string()).collect::>())); + field.default = json!("a ,, b , c,"); + let val = field.get_default().unwrap(); + assert_eq!(val.as_comma_string_slice(), Some(val_str.iter().map(|&s| s.to_string()).collect::>())); } } diff --git a/src/logical/mod.rs b/src/logical/mod.rs index 58dfb5f..aa931cb 100644 --- a/src/logical/mod.rs +++ b/src/logical/mod.rs @@ -24,6 +24,7 @@ pub use path::{Path, PathOperation}; pub use request::Request; pub use response::Response; pub use secret::{Secret, SecretData}; +pub use connection::Connection; #[derive(Eq, PartialEq, Copy, Clone, Debug, EnumString, Display, Enum, Serialize, Deserialize)] pub enum Operation { diff --git a/src/logical/path.rs b/src/logical/path.rs index 5528b40..7846702 100644 --- a/src/logical/path.rs +++ b/src/logical/path.rs @@ -67,11 +67,8 @@ macro_rules! new_fields_internal { $object.required = $required; }; (@object $object:ident default: $default:expr) => { - if $object.field_type == FieldType::Str { - $object.default = Arc::new($default.to_string()); - } else { - $object.default = Arc::new($default); - } + let val = serde_json::json!($default); + $object.default = val; $object.required = false; }; (@object $object:ident description: $description:expr) => { diff --git a/src/logical/request.rs b/src/logical/request.rs index cf8232d..e130932 100644 --- a/src/logical/request.rs +++ b/src/logical/request.rs @@ -75,21 +75,26 @@ impl Request { if field.is_none() { return Err(RvError::ErrRequestNoDataField); } + let field = field.unwrap(); if self.data.is_some() { if let Some(data) = self.data.as_ref().unwrap().get(key) { + if !field.check_data_type(&data) { + return Err(RvError::ErrRequestFieldInvalid); + } return Ok(data.clone()); } } if self.body.is_some() { if let Some(data) = self.body.as_ref().unwrap().get(key) { + if !field.check_data_type(&data) { + return Err(RvError::ErrRequestFieldInvalid); + } return Ok(data.clone()); } } - let field = field.unwrap(); - if field.required { return Err(RvError::ErrRequestFieldNotFound); } diff --git a/src/modules/auth/token_store.rs b/src/modules/auth/token_store.rs index 4a23108..885e2bd 100644 --- a/src/modules/auth/token_store.rs +++ b/src/modules/auth/token_store.rs @@ -514,6 +514,7 @@ impl TokenStoreInner { display_name: te.display_name.clone(), policies: te.policies.clone(), metadata: te.meta.clone(), + ..Default::default() }; let resp = Response { auth: Some(auth), ..Response::default() }; diff --git a/src/modules/credential/mod.rs b/src/modules/credential/mod.rs index 7134d2e..6a266e8 100644 --- a/src/modules/credential/mod.rs +++ b/src/modules/credential/mod.rs @@ -1 +1,2 @@ pub mod userpass; +//pub mod cert; diff --git a/src/modules/pki/field.rs b/src/modules/pki/field.rs index de6e947..d05278c 100644 --- a/src/modules/pki/field.rs +++ b/src/modules/pki/field.rs @@ -9,7 +9,7 @@ pub fn ca_common_fields() -> HashMap> { let fields = new_fields!({ "alt_names": { field_type: FieldType::Str, - required: false, + default: "", description: r#"The requested Subject Alternative Names, if any, in a comma-delimited list. May contain both DNS names and email addresses."# }, @@ -45,43 +45,43 @@ Set the not after field of the certificate with specified date value. The value format should be given in UTC format YYYY-MM-ddTHH:MM:SSZ."# }, "ou": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, OU (OrganizationalUnit) will be set to this value."# }, "organization": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, O (Organization) will be set to this value."# }, "country": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, Country will be set to this value."# }, "locality": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, Locality will be set to this value in certificates issued by this role."# }, "province": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, Province will be set to this value."# }, "street_address": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, Street Address will be set to this value."# }, "postal_code": { - required: false, field_type: FieldType::Str, + default: "", description: r#"If set, Postal Code will be set to this value."# }, "serial_number": { - required: false, field_type: FieldType::Str, + default: "", description: r#"The Subject's requested serial number, if any. See RFC 4519 Section 2.31 'serialNumber' for a description of this field. If you want more than one, specify alternative names in the alt_names diff --git a/src/modules/pki/mod.rs b/src/modules/pki/mod.rs index 73f20a0..7d6f872 100644 --- a/src/modules/pki/mod.rs +++ b/src/modules/pki/mod.rs @@ -232,7 +232,7 @@ x/+V28hUf8m8P2NxP5ALaDZagdaMfzjGZo3O3wDv33Cds0P5GMGQYnRXDxcZN/2L req.body = data; let resp = core.handle_request(&mut req); - println!("path: {}, resp: {:?}", path, resp); + println!("path: {}, req.body: {:?}, resp: {:?}", path, req.body, resp); assert_eq!(resp.is_ok(), is_ok); resp } diff --git a/src/modules/pki/path_keys.rs b/src/modules/pki/path_keys.rs index 6123267..1a77faa 100644 --- a/src/modules/pki/path_keys.rs +++ b/src/modules/pki/path_keys.rs @@ -163,8 +163,8 @@ used for sign,verify,encrypt,decrypt. description: "Data that needs to be encrypted" }, "aad": { - required: false, field_type: FieldType::Str, + default: "", description: "Additional Authenticated Data can be provided for aes-gcm/cbc encryption" } }, @@ -194,8 +194,8 @@ used for sign,verify,encrypt,decrypt. description: "Data that needs to be decrypted" }, "aad": { - required: false, field_type: FieldType::Str, + default: "", description: "Additional Authenticated Data can be provided for aes-gcm/cbc decryption" } }, @@ -269,9 +269,12 @@ impl PkiBackendInner { let key_name = key_name_value.as_str().unwrap(); let key_type_value = req.get_data("key_type")?; let key_type = key_type_value.as_str().unwrap(); - let pem_bundle_value = req.get_data("pem_bundle"); - let hex_bundle_value = req.get_data("hex_bundle"); - if pem_bundle_value.is_err() && hex_bundle_value.is_err() { + let pem_bundle_value = req.get_data("pem_bundle")?; + let pem_bundle = pem_bundle_value.as_str().unwrap(); + let hex_bundle_value = req.get_data("hex_bundle")?; + let hex_bundle = hex_bundle_value.as_str().unwrap(); + + if pem_bundle.len() == 0 && hex_bundle.len() == 0 { return Err(RvError::ErrRequestFieldNotFound); } @@ -282,56 +285,46 @@ impl PkiBackendInner { let mut key_bundle = KeyBundle::new(key_name, key_type.to_lowercase().as_str(), 0); - match pem_bundle_value { - Ok(pem_bundle_val) => { - if let Some(pem_bundle) = pem_bundle_val.as_str() { - key_bundle.key = pem_bundle.as_bytes().to_vec(); - match key_type { - "rsa" => { - let rsa = Rsa::private_key_from_pem(&key_bundle.key)?; - key_bundle.bits = rsa.size() * 8; - } - "ec" => { - let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?; - key_bundle.bits = ec_key.group().degree(); - } - _ => { - return Err(RvError::ErrPkiKeyTypeInvalid); - } - } + if pem_bundle.len() != 0 { + key_bundle.key = pem_bundle.as_bytes().to_vec(); + match key_type { + "rsa" => { + let rsa = Rsa::private_key_from_pem(&key_bundle.key)?; + key_bundle.bits = rsa.size() * 8; + }, + "ec" => { + let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?; + key_bundle.bits = ec_key.group().degree(); + }, + _ => { + return Err(RvError::ErrPkiKeyTypeInvalid); } - } - _ => {} + }; } - match hex_bundle_value { - Ok(hex_bundle_val) => { - if let Some(hex_bundle) = hex_bundle_val.as_str() { - key_bundle.key = hex::decode(&hex_bundle)?; - key_bundle.bits = (key_bundle.key.len() as u32) * 8; - match key_bundle.bits { - 128 | 192 | 256 => {} - _ => { - return Err(RvError::ErrPkiKeyBitsInvalid); - } - } - let iv_value = req.get_data("iv")?; - match key_type { - "aes-gcm" | "aes-cbc" => { - if let Some(iv) = iv_value.as_str() { - key_bundle.iv = hex::decode(&iv)?; - } else { - return Err(RvError::ErrRequestFieldNotFound); - } - } - "aes-ecb" => {} - _ => { - return Err(RvError::ErrPkiKeyTypeInvalid); - } + if hex_bundle.len() != 0 { + key_bundle.key = hex::decode(&hex_bundle)?; + key_bundle.bits = (key_bundle.key.len() as u32) * 8; + match key_bundle.bits { + 128 | 192 | 256 => {}, + _ => { + return Err(RvError::ErrPkiKeyBitsInvalid); + } + }; + let iv_value = req.get_data("iv")?; + match key_type { + "aes-gcm" | "aes-cbc" => { + if let Some(iv) = iv_value.as_str() { + key_bundle.iv = hex::decode(&iv)?; + } else { + return Err(RvError::ErrRequestFieldNotFound); } + }, + "aes-ecb" => {}, + _ => { + return Err(RvError::ErrPkiKeyTypeInvalid); } - } - _ => {} + }; } self.write_key(req, &key_bundle)?; diff --git a/src/modules/system/mod.rs b/src/modules/system/mod.rs index d6c2622..5caf139 100644 --- a/src/modules/system/mod.rs +++ b/src/modules/system/mod.rs @@ -91,8 +91,8 @@ impl SystemBackend { description: r#"The type of the backend. Example: "kv""# }, "description": { - required: false, field_type: FieldType::Str, + default: "", description: r#"User-friendly description for this mount."# } }, @@ -173,8 +173,8 @@ impl SystemBackend { description: r#"The type of the backend. Example: "userpass""# }, "description": { - required: false, field_type: FieldType::Str, + default: "", description: r#"User-friendly description for this crential backend."# } }, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9527247..f23150c 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -94,3 +94,14 @@ pub fn asn1time_to_timestamp(time_str: &str) -> Result { Ok(timestamp) } + +pub fn hex_encode_with_colon(bytes: &[u8]) -> String { + let hex_str = hex::encode(bytes); + let split_hex: Vec = hex_str + .as_bytes() + .chunks(2) + .map(|chunk| String::from_utf8(chunk.to_vec()).unwrap()) + .collect(); + + split_hex.join(":") +}