Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 218 additions & 20 deletions keep-frost-net/src/descriptor_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ use sha2::{Digest, Sha256};

use crate::error::{FrostNetError, Result};
use crate::protocol::{
KeySlot, WalletPolicy, DESCRIPTOR_SESSION_TIMEOUT_SECS, MAX_FINGERPRINT_LENGTH, MAX_XPUB_LENGTH,
KeySlot, WalletPolicy, DESCRIPTOR_ACK_PHASE_TIMEOUT_SECS, DESCRIPTOR_CONTRIBUTION_TIMEOUT_SECS,
DESCRIPTOR_FINALIZE_TIMEOUT_SECS, DESCRIPTOR_SESSION_TIMEOUT_SECS, MAX_FINGERPRINT_LENGTH,
MAX_XPUB_LENGTH,
};

const MAX_SESSIONS: usize = 64;
const REAP_GRACE_SECS: u64 = 60;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DescriptorSessionState {
Expand Down Expand Up @@ -51,7 +54,12 @@ pub struct DescriptorSession {
expected_acks: HashSet<u16>,
state: DescriptorSessionState,
created_at: Instant,
contributions_complete_at: Option<Instant>,
finalized_at: Option<Instant>,
timeout: Duration,
contribution_timeout: Duration,
finalize_timeout: Duration,
ack_phase_timeout: Duration,
}

impl DescriptorSession {
Expand All @@ -78,7 +86,12 @@ impl DescriptorSession {
expected_acks,
state: DescriptorSessionState::Proposed,
created_at: Instant::now(),
contributions_complete_at: None,
finalized_at: None,
timeout,
contribution_timeout: Duration::from_secs(DESCRIPTOR_CONTRIBUTION_TIMEOUT_SECS),
finalize_timeout: Duration::from_secs(DESCRIPTOR_FINALIZE_TIMEOUT_SECS),
ack_phase_timeout: Duration::from_secs(DESCRIPTOR_ACK_PHASE_TIMEOUT_SECS),
}
}

Expand All @@ -103,7 +116,9 @@ impl DescriptorSession {
}

pub fn set_initiator(&mut self, initiator: PublicKey) {
self.initiator = Some(initiator);
if self.initiator.is_none() {
self.initiator = Some(initiator);
}
}

pub fn state(&self) -> &DescriptorSessionState {
Expand Down Expand Up @@ -133,17 +148,23 @@ impl DescriptorSession {
if xpub.len() > MAX_XPUB_LENGTH {
return Err(FrostNetError::Session("xpub exceeds maximum length".into()));
}
if fingerprint.len() > MAX_FINGERPRINT_LENGTH {
if fingerprint.len() != MAX_FINGERPRINT_LENGTH
|| !fingerprint.chars().all(|c| c.is_ascii_hexdigit())
{
return Err(FrostNetError::Session(
"fingerprint exceeds maximum length".into(),
"fingerprint must be exactly 8 hex characters".into(),
));
}
let is_mainnet = self.network == "bitcoin";
let valid_prefix = if is_mainnet { "xpub" } else { "tpub" };
if !xpub.starts_with(valid_prefix) {
let valid_prefixes: &[&str] = if is_mainnet {
&["xpub"]
} else {
&["tpub", "Vpub", "Upub"]
};
if !valid_prefixes.iter().any(|p| xpub.starts_with(p)) {
return Err(FrostNetError::Session(format!(
"xpub must start with '{valid_prefix}' for network '{}'",
self.network
"xpub must start with one of {:?} for network '{}'",
valid_prefixes, self.network
)));
}

Expand All @@ -161,6 +182,10 @@ impl DescriptorSession {
},
);

if self.has_all_contributions() {
self.contributions_complete_at = Some(Instant::now());
}

Ok(())
}

Expand Down Expand Up @@ -195,6 +220,7 @@ impl DescriptorSession {

self.descriptor = Some(descriptor);
self.state = DescriptorSessionState::Finalized;
self.finalized_at = Some(Instant::now());
Ok(())
}

Expand Down Expand Up @@ -268,7 +294,43 @@ impl DescriptorSession {
}

pub fn is_expired(&self) -> bool {
self.created_at.elapsed() > self.timeout
self.expired_phase().is_some()
}

pub fn expired_phase(&self) -> Option<&'static str> {
match self.state {
DescriptorSessionState::Complete | DescriptorSessionState::Failed(_) => {
if self.created_at.elapsed() > self.timeout + Duration::from_secs(REAP_GRACE_SECS) {
return Some("reap");
}
None
}
DescriptorSessionState::Proposed => {
if self.created_at.elapsed() > self.timeout {
return Some("session");
}
if let Some(complete_at) = self.contributions_complete_at {
if complete_at.elapsed() > self.finalize_timeout {
return Some("finalize");
}
} else if self.created_at.elapsed() > self.contribution_timeout {
return Some("contribution");
}
None
}
DescriptorSessionState::Finalized => {
if self.created_at.elapsed() > self.timeout {
return Some("session");
}
let Some(fin_at) = self.finalized_at else {
return Some("session");
};
if fin_at.elapsed() > self.ack_phase_timeout {
return Some("ack");
}
None
}
}
}

pub fn is_participant(&self, share_index: u16) -> bool {
Expand All @@ -289,7 +351,9 @@ impl DescriptorSession {
}

pub fn fail(&mut self, reason: String) {
self.state = DescriptorSessionState::Failed(reason);
if !self.is_complete() && !self.is_failed() {
self.state = DescriptorSessionState::Failed(reason);
}
}
}

Expand Down Expand Up @@ -336,7 +400,7 @@ impl DescriptorSessionManager {
self.sessions.remove(&session_id);
}

self.cleanup_expired();
let _ = self.cleanup_expired();

if self.sessions.len() >= MAX_SESSIONS {
return Err(FrostNetError::Session(
Expand All @@ -355,7 +419,9 @@ impl DescriptorSessionManager {
);

self.sessions.insert(session_id, session);
Ok(self.sessions.get_mut(&session_id).unwrap())
self.sessions
.get_mut(&session_id)
.ok_or_else(|| FrostNetError::Session("Failed to retrieve created session".into()))
}

pub fn get_session(&self, session_id: &[u8; 32]) -> Option<&DescriptorSession> {
Expand All @@ -372,8 +438,17 @@ impl DescriptorSessionManager {
self.sessions.remove(session_id);
}

pub fn cleanup_expired(&mut self) {
self.sessions.retain(|_, session| !session.is_expired());
pub fn cleanup_expired(&mut self) -> Vec<([u8; 32], String)> {
let mut expired = Vec::new();
self.sessions.retain(|id, session| {
if let Some(phase) = session.expired_phase() {
expired.push((*id, format!("timeout:{phase}")));
false
} else {
true
}
});
expired
}
}

Expand Down Expand Up @@ -835,13 +910,13 @@ mod tests {
let mut session = test_session();

session
.add_contribution(1, "tpub1".into(), "aa".into())
.add_contribution(1, "tpub1".into(), "aabb0011".into())
.unwrap();
session
.add_contribution(2, "tpub2".into(), "bb".into())
.add_contribution(2, "tpub2".into(), "bbcc2233".into())
.unwrap();
session
.add_contribution(3, "tpub3".into(), "cc".into())
.add_contribution(3, "tpub3".into(), "ccdd4455".into())
.unwrap();

let finalized = FinalizedDescriptor {
Expand All @@ -851,7 +926,7 @@ mod tests {
};
session.set_finalized(finalized).unwrap();

let result = session.add_contribution(1, "tpub1newzzzzzzzzzzzzz".into(), "aa".into());
let result = session.add_contribution(1, "tpub1newzzzzzzzzzzzzz".into(), "aabb0011".into());
assert!(result.is_err());
}

Expand Down Expand Up @@ -949,7 +1024,8 @@ mod tests {
.unwrap();

std::thread::sleep(Duration::from_millis(10));
manager.cleanup_expired();
let expired = manager.cleanup_expired();
assert!(!expired.is_empty());
assert!(manager.get_session(&[1u8; 32]).is_none());
}

Expand Down Expand Up @@ -1001,7 +1077,7 @@ mod tests {

let session = manager.get_session_mut(&[1u8; 32]).unwrap();
session
.add_contribution(1, "tpub1".into(), "aa".into())
.add_contribution(1, "tpub1".into(), "aabb0011".into())
.unwrap();

let session = manager.get_session(&[1u8; 32]).unwrap();
Expand All @@ -1014,6 +1090,128 @@ mod tests {
assert!(manager.get_session(&[0u8; 32]).is_none());
}

#[test]
fn test_contribution_phase_timeout() {
let policy = test_policy();
let contributors: HashSet<u16> = [1, 2, 3].into();
let acks: HashSet<u16> = [1, 2, 3].into();
let mut session = DescriptorSession::new(
[1u8; 32],
[2u8; 32],
policy,
"signet".into(),
contributors,
acks,
Duration::from_secs(600),
);
session.contribution_timeout = Duration::from_millis(1);

session
.add_contribution(1, "tpub1zzzzzzzzzzzzzzz".into(), "aabbccdd".into())
.unwrap();

std::thread::sleep(Duration::from_millis(10));
assert!(session.is_expired());
assert_eq!(session.expired_phase(), Some("contribution"));
}

#[test]
fn test_finalize_phase_timeout() {
let policy = test_policy();
let contributors: HashSet<u16> = [1, 2, 3].into();
let acks: HashSet<u16> = [1, 2, 3].into();
let mut session = DescriptorSession::new(
[1u8; 32],
[2u8; 32],
policy,
"signet".into(),
contributors,
acks,
Duration::from_secs(600),
);
session.finalize_timeout = Duration::from_millis(1);

session
.add_contribution(1, "tpub1zzzzzzzzzzzzzzz".into(), "aabbccdd".into())
.unwrap();
session
.add_contribution(2, "tpub2zzzzzzzzzzzzzzz".into(), "11223344".into())
.unwrap();
session
.add_contribution(3, "tpub3zzzzzzzzzzzzzzz".into(), "55667788".into())
.unwrap();
assert!(session.contributions_complete_at.is_some());

std::thread::sleep(Duration::from_millis(10));
assert!(session.is_expired());
assert_eq!(session.expired_phase(), Some("finalize"));
}

#[test]
fn test_ack_phase_timeout() {
let policy = test_policy();
let contributors: HashSet<u16> = [1, 2, 3].into();
let acks: HashSet<u16> = [1, 2, 3].into();
let mut session = DescriptorSession::new(
[1u8; 32],
[2u8; 32],
policy,
"signet".into(),
contributors,
acks,
Duration::from_secs(600),
);
session.ack_phase_timeout = Duration::from_millis(1);

session
.add_contribution(1, "tpub1zzzzzzzzzzzzzzz".into(), "aabbccdd".into())
.unwrap();
session
.add_contribution(2, "tpub2zzzzzzzzzzzzzzz".into(), "11223344".into())
.unwrap();
session
.add_contribution(3, "tpub3zzzzzzzzzzzzzzz".into(), "55667788".into())
.unwrap();

let finalized = FinalizedDescriptor {
external: "tr(frost_key)".into(),
internal: "tr(frost_key)/1".into(),
policy_hash: [0; 32],
};
session.set_finalized(finalized).unwrap();
assert!(session.finalized_at.is_some());

std::thread::sleep(Duration::from_millis(10));
assert!(session.is_expired());
assert_eq!(session.expired_phase(), Some("ack"));
}

#[test]
fn test_cleanup_returns_phase_reasons() {
let mut manager = DescriptorSessionManager::with_timeout(Duration::from_secs(600));
let policy = test_policy();

{
let session = manager
.create_session(
[1u8; 32],
[2u8; 32],
policy,
"signet".into(),
[1, 2, 3].into(),
[1, 2, 3].into(),
)
.unwrap();
session.contribution_timeout = Duration::from_millis(1);
}

std::thread::sleep(Duration::from_millis(10));
let expired = manager.cleanup_expired();
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].0, [1u8; 32]);
assert_eq!(expired[0].1, "timeout:contribution");
}

#[test]
fn test_duplicate_xpub_across_participants_rejected() {
let mut session = test_session();
Expand Down
Loading