Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to rust-tongsuo, supporting SM2 and SM4 algorithms. #60

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RustyVault's RESTful API is designed to be fully compatible with Hashicorp Vault
"""
repository = "https://github.com/Tongsuo-Project/RustyVault"
documentation = "https://docs.rs/rusty_vault/latest/rusty_vault/"
build = "build.rs"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -24,8 +25,8 @@ serde_json = "^1.0"
serde_bytes = "0.11"
go-defer = "^0.1"
rand = "^0.8"
openssl = "0.10"
openssl-sys = "0.9.92"
openssl = { version = "0.10" }
openssl-sys = { version = "0.9" }
derivative = "2.2.0"
enum-map = "2.6.1"
strum = { version = "0.25", features = ["derive"] }
Expand Down Expand Up @@ -60,6 +61,10 @@ serde_asn1_der = "0.8"
base64 = "0.22"
ipnetwork = "0.20"

[patch.crates-io]
openssl = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" }
openssl-sys = { git = "https://github.com/Tongsuo-Project/rust-tongsuo.git" }

[features]
storage_mysql = ["diesel", "r2d2", "r2d2-diesel"]

Expand Down
7 changes: 7 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use std::env;

fn main() {
if let Ok(_) = env::var("DEP_OPENSSL_TONGSUO") {
println!("cargo:rustc-cfg=tongsuo");
}
}
2 changes: 1 addition & 1 deletion src/modules/pki/path_config_ca.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ For security reasons, you can only view the certificate when reading this endpoi
impl PkiBackendInner {
pub fn write_path_ca(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().unwrap();
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let items = pem::parse_many(pem_bundle)?;
let mut key_found = false;
Expand Down
2 changes: 1 addition & 1 deletion src/modules/pki/path_fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl PkiBackendInner {

pub fn read_path_fetch_cert(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let serial_number_value = req.get_data("serial")?;
let serial_number = serial_number_value.as_str().unwrap();
let serial_number = serial_number_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let serial_number_hex = serial_number.replace(":", "-").to_lowercase();
let cert = self.fetch_cert(req, &serial_number_hex)?;
let ca_bundle = self.fetch_ca_bundle(req)?;
Expand Down
9 changes: 4 additions & 5 deletions src/modules/pki/path_issue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ requested common name is allowed by the role policy.

impl PkiBackendInner {
pub fn issue_cert(&self, backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let role_value = req.get_data("role")?;
let role_name = role_value.as_str().unwrap();
//let role_name = req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut common_names = Vec::new();

let common_name_value = req.get_data("common_name")?;
let common_name = common_name_value.as_str().unwrap();
let common_name = common_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if common_name != "" {
common_names.push(common_name.to_string());
}
Expand All @@ -87,7 +86,7 @@ impl PkiBackendInner {
}
}

let role = self.get_role(req, &role_name)?;
let role = self.get_role(req, req.get_data("role")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
if role.is_none() {
return Err(RvError::ErrPkiRoleNotFound);
}
Expand All @@ -111,7 +110,7 @@ impl PkiBackendInner {
let mut not_after = not_before + parse_duration("30d").unwrap();

let ttl_value = req.get_data("ttl")?;
let ttl = ttl_value.as_str().unwrap();
let ttl = ttl_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if ttl != "" {
let ttl_dur = parse_duration(ttl)?;
let req_ttl_not_after_dur = SystemTime::now() + ttl_dur;
Expand Down
79 changes: 38 additions & 41 deletions src/modules/pki/path_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,10 @@ used for sign,verify,encrypt,decrypt.
impl PkiBackendInner {
pub fn generate_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().unwrap();
let key_bits_value = req.get_data("key_bits")?;
let key_bits = key_bits_value.as_u64().unwrap();
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;

let mut export_private_key = false;
if req.path.ends_with("/exported") {
Expand Down Expand Up @@ -245,7 +244,7 @@ impl PkiBackendInner {

if export_private_key {
match key_type {
"rsa" | "ec" => {
"rsa" | "ec" | "sm2" => {
resp_data.insert(
"private_key".to_string(),
Value::String(String::from_utf8_lossy(&key_bundle.key).to_string()),
Expand All @@ -266,13 +265,13 @@ impl PkiBackendInner {

pub fn import_key(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let key_name = key_name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().unwrap();
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let pem_bundle_value = req.get_data("pem_bundle")?;
let pem_bundle = pem_bundle_value.as_str().unwrap();
let pem_bundle = pem_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let hex_bundle_value = req.get_data("hex_bundle")?;
let hex_bundle = hex_bundle_value.as_str().unwrap();
let hex_bundle = hex_bundle_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

if pem_bundle.len() == 0 && hex_bundle.len() == 0 {
return Err(RvError::ErrRequestFieldNotFound);
Expand All @@ -292,7 +291,7 @@ impl PkiBackendInner {
let rsa = Rsa::private_key_from_pem(&key_bundle.key)?;
key_bundle.bits = rsa.size() * 8;
},
"ec" => {
"ec" | "sm2" => {
let ec_key = EcKey::private_key_from_pem(&key_bundle.key)?;
key_bundle.bits = ec_key.group().degree();
},
Expand All @@ -312,19 +311,25 @@ impl PkiBackendInner {
}
};
let iv_value = req.get_data("iv")?;
match key_type {
"aes-gcm" | "aes-cbc" => {
if let Some(iv) = iv_value.as_str() {
key_bundle.iv = hex::decode(&iv)?;
} else {
return Err(RvError::ErrRequestFieldNotFound);
}
},
"aes-ecb" => {},
_ => {
return Err(RvError::ErrPkiKeyTypeInvalid);
let is_iv_required = matches!(key_type, "aes-gcm" | "aes-cbc" | "sm4-gcm" | "sm4-ccm");
#[cfg(tongsuo)]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb" | "sm4-gcm" | "sm4-ccm");
#[cfg(not(tongsuo))]
let is_valid_key_type = matches!(key_type, "aes-gcm" | "aes-cbc" | "aes-ecb");

// Check if the key type is valid, if not return an error.
if !is_valid_key_type {
return Err(RvError::ErrPkiKeyTypeInvalid);
}

// Proceed to check IV only if required by the key type.
if is_iv_required {
if let Some(iv) = iv_value.as_str() {
key_bundle.iv = hex::decode(&iv)?;
} else {
return Err(RvError::ErrRequestFieldNotFound);
}
};
}
}

self.write_key(req, &key_bundle)?;
Expand All @@ -343,12 +348,10 @@ impl PkiBackendInner {
}

pub fn key_sign(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.sign(&decoded_data)?;
Expand All @@ -364,14 +367,12 @@ impl PkiBackendInner {
}

pub fn key_verify(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let signature_value = req.get_data("signature")?;
let signature = signature_value.as_str().unwrap();
let signature = signature_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let decoded_signature = hex::decode(signature.as_bytes())?;
Expand All @@ -388,14 +389,12 @@ impl PkiBackendInner {
}

pub fn key_encrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad = aad_value.as_str().unwrap();
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.encrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?;
Expand All @@ -411,14 +410,12 @@ impl PkiBackendInner {
}

pub fn key_decrypt(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let key_name_value = req.get_data("key_name")?;
let key_name = key_name_value.as_str().unwrap();
let data_value = req.get_data("data")?;
let data = data_value.as_str().unwrap();
let data = data_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let aad_value = req.get_data("aad")?;
let aad = aad_value.as_str().unwrap();
let aad = aad_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;

let key_bundle = self.fetch_key(req, key_name)?;
let key_bundle = self.fetch_key(req, req.get_data("key_name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;

let decoded_data = hex::decode(data.as_bytes())?;
let result = key_bundle.decrypt(&decoded_data, Some(EncryptExtraData::Aad(aad.as_bytes())))?;
Expand Down
95 changes: 32 additions & 63 deletions src/modules/pki/path_roles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,30 +318,19 @@ impl PkiBackendInner {
}

pub fn read_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let role_entry = self.get_role(req, name)?;
let role_entry = self.get_role(req, req.get_data("name")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let data = serde_json::to_value(&role_entry)?;
Ok(Some(Response::data_response(Some(data.as_object().unwrap().clone()))))
}

pub fn create_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let ttl_vale = req.get_data("ttl")?;
let ttl = {
let ttl_str = ttl_vale.as_str().unwrap();
parse_duration(ttl_str)?
};
let max_ttl_vale = req.get_data("max_ttl")?;
let max_ttl = {
let max_ttl_str = max_ttl_vale.as_str().unwrap();
parse_duration(max_ttl_str)?
};
let key_type_vale = req.get_data("key_type")?;
let key_type = key_type_vale.as_str().unwrap();
let key_bits_vale = req.get_data("key_bits")?;
let mut key_bits = key_bits_vale.as_u64().unwrap();
let name_value = req.get_data("name")?;
let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let ttl = parse_duration(req.get_data("ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let max_ttl = parse_duration(req.get_data("max_ttl")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?)?;
let key_type_value = req.get_data("key_type")?;
let key_type = key_type_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
let mut key_bits = req.get_data("key_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
match key_type {
"rsa" => {
if key_bits == 0 {
Expand All @@ -366,48 +355,28 @@ impl PkiBackendInner {
}
}

let signature_bits_vale = req.get_data("signature_bits")?;
let signature_bits = signature_bits_vale.as_u64().unwrap();
let allow_localhost_vale = req.get_data("allow_localhost")?;
let allow_localhost = allow_localhost_vale.as_bool().unwrap();
let allow_bare_domain_vale = req.get_data("allow_bare_domains")?;
let allow_bare_domains = allow_bare_domain_vale.as_bool().unwrap();
let allow_subdomains_vale = req.get_data("allow_subdomains")?;
let allow_subdomains = allow_subdomains_vale.as_bool().unwrap();
let allow_any_name_vale = req.get_data("allow_any_name")?;
let allow_any_name = allow_any_name_vale.as_bool().unwrap();
let allow_ip_sans_vale = req.get_data("allow_ip_sans")?;
let allow_ip_sans = allow_ip_sans_vale.as_bool().unwrap();
let server_flag_vale = req.get_data("server_flag")?;
let server_flag = server_flag_vale.as_bool().unwrap();
let client_flag_vale = req.get_data("client_flag")?;
let client_flag = client_flag_vale.as_bool().unwrap();
let use_csr_sans_vale = req.get_data("use_csr_sans")?;
let use_csr_sans = use_csr_sans_vale.as_bool().unwrap();
let use_csr_common_name_vale = req.get_data("use_csr_common_name")?;
let use_csr_common_name = use_csr_common_name_vale.as_bool().unwrap();
let country_vale = req.get_data("country")?;
let country = country_vale.as_str().unwrap().to_string();
let province_vale = req.get_data("province")?;
let province = province_vale.as_str().unwrap().to_string();
let locality_vale = req.get_data("locality")?;
let locality = locality_vale.as_str().unwrap().to_string();
let organization_vale = req.get_data("organization")?;
let organization = organization_vale.as_str().unwrap().to_string();
let ou_vale = req.get_data("ou")?;
let ou = ou_vale.as_str().unwrap().to_string();
let street_address_vale = req.get_data("street_address")?;
let street_address = street_address_vale.as_str().unwrap().to_string();
let postal_code_vale = req.get_data("postal_code")?;
let postal_code = postal_code_vale.as_str().unwrap().to_string();
let no_store_vale = req.get_data("no_store")?;
let no_store = no_store_vale.as_bool().unwrap();
let generate_lease_vale = req.get_data("generate_lease")?;
let generate_lease = generate_lease_vale.as_bool().unwrap();
let not_after_vale = req.get_data("not_after")?;
let not_after = not_after_vale.as_str().unwrap().to_string();
let not_before_duration_vale = req.get_data("not_before_duration")?;
let not_before_duration = Duration::from_secs(not_before_duration_vale.as_u64().unwrap());
let signature_bits = req.get_data("signature_bits")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_localhost = req.get_data("allow_localhost")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_bare_domains = req.get_data("allow_bare_domains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_subdomains = req.get_data("allow_subdomains")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_any_name = req.get_data("allow_any_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let allow_ip_sans = req.get_data("allow_ip_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let server_flag = req.get_data("server_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let client_flag = req.get_data("client_flag")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let use_csr_sans = req.get_data("use_csr_sans")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let use_csr_common_name = req.get_data("use_csr_common_name")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let country = req.get_data("country")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let province = req.get_data("province")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let locality = req.get_data("locality")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let organization = req.get_data("organization")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let ou = req.get_data("ou")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let street_address = req.get_data("street_address")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let postal_code = req.get_data("postal_code")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let no_store = req.get_data("no_store")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let generate_lease = req.get_data("generate_lease")?.as_bool().ok_or(RvError::ErrRequestFieldInvalid)?;
let not_after = req.get_data("not_after")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)?.to_string();
let not_before_duration_u64 = req.get_data("not_before_duration")?.as_u64().ok_or(RvError::ErrRequestFieldInvalid)?;
let not_before_duration = Duration::from_secs(not_before_duration_u64);

let role_entry = RoleEntry {
ttl,
Expand Down Expand Up @@ -446,8 +415,8 @@ impl PkiBackendInner {
}

pub fn delete_path_role(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let name_vale = req.get_data("name")?;
let name = name_vale.as_str().unwrap();
let name_value = req.get_data("name")?;
let name = name_value.as_str().ok_or(RvError::ErrRequestFieldInvalid)?;
if name == "" {
return Err(RvError::ErrRequestNoDataField);
}
Expand Down
4 changes: 1 addition & 3 deletions src/modules/pki/path_root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ impl PkiBackend {
impl PkiBackendInner {
pub fn generate_root(&self, _backend: &dyn Backend, req: &mut Request) -> Result<Option<Response>, RvError> {
let mut export_private_key = false;
let exported_vale = req.get_data("exported")?;
let exported = exported_vale.as_str().unwrap();
if exported == "exported" {
if req.get_data("exported")?.as_str().ok_or(RvError::ErrRequestFieldInvalid)? == "exported" {
export_private_key = true;
}

Expand Down
Loading
Loading