Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions keep-agent-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ serde_json = "1.0"
hex = "0.4"
nostr-sdk = { version = "0.44", features = ["nip44"] }
k256 = { version = "0.13", features = ["schnorr"] }
zeroize = "1"
40 changes: 25 additions & 15 deletions keep-agent-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::prelude::*;
use pyo3::exceptions::{PyRuntimeError, PyValueError, PyConnectionError};
use std::sync::Arc;
use tokio::sync::Mutex;
use zeroize::Zeroizing;

use ::keep_agent::{
RateLimitConfig, SessionConfig, SessionManager, SessionMetadata,
Expand Down Expand Up @@ -155,7 +156,7 @@ pub struct PyAgentSession {
manager: SessionManager,
token: SessionToken,
session_id: String,
secret_key: Option<[u8; 32]>,
secret_key: Option<Zeroizing<[u8; 32]>>,
}

#[pymethods]
Expand All @@ -169,9 +170,9 @@ impl PyAgentSession {
policy: Option<String>,
secret_key: Option<String>,
) -> PyResult<Self> {
let secret_bytes: Option<[u8; 32]> = if let Some(sk) = secret_key {
let decoded = hex::decode(&sk)
.map_err(|e| PyValueError::new_err(format!("Invalid secret key hex: {}", e)))?;
let secret_bytes: Option<Zeroizing<[u8; 32]>> = if let Some(sk) = secret_key {
let decoded = Zeroizing::new(hex::decode(&sk)
.map_err(|e| PyValueError::new_err(format!("Invalid secret key hex: {}", e)))?);
if decoded.len() != 32 {
return Err(PyValueError::new_err(format!(
"Secret key must be 32 bytes, got {}",
Expand All @@ -180,14 +181,14 @@ impl PyAgentSession {
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&decoded);
Some(arr)
Some(Zeroizing::new(arr))
} else {
None
};

let pubkey_bytes: [u8; 32] = if let Some(ref sk) = secret_bytes {
use k256::elliptic_curve::sec1::ToEncodedPoint;
let scalar = k256::NonZeroScalar::try_from(sk.as_slice())
let scalar = k256::NonZeroScalar::try_from(sk.as_ref().as_slice())
.map_err(|_| PyValueError::new_err("Invalid secret key"))?;
let pk = k256::PublicKey::from_secret_scalar(&scalar);
let point = pk.to_encoded_point(true);
Expand Down Expand Up @@ -300,14 +301,15 @@ impl PyAgentSession {
session.check_operation(&Operation::SignNostrEvent).map_err(to_py_err)?;
session.check_event_kind(kind).map_err(to_py_err)?;

let secret = self.secret_key
let secret = self.secret_key.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?;

self.manager.record_request(&self.session_id).map_err(to_py_err)?;

use nostr_sdk::prelude::*;

let keys = Keys::parse(&hex::encode(secret))
let hex = Zeroizing::new(hex::encode(secret.as_ref()));
let keys = Keys::parse(hex.as_str())
.map_err(to_py_value_err)?;

let mut nostr_tags: Vec<Tag> = Vec::new();
Expand Down Expand Up @@ -358,8 +360,8 @@ impl PyAgentSession {

session.check_operation(&Operation::SignPsbt).map_err(to_py_err)?;

let mut secret = self.secret_key
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?;
let mut secret = Zeroizing::new(*self.secret_key.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?);

let network = match network.unwrap_or("testnet") {
"mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin,
Expand All @@ -371,7 +373,7 @@ impl PyAgentSession {
let mut psbt = keep_bitcoin::psbt::parse_psbt_base64(psbt_base64)
.map_err(|e| PyValueError::new_err(format!("Invalid PSBT: {}", e)))?;

let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network)
let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network)
.map_err(to_py_err)?;

let analysis = signer.analyze_psbt(&psbt).map_err(to_py_err)?;
Expand Down Expand Up @@ -408,7 +410,7 @@ impl PyAgentSession {
}

fn get_public_key(&self) -> PyResult<String> {
let _ = self.secret_key
let _ = self.secret_key.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured"))?;

let session = self.manager
Expand All @@ -429,8 +431,8 @@ impl PyAgentSession {

session.check_operation(&Operation::GetBitcoinAddress).map_err(to_py_err)?;

let mut secret = self.secret_key
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?;
let mut secret = Zeroizing::new(*self.secret_key.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("No secret key configured. Pass secret_key to constructor."))?);

let network = match network.unwrap_or("testnet") {
"mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin,
Expand All @@ -439,7 +441,7 @@ impl PyAgentSession {
_ => keep_bitcoin::Network::Testnet,
};

let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network)
let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network)
.map_err(to_py_err)?;

signer.get_receive_address(0).map_err(to_py_err)
Expand Down Expand Up @@ -524,6 +526,14 @@ impl PyRemoteSession {
}).map_err(to_py_err)
}

fn switch_relays(&self) -> PyResult<Option<Vec<String>>> {
let client = self.client.clone();
self.runtime.block_on(async {
let mut c = client.lock().await;
c.switch_relays().await
}).map_err(to_py_err)
}

fn disconnect(&self) -> PyResult<()> {
let client = self.client.clone();
self.runtime.block_on(async {
Expand Down
1 change: 1 addition & 0 deletions keep-agent-ts/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ serde_json = "1.0"
nostr-sdk = "0.44"
k256 = { version = "0.13", features = ["schnorr"] }
hex = "0.4"
zeroize = "1"

[build-dependencies]
napi-build = "2"
44 changes: 28 additions & 16 deletions keep-agent-ts/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use napi::bindgen_prelude::*;
use napi_derive::napi;
use std::sync::Arc;
use tokio::sync::Mutex;
use zeroize::Zeroizing;

use keep_agent::{
AgentClient, ApprovalStatus, Operation, PendingSession as RustPendingSession,
Expand Down Expand Up @@ -53,7 +54,7 @@ pub struct KeepAgentSession {
manager: SessionManager,
token: Arc<Mutex<SessionToken>>,
session_id: String,
secret_key: Option<[u8; 32]>,
secret_key: Option<Zeroizing<[u8; 32]>>,
}

#[napi]
Expand All @@ -66,9 +67,9 @@ impl KeepAgentSession {
policy: Option<String>,
secret_key: Option<String>,
) -> Result<Self> {
let secret_bytes: Option<[u8; 32]> = if let Some(ref sk) = secret_key {
let decoded = hex::decode(sk)
.map_err(|e| Error::from_reason(format!("Invalid secret key hex: {}", e)))?;
let secret_bytes: Option<Zeroizing<[u8; 32]>> = if let Some(ref sk) = secret_key {
let decoded = Zeroizing::new(hex::decode(sk)
.map_err(|e| Error::from_reason(format!("Invalid secret key hex: {}", e)))?);
if decoded.len() != 32 {
return Err(Error::from_reason(format!(
"Secret key must be 32 bytes, got {}",
Expand All @@ -77,14 +78,14 @@ impl KeepAgentSession {
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&decoded);
Some(arr)
Some(Zeroizing::new(arr))
} else {
None
};

let pubkey: [u8; 32] = if let Some(ref sk) = secret_bytes {
use k256::elliptic_curve::sec1::ToEncodedPoint;
let scalar = k256::NonZeroScalar::try_from(sk.as_slice())
let scalar = k256::NonZeroScalar::try_from(sk.as_ref().as_slice())
.map_err(|_| Error::from_reason("Invalid secret key"))?;
let pk = k256::PublicKey::from_secret_scalar(&scalar);
let point = pk.to_encoded_point(true);
Expand Down Expand Up @@ -304,6 +305,7 @@ impl KeepAgentSession {

let secret = self
.secret_key
.as_ref()
.ok_or_else(|| Error::from_reason("No secret key configured"))?;

self.manager
Expand All @@ -312,7 +314,8 @@ impl KeepAgentSession {

use nostr_sdk::prelude::{EventBuilder, Keys, Kind, Tag};

let keys = Keys::parse(&hex::encode(secret))
let hex = Zeroizing::new(hex::encode(secret.as_ref()));
let keys = Keys::parse(hex.as_str())
.map_err(|e| napi::Error::from_reason(e.to_string()))?;

let nostr_tags: Vec<Tag> = tags
Expand Down Expand Up @@ -365,9 +368,10 @@ impl KeepAgentSession {
.check_operation(&Operation::SignPsbt)
.map_err(|e| Error::from_reason(e.to_string()))?;

let mut secret = self
let mut secret = Zeroizing::new(*self
.secret_key
.ok_or_else(|| Error::from_reason("No secret key configured"))?;
.as_ref()
.ok_or_else(|| Error::from_reason("No secret key configured"))?);

let network = match network.as_deref().unwrap_or("testnet") {
"mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin,
Expand All @@ -379,7 +383,7 @@ impl KeepAgentSession {
let mut psbt = keep_bitcoin::psbt::parse_psbt_base64(&psbt_base64)
.map_err(|e| Error::from_reason(format!("Invalid PSBT: {}", e)))?;

let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network)
let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network)
.map_err(|e| Error::from_reason(e.to_string()))?;

let analysis = signer
Expand Down Expand Up @@ -414,9 +418,7 @@ impl KeepAgentSession {
.record_request(&self.session_id)
.map_err(|e| Error::from_reason(e.to_string()))?;

signer
.sign_psbt(&mut psbt)
.map_err(|e| Error::from_reason(e.to_string()))?;
signer.sign_psbt(&mut psbt).map_err(|e| Error::from_reason(e.to_string()))?;

Ok(keep_bitcoin::psbt::serialize_psbt_base64(&psbt))
}
Expand Down Expand Up @@ -449,9 +451,10 @@ impl KeepAgentSession {
.check_operation(&Operation::GetBitcoinAddress)
.map_err(|e| Error::from_reason(e.to_string()))?;

let mut secret = self
let mut secret = Zeroizing::new(*self
.secret_key
.ok_or_else(|| Error::from_reason("No secret key configured"))?;
.as_ref()
.ok_or_else(|| Error::from_reason("No secret key configured"))?);

let network = match network.as_deref().unwrap_or("testnet") {
"mainnet" | "bitcoin" => keep_bitcoin::Network::Bitcoin,
Expand All @@ -460,7 +463,7 @@ impl KeepAgentSession {
_ => keep_bitcoin::Network::Testnet,
};

let signer = keep_bitcoin::BitcoinSigner::new(&mut secret, network)
let signer = keep_bitcoin::BitcoinSigner::new(&mut *secret, network)
.map_err(|e| Error::from_reason(e.to_string()))?;

signer
Expand Down Expand Up @@ -534,6 +537,15 @@ impl RemoteSession {
.map_err(|e| Error::from_reason(e.to_string()))
}

#[napi]
pub async fn switch_relays(&self) -> Result<Option<Vec<String>>> {
let mut client = self.client.lock().await;
client
.switch_relays()
.await
.map_err(|e| Error::from_reason(e.to_string()))
}

#[napi]
pub async fn disconnect(&self) -> Result<()> {
let client = self.client.lock().await;
Expand Down
Loading