Skip to content

Commit

Permalink
Model observability for Mistral
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and markpollack committed Aug 20, 2024
1 parent 46893b0 commit 80fe5e4
Show file tree
Hide file tree
Showing 10 changed files with 607 additions and 130 deletions.
6 changes: 6 additions & 0 deletions models/spring-ai-mistral-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -29,11 +32,13 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
Expand All @@ -57,17 +62,21 @@
import reactor.core.publisher.Mono;

/**
* Represents a Mistral AI Chat Model.
*
* @author Ricken Bazolo
* @author Christian Tzolov
* @author Grogdunn
* @author Thomas Vitale
* @author luocongqiu
* @since 0.8.1
* @since 1.0.0
*/
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

/**
* The default options used for the chat completion requests.
*/
Expand All @@ -80,6 +89,16 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM

private final RetryTemplate retryTemplate;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public MistralAiChatModel(MistralAiApi mistralAiApi) {
this(mistralAiApi,
MistralAiChatOptions.builder()
Expand All @@ -102,118 +121,160 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackContext, toolFunctionCallbacks, retryTemplate,
ObservationRegistry.NOOP);
}

public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.mistralAiApi = mistralAiApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public ChatResponse call(Prompt prompt) {

var request = createRequest(prompt, false);
MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletion> completionEntity = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionEntity(request));
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MistralAiApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

ChatCompletion chatCompletion = completionEntity.getBody();
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
ResponseEntity<ChatCompletion> completionEntity = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionEntity(request));

ChatCompletion chatCompletion = completionEntity.getBody();

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"index", choice.index(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"index", choice.index(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));

// // Non function calling.
// RateLimit rateLimit =
// OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
observationContext.setResponse(chatResponse);

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
return chatResponse;
});

if (isToolCall(chatResponse, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
if (response != null && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

return chatResponse;
return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
var request = createRequest(prompt, true);

Flux<ChatCompletionChunk> completionChunks = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::toChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return Flux.deferContextual(contextView -> {
var request = createRequest(prompt, true);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MistralAiApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

Flux<ChatCompletionChunk> completionChunks = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::toChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
// @formatter:on
}).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
}
}
else {
return new ChatResponse(generations);
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}
}));

// @formatter:off
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
else {
return Flux.just(response);
}

}));

return chatResponse.flatMap(response -> {

if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> {
observation.stop();
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on;

return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
});

}

private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
Expand Down Expand Up @@ -333,9 +394,28 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
}).toList();
}

private ChatOptions buildRequestOptions(MistralAiApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withMaxTokens(request.maxTokens())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
.build();
}

@Override
public ChatOptions getDefaultOptions() {
return MistralAiChatOptions.fromOptions(this.defaultOptions);
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Loading

0 comments on commit 80fe5e4

Please sign in to comment.