Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions codex-rs/app-server/src/codex_message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3509,7 +3509,7 @@ impl CodexMessageProcessor {
INVALID_REQUEST_ERROR_CODE,
format!("failed to load rollout `{}`: {err}", rollout_path.display()),
),
CodexErr::InvalidRequest(message) => (INVALID_REQUEST_ERROR_CODE, message),
CodexErr::InvalidRequest(err) => (INVALID_REQUEST_ERROR_CODE, err.message),
_ => (INTERNAL_ERROR_CODE, format!("error forking thread: {err}")),
};
let error = JSONRPCErrorError {
Expand Down Expand Up @@ -5837,6 +5837,14 @@ impl CodexMessageProcessor {
break;
}
};
let event = event.clone();
if matches!(
event.msg,
EventMsg::ResponsesApiRequestId(_)
| EventMsg::ResponsesApiResponseId(_)
) {
continue;
}

// For now, we send a notification for every event,
// JSON-serializing the `Event` as-is, but these should
Expand Down Expand Up @@ -5875,7 +5883,9 @@ impl CodexMessageProcessor {
let subscribed_connection_ids = thread_state_manager
.subscribed_connection_ids(conversation_id)
.await;
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
if let EventMsg::RawResponseItem(_) = &event.msg
&& !raw_events_enabled
{
continue;
}

Expand Down
14 changes: 13 additions & 1 deletion codex-rs/codex-api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use futures::Stream;
use http::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
Expand Down Expand Up @@ -54,7 +55,9 @@ pub struct MemorySummarizeOutput {

#[derive(Debug)]
pub enum ResponseEvent {
Created,
Created {
response_id: Option<String>,
},
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
/// Emitted when the server includes `OpenAI-Model` on the stream response.
Expand Down Expand Up @@ -84,6 +87,14 @@ pub enum ResponseEvent {
ModelsEtag(String),
}

pub(crate) fn extract_request_id(headers: &HeaderMap) -> Option<String> {
headers
.get("x-request-id")
.or_else(|| headers.get("x-oai-request-id"))
.and_then(|value| value.to_str().ok())
.map(str::to_string)
}

#[derive(Debug, Serialize, Clone, PartialEq)]
pub struct Reasoning {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -238,6 +249,7 @@ pub fn create_text_param_for_request(

pub struct ResponseStream {
pub rx_event: mpsc::Receiver<Result<ResponseEvent, ApiError>>,
pub request_id_for_rollout_log: Option<String>,
}

impl Stream for ResponseStream {
Expand Down
8 changes: 7 additions & 1 deletion codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,13 @@ impl ResponsesWebsocketConnection {
}
});

Ok(ResponseStream { rx_event })
Ok(ResponseStream {
rx_event,
// Websocket upgrade response headers are scoped to the connection, not
// individual `stream_request` calls, so they cannot be used as per-turn
// request IDs in rollout `ResponsesApiRequestId` events.
request_id_for_rollout_log: None,
})
}
}

Expand Down
50 changes: 44 additions & 6 deletions codex-rs/codex-api/src/sse/responses.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::extract_request_id;
use crate::error::ApiError;
use crate::rate_limits::parse_all_rate_limits;
use crate::telemetry::SseTelemetry;
Expand Down Expand Up @@ -46,7 +47,10 @@ pub fn stream_from_fixture(
let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None));
Ok(ResponseStream { rx_event })
Ok(ResponseStream {
rx_event,
request_id_for_rollout_log: None,
})
}

pub fn spawn_response_stream(
Expand All @@ -55,6 +59,7 @@ pub fn spawn_response_stream(
telemetry: Option<Arc<dyn SseTelemetry>>,
turn_state: Option<Arc<OnceLock<String>>>,
) -> ResponseStream {
let request_id = extract_request_id(&stream_response.headers);
let rate_limit_snapshots = parse_all_rate_limits(&stream_response.headers);
let models_etag = stream_response
.headers
Expand Down Expand Up @@ -97,7 +102,10 @@ pub fn spawn_response_stream(
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});

ResponseStream { rx_event }
ResponseStream {
rx_event,
request_id_for_rollout_log: request_id,
}
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -263,7 +271,13 @@ pub fn process_responses_event(
}
"response.created" => {
if event.response.is_some() {
return Ok(Some(ResponseEvent::Created {}));
let response_id = event
.response
.as_ref()
.and_then(|response| response.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
return Ok(Some(ResponseEvent::Created { response_id }));
}
}
"response.failed" => {
Expand Down Expand Up @@ -771,7 +785,7 @@ mod tests {
}

fn is_created(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::Created)
matches!(ev, ResponseEvent::Created { .. })
}
fn is_output(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::OutputItemDone(_))
Expand Down Expand Up @@ -869,6 +883,24 @@ mod tests {
}
}

#[tokio::test]
async fn spawn_response_stream_emits_request_id_header() {
let mut headers = HeaderMap::new();
headers.insert("x-request-id", HeaderValue::from_static("req_123"));
let bytes = stream::iter(Vec::<Result<Bytes, TransportError>>::new());
let stream_response = StreamResponse {
status: StatusCode::OK,
headers,
bytes: Box::pin(bytes),
};

let stream = spawn_response_stream(stream_response, idle_timeout(), None, None);
assert_eq!(
stream.request_id_for_rollout_log.as_deref(),
Some("req_123")
);
}

#[tokio::test]
async fn process_sse_ignores_response_model_field_in_payload() {
let events = run_sse(vec![
Expand All @@ -890,7 +922,10 @@ mod tests {
.await;

assert_eq!(events.len(), 2);
assert_matches!(&events[0], ResponseEvent::Created);
assert_matches!(
&events[0],
ResponseEvent::Created { response_id } if response_id.as_deref() == Some("resp-1")
);
assert_matches!(
&events[1],
ResponseEvent::Completed {
Expand Down Expand Up @@ -926,7 +961,10 @@ mod tests {
&events[0],
ResponseEvent::ServerModel(model) if model == CYBER_RESTRICTED_MODEL_FOR_TESTS
);
assert_matches!(&events[1], ResponseEvent::Created);
assert_matches!(
&events[1],
ResponseEvent::Created { response_id } if response_id.as_deref() == Some("resp-1")
);
assert_matches!(
&events[2],
ResponseEvent::Completed {
Expand Down
39 changes: 37 additions & 2 deletions codex-rs/core/src/api_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde::Deserialize;

use crate::auth::CodexAuth;
use crate::error::CodexErr;
use crate::error::InvalidRequestError;
use crate::error::RetryLimitReachedError;
use crate::error::UnexpectedResponseError;
use crate::error::UsageLimitReachedError;
Expand All @@ -31,7 +32,7 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
cf_ray: None,
request_id: None,
}),
ApiError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
ApiError::InvalidRequest { message } => CodexErr::InvalidRequest(message.into()),
ApiError::Transport(transport) => match transport {
TransportError::Http {
status,
Expand Down Expand Up @@ -60,7 +61,10 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
{
CodexErr::InvalidImageRequest()
} else {
CodexErr::InvalidRequest(body_text)
CodexErr::InvalidRequest(InvalidRequestError {
message: body_text,
request_id: extract_request_id(headers.as_ref()),
})
}
} else if status == http::StatusCode::INTERNAL_SERVER_ERROR {
CodexErr::InternalServerError
Expand Down Expand Up @@ -217,6 +221,37 @@ mod tests {
None
);
}

#[test]
fn map_api_error_preserves_request_id_for_bad_request() {
let mut headers = HeaderMap::new();
headers.insert(
REQUEST_ID_HEADER,
http::HeaderValue::from_static("req_bad_123"),
);
let body = serde_json::json!({
"error": {
"message": "bad request"
}
})
.to_string();

let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: Some(headers),
body: Some(body),
}));

let CodexErr::InvalidRequest(invalid_request) = err else {
panic!("expected CodexErr::InvalidRequest, got {err:?}");
};
assert_eq!(
invalid_request.message,
r#"{"error":{"message":"bad request"}}"#
);
assert_eq!(invalid_request.request_id.as_deref(), Some("req_bad_123"));
}
}

fn extract_request_tracking_id(headers: Option<&HeaderMap>) -> Option<String> {
Expand Down
21 changes: 11 additions & 10 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,18 +1094,13 @@ fn build_responses_headers(
headers
}

fn map_response_stream<S>(
api_stream: S,
fn map_response_stream(
api_stream: codex_api::ResponseStream,
otel_manager: OtelManager,
) -> (ResponseStream, oneshot::Receiver<LastResponse>)
where
S: futures::Stream<Item = std::result::Result<ResponseEvent, ApiError>>
+ Unpin
+ Send
+ 'static,
{
) -> (ResponseStream, oneshot::Receiver<LastResponse>) {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
let request_id_for_rollout_log = api_stream.request_id_for_rollout_log.clone();

tokio::spawn(async move {
let mut logged_error = false;
Expand Down Expand Up @@ -1173,7 +1168,13 @@ where
}
});

(ResponseStream { rx_event }, rx_last_response)
(
ResponseStream {
rx_event,
request_id_for_rollout_log,
},
rx_last_response,
)
}

/// Handles a 401 response by optionally refreshing ChatGPT tokens once.
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ pub(crate) mod tools {

pub struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
pub(crate) request_id_for_rollout_log: Option<String>,
}

impl Stream for ResponseStream {
Expand Down
43 changes: 40 additions & 3 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ use codex_protocol::protocol::HasLegacyEvent;
use codex_protocol::protocol::ItemCompletedEvent;
use codex_protocol::protocol::ItemStartedEvent;
use codex_protocol::protocol::RawResponseItemEvent;
use codex_protocol::protocol::ResponsesApiRequestIdEvent;
use codex_protocol::protocol::ResponsesApiResponseIdEvent;
use codex_protocol::protocol::ReviewRequest;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionSource;
Expand Down Expand Up @@ -5986,6 +5988,8 @@ fn realtime_text_for_event(msg: &EventMsg) -> Option<String> {
| EventMsg::ShutdownComplete
| EventMsg::EnteredReviewMode(_)
| EventMsg::ExitedReviewMode(_)
| EventMsg::ResponsesApiRequestId(_)
| EventMsg::ResponsesApiResponseId(_)
| EventMsg::RawResponseItem(_)
| EventMsg::ItemStarted(_)
| EventMsg::AgentMessageContentDelta(_)
Expand Down Expand Up @@ -6308,7 +6312,7 @@ async fn try_run_sampling_request(
auth_mode = sess.services.auth_manager.auth_mode(),
features = sess.features.enabled_features(),
);
let mut stream = client_session
let stream_result = client_session
.stream(
prompt,
&turn_context.model_info,
Expand All @@ -6320,7 +6324,30 @@ async fn try_run_sampling_request(
)
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)
.await??;
.await;
let mut stream = match stream_result {
Ok(Ok(stream)) => stream,
Ok(Err(err)) => {
if let Some(request_id) = err.request_id() {
sess.send_event(
&turn_context,
EventMsg::ResponsesApiRequestId(ResponsesApiRequestIdEvent {
request_id: request_id.to_string(),
}),
)
.await;
}
return Err(err);
}
Err(codex_async_utils::CancelErr::Cancelled) => return Err(CodexErr::TurnAborted),
};
if let Some(request_id) = stream.request_id_for_rollout_log.clone() {
sess.send_event(
&turn_context,
EventMsg::ResponsesApiRequestId(ResponsesApiRequestIdEvent { request_id }),
)
.await;
}

let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Expand Down Expand Up @@ -6372,7 +6399,17 @@ async fn try_run_sampling_request(
.record_responses(&handle_responses, &event);

match event {
ResponseEvent::Created => {}
ResponseEvent::Created { response_id } => {
if let Some(response_id) = response_id {
sess.send_event(
&turn_context,
EventMsg::ResponsesApiResponseId(ResponsesApiResponseIdEvent {
response_id,
}),
)
.await;
}
}
ResponseEvent::OutputItemDone(item) => {
let previously_active_item = active_item.take();
if let Some(previous) = previously_active_item.as_ref()
Expand Down
Loading
Loading