Skip to content

Commit

Permalink
Add OpenAiChatModel stream observability
Browse files Browse the repository at this point in the history
Integrated Micrometer's Observation into the OpenAiChatModel#stream reactive chain.

Included changes:
 - Added ability to aggregate streaming responses for use in Observation metadata.
 - Improved error handling and logging for chat response processing.
 - Updated unit tests to include new observation logic and subscribe to Flux responses.
 - Refined validation of observations in both normal and streaming chat operations.
 - Disabled retry for streaming which used RetryTemplate - should use .retryWhen operator as the next step.
 - Added an integration test.

Resolves #1190

Co-authored-by Christian Tzolov <ctzolov@vmware.com>
  • Loading branch information
chemicL authored and tzolov committed Aug 8, 2024
1 parent 86348e4 commit 478f180
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
Expand Down Expand Up @@ -72,7 +73,9 @@
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -271,64 +274,90 @@ public ChatResponse call(Prompt prompt) {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt)));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
getAdditionalHttpHeaders(prompt));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.operationMetadata(buildOperationMetadata())
.requestOptions(buildRequestOptions(request))
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off

if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");

return buildGeneration(choice, metadata);
}).toList();
// @formatter:on
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2, null));
}
else {
return new ChatResponse(generations);
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

}));
}));

return chatResponse.flatMap(response -> {
// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {

if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> {
// TODO: Consider a custom ObservationContext and
// include additional metadata
// if (s == SignalType.CANCEL) {
// observationContext.setAborted(true);
// }
observation.stop();
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(flux, mergedChatResponse -> {
observationContext.setResponse(mergedChatResponse);
});

if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void streamUserMessageSimpleContentType() {

when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);

chatModel.stream(new Prompt(List.of(new UserMessage("test message"))));
chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe();

validateStringContent(pomptCaptor.getValue());
assertThat(headersCaptor.getValue()).isEmpty();
Expand Down Expand Up @@ -137,8 +137,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException {
when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);

URL mediaUrl = new URL("http://test");
chatModel.stream(new Prompt(
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))));
chatModel
.stream(new Prompt(
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))))
.subscribe();

validateComplexContent(pomptCaptor.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import io.micrometer.common.KeyValue;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import reactor.core.publisher.Flux;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
Expand All @@ -37,6 +40,7 @@
import org.springframework.retry.support.RetryTemplate;

import java.util.List;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;
Expand All @@ -57,8 +61,14 @@ public class OpenAiChatModelObservationIT {
@Autowired
OpenAiChatModel chatModel;

@BeforeEach
void beforeEach() {
observationRegistry.clear();
}

@Test
void observationForEmbeddingOperation() {
void observationForChatOperation() {

var options = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withFrequencyPenalty(0f)
Expand All @@ -77,6 +87,45 @@ void observationForEmbeddingOperation() {
ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
}

@Test
void observationForStreamingChatOperation() {
var options = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withFrequencyPenalty(0f)
.withMaxTokens(2048)
.withPresencePenalty(0f)
.withStop(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTopP(1f)
.withStreamUsage(true)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);

Flux<ChatResponse> chatResponseFlux = chatModel.stream(prompt);

List<ChatResponse> responses = chatResponseFlux.collectList().block();
assertThat(responses).isNotEmpty();
assertThat(responses).hasSizeGreaterThan(10);

String aggregatedResponse = responses.subList(0, responses.size() - 1)
.stream()
.map(r -> r.getResult().getOutput().getContent())
.collect(Collectors.joining());
assertThat(aggregatedResponse).isNotEmpty();

ChatResponse lastChatResponse = responses.get(responses.size() - 1);

ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
}

private void validate(ChatResponseMetadata responseMetadata) {
TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@
* @author Christian Tzolov
*/
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
public class OpenAiPaymentTransactionIT {

private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Optional;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -163,6 +164,7 @@ public void openAiChatNonTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamTransientError() {

var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
Expand All @@ -184,10 +186,11 @@ public void openAiChatStreamTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamNonTransientError() {
when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any()))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.DefaultChatClient;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
Expand Down
Loading

0 comments on commit 478f180

Please sign in to comment.