diff --git a/atoma-proxy/src/server/streamer.rs b/atoma-proxy/src/server/streamer.rs index 8e1ec21f..b44cd83f 100644 --- a/atoma-proxy/src/server/streamer.rs +++ b/atoma-proxy/src/server/streamer.rs @@ -270,176 +270,205 @@ impl Stream for Streamer { } match self.stream.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - if self.status != StreamStatus::Started { - self.status = StreamStatus::Started; - } - - if chunk.as_ref() == KEEP_ALIVE_CHUNK { - return Poll::Pending; - } - - let chunk_str = match std::str::from_utf8(&chunk) { - Ok(v) => v, - Err(e) => { + Poll::Ready(Some(Ok(chunk))) => match self.handle_stream_chunk(chunk) { + Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))), + Poll::Ready(Some(Err(e))) => { + self.status = StreamStatus::Failed(e.to_string()); + if let Err(e) = update_state_manager( + &self.state_manager_sender, + self.stack_small_id, + self.estimated_total_tokens, + self.estimated_total_tokens, + &self.endpoint, + ) { error!( - target = "atoma-service", + target = "atoma-service-streamer", level = "error", - "Invalid UTF-8 sequence: {}", + "Error updating stack num tokens: {}", e ); - return Poll::Ready(Some(Err(Error::new(format!( - "Invalid UTF-8 sequence: {}", - e - ))))); } - }; + Poll::Ready(Some(Err(e))) + } + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + }, + Poll::Ready(Some(Err(e))) => { + self.status = StreamStatus::Failed(e.to_string()); + Poll::Ready(None) + } + Poll::Ready(None) => { + if !self.chunk_buffer.is_empty() { + error!( + target = "atoma-service-streamer", + level = "error", + "Stream ended, but the chunk buffer is not empty, this should not happen: {}", + self.chunk_buffer + ); + } + self.status = StreamStatus::Completed; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl Streamer { + #[instrument( + level = "info", + skip(self, chunk), + fields(endpoint = "handle_stream_chunk",) + )] + fn handle_stream_chunk(&mut self, chunk: Bytes) -> Poll>> { + if self.status != StreamStatus::Started { + self.status = StreamStatus::Started; + } - let chunk_str = chunk_str.strip_prefix(DATA_PREFIX).unwrap_or(chunk_str); + if chunk.as_ref() == KEEP_ALIVE_CHUNK { + return Poll::Pending; + } - if chunk_str.starts_with(DONE_CHUNK) { - // This is the last chunk, meaning the inference streaming is complete - self.status = StreamStatus::Completed; - return Poll::Ready(None); + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(v) => v, + Err(e) => { + error!( + target = "atoma-service", + level = "error", + "Invalid UTF-8 sequence: {}", + e + ); + return Poll::Ready(Some(Err(Error::new(format!( + "Invalid UTF-8 sequence: {}", + e + ))))); + } + }; + + let chunk_str = chunk_str.strip_prefix(DATA_PREFIX).unwrap_or(chunk_str); + + if chunk_str.starts_with(DONE_CHUNK) { + // This is the last chunk, meaning the inference streaming is complete + self.status = StreamStatus::Completed; + return Poll::Ready(None); + } + + let chunk = match serde_json::from_str::(chunk_str) { + Ok(chunk) => { + if !self.chunk_buffer.is_empty() { + error!( + target = "atoma-service-streamer", + level = "error", + "Error parsing previous chunk(s), as chunk buffer is not empty: {}", + self.chunk_buffer + ); + self.chunk_buffer.clear(); + } + chunk + } + Err(e) => { + if e.is_eof() { + info!( + target = "atoma-service-streamer", + parse_chunk = "eof_chunk", + "EOF reached, pushing chunk to buffer: {}", + chunk_str + ); + self.chunk_buffer.push_str(chunk_str); + return Poll::Pending; + } + + if self.chunk_buffer.is_empty() { + error!( + target = "atoma-service-streamer", + level = "error", + "Error parsing chunk {chunk_str}: {}", + e + ); + return Poll::Ready(Some(Err(Error::new(format!( + "Error parsing chunk: {}", + e + ))))); } - let chunk = match serde_json::from_str::(chunk_str) { + self.chunk_buffer.push_str(chunk_str); + match serde_json::from_str::(&self.chunk_buffer) { Ok(chunk) => { - if !self.chunk_buffer.is_empty() { - error!( - target = "atoma-service-streamer", - level = "error", - "Error parsing previous chunk(s), as chunk buffer is not empty: {}", - self.chunk_buffer - ); - self.chunk_buffer.clear(); - } + info!( + target = "atoma-service-streamer", + parse_chunk = "eof_chunk", + "Chunk parsed successfully, clearing buffer: {}", + self.chunk_buffer + ); + self.chunk_buffer.clear(); chunk } Err(e) => { if e.is_eof() { - info!( - target = "atoma-service-streamer", - parse_chunk = "eof_chunk", - "EOF reached, pushing chunk to buffer: {}", - chunk_str - ); - self.chunk_buffer.push_str(chunk_str); + // NOTE: We don't need to push the chunk to the buffer, as it was pushed already return Poll::Pending; } - - if self.chunk_buffer.is_empty() { - error!( - target = "atoma-service-streamer", - level = "error", - "Error parsing chunk {chunk_str}: {}", - e - ); - return Poll::Ready(Some(Err(Error::new(format!( - "Error parsing chunk: {}", - e - ))))); - } - - self.chunk_buffer.push_str(chunk_str); - match serde_json::from_str::(&self.chunk_buffer) { - Ok(chunk) => { - info!( - target = "atoma-service-streamer", - parse_chunk = "eof_chunk", - "Chunk parsed successfully, clearing buffer: {}", - self.chunk_buffer - ); - self.chunk_buffer.clear(); - chunk - } - Err(e) => { - if e.is_eof() { - // NOTE: We don't need to push the chunk to the buffer, as it was pushed already - return Poll::Pending; - } - error!( - target = "atoma-service-streamer", - level = "error", - "Error parsing chunk {}: {}", - self.chunk_buffer, - e - ); - self.chunk_buffer.clear(); - return Poll::Ready(Some(Err(Error::new(format!( - "Error parsing chunk: {}", - e - ))))); - } - } + error!( + target = "atoma-service-streamer", + level = "error", + "Error parsing chunk {}: {}", + self.chunk_buffer, + e + ); + self.chunk_buffer.clear(); + return Poll::Ready(Some(Err(Error::new(format!( + "Error parsing chunk: {}", + e + ))))); } - }; - - if self.start_decode.is_none() { - self.start_decode = Some(Instant::now()); - let latency = self.start.elapsed().as_secs_f64(); - self.state_manager_sender - .send(AtomaAtomaStateManagerEvent::UpdateNodeLatencyPerformance { - timestamp: DateTime::::from(std::time::SystemTime::now()), // Convert to chrono::DateTime - node_small_id: self.node_id, - latency, - }) - .map_err(|e| { - error!( - target = "atoma-service-streamer", - level = "error", - "Error updating node latency performance: {}", - e - ); - Error::new(format!("Error updating node latency performance: {}", e)) - })?; } + } + }; - if self.endpoint == CHAT_COMPLETIONS_PATH { - let choices = match chunk.get(CHOICES).and_then(|choices| choices.as_array()) { - Some(choices) => choices, - None => { - error!( - target = "atoma-service-streamer", - level = "error", - "Error getting choices from chunk" - ); - return Poll::Ready(Some(Err(Error::new( - "Error getting choices from chunk", - )))); - } - }; - - if choices.is_empty() { - if let Some(usage) = chunk.get(USAGE) { - self.status = StreamStatus::Completed; - self.handle_final_chunk(usage)?; - } - } - } else if let Some(usage) = chunk.get(USAGE) { - self.status = StreamStatus::Completed; - self.handle_final_chunk(usage)?; - } + if self.start_decode.is_none() { + self.start_decode = Some(Instant::now()); + let latency = self.start.elapsed().as_secs_f64(); + self.state_manager_sender + .send(AtomaAtomaStateManagerEvent::UpdateNodeLatencyPerformance { + timestamp: DateTime::::from(std::time::SystemTime::now()), // Convert to chrono::DateTime + node_small_id: self.node_id, + latency, + }) + .map_err(|e| { + error!( + target = "atoma-service-streamer", + level = "error", + "Error updating node latency performance: {}", + e + ); + Error::new(format!("Error updating node latency performance: {}", e)) + })?; + } - Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?))) - } - Poll::Ready(Some(Err(e))) => { - self.status = StreamStatus::Failed(e.to_string()); - Poll::Ready(None) - } - Poll::Ready(None) => { - if !self.chunk_buffer.is_empty() { + if self.endpoint == CHAT_COMPLETIONS_PATH { + let choices = match chunk.get(CHOICES).and_then(|choices| choices.as_array()) { + Some(choices) => choices, + None => { error!( target = "atoma-service-streamer", level = "error", - "Stream ended, but the chunk buffer is not empty, this should not happen: {}", - self.chunk_buffer + "Error getting choices from chunk" ); + return Poll::Ready(Some(Err(Error::new("Error getting choices from chunk")))); + } + }; + + if choices.is_empty() { + if let Some(usage) = chunk.get(USAGE) { + self.status = StreamStatus::Completed; + self.handle_final_chunk(usage)?; } - self.status = StreamStatus::Completed; - Poll::Ready(None) } - Poll::Pending => Poll::Pending, + } else if let Some(usage) = chunk.get(USAGE) { + self.status = StreamStatus::Completed; + self.handle_final_chunk(usage)?; } + + Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?))) } }