Skip to content

Commit

Permalink
Add ObservationRegistry support to ChatClient
Browse files Browse the repository at this point in the history
 - Implement observable chat responses in DefaultChatClient
 - Add ChatClientObservationContext and related classes for metrics
 - Update ChatClient and builder methods to support ObservationRegistry
 - Enhance RequestResponseAdvisor with getName() method
 - Add ChatClient streaming observability support
 - Introduce ChatClientObservationDocumentation for metric key names
 - Create DefaultChatClientObservationConvention for implementing conventions
 - Add ChatClientInputContentObservationFilter for optional input content logging
 - Update ChatClientAutoConfiguration to include new observation components
 - Extend ChatClientBuilderProperties with observation configuration options
 - Add unit tests for new observation classes and configurations
 - Update AiOperationType and AiProvider enums with new values
 - Implement safeguards and warnings for sensitive data in observations

 Resolves #1206
  • Loading branch information
tzolov committed Aug 11, 2024
1 parent bf84d59 commit 66e4b88
Show file tree
Hide file tree
Showing 17 changed files with 1,035 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.ai.model.Media;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.Media;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeType;

import io.micrometer.observation.ObservationRegistry;
import reactor.core.publisher.Flux;

/**
Expand All @@ -48,11 +50,25 @@
public interface ChatClient {

static ChatClient create(ChatModel chatModel) {
return builder(chatModel).build();
return create(chatModel, ObservationRegistry.NOOP);
}

static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry) {
return create(chatModel, observationRegistry, null);
}

static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry,
ChatClientObservationConvention observationConvention) {
return builder(chatModel, observationRegistry, observationConvention).build();
}

static Builder builder(ChatModel chatModel) {
return new DefaultChatClientBuilder(chatModel);
return builder(chatModel, ObservationRegistry.NOOP, null);
}

static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry,
ChatClientObservationConvention customObservationConvention) {
return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention);
}

ChatClientRequestSpec prompt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

import reactor.core.publisher.Flux;

import org.springframework.ai.model.Media;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
Expand All @@ -42,6 +43,7 @@
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.model.function.FunctionCallingOptions;
Expand All @@ -52,6 +54,11 @@
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;

/**
* The default implementation of {@link ChatClient} as created by the
* {@link Builder#build()} } method.
Expand All @@ -65,6 +72,8 @@
*/
public class DefaultChatClient implements ChatClient {

private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention();

private final ChatModel chatModel;

private final DefaultChatClientRequestSpec defaultChatClientRequest;
Expand Down Expand Up @@ -281,7 +290,7 @@ public <T> ResponseEntity<ChatResponse, T> responseEntity(
}

protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var chatResponse = doGetObservableChatResponse(this.request, boc.getFormat());
var responseContent = chatResponse.getResult().getOutput().getContent();
T entity = boc.convert(responseContent);

Expand All @@ -297,7 +306,7 @@ public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
}

private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var chatResponse = doGetObservableChatResponse(this.request, boc.getFormat());
var stringResponse = chatResponse.getResult().getOutput().getContent();
return boc.convert(stringResponse);
}
Expand All @@ -309,7 +318,23 @@ public <T> T entity(Class<T> type) {
}

private ChatResponse doGetChatResponse() {
return this.doGetChatResponse(this.request, "");
return this.doGetObservableChatResponse(this.request, "");
}

private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec inputRequest,
String formatParam) {

ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest,
formatParam, false);

return ChatClientObservationDocumentation.AI_CHAT_CLIENT
.observation(inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
() -> observationContext, inputRequest.observationRegistry)
.observe(() -> {
ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam);
return chatResponse;
});

}

private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) {
Expand Down Expand Up @@ -395,6 +420,29 @@ public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSp
}

private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
return Flux.deferContextual(contextView -> {
ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, "",
true);

Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
() -> observationContext, inputRequest.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null))
.start();

// @formatter:off
return doGetFluxChatResponse2(inputRequest)
.doOnError(observation::error)
.doFinally(s -> {
observation.stop();
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on
});
}

private Flux<ChatResponse> doGetFluxChatResponse2(DefaultChatClientRequestSpec inputRequest) {

Map<String, Object> context = new ConcurrentHashMap<>();
context.putAll(inputRequest.getAdvisorParams());
Expand Down Expand Up @@ -426,9 +474,7 @@ private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec in
messages.add(userMessage);
}

if (advisedRequest.getChatOptions() instanceof

FunctionCallingOptions functionCallingOptions) {
if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
if (!advisedRequest.getFunctionNames().isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames()));
}
Expand Down Expand Up @@ -470,6 +516,10 @@ public Flux<String> content() {

public static class DefaultChatClientRequestSpec implements ChatClientRequestSpec {

private final ObservationRegistry observationRegistry;

private final ChatClientObservationConvention customObservationConvention;

private final ChatModel chatModel;

private String userText = "";
Expand All @@ -494,6 +544,14 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe

private final Map<String, Object> advisorParams = new HashMap<>();

private ObservationRegistry getObservationRegistry() {
return observationRegistry;
}

private ChatClientObservationConvention getCustomObservationConvention() {
return customObservationConvention;
}

public String getUserText() {
return userText;
}
Expand Down Expand Up @@ -541,13 +599,15 @@ public List<FunctionCallback> getFunctionCallbacks() {
/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.observationRegistry, ccr.customObservationConvention);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams,
ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) {

this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions.copy()
Expand All @@ -564,14 +624,17 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<St
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention;
}

/**
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient.builder(chatModel)
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
.builder(chatModel, this.observationRegistry, this.customObservationConvention)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
Expand Down Expand Up @@ -756,7 +819,8 @@ public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequ
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
adviseRequest.advisorParams(), inputRequest.getObservationRegistry(),
inputRequest.getCustomObservationConvention());
}

return advisedRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec;
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;

import io.micrometer.observation.ObservationRegistry;

/**
* DefaultChatClientBuilder is a builder class for creating a ChatClient.
*
Expand All @@ -48,11 +51,18 @@ public class DefaultChatClientBuilder implements Builder {

private final ChatModel chatModel;

public DefaultChatClientBuilder(ChatModel chatModel) {
DefaultChatClientBuilder(ChatModel chatModel) {
this(chatModel, ObservationRegistry.NOOP, null);
}

public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry,
ChatClientObservationConvention customObservationConvention) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(),
List.of(), List.of(), List.of(), null, List.of(), Map.of());
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
customObservationConvention);
}

public ChatClient build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
*/
public interface RequestResponseAdvisor {

/**
* @return the advisor name.
*/
default String getName() {
return this.getClass().getSimpleName();
}

/**
* @param request the {@link AdvisedRequest} data to be advised. Represents the row
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
Expand Down
Loading

0 comments on commit 66e4b88

Please sign in to comment.