From 6a961d5e8393ca4640624574a95ec3cc5664a853 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 29 Aug 2024 15:07:48 +0200 Subject: [PATCH 1/6] Handling stream advisor responses --- .../ai/chat/client/DefaultChatClient.java | 178 ++++++++++-------- .../chat/client/RequestResponseAdvisor.java | 21 +++ .../advisor/AbstractChatMemoryAdvisor.java | 5 + .../advisor/MessageChatMemoryAdvisor.java | 16 -- .../advisor/PromptChatMemoryAdvisor.java | 18 +- .../client/advisor/QuestionAnswerAdvisor.java | 16 +- .../client/advisor/SimpleLoggerAdvisor.java | 14 +- .../advisor/VectorStoreChatMemoryAdvisor.java | 18 +- .../ObservableRequestResponseAdvisor.java | 32 +++- 9 files changed, 167 insertions(+), 151 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index c1cd207446..f71fe9beea 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -60,6 +60,7 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; /** * The default implementation of {@link ChatClient} as created by the @@ -386,6 +387,7 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest } } var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + var chatResponse = this.chatModel.call(prompt); ChatResponse advisedResponse = chatResponse; @@ -421,8 +423,9 @@ public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSp this.request = request; } - private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { + private Flux doGetObservableFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { return Flux.deferContextual(contextView -> { + ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, "", true); @@ -433,8 +436,8 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); - // @formatter:off - return doGetFluxChatResponse2(inputRequest) + // @formatter:off + return doGetFluxChatResponse(inputRequest) .doOnError(observation::error) .doFinally(s -> { observation.stop(); @@ -444,68 +447,81 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in }); } - private Flux doGetFluxChatResponse2(DefaultChatClientRequestSpec inputRequest) { - - Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.getAdvisorParams()); - DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, - context); - - String processedUserText = advisedRequest.getUserText(); - Map userParams = new HashMap<>(advisedRequest.getUserParams()); - - var messages = new ArrayList(advisedRequest.getMessages()); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.getSystemText())); - if (textsAreValid) { - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.getMedia()); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); - } - if (StringUtils.hasText(advisedRequest.getSystemText()) - || !advisedRequest.getSystemParams().isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) - .render()); - messages.add(systemMessage); - } - messages.add(userMessage); - } - - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); - } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); - } - } - var prompt = new Prompt(messages, advisedRequest.getChatOptions()); - - var fluxChatResponse = this.chatModel.stream(prompt); + record AdvisedRequestWithContext(AdvisedRequest request, Map advisorContext) { + } - Flux advisedResponse = fluxChatResponse; - // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); - } - } + private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { - return advisedResponse; + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + + var reqWithContext = new AdvisedRequestWithContext(toAdvisedRequest(inputRequest), advisorContext); + + return Flux.fromIterable(inputRequest.advisors) + .transformDeferredContextual((f, ctx) -> f + // This allows us to call blocking code in reduce + .publishOn(Schedulers.boundedElastic()) + .reduce(reqWithContext, (rwc, advisor) -> { + return new AdvisedRequestWithContext(advisor.adviseRequest(rwc.request, rwc.advisorContext), + rwc.advisorContext); + })) + .single() + // .doOnNext(r -> System.out.println("Request: " + r)) + .flatMapMany(r -> { + DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(r.request, + inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention()); + var messages = new ArrayList(advisedRequest.getMessages()); + + String processedSystemText = advisedRequest.getSystemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(advisedRequest.getSystemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, + advisedRequest.getSystemParams()) + .render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + String processedUserText = advisedRequest.getUserText(); + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(advisedRequest.getUserParams()); + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, advisedRequest.getMedia())); + } + + if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.getFunctionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + } + if (!advisedRequest.getFunctionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + } + } + var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + + Flux fluxChatResponse = this.chatModel.stream(prompt); + + Flux advisedResponse = fluxChatResponse; + // apply the advisors on response + if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedResponse = advisor.adviseResponse(advisedResponse, advisorContext); + } + } + + return advisedResponse; + }); } public Flux chatResponse() { - return doGetFluxChatResponse(this.request); + return doGetObservableFluxChatResponse(this.request); } public Flux content() { - return doGetFluxChatResponse(this.request).map(r -> { + return doGetObservableFluxChatResponse(this.request).map(r -> { if (r.getResult() == null || r.getResult().getOutput() == null || r.getResult().getOutput().getContent() == null) { return ""; @@ -815,34 +831,40 @@ public StreamResponseSpec stream() { public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest, Map context) { - DefaultChatClientRequestSpec advisedRequest = inputRequest; - - if (!CollectionUtils.isEmpty(inputRequest.advisors)) { - AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, - inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, - inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, - inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, - inputRequest.advisorParams); + if (CollectionUtils.isEmpty(inputRequest.advisors)) { + return inputRequest; + } - // apply the advisors onRequest - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - adviseRequest = advisor.adviseRequest(adviseRequest, context); - } + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequest); - advisedRequest = new DefaultChatClientRequestSpec(adviseRequest.chatModel(), adviseRequest.userText(), - adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(), - adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(), - adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(), - adviseRequest.advisorParams(), inputRequest.getObservationRegistry(), - inputRequest.getCustomObservationConvention()); + // apply the advisors onRequest + var currentAdvisors = new ArrayList<>(inputRequest.advisors); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedRequest = advisor.adviseRequest(advisedRequest, context); } - - return advisedRequest; + return toDefaultChatClientRequestSpec(advisedRequest, inputRequest.getObservationRegistry(), + inputRequest.getCustomObservationConvention()); } } + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, + inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, + inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams); + } + + private static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, + ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { + + return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), + advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), + advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), + advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), + advisedRequest.advisorParams(), observationRegistry, customObservationConvention); + } + // Prompt public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index a696c66c6a..ccefcf9983 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -21,6 +21,7 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; @@ -34,6 +35,16 @@ */ public interface RequestResponseAdvisor { + public enum StreamResponseMode { + + CHUNK, AGGREGATE, CUSTOM; + + } + + default StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.CUSTOM; + } + /** * @return the advisor name. */ @@ -73,6 +84,16 @@ default ChatResponse adviseResponse(ChatResponse response, Map c * @return the advised {@link ChatResponse} flux. */ default Flux adviseResponse(Flux fluxResponse, Map context) { + + if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) { + return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context)); + } + else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { + return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { + this.adviseResponse(chatResponse, context); + }); + } + return fluxResponse; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 06a48e7a04..85b354ce04 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -59,6 +59,11 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize; } + @Override + public StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.AGGREGATE; + } + protected T getChatMemoryStore() { return this.chatMemoryStore; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index bf22231ceb..5f125f0bd3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -20,14 +20,11 @@ import java.util.List; import java.util.Map; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; /** * Memory is retrieved added as a collection of messages to the prompt @@ -79,17 +76,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); - }); - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index 400fb89f63..19af746d60 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -21,16 +21,13 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.model.Content; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.model.Content; /** * Memory is retrieved added into the prompt's system text. @@ -109,17 +106,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); - }); - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index dc8747fff5..9dda036704 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -24,6 +24,7 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; @@ -133,11 +134,10 @@ public ChatResponse adviseResponse(ChatResponse response, Map co } @Override - public Flux adviseResponse(Flux fluxResponse, Map context) { - return fluxResponse.map(cr -> { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr); - chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return chatResponseBuilder.build(); + public Flux adviseResponse(Flux fluxChatResponse, Map context) { + // return fluxResponse.map(cr -> adviseResponse(cr, context)); + return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { + this.adviseResponse(chatResponse, context); }); } @@ -151,4 +151,10 @@ protected Filter.Expression doGetFilterExpression(Map context) { } + @Override + public StreamResponseMode getStreamResponseMode() { + // return StreamResponseMode.CHUNK; + return StreamResponseMode.AGGREGATE; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index d09141720d..dc21c3fcc7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -23,11 +23,8 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; -import reactor.core.publisher.Flux; - /** * A simple logger advisor that logs the request and response messages. * @@ -65,12 +62,6 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map return request; } - @Override - public Flux adviseResponse(Flux fluxChatResponse, Map context) { - return new MessageAggregator().aggregate(fluxChatResponse, - chatResponse -> logger.debug("stream response: {}", this.responseToString.apply(chatResponse))); - } - @Override public ChatResponse adviseResponse(ChatResponse response, Map context) { logger.debug("response: {}", this.responseToString.apply(response)); @@ -82,4 +73,9 @@ public String toString() { return SimpleLoggerAdvisor.class.getSimpleName(); } + @Override + public StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.AGGREGATE; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index 52e30ecdf7..42f14534d4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -21,15 +21,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.messages.AssistantMessage; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; @@ -120,19 +117,6 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context))); - }); - } - private List toDocuments(List messages, String conversationId) { List docs = messages.stream() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java index cbc57f6f18..08f7f9fba9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -53,37 +54,50 @@ public ObservableRequestResponseAdvisor(RequestResponseAdvisor targetAdvisor, } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map advisorRequestContext) { + public AdvisedRequest adviseRequest(AdvisedRequest request, Map advisorContext) { var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.BEFORE) .withAdvisedRequest(request) - .withAdvisorRequestContext(advisorRequestContext) + .withAdvisorRequestContext(advisorContext) .build(); return AdvisorObservationDocumentation.AI_ADVISOR .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseRequest(request, advisorRequestContext)); + .observe(() -> this.targetAdvisor.adviseRequest(request, advisorContext)); } @Override - public ChatResponse adviseResponse(ChatResponse response, Map advisorResponseContext) { + public ChatResponse adviseResponse(ChatResponse response, Map advisorContext) { var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.AFTER) - .withAdvisorRequestContext(advisorResponseContext) + .withAdvisorRequestContext(advisorContext) .build(); return AdvisorObservationDocumentation.AI_ADVISOR .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseResponse(response, advisorResponseContext)); + .observe(() -> this.targetAdvisor.adviseResponse(response, advisorContext)); } @Override - public Flux adviseResponse(Flux fluxResponse, Map context) { + public StreamResponseMode getStreamResponseMode() { + return this.targetAdvisor.getStreamResponseMode(); + } - // NOTE: The reactive observation support is not available yet. - return this.targetAdvisor.adviseResponse(fluxResponse, context); + @Override + public Flux adviseResponse(Flux fluxResponse, Map advisorContext) { + + if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) { + return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, advisorContext)); + } + else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { + return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { + this.adviseResponse(chatResponse, advisorContext); + }); + } + + return this.targetAdvisor.adviseResponse(fluxResponse, advisorContext); } /** From fb5737e722267f6a4136efecbb6e754eacdf396e Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 30 Aug 2024 19:54:57 +0200 Subject: [PATCH 2/6] Fix observaed advisor name resolution --- .../observation/ObservableRequestResponseAdvisor.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java index 08f7f9fba9..ad2598b848 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java @@ -53,6 +53,11 @@ public ObservableRequestResponseAdvisor(RequestResponseAdvisor targetAdvisor, this.customObservationConvention = customObservationConvention; } + @Override + public String getName() { + return this.targetAdvisor.getName(); + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map advisorContext) { From 131cbd49b8c6c047078301b751de0de703bafe64 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sat, 31 Aug 2024 20:44:44 +0200 Subject: [PATCH 3/6] Streamline repeating code --- .../SimplePersistentVectorStoreIT.java | 5 +- .../ai/chat/client/DefaultChatClient.java | 114 +++++++----------- .../ai/chat/client/ChatClientTest.java | 7 +- 3 files changed, 48 insertions(+), 78 deletions(-) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java index 32ba4572cb..21ca5bc49b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java @@ -44,8 +44,11 @@ public class SimplePersistentVectorStoreIT { @Autowired private EmbeddingModel embeddingModel; + @TempDir(cleanup = CleanupMode.ON_SUCCESS) + Path workingDir; + @Test - void persist(@TempDir(cleanup = CleanupMode.ON_SUCCESS) Path workingDir) { + void persist() { JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index f71fe9beea..cfae8ee334 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -347,46 +347,7 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, context); - var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.getUserText(); - - Map userParams = new HashMap<>(advisedRequest.getUserParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - - var messages = new ArrayList(advisedRequest.getMessages()); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.getSystemText())); - if (textsAreValid) { - if (StringUtils.hasText(advisedRequest.getSystemText()) - || !advisedRequest.getSystemParams().isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) - .render()); - messages.add(systemMessage); - } - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.getMedia()); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); - } - messages.add(userMessage); - } - - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); - } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); - } - } - var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + var prompt = toPrompt(advisedRequest, formatParam); var chatResponse = this.chatModel.call(prompt); @@ -412,6 +373,47 @@ public String content() { } + private static Prompt toPrompt(DefaultChatClientRequestSpec advisedRequest, String formatParam) { + + var messages = new ArrayList(advisedRequest.getMessages()); + + String processedSystemText = advisedRequest.getSystemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(advisedRequest.getSystemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.getSystemParams()) + .render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + var processedUserText = StringUtils.hasText(formatParam) + ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" + : advisedRequest.getUserText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(advisedRequest.getUserParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, advisedRequest.getMedia())); + } + + if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.getFunctionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + } + if (!advisedRequest.getFunctionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + } + } + + return new Prompt(messages, advisedRequest.getChatOptions()); + } + public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; @@ -465,41 +467,11 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in rwc.advisorContext); })) .single() - // .doOnNext(r -> System.out.println("Request: " + r)) .flatMapMany(r -> { DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(r.request, inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention()); - var messages = new ArrayList(advisedRequest.getMessages()); - - String processedSystemText = advisedRequest.getSystemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(advisedRequest.getSystemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, - advisedRequest.getSystemParams()) - .render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - String processedUserText = advisedRequest.getUserText(); - if (StringUtils.hasText(processedUserText)) { - - Map userParams = new HashMap<>(advisedRequest.getUserParams()); - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); - } - messages.add(new UserMessage(processedUserText, advisedRequest.getMedia())); - } - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); - } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); - } - } - var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + var prompt = toPrompt(advisedRequest, null); Flux fluxChatResponse = this.chatModel.stream(prompt); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 86bd33fa05..abd86be21d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -471,16 +471,11 @@ public void simpleSystemPrompt() throws MalformedURLException { assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(2); + assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); Message systemMessage = promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System prompt"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - - // Is this expected? - Message userMessage = promptCaptor.getValue().getInstructions().get(1); - assertThat(userMessage.getContent()).isEqualTo(""); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test From ac82c5e84a321ca00e3c07e7732c70655ad45049 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 2 Sep 2024 11:29:37 +0200 Subject: [PATCH 4/6] Add advisor strategy for ON_FINISH_REASON streaming response. Used by the Q&A advisor --- .../ai/chat/client/DefaultChatClient.java | 6 +-- .../chat/client/RequestResponseAdvisor.java | 45 ++++++++++++++++--- .../client/advisor/QuestionAnswerAdvisor.java | 14 +----- .../ObservableRequestResponseAdvisor.java | 18 +++++++- 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index cfae8ee334..5e1c4eaada 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -467,8 +467,8 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in rwc.advisorContext); })) .single() - .flatMapMany(r -> { - DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(r.request, + .flatMapMany(rwc -> { + DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(rwc.request, inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention()); var prompt = toPrompt(advisedRequest, null); @@ -483,8 +483,8 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in advisedResponse = advisor.adviseResponse(advisedResponse, advisorContext); } } - return advisedResponse; + }); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index ccefcf9983..498287f477 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -18,12 +18,13 @@ import java.util.Map; -import reactor.core.publisher.Flux; - +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.util.StringUtils; + +import reactor.core.publisher.Flux; /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and @@ -37,7 +38,26 @@ public interface RequestResponseAdvisor { public enum StreamResponseMode { - CHUNK, AGGREGATE, CUSTOM; + /** + * The sync advisor will be called for each response chunk (e.g. on each Flux + * item). + */ + PER_CHUNK, + /** + * The sync advisor is called only on chunks that contain a finish reason. Usually + * the last chunk in the stream. + */ + ON_FINISH_REASON, + /** + * The sync advisor is called only once after the stream is completed and an + * aggregated response is computed. Note that at that stage the advisor can not + * modify the response, but only observe it and react on the aggregated response. + */ + AGGREGATE, + /** + * Delegates to the stream advisor implementation. + */ + CUSTOM; } @@ -85,7 +105,7 @@ default ChatResponse adviseResponse(ChatResponse response, Map c */ default Flux adviseResponse(Flux fluxResponse, Map context) { - if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) { + if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context)); } else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { @@ -93,6 +113,21 @@ else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { this.adviseResponse(chatResponse, context); }); } + else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { + return fluxResponse.map(chatResponse -> { + boolean withFinishReason = chatResponse.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + return this.adviseResponse(chatResponse, context); + } + return chatResponse; + }); + } return fluxResponse; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 9dda036704..2aeacd0fe3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -24,7 +24,6 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; @@ -34,8 +33,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * Context for the question is retrieved from a Vector Store and added to the prompt's * user text. @@ -133,14 +130,6 @@ public ChatResponse adviseResponse(ChatResponse response, Map co return chatResponseBuilder.build(); } - @Override - public Flux adviseResponse(Flux fluxChatResponse, Map context) { - // return fluxResponse.map(cr -> adviseResponse(cr, context)); - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - this.adviseResponse(chatResponse, context); - }); - } - protected Filter.Expression doGetFilterExpression(Map context) { if (!context.containsKey(FILTER_EXPRESSION) @@ -153,8 +142,7 @@ protected Filter.Expression doGetFilterExpression(Map context) { @Override public StreamResponseMode getStreamResponseMode() { - // return StreamResponseMode.CHUNK; - return StreamResponseMode.AGGREGATE; + return StreamResponseMode.ON_FINISH_REASON; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java index ad2598b848..6be23f2f2a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import io.micrometer.observation.ObservationRegistry; import reactor.core.publisher.Flux; @@ -93,7 +94,7 @@ public StreamResponseMode getStreamResponseMode() { @Override public Flux adviseResponse(Flux fluxResponse, Map advisorContext) { - if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) { + if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, advisorContext)); } else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { @@ -101,6 +102,21 @@ else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { this.adviseResponse(chatResponse, advisorContext); }); } + else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { + return fluxResponse.map(chatResponse -> { + boolean withFinishReason = chatResponse.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + return this.adviseResponse(chatResponse, advisorContext); + } + return chatResponse; + }); + } return this.targetAdvisor.adviseResponse(fluxResponse, advisorContext); } From 1019b17e4d55d7daed5eab2dc38666760f16401b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 3 Sep 2024 14:23:18 +0200 Subject: [PATCH 5/6] code cleaning --- .../ai/chat/client/DefaultChatClient.java | 48 ++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5e1c4eaada..520b02c8f1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -340,21 +340,33 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in } - private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) { + private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam) { Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.getAdvisorParams()); - DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, - context); + context.putAll(inputRequestSpec.getAdvisorParams()); - var prompt = toPrompt(advisedRequest, formatParam); + DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; + if (!CollectionUtils.isEmpty(inputRequestSpec.advisors)) { + + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); + + // apply the advisors onRequest + var currentAdvisors = new ArrayList<>(inputRequestSpec.advisors); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedRequest = advisor.adviseRequest(advisedRequest, context); + } + advisedRequestSpec = toDefaultChatClientRequestSpec(advisedRequest, + inputRequestSpec.getObservationRegistry(), inputRequestSpec.getCustomObservationConvention()); + } + + var prompt = toPrompt(advisedRequestSpec, formatParam); var chatResponse = this.chatModel.call(prompt); ChatResponse advisedResponse = chatResponse; // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequestSpec.getAdvisors()); for (RequestResponseAdvisor advisor : currentAdvisors) { advisedResponse = advisor.adviseResponse(advisedResponse, context); } @@ -373,7 +385,7 @@ public String content() { } - private static Prompt toPrompt(DefaultChatClientRequestSpec advisedRequest, String formatParam) { + public static Prompt toPrompt(DefaultChatClientRequestSpec advisedRequest, String formatParam) { var messages = new ArrayList(advisedRequest.getMessages()); @@ -800,24 +812,6 @@ public StreamResponseSpec stream() { return new DefaultStreamResponseSpec(chatModel, this); } - public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest, - Map context) { - - if (CollectionUtils.isEmpty(inputRequest.advisors)) { - return inputRequest; - } - - AdvisedRequest advisedRequest = toAdvisedRequest(inputRequest); - - // apply the advisors onRequest - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedRequest = advisor.adviseRequest(advisedRequest, context); - } - return toDefaultChatClientRequestSpec(advisedRequest, inputRequest.getObservationRegistry(), - inputRequest.getCustomObservationConvention()); - } - } private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { @@ -827,7 +821,7 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams); } - private static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, + public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), From 7f025d79bca2626f9c308628d80e7c380a64fe3c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 3 Sep 2024 17:03:50 +0200 Subject: [PATCH 6/6] Pass the parent observation to the avisor observation instrumnatation --- .../ai/chat/client/DefaultChatClient.java | 59 ++++---- .../observation/AdvisorObservableHelper.java | 102 +++++++++++++ .../ObservableRequestResponseAdvisor.java | 138 ------------------ 3 files changed, 128 insertions(+), 171 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java delete mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 520b02c8f1..18f1e4b448 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -28,8 +28,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention; -import org.springframework.ai.chat.client.advisor.observation.ObservableRequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; @@ -330,17 +329,18 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, formatParam, false); - return ChatClientObservationDocumentation.AI_CHAT_CLIENT - .observation(inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.observationRegistry) - .observe(() -> { - ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam); - return chatResponse; - }); + var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( + inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, + () -> observationContext, inputRequest.observationRegistry); + return observation.observe(() -> { + ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam, observation); + return chatResponse; + }); } - private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam) { + private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, + Observation parentObservation) { Map context = new ConcurrentHashMap<>(); context.putAll(inputRequestSpec.getAdvisorParams()); @@ -353,7 +353,8 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest // apply the advisors onRequest var currentAdvisors = new ArrayList<>(inputRequestSpec.advisors); for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedRequest = advisor.adviseRequest(advisedRequest, context); + advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest, + context); } advisedRequestSpec = toDefaultChatClientRequestSpec(advisedRequest, inputRequestSpec.getObservationRegistry(), inputRequestSpec.getCustomObservationConvention()); @@ -368,7 +369,9 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { var currentAdvisors = new ArrayList<>(inputRequestSpec.getAdvisors()); for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse, context); + // advisedResponse = advisor.adviseResponse(advisedResponse, context); } } @@ -451,7 +454,7 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ .start(); // @formatter:off - return doGetFluxChatResponse(inputRequest) + return doGetFluxChatResponse(inputRequest, observation) .doOnError(observation::error) .doFinally(s -> { observation.stop(); @@ -464,7 +467,8 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ record AdvisedRequestWithContext(AdvisedRequest request, Map advisorContext) { } - private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { + private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest, + Observation parentObservation) { Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); @@ -475,8 +479,9 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in // This allows us to call blocking code in reduce .publishOn(Schedulers.boundedElastic()) .reduce(reqWithContext, (rwc, advisor) -> { - return new AdvisedRequestWithContext(advisor.adviseRequest(rwc.request, rwc.advisorContext), - rwc.advisorContext); + AdvisedRequest advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, + advisor, rwc.request, rwc.advisorContext); + return new AdvisedRequestWithContext(advisedRequest, rwc.advisorContext); })) .single() .flatMapMany(rwc -> { @@ -492,7 +497,8 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, advisorContext); + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse, advisorContext); } } return advisedResponse; @@ -656,35 +662,22 @@ public ChatClientRequestSpec advisors(Consumer consumer) var as = new DefaultAdvisorSpec(); consumer.accept(as); this.advisorParams.putAll(as.getParams()); - this.advisors.addAll(toObservableAdvisors(as.getAdvisors(), this.observationRegistry, null)); + this.advisors.addAll(as.getAdvisors()); return this; } public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) { Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(toObservableAdvisors(List.of(advisors), this.observationRegistry, null)); + this.advisors.addAll(Arrays.asList(advisors)); return this; } public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(toObservableAdvisors(advisors, this.observationRegistry, null)); + this.advisors.addAll(advisors); return this; } - private List toObservableAdvisors(List advisors, - ObservationRegistry observationRegistry, AdvisorObservationConvention customObservationConvention) { - if (CollectionUtils.isEmpty(advisors)) { - return advisors; - } - List observableAdvisors = new ArrayList<>(); - for (RequestResponseAdvisor advisor : advisors) { - observableAdvisors.add(new ObservableRequestResponseAdvisor(advisor, observationRegistry, - customObservationConvention)); - } - return observableAdvisors; - } - public ChatClientRequestSpec messages(Message... messages) { Assert.notNull(messages, "the messages must be non-null"); this.messages.addAll(List.of(messages)); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java new file mode 100644 index 0000000000..2f94da6ac5 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java @@ -0,0 +1,102 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.observation; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.RequestResponseAdvisor.StreamResponseMode; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.util.StringUtils; + +import io.micrometer.observation.Observation; +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class AdvisorObservableHelper { + + private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); + + public static AdvisedRequest adviseRequest(Observation parentObservation, RequestResponseAdvisor advisor, + AdvisedRequest advisedRequest, Map advisorContext) { + + var observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.BEFORE) + .withAdvisedRequest(advisedRequest) + .withAdvisorRequestContext(advisorContext) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + parentObservation.getObservationRegistry()) + .parentObservation(parentObservation) + .observe(() -> advisor.adviseRequest(advisedRequest, advisorContext)); + } + + public static ChatResponse adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, + ChatResponse response, Map advisorContext) { + + var observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.AFTER) + .withAdvisorRequestContext(advisorContext) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + parentObservation.getObservationRegistry()) + .parentObservation(parentObservation) + .observe(() -> advisor.adviseResponse(response, advisorContext)); + } + + public static Flux adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, + Flux fluxResponse, Map advisorContext) { + + if (advisor.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { + return fluxResponse + .map(chatResponse -> adviseResponse(parentObservation, advisor, chatResponse, advisorContext)); + } + else if (advisor.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { + return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { + adviseResponse(parentObservation, advisor, chatResponse, advisorContext); + }); + } + else if (advisor.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { + return fluxResponse.map(chatResponse -> { + boolean withFinishReason = chatResponse.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + return adviseResponse(parentObservation, advisor, chatResponse, advisorContext); + } + return chatResponse; + }); + } + + return advisor.adviseResponse(fluxResponse, advisorContext); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java deleted file mode 100644 index 6be23f2f2a..0000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java +++ /dev/null @@ -1,138 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.chat.client.advisor.observation; - -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.RequestResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; - -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class ObservableRequestResponseAdvisor implements RequestResponseAdvisor { - - private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); - - private final RequestResponseAdvisor targetAdvisor; - - private final ObservationRegistry observationRegistry; - - private final AdvisorObservationConvention customObservationConvention; - - public ObservableRequestResponseAdvisor(RequestResponseAdvisor targetAdvisor, - ObservationRegistry observationRegistry, - @Nullable AdvisorObservationConvention customObservationConvention) { - - Assert.notNull(targetAdvisor, "TargetAdvisor must not be null"); - Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); - - this.targetAdvisor = targetAdvisor; - this.observationRegistry = observationRegistry; - this.customObservationConvention = customObservationConvention; - } - - @Override - public String getName() { - return this.targetAdvisor.getName(); - } - - @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map advisorContext) { - - var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.BEFORE) - .withAdvisedRequest(request) - .withAdvisorRequestContext(advisorContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseRequest(request, advisorContext)); - } - - @Override - public ChatResponse adviseResponse(ChatResponse response, Map advisorContext) { - - var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.AFTER) - .withAdvisorRequestContext(advisorContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseResponse(response, advisorContext)); - } - - @Override - public StreamResponseMode getStreamResponseMode() { - return this.targetAdvisor.getStreamResponseMode(); - } - - @Override - public Flux adviseResponse(Flux fluxResponse, Map advisorContext) { - - if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { - return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, advisorContext)); - } - else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { - return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { - this.adviseResponse(chatResponse, advisorContext); - }); - } - else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { - return fluxResponse.map(chatResponse -> { - boolean withFinishReason = chatResponse.getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - - if (withFinishReason) { - return this.adviseResponse(chatResponse, advisorContext); - } - return chatResponse; - }); - } - - return this.targetAdvisor.adviseResponse(fluxResponse, advisorContext); - } - - /** - * Create the AdvisorObservationContext.Builder for the given advisorType. Can be - * overridden by the concrete advisor to provide additional context information. - * @param advisorType the advisor type. - * @return the AdvisorObservationContext.Builder instance. - */ - public AdvisorObservationContext.Builder doCreateObservationContextBuilder( - AdvisorObservationContext.Type advisorType) { - - return AdvisorObservationContext.builder() - .withAdvisorName(this.targetAdvisor.getName()) - .withAdvisorType(advisorType); - } - -}