Skip to content

Commit

Permalink
Test case optimization and bug fixing.
Browse files Browse the repository at this point in the history
1. Add test framework TestHttpServer for testing HTTP interfaces to better test
   restful API interfaces.
2. Fix the bug of rustls parsing failure caused by incorrect version field
   in issued certificate.
3. Fix the bug of unable to obtain client certificate through HTTPS interface.
  • Loading branch information
wa5i committed Sep 18, 2024
1 parent 4bce448 commit 5737fdd
Show file tree
Hide file tree
Showing 33 changed files with 938 additions and 133 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,22 @@ delay_timer = "0.11.6"
as-any = "0.3.1"
pem = "3.0"
chrono = "0.4"
zeroize = { version = "1.7.0", features= ["zeroize_derive"] }
zeroize = { version = "1.7.0", features = ["zeroize_derive"] }
diesel = { version = "2.1.4", features = ["mysql", "r2d2"], optional = true }
r2d2 = { version = "0.8.9", optional = true }
r2d2-diesel = { version = "1.0.0", optional = true }
bcrypt = "0.15"
url = "2.5"
ureq = "2.9"
ureq = { version = "2.10", features = ["json"] }
rustls = "0.23"
rustls-pemfile = "2.1"
glob = "0.3"
serde_asn1_der = "0.8"
base64 = "0.22"
ipnetwork = "0.20"
blake2b_simd = "1.0"
derive_more = "0.99.17"
dashmap = "5.5"
tokio = "1.38"
tokio = { version = "1.40", features = ["rt-multi-thread", "macros"] }
ctor = "0.2.8"
better_default = "1.0.5"

Expand Down
4 changes: 4 additions & 0 deletions src/cli/command/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ pub fn main(config_path: &str) -> Result<(), RvError> {
builder.set_ciphersuites("TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256")?;
}

if !listener.tls_disable_client_certs {
builder.set_verify_callback(SslVerifyMode::PEER, |_, _| true);
}

if listener.tls_require_and_verify_client_cert {
builder.set_verify_callback(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT, move |p, _x| {
return p;
Expand Down
1 change: 0 additions & 1 deletion src/cli/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ mod test {
use std::{env, fs, io::prelude::*};

use super::*;

use crate::test_utils::TEST_DIR;

fn write_file(path: &str, config: &str) -> Result<(), RvError> {
Expand Down
37 changes: 37 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
};

use thiserror::Error;
use actix_web::http::StatusCode;

#[derive(Error, Debug)]
pub enum RvError {
Expand Down Expand Up @@ -266,6 +267,21 @@ pub enum RvError {
source: url::ParseError,
},

#[error("Some rustls error happened, {:?}", .source)]
RustlsError {
#[from]
source: rustls::Error,
},

#[error("Some rustls_pemfile error happened")]
RustlsPemFileError(rustls_pemfile::Error),

#[error("Some string utf8 error happened, {:?}", .source)]
StringUtf8Error {
#[from]
source: std::string::FromUtf8Error,
},

/// Database Errors Begin
///
#[error("Database type is not support now. Please try postgressql or mysql again.")]
Expand Down Expand Up @@ -297,6 +313,21 @@ pub enum RvError {
ErrUnknown,
}

impl RvError {
pub fn response_status(&self) -> StatusCode {
match self {
RvError::ErrRequestNoData
| RvError::ErrRequestNoDataField
| RvError::ErrRequestInvalid
| RvError::ErrRequestClientTokenMissing
| RvError::ErrRequestFieldNotFound
| RvError::ErrRequestFieldInvalid => StatusCode::BAD_REQUEST,
RvError::ErrPermissionDenied => StatusCode::FORBIDDEN,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}

impl PartialEq for RvError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
Expand Down Expand Up @@ -398,6 +429,12 @@ impl<T> From<PoisonError<RwLockReadGuard<'_, T>>> for RvError {
}
}

impl From<rustls_pemfile::Error> for RvError {
fn from(err: rustls_pemfile::Error) -> Self {
RvError::RustlsPemFileError(err)
}
}

#[macro_export]
macro_rules! rv_error_response {
($message:expr) => {
Expand Down
19 changes: 9 additions & 10 deletions src/http/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ async fn logical_request_handler(
}

let core = core.read()?;
let resp = core.handle_request(&mut r)?;

if r.operation == Operation::Read && resp.is_none() {
return Ok(response_error(StatusCode::NOT_FOUND, ""));
}

if resp.is_none() {
return Ok(response_ok(None, None));
let res = core.handle_request(&mut r)?;
match res {
Some(resp) => response_logical(&resp, &r.path),
None => {
if matches!(r.operation, Operation::Read | Operation::List) {
return Ok(response_error(StatusCode::NOT_FOUND, ""));
}
Ok(response_ok(None, None))
}
}

response_logical(&resp.unwrap(), &r.path)
}

fn response_logical(resp: &Response, path: &str) -> Result<HttpResponse, RvError> {
Expand Down
13 changes: 11 additions & 2 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,19 @@ pub fn request_on_connect_handler(conn: &dyn Any, ext: &mut Extensions) {
return;
}

if let Some(cert_stack) = tls_stream.ssl().verified_chain() {
if let Some(cert_stack) = tls_stream.ssl().peer_cert_chain() {
let certs: Vec<X509> = cert_stack.iter().map(X509Ref::to_owned).collect();
cert_chain = Some(certs);
}

if let Some(cert) = tls_stream.ssl().peer_certificate() {
if let Some(ref mut chain) = cert_chain {
chain.push(cert);
} else {
cert_chain = Some(vec![cert]);
}
}

ext.insert(Connection {
bind: socket.local_addr().unwrap(),
peer: peer_addr.unwrap(),
Expand Down Expand Up @@ -106,9 +114,10 @@ pub fn init_service(cfg: &mut web::ServiceConfig) {
impl ResponseError for RvError {
// builds the actual response to send back when an error occurs
fn error_response(&self) -> HttpResponse {
let mut status = StatusCode::INTERNAL_SERVER_ERROR;
let mut status = self.response_status();
let text: String;
if let RvError::ErrResponse(resp_text) = self {
status = StatusCode::from_u16(400).unwrap();
text = resp_text.clone();
} else if let RvError::ErrResponseStatus(status_code, resp_text) = self {
status = StatusCode::from_u16(status_code.clone()).unwrap();
Expand Down
26 changes: 23 additions & 3 deletions src/logical/auth.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::{
collections::HashMap,
};
use std::collections::HashMap;

use derive_more::{Deref, DerefMut};
use serde::{Deserialize, Serialize};
Expand All @@ -12,9 +10,31 @@ pub struct Auth {
#[deref]
#[deref_mut]
pub lease: Lease,

// ClientToken is the token that is generated for the authentication.
// This will be filled in by Vault core when an auth structure is returned.
// Setting this manually will have no effect.
pub client_token: String,

// DisplayName is a non-security sensitive identifier that is applicable to this Auth.
// It is used for logging and prefixing of dynamic secrets. For example,
// DisplayName may be "armon" for the github credential backend. If the client token
// is used to generate a SQL credential, the user may be "github-armon-uuid".
// This is to help identify the source without using audit tables.
pub display_name: String,

// Policies is the list of policies that the authenticated user is associated with.
pub policies: Vec<String>,

// Indicates that the default policy should not be added by core when creating a token.
// The default policy will still be added if it's explicitly defined.
pub no_default_policy: bool,

// InternalData is JSON-encodable data that is stored with the auth struct.
// This will be sent back during a Renew/Revoke for storing internal data used for those operations.
pub internal_data: HashMap<String, String>,

// Metadata is used to attach arbitrary string-type metadata to an authenticated user.
// This metadata will be outputted into the audit log.
pub metadata: HashMap<String, String>,
}
2 changes: 1 addition & 1 deletion src/logical/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ mod test {

use super::*;
use crate::{
test_utils::test_backend,
logical::{field::FieldTrait, Field, FieldType, PathOperation},
new_fields, new_fields_internal, new_path, new_path_internal, new_secret, new_secret_internal, storage,
test_utils::test_backend,
};

struct MyTest;
Expand Down
2 changes: 1 addition & 1 deletion src/logical/lease.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::time::{Duration, SystemTime};

use serde::{Deserialize, Serialize};
use better_default::Default;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Lease {
Expand Down
2 changes: 1 addition & 1 deletion src/logical/request.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{collections::HashMap, sync::Arc};

use better_default::Default;
use serde_json::{Map, Value};
use tokio::task::JoinHandle;
use better_default::Default;

use super::{Operation, Path};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/logical/response.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::HashMap;

use better_default::Default;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use better_default::Default;

use crate::{
errors::RvError,
Expand Down
7 changes: 2 additions & 5 deletions src/logical/secret.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::{
sync::Arc,
time::Duration,
};
use std::{sync::Arc, time::Duration};

use derive_more::{Deref, DerefMut};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use derive_more::{Deref, DerefMut};

use super::{lease::Lease, Backend, Request, Response};
use crate::errors::RvError;
Expand Down
2 changes: 1 addition & 1 deletion src/modules/auth/expiration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::{
time::{Duration, SystemTime},
};

use better_default::Default;
use delay_timer::prelude::*;
use derive_more::Deref;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use better_default::Default;

use super::TokenStore;
use crate::{
Expand Down
11 changes: 9 additions & 2 deletions src/modules/auth/token_store.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{collections::HashMap, sync::Arc, time::Duration};

use derive_more::Deref;
use humantime::parse_duration;
use lazy_static::lazy_static;
use regex::Regex;
use derive_more::Deref;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};

Expand All @@ -19,10 +19,11 @@ use crate::{
logical::{
Auth, Backend, Field, FieldType, Lease, LogicalBackend, Operation, Path, PathOperation, Request, Response,
},
rv_error_response,
new_fields, new_fields_internal, new_logical_backend, new_logical_backend_internal, new_path, new_path_internal,
router::Router,
storage::{Storage, StorageEntry},
utils::{generate_uuid, is_str_subset, sha1},
utils::{generate_uuid, is_str_subset, sha1, policy::sanitize_policies},
};

const TOKEN_LOOKUP_PREFIX: &str = "id/";
Expand Down Expand Up @@ -662,6 +663,12 @@ impl Handler for TokenStore {
auth.ttl = MAX_LEASE_DURATION_SECS;
}

sanitize_policies(&mut auth.policies, !auth.no_default_policy);

if auth.policies.contains(&"root".to_string()) {
return Err(rv_error_response!("auth methods cannot create root tokens"));
}

let mut te = TokenEntry {
path: req.path.clone(),
meta: auth.metadata.clone(),
Expand Down
27 changes: 6 additions & 21 deletions src/modules/credential/approle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,10 @@ mod test {
use crate::{
core::Core,
logical::{field::FieldTrait, Operation, Request},
test_utils::{test_rusty_vault_init, test_read_api, test_write_api, test_delete_api, test_mount_auth_api},
test_utils::{test_delete_api, test_mount_auth_api, test_read_api, test_rusty_vault_init, test_write_api},
};

pub fn test_read_role(
core: &Core,
token: &str,
path: &str,
role_name: &str,
) -> Result<Option<Response>, RvError> {
pub fn test_read_role(core: &Core, token: &str, path: &str, role_name: &str) -> Result<Option<Response>, RvError> {
let resp = test_read_api(core, token, format!("auth/{}/role/{}", path, role_name).as_str(), true);
assert!(resp.is_ok());
resp
Expand Down Expand Up @@ -470,13 +465,8 @@ mod test {
.as_object()
.unwrap()
.clone();
let resp = test_write_api(
core,
token,
format!("auth/{}/role/{}", path, role_name).as_str(),
true,
Some(data.clone()),
);
let resp =
test_write_api(core, token, format!("auth/{}/role/{}", path, role_name).as_str(), true, Some(data.clone()));
assert!(resp.is_ok());

// Get the role field
Expand All @@ -503,13 +493,8 @@ mod test {
// Update the role
data["token_num_uses"] = Value::from(0);
data["token_type"] = Value::from("batch");
let resp = test_write_api(
core,
token,
format!("auth/{}/role/{}", path, role_name).as_str(),
true,
Some(data.clone()),
);
let resp =
test_write_api(core, token, format!("auth/{}/role/{}", path, role_name).as_str(), true, Some(data.clone()));
assert!(resp.is_ok());

// Get the role field
Expand Down
Loading

0 comments on commit 5737fdd

Please sign in to comment.