diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java index 32ba4572cb..21ca5bc49b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java @@ -44,8 +44,11 @@ public class SimplePersistentVectorStoreIT { @Autowired private EmbeddingModel embeddingModel; + @TempDir(cleanup = CleanupMode.ON_SUCCESS) + Path workingDir; + @Test - void persist(@TempDir(cleanup = CleanupMode.ON_SUCCESS) Path workingDir) { + void persist() { JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index c1cd207446..18f1e4b448 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -28,8 +28,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention; -import org.springframework.ai.chat.client.advisor.observation.ObservableRequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; @@ -60,6 +59,7 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; /** * The default implementation of {@link ChatClient} as created by the @@ -329,71 +329,49 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, formatParam, false); - return ChatClientObservationDocumentation.AI_CHAT_CLIENT - .observation(inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, - () -> observationContext, inputRequest.observationRegistry) - .observe(() -> { - ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam); - return chatResponse; - }); + var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( + inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, + () -> observationContext, inputRequest.observationRegistry); + return observation.observe(() -> { + ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam, observation); + return chatResponse; + }); } - private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) { + private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, + Observation parentObservation) { Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.getAdvisorParams()); - DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, - context); + context.putAll(inputRequestSpec.getAdvisorParams()); - var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.getUserText(); + DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; + if (!CollectionUtils.isEmpty(inputRequestSpec.advisors)) { - Map userParams = new HashMap<>(advisedRequest.getUserParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); - var messages = new ArrayList(advisedRequest.getMessages()); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.getSystemText())); - if (textsAreValid) { - if (StringUtils.hasText(advisedRequest.getSystemText()) - || !advisedRequest.getSystemParams().isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) - .render()); - messages.add(systemMessage); - } - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.getMedia()); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); + // apply the advisors onRequest + var currentAdvisors = new ArrayList<>(inputRequestSpec.advisors); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest, + context); } - messages.add(userMessage); + advisedRequestSpec = toDefaultChatClientRequestSpec(advisedRequest, + inputRequestSpec.getObservationRegistry(), inputRequestSpec.getCustomObservationConvention()); } - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); - } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); - } - } - var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + var prompt = toPrompt(advisedRequestSpec, formatParam); + var chatResponse = this.chatModel.call(prompt); ChatResponse advisedResponse = chatResponse; // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequestSpec.getAdvisors()); for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse, context); + // advisedResponse = advisor.adviseResponse(advisedResponse, context); } } @@ -410,6 +388,47 @@ public String content() { } + public static Prompt toPrompt(DefaultChatClientRequestSpec advisedRequest, String formatParam) { + + var messages = new ArrayList(advisedRequest.getMessages()); + + String processedSystemText = advisedRequest.getSystemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(advisedRequest.getSystemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.getSystemParams()) + .render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + var processedUserText = StringUtils.hasText(formatParam) + ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" + : advisedRequest.getUserText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(advisedRequest.getUserParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, advisedRequest.getMedia())); + } + + if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.getFunctionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + } + if (!advisedRequest.getFunctionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + } + } + + return new Prompt(messages, advisedRequest.getChatOptions()); + } + public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; @@ -421,8 +440,9 @@ public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSp this.request = request; } - private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { + private Flux doGetObservableFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { return Flux.deferContextual(contextView -> { + ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, "", true); @@ -433,8 +453,8 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); - // @formatter:off - return doGetFluxChatResponse2(inputRequest) + // @formatter:off + return doGetFluxChatResponse(inputRequest, observation) .doOnError(observation::error) .doFinally(s -> { observation.stop(); @@ -444,68 +464,54 @@ private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec in }); } - private Flux doGetFluxChatResponse2(DefaultChatClientRequestSpec inputRequest) { + record AdvisedRequestWithContext(AdvisedRequest request, Map advisorContext) { + } - Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.getAdvisorParams()); - DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, - context); + private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest, + Observation parentObservation) { - String processedUserText = advisedRequest.getUserText(); - Map userParams = new HashMap<>(advisedRequest.getUserParams()); + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - var messages = new ArrayList(advisedRequest.getMessages()); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.getSystemText())); - if (textsAreValid) { - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.getMedia()); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); - } - if (StringUtils.hasText(advisedRequest.getSystemText()) - || !advisedRequest.getSystemParams().isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) - .render()); - messages.add(systemMessage); - } - messages.add(userMessage); - } + var reqWithContext = new AdvisedRequestWithContext(toAdvisedRequest(inputRequest), advisorContext); - if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.getFunctionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); - } - if (!advisedRequest.getFunctionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); - } - } - var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + return Flux.fromIterable(inputRequest.advisors) + .transformDeferredContextual((f, ctx) -> f + // This allows us to call blocking code in reduce + .publishOn(Schedulers.boundedElastic()) + .reduce(reqWithContext, (rwc, advisor) -> { + AdvisedRequest advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, + advisor, rwc.request, rwc.advisorContext); + return new AdvisedRequestWithContext(advisedRequest, rwc.advisorContext); + })) + .single() + .flatMapMany(rwc -> { + DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(rwc.request, + inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention()); - var fluxChatResponse = this.chatModel.stream(prompt); + var prompt = toPrompt(advisedRequest, null); - Flux advisedResponse = fluxChatResponse; - // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); - } - } + Flux fluxChatResponse = this.chatModel.stream(prompt); - return advisedResponse; + Flux advisedResponse = fluxChatResponse; + // apply the advisors on response + if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse, advisorContext); + } + } + return advisedResponse; + + }); } public Flux chatResponse() { - return doGetFluxChatResponse(this.request); + return doGetObservableFluxChatResponse(this.request); } public Flux content() { - return doGetFluxChatResponse(this.request).map(r -> { + return doGetObservableFluxChatResponse(this.request).map(r -> { if (r.getResult() == null || r.getResult().getOutput() == null || r.getResult().getOutput().getContent() == null) { return ""; @@ -656,35 +662,22 @@ public ChatClientRequestSpec advisors(Consumer consumer) var as = new DefaultAdvisorSpec(); consumer.accept(as); this.advisorParams.putAll(as.getParams()); - this.advisors.addAll(toObservableAdvisors(as.getAdvisors(), this.observationRegistry, null)); + this.advisors.addAll(as.getAdvisors()); return this; } public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) { Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(toObservableAdvisors(List.of(advisors), this.observationRegistry, null)); + this.advisors.addAll(Arrays.asList(advisors)); return this; } public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(toObservableAdvisors(advisors, this.observationRegistry, null)); + this.advisors.addAll(advisors); return this; } - private List toObservableAdvisors(List advisors, - ObservationRegistry observationRegistry, AdvisorObservationConvention customObservationConvention) { - if (CollectionUtils.isEmpty(advisors)) { - return advisors; - } - List observableAdvisors = new ArrayList<>(); - for (RequestResponseAdvisor advisor : advisors) { - observableAdvisors.add(new ObservableRequestResponseAdvisor(advisor, observationRegistry, - customObservationConvention)); - } - return observableAdvisors; - } - public ChatClientRequestSpec messages(Message... messages) { Assert.notNull(messages, "the messages must be non-null"); this.messages.addAll(List.of(messages)); @@ -812,35 +805,23 @@ public StreamResponseSpec stream() { return new DefaultStreamResponseSpec(chatModel, this); } - public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest, - Map context) { - - DefaultChatClientRequestSpec advisedRequest = inputRequest; - - if (!CollectionUtils.isEmpty(inputRequest.advisors)) { - AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, - inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, - inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, - inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, - inputRequest.advisorParams); - - // apply the advisors onRequest - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - adviseRequest = advisor.adviseRequest(adviseRequest, context); - } + } - advisedRequest = new DefaultChatClientRequestSpec(adviseRequest.chatModel(), adviseRequest.userText(), - adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(), - adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(), - adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(), - adviseRequest.advisorParams(), inputRequest.getObservationRegistry(), - inputRequest.getCustomObservationConvention()); - } + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, + inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, + inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams); + } - return advisedRequest; - } + public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, + ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { + return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), + advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), + advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), + advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), + advisedRequest.advisorParams(), observationRegistry, customObservationConvention); } // Prompt diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index a696c66c6a..498287f477 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -18,11 +18,13 @@ import java.util.Map; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.util.StringUtils; + +import reactor.core.publisher.Flux; /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and @@ -34,6 +36,35 @@ */ public interface RequestResponseAdvisor { + public enum StreamResponseMode { + + /** + * The sync advisor will be called for each response chunk (e.g. on each Flux + * item). + */ + PER_CHUNK, + /** + * The sync advisor is called only on chunks that contain a finish reason. Usually + * the last chunk in the stream. + */ + ON_FINISH_REASON, + /** + * The sync advisor is called only once after the stream is completed and an + * aggregated response is computed. Note that at that stage the advisor can not + * modify the response, but only observe it and react on the aggregated response. + */ + AGGREGATE, + /** + * Delegates to the stream advisor implementation. + */ + CUSTOM; + + } + + default StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.CUSTOM; + } + /** * @return the advisor name. */ @@ -73,6 +104,31 @@ default ChatResponse adviseResponse(ChatResponse response, Map c * @return the advised {@link ChatResponse} flux. */ default Flux adviseResponse(Flux fluxResponse, Map context) { + + if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { + return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context)); + } + else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { + return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { + this.adviseResponse(chatResponse, context); + }); + } + else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { + return fluxResponse.map(chatResponse -> { + boolean withFinishReason = chatResponse.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + return this.adviseResponse(chatResponse, context); + } + return chatResponse; + }); + } + return fluxResponse; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 06a48e7a04..85b354ce04 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -59,6 +59,11 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize; } + @Override + public StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.AGGREGATE; + } + protected T getChatMemoryStore() { return this.chatMemoryStore; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index bf22231ceb..5f125f0bd3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -20,14 +20,11 @@ import java.util.List; import java.util.Map; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; /** * Memory is retrieved added as a collection of messages to the prompt @@ -79,17 +76,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); - }); - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index 400fb89f63..19af746d60 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -21,16 +21,13 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.model.Content; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.model.Content; /** * Memory is retrieved added into the prompt's system text. @@ -109,17 +106,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); - }); - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index dc8747fff5..2aeacd0fe3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -33,8 +33,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * Context for the question is retrieved from a Vector Store and added to the prompt's * user text. @@ -132,15 +130,6 @@ public ChatResponse adviseResponse(ChatResponse response, Map co return chatResponseBuilder.build(); } - @Override - public Flux adviseResponse(Flux fluxResponse, Map context) { - return fluxResponse.map(cr -> { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr); - chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return chatResponseBuilder.build(); - }); - } - protected Filter.Expression doGetFilterExpression(Map context) { if (!context.containsKey(FILTER_EXPRESSION) @@ -151,4 +140,9 @@ protected Filter.Expression doGetFilterExpression(Map context) { } + @Override + public StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.ON_FINISH_REASON; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index d09141720d..dc21c3fcc7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -23,11 +23,8 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; -import reactor.core.publisher.Flux; - /** * A simple logger advisor that logs the request and response messages. * @@ -65,12 +62,6 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map return request; } - @Override - public Flux adviseResponse(Flux fluxChatResponse, Map context) { - return new MessageAggregator().aggregate(fluxChatResponse, - chatResponse -> logger.debug("stream response: {}", this.responseToString.apply(chatResponse))); - } - @Override public ChatResponse adviseResponse(ChatResponse response, Map context) { logger.debug("response: {}", this.responseToString.apply(response)); @@ -82,4 +73,9 @@ public String toString() { return SimpleLoggerAdvisor.class.getSimpleName(); } + @Override + public StreamResponseMode getStreamResponseMode() { + return StreamResponseMode.AGGREGATE; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index 52e30ecdf7..42f14534d4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -21,15 +21,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.messages.AssistantMessage; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; @@ -120,19 +117,6 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map adviseResponse(Flux fluxChatResponse, Map context) { - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - List assistantMessages = chatResponse.getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context))); - }); - } - private List toDocuments(List messages, String conversationId) { List docs = messages.stream() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java new file mode 100644 index 0000000000..2f94da6ac5 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java @@ -0,0 +1,102 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.observation; + +import java.util.Map; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.RequestResponseAdvisor.StreamResponseMode; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.util.StringUtils; + +import io.micrometer.observation.Observation; +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class AdvisorObservableHelper { + + private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); + + public static AdvisedRequest adviseRequest(Observation parentObservation, RequestResponseAdvisor advisor, + AdvisedRequest advisedRequest, Map advisorContext) { + + var observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.BEFORE) + .withAdvisedRequest(advisedRequest) + .withAdvisorRequestContext(advisorContext) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + parentObservation.getObservationRegistry()) + .parentObservation(parentObservation) + .observe(() -> advisor.adviseRequest(advisedRequest, advisorContext)); + } + + public static ChatResponse adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, + ChatResponse response, Map advisorContext) { + + var observationContext = AdvisorObservationContext.builder() + .withAdvisorName(advisor.getName()) + .withAdvisorType(AdvisorObservationContext.Type.AFTER) + .withAdvisorRequestContext(advisorContext) + .build(); + + return AdvisorObservationDocumentation.AI_ADVISOR + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + parentObservation.getObservationRegistry()) + .parentObservation(parentObservation) + .observe(() -> advisor.adviseResponse(response, advisorContext)); + } + + public static Flux adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor, + Flux fluxResponse, Map advisorContext) { + + if (advisor.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) { + return fluxResponse + .map(chatResponse -> adviseResponse(parentObservation, advisor, chatResponse, advisorContext)); + } + else if (advisor.getStreamResponseMode() == StreamResponseMode.AGGREGATE) { + return new MessageAggregator().aggregate(fluxResponse, chatResponse -> { + adviseResponse(parentObservation, advisor, chatResponse, advisorContext); + }); + } + else if (advisor.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) { + return fluxResponse.map(chatResponse -> { + boolean withFinishReason = chatResponse.getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + return adviseResponse(parentObservation, advisor, chatResponse, advisorContext); + } + return chatResponse; + }); + } + + return advisor.adviseResponse(fluxResponse, advisorContext); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java deleted file mode 100644 index cbc57f6f18..0000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/ObservableRequestResponseAdvisor.java +++ /dev/null @@ -1,103 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.chat.client.advisor.observation; - -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.RequestResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; - -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class ObservableRequestResponseAdvisor implements RequestResponseAdvisor { - - private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); - - private final RequestResponseAdvisor targetAdvisor; - - private final ObservationRegistry observationRegistry; - - private final AdvisorObservationConvention customObservationConvention; - - public ObservableRequestResponseAdvisor(RequestResponseAdvisor targetAdvisor, - ObservationRegistry observationRegistry, - @Nullable AdvisorObservationConvention customObservationConvention) { - - Assert.notNull(targetAdvisor, "TargetAdvisor must not be null"); - Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); - - this.targetAdvisor = targetAdvisor; - this.observationRegistry = observationRegistry; - this.customObservationConvention = customObservationConvention; - } - - @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map advisorRequestContext) { - - var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.BEFORE) - .withAdvisedRequest(request) - .withAdvisorRequestContext(advisorRequestContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseRequest(request, advisorRequestContext)); - } - - @Override - public ChatResponse adviseResponse(ChatResponse response, Map advisorResponseContext) { - - var observationContext = this.doCreateObservationContextBuilder(AdvisorObservationContext.Type.AFTER) - .withAdvisorRequestContext(advisorResponseContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> this.targetAdvisor.adviseResponse(response, advisorResponseContext)); - } - - @Override - public Flux adviseResponse(Flux fluxResponse, Map context) { - - // NOTE: The reactive observation support is not available yet. - return this.targetAdvisor.adviseResponse(fluxResponse, context); - } - - /** - * Create the AdvisorObservationContext.Builder for the given advisorType. Can be - * overridden by the concrete advisor to provide additional context information. - * @param advisorType the advisor type. - * @return the AdvisorObservationContext.Builder instance. - */ - public AdvisorObservationContext.Builder doCreateObservationContextBuilder( - AdvisorObservationContext.Type advisorType) { - - return AdvisorObservationContext.builder() - .withAdvisorName(this.targetAdvisor.getName()) - .withAdvisorType(advisorType); - } - -} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 86bd33fa05..abd86be21d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -471,16 +471,11 @@ public void simpleSystemPrompt() throws MalformedURLException { assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(2); + assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); Message systemMessage = promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System prompt"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - - // Is this expected? - Message userMessage = promptCaptor.getValue().getInstructions().get(1); - assertThat(userMessage.getContent()).isEqualTo(""); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test