Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 125 additions & 92 deletions keep-agent/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::time::Duration;

use nostr_sdk::prelude::*;

use keep_core::relay::TIMESTAMP_TWEAK_RANGE;

use crate::error::{AgentError, Result};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand All @@ -31,24 +33,13 @@ impl PendingSession {
let client = Client::new(client_keys.clone());

client
.add_relay(&relay_url)
.pool()
.add_relay(&relay_url, default_relay_opts())
.await
.map_err(|e| AgentError::Connection(e.to_string()))?;

client.connect().await;

tokio::time::timeout(timeout, async {
loop {
if let Ok(relay) = client.relay(&relay_url).await {
if matches!(relay.status(), RelayStatus::Connected) {
break;
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await
.map_err(|_| AgentError::Connection("Relay connection timeout".into()))?;
wait_for_relay_connection(&client, &relay_url, timeout).await?;

let request_id = generate_uuid();

Expand Down Expand Up @@ -98,7 +89,7 @@ impl PendingSession {
let tags = vec![Tag::public_key(self.signer_pubkey)];
let unsigned = UnsignedEvent::new(
self.client_keys.public_key(),
Timestamp::now(),
Timestamp::tweaked(TIMESTAMP_TWEAK_RANGE),
Kind::NostrConnect,
tags,
encrypted,
Expand All @@ -123,51 +114,49 @@ impl PendingSession {
.kind(Kind::NostrConnect)
.author(self.signer_pubkey)
.pubkey(self.client_keys.public_key())
.since(Timestamp::now());
.since(Timestamp::now() - Duration::from_secs(10));

let sub_output = self
let mut stream = self
.client
.subscribe(filter, None)
.pool()
.stream_events(filter, timeout, ReqExitPolicy::WaitForEventsAfterEOSE(1))
.await
.map_err(|e| AgentError::Nostr(e.to_string()))?;

let sub_id = sub_output.id();
let mut notifications = self.client.notifications();

let result = tokio::time::timeout(timeout, async {
while let Ok(notification) = notifications.recv().await {
if let RelayPoolNotification::Event { event, .. } = notification {
if event.kind == Kind::NostrConnect && event.pubkey == self.signer_pubkey {
if let Ok(decrypted) = nip44::decrypt(
self.client_keys.secret_key(),
&self.signer_pubkey,
&event.content,
) {
let parsed: serde_json::Value = serde_json::from_str(&decrypted)
.map_err(|e| AgentError::Serialization(e.to_string()))?;

if let Some(id) = parsed.get("id").and_then(|v| v.as_str()) {
if id == self.request_id {
if let Some(error) = parsed.get("error") {
if !error.is_null() {
return Ok(ApprovalStatus::Denied);
}
}
if parsed.get("result").is_some() {
return Ok(ApprovalStatus::Approved);
}
}
}
}
while let Some(event) = stream.next().await {
if event.kind != Kind::NostrConnect || event.pubkey != self.signer_pubkey {
continue;
}
let Ok(decrypted) = nip44::decrypt(
self.client_keys.secret_key(),
&self.signer_pubkey,
&event.content,
) else {
continue;
};
let parsed: serde_json::Value = serde_json::from_str(&decrypted)
.map_err(|e| AgentError::Serialization(e.to_string()))?;

let Some(id) = parsed.get("id").and_then(|v| v.as_str()) else {
continue;
};
if id != self.request_id {
continue;
}
if let Some(error) = parsed.get("error") {
if !error.is_null() {
return Ok(ApprovalStatus::Denied);
}
}
if parsed.get("result").is_some() {
return Ok(ApprovalStatus::Approved);
}
}
Ok(ApprovalStatus::Pending)
})
.await;

self.client.unsubscribe(sub_id).await;

match result {
Ok(inner) => inner,
Err(_) => Ok(ApprovalStatus::Pending),
Expand Down Expand Up @@ -220,24 +209,13 @@ impl AgentClient {
let client = Client::new(client_keys.clone());

client
.add_relay(&relay_url)
.pool()
.add_relay(&relay_url, default_relay_opts())
.await
.map_err(|e| AgentError::Connection(e.to_string()))?;

client.connect().await;

tokio::time::timeout(timeout, async {
loop {
if let Ok(relay) = client.relay(&relay_url).await {
if matches!(relay.status(), RelayStatus::Connected) {
break;
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await
.map_err(|_| AgentError::Connection("Relay connection timeout".into()))?;
wait_for_relay_connection(&client, &relay_url, timeout).await?;

let mut agent_client = Self {
signer_pubkey,
Expand Down Expand Up @@ -412,7 +390,7 @@ impl AgentClient {
}

let relays: Vec<String> = if result.is_string() {
serde_json::from_str(result.as_str().unwrap())
serde_json::from_str(result.as_str().expect("guarded by is_string check"))
.map_err(|e| AgentError::Serialization(format!("Invalid relay list: {e}")))?
} else if result.is_array() {
serde_json::from_value(result.clone())
Expand All @@ -438,9 +416,16 @@ impl AgentClient {

self.client.disconnect().await;
self.client.remove_all_relays().await;
let relay_opts = default_relay_opts();
let mut added = Vec::new();
for relay in &valid_relays {
if self.client.add_relay(relay).await.is_ok() {
if self
.client
.pool()
.add_relay(relay, relay_opts.clone())
.await
.is_ok()
{
added.push(relay.clone());
}
}
Expand All @@ -450,6 +435,7 @@ impl AgentClient {
));
}
self.client.connect().await;
wait_for_any_relay_connection(&self.client, &added, Duration::from_secs(10)).await?;

self.relay_url = added[0].clone();

Expand Down Expand Up @@ -477,7 +463,7 @@ impl AgentClient {
let tags = vec![Tag::public_key(self.signer_pubkey)];
let unsigned = UnsignedEvent::new(
self.client_keys.public_key(),
Timestamp::now(),
Timestamp::tweaked(TIMESTAMP_TWEAK_RANGE),
Kind::NostrConnect,
tags,
encrypted,
Expand All @@ -496,47 +482,44 @@ impl AgentClient {
.kind(Kind::NostrConnect)
.author(self.signer_pubkey)
.pubkey(self.client_keys.public_key())
.since(Timestamp::now());
.since(Timestamp::now() - Duration::from_secs(10));

let sub_output = self
let mut stream = self
.client
.subscribe(filter, None)
.pool()
.stream_events(
filter,
Duration::from_secs(30),
ReqExitPolicy::WaitForEventsAfterEOSE(5),
)
.await
.map_err(|e| AgentError::Nostr(e.to_string()))?;

let sub_id = sub_output.id();
let mut notifications = self.client.notifications();

let result = tokio::time::timeout(Duration::from_secs(30), async {
while let Ok(notification) = notifications.recv().await {
if let RelayPoolNotification::Event { event, .. } = notification {
if event.kind == Kind::NostrConnect && event.pubkey == self.signer_pubkey {
if let Ok(decrypted) = nip44::decrypt(
self.client_keys.secret_key(),
&self.signer_pubkey,
&event.content,
) {
if let Some(ref expected_id) = request_id {
if let Ok(resp) =
serde_json::from_str::<serde_json::Value>(&decrypted)
{
if resp.get("id").and_then(|v| v.as_str()) != Some(expected_id)
{
continue;
}
}
}
return Ok(decrypted);
while let Some(event) = stream.next().await {
if event.kind != Kind::NostrConnect || event.pubkey != self.signer_pubkey {
continue;
}
let Ok(decrypted) = nip44::decrypt(
self.client_keys.secret_key(),
&self.signer_pubkey,
&event.content,
) else {
continue;
};
if let Some(ref expected_id) = request_id {
if let Ok(resp) = serde_json::from_str::<serde_json::Value>(&decrypted) {
if resp.get("id").and_then(|v| v.as_str()) != Some(expected_id) {
continue;
}
}
}
return Ok(decrypted);
}
Err(AgentError::Connection("No response received".into()))
})
.await;

self.client.unsubscribe(sub_id).await;

match result {
Ok(inner) => inner,
Err(_) => Err(AgentError::Connection("Response timeout".into())),
Expand All @@ -557,7 +540,7 @@ impl AgentClient {
.get("result")
.map(|v| {
if v.is_string() {
v.as_str().unwrap().to_string()
v.as_str().expect("guarded by is_string check").to_string()
} else {
v.to_string()
}
Expand All @@ -578,6 +561,56 @@ impl AgentClient {
}
}

async fn wait_for_relay_connection(
client: &Client,
relay_url: &str,
timeout: Duration,
) -> Result<()> {
tokio::time::timeout(timeout, async {
loop {
if let Ok(relay) = client.relay(relay_url).await {
if matches!(relay.status(), RelayStatus::Connected) {
return;
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await
.map_err(|_| AgentError::Connection("Relay connection timeout".into()))
}

async fn wait_for_any_relay_connection(
client: &Client,
relays: &[String],
timeout: Duration,
) -> Result<()> {
tokio::time::timeout(timeout, async {
loop {
for relay in relays {
if let Ok(r) = client.relay(relay).await {
if matches!(r.status(), RelayStatus::Connected) {
return;
}
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
})
.await
.map_err(|_| AgentError::Connection("Relay connection timeout".into()))
}

fn default_relay_opts() -> RelayOptions {
RelayOptions::default()
.reconnect(true)
.ping(true)
.retry_interval(Duration::from_secs(10))
.adjust_retry_interval(true)
.ban_relay_on_mismatch(true)
.max_avg_latency(Some(Duration::from_secs(3)))
}

fn generate_uuid() -> String {
uuid::Uuid::new_v4().to_string()
}
Expand Down
3 changes: 3 additions & 0 deletions keep-core/src/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub const MAX_RELAYS: usize = 10;
/// Maximum length of a relay URL.
pub const MAX_RELAY_URL_LENGTH: usize = 256;

/// Range of seconds to randomly tweak event timestamps for privacy.
pub const TIMESTAMP_TWEAK_RANGE: std::ops::Range<u64> = 0..5;

/// Relay configuration for a FROST share, keyed by group public key.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RelayConfig {
Expand Down
1 change: 0 additions & 1 deletion keep-frost-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
hex = "0.4"
base64 = "0.22"
chrono = "0.4"
rustls = "0.23"
rustls-pki-types = "1"

Expand Down
Loading