Skip to content

Commit

Permalink
Merge pull request #78 from himmelblau-idm/kanidm/mfa
Browse files Browse the repository at this point in the history
WIP: Use the Kanidm MFA patches
  • Loading branch information
dmulder authored Mar 29, 2024
2 parents 9e3235b + cc1f406 commit 7195357
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 74 deletions.
3 changes: 3 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ updates:
# Kanidm currently depends on these outdated versions of opentelemetry
- dependency-name: "opentelemetry-otlp"
- dependency-name: "opentelemetry_sdk"
- dependency-name: "tracing-opentelemetry"
# This requires an update to compact-jwt first, awaiting that update
- dependency-name: "kanidm-hsm-crypto"
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[submodule "src/kanidm"]
path = src/kanidm
url = https://github.com/dmulder/kanidm.git
branch = dmulder/mfa_capabilities
url = https://github.com/kanidm/kanidm.git
branch = master
shallow = true
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ utoipa = "4.0.0"
utoipa-swagger-ui = "4.0.0"
opentelemetry = { version = "0.20.0" }
opentelemetry_api = { version = "0.20.0", features = ["logs", "metrics"] }
opentelemetry-otlp = { version = "0.15.0", default-features = false, features = [
opentelemetry-otlp = { version = "0.13.0", default-features = false, features = [
"serde",
"logs",
"metrics",
"http-proto",
"grpc-tonic",
] }
opentelemetry_sdk = "0.22.1"
opentelemetry_sdk = "0.20.0"
opentelemetry-stdout = { version = "0.1.0", features = [
"logs",
"metrics",
"trace",
] }
tonic = "0.10.2"
tracing-opentelemetry = "0.23.0"
tracing-opentelemetry = "0.21.0"
compact_jwt = { version = "0.3.5", features = ["hsm-crypto", "msextensions"] }
kanidm-hsm-crypto = { version = "^0.2.0", features = ["msextensions"] }
kanidm-hsm-crypto = { version = "^0.1.6", features = ["msextensions"] }
181 changes: 115 additions & 66 deletions src/common/src/idprovider/himmelblau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;
use std::thread::sleep;
use std::time::Duration;
use std::time::SystemTime;
use tokio::sync::RwLock;
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;

use rand::Rng;
Expand Down Expand Up @@ -181,14 +181,21 @@ impl IdProvider for HimmelblauMultiProvider {
token: Option<&UserToken>,
tpm: &mut tpm::BoxedDynTpm,
machine_key: &tpm::MachineKey,
shutdown_rx: &broadcast::Receiver<()>,
) -> Result<(AuthRequest, AuthCredHandler), IdpError> {
match split_username(account_id) {
Some((_sam, domain)) => {
let providers = self.providers.read().await;
match providers.get(domain) {
Some(provider) => {
provider
.unix_user_online_auth_init(account_id, token, tpm, machine_key)
.unix_user_online_auth_init(
account_id,
token,
tpm,
machine_key,
shutdown_rx,
)
.await
}
None => Err(IdpError::NotFound),
Expand All @@ -209,6 +216,7 @@ impl IdProvider for HimmelblauMultiProvider {
keystore: &mut D,
tpm: &mut tpm::BoxedDynTpm,
machine_key: &tpm::MachineKey,
shutdown_rx: &broadcast::Receiver<()>,
) -> Result<(AuthResult, AuthCacheAction), IdpError> {
match split_username(account_id) {
Some((_sam, domain)) => {
Expand All @@ -223,6 +231,7 @@ impl IdProvider for HimmelblauMultiProvider {
keystore,
tpm,
machine_key,
shutdown_rx,
)
.await
}
Expand Down Expand Up @@ -334,6 +343,14 @@ struct MFAAuthContinueI(MFAAuthContinue);
#[allow(clippy::from_over_into)]
impl Into<Vec<String>> for MFAAuthContinueI {
fn into(self) -> Vec<String> {
let max_poll_attempts = match self.0.max_poll_attempts {
Some(n) => n.to_string(),
None => String::new(),
};
let polling_interval = match self.0.polling_interval {
Some(n) => n.to_string(),
None => String::new(),
};
vec![
self.0.mfa_method,
self.0.msg,
Expand All @@ -344,17 +361,29 @@ impl Into<Vec<String>> for MFAAuthContinueI {
self.0.url_end_auth,
self.0.url_begin_auth,
self.0.url_post,
max_poll_attempts,
polling_interval,
]
}
}

impl From<Vec<String>> for MFAAuthContinueI {
fn from(src: Vec<String>) -> Self {
impl From<&Vec<String>> for MFAAuthContinueI {
fn from(src: &Vec<String>) -> Self {
let max_poll_attempts: Option<u32> = if src[9].is_empty() {
None
} else {
src[9].parse().ok()
};
let polling_interval: Option<u32> = if src[10].is_empty() {
None
} else {
src[10].parse().ok()
};
MFAAuthContinueI(MFAAuthContinue {
mfa_method: src[0].clone(),
msg: src[1].clone(),
max_poll_attempts: None,
polling_interval: None,
max_poll_attempts,
polling_interval,
session_id: src[2].clone(),
flow_token: src[3].clone(),
ctx: src[4].clone(),
Expand Down Expand Up @@ -445,6 +474,7 @@ impl IdProvider for HimmelblauProvider {
_token: Option<&UserToken>,
_tpm: &mut tpm::BoxedDynTpm,
_machine_key: &tpm::MachineKey,
_shutdown_rx: &broadcast::Receiver<()>,
) -> Result<(AuthRequest, AuthCredHandler), IdpError> {
Ok((AuthRequest::Password, AuthCredHandler::Password))
}
Expand All @@ -457,8 +487,10 @@ impl IdProvider for HimmelblauProvider {
keystore: &mut D,
tpm: &mut tpm::BoxedDynTpm,
machine_key: &tpm::MachineKey,
shutdown_rx: &broadcast::Receiver<()>,
) -> Result<(AuthResult, AuthCacheAction), IdpError> {
match (cred_handler, pam_next_req) {
let mut shutdown_rx_cl = shutdown_rx.resubscribe();
match (&cred_handler, pam_next_req) {
(AuthCredHandler::Password, PamAuthRequest::Password { cred }) => {
let mut scopes = vec!["GroupMember.Read.All"];
if !self.is_domain_joined(keystore).await {
Expand Down Expand Up @@ -499,11 +531,12 @@ impl IdProvider for HimmelblauProvider {
};
match resp.mfa_method.as_str() {
"PhoneAppNotification" | "PhoneAppOTP" => {
let msg = resp.msg.clone();
*cred_handler = AuthCredHandler::MFA {
data: MFAAuthContinueI(resp).into(),
};
return Ok((
AuthResult::Next(AuthRequest::MFACode {
msg: resp.msg.clone(),
data: MFAAuthContinueI(resp).into(),
}),
AuthResult::Next(AuthRequest::MFACode { msg }),
/* An MFA auth cannot cache the password. This would
* lead to a potential downgrade to SFA attack (where
* the attacker auths with a stolen password, then
Expand All @@ -512,18 +545,18 @@ impl IdProvider for HimmelblauProvider {
));
}
_ => {
let msg = resp.msg.clone();
let polling_interval = resp.polling_interval.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?;
*cred_handler = AuthCredHandler::MFA {
data: MFAAuthContinueI(resp).into(),
};
return Ok((
AuthResult::Next(AuthRequest::MFAPoll {
msg: resp.msg.clone(),
max_poll_attempts: resp.max_poll_attempts.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?,
polling_interval: resp.polling_interval.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?,
data: MFAAuthContinueI(resp).into(),
msg,
polling_interval,
}),
/* An MFA auth cannot cache the password. This would
* lead to a potential downgrade to SFA attack (where
Expand Down Expand Up @@ -613,11 +646,12 @@ impl IdProvider for HimmelblauProvider {
};
match resp.mfa_method.as_str() {
"PhoneAppNotification" | "PhoneAppOTP" => {
let msg = resp.msg.clone();
*cred_handler = AuthCredHandler::MFA {
data: MFAAuthContinueI(resp).into(),
};
return Ok((
AuthResult::Next(AuthRequest::MFACode {
msg: resp.msg.clone(),
data: MFAAuthContinueI(resp).into(),
}),
AuthResult::Next(AuthRequest::MFACode { msg }),
/* An MFA auth cannot cache the password. This would
* lead to a potential downgrade to SFA attack (where
* the attacker auths with a stolen password, then
Expand All @@ -626,18 +660,18 @@ impl IdProvider for HimmelblauProvider {
));
}
_ => {
let msg = resp.msg.clone();
let polling_interval = resp.polling_interval.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?;
*cred_handler = AuthCredHandler::MFA {
data: MFAAuthContinueI(resp).into(),
};
return Ok((
AuthResult::Next(AuthRequest::MFAPoll {
msg: resp.msg.clone(),
max_poll_attempts: resp.max_poll_attempts.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?,
polling_interval: resp.polling_interval.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?,
data: MFAAuthContinueI(resp).into(),
msg,
polling_interval,
}),
/* An MFA auth cannot cache the password. This would
* lead to a potential downgrade to SFA attack (where
Expand Down Expand Up @@ -737,7 +771,7 @@ impl IdProvider for HimmelblauProvider {
Err(e) => Err(e),
}
}
(_, PamAuthRequest::MFACode { cred, data }) => {
(AuthCredHandler::MFA { data }, PamAuthRequest::MFACode { cred }) => {
let mut token = self
.client
.write()
Expand Down Expand Up @@ -785,36 +819,51 @@ impl IdProvider for HimmelblauProvider {
Err(e) => Err(e),
}
}
(_, PamAuthRequest::MFAPoll { poll_attempt, data }) => {
let mut token = match self
.client
.write()
.await
.acquire_token_by_mfa_flow(
account_id,
None,
Some(poll_attempt),
MFAAuthContinueI::from(data).0,
)
.await
{
Ok(token) => token,
Err(e) => match e {
MsalError::MFAPollContinue => {
return Ok((
AuthResult::Next(AuthRequest::MFAPollWait),
/* An MFA auth cannot cache the password. This would
* lead to a potential downgrade to SFA attack (where
* the attacker auths with a stolen password, then
* disconnects the network to complete the auth). */
AuthCacheAction::None,
));
}
e => {
error!("{:?}", e);
return Err(IdpError::NotFound);
}
},
(AuthCredHandler::MFA { data }, PamAuthRequest::MFAPoll) => {
let flow = MFAAuthContinueI::from(data).0;
let max_poll_attempts = flow.max_poll_attempts.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?;
let polling_interval = flow.polling_interval.ok_or({
error!("Invalid response from the server");
IdpError::BadRequest
})?;
let mut poll_attempt = 1;
let mut token = loop {
if poll_attempt > max_poll_attempts {
error!("MFA polling timed out");
return Err(IdpError::BadRequest);
}
if shutdown_rx_cl.try_recv().ok().is_some() {
debug!("Received a signal to shutdown, bailing MFA poll");
return Err(IdpError::BadRequest);
}
sleep(Duration::from_secs(polling_interval.into()));
match self
.client
.write()
.await
.acquire_token_by_mfa_flow(
account_id,
None,
Some(poll_attempt),
MFAAuthContinueI::from(data).0,
)
.await
{
Ok(token) => break token,
Err(e) => match e {
MsalError::MFAPollContinue => {
poll_attempt += 1;
continue;
}
e => {
error!("{:?}", e);
return Err(IdpError::NotFound);
}
},
}
};
if !self.is_domain_joined(keystore).await {
self.join_domain(tpm, &token, keystore, machine_key)
Expand Down
17 changes: 16 additions & 1 deletion src/daemon/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ async fn handle_client(
let mut reqs = Framed::new(sock, ClientCodec);
let mut pam_auth_session_state = None;

// Setup a broadcast channel so that if we have an unexpected disconnection, we can
// tell consumers to stop work.
let (shutdown_tx, _shutdown_rx) = broadcast::channel(1);

trace!("Waiting for requests ...");
while let Some(Ok(req)) = reqs.next().await {
let resp = match req {
Expand Down Expand Up @@ -296,7 +300,10 @@ async fn handle_client(
}
None => {
match cachelayer
.pam_account_authenticate_init(account_id.as_str())
.pam_account_authenticate_init(
account_id.as_str(),
shutdown_tx.subscribe(),
)
.await
{
Ok((auth_session, pam_auth_response)) => {
Expand Down Expand Up @@ -408,6 +415,14 @@ async fn handle_client(
debug!("flushed response!");
}

// Signal any tasks that they need to stop.
if let Err(shutdown_err) = shutdown_tx.send(()) {
warn!(
?shutdown_err,
"Unable to signal tasks to stop, they will naturally timeout instead."
)
}

// Disconnect them
debug!("Disconnecting client ...");
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/kanidm
Submodule kanidm updated 103 files

0 comments on commit 7195357

Please sign in to comment.