diff --git a/atoma-auth/src/sui/mod.rs b/atoma-auth/src/sui/mod.rs index 0d706a60..001ae5fb 100644 --- a/atoma-auth/src/sui/mod.rs +++ b/atoma-auth/src/sui/mod.rs @@ -403,6 +403,16 @@ impl Sui { Ok(signature.encode_base64()) } + /// Get the underlying keystore + /// + /// # Returns + /// + /// Returns the keystore. + #[must_use] + pub fn get_keystore(&self) -> &Keystore { + &self.wallet_ctx.config.keystore + } + /// Sign a hash using the wallet's private key /// /// # Arguments diff --git a/atoma-proxy/src/server/handlers/chat_completions.rs b/atoma-proxy/src/server/handlers/chat_completions.rs index d2fd67ac..c6bb1d83 100644 --- a/atoma-proxy/src/server/handlers/chat_completions.rs +++ b/atoma-proxy/src/server/handlers/chat_completions.rs @@ -44,7 +44,7 @@ use super::metrics::{ }; use super::request_model::RequestModel; use super::{ - handle_status_code_error, update_state_manager, verify_response_hash_and_signature, + handle_status_code_error, update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY, }; use crate::server::{Result, DEFAULT_MAX_TOKENS, MAX_COMPLETION_TOKENS, MAX_TOKENS, MODEL}; @@ -525,6 +525,7 @@ pub fn confidential_chat_completions_create_stream( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_non_streaming_response( state: &ProxyState, node_address: &String, @@ -561,7 +562,7 @@ async fn handle_non_streaming_response( handle_status_code_error(response.status(), &endpoint, error)?; } - let response = response + let mut response = response .json::() .await .map_err(|err| AtomaProxyError::InternalError { @@ -597,7 +598,12 @@ async fn handle_non_streaming_response( ); let verify_hash = endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH; - verify_response_hash_and_signature(&response.0, verify_hash)?; + + let guard: tokio::sync::RwLockReadGuard<'_, atoma_auth::Sui> = state.sui.read().await; + let keystore = guard.get_keystore(); + let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); state .state_manager_sender @@ -760,6 +766,7 @@ async fn handle_streaming_response( state.state_manager_sender.clone(), selected_stack_small_id, estimated_total_tokens, + state.sui.clone(), start, node_id, model_name, diff --git a/atoma-proxy/src/server/handlers/embeddings.rs b/atoma-proxy/src/server/handlers/embeddings.rs index 88e4039e..77c70aea 100644 --- a/atoma-proxy/src/server/handlers/embeddings.rs +++ b/atoma-proxy/src/server/handlers/embeddings.rs @@ -32,7 +32,7 @@ use super::{ TOTAL_FAILED_REQUESTS, TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS, }, request_model::RequestModel, - update_state_manager, verify_response_hash_and_signature, RESPONSE_HASH_KEY, + update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY, }; use crate::server::Result; @@ -357,6 +357,7 @@ pub async fn confidential_embeddings_create( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_embeddings_response( state: &ProxyState, node_address: String, @@ -397,7 +398,7 @@ async fn handle_embeddings_response( handle_status_code_error(response.status(), &endpoint, error)?; } - let response = + let mut response = response .json::() .await @@ -407,8 +408,13 @@ async fn handle_embeddings_response( endpoint: endpoint.to_string(), })?; + let guard = state.sui.blocking_read(); + let keystore = guard.get_keystore(); let verify_hash = endpoint != CONFIDENTIAL_EMBEDDINGS_PATH; - verify_response_hash_and_signature(&response, verify_hash)?; + + let proxy_signature = verify_and_sign_response(&response, verify_hash, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); let num_input_compute_units = if endpoint == CONFIDENTIAL_EMBEDDINGS_PATH { response diff --git a/atoma-proxy/src/server/handlers/image_generations.rs b/atoma-proxy/src/server/handlers/image_generations.rs index e8ba51f7..6ae5f06d 100644 --- a/atoma-proxy/src/server/handlers/image_generations.rs +++ b/atoma-proxy/src/server/handlers/image_generations.rs @@ -18,12 +18,15 @@ use crate::server::error::AtomaProxyError; use crate::server::types::{ConfidentialComputeRequest, ConfidentialComputeResponse}; use crate::server::{http_server::ProxyState, middleware::RequestMetadataExtension}; +use super::handle_status_code_error; use super::metrics::{ IMAGE_GEN_LATENCY_METRICS, IMAGE_GEN_NUM_REQUESTS, TOTAL_COMPLETED_REQUESTS, TOTAL_FAILED_IMAGE_GENERATION_REQUESTS, TOTAL_FAILED_REQUESTS, }; -use super::{handle_status_code_error, verify_response_hash_and_signature}; -use super::{request_model::RequestModel, update_state_manager, RESPONSE_HASH_KEY}; +use super::{ + request_model::RequestModel, update_state_manager, verify_and_sign_response, + PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY, +}; use crate::server::{Result, MODEL}; /// Path for the confidential image generations endpoint. @@ -336,6 +339,7 @@ pub async fn confidential_image_generations_create( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_image_generation_response( state: &ProxyState, node_address: String, @@ -374,7 +378,7 @@ async fn handle_image_generation_response( handle_status_code_error(response.status(), &endpoint, error)?; } - let response = response + let mut response = response .json::() .await .map_err(|err| AtomaProxyError::InternalError { @@ -384,8 +388,13 @@ async fn handle_image_generation_response( }) .map(Json)?; + let guard = state.sui.read().await; + let keystore = guard.get_keystore(); let verify_hash = endpoint != CONFIDENTIAL_IMAGE_GENERATIONS_PATH; - verify_response_hash_and_signature(&response.0, verify_hash)?; + + let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); // Update the node throughput performance state diff --git a/atoma-proxy/src/server/handlers/mod.rs b/atoma-proxy/src/server/handlers/mod.rs index b5897754..1e4e90eb 100644 --- a/atoma-proxy/src/server/handlers/mod.rs +++ b/atoma-proxy/src/server/handlers/mod.rs @@ -1,7 +1,7 @@ use std::str::FromStr; use atoma_state::types::AtomaAtomaStateManagerEvent; -use base64::engine::{general_purpose::STANDARD, Engine}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use blake2::Digest; use fastcrypto::{ ed25519::{Ed25519PublicKey, Ed25519Signature}, @@ -11,6 +11,7 @@ use fastcrypto::{ }; use flume::Sender; use reqwest::StatusCode; +use sui_keys::keystore::{AccountKeystore, Keystore}; use sui_sdk::types::crypto::{PublicKey, Signature, SignatureScheme, SuiSignature}; use tracing::instrument; @@ -31,6 +32,9 @@ pub const RESPONSE_HASH_KEY: &str = "response_hash"; /// Key for the signature in the payload pub const SIGNATURE_KEY: &str = "signature"; +/// Key for the proxy signature in the payload +pub const PROXY_SIGNATURE_KEY: &str = "proxy_signature"; + /// Updates the state manager with token usage and hash information for a stack. /// /// This function performs two main operations: @@ -82,16 +86,17 @@ pub fn update_state_manager( Ok(()) } -/// Verifies a Sui signature from a handler response +/// Verifies a Sui signature and creates a new signature using the proxy's key /// /// # Arguments /// /// * `payload` - JSON payload containing the response hash and its signature /// * `node_public_key` - Public key of the node that signed the response +/// * `proxy_keystore` - Keystore containing the proxy's signing key /// /// # Returns /// -/// Returns `Ok(())` if verification succeeds, +/// Returns `Ok(String)` with the new signature if verification succeeds, /// or an error if verification fails or signing fails /// /// # Errors @@ -99,13 +104,15 @@ pub fn update_state_manager( /// This function will return an error if: /// - The payload format is invalid /// - The signature verification fails +/// - Creating the new signature fails #[instrument(level = "debug", skip_all)] -pub fn verify_response_hash_and_signature( +pub fn verify_and_sign_response( payload: &serde_json::Value, verify_hash: bool, -) -> Result<()> { + keystore: &Keystore, +) -> Result { // Extract response hash and signature from payload - let response_hash = + let response_hash_str = payload[RESPONSE_HASH_KEY] .as_str() .ok_or_else(|| AtomaProxyError::InternalError { @@ -113,12 +120,22 @@ pub fn verify_response_hash_and_signature( client_message: Some("Invalid response from inference service".to_string()), endpoint: "verify_signature".to_string(), })?; - let response_hash = STANDARD.decode(response_hash).unwrap(); if verify_hash { - verify_response_hash(payload, &response_hash)?; + // Decode base64 string to bytes for verification + let response_hash_bytes = + BASE64 + .decode(response_hash_str) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to decode response hash: {e}"), + client_message: Some("Invalid response from inference service".to_string()), + endpoint: "verify_signature".to_string(), + })?; + verify_response_hash(payload, &response_hash_bytes)?; } + let response_hash_bytes = BASE64.decode(response_hash_str).unwrap(); + let node_signature = payload[SIGNATURE_KEY] .as_str() @@ -163,7 +180,7 @@ pub fn verify_response_hash_and_signature( } })?; public_key - .verify(response_hash.as_slice(), &signature) + .verify(response_hash_bytes.as_slice(), &signature) .map_err(|e| AtomaProxyError::InternalError { message: format!("Failed to verify ed25519 signature: {e}"), client_message: Some("Invalid response from inference service".to_string()), @@ -187,7 +204,7 @@ pub fn verify_response_hash_and_signature( } })?; public_key - .verify(response_hash.as_slice(), &signature) + .verify(response_hash_bytes.as_slice(), &signature) .map_err(|_| AtomaProxyError::InternalError { message: "Failed to verify secp256k1 signature".to_string(), client_message: Some("Invalid response from inference service".to_string()), @@ -211,7 +228,7 @@ pub fn verify_response_hash_and_signature( } })?; public_key - .verify(response_hash.as_slice(), &signature) + .verify(response_hash_bytes.as_slice(), &signature) .map_err(|_| AtomaProxyError::InternalError { message: "Failed to verify secp256r1 signature".to_string(), client_message: Some("Invalid response from inference service".to_string()), @@ -227,7 +244,25 @@ pub fn verify_response_hash_and_signature( } } - Ok(()) + // Sign with proxy's key + let proxy_signature = match keystore { + Keystore::File(keystore) => keystore + .sign_hashed(&keystore.addresses()[0], &response_hash_bytes) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to create proxy signature: {e}"), + client_message: Some("Invalid response from inference service".to_string()), + endpoint: "verify_signature".to_string(), + })?, + Keystore::InMem(keystore) => keystore + .sign_hashed(&keystore.addresses()[0], &response_hash_bytes) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to create proxy signature: {e}"), + client_message: Some("Invalid response from inference service".to_string()), + endpoint: "verify_signature".to_string(), + })?, + }; + // Convert signature to base64 + Ok(BASE64.encode(proxy_signature.as_ref())) } /// Verifies that a response hash matches the computed hash of the payload diff --git a/atoma-proxy/src/server/streamer.rs b/atoma-proxy/src/server/streamer.rs index 5b496839..10380f22 100644 --- a/atoma-proxy/src/server/streamer.rs +++ b/atoma-proxy/src/server/streamer.rs @@ -1,6 +1,7 @@ #![allow(clippy::cognitive_complexity)] #![allow(clippy::too_many_lines)] +use atoma_auth::Sui; use atoma_state::types::AtomaAtomaStateManagerEvent; use axum::body::Bytes; use axum::{response::sse::Event, Error}; @@ -10,6 +11,12 @@ use opentelemetry::KeyValue; use reqwest; use serde_json::Value; use sqlx::types::chrono::{DateTime, Utc}; +use tokio::sync::RwLock; + +use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager}; + +use super::handlers::verify_and_sign_response; +use std::sync::Arc; use std::{ pin::Pin, task::{Context, Poll}, @@ -17,14 +24,11 @@ use std::{ }; use tracing::{error, info, instrument, warn}; -use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager}; - use super::handlers::chat_completions::CONFIDENTIAL_CHAT_COMPLETIONS_PATH; use super::handlers::metrics::{ CHAT_COMPLETIONS_INTER_TOKEN_GENERATION_TIME, CHAT_COMPLETIONS_TIME_TO_FIRST_TOKEN, CHAT_COMPLETIONS_TOTAL_TOKENS, }; -use super::handlers::verify_response_hash_and_signature; /// The chunk that indicates the end of a streaming response const DONE_CHUNK: &str = "[DONE]"; @@ -52,6 +56,8 @@ pub struct Streamer { status: StreamStatus, /// Estimated total tokens for the stream estimated_total_tokens: i64, + /// Keystore + sui: Arc>, /// Stack small id stack_small_id: i64, /// State manager sender @@ -99,6 +105,7 @@ impl Streamer { state_manager_sender: Sender, stack_small_id: i64, estimated_total_tokens: i64, + sui: Arc>, start: Instant, node_id: i64, model_name: String, @@ -108,6 +115,7 @@ impl Streamer { stream: Box::pin(stream), status: StreamStatus::NotStarted, estimated_total_tokens, + sui, stack_small_id, state_manager_sender, start, @@ -450,15 +458,6 @@ impl Stream for Streamer { // is not running within a secure enclave. Otherwise, the fact that the node can process requests // with confidential data is proof of data integrity. let verify_hash = self.endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH; - verify_response_hash_and_signature(&chunk, verify_hash).map_err(|e| { - error!( - target = "atoma-service-streamer", - level = "error", - "Error verifying response: {}", - e - ); - Error::new(format!("Error verifying and signing response: {e:?}")) - })?; if self.start_decode.is_none() { self.start_decode = Some(Instant::now()); @@ -501,6 +500,11 @@ impl Stream for Streamer { } } else if let Some(usage) = chunk.get(USAGE) { self.status = StreamStatus::Completed; + let _ = { + let guard = self.sui.blocking_read(); + verify_and_sign_response(&chunk, verify_hash, guard.get_keystore()) + .map_err(|e| Error::new(e.to_string()))? + }; // guard is dropped immediately after signature is created self.handle_final_chunk(usage)?; }