diff --git a/Cargo.lock b/Cargo.lock index b6778eb9..0db66d93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4384,6 +4384,7 @@ dependencies = [ "hex", "keep-core", "keep-frost-net", + "nostr-relay-builder", "nostr-sdk", "rustls", "serde", @@ -4396,6 +4397,7 @@ dependencies = [ "tracing-subscriber", "url", "urlencoding", + "zeroize", ] [[package]] diff --git a/keep-agent-py/Cargo.toml b/keep-agent-py/Cargo.toml index 484b1bd5..23d410b7 100644 --- a/keep-agent-py/Cargo.toml +++ b/keep-agent-py/Cargo.toml @@ -19,3 +19,4 @@ serde_json = "1.0" hex = "0.4" nostr-sdk = { version = "0.44", features = ["nip44"] } k256 = { version = "0.13", features = ["schnorr"] } +zeroize = "1" diff --git a/keep-agent-py/src/lib.rs b/keep-agent-py/src/lib.rs index e3c1acdb..25987694 100644 --- a/keep-agent-py/src/lib.rs +++ b/keep-agent-py/src/lib.rs @@ -7,6 +7,7 @@ use pyo3::prelude::*; use pyo3::exceptions::{PyRuntimeError, PyValueError, PyConnectionError}; use std::sync::Arc; use tokio::sync::Mutex; +use zeroize::Zeroizing; use ::keep_agent::{ RateLimitConfig, SessionConfig, SessionManager, SessionMetadata, @@ -155,7 +156,7 @@ pub struct PyAgentSession { manager: SessionManager, token: SessionToken, session_id: String, - secret_key: Option<[u8; 32]>, + secret_key: Option>, } #[pymethods] @@ -169,9 +170,9 @@ impl PyAgentSession { policy: Option, secret_key: Option, ) -> PyResult { - let secret_bytes: Option<[u8; 32]> = if let Some(sk) = secret_key { - let decoded = hex::decode(&sk) - .map_err(|e| PyValueError::new_err(format!("Invalid secret key hex: {}", e)))?; + let secret_bytes: Option> = if let Some(sk) = secret_key { + let decoded = Zeroizing::new(hex::decode(&sk) + .map_err(|e| PyValueError::new_err(format!("Invalid secret key hex: {}", e)))?); if decoded.len() != 32 { return Err(PyValueError::new_err(format!( "Secret key must be 32 bytes, got {}", @@ -180,14 +181,14 @@ impl PyAgentSession { } let mut arr = [0u8; 32]; arr.copy_from_slice(&decoded); - Some(arr) + Some(Zeroizing::new(arr)) } else { None }; let pubkey_bytes: [u8; 32] = if let Some(ref sk) = secret_bytes { use k256::elliptic_curve::sec1::ToEncodedPoint; - let scalar = k256::NonZeroScalar::try_from(sk.as_slice()) + let scalar = k256::NonZeroScalar::try_from(sk.as_ref().as_slice()) .map_err(|_| PyValueError::new_err("Invalid secret key"))?; let pk = k256::PublicKey::from_secret_scalar(&scalar); let point = pk.to_encoded_point(true); @@ -300,14 +301,15 @@ impl PyAgentSession { session.check_operation(&Operation::SignNostrEvent).map_err(to_py_err)?; session.check_event_kind(kind).map_err(to_py_err)?; - let secret = self.secret_key + let secret = self.secret_key.as_ref() .ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?; self.manager.record_request(&self.session_id).map_err(to_py_err)?; use nostr_sdk::prelude::*; - let keys = Keys::parse(&hex::encode(secret)) + let hex = Zeroizing::new(hex::encode(secret.as_ref())); + let keys = Keys::parse(hex.as_str()) .map_err(to_py_value_err)?; let mut nostr_tags: Vec = Vec::new(); @@ -358,8 +360,8 @@ impl PyAgentSession { session.check_operation(&Operation::SignPsbt).map_err(to_py_err)?; - let mut secret = self.secret_key - .ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?; + let mut secret = Zeroizing::new(*self.secret_key.as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?); let network = match network.unwrap_or("testnet") { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, @@ -371,7 +373,7 @@ impl PyAgentSession { let mut psbt = keep_bitcoin::psbt::parse_psbt_base64(psbt_base64) .map_err(|e| PyValueError::new_err(format!("Invalid PSBT: {}", e)))?; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network) .map_err(to_py_err)?; let analysis = signer.analyze_psbt(&psbt).map_err(to_py_err)?; @@ -408,7 +410,7 @@ impl PyAgentSession { } fn get_public_key(&self) -> PyResult { - let _ = self.secret_key + let _ = self.secret_key.as_ref() .ok_or_else(|| PyRuntimeError::new_err("No secret key configured"))?; let session = self.manager @@ -429,8 +431,8 @@ impl PyAgentSession { session.check_operation(&Operation::GetBitcoinAddress).map_err(to_py_err)?; - let mut secret = self.secret_key - .ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?; + let mut secret = Zeroizing::new(*self.secret_key.as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?); let network = match network.unwrap_or("testnet") { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, @@ -439,7 +441,7 @@ impl PyAgentSession { _ => keep_bitcoin::Network::Testnet, }; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network) .map_err(to_py_err)?; signer.get_receive_address(0).map_err(to_py_err) @@ -524,6 +526,14 @@ impl PyRemoteSession { }).map_err(to_py_err) } + fn switch_relays(&self) -> PyResult>> { + let client = self.client.clone(); + self.runtime.block_on(async { + let mut c = client.lock().await; + c.switch_relays().await + }).map_err(to_py_err) + } + fn disconnect(&self) -> PyResult<()> { let client = self.client.clone(); self.runtime.block_on(async { diff --git a/keep-agent-ts/Cargo.toml b/keep-agent-ts/Cargo.toml index 03d5439c..dfc6706d 100644 --- a/keep-agent-ts/Cargo.toml +++ b/keep-agent-ts/Cargo.toml @@ -19,6 +19,7 @@ serde_json = "1.0" nostr-sdk = "0.44" k256 = { version = "0.13", features = ["schnorr"] } hex = "0.4" +zeroize = "1" [build-dependencies] napi-build = "2" diff --git a/keep-agent-ts/src/lib.rs b/keep-agent-ts/src/lib.rs index 8f63c318..2f2b72da 100644 --- a/keep-agent-ts/src/lib.rs +++ b/keep-agent-ts/src/lib.rs @@ -7,6 +7,7 @@ use napi::bindgen_prelude::*; use napi_derive::napi; use std::sync::Arc; use tokio::sync::Mutex; +use zeroize::Zeroizing; use keep_agent::{ AgentClient, ApprovalStatus, Operation, PendingSession as RustPendingSession, @@ -53,7 +54,7 @@ pub struct KeepAgentSession { manager: SessionManager, token: Arc>, session_id: String, - secret_key: Option<[u8; 32]>, + secret_key: Option>, } #[napi] @@ -66,9 +67,9 @@ impl KeepAgentSession { policy: Option, secret_key: Option, ) -> Result { - let secret_bytes: Option<[u8; 32]> = if let Some(ref sk) = secret_key { - let decoded = hex::decode(sk) - .map_err(|e| Error::from_reason(format!("Invalid secret key hex: {}", e)))?; + let secret_bytes: Option> = if let Some(ref sk) = secret_key { + let decoded = Zeroizing::new(hex::decode(sk) + .map_err(|e| Error::from_reason(format!("Invalid secret key hex: {}", e)))?); if decoded.len() != 32 { return Err(Error::from_reason(format!( "Secret key must be 32 bytes, got {}", @@ -77,14 +78,14 @@ impl KeepAgentSession { } let mut arr = [0u8; 32]; arr.copy_from_slice(&decoded); - Some(arr) + Some(Zeroizing::new(arr)) } else { None }; let pubkey: [u8; 32] = if let Some(ref sk) = secret_bytes { use k256::elliptic_curve::sec1::ToEncodedPoint; - let scalar = k256::NonZeroScalar::try_from(sk.as_slice()) + let scalar = k256::NonZeroScalar::try_from(sk.as_ref().as_slice()) .map_err(|_| Error::from_reason("Invalid secret key"))?; let pk = k256::PublicKey::from_secret_scalar(&scalar); let point = pk.to_encoded_point(true); @@ -304,6 +305,7 @@ impl KeepAgentSession { let secret = self .secret_key + .as_ref() .ok_or_else(|| Error::from_reason("No secret key configured"))?; self.manager @@ -312,7 +314,8 @@ impl KeepAgentSession { use nostr_sdk::prelude::{EventBuilder, Keys, Kind, Tag}; - let keys = Keys::parse(&hex::encode(secret)) + let hex = Zeroizing::new(hex::encode(secret.as_ref())); + let keys = Keys::parse(hex.as_str()) .map_err(|e| napi::Error::from_reason(e.to_string()))?; let nostr_tags: Vec = tags @@ -365,9 +368,10 @@ impl KeepAgentSession { .check_operation(&Operation::SignPsbt) .map_err(|e| Error::from_reason(e.to_string()))?; - let mut secret = self + let mut secret = Zeroizing::new(*self .secret_key - .ok_or_else(|| Error::from_reason("No secret key configured"))?; + .as_ref() + .ok_or_else(|| Error::from_reason("No secret key configured"))?); let network = match network.as_deref().unwrap_or("testnet") { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, @@ -379,7 +383,7 @@ impl KeepAgentSession { let mut psbt = keep_bitcoin::psbt::parse_psbt_base64(&psbt_base64) .map_err(|e| Error::from_reason(format!("Invalid PSBT: {}", e)))?; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network) .map_err(|e| Error::from_reason(e.to_string()))?; let analysis = signer @@ -414,9 +418,7 @@ impl KeepAgentSession { .record_request(&self.session_id) .map_err(|e| Error::from_reason(e.to_string()))?; - signer - .sign_psbt(&mut psbt) - .map_err(|e| Error::from_reason(e.to_string()))?; + signer.sign_psbt(&mut psbt).map_err(|e| Error::from_reason(e.to_string()))?; Ok(keep_bitcoin::psbt::serialize_psbt_base64(&psbt)) } @@ -449,9 +451,10 @@ impl KeepAgentSession { .check_operation(&Operation::GetBitcoinAddress) .map_err(|e| Error::from_reason(e.to_string()))?; - let mut secret = self + let mut secret = Zeroizing::new(*self .secret_key - .ok_or_else(|| Error::from_reason("No secret key configured"))?; + .as_ref() + .ok_or_else(|| Error::from_reason("No secret key configured"))?); let network = match network.as_deref().unwrap_or("testnet") { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, @@ -460,7 +463,7 @@ impl KeepAgentSession { _ => keep_bitcoin::Network::Testnet, }; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network) .map_err(|e| Error::from_reason(e.to_string()))?; signer @@ -534,6 +537,15 @@ impl RemoteSession { .map_err(|e| Error::from_reason(e.to_string())) } + #[napi] + pub async fn switch_relays(&self) -> Result>> { + let mut client = self.client.lock().await; + client + .switch_relays() + .await + .map_err(|e| Error::from_reason(e.to_string())) + } + #[napi] pub async fn disconnect(&self) -> Result<()> { let client = self.client.lock().await; diff --git a/keep-agent/src/client.rs b/keep-agent/src/client.rs index e0428b72..02f4d87a 100644 --- a/keep-agent/src/client.rs +++ b/keep-agent/src/client.rs @@ -1,5 +1,6 @@ // SPDX-FileCopyrightText: © 2026 PrivKey LLC // SPDX-License-Identifier: AGPL-3.0-or-later +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use nostr_sdk::prelude::*; @@ -19,6 +20,7 @@ pub struct PendingSession { relay_url: String, client_keys: Keys, client: Client, + connect_sent: AtomicBool, } impl PendingSession { @@ -56,6 +58,7 @@ impl PendingSession { relay_url, client_keys, client, + connect_sent: AtomicBool::new(false), }) } @@ -67,16 +70,17 @@ impl PendingSession { let encoded_relay = urlencoding::encode(&self.relay_url); format!( "nostrconnect://{}?relay={}&metadata={}", - self.client_keys - .public_key() - .to_bech32() - .unwrap_or_default(), + self.client_keys.public_key().to_hex(), encoded_relay, urlencoding::encode("{\"name\":\"Keep Agent\"}") ) } - pub async fn poll(&self, timeout: Duration) -> Result { + async fn send_connect_once(&self) -> Result<()> { + if self.connect_sent.swap(true, Ordering::AcqRel) { + return Ok(()); + } + let request = serde_json::json!({ "id": &self.request_id, "method": "connect", @@ -109,6 +113,12 @@ impl PendingSession { .await .map_err(|e| AgentError::Nostr(e.to_string()))?; + Ok(()) + } + + pub async fn poll(&self, timeout: Duration) -> Result { + self.send_connect_once().await?; + let filter = Filter::new() .kind(Kind::NostrConnect) .author(self.signer_pubkey) @@ -171,19 +181,19 @@ impl PendingSession { while start.elapsed() < timeout { match self.poll(poll_interval).await? { ApprovalStatus::Approved => { - return Ok(AgentClient { + let mut client = AgentClient { signer_pubkey: self.signer_pubkey, relay_url: self.relay_url.clone(), client_keys: self.client_keys.clone(), client: self.client.clone(), - }); + }; + let _ = client.switch_relays().await; + return Ok(client); } ApprovalStatus::Denied => { return Err(AgentError::AuthFailed("Session request denied".into())); } - ApprovalStatus::Pending => { - tokio::time::sleep(poll_interval).await; - } + ApprovalStatus::Pending => {} } } @@ -197,7 +207,6 @@ impl PendingSession { pub struct AgentClient { signer_pubkey: PublicKey, - #[allow(dead_code)] relay_url: String, client_keys: Keys, client: Client, @@ -230,7 +239,7 @@ impl AgentClient { .await .map_err(|_| AgentError::Connection("Relay connection timeout".into()))?; - let agent_client = Self { + let mut agent_client = Self { signer_pubkey, relay_url, client_keys, @@ -238,6 +247,7 @@ impl AgentClient { }; agent_client.send_connect(secret.as_deref()).await?; + let _ = agent_client.switch_relays().await; Ok(agent_client) } @@ -279,6 +289,9 @@ impl AgentClient { AgentError::Connection("Missing relay parameter in bunker URL".into()) })?; + keep_core::relay::validate_relay_url(&relay_url) + .map_err(|e| AgentError::Connection(format!("Invalid relay URL: {e}")))?; + Ok((signer_pubkey, relay_url, secret)) } @@ -373,7 +386,86 @@ impl AgentClient { Ok(false) } + pub async fn switch_relays(&mut self) -> Result>> { + let request = serde_json::json!({ + "id": generate_uuid(), + "method": "switch_relays", + "params": [] + }); + + let response = self.send_request(&request.to_string()).await?; + let parsed: serde_json::Value = serde_json::from_str(&response) + .map_err(|e| AgentError::Serialization(e.to_string()))?; + + if let Some(error) = parsed.get("error") { + if !error.is_null() { + return Err(AgentError::Nostr(error.to_string())); + } + } + + let result = parsed + .get("result") + .ok_or_else(|| AgentError::Serialization("Missing result field".into()))?; + + if result.is_null() || (result.is_string() && result.as_str() == Some("null")) { + return Ok(None); + } + + let relays: Vec = if result.is_string() { + serde_json::from_str(result.as_str().unwrap()) + .map_err(|e| AgentError::Serialization(format!("Invalid relay list: {e}")))? + } else if result.is_array() { + serde_json::from_value(result.clone()) + .map_err(|e| AgentError::Serialization(format!("Invalid relay list: {e}")))? + } else { + return Err(AgentError::Serialization( + "Unexpected switch_relays result format".into(), + )); + }; + + if relays.is_empty() { + return Ok(None); + } + + let valid_relays: Vec = relays + .into_iter() + .filter(|r| keep_core::relay::validate_relay_url(r).is_ok()) + .collect(); + + if valid_relays.is_empty() { + return Ok(None); + } + + self.client.disconnect().await; + self.client.remove_all_relays().await; + let mut added = Vec::new(); + for relay in &valid_relays { + if self.client.add_relay(relay).await.is_ok() { + added.push(relay.clone()); + } + } + if added.is_empty() { + return Err(AgentError::Connection( + "Failed to add any relay during switch".into(), + )); + } + self.client.connect().await; + + self.relay_url = added[0].clone(); + + Ok(Some(added)) + } + async fn send_request(&self, content: &str) -> Result { + let request_id = { + let parsed: serde_json::Value = serde_json::from_str(content) + .map_err(|e| AgentError::Serialization(e.to_string()))?; + parsed + .get("id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }; + let encrypted = nip44::encrypt( self.client_keys.secret_key(), &self.signer_pubkey, @@ -424,6 +516,16 @@ impl AgentClient { &self.signer_pubkey, &event.content, ) { + if let Some(ref expected_id) = request_id { + if let Ok(resp) = + serde_json::from_str::(&decrypted) + { + if resp.get("id").and_then(|v| v.as_str()) != Some(expected_id) + { + continue; + } + } + } return Ok(decrypted); } } @@ -491,4 +593,35 @@ mod tests { assert_ne!(uuid1, uuid2); assert_eq!(uuid1.len(), 36); } + + #[test] + fn test_switch_relays_parse() { + let null_response = r#"{"id":"abc","result":null}"#; + let parsed: serde_json::Value = serde_json::from_str(null_response).unwrap(); + let result = parsed.get("result").unwrap(); + assert!(result.is_null()); + + let null_string_response = r#"{"id":"abc","result":"null"}"#; + let parsed: serde_json::Value = serde_json::from_str(null_string_response).unwrap(); + let result = parsed.get("result").unwrap(); + assert!(result.is_string() && result.as_str() == Some("null")); + + let array_response = + r#"{"id":"abc","result":["wss://relay1.example.com","wss://relay2.example.com"]}"#; + let parsed: serde_json::Value = serde_json::from_str(array_response).unwrap(); + let result = parsed.get("result").unwrap(); + assert!(result.is_array()); + let relays: Vec = serde_json::from_value(result.clone()).unwrap(); + assert_eq!(relays.len(), 2); + assert_eq!(relays[0], "wss://relay1.example.com"); + assert_eq!(relays[1], "wss://relay2.example.com"); + + let string_array_response = r#"{"id":"abc","result":"[\"wss://relay1.example.com\",\"wss://relay2.example.com\"]"}"#; + let parsed: serde_json::Value = serde_json::from_str(string_array_response).unwrap(); + let result = parsed.get("result").unwrap(); + assert!(result.is_string()); + let relays: Vec = serde_json::from_str(result.as_str().unwrap()).unwrap(); + assert_eq!(relays.len(), 2); + assert_eq!(relays[0], "wss://relay1.example.com"); + } } diff --git a/keep-agent/src/manager.rs b/keep-agent/src/manager.rs index 663f8499..cc3c1ff9 100644 --- a/keep-agent/src/manager.rs +++ b/keep-agent/src/manager.rs @@ -7,6 +7,8 @@ use crate::error::{AgentError, Result}; use crate::scope::Operation; use crate::session::{AgentSession, SessionConfig, SessionMetadata, SessionToken}; +const MAX_SESSIONS: usize = 128; + pub struct SessionManager { sessions: Arc>>, pubkey: [u8; 32], @@ -34,6 +36,14 @@ impl SessionManager { .write() .map_err(|_| AgentError::Other("Failed to acquire session lock".into()))?; + sessions.retain(|_, s| !s.is_expired()); + + if sessions.len() >= MAX_SESSIONS { + return Err(AgentError::Other(format!( + "Maximum session count ({MAX_SESSIONS}) reached" + ))); + } + sessions.insert(session_id.clone(), session); Ok((token, session_id)) } diff --git a/keep-agent/src/mcp/server.rs b/keep-agent/src/mcp/server.rs index 9dc11d6f..a26f2a54 100644 --- a/keep-agent/src/mcp/server.rs +++ b/keep-agent/src/mcp/server.rs @@ -8,6 +8,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::sync::RwLock; +use zeroize::Zeroizing; + use crate::error::{AgentError, Result}; use crate::manager::SessionManager; use crate::scope::Operation; @@ -45,7 +47,7 @@ pub struct McpServer { version: String, session_manager: Arc>>, manager: SessionManager, - secret_key: Option<[u8; 32]>, + secret_key: Option>, } impl McpServer { @@ -65,7 +67,7 @@ impl McpServer { version: env!("CARGO_PKG_VERSION").into(), session_manager: Arc::new(RwLock::new(None)), manager: SessionManager::new(pubkey), - secret_key: Some(secret), + secret_key: Some(Zeroizing::new(secret)), } } @@ -265,10 +267,11 @@ impl McpServer { .and_then(|v| serde_json::from_value(v.clone()).ok()) .unwrap_or_default(); - if let Some(secret) = self.secret_key { + if let Some(ref secret) = self.secret_key { use nostr_sdk::prelude::*; - let keys = Keys::parse(&hex::encode(secret)) + let hex_secret = Zeroizing::new(hex::encode(**secret)); + let keys = Keys::parse(hex_secret.as_str()) .map_err(|e| AgentError::Other(e.to_string()))?; let nostr_tags: Vec = tags @@ -322,7 +325,8 @@ impl McpServer { .and_then(|v| v.as_str()) .unwrap_or("testnet"); - if let Some(mut secret) = self.secret_key { + if let Some(ref secret) = self.secret_key { + let mut secret_copy = Zeroizing::new(**secret); let network = match network_str { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, "signet" => keep_bitcoin::Network::Signet, @@ -333,7 +337,7 @@ impl McpServer { let mut psbt = keep_bitcoin::psbt::parse_psbt_base64(psbt_base64) .map_err(|e| AgentError::Other(format!("Invalid PSBT: {e}")))?; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut secret_copy, network) .map_err(|e| AgentError::Other(e.to_string()))?; let analysis = signer @@ -416,7 +420,8 @@ impl McpServer { .and_then(|v| v.as_str()) .unwrap_or("testnet"); - if let Some(mut secret) = self.secret_key { + if let Some(ref secret) = self.secret_key { + let mut secret_copy = Zeroizing::new(**secret); let network = match network_str { "mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin, "signet" => keep_bitcoin::Network::Signet, @@ -424,7 +429,7 @@ impl McpServer { _ => keep_bitcoin::Network::Testnet, }; - let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network) + let signer = keep_bitcoin::BitcoinSigner::new(&mut secret_copy, network) .map_err(|e| AgentError::Other(e.to_string()))?; let address = signer @@ -449,7 +454,9 @@ impl McpServer { _ => ToolResult::error(format!("Unknown tool: {name}")), }; - self.manager.record_request(&session_id)?; + if result.success { + self.manager.record_request(&session_id)?; + } let text = serde_json::to_string(&result.content) .map_err(|e| AgentError::Serialization(e.to_string()))?; diff --git a/keep-desktop/src/app.rs b/keep-desktop/src/app.rs index ab747444..b94e1ce8 100644 --- a/keep-desktop/src/app.rs +++ b/keep-desktop/src/app.rs @@ -83,6 +83,7 @@ pub(crate) const RECONNECT_MAX_MS: u64 = 30_000; pub(crate) const RECONNECT_MAX_ATTEMPTS: u32 = 10; pub(crate) const BUNKER_APPROVAL_TIMEOUT: Duration = Duration::from_secs(60); pub(crate) const MAX_BUNKER_LOG_ENTRIES: usize = 1000; +pub(crate) const MAX_ACTIVE_COORDINATIONS: usize = 64; const DEFAULT_BUNKER_RELAYS: &[&str] = &["wss://relay.damus.io", "wss://relay.nsec.app"]; @@ -108,6 +109,7 @@ pub struct Toast { pub(crate) struct ActiveCoordination { pub group_pubkey: [u8; 32], pub network: String, + pub is_initiator: bool, } pub struct App { @@ -1306,20 +1308,40 @@ impl App { Message::WalletBeginCoordination => self.begin_descriptor_coordination(), Message::WalletSessionStarted(result) => { match result { - Ok((session_id, group_pubkey, network)) => { - if let Screen::Wallet(WalletScreen { setup: Some(s), .. }) = - &mut self.screen - { + Ok((session_id, group_pubkey, network, _expected_participants)) => { + let on_wallet_screen = matches!( + self.screen, + Screen::Wallet(WalletScreen { setup: Some(_), .. }) + ); + if !on_wallet_screen { + if let Some(node) = self.get_frost_node() { + node.cancel_descriptor_session(&session_id); + } + } else if self.active_coordinations.len() >= MAX_ACTIVE_COORDINATIONS { + if let Some(node) = self.get_frost_node() { + node.cancel_descriptor_session(&session_id); + } + if let Screen::Wallet(WalletScreen { setup: Some(s), .. }) = + &mut self.screen + { + s.phase = SetupPhase::Coordinating(DescriptorProgress::Failed( + "Too many active coordinations".to_string(), + )); + } + } else { self.active_coordinations.insert( session_id, ActiveCoordination { group_pubkey, network, + is_initiator: true, }, ); - s.session_id = Some(session_id); - } else if let Some(node) = self.get_frost_node() { - node.cancel_descriptor_session(&session_id); + if let Screen::Wallet(WalletScreen { setup: Some(s), .. }) = + &mut self.screen + { + s.session_id = Some(session_id); + } } } Err(e) => { @@ -1472,7 +1494,12 @@ impl App { .await .map_err(|e| format!("{e}"))?; - Ok::<([u8; 32], [u8; 32], String), String>((session_id, group_pubkey, net)) + Ok::<([u8; 32], [u8; 32], String, usize), String>(( + session_id, + group_pubkey, + net, + expected_total, + )) }, Message::WalletSessionStarted, ) diff --git a/keep-desktop/src/frost.rs b/keep-desktop/src/frost.rs index 36f63af4..e830d889 100644 --- a/keep-desktop/src/frost.rs +++ b/keep-desktop/src/frost.rs @@ -16,9 +16,9 @@ use keep_core::Keep; use crate::app::{ friendly_err, lock_keep, with_keep_blocking, ActiveCoordination, App, ToastKind, - MAX_PENDING_REQUESTS, MAX_REQUESTS_PER_PEER, RATE_LIMIT_GLOBAL, RATE_LIMIT_PER_PEER, - RATE_LIMIT_WINDOW_SECS, RECONNECT_BASE_MS, RECONNECT_MAX_ATTEMPTS, RECONNECT_MAX_MS, - SIGNING_RESPONSE_TIMEOUT, + MAX_ACTIVE_COORDINATIONS, MAX_PENDING_REQUESTS, MAX_REQUESTS_PER_PEER, RATE_LIMIT_GLOBAL, + RATE_LIMIT_PER_PEER, RATE_LIMIT_WINDOW_SECS, RECONNECT_BASE_MS, RECONNECT_MAX_ATTEMPTS, + RECONNECT_MAX_MS, SIGNING_RESPONSE_TIMEOUT, }; use crate::message::{ConnectionStatus, FrostNodeMsg, Message, PeerEntry, PendingSignRequest}; use crate::screen::relay::RelayScreen; @@ -705,34 +705,45 @@ impl App { ); } FrostNodeMsg::DescriptorReady { session_id } => { + let is_initiator = self + .active_coordinations + .get(&session_id) + .is_some_and(|c| c.is_initiator); + if !is_initiator { + return iced::Task::none(); + } self.update_wallet_setup(&session_id, |setup| { setup.phase = SetupPhase::Coordinating(DescriptorProgress::Finalizing); }); - if self.active_coordinations.contains_key(&session_id) { - let Some(node) = self.get_frost_node() else { - return iced::Task::none(); - }; - return iced::Task::perform( - async move { - node.build_and_finalize_descriptor(session_id) - .await - .map_err(|e| format!("{e}")) - }, - move |result| match result { - Ok(expected_acks) => Message::WalletDescriptorProgress( - DescriptorProgress::WaitingAcks { - received: 0, - expected: expected_acks, - }, - Some(session_id), - ), - Err(e) => Message::WalletDescriptorProgress( - DescriptorProgress::Failed(e), - Some(session_id), - ), - }, - ); - } + let Some(node) = self.get_frost_node() else { + self.active_coordinations.remove(&session_id); + self.update_wallet_setup(&session_id, |setup| { + setup.phase = SetupPhase::Coordinating(DescriptorProgress::Failed( + "Node unavailable".to_string(), + )); + }); + return iced::Task::none(); + }; + return iced::Task::perform( + async move { + node.build_and_finalize_descriptor(session_id) + .await + .map_err(|e| format!("{e}")) + }, + move |result| match result { + Ok(expected_acks) => Message::WalletDescriptorProgress( + DescriptorProgress::WaitingAcks { + received: 0, + expected: expected_acks, + }, + Some(session_id), + ), + Err(e) => Message::WalletDescriptorProgress( + DescriptorProgress::Failed(e), + Some(session_id), + ), + }, + ); } FrostNodeMsg::DescriptorContributed { session_id, .. } => { self.update_wallet_setup(&session_id, |setup| { @@ -819,11 +830,17 @@ impl App { return iced::Task::none(); } + if self.active_coordinations.len() >= MAX_ACTIVE_COORDINATIONS { + tracing::warn!("Dropping descriptor contribution: too many active coordinations"); + return iced::Task::none(); + } + self.active_coordinations.insert( session_id, ActiveCoordination { group_pubkey: share.group_pubkey, network: network.clone(), + is_initiator: false, }, ); @@ -926,6 +943,7 @@ impl App { self.frost_status = ConnectionStatus::Disconnected; self.frost_peers.clear(); self.pending_sign_display.clear(); + self.active_coordinations.clear(); self.frost_reconnect_attempts = 0; self.frost_reconnect_at = None; if let Ok(mut guard) = self.frost_node.lock() { @@ -936,7 +954,6 @@ impl App { let _ = entry.response_tx.try_send(false); } } - self.active_coordinations.clear(); if let Some(s) = self.relay_screen_mut() { s.status = ConnectionStatus::Disconnected; s.peers.clear(); diff --git a/keep-desktop/src/message.rs b/keep-desktop/src/message.rs index e418edd4..9997e933 100644 --- a/keep-desktop/src/message.rs +++ b/keep-desktop/src/message.rs @@ -3,6 +3,7 @@ use std::fmt; +use keep_frost_net::AnnouncedXpub; use zeroize::Zeroizing; use crate::screen::shares::ShareEntry; @@ -198,7 +199,7 @@ pub enum Message { WalletRemoveTier(usize), WalletBeginCoordination, WalletCancelSetup, - WalletSessionStarted(Result<([u8; 32], [u8; 32], String), String>), + WalletSessionStarted(Result<([u8; 32], [u8; 32], String, usize), String>), WalletDescriptorProgress(DescriptorProgress, Option<[u8; 32]>), // Relay / FROST RelayUrlChanged(String), @@ -320,7 +321,7 @@ pub enum FrostNodeMsg { }, XpubAnnounced { share_index: u16, - recovery_xpubs: Vec, + recovery_xpubs: Vec, }, } diff --git a/keep-desktop/src/screen/wallet.rs b/keep-desktop/src/screen/wallet.rs index aab3811f..ab82fb9c 100644 --- a/keep-desktop/src/screen/wallet.rs +++ b/keep-desktop/src/screen/wallet.rs @@ -449,7 +449,9 @@ impl WalletScreen { .size(theme::size::TINY) .color(theme::color::TEXT_DIM); - let created = DateTime::::from_timestamp(entry.created_at as i64, 0) + let created = i64::try_from(entry.created_at) + .ok() + .and_then(|ts| DateTime::::from_timestamp(ts, 0)) .map(|dt| dt.format("%Y-%m-%d %H:%M UTC").to_string()) .unwrap_or_else(|| entry.created_at.to_string()); let created_text = text(format!("Created: {created}")) diff --git a/keep-frost-net/src/descriptor_session.rs b/keep-frost-net/src/descriptor_session.rs index 5201e626..dafddd54 100644 --- a/keep-frost-net/src/descriptor_session.rs +++ b/keep-frost-net/src/descriptor_session.rs @@ -8,6 +8,7 @@ use keep_bitcoin::recovery::{RecoveryConfig, RecoveryTier as BitcoinRecoveryTier use keep_bitcoin::{xpub_to_x_only, DescriptorExport, Network}; use nostr_sdk::PublicKey; use sha2::{Digest, Sha256}; +use subtle::ConstantTimeEq; use crate::error::{FrostNetError, Result}; use crate::protocol::{ @@ -241,12 +242,13 @@ impl DescriptorSession { Ok(()) } + /// Returns `Ok(true)` when the ack was new, `Ok(false)` for duplicates. pub fn add_ack( &mut self, share_index: u16, descriptor_hash: [u8; 32], key_proof_psbt: &[u8], - ) -> Result<()> { + ) -> Result { if self.state != DescriptorSessionState::Finalized { return Err(FrostNetError::Session("Not accepting ACKs".into())); } @@ -258,7 +260,7 @@ impl DescriptorSession { } if self.acks.contains(&share_index) { - return Ok(()); + return Ok(false); } let finalized = self @@ -272,7 +274,7 @@ impl DescriptorSession { hasher.update(finalized.policy_hash); let expected_hash: [u8; 32] = hasher.finalize().into(); - if descriptor_hash != expected_hash { + if !bool::from(descriptor_hash.ct_eq(&expected_hash)) { return Err(FrostNetError::Session("Descriptor hash mismatch".into())); } @@ -295,7 +297,7 @@ impl DescriptorSession { self.state = DescriptorSessionState::Complete; } - Ok(()) + Ok(true) } pub fn has_all_acks(&self) -> bool { diff --git a/keep-frost-net/src/node/descriptor.rs b/keep-frost-net/src/node/descriptor.rs index 84d6525b..49337959 100644 --- a/keep-frost-net/src/node/descriptor.rs +++ b/keep-frost-net/src/node/descriptor.rs @@ -821,18 +821,19 @@ impl KfpNode { .ok_or_else(|| FrostNetError::UntrustedPeer(sender.to_string()))? }; - let (is_complete, ack_count, expected_acks) = { + let (is_new, is_complete, ack_count, expected_acks) = { let mut sessions = self.descriptor_sessions.write(); let session = sessions .get_session_mut(&payload.session_id) .ok_or_else(|| FrostNetError::Session("unknown descriptor session".into()))?; - session.add_ack( + let is_new = session.add_ack( share_index, payload.descriptor_hash, &payload.key_proof_psbt, )?; ( + is_new, session.is_complete(), session.ack_count(), session.expected_ack_count(), @@ -848,14 +849,16 @@ impl KfpNode { "Received descriptor ACK" ); - let _ = self.event_tx.send(KfpNodeEvent::DescriptorAcked { - session_id: payload.session_id, - share_index, - ack_count, - expected_acks, - }); + if is_new { + let _ = self.event_tx.send(KfpNodeEvent::DescriptorAcked { + session_id: payload.session_id, + share_index, + ack_count, + expected_acks, + }); + } - if is_complete { + if is_new && is_complete { let sessions = self.descriptor_sessions.read(); if let Some(session) = sessions.get_session(&payload.session_id) { if let Some(desc) = session.descriptor() { @@ -928,16 +931,21 @@ impl KfpNode { let digest: [u8; 32] = hasher.finalize().into(); let dedup_key = (payload.share_index, payload.created_at, digest); let mut seen = self.seen_xpub_announces.write(); - if seen.contains(&dedup_key) { + if !seen.insert(dedup_key) { return Ok(()); } - if seen.len() >= 10_000 { - tracing::warn!("seen_xpub_announces at capacity, evicting oldest entry"); - if let Some(&oldest) = seen.iter().min_by_key(|&(_, ts, _)| ts) { - seen.remove(&oldest); + const MAX_SEEN_XPUB_ANNOUNCES: usize = 10_000; + if seen.len() > MAX_SEEN_XPUB_ANNOUNCES { + let now = chrono::Utc::now().timestamp().max(0) as u64; + let window = self + .replay_window_secs + .saturating_add(super::MAX_FUTURE_SKEW_SECS); + seen.retain(|&(_, ts, _)| now.saturating_sub(window) <= ts); + if seen.len() > MAX_SEEN_XPUB_ANNOUNCES { + seen.clear(); + seen.insert(dedup_key); } } - seen.insert(dedup_key); } { diff --git a/keep-nip46/Cargo.toml b/keep-nip46/Cargo.toml index b72a12ef..d29ff012 100644 --- a/keep-nip46/Cargo.toml +++ b/keep-nip46/Cargo.toml @@ -22,7 +22,9 @@ tracing.workspace = true thiserror.workspace = true subtle.workspace = true sha2 = "0.10" +zeroize.workspace = true [dev-dependencies] +nostr-relay-builder = "0.44" rustls = { version = "0.23", features = ["ring"] } tracing-subscriber = "0.3" diff --git a/keep-nip46/src/audit.rs b/keep-nip46/src/audit.rs index 4e8ea1aa..cbf5603e 100644 --- a/keep-nip46/src/audit.rs +++ b/keep-nip46/src/audit.rs @@ -21,6 +21,7 @@ pub enum AuditAction { Nip44Encrypt, Nip44Decrypt, PermissionChanged, + SwitchRelays, PermissionDenied, UserRejected, } @@ -36,6 +37,7 @@ impl std::fmt::Display for AuditAction { Self::Nip04Decrypt => write!(f, "nip04_decrypt"), Self::Nip44Encrypt => write!(f, "nip44_encrypt"), Self::Nip44Decrypt => write!(f, "nip44_decrypt"), + Self::SwitchRelays => write!(f, "switch_relays"), Self::PermissionChanged => write!(f, "permission_changed"), Self::PermissionDenied => write!(f, "permission_denied"), Self::UserRejected => write!(f, "user_rejected"), diff --git a/keep-nip46/src/handler.rs b/keep-nip46/src/handler.rs index 5fcb2e56..f48e6eb0 100644 --- a/keep-nip46/src/handler.rs +++ b/keep-nip46/src/handler.rs @@ -10,6 +10,7 @@ use sha2::{Digest, Sha256}; use subtle::ConstantTimeEq; use tokio::sync::Mutex; use tracing::{debug, info, warn}; +use zeroize::Zeroizing; use keep_core::error::{CryptoError, KeepError, Result}; use keep_core::keyring::Keyring; @@ -96,7 +97,7 @@ pub struct SignerHandler { rate_limiters: Mutex>, rate_limit_config: Option, new_conn_timestamps: Mutex>, - expected_secret: Option, + expected_secret: Option>, auto_approve: bool, relay_urls: Vec, kill_switch: Arc, @@ -127,7 +128,7 @@ impl SignerHandler { } pub fn with_expected_secret(mut self, secret: String) -> Self { - self.expected_secret = Some(secret); + self.expected_secret = Some(Zeroizing::new(secret)); self } @@ -482,6 +483,8 @@ impl SignerHandler { self.check_rate_limit(&app_pubkey).await?; self.require_permission(&app_pubkey, Permission::NIP44_ENCRYPT) .await?; + self.require_approval(app_pubkey, "nip44_encrypt").await?; + self.check_kill_switch()?; let secret = self.primary_secret_key().await?; let ciphertext = nip44::encrypt(&secret, &recipient, plaintext, nip44::Version::V2) .map_err(|e| CryptoError::encryption(format!("NIP-44: {e}")))?; @@ -528,6 +531,8 @@ impl SignerHandler { self.check_rate_limit(&app_pubkey).await?; self.require_permission(&app_pubkey, Permission::NIP04_ENCRYPT) .await?; + self.require_approval(app_pubkey, "nip04_encrypt").await?; + self.check_kill_switch()?; let secret = self.primary_secret_key().await?; let ciphertext = nip04::encrypt(&secret, &recipient, plaintext) .map_err(|e| CryptoError::encryption(format!("NIP-04: {e}")))?; @@ -565,13 +570,15 @@ impl SignerHandler { } pub async fn handle_switch_relays(&self, app_pubkey: PublicKey) -> Result>> { + self.check_kill_switch()?; self.check_rate_limit(&app_pubkey).await?; self.require_permission(&app_pubkey, Permission::GET_PUBLIC_KEY) .await?; - self.audit.lock().await.log( - AuditEntry::new(AuditAction::GetPublicKey, app_pubkey).with_reason("switch_relays"), - ); + self.audit + .lock() + .await + .log(AuditEntry::new(AuditAction::SwitchRelays, app_pubkey)); let relays = (!self.relay_urls.is_empty()).then(|| self.relay_urls.clone()); Ok(relays) diff --git a/keep-nip46/src/server.rs b/keep-nip46/src/server.rs index 32819f50..fa130340 100644 --- a/keep-nip46/src/server.rs +++ b/keep-nip46/src/server.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use nostr_sdk::prelude::*; use tokio::sync::Mutex; use tracing::{debug, error, info, warn}; +use zeroize::Zeroizing; use keep_core::error::{CryptoError, KeepError, NetworkError, StorageError}; use keep_core::keyring::Keyring; @@ -51,7 +52,7 @@ pub struct Server { running: bool, callbacks: Option>, config: ServerConfig, - bunker_secret: Option, + bunker_secret: Option>, } async fn add_relays(client: &Client, relay_urls: &[String]) -> Result<()> { @@ -75,7 +76,7 @@ fn finalize_handler( mut handler: SignerHandler, config: &ServerConfig, relay_urls: &[String], -) -> (SignerHandler, Option) { +) -> (SignerHandler, Option>) { handler = handler.with_relay_urls(relay_urls.to_vec()); if let Some(ref rl_config) = config.rate_limit { handler = handler.with_rate_limit(rl_config.clone()); @@ -87,7 +88,7 @@ fn finalize_handler( let secret = hex::encode(keep_core::crypto::random_bytes::<16>()); warn!("headless mode: bunker secret required for authentication"); handler = handler.with_expected_secret(secret.clone()); - Some(secret) + Some(Zeroizing::new(secret)) } else if let Some(ref secret) = config.expected_secret { handler = handler.with_expected_secret(secret.clone()); None @@ -105,7 +106,7 @@ impl Server { handler: SignerHandler, callbacks: Option>, config: ServerConfig, - bunker_secret: Option, + bunker_secret: Option>, ) -> Self { Self { keys, @@ -339,7 +340,7 @@ impl Server { generate_bunker_url( &self.keys.public_key(), &self.relay_urls, - self.bunker_secret.as_deref(), + self.bunker_secret.as_ref().map(|s| s.as_str()), ) } @@ -462,6 +463,11 @@ impl Server { return Err(KeepError::InvalidInput("invalid request ID".into())); } + const MAX_NIP46_PARAMS: usize = 10; + if request.params.len() > MAX_NIP46_PARAMS { + return Err(KeepError::InvalidInput("too many request params".into())); + } + debug!(method = %request.method, app_id, "NIP-46 request"); let method = request.method.clone(); diff --git a/keep-nip46/tests/bunker_integration.rs b/keep-nip46/tests/bunker_integration.rs index ed92a0d8..c186d082 100644 --- a/keep-nip46/tests/bunker_integration.rs +++ b/keep-nip46/tests/bunker_integration.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::Duration; +use nostr_relay_builder::prelude::*; use nostr_sdk::prelude::*; use tokio::sync::Mutex; @@ -37,7 +38,8 @@ async fn test_bunker_e2e_connect_and_sign() { .install_default() .ok(); - let relay_url = "wss://relay.damus.io".to_string(); + let mock_relay = MockRelay::run().await.expect("Failed to start mock relay"); + let relay_url = mock_relay.url().await.to_string(); let (keyring, signer_pubkey) = setup_keyring(); let mut server = Server::new_with_config( @@ -71,14 +73,14 @@ async fn test_bunker_e2e_connect_and_sign() { let _ = server.run().await; }); - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(1)).await; let client_keys = Keys::generate(); let client = Client::new(client_keys.clone()); client.add_relay(relay_url).await.unwrap(); client.connect().await; - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(1)).await; // Send connect request with bunker secret let connect_req = serde_json::json!({ @@ -103,7 +105,7 @@ async fn test_bunker_e2e_connect_and_sign() { client.send_event(&connect_event).await.unwrap(); - tokio::time::sleep(Duration::from_secs(3)).await; + tokio::time::sleep(Duration::from_secs(2)).await; // Verify the client connected let clients = server_handler.list_clients().await; @@ -164,31 +166,9 @@ async fn test_bunker_e2e_connect_and_sign() { client.send_event(&sign_event).await.unwrap(); - // Wait for relay to deliver responses - retry a few times - let mut request_count = 0; - for _ in 0..10 { - tokio::time::sleep(Duration::from_secs(1)).await; - let apps = server_handler.list_clients().await; - if let Some(app) = apps.first() { - request_count = app.request_count; - if request_count > 0 { - break; - } - } - } - // Cleanup server_handle.abort(); client.disconnect().await; - - println!("NIP-46 bunker integration test passed!"); - println!(" - Server started with bunker URL: {bunker_url}"); - println!( - " - Client connected successfully ({} clients)", - clients.len() - ); - println!(" - get_public_key and sign_event requests sent"); - println!(" - {request_count} request(s) recorded via relay round-trip"); } #[tokio::test] @@ -197,7 +177,8 @@ async fn test_bunker_rejects_without_auto_approve() { .install_default() .ok(); - let relay_url = "wss://relay.damus.io".to_string(); + let mock_relay = MockRelay::run().await.expect("Failed to start mock relay"); + let relay_url = mock_relay.url().await.to_string(); let (keyring, _signer_pubkey) = setup_keyring(); // auto_approve: false, no callbacks = rejects by default @@ -226,8 +207,6 @@ async fn test_bunker_rejects_without_auto_approve() { result.is_err(), "connect should be rejected without callbacks or auto_approve" ); - - println!("Rejection test passed: connect properly denied without auto_approve"); } #[tokio::test] @@ -236,7 +215,8 @@ async fn test_bunker_permission_scoping() { .install_default() .ok(); - let relay_url = "wss://relay.damus.io".to_string(); + let mock_relay = MockRelay::run().await.expect("Failed to start mock relay"); + let relay_url = mock_relay.url().await.to_string(); let (keyring, signer_pubkey) = setup_keyring(); let server = Server::new_with_config( @@ -286,8 +266,4 @@ async fn test_bunker_permission_scoping() { sign_result.is_err(), "sign_event should be denied when only get_public_key requested" ); - - println!("Permission scoping test passed:"); - println!(" - get_public_key: allowed"); - println!(" - sign_event: correctly denied"); }