From 1eab6bc2874616377068216e1675f8713eed4e09 Mon Sep 17 00:00:00 2001 From: Shadman Hossain Date: Sun, 22 Feb 2026 19:25:11 -0500 Subject: [PATCH 1/3] fix: add missing skills_prompt_mode arg in build_system_prompt test The test was not updated when the function signature changed to include the skills_prompt_mode parameter. Co-Authored-By: Claude Opus 4.6 --- src/agent/loop_.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index fbedd4835..44f41dd62 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -3608,6 +3608,7 @@ Let me check the result."#; None, // no identity config None, // no bootstrap_max_chars true, // native_tools + crate::config::SkillsPromptInjectionMode::Full, ); // Must contain zero XML protocol artifacts From 7d4f292cdd59f10a808ed63a2e698c4552889fd4 Mon Sep 17 00:00:00 2001 From: Shadman Hossain Date: Sun, 22 Feb 2026 19:25:20 -0500 Subject: [PATCH 2/3] feat: add OpenAI-compatible /v1/chat/completions and /v1/models endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an OpenAI-compatible API surface to the gateway so that standard OpenAI client libraries can interact with ZeroClaw directly. Endpoints: - POST /v1/chat/completions — supports both streaming (SSE) and non-streaming responses, bearer token auth, rate limiting - GET /v1/models — returns the gateway's configured model The chat completions endpoint accepts the standard OpenAI request format (model, messages, temperature, stream) and returns responses in the OpenAI envelope format. Streaming uses SSE with delta chunks and a [DONE] sentinel. A 512KB body limit is applied (vs 64KB default) since chat histories can be large. When the underlying provider doesn't support native streaming, the handler falls back to wrapping the non-streaming response in a single SSE chunk for transparent compatibility. Includes 8 unit tests for request/response serialization. Co-Authored-By: Claude Opus 4.6 --- src/gateway/mod.rs | 23 +- src/gateway/openai_compat.rs | 702 +++++++++++++++++++++++++++++++++++ 2 files changed, 723 insertions(+), 2 deletions(-) create mode 100644 src/gateway/openai_compat.rs diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 97890d89f..ce65291f6 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -7,6 +7,8 @@ //! - Request timeouts (30s) to prevent slow-loris attacks //! - Header sanitization (handled by axum/hyper) +mod openai_compat; + use crate::channels::{Channel, LinqChannel, NextcloudTalkChannel, SendMessage, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; @@ -240,7 +242,7 @@ fn forwarded_client_ip(headers: &HeaderMap) -> Option { .and_then(parse_client_ip) } -fn client_key_from_request( +pub(crate) fn client_key_from_request( peer_addr: Option, headers: &HeaderMap, trust_forwarded_headers: bool, @@ -527,6 +529,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { if nextcloud_talk_channel.is_some() { println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook"); } + println!(" POST /v1/chat/completions — OpenAI-compatible chat"); + println!(" GET /v1/models — list available models"); println!(" GET /health — health check"); println!(" GET /metrics — Prometheus metrics"); if let Some(code) = pairing.pairing_code() { @@ -570,7 +574,20 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { observer, }; - // Build router with middleware + // Build router with middleware. + // + // The OpenAI-compatible endpoints use a larger body limit (512KB) because + // chat histories can be much bigger than the default 64KB webhook limit. + // They get their own nested router with a separate body limit layer. + let openai_compat_routes = Router::new() + .route( + "/v1/chat/completions", + post(openai_compat::handle_v1_chat_completions), + ) + .layer(RequestBodyLimitLayer::new( + openai_compat::CHAT_COMPLETIONS_MAX_BODY_SIZE, + )); + let app = Router::new() .route("/health", get(handle_health)) .route("/metrics", get(handle_metrics)) @@ -580,6 +597,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .route("/whatsapp", post(handle_whatsapp_message)) .route("/linq", post(handle_linq_webhook)) .route("/nextcloud-talk", post(handle_nextcloud_talk_webhook)) + .route("/v1/models", get(openai_compat::handle_v1_models)) + .merge(openai_compat_routes) .with_state(state) .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE)) .layer(TimeoutLayer::with_status_code( diff --git a/src/gateway/openai_compat.rs b/src/gateway/openai_compat.rs new file mode 100644 index 000000000..8d424475b --- /dev/null +++ b/src/gateway/openai_compat.rs @@ -0,0 +1,702 @@ +//! OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints. +//! +//! These endpoints allow ZeroClaw to act as a drop-in replacement for the +//! OpenAI API, enabling any OpenAI-compatible client (e.g., `openai` Python +//! library, `curl`, Aura) to send chat requests through the gateway. + +use super::AppState; +use crate::providers::traits::{ChatMessage, StreamOptions}; +use axum::{ + body::Body, + extract::{ConnectInfo, State}, + http::{header, HeaderMap, StatusCode}, + response::{IntoResponse, Json}, +}; +use futures_util::StreamExt; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use std::time::Instant; +use uuid::Uuid; + +/// Maximum body size for chat completions requests (512KB). +/// Chat histories with many messages can be much larger than the default 64KB gateway limit. +pub const CHAT_COMPLETIONS_MAX_BODY_SIZE: usize = 524_288; + +// ══════════════════════════════════════════════════════════════════════════════ +// REQUEST / RESPONSE TYPES +// ══════════════════════════════════════════════════════════════════════════════ + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionsRequest { + /// Model ID (e.g. "anthropic/claude-sonnet-4"). Falls back to gateway default. + #[serde(default)] + pub model: Option, + /// Conversation messages. + pub messages: Vec, + /// Sampling temperature. Falls back to gateway default. + #[serde(default)] + pub temperature: Option, + /// Whether to stream the response as SSE events. + #[serde(default)] + pub stream: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionsMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsResponse { + pub id: String, + pub object: &'static str, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: ChatCompletionsUsage, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsChoice { + pub index: u32, + pub message: ChatCompletionsResponseMessage, + pub finish_reason: &'static str, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsResponseMessage { + pub role: &'static str, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +/// SSE streaming chunk format. +#[derive(Debug, Serialize)] +struct ChatCompletionsChunk { + id: String, + object: &'static str, + created: u64, + model: String, + choices: Vec, +} + +#[derive(Debug, Serialize)] +struct ChunkChoice { + index: u32, + delta: ChunkDelta, + finish_reason: Option<&'static str>, +} + +#[derive(Debug, Serialize)] +struct ChunkDelta { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option<&'static str>, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, +} + +#[derive(Debug, Serialize)] +pub struct ModelsResponse { + pub object: &'static str, + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelObject { + pub id: String, + pub object: &'static str, + pub created: u64, + pub owned_by: String, +} + +// ══════════════════════════════════════════════════════════════════════════════ +// HANDLERS +// ══════════════════════════════════════════════════════════════════════════════ + +/// POST /v1/chat/completions — OpenAI-compatible chat endpoint. +pub async fn handle_v1_chat_completions( + State(state): State, + ConnectInfo(peer_addr): ConnectInfo, + headers: HeaderMap, + body: axum::body::Bytes, +) -> impl IntoResponse { + // ── Rate limit ── + let rate_key = + super::client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); + if !state.rate_limiter.allow_webhook(&rate_key) { + tracing::warn!("/v1/chat/completions rate limit exceeded"); + let err = serde_json::json!({ + "error": { + "message": "Rate limit exceeded. Please retry later.", + "type": "rate_limit_error", + "code": "rate_limit_exceeded" + } + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)).into_response(); + } + + // ── Bearer token auth (pairing) ── + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + tracing::warn!("/v1/chat/completions: rejected — not paired / invalid bearer token"); + let err = serde_json::json!({ + "error": { + "message": "Invalid API key. Pair first via POST /pair, then use Authorization: Bearer ", + "type": "invalid_request_error", + "code": "invalid_api_key" + } + }); + return (StatusCode::UNAUTHORIZED, Json(err)).into_response(); + } + } + + // ── Enforce body size limit (since this route uses a separate limit) ── + if body.len() > CHAT_COMPLETIONS_MAX_BODY_SIZE { + let err = serde_json::json!({ + "error": { + "message": format!("Request body too large ({} bytes, max {})", body.len(), CHAT_COMPLETIONS_MAX_BODY_SIZE), + "type": "invalid_request_error", + "code": "request_too_large" + } + }); + return (StatusCode::PAYLOAD_TOO_LARGE, Json(err)).into_response(); + } + + // ── Parse body ── + let request: ChatCompletionsRequest = match serde_json::from_slice(&body) { + Ok(req) => req, + Err(e) => { + tracing::warn!("/v1/chat/completions JSON parse error: {e}"); + let err = serde_json::json!({ + "error": { + "message": format!("Invalid JSON body: {e}"), + "type": "invalid_request_error", + "code": "invalid_json" + } + }); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + }; + + if request.messages.is_empty() { + let err = serde_json::json!({ + "error": { + "message": "messages array must not be empty", + "type": "invalid_request_error", + "code": "invalid_messages" + } + }); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + + let model = request + .model + .as_deref() + .filter(|m| !m.is_empty()) + .unwrap_or(&state.model) + .to_string(); + let temperature = request.temperature.unwrap_or(state.temperature); + let stream = request.stream.unwrap_or(false); + + // Convert messages to provider format + let messages: Vec = request + .messages + .iter() + .map(|m| ChatMessage { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); + let started_at = Instant::now(); + + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmRequest { + provider: provider_label.clone(), + model: model.clone(), + messages_count: messages.len(), + }); + + if stream { + handle_streaming(state, messages, model, temperature, provider_label, started_at) + .into_response() + } else { + handle_non_streaming(state, messages, model, temperature, provider_label, started_at) + .await + .into_response() + } +} + +/// Non-streaming chat completions. +async fn handle_non_streaming( + state: AppState, + messages: Vec, + model: String, + temperature: f64, + provider_label: String, + started_at: Instant, +) -> impl IntoResponse { + match state + .provider + .chat_with_history(&messages, &model, temperature) + .await + { + Ok(response_text) => { + let duration = started_at.elapsed(); + record_success(&state, &provider_label, &model, duration); + + let completion_tokens = (response_text.len() / 4) as u32; + let prompt_tokens = messages.iter().map(|m| m.content.len() / 4).sum::() as u32; + + let response = ChatCompletionsResponse { + id: format!("chatcmpl-{}", Uuid::new_v4()), + object: "chat.completion", + created: unix_timestamp(), + model: model.clone(), + choices: vec![ChatCompletionsChoice { + index: 0, + message: ChatCompletionsResponseMessage { + role: "assistant", + content: response_text, + }, + finish_reason: "stop", + }], + usage: ChatCompletionsUsage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + }; + + (StatusCode::OK, Json(serde_json::to_value(response).unwrap())).into_response() + } + Err(e) => { + let duration = started_at.elapsed(); + let sanitized = crate::providers::sanitize_api_error(&e.to_string()); + record_failure(&state, &provider_label, &model, duration, &sanitized); + + tracing::error!("/v1/chat/completions provider error: {sanitized}"); + let err = serde_json::json!({ + "error": { + "message": "LLM request failed", + "type": "server_error", + "code": "provider_error" + } + }); + (StatusCode::INTERNAL_SERVER_ERROR, Json(err)).into_response() + } + } +} + +/// Streaming chat completions via SSE. +fn handle_streaming( + state: AppState, + messages: Vec, + model: String, + temperature: f64, + provider_label: String, + started_at: Instant, +) -> impl IntoResponse { + let request_id = format!("chatcmpl-{}", Uuid::new_v4()); + let created = unix_timestamp(); + + if !state.provider.supports_streaming() { + // Provider doesn't support streaming — fall back to a single-chunk response + let model_clone = model.clone(); + let id = request_id.clone(); + + let stream = futures_util::stream::once(async move { + match state + .provider + .chat_with_history(&messages, &model_clone, temperature) + .await + { + Ok(text) => { + let duration = started_at.elapsed(); + record_success(&state, &provider_label, &model_clone, duration); + + let chunk = ChatCompletionsChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: model_clone, + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: Some("assistant"), + content: Some(text), + }, + finish_reason: Some("stop"), + }], + }; + let json = + serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_string()); + let mut output = format!("data: {json}\n\n"); + output.push_str("data: [DONE]\n\n"); + Ok::<_, std::io::Error>(axum::body::Bytes::from(output)) + } + Err(e) => { + let duration = started_at.elapsed(); + let sanitized = crate::providers::sanitize_api_error(&e.to_string()); + record_failure(&state, &provider_label, &model_clone, duration, &sanitized); + + let error_json = serde_json::json!({"error": sanitized}); + let output = format!("data: {error_json}\n\ndata: [DONE]\n\n"); + Ok(axum::body::Bytes::from(output)) + } + } + }); + + return axum::response::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .body(Body::from_stream(stream)) + .unwrap() + .into_response(); + } + + // Provider supports native streaming + let provider_stream = state.provider.stream_chat_with_history( + &messages, + &model, + temperature, + StreamOptions::new(true), + ); + + let model_for_stream = model.clone(); + let state_for_stream = state.clone(); + let provider_label_for_stream = provider_label.clone(); + let mut first_chunk = true; + let mut errored = false; + + let sse_stream = provider_stream + .map(move |result| { + match result { + Ok(chunk) if chunk.is_final => { + if !errored { + let duration = started_at.elapsed(); + record_success( + &state_for_stream, + &provider_label_for_stream, + &model_for_stream, + duration, + ); + } + Ok::<_, std::io::Error>(axum::body::Bytes::from("data: [DONE]\n\n")) + } + Ok(chunk) => { + let role = if first_chunk { + first_chunk = false; + Some("assistant") + } else { + None + }; + + let sse_chunk = ChatCompletionsChunk { + id: request_id.clone(), + object: "chat.completion.chunk", + created, + model: model_for_stream.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role, + content: if chunk.delta.is_empty() { + None + } else { + Some(chunk.delta) + }, + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&sse_chunk) + .unwrap_or_else(|_| "{}".to_string()); + Ok(axum::body::Bytes::from(format!("data: {json}\n\n"))) + } + Err(e) => { + errored = true; + let duration = started_at.elapsed(); + let msg = e.to_string(); + record_failure( + &state_for_stream, + &provider_label_for_stream, + &model_for_stream, + duration, + &msg, + ); + let error_json = serde_json::json!({"error": msg}); + Ok(axum::body::Bytes::from(format!("data: {error_json}\n\n"))) + } + } + }); + + axum::response::Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .body(Body::from_stream(sse_stream)) + .unwrap() + .into_response() +} + +/// GET /v1/models — List available models. +pub async fn handle_v1_models( + State(state): State, + headers: HeaderMap, +) -> impl IntoResponse { + // ── Bearer token auth (pairing) ── + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + let err = serde_json::json!({ + "error": { + "message": "Invalid API key", + "type": "invalid_request_error", + "code": "invalid_api_key" + } + }); + return (StatusCode::UNAUTHORIZED, Json(err)); + } + } + + let response = ModelsResponse { + object: "list", + data: vec![ModelObject { + id: state.model.clone(), + object: "model", + created: unix_timestamp(), + owned_by: "zeroclaw".to_string(), + }], + }; + + (StatusCode::OK, Json(serde_json::to_value(response).unwrap())) +} + +// ══════════════════════════════════════════════════════════════════════════════ +// HELPERS +// ══════════════════════════════════════════════════════════════════════════════ + +fn unix_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn record_success( + state: &AppState, + provider_label: &str, + model: &str, + duration: std::time::Duration, +) { + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.to_string(), + model: model.to_string(), + duration, + success: true, + error_message: None, + }); + state + .observer + .record_metric(&crate::observability::traits::ObserverMetric::RequestLatency( + duration, + )); +} + +fn record_failure( + state: &AppState, + provider_label: &str, + model: &str, + duration: std::time::Duration, + error_message: &str, +) { + state + .observer + .record_event(&crate::observability::ObserverEvent::LlmResponse { + provider: provider_label.to_string(), + model: model.to_string(), + duration, + success: false, + error_message: Some(error_message.to_string()), + }); + state + .observer + .record_metric(&crate::observability::traits::ObserverMetric::RequestLatency( + duration, + )); + state + .observer + .record_event(&crate::observability::ObserverEvent::Error { + component: "gateway".to_string(), + message: error_message.to_string(), + }); +} + +// ══════════════════════════════════════════════════════════════════════════════ +// TESTS +// ══════════════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chat_completions_request_deserializes_minimal() { + let json = r#"{"messages": [{"role": "user", "content": "Hello"}]}"#; + let req: ChatCompletionsRequest = serde_json::from_str(json).unwrap(); + assert!(req.model.is_none()); + assert!(req.temperature.is_none()); + assert!(req.stream.is_none()); + assert_eq!(req.messages.len(), 1); + assert_eq!(req.messages[0].role, "user"); + assert_eq!(req.messages[0].content, "Hello"); + } + + #[test] + fn chat_completions_request_deserializes_full() { + let json = r#"{ + "model": "anthropic/claude-sonnet-4", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"} + ], + "temperature": 0.5, + "stream": true + }"#; + let req: ChatCompletionsRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.model.as_deref(), Some("anthropic/claude-sonnet-4")); + assert_eq!(req.temperature, Some(0.5)); + assert_eq!(req.stream, Some(true)); + assert_eq!(req.messages.len(), 2); + } + + #[test] + fn chat_completions_response_serializes() { + let response = ChatCompletionsResponse { + id: "chatcmpl-test".to_string(), + object: "chat.completion", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChatCompletionsChoice { + index: 0, + message: ChatCompletionsResponseMessage { + role: "assistant", + content: "Hello!".to_string(), + }, + finish_reason: "stop", + }], + usage: ChatCompletionsUsage { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }; + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("chatcmpl-test")); + assert!(json.contains("chat.completion")); + assert!(json.contains("Hello!")); + assert!(json.contains("stop")); + } + + #[test] + fn models_response_serializes() { + let response = ModelsResponse { + object: "list", + data: vec![ModelObject { + id: "anthropic/claude-sonnet-4".to_string(), + object: "model", + created: 1234567890, + owned_by: "zeroclaw".to_string(), + }], + }; + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("\"object\":\"list\"")); + assert!(json.contains("anthropic/claude-sonnet-4")); + assert!(json.contains("zeroclaw")); + } + + #[test] + fn streaming_chunk_serializes() { + let chunk = ChatCompletionsChunk { + id: "chatcmpl-test".to_string(), + object: "chat.completion.chunk", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: Some("assistant"), + content: Some("Hello".to_string()), + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&chunk).unwrap(); + assert!(json.contains("chat.completion.chunk")); + assert!(json.contains("Hello")); + assert!(json.contains("assistant")); + } + + #[test] + fn streaming_chunk_omits_none_fields() { + let chunk = ChatCompletionsChunk { + id: "chatcmpl-test".to_string(), + object: "chat.completion.chunk", + created: 1234567890, + model: "test-model".to_string(), + choices: vec![ChunkChoice { + index: 0, + delta: ChunkDelta { + role: None, + content: None, + }, + finish_reason: None, + }], + }; + let json = serde_json::to_string(&chunk).unwrap(); + assert!(!json.contains("role")); + assert!(!json.contains("content")); + } + + #[test] + fn unix_timestamp_is_reasonable() { + let ts = unix_timestamp(); + // Should be after 2024-01-01 and before 2030-01-01 + assert!(ts > 1_704_067_200); + assert!(ts < 1_893_456_000); + } + + #[test] + fn body_size_limit_is_512kb() { + assert_eq!(CHAT_COMPLETIONS_MAX_BODY_SIZE, 524_288); + } +} From 6a92919b3882f71640b726c7bda6597cd5547436 Mon Sep 17 00:00:00 2001 From: Shadman Hossain Date: Sun, 22 Feb 2026 19:25:29 -0500 Subject: [PATCH 3/3] feat: add streaming support for AWS Bedrock ConverseStream API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the streaming provider trait methods for Bedrock, enabling real-time token-by-token responses via the ConverseStream endpoint. Key implementation details: - Uses /model/{id}/converse-stream endpoint with SigV4 signing - Parses AWS binary event-stream format (application/vnd.amazon.eventstream) with a minimal parser (~60 lines) — no new crate dependencies needed - Handles contentBlockDelta events for text extraction, plus error and exception events - Uses mpsc channel + stream::unfold pattern (matching compatible.rs) - Clones credentials for async task ownership The binary event-stream parser extracts frame lengths, header sections (looking for :event-type), and payload bytes. CRC validation is skipped since TLS already provides integrity guarantees. Includes 10 new tests for URL formatting, binary parsing, and deserialization. Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 2 +- src/providers/bedrock.rs | 606 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 606 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a16de6f82..2472ca389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7107,7 +7107,7 @@ checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" dependencies = [ "either", "env_home", - "rustix 1.1.3", + "rustix", "winsafe", ] diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs index 7f8e9fcfe..4f20483b8 100644 --- a/src/providers/bedrock.rs +++ b/src/providers/bedrock.rs @@ -6,10 +6,12 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ProviderCapabilities, ToolCall as ProviderToolCall, ToolsPayload, + Provider, ProviderCapabilities, StreamChunk, StreamError, StreamOptions, StreamResult, + ToolCall as ProviderToolCall, ToolsPayload, }; use crate::tools::ToolSpec; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use hmac::{Hmac, Mac}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -382,6 +384,11 @@ impl BedrockProvider { format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse") } + /// Build the streaming request URL (converse-stream endpoint). + fn stream_endpoint_url(region: &str, model_id: &str) -> String { + format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse-stream") + } + /// Build the canonical URI for SigV4 signing. Must URI-encode the path /// per SigV4 spec: colons become `%3A`. AWS verifies the signature against /// the encoded form even though the wire request uses raw colons. @@ -390,6 +397,12 @@ impl BedrockProvider { format!("/model/{encoded}/converse") } + /// Canonical URI for the streaming endpoint. + fn stream_canonical_uri(model_id: &str) -> String { + let encoded = Self::encode_model_path(model_id); + format!("/model/{encoded}/converse-stream") + } + fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> { self.credentials.as_ref().ok_or_else(|| { anyhow::anyhow!( @@ -735,6 +748,244 @@ impl BedrockProvider { let converse_response: ConverseResponse = response.json().await?; Ok(converse_response) } + + /// Send a signed request to the ConverseStream endpoint and return the raw + /// response for event-stream parsing. + async fn send_converse_stream_request( + &self, + credentials: &AwsCredentials, + model: &str, + request_body: &ConverseRequest, + ) -> anyhow::Result { + let payload = serde_json::to_vec(request_body)?; + let url = Self::stream_endpoint_url(&credentials.region, model); + let canonical_uri = Self::stream_canonical_uri(model); + let now = chrono::Utc::now(); + let host = credentials.host(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut headers_to_sign = vec![ + ("content-type".to_string(), "application/json".to_string()), + ("host".to_string(), host), + ("x-amz-date".to_string(), amz_date.clone()), + ]; + if let Some(ref token) = credentials.session_token { + headers_to_sign.push(("x-amz-security-token".to_string(), token.clone())); + } + headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0)); + + let authorization = build_authorization_header( + credentials, + "POST", + &canonical_uri, + "", + &headers_to_sign, + &payload, + &now, + ); + + let mut request = self + .http_client() + .post(&url) + .header("content-type", "application/json") + .header("x-amz-date", &amz_date) + .header("authorization", &authorization); + + if let Some(ref token) = credentials.session_token { + request = request.header("x-amz-security-token", token); + } + + let response = request.body(payload).send().await?; + + if !response.status().is_success() { + return Err(super::api_error("Bedrock", response).await); + } + + Ok(response) + } +} + +// ── AWS Event-Stream Binary Parser ────────────────────────────── +// +// Bedrock ConverseStream returns `application/vnd.amazon.eventstream` +// binary format. Each message is: +// [total_byte_length: u32 BE] +// [headers_byte_length: u32 BE] +// [prelude_crc: u32 BE] +// [headers: variable] +// [payload: variable] +// [message_crc: u32 BE] +// +// We skip CRC validation since the connection is already TLS-protected. + +/// Parse a single event-stream message from a byte buffer. +/// Returns `(event_type, payload_bytes, total_consumed)` or None if not enough data. +fn parse_event_stream_message(buf: &[u8]) -> Option<(String, Vec, usize)> { + // Minimum message: 4 (total_len) + 4 (header_len) + 4 (prelude_crc) + 4 (message_crc) = 16 + if buf.len() < 16 { + return None; + } + + let total_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < total_len { + return None; + } + + let headers_len = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; + // prelude_crc is at bytes 8..12, skip it + let headers_start = 12; + let headers_end = headers_start + headers_len; + let payload_start = headers_end; + let payload_end = total_len - 4; // 4 bytes for message_crc + + // Parse headers to find :event-type + let mut event_type = String::new(); + let mut pos = headers_start; + while pos < headers_end { + if pos >= buf.len() { + break; + } + let name_len = buf[pos] as usize; + pos += 1; + if pos + name_len > buf.len() { + break; + } + let name = String::from_utf8_lossy(&buf[pos..pos + name_len]).to_string(); + pos += name_len; + if pos >= buf.len() { + break; + } + let value_type = buf[pos]; + pos += 1; + match value_type { + 7 => { + // String type + if pos + 2 > buf.len() { + break; + } + let val_len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize; + pos += 2; + if pos + val_len > buf.len() { + break; + } + let value = String::from_utf8_lossy(&buf[pos..pos + val_len]).to_string(); + pos += val_len; + if name == ":event-type" { + event_type = value; + } + } + _ => { + // Skip other header types. Most are fixed-size or have length prefixes. + // For safety, just break if we hit an unknown type. + break; + } + } + } + + let payload = if payload_start < payload_end && payload_end <= buf.len() { + buf[payload_start..payload_end].to_vec() + } else { + Vec::new() + }; + + Some((event_type, payload, total_len)) +} + +/// Bedrock converse-stream event payloads. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ContentBlockDelta { + #[allow(dead_code)] + content_block_index: Option, + delta: DeltaContent, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct DeltaContent { + #[serde(default)] + text: Option, +} + +/// Convert a Bedrock converse-stream byte response into a stream of `StreamChunk`s. +fn bedrock_event_stream_to_chunks( + response: reqwest::Response, + count_tokens: bool, +) -> stream::BoxStream<'static, StreamResult> { + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let mut buffer = Vec::new(); + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + + // Try to parse complete messages from the buffer + loop { + match parse_event_stream_message(&buffer) { + Some((event_type, payload, consumed)) => { + buffer.drain(..consumed); + + match event_type.as_str() { + "contentBlockDelta" => { + if let Ok(delta) = + serde_json::from_slice::(&payload) + { + if let Some(text) = delta.delta.text { + if !text.is_empty() { + let mut chunk = StreamChunk::delta(text); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; + } + } + } + } + } + "messageStop" | "metadata" => { + // Will send final chunk after the loop + } + "messageStart" | "contentBlockStart" + | "contentBlockStop" => { + // Informational, skip + } + other if other.contains("Exception") + || other.contains("Error") => + { + let msg = String::from_utf8_lossy(&payload).to_string(); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream error ({other}): {msg}" + )))) + .await; + return; + } + _ => { + // Unknown event type, skip + } + } + } + None => break, // Not enough data yet + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + // Send final chunk + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + stream::unfold(rx, |mut rx| async { rx.recv().await.map(|chunk| (chunk, rx)) }).boxed() } // ── Provider trait implementation ─────────────────────────────── @@ -864,6 +1115,208 @@ impl Provider for BedrockProvider { Ok(Self::parse_converse_response(response)) } + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let credentials = match self.require_credentials() { + Ok(c) => c, + Err(_) => { + return stream::once(async { + Err(StreamError::Provider( + "AWS Bedrock credentials not set".to_string(), + )) + }) + .boxed(); + } + }; + + let system = system_prompt.map(|text| { + let mut blocks = vec![SystemBlock::Text(TextBlock { + text: text.to_string(), + })]; + if Self::should_cache_system(text) { + blocks.push(SystemBlock::CachePoint(CachePointWrapper { + cache_point: CachePoint::default_cache(), + })); + } + blocks + }); + + let request = ConverseRequest { + system, + messages: vec![ConverseMessage { + role: "user".to_string(), + content: Self::parse_user_content_blocks(message), + }], + inference_config: Some(InferenceConfig { + max_tokens: DEFAULT_MAX_TOKENS, + temperature, + }), + tool_config: None, + }; + + // Clone what we need for the async block + let credentials = AwsCredentials { + access_key_id: credentials.access_key_id.clone(), + secret_access_key: credentials.secret_access_key.clone(), + session_token: credentials.session_token.clone(), + region: credentials.region.clone(), + }; + let model = model.to_string(); + let count_tokens = options.count_tokens; + let client = self.http_client(); + + // We need to send the request asynchronously, then convert the response to a stream. + // Use a channel to bridge the async setup with the streaming response. + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let payload = match serde_json::to_vec(&request) { + Ok(p) => p, + Err(e) => { + let _ = tx + .send(Err(StreamError::Provider(format!( + "Failed to serialize request: {e}" + )))) + .await; + return; + } + }; + + let url = BedrockProvider::stream_endpoint_url(&credentials.region, &model); + let canonical_uri = BedrockProvider::stream_canonical_uri(&model); + let now = chrono::Utc::now(); + let host = credentials.host(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut headers_to_sign = vec![ + ("content-type".to_string(), "application/json".to_string()), + ("host".to_string(), host), + ("x-amz-date".to_string(), amz_date.clone()), + ]; + if let Some(ref token) = credentials.session_token { + headers_to_sign.push(("x-amz-security-token".to_string(), token.clone())); + } + headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0)); + + let authorization = build_authorization_header( + &credentials, + "POST", + &canonical_uri, + "", + &headers_to_sign, + &payload, + &now, + ); + + let mut req = client + .post(&url) + .header("content-type", "application/json") + .header("x-amz-date", &amz_date) + .header("authorization", &authorization); + + if let Some(ref token) = credentials.session_token { + req = req.header("x-amz-security-token", token); + } + + let response = match req.body(payload).send().await { + Ok(r) => r, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream request failed ({status}): {body}" + )))) + .await; + return; + } + + // Parse the binary event stream + let mut buffer = Vec::new(); + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + + loop { + match parse_event_stream_message(&buffer) { + Some((event_type, payload_bytes, consumed)) => { + buffer.drain(..consumed); + + match event_type.as_str() { + "contentBlockDelta" => { + if let Ok(delta) = + serde_json::from_slice::( + &payload_bytes, + ) + { + if let Some(text) = delta.delta.text { + if !text.is_empty() { + let mut chunk = StreamChunk::delta(text); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; + } + } + } + } + } + other + if other.contains("Exception") + || other.contains("Error") => + { + let msg = + String::from_utf8_lossy(&payload_bytes).to_string(); + let _ = tx + .send(Err(StreamError::Provider(format!( + "Bedrock stream error ({other}): {msg}" + )))) + .await; + return; + } + _ => {} // messageStart, contentBlockStart, contentBlockStop, messageStop, metadata — skip + } + } + None => break, + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + stream::unfold(rx, |mut rx| async { rx.recv().await.map(|chunk| (chunk, rx)) }).boxed() + } + async fn warmup(&self) -> anyhow::Result<()> { if let Some(ref creds) = self.credentials { let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region); @@ -1334,4 +1787,155 @@ mod tests { let caps = provider.capabilities(); assert!(caps.native_tool_calling); } + + // ── Streaming tests ────────────────────────────────────────── + + #[test] + fn supports_streaming_returns_true() { + let provider = BedrockProvider { credentials: None }; + assert!(provider.supports_streaming()); + } + + #[test] + fn stream_endpoint_url_formats_correctly() { + let url = BedrockProvider::stream_endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6"); + assert_eq!( + url, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse-stream" + ); + } + + #[test] + fn stream_canonical_uri_encodes_colon() { + let uri = + BedrockProvider::stream_canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0"); + assert_eq!( + uri, + "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream" + ); + } + + #[test] + fn stream_canonical_uri_no_colon() { + let uri = BedrockProvider::stream_canonical_uri("anthropic.claude-sonnet-4-6"); + assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse-stream"); + } + + // ── Event-stream parser tests ──────────────────────────────── + + /// Helper: build a minimal AWS event-stream message with a string `:event-type` header. + fn build_event_stream_message(event_type: &str, payload: &[u8]) -> Vec { + // Header: `:event-type` as string (type 7) + let header_name = b":event-type"; + let header_name_len = header_name.len() as u8; + let event_type_bytes = event_type.as_bytes(); + let event_type_len = event_type_bytes.len() as u16; + + // Header bytes: 1 (name_len) + name + 1 (type=7) + 2 (val_len) + val + let headers_len = 1 + header_name.len() + 1 + 2 + event_type_bytes.len(); + // Total: 4 (total_len) + 4 (headers_len) + 4 (prelude_crc) + headers + payload + 4 (message_crc) + let total_len = 12 + headers_len + payload.len() + 4; + + let mut msg = Vec::with_capacity(total_len); + msg.extend_from_slice(&(total_len as u32).to_be_bytes()); + msg.extend_from_slice(&(headers_len as u32).to_be_bytes()); + msg.extend_from_slice(&0u32.to_be_bytes()); // prelude_crc (skipped) + + // Write header + msg.push(header_name_len); + msg.extend_from_slice(header_name); + msg.push(7); // string type + msg.extend_from_slice(&event_type_len.to_be_bytes()); + msg.extend_from_slice(event_type_bytes); + + // Write payload + msg.extend_from_slice(payload); + + // Write message CRC (skipped, just zeros) + msg.extend_from_slice(&0u32.to_be_bytes()); + + msg + } + + #[test] + fn parse_event_stream_message_content_block_delta() { + let payload = br#"{"contentBlockIndex":0,"delta":{"text":"Hello"}}"#; + let msg = build_event_stream_message("contentBlockDelta", payload); + + let result = parse_event_stream_message(&msg); + assert!(result.is_some()); + let (event_type, parsed_payload, consumed) = result.unwrap(); + assert_eq!(event_type, "contentBlockDelta"); + assert_eq!(consumed, msg.len()); + + let delta: ContentBlockDelta = serde_json::from_slice(&parsed_payload).unwrap(); + assert_eq!(delta.delta.text.as_deref(), Some("Hello")); + } + + #[test] + fn parse_event_stream_message_stop() { + let payload = br#"{"stopReason":"end_turn"}"#; + let msg = build_event_stream_message("messageStop", payload); + + let result = parse_event_stream_message(&msg); + assert!(result.is_some()); + let (event_type, _, _) = result.unwrap(); + assert_eq!(event_type, "messageStop"); + } + + #[test] + fn parse_event_stream_message_insufficient_data() { + // Only 10 bytes — not enough for even the minimum 16-byte message + let buf = vec![0u8; 10]; + assert!(parse_event_stream_message(&buf).is_none()); + } + + #[test] + fn parse_event_stream_message_incomplete_message() { + let payload = br#"{"text":"Hi"}"#; + let msg = build_event_stream_message("contentBlockDelta", payload); + + // Truncate to simulate incomplete data + let truncated = &msg[..msg.len() - 5]; + assert!(parse_event_stream_message(truncated).is_none()); + } + + #[test] + fn parse_event_stream_multiple_messages() { + let payload1 = br#"{"contentBlockIndex":0,"delta":{"text":"Hello"}}"#; + let payload2 = br#"{"contentBlockIndex":0,"delta":{"text":" World"}}"#; + let msg1 = build_event_stream_message("contentBlockDelta", payload1); + let msg2 = build_event_stream_message("contentBlockDelta", payload2); + + let mut buf = Vec::new(); + buf.extend_from_slice(&msg1); + buf.extend_from_slice(&msg2); + + // Parse first message + let (event_type1, p1, consumed1) = parse_event_stream_message(&buf).unwrap(); + assert_eq!(event_type1, "contentBlockDelta"); + let delta1: ContentBlockDelta = serde_json::from_slice(&p1).unwrap(); + assert_eq!(delta1.delta.text.as_deref(), Some("Hello")); + + // Parse second message from remainder + let (event_type2, p2, _) = parse_event_stream_message(&buf[consumed1..]).unwrap(); + assert_eq!(event_type2, "contentBlockDelta"); + let delta2: ContentBlockDelta = serde_json::from_slice(&p2).unwrap(); + assert_eq!(delta2.delta.text.as_deref(), Some(" World")); + } + + #[test] + fn content_block_delta_deserializes() { + let json = r#"{"contentBlockIndex":0,"delta":{"text":"Hello from Bedrock"}}"#; + let delta: ContentBlockDelta = serde_json::from_str(json).unwrap(); + assert_eq!(delta.content_block_index, Some(0)); + assert_eq!(delta.delta.text.as_deref(), Some("Hello from Bedrock")); + } + + #[test] + fn content_block_delta_empty_text() { + let json = r#"{"contentBlockIndex":0,"delta":{}}"#; + let delta: ContentBlockDelta = serde_json::from_str(json).unwrap(); + assert!(delta.delta.text.is_none()); + } }