diff --git a/Cargo.lock b/Cargo.lock index a16de6f82..2472ca389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7107,7 +7107,7 @@ checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" dependencies = [ "either", "env_home", - "rustix 1.1.3", + "rustix", "winsafe", ] diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index fbedd4835..44f41dd62 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -3608,6 +3608,7 @@ Let me check the result."#; None, // no identity config None, // no bootstrap_max_chars true, // native_tools + crate::config::SkillsPromptInjectionMode::Full, ); // Must contain zero XML protocol artifacts diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 97890d89f..ce65291f6 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -7,6 +7,8 @@ //! - Request timeouts (30s) to prevent slow-loris attacks //! - Header sanitization (handled by axum/hyper) +mod openai_compat; + use crate::channels::{Channel, LinqChannel, NextcloudTalkChannel, SendMessage, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; @@ -240,7 +242,7 @@ fn forwarded_client_ip(headers: &HeaderMap) -> Option { .and_then(parse_client_ip) } -fn client_key_from_request( +pub(crate) fn client_key_from_request( peer_addr: Option, headers: &HeaderMap, trust_forwarded_headers: bool, @@ -527,6 +529,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { if nextcloud_talk_channel.is_some() { println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook"); } + println!(" POST /v1/chat/completions — OpenAI-compatible chat"); + println!(" GET /v1/models — list available models"); println!(" GET /health — health check"); println!(" GET /metrics — Prometheus metrics"); if let Some(code) = pairing.pairing_code() { @@ -570,7 +574,20 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { observer, }; - // Build router with middleware + // Build router with middleware. + // + // The OpenAI-compatible endpoints use a larger body limit (512KB) because + // chat histories can be much bigger than the default 64KB webhook limit. + // They get their own nested router with a separate body limit layer. + let openai_compat_routes = Router::new() + .route( + "/v1/chat/completions", + post(openai_compat::handle_v1_chat_completions), + ) + .layer(RequestBodyLimitLayer::new( + openai_compat::CHAT_COMPLETIONS_MAX_BODY_SIZE, + )); + let app = Router::new() .route("/health", get(handle_health)) .route("/metrics", get(handle_metrics)) @@ -580,6 +597,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .route("/whatsapp", post(handle_whatsapp_message)) .route("/linq", post(handle_linq_webhook)) .route("/nextcloud-talk", post(handle_nextcloud_talk_webhook)) + .route("/v1/models", get(openai_compat::handle_v1_models)) + .merge(openai_compat_routes) .with_state(state) .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE)) .layer(TimeoutLayer::with_status_code( diff --git a/src/gateway/openai_compat.rs b/src/gateway/openai_compat.rs new file mode 100644 index 000000000..8d424475b --- /dev/null +++ b/src/gateway/openai_compat.rs @@ -0,0 +1,702 @@ +//! OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints. +//! +//! These endpoints allow ZeroClaw to act as a drop-in replacement for the +//! OpenAI API, enabling any OpenAI-compatible client (e.g., `openai` Python +//! library, `curl`, Aura) to send chat requests through the gateway. + +use super::AppState; +use crate::providers::traits::{ChatMessage, StreamOptions}; +use axum::{ + body::Body, + extract::{ConnectInfo, State}, + http::{header, HeaderMap, StatusCode}, + response::{IntoResponse, Json}, +}; +use futures_util::StreamExt; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use std::time::Instant; +use uuid::Uuid; + +/// Maximum body size for chat completions requests (512KB). +/// Chat histories with many messages can be much larger than the default 64KB gateway limit. +pub const CHAT_COMPLETIONS_MAX_BODY_SIZE: usize = 524_288; + +// ══════════════════════════════════════════════════════════════════════════════ +// REQUEST / RESPONSE TYPES +// ══════════════════════════════════════════════════════════════════════════════ + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionsRequest { + /// Model ID (e.g. "anthropic/claude-sonnet-4"). Falls back to gateway default. + #[serde(default)] + pub model: Option, + /// Conversation messages. + pub messages: Vec, + /// Sampling temperature. Falls back to gateway default. + #[serde(default)] + pub temperature: Option, + /// Whether to stream the response as SSE events. + #[serde(default)] + pub stream: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionsMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsResponse { + pub id: String, + pub object: &'static str, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: ChatCompletionsUsage, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsChoice { + pub index: u32, + pub message: ChatCompletionsResponseMessage, + pub finish_reason: &'static str, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsResponseMessage { + pub role: &'static str, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +/// SSE streaming chunk format. +#[derive(Debug, Serialize)] +struct ChatCompletionsChunk { + id: String, + object: &'static str, + created: u64, + model: String, + choices: Vec, +} + +#[derive(Debug, Serialize)] +struct ChunkChoice { + index: u32, + delta: ChunkDelta, + finish_reason: Option<&'static str>, +} + +#[derive(Debug, Serialize)] +struct ChunkDelta { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option<&'static str>, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, +} + +#[derive(Debug, Serialize)] +pub struct ModelsResponse { + pub object: &'static str, + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelObject { + pub id: String, + pub object: &'static str, + pub created: u64, + pub owned_by: String, +} + +// ══════════════════════════════════════════════════════════════════════════════ +// HANDLERS +// ══════════════════════════════════════════════════════════════════════════════ + +/// POST /v1/chat/completions — OpenAI-compatible chat endpoint. +pub async fn handle_v1_chat_completions( + State(state): State, + ConnectInfo(peer_addr): ConnectInfo, + headers: HeaderMap, + body: axum::body::Bytes, +) -> impl IntoResponse { + // ── Rate limit ── + let rate_key = + super::client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); + if !state.rate_limiter.allow_webhook(&rate_key) { + tracing::warn!("/v1/chat/completions rate limit exceeded"); + let err = serde_json::json!({ + "error": { + "message": "Rate limit exceeded. Please retry later.", + "type": "rate_limit_error", + "code": "rate_limit_exceeded" + } + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)).into_response(); + } + + // ── Bearer token auth (pairing) ── + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + tracing::warn!("/v1/chat/completions: rejected — not paired / invalid bearer token"); + let err = serde_json::json!({ + "error": { + "message": "Invalid API key. Pair first via POST /pair, then use Authorization: Bearer ", + "type": "invalid_request_error", + "code": "invalid_api_key" + } + }); + return (StatusCode::UNAUTHORIZED, Json(err)).into_response(); + } + } + + // ── Enforce body size limit (since this route uses a separate limit) ── + if body.len() > CHAT_COMPLETIONS_MAX_BODY_SIZE { + let err = serde_json::json!({ + "error": { + "message": format!("Request body too large ({} bytes, max {})", body.len(), CHAT_COMPLETIONS_MAX_BODY_SIZE), + "type": "invalid_request_error", + "code": "request_too_large" + } + }); + return (StatusCode::PAYLOAD_TOO_LARGE, Json(err)).into_response(); + } + + // ── Parse body ── + let request: ChatCompletionsRequest = match serde_json::from_slice(&body) { + Ok(req) => req, + Err(e) => { + tracing::warn!("/v1/chat/completions JSON parse error: {e}"); + let err = serde_json::json!({ + "error": { + "message": format!("Invalid JSON body: {e}"), + "type": "invalid_request_error", + "code": "invalid_json" + } + }); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + }; + + if request.messages.is_empty() { + let err = serde_json::json!({ + "error": { + "message": "messages array must not be empty", + "type": "invalid_request_error", + "code": "invalid_messages" + } + }); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + + let model = request + .model + .as_deref() + .filter(|m| !m.is_empty()) + .unwrap_or(&state.model) + .to_string(); + let temperature = request.temperature.unwrap_or(state.temperature); + let stream = request.stream.unwrap_or(false); + + // Convert messages to provider format + let messages: Vec = request + .messages + .iter() + .map(|m| ChatMessage { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); + let started_at = Instant::now(); + + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmRequest { + provider: provider_label.clone(), + model: model.clone(), + messages_count: messages.len(), + }); + + if stream { + handle_streaming(state, messages, model, temperature, provider_label, started_at) + .into_response() + } else { + handle_non_streaming(state, messages, model, temperature, provider_label, started_at) + .await + .into_response() + } +} + +/// Non-streaming chat completions. +async fn handle_non_streaming( + state: AppState, + messages: Vec, + model: String, + temperature: f64, + provider_label: String, + started_at: Instant, +) -> impl IntoResponse { + match state + .provider + .chat_with_history(&messages, &model, temperature) + .await + { + Ok(response_text) => { + let duration = started_at.elapsed(); + record_success(&state, &provider_label, &model, duration); + + let completion_tokens = (response_text.len() / 4) as u32; + let prompt_tokens = messages.iter().map(|m| m.content.len() / 4).sum::() as u32; + + let response = ChatCompletionsResponse { + id: format!("chatcmpl-{}", Uuid::new_v4()), + object: "chat.completion", + created: unix_timestamp(), + model: model.clone(), + choices: vec![ChatCompletionsChoice { + index: 0, + message: ChatCompletionsResponseMessage { + role: "assistant", + content: response_text, + }, + finish_reason: "stop", + }], + usage: ChatCompletionsUsage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + }; + + (StatusCode::OK, Json(serde_json::to_value(response).unwrap())).into_response() + } + Err(e) => { + let duration = started_at.elapsed(); + let sanitized = crate::providers::sanitize_api_error(&e.to_string()); + record_failure(&state, &provider_label, &model, duration, &sanitized); + + tracing::error!("/v1/chat/completions provider error: {sanitized}"); + let err = serde_json::json!({ + "error": { + "message": "LLM request failed", + "type": "server_error", + "code": "provider_error" + } + }); + (StatusCode::INTERNAL_SERVER_ERROR, Json(err)).into_response() + } + } +} + +/// Streaming chat completions via SSE. +fn handle_streaming( + state: AppState, + messages: Vec, + model: String, + temperature: f64, + provider_label: String, + started_at: Instant, +) -> impl IntoResponse { + let request_id = format!("chatcmpl-{}", Uuid::new_v4()); + let created = unix_timestamp(); + + if !state.provider.supports_streaming() { + // Provider doesn't support streaming — fall back to a single-chunk response + let model_clone = model.clone(); + let id = request_id.clone(); + + let stream = futures_util::stream::once(async move { + match state + .provider + .chat_with_history(&messages, &model_clone, temperature) + .await + { + Ok(text) => { + let duration = started_at.elapsed(); + record_success(&state, &provider_label, &model_clone, duration); + + let chunk = ChatCompletionsChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: model_clone, + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: Some("assistant"), + content: Some(text), + }, + finish_reason: Some("stop"), + }], + }; + let json = + serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_string()); + let mut output = format!("data: {json}\n\n"); + output.push_str("data: [DONE]\n\n"); + Ok::<_, std::io::Error>(axum::body::Bytes::from(output)) + } + Err(e) => { + let duration = started_at.elapsed(); + let sanitized = crate::providers::sanitize_api_error(&e.to_string()); + record_failure(&state, &provider_label, &model_clone, duration, &sanitized); + + let error_json = serde_json::json!({"error": sanitized}); + let output = format!("data: {error_json}\n\ndata: [DONE]\n\n"); + Ok(axum::body::Bytes::from(output)) + } + } + }); + + return axum::response::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .body(Body::from_stream(stream)) + .unwrap() + .into_response(); + } + + // Provider supports native streaming + let provider_stream = state.provider.stream_chat_with_history( + &messages, + &model, + temperature, + StreamOptions::new(true), + ); + + let model_for_stream = model.clone(); + let state_for_stream = state.clone(); + let provider_label_for_stream = provider_label.clone(); + let mut first_chunk = true; + let mut errored = false; + + let sse_stream = provider_stream + .map(move |result| { + match result { + Ok(chunk) if chunk.is_final => { + if !errored { + let duration = started_at.elapsed(); + record_success( + &state_for_stream, + &provider_label_for_stream, + &model_for_stream, + duration, + ); + } + Ok::<_, std::io::Error>(axum::body::Bytes::from("data: [DONE]\n\n")) + } + Ok(chunk) => { + let role = if first_chunk { + first_chunk = false; + Some("assistant") + } else { + None + }; + + let sse_chunk = ChatCompletionsChunk { + id: request_id.clone(), + object: "chat.completion.chunk", + created, + model: model_for_stream.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role, + content: if chunk.delta.is_empty() { + None + } else { + Some(chunk.delta) + }, + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&sse_chunk) + .unwrap_or_else(|_| "{}".to_string()); + Ok(axum::body::Bytes::from(format!("data: {json}\n\n"))) + } + Err(e) => { + errored = true; + let duration = started_at.elapsed(); + let msg = e.to_string(); + record_failure( + &state_for_stream, + &provider_label_for_stream, + &model_for_stream, + duration, + &msg, + ); + let error_json = serde_json::json!({"error": msg}); + Ok(axum::body::Bytes::from(format!("data: {error_json}\n\n"))) + } + } + }); + + axum::response::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .body(Body::from_stream(sse_stream)) + .unwrap() + .into_response() +} + +/// GET /v1/models — List available models. +pub async fn handle_v1_models( + State(state): State, + headers: HeaderMap, +) -> impl IntoResponse { + // ── Bearer token auth (pairing) ── + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + let err = serde_json::json!({ + "error": { + "message": "Invalid API key", + "type": "invalid_request_error", + "code": "invalid_api_key" + } + }); + return (StatusCode::UNAUTHORIZED, Json(err)); + } + } + + let response = ModelsResponse { + object: "list", + data: vec![ModelObject { + id: state.model.clone(), + object: "model", + created: unix_timestamp(), + owned_by: "zeroclaw".to_string(), + }], + }; + + (StatusCode::OK, Json(serde_json::to_value(response).unwrap())) +} + +// ══════════════════════════════════════════════════════════════════════════════ +// HELPERS +// ══════════════════════════════════════════════════════════════════════════════ + +fn unix_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn record_success( + state: &AppState, + provider_label: &str, + model: &str, + duration: std::time::Duration, +) { + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.to_string(), + model: model.to_string(), + duration, + success: true, + error_message: None, + }); + state + .observer + .record_metric(&crate::observability::traits::ObserverMetric::RequestLatency( + duration, + )); +} + +fn record_failure( + state: &AppState, + provider_label: &str, + model: &str, + duration: std::time::Duration, + error_message: &str, +) { + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.to_string(), + model: model.to_string(), + duration, + success: false, + error_message: Some(error_message.to_string()), + }); + state + .observer + .record_metric(&crate::observability::traits::ObserverMetric::RequestLatency( + duration, + )); + state + .observer + .record_event(&crate::observability::ObserverEvent::Error { + component: "gateway".to_string(), + message: error_message.to_string(), + }); +} + +// ══════════════════════════════════════════════════════════════════════════════ +// TESTS +// ══════════════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chat_completions_request_deserializes_minimal() { + let json = r#"{"messages": [{"role": "user", "content": "Hello"}]}"#; + let req: ChatCompletionsRequest = serde_json::from_str(json).unwrap(); + assert!(req.model.is_none()); + assert!(req.temperature.is_none()); + assert!(req.stream.is_none()); + assert_eq!(req.messages.len(), 1); + assert_eq!(req.messages[0].role, "user"); + assert_eq!(req.messages[0].content, "Hello"); + } + + #[test] + fn chat_completions_request_deserializes_full() { + let json = r#"{ + "model": "anthropic/claude-sonnet-4", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"} + ], + "temperature": 0.5, + "stream": true + }"#; + let req: ChatCompletionsRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.model.as_deref(), Some("anthropic/claude-sonnet-4")); + assert_eq!(req.temperature, Some(0.5)); + assert_eq!(req.stream, Some(true)); + assert_eq!(req.messages.len(), 2); + } + + #[test] + fn chat_completions_response_serializes() { + let response = ChatCompletionsResponse { + id: "chatcmpl-test".to_string(), + object: "chat.completion", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChatCompletionsChoice { + index: 0, + message: ChatCompletionsResponseMessage { + role: "assistant", + content: "Hello!".to_string(), + }, + finish_reason: "stop", + }], + usage: ChatCompletionsUsage { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }; + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("chatcmpl-test")); + assert!(json.contains("chat.completion")); + assert!(json.contains("Hello!")); + assert!(json.contains("stop")); + } + + #[test] + fn models_response_serializes() { + let response = ModelsResponse { + object: "list", + data: vec![ModelObject { + id: "anthropic/claude-sonnet-4".to_string(), + object: "model", + created: 1234567890, + owned_by: "zeroclaw".to_string(), + }], + }; + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("\"object\":\"list\"")); + assert!(json.contains("anthropic/claude-sonnet-4")); + assert!(json.contains("zeroclaw")); + } + + #[test] + fn streaming_chunk_serializes() { + let chunk = ChatCompletionsChunk { + id: "chatcmpl-test".to_string(), + object: "chat.completion.chunk", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: Some("assistant"), + content: Some("Hello".to_string()), + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("chat.completion.chunk")); + assert!(json.contains("Hello")); + assert!(json.contains("assistant")); + } + + #[test] + fn streaming_chunk_omits_none_fields() { + let chunk = ChatCompletionsChunk { + id: "chatcmpl-test".to_string(), + object: "chat.completion.chunk", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: None, + content: None, + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&chunk).unwrap(); + assert!(!json.contains("role")); + assert!(!json.contains("content")); + } + + #[test] + fn unix_timestamp_is_reasonable() { + let ts = unix_timestamp(); + // Should be after 2024-01-01 and before 2030-01-01 + assert!(ts > 1_704_067_200); + assert!(ts < 1_893_456_000); + } + + #[test] + fn body_size_limit_is_512kb() { + assert_eq!(CHAT_COMPLETIONS_MAX_BODY_SIZE, 524_288); + } +} diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs index 7f8e9fcfe..4f20483b8 100644 --- a/src/providers/bedrock.rs +++ b/src/providers/bedrock.rs @@ -6,10 +6,12 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ProviderCapabilities, ToolCall as ProviderToolCall, ToolsPayload, + Provider, ProviderCapabilities, StreamChunk, StreamError, StreamOptions, StreamResult, + ToolCall as ProviderToolCall, ToolsPayload, }; use crate::tools::ToolSpec; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use hmac::{Hmac, Mac}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -382,6 +384,11 @@ impl BedrockProvider { format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse") } + /// Build the streaming request URL (converse-stream endpoint). + fn stream_endpoint_url(region: &str, model_id: &str) -> String { + format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse-stream") + } + /// Build the canonical URI for SigV4 signing. Must URI-encode the path /// per SigV4 spec: colons become `%3A`. AWS verifies the signature against /// the encoded form even though the wire request uses raw colons. @@ -390,6 +397,12 @@ impl BedrockProvider { format!("/model/{encoded}/converse") } + /// Canonical URI for the streaming endpoint. + fn stream_canonical_uri(model_id: &str) -> String { + let encoded = Self::encode_model_path(model_id); + format!("/model/{encoded}/converse-stream") + } + fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> { self.credentials.as_ref().ok_or_else(|| { anyhow::anyhow!( @@ -735,6 +748,244 @@ impl BedrockProvider { let converse_response: ConverseResponse = response.json().await?; Ok(converse_response) } + + /// Send a signed request to the ConverseStream endpoint and return the raw + /// response for event-stream parsing. + async fn send_converse_stream_request( + &self, + credentials: &AwsCredentials, + model: &str, + request_body: &ConverseRequest, + ) -> anyhow::Result { + let payload = serde_json::to_vec(request_body)?; + let url = Self::stream_endpoint_url(&credentials.region, model); + let canonical_uri = Self::stream_canonical_uri(model); + let now = chrono::Utc::now(); + let host = credentials.host(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut headers_to_sign = vec![ + ("content-type".to_string(), "application/json".to_string()), + ("host".to_string(), host), + ("x-amz-date".to_string(), amz_date.clone()), + ]; + if let Some(ref token) = credentials.session_token { + headers_to_sign.push(("x-amz-security-token".to_string(), token.clone())); + } + headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0)); + + let authorization = build_authorization_header( + credentials, + "POST", + &canonical_uri, + "", + &headers_to_sign, + &payload, + &now, + ); + + let mut request = self + .http_client() + .post(&url) + .header("content-type", "application/json") + .header("x-amz-date", &amz_date) + .header("authorization", &authorization); + + if let Some(ref token) = credentials.session_token { + request = request.header("x-amz-security-token", token); + } + + let response = request.body(payload).send().await?; + + if !response.status().is_success() { + return Err(super::api_error("Bedrock", response).await); + } + + Ok(response) + } +} + +// ── AWS Event-Stream Binary Parser ────────────────────────────── +// +// Bedrock ConverseStream returns `application/vnd.amazon.eventstream` +// binary format. Each message is: +// [total_byte_length: u32 BE] +// [headers_byte_length: u32 BE] +// [prelude_crc: u32 BE] +// [headers: variable] +// [payload: variable] +// [message_crc: u32 BE] +// +// We skip CRC validation since the connection is already TLS-protected. + +/// Parse a single event-stream message from a byte buffer. +/// Returns `(event_type, payload_bytes, total_consumed)` or None if not enough data. +fn parse_event_stream_message(buf: &[u8]) -> Option<(String, Vec, usize)> { + // Minimum message: 4 (total_len) + 4 (header_len) + 4 (prelude_crc) + 4 (message_crc) = 16 + if buf.len() < 16 { + return None; + } + + let total_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < total_len { + return None; + } + + let headers_len = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; + // prelude_crc is at bytes 8..12, skip it + let headers_start = 12; + let headers_end = headers_start + headers_len; + let payload_start = headers_end; + let payload_end = total_len - 4; // 4 bytes for message_crc + + // Parse headers to find :event-type + let mut event_type = String::new(); + let mut pos = headers_start; + while pos < headers_end { + if pos >= buf.len() { + break; + } + let name_len = buf[pos] as usize; + pos += 1; + if pos + name_len > buf.len() { + break; + } + let name = String::from_utf8_lossy(&buf[pos..pos + name_len]).to_string(); + pos += name_len; + if pos >= buf.len() { + break; + } + let value_type = buf[pos]; + pos += 1; + match value_type { + 7 => { + // String type + if pos + 2 > buf.len() { + break; + } + let val_len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize; + pos += 2; + if pos + val_len > buf.len() { + break; + } + let value = String::from_utf8_lossy(&buf[pos..pos + val_len]).to_string(); + pos += val_len; + if name == ":event-type" { + event_type = value; + } + } + _ => { + // Skip other header types. Most are fixed-size or have length prefixes. + // For safety, just break if we hit an unknown type. + break; + } + } + } + + let payload = if payload_start < payload_end && payload_end <= buf.len() { + buf[payload_start..payload_end].to_vec() + } else { + Vec::new() + }; + + Some((event_type, payload, total_len)) +} + +/// Bedrock converse-stream event payloads. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ContentBlockDelta { + #[allow(dead_code)] + content_block_index: Option, + delta: DeltaContent, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct DeltaContent { + #[serde(default)] + text: Option, +} + +/// Convert a Bedrock converse-stream byte response into a stream of `StreamChunk`s. +fn bedrock_event_stream_to_chunks( + response: reqwest::Response, + count_tokens: bool, +) -> stream::BoxStream<'static, StreamResult> { + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let mut buffer = Vec::new(); + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + + // Try to parse complete messages from the buffer + loop { + match parse_event_stream_message(&buffer) { + Some((event_type, payload, consumed)) => { + buffer.drain(..consumed); + + match event_type.as_str() { + "contentBlockDelta" => { + if let Ok(delta) = + serde_json::from_slice::(&payload) + { + if let Some(text) = delta.delta.text { + if !text.is_empty() { + let mut chunk = StreamChunk::delta(text); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; + } + } + } + } + } + "messageStop" | "metadata" => { + // Will send final chunk after the loop + } + "messageStart" | "contentBlockStart" + | "contentBlockStop" => { + // Informational, skip + } + other if other.contains("Exception") + || other.contains("Error") => + { + let msg = String::from_utf8_lossy(&payload).to_string(); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream error ({other}): {msg}" + )))) + .await; + return; + } + _ => { + // Unknown event type, skip + } + } + } + None => break, // Not enough data yet + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + // Send final chunk + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + stream::unfold(rx, |mut rx| async { rx.recv().await.map(|chunk| (chunk, rx)) }).boxed() } // ── Provider trait implementation ─────────────────────────────── @@ -864,6 +1115,208 @@ impl Provider for BedrockProvider { Ok(Self::parse_converse_response(response)) } + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let credentials = match self.require_credentials() { + Ok(c) => c, + Err(_) => { + return stream::once(async { + Err(StreamError::Provider( + "AWS Bedrock credentials not set".to_string(), + )) + }) + .boxed(); + } + }; + + let system = system_prompt.map(|text| { + let mut blocks = vec![SystemBlock::Text(TextBlock { + text: text.to_string(), + })]; + if Self::should_cache_system(text) { + blocks.push(SystemBlock::CachePoint(CachePointWrapper { + cache_point: CachePoint::default_cache(), + })); + } + blocks + }); + + let request = ConverseRequest { + system, + messages: vec![ConverseMessage { + role: "user".to_string(), + content: Self::parse_user_content_blocks(message), + }], + inference_config: Some(InferenceConfig { + max_tokens: DEFAULT_MAX_TOKENS, + temperature, + }), + tool_config: None, + }; + + // Clone what we need for the async block + let credentials = AwsCredentials { + access_key_id: credentials.access_key_id.clone(), + secret_access_key: credentials.secret_access_key.clone(), + session_token: credentials.session_token.clone(), + region: credentials.region.clone(), + }; + let model = model.to_string(); + let count_tokens = options.count_tokens; + let client = self.http_client(); + + // We need to send the request asynchronously, then convert the response to a stream. + // Use a channel to bridge the async setup with the streaming response. + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let payload = match serde_json::to_vec(&request) { + Ok(p) => p, + Err(e) => { + let _ = tx + .send(Err(StreamError::Provider(format!( + "Failed to serialize request: {e}" + )))) + .await; + return; + } + }; + + let url = BedrockProvider::stream_endpoint_url(&credentials.region, &model); + let canonical_uri = BedrockProvider::stream_canonical_uri(&model); + let now = chrono::Utc::now(); + let host = credentials.host(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut headers_to_sign = vec![ + ("content-type".to_string(), "application/json".to_string()), + ("host".to_string(), host), + ("x-amz-date".to_string(), amz_date.clone()), + ]; + if let Some(ref token) = credentials.session_token { + headers_to_sign.push(("x-amz-security-token".to_string(), token.clone())); + } + headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0)); + + let authorization = build_authorization_header( + &credentials, + "POST", + &canonical_uri, + "", + &headers_to_sign, + &payload, + &now, + ); + + let mut req = client + .post(&url) + .header("content-type", "application/json") + .header("x-amz-date", &amz_date) + .header("authorization", &authorization); + + if let Some(ref token) = credentials.session_token { + req = req.header("x-amz-security-token", token); + } + + let response = match req.body(payload).send().await { + Ok(r) => r, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream request failed ({status}): {body}" + )))) + .await; + return; + } + + // Parse the binary event stream + let mut buffer = Vec::new(); + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + + loop { + match parse_event_stream_message(&buffer) { + Some((event_type, payload_bytes, consumed)) => { + buffer.drain(..consumed); + + match event_type.as_str() { + "contentBlockDelta" => { + if let Ok(delta) = + serde_json::from_slice::( + &payload_bytes, + ) + { + if let Some(text) = delta.delta.text { + if !text.is_empty() { + let mut chunk = StreamChunk::delta(text); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; + } + } + } + } + } + other + if other.contains("Exception") + || other.contains("Error") => + { + let msg = + String::from_utf8_lossy(&payload_bytes).to_string(); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream error ({other}): {msg}" + )))) + .await; + return; + } + _ => {} // messageStart, contentBlockStart, contentBlockStop, messageStop, metadata — skip + } + } + None => break, + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + stream::unfold(rx, |mut rx| async { rx.recv().await.map(|chunk| (chunk, rx)) }).boxed() + } + async fn warmup(&self) -> anyhow::Result<()> { if let Some(ref creds) = self.credentials { let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region); @@ -1334,4 +1787,155 @@ mod tests { let caps = provider.capabilities(); assert!(caps.native_tool_calling); } + + // ── Streaming tests ────────────────────────────────────────── + + #[test] + fn supports_streaming_returns_true() { + let provider = BedrockProvider { credentials: None }; + assert!(provider.supports_streaming()); + } + + #[test] + fn stream_endpoint_url_formats_correctly() { + let url = BedrockProvider::stream_endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6"); + assert_eq!( + url, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse-stream" + ); + } + + #[test] + fn stream_canonical_uri_encodes_colon() { + let uri = + BedrockProvider::stream_canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0"); + assert_eq!( + uri, + "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream" + ); + } + + #[test] + fn stream_canonical_uri_no_colon() { + let uri = BedrockProvider::stream_canonical_uri("anthropic.claude-sonnet-4-6"); + assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse-stream"); + } + + // ── Event-stream parser tests ──────────────────────────────── + + /// Helper: build a minimal AWS event-stream message with a string `:event-type` header. + fn build_event_stream_message(event_type: &str, payload: &[u8]) -> Vec { + // Header: `:event-type` as string (type 7) + let header_name = b":event-type"; + let header_name_len = header_name.len() as u8; + let event_type_bytes = event_type.as_bytes(); + let event_type_len = event_type_bytes.len() as u16; + + // Header bytes: 1 (name_len) + name + 1 (type=7) + 2 (val_len) + val + let headers_len = 1 + header_name.len() + 1 + 2 + event_type_bytes.len(); + // Total: 4 (total_len) + 4 (headers_len) + 4 (prelude_crc) + headers + payload + 4 (message_crc) + let total_len = 12 + headers_len + payload.len() + 4; + + let mut msg = Vec::with_capacity(total_len); + msg.extend_from_slice(&(total_len as u32).to_be_bytes()); + msg.extend_from_slice(&(headers_len as u32).to_be_bytes()); + msg.extend_from_slice(&0u32.to_be_bytes()); // prelude_crc (skipped) + + // Write header + msg.push(header_name_len); + msg.extend_from_slice(header_name); + msg.push(7); // string type + msg.extend_from_slice(&event_type_len.to_be_bytes()); + msg.extend_from_slice(event_type_bytes); + + // Write payload + msg.extend_from_slice(payload); + + // Write message CRC (skipped, just zeros) + msg.extend_from_slice(&0u32.to_be_bytes()); + + msg + } + + #[test] + fn parse_event_stream_message_content_block_delta() { + let payload = br#"{"contentBlockIndex":0,"delta":{"text":"Hello"}}"#; + let msg = build_event_stream_message("contentBlockDelta", payload); + + let result = parse_event_stream_message(&msg); + assert!(result.is_some()); + let (event_type, parsed_payload, consumed) = result.unwrap(); + assert_eq!(event_type, "contentBlockDelta"); + assert_eq!(consumed, msg.len()); + + let delta: ContentBlockDelta = serde_json::from_slice(&parsed_payload).unwrap(); + assert_eq!(delta.delta.text.as_deref(), Some("Hello")); + } + + #[test] + fn parse_event_stream_message_stop() { + let payload = br#"{"stopReason":"end_turn"}"#; + let msg = build_event_stream_message("messageStop", payload); + + let result = parse_event_stream_message(&msg); + assert!(result.is_some()); + let (event_type, _, _) = result.unwrap(); + assert_eq!(event_type, "messageStop"); + } + + #[test] + fn parse_event_stream_message_insufficient_data() { + // Only 10 bytes — not enough for even the minimum 16-byte message + let buf = vec![0u8; 10]; + assert!(parse_event_stream_message(&buf).is_none()); + } + + #[test] + fn parse_event_stream_message_incomplete_message() { + let payload = br#"{"text":"Hi"}"#; + let msg = build_event_stream_message("contentBlockDelta", payload); + + // Truncate to simulate incomplete data + let truncated = &msg[..msg.len() - 5]; + assert!(parse_event_stream_message(truncated).is_none()); + } + + #[test] + fn parse_event_stream_multiple_messages() { + let payload1 = br#"{"contentBlockIndex":0,"delta":{"text":"Hello"}}"#; + let payload2 = br#"{"contentBlockIndex":0,"delta":{"text":" World"}}"#; + let msg1 = build_event_stream_message("contentBlockDelta", payload1); + let msg2 = build_event_stream_message("contentBlockDelta", payload2); + + let mut buf = Vec::new(); + buf.extend_from_slice(&msg1); + buf.extend_from_slice(&msg2); + + // Parse first message + let (event_type1, p1, consumed1) = parse_event_stream_message(&buf).unwrap(); + assert_eq!(event_type1, "contentBlockDelta"); + let delta1: ContentBlockDelta = serde_json::from_slice(&p1).unwrap(); + assert_eq!(delta1.delta.text.as_deref(), Some("Hello")); + + // Parse second message from remainder + let (event_type2, p2, _) = parse_event_stream_message(&buf[consumed1..]).unwrap(); + assert_eq!(event_type2, "contentBlockDelta"); + let delta2: ContentBlockDelta = serde_json::from_slice(&p2).unwrap(); + assert_eq!(delta2.delta.text.as_deref(), Some(" World")); + } + + #[test] + fn content_block_delta_deserializes() { + let json = r#"{"contentBlockIndex":0,"delta":{"text":"Hello from Bedrock"}}"#; + let delta: ContentBlockDelta = serde_json::from_str(json).unwrap(); + assert_eq!(delta.content_block_index, Some(0)); + assert_eq!(delta.delta.text.as_deref(), Some("Hello from Bedrock")); + } + + #[test] + fn content_block_delta_empty_text() { + let json = r#"{"contentBlockIndex":0,"delta":{}}"#; + let delta: ContentBlockDelta = serde_json::from_str(json).unwrap(); + assert!(delta.delta.text.is_none()); + } }