Skip to content

Commit

Permalink
Model observability for Anthropic
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and markpollack committed Aug 20, 2024
1 parent 3fa102e commit 3b7522b
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 35 deletions.
6 changes: 6 additions & 0 deletions models/spring-ai-anthropic/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-xml</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import java.util.Set;
import java.util.stream.Collectors;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
Expand All @@ -39,11 +42,13 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
Expand All @@ -64,12 +69,15 @@
* @author Christian Tzolov
* @author luocongqiu
* @author Mariusz Bernacki
* @author Thomas Vitale
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {

private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();

public static final Integer DEFAULT_MAX_TOKENS = 500;
Expand All @@ -91,6 +99,16 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
*/
public final RetryTemplate retryTemplate;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
Expand Down Expand Up @@ -151,54 +169,108 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext,
List<FunctionCallback> toolFunctionCallbacks) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackContext, toolFunctionCallbacks,
ObservationRegistry.NOOP);
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackContext the function callback context used to store the
* state of the function calls.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
*/
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext,
List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {

super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);

Assert.notNull(anthropicApi, "AnthropicApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");

this.anthropicApi = anthropicApi;
this.defaultOptions = defaultOptions;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {

ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());

if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
observationContext.setResponse(chatResponse);

return chatResponse;
});

if (response != null && this.isToolCall(response, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, response);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

return chatResponse;
return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

ChatCompletionRequest request = createRequest(prompt, true);
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

Flux<ChatCompletionResponse> response = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionStream(request));
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

return response.switchMap(chatCompletionResponse -> {
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);

if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
// @formatter:off
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);

if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}

return Mono.just(chatResponse);
return Mono.just(chatResponse);
})
.doOnError(observation::error)
.doFinally(s -> {
observation.stop();
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
});
}

Expand Down Expand Up @@ -366,9 +438,29 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
}).toList();
}

private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withMaxTokens(request.maxTokens())
.withStopSequences(request.stopSequences())
.withTemperature(request.temperature())
.withTopK(request.topK())
.withTopP(request.topP())
.build();
}

@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
import java.util.function.Consumer;
import java.util.function.Predicate;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
Expand All @@ -51,11 +50,12 @@
/**
* @author Christian Tzolov
* @author Mariusz Bernacki
* @author Thomas Vitale
* @since 1.0.0
*/
public class AnthropicApi {

private static final Logger logger = LoggerFactory.getLogger(AnthropicApi.class);
public static final String PROVIDER_NAME = AiProvider.ANTHROPIC.value();

private static final String HEADER_X_API_KEY = "x-api-key";

Expand Down
Loading

0 comments on commit 3b7522b

Please sign in to comment.