Skip to content
Open
10 changes: 10 additions & 0 deletions atoma-auth/src/sui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ impl Sui {
Ok(signature.encode_base64())
}

/// Get the underlying keystore
///
/// # Returns
///
/// Returns the keystore.
#[must_use]
pub fn get_keystore(&self) -> &Keystore {
&self.wallet_ctx.config.keystore
}

/// Sign a hash using the wallet's private key
///
/// # Arguments
Expand Down
13 changes: 10 additions & 3 deletions atoma-proxy/src/server/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use super::metrics::{
};
use super::request_model::RequestModel;
use super::{
handle_status_code_error, update_state_manager, verify_response_hash_and_signature,
handle_status_code_error, update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY,
RESPONSE_HASH_KEY,
};
use crate::server::{Result, DEFAULT_MAX_TOKENS, MAX_COMPLETION_TOKENS, MAX_TOKENS, MODEL};
Expand Down Expand Up @@ -525,6 +525,7 @@ pub fn confidential_chat_completions_create_stream(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
async fn handle_non_streaming_response(
state: &ProxyState,
node_address: &String,
Expand Down Expand Up @@ -561,7 +562,7 @@ async fn handle_non_streaming_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response = response
let mut response = response
.json::<Value>()
.await
.map_err(|err| AtomaProxyError::InternalError {
Expand Down Expand Up @@ -597,7 +598,12 @@ async fn handle_non_streaming_response(
);

let verify_hash = endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH;
verify_response_hash_and_signature(&response.0, verify_hash)?;

let guard: tokio::sync::RwLockReadGuard<'_, atoma_auth::Sui> = state.sui.read().await;
let keystore = guard.get_keystore();
let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

state
.state_manager_sender
Expand Down Expand Up @@ -760,6 +766,7 @@ async fn handle_streaming_response(
state.state_manager_sender.clone(),
selected_stack_small_id,
estimated_total_tokens,
state.sui.clone(),
start,
node_id,
model_name,
Expand Down
12 changes: 9 additions & 3 deletions atoma-proxy/src/server/handlers/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use super::{
TOTAL_FAILED_REQUESTS, TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS,
},
request_model::RequestModel,
update_state_manager, verify_response_hash_and_signature, RESPONSE_HASH_KEY,
update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY,
};
use crate::server::Result;

Expand Down Expand Up @@ -357,6 +357,7 @@ pub async fn confidential_embeddings_create(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
async fn handle_embeddings_response(
state: &ProxyState,
node_address: String,
Expand Down Expand Up @@ -397,7 +398,7 @@ async fn handle_embeddings_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response =
let mut response =
response
.json::<Value>()
.await
Expand All @@ -407,8 +408,13 @@ async fn handle_embeddings_response(
endpoint: endpoint.to_string(),
})?;

let guard = state.sui.blocking_read();
let keystore = guard.get_keystore();
let verify_hash = endpoint != CONFIDENTIAL_EMBEDDINGS_PATH;
verify_response_hash_and_signature(&response, verify_hash)?;

let proxy_signature = verify_and_sign_response(&response, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

let num_input_compute_units = if endpoint == CONFIDENTIAL_EMBEDDINGS_PATH {
response
Expand Down
17 changes: 13 additions & 4 deletions atoma-proxy/src/server/handlers/image_generations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ use crate::server::error::AtomaProxyError;
use crate::server::types::{ConfidentialComputeRequest, ConfidentialComputeResponse};
use crate::server::{http_server::ProxyState, middleware::RequestMetadataExtension};

use super::handle_status_code_error;
use super::metrics::{
IMAGE_GEN_LATENCY_METRICS, IMAGE_GEN_NUM_REQUESTS, TOTAL_COMPLETED_REQUESTS,
TOTAL_FAILED_IMAGE_GENERATION_REQUESTS, TOTAL_FAILED_REQUESTS,
};
use super::{handle_status_code_error, verify_response_hash_and_signature};
use super::{request_model::RequestModel, update_state_manager, RESPONSE_HASH_KEY};
use super::{
request_model::RequestModel, update_state_manager, verify_and_sign_response,
PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY,
};
use crate::server::{Result, MODEL};

/// Path for the confidential image generations endpoint.
Expand Down Expand Up @@ -336,6 +339,7 @@ pub async fn confidential_image_generations_create(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
async fn handle_image_generation_response(
state: &ProxyState,
node_address: String,
Expand Down Expand Up @@ -374,7 +378,7 @@ async fn handle_image_generation_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response = response
let mut response = response
.json::<Value>()
.await
.map_err(|err| AtomaProxyError::InternalError {
Expand All @@ -384,8 +388,13 @@ async fn handle_image_generation_response(
})
.map(Json)?;

let guard = state.sui.read().await;
let keystore = guard.get_keystore();
let verify_hash = endpoint != CONFIDENTIAL_IMAGE_GENERATIONS_PATH;
verify_response_hash_and_signature(&response.0, verify_hash)?;

let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

// Update the node throughput performance
state
Expand Down
59 changes: 47 additions & 12 deletions atoma-proxy/src/server/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::str::FromStr;

use atoma_state::types::AtomaAtomaStateManagerEvent;
use base64::engine::{general_purpose::STANDARD, Engine};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use blake2::Digest;
use fastcrypto::{
ed25519::{Ed25519PublicKey, Ed25519Signature},
Expand All @@ -11,6 +11,7 @@ use fastcrypto::{
};
use flume::Sender;
use reqwest::StatusCode;
use sui_keys::keystore::{AccountKeystore, Keystore};
use sui_sdk::types::crypto::{PublicKey, Signature, SignatureScheme, SuiSignature};
use tracing::instrument;

Expand All @@ -31,6 +32,9 @@ pub const RESPONSE_HASH_KEY: &str = "response_hash";
/// Key for the signature in the payload
pub const SIGNATURE_KEY: &str = "signature";

/// Key for the proxy signature in the payload
pub const PROXY_SIGNATURE_KEY: &str = "proxy_signature";

/// Updates the state manager with token usage and hash information for a stack.
///
/// This function performs two main operations:
Expand Down Expand Up @@ -82,43 +86,56 @@ pub fn update_state_manager(
Ok(())
}

/// Verifies a Sui signature from a handler response
/// Verifies a Sui signature and creates a new signature using the proxy's key
///
/// # Arguments
///
/// * `payload` - JSON payload containing the response hash and its signature
/// * `node_public_key` - Public key of the node that signed the response
/// * `proxy_keystore` - Keystore containing the proxy's signing key
///
/// # Returns
///
/// Returns `Ok(())` if verification succeeds,
/// Returns `Ok(String)` with the new signature if verification succeeds,
/// or an error if verification fails or signing fails
///
/// # Errors
///
/// This function will return an error if:
/// - The payload format is invalid
/// - The signature verification fails
/// - Creating the new signature fails
#[instrument(level = "debug", skip_all)]
pub fn verify_response_hash_and_signature(
pub fn verify_and_sign_response(
payload: &serde_json::Value,
verify_hash: bool,
) -> Result<()> {
keystore: &Keystore,
) -> Result<String> {
// Extract response hash and signature from payload
let response_hash =
let response_hash_str =
payload[RESPONSE_HASH_KEY]
.as_str()
.ok_or_else(|| AtomaProxyError::InternalError {
message: "Missing response_hash in payload".to_string(),
client_message: Some("Invalid response from inference service".to_string()),
endpoint: "verify_signature".to_string(),
})?;
let response_hash = STANDARD.decode(response_hash).unwrap();

if verify_hash {
verify_response_hash(payload, &response_hash)?;
// Decode base64 string to bytes for verification
let response_hash_bytes =
BASE64
.decode(response_hash_str)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to decode response hash: {e}"),
client_message: Some("Invalid response from inference service".to_string()),
endpoint: "verify_signature".to_string(),
})?;
verify_response_hash(payload, &response_hash_bytes)?;
}

let response_hash_bytes = BASE64.decode(response_hash_str).unwrap();

let node_signature =
payload[SIGNATURE_KEY]
.as_str()
Expand Down Expand Up @@ -163,7 +180,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to verify ed25519 signature: {e}"),
client_message: Some("Invalid response from inference service".to_string()),
Expand All @@ -187,7 +204,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|_| AtomaProxyError::InternalError {
message: "Failed to verify secp256k1 signature".to_string(),
client_message: Some("Invalid response from inference service".to_string()),
Expand All @@ -211,7 +228,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|_| AtomaProxyError::InternalError {
message: "Failed to verify secp256r1 signature".to_string(),
client_message: Some("Invalid response from inference service".to_string()),
Expand All @@ -227,7 +244,25 @@ pub fn verify_response_hash_and_signature(
}
}

Ok(())
// Sign with proxy's key
let proxy_signature = match keystore {
Keystore::File(keystore) => keystore
.sign_hashed(&keystore.addresses()[0], &response_hash_bytes)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to create proxy signature: {e}"),
client_message: Some("Invalid response from inference service".to_string()),
endpoint: "verify_signature".to_string(),
})?,
Keystore::InMem(keystore) => keystore
.sign_hashed(&keystore.addresses()[0], &response_hash_bytes)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to create proxy signature: {e}"),
client_message: Some("Invalid response from inference service".to_string()),
endpoint: "verify_signature".to_string(),
})?,
};
// Convert signature to base64
Ok(BASE64.encode(proxy_signature.as_ref()))
}

/// Verifies that a response hash matches the computed hash of the payload
Expand Down
28 changes: 16 additions & 12 deletions atoma-proxy/src/server/streamer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::cognitive_complexity)]
#![allow(clippy::too_many_lines)]

use atoma_auth::Sui;
use atoma_state::types::AtomaAtomaStateManagerEvent;
use axum::body::Bytes;
use axum::{response::sse::Event, Error};
Expand All @@ -10,21 +11,24 @@ use opentelemetry::KeyValue;
use reqwest;
use serde_json::Value;
use sqlx::types::chrono::{DateTime, Utc};
use tokio::sync::RwLock;

use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager};

use super::handlers::verify_and_sign_response;
use std::sync::Arc;
use std::{
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use tracing::{error, info, instrument, warn};

use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager};

use super::handlers::chat_completions::CONFIDENTIAL_CHAT_COMPLETIONS_PATH;
use super::handlers::metrics::{
CHAT_COMPLETIONS_INTER_TOKEN_GENERATION_TIME, CHAT_COMPLETIONS_TIME_TO_FIRST_TOKEN,
CHAT_COMPLETIONS_TOTAL_TOKENS,
};
use super::handlers::verify_response_hash_and_signature;

/// The chunk that indicates the end of a streaming response
const DONE_CHUNK: &str = "[DONE]";
Expand Down Expand Up @@ -52,6 +56,8 @@ pub struct Streamer {
status: StreamStatus,
/// Estimated total tokens for the stream
estimated_total_tokens: i64,
/// Keystore
sui: Arc<RwLock<Sui>>,
/// Stack small id
stack_small_id: i64,
/// State manager sender
Expand Down Expand Up @@ -99,6 +105,7 @@ impl Streamer {
state_manager_sender: Sender<AtomaAtomaStateManagerEvent>,
stack_small_id: i64,
estimated_total_tokens: i64,
sui: Arc<RwLock<Sui>>,
start: Instant,
node_id: i64,
model_name: String,
Expand All @@ -108,6 +115,7 @@ impl Streamer {
stream: Box::pin(stream),
status: StreamStatus::NotStarted,
estimated_total_tokens,
sui,
stack_small_id,
state_manager_sender,
start,
Expand Down Expand Up @@ -450,15 +458,6 @@ impl Stream for Streamer {
// is not running within a secure enclave. Otherwise, the fact that the node can process requests
// with confidential data is proof of data integrity.
let verify_hash = self.endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH;
verify_response_hash_and_signature(&chunk, verify_hash).map_err(|e| {
error!(
target = "atoma-service-streamer",
level = "error",
"Error verifying response: {}",
e
);
Error::new(format!("Error verifying and signing response: {e:?}"))
})?;

if self.start_decode.is_none() {
self.start_decode = Some(Instant::now());
Expand Down Expand Up @@ -501,6 +500,11 @@ impl Stream for Streamer {
}
} else if let Some(usage) = chunk.get(USAGE) {
self.status = StreamStatus::Completed;
let _ = {
let guard = self.sui.blocking_read();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How safe is it to have a blocking_read in an async stream ?

Another possibility is poll::pending and add logic to the poll next method to handle when a signature is ready, but that's a bit too much of hassle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, we actually try caching the keystore when creating the Streamer to avoid locking during stream processing, wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense, but it is a bit of a pain

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we might introduce unnecessary complexity here, let's revisit this after profiling

verify_and_sign_response(&chunk, verify_hash, guard.get_keystore())
.map_err(|e| Error::new(e.to_string()))?
}; // guard is dropped immediately after signature is created
self.handle_final_chunk(usage)?;
}

Expand Down