Skip to content

Commit

Permalink
Disable retry for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
tzolov committed Aug 8, 2024
1 parent 9ef51d0 commit 7cfef1f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import io.micrometer.observation.Observation;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -75,10 +73,11 @@
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;

/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
Expand Down Expand Up @@ -278,8 +277,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt)));
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
getAdditionalHttpHeaders(prompt));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
Expand Down Expand Up @@ -354,8 +353,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on
return new MessageAggregator().aggregate(flux, cr -> {
observationContext.setResponse(cr);

return new MessageAggregator().aggregate(flux, mergedChatResponse -> {
observationContext.setResponse(mergedChatResponse);
});

});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Optional;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -163,6 +164,7 @@ public void openAiChatNonTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamTransientError() {

var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
Expand All @@ -184,6 +186,7 @@ public void openAiChatStreamTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamNonTransientError() {
when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any()))
.thenThrow(new RuntimeException("Non Transient Error"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ public class MessageAggregator {
public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
Consumer<ChatResponse> onAggregationComplete) {

AtomicReference<StringBuilder> stringBufferRef = new AtomicReference<>(new StringBuilder());
AtomicReference<Map<String, Object>> mapRef = new AtomicReference<>();
// Assistant Message
AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<>(new StringBuilder());
AtomicReference<Map<String, Object>> messageMetadataMapRef = new AtomicReference<>();

// ChatGeneration Metadata
AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<>(
ChatGenerationMetadata.NULL);

// Usage
AtomicReference<Long> metadataUsagePromptTokensRef = new AtomicReference<>(0L);
AtomicReference<Long> metadataUsageGenerationTokensRef = new AtomicReference<>(0L);
AtomicReference<Long> metadataUsageTotalTokensRef = new AtomicReference<>(0L);
Expand All @@ -66,8 +69,8 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
AtomicReference<String> metadataModelRef = new AtomicReference<>("");

return fluxChatResponse.doOnSubscribe(subscription -> {
stringBufferRef.set(new StringBuilder());
mapRef.set(new HashMap<>());
messageTextContentRef.set(new StringBuilder());
messageMetadataMapRef.set(new HashMap<>());
metadataIdRef.set("");
metadataModelRef.set("");
metadataUsagePromptTokensRef.set(0L);
Expand All @@ -84,10 +87,10 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
generationMetadataRef.set(chatResponse.getResult().getMetadata());
}
if (chatResponse.getResult().getOutput().getContent() != null) {
stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent());
messageTextContentRef.get().append(chatResponse.getResult().getOutput().getContent());
}
if (chatResponse.getResult().getOutput().getMetadata() != null) {
mapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
}
}
if (chatResponse.getMetadata() != null) {
Expand Down Expand Up @@ -128,13 +131,12 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
.withPromptMetadata(metadataPromptMetadataRef.get())
.build();

onAggregationComplete.accept(new ChatResponse(
List.of(new Generation(new AssistantMessage(stringBufferRef.get().toString(), mapRef.get()),
generationMetadataRef.get())),
chatResponseMetadata));
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()),
generationMetadataRef.get())), chatResponseMetadata));

stringBufferRef.set(new StringBuilder());
mapRef.set(new HashMap<>());
messageTextContentRef.set(new StringBuilder());
messageMetadataMapRef.set(new HashMap<>());
metadataIdRef.set("");
metadataModelRef.set("");
metadataUsagePromptTokensRef.set(0L);
Expand Down

0 comments on commit 7cfef1f

Please sign in to comment.