diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f85f99c1c0..b4d479a91e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -25,8 +25,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import io.micrometer.observation.Observation; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -75,10 +73,11 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; +import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -278,8 +277,8 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); + Flux completionChunks = this.openAiApi.chatCompletionStream(request, + getAdditionalHttpHeaders(prompt)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. @@ -354,8 +353,9 @@ public Flux stream(Prompt prompt) { }) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, cr -> { - observationContext.setResponse(cr); + + return new MessageAggregator().aggregate(flux, mergedChatResponse -> { + observationContext.setResponse(mergedChatResponse); }); }); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index c48ade0c68..0f486d4c5b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -19,6 +19,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -163,6 +164,7 @@ public void openAiChatNonTransientError() { } @Test + @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, @@ -184,6 +186,7 @@ public void openAiChatStreamTransientError() { } @Test + @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamNonTransientError() { when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index ad02dddf29..6aef10ed77 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -49,12 +49,15 @@ public class MessageAggregator { public Flux aggregate(Flux fluxChatResponse, Consumer onAggregationComplete) { - AtomicReference stringBufferRef = new AtomicReference<>(new StringBuilder()); - AtomicReference> mapRef = new AtomicReference<>(); + // Assistant Message + AtomicReference messageTextContentRef = new AtomicReference<>(new StringBuilder()); + AtomicReference> messageMetadataMapRef = new AtomicReference<>(); + // ChatGeneration Metadata AtomicReference generationMetadataRef = new AtomicReference<>( ChatGenerationMetadata.NULL); + // Usage AtomicReference metadataUsagePromptTokensRef = new AtomicReference<>(0L); AtomicReference metadataUsageGenerationTokensRef = new AtomicReference<>(0L); AtomicReference metadataUsageTotalTokensRef = new AtomicReference<>(0L); @@ -66,8 +69,8 @@ public Flux aggregate(Flux fluxChatResponse, AtomicReference metadataModelRef = new AtomicReference<>(""); return fluxChatResponse.doOnSubscribe(subscription -> { - stringBufferRef.set(new StringBuilder()); - mapRef.set(new HashMap<>()); + messageTextContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0L); @@ -84,10 +87,10 @@ public Flux aggregate(Flux fluxChatResponse, generationMetadataRef.set(chatResponse.getResult().getMetadata()); } if (chatResponse.getResult().getOutput().getContent() != null) { - stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent()); + messageTextContentRef.get().append(chatResponse.getResult().getOutput().getContent()); } if (chatResponse.getResult().getOutput().getMetadata() != null) { - mapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); + messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); } } if (chatResponse.getMetadata() != null) { @@ -128,13 +131,12 @@ public Flux aggregate(Flux fluxChatResponse, .withPromptMetadata(metadataPromptMetadataRef.get()) .build(); - onAggregationComplete.accept(new ChatResponse( - List.of(new Generation(new AssistantMessage(stringBufferRef.get().toString(), mapRef.get()), - generationMetadataRef.get())), - chatResponseMetadata)); + onAggregationComplete.accept(new ChatResponse(List.of(new Generation( + new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()), + generationMetadataRef.get())), chatResponseMetadata)); - stringBufferRef.set(new StringBuilder()); - mapRef.set(new HashMap<>()); + messageTextContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0L);