Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling stream advisor responses #1293

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ public class SimplePersistentVectorStoreIT {
@Autowired
private EmbeddingModel embeddingModel;

@TempDir(cleanup = CleanupMode.ON_SUCCESS)
Path workingDir;

@Test
void persist(@TempDir(cleanup = CleanupMode.ON_SUCCESS) Path workingDir) {
void persist() {
JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name",
"shortDescription", "description", "tags");
List<Document> documents = jsonReader.get();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import java.util.Map;

import reactor.core.publisher.Flux;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;

/**
* Advisor called before and after the {@link ChatModel#call(Prompt)} and
Expand All @@ -34,6 +36,35 @@
*/
public interface RequestResponseAdvisor {

public enum StreamResponseMode {

/**
* The sync advisor will be called for each response chunk (e.g. on each Flux
* item).
*/
PER_CHUNK,
/**
* The sync advisor is called only on chunks that contain a finish reason. Usually
* the last chunk in the stream.
*/
ON_FINISH_REASON,
/**
* The sync advisor is called only once after the stream is completed and an
* aggregated response is computed. Note that at that stage the advisor can not
* modify the response, but only observe it and react on the aggregated response.
*/
AGGREGATE,
/**
* Delegates to the stream advisor implementation.
*/
CUSTOM;

}

default StreamResponseMode getStreamResponseMode() {
return StreamResponseMode.CUSTOM;
}

/**
* @return the advisor name.
*/
Expand Down Expand Up @@ -73,6 +104,31 @@ default ChatResponse adviseResponse(ChatResponse response, Map<String, Object> c
* @return the advised {@link ChatResponse} flux.
*/
default Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {

if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context));
}
else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
this.adviseResponse(chatResponse, context);
});
}
else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
return fluxResponse.map(chatResponse -> {
boolean withFinishReason = chatResponse.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();

if (withFinishReason) {
return this.adviseResponse(chatResponse, context);
}
return chatResponse;
});
}

return fluxResponse;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
}

@Override
public StreamResponseMode getStreamResponseMode() {
return StreamResponseMode.AGGREGATE;
}

protected T getChatMemoryStore() {
return this.chatMemoryStore;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,11 @@
import java.util.List;
import java.util.Map;

import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;

/**
* Memory is retrieved added as a collection of messages to the prompt
Expand Down Expand Up @@ -79,17 +76,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
return chatResponse;
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {

return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
List<Message> assistantMessages = chatResponse.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();

this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,13 @@
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.ai.model.Content;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.model.Content;

/**
* Memory is retrieved added into the prompt's system text.
Expand Down Expand Up @@ -109,17 +106,4 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
return chatResponse;
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {

return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
List<Message> assistantMessages = chatResponse.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();

this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;

/**
* Context for the question is retrieved from a Vector Store and added to the prompt's
* user text.
Expand Down Expand Up @@ -132,15 +130,6 @@ public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> co
return chatResponseBuilder.build();
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
return fluxResponse.map(cr -> {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return chatResponseBuilder.build();
});
}

protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

if (!context.containsKey(FILTER_EXPRESSION)
Expand All @@ -151,4 +140,9 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

}

@Override
public StreamResponseMode getStreamResponseMode() {
return StreamResponseMode.ON_FINISH_REASON;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not AGGREGATE?

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.model.ModelOptionsUtils;

import reactor.core.publisher.Flux;

/**
* A simple logger advisor that logs the request and response messages.
*
Expand Down Expand Up @@ -65,12 +62,6 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
return request;
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
return new MessageAggregator().aggregate(fluxChatResponse,
chatResponse -> logger.debug("stream response: {}", this.responseToString.apply(chatResponse)));
}

@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
logger.debug("response: {}", this.responseToString.apply(response));
Expand All @@ -82,4 +73,9 @@ public String toString() {
return SimpleLoggerAdvisor.class.getSimpleName();
}

@Override
public StreamResponseMode getStreamResponseMode() {
return StreamResponseMode.AGGREGATE;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.ai.chat.messages.AssistantMessage;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
import org.springframework.ai.vectorstore.SearchRequest;
Expand Down Expand Up @@ -120,19 +117,6 @@ public ChatResponse adviseResponse(ChatResponse chatResponse, Map<String, Object
return chatResponse;
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {

return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
List<Message> assistantMessages = chatResponse.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();

this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context)));
});
}

private List<Document> toDocuments(List<Message> messages, String conversationId) {

List<Document> docs = messages.stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.observation;

import java.util.Map;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.RequestResponseAdvisor.StreamResponseMode;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import reactor.core.publisher.Flux;

/**
* @author Christian Tzolov
* @since 1.0.0
*/
public class AdvisorObservableHelper {

private static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();

public static AdvisedRequest adviseRequest(Observation parentObservation, RequestResponseAdvisor advisor,
AdvisedRequest advisedRequest, Map<String, Object> advisorContext) {

var observationContext = AdvisorObservationContext.builder()
.withAdvisorName(advisor.getName())
.withAdvisorType(AdvisorObservationContext.Type.BEFORE)
.withAdvisedRequest(advisedRequest)
.withAdvisorRequestContext(advisorContext)
.build();

return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
parentObservation.getObservationRegistry())
.parentObservation(parentObservation)
.observe(() -> advisor.adviseRequest(advisedRequest, advisorContext));
}

public static ChatResponse adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor,
ChatResponse response, Map<String, Object> advisorContext) {

var observationContext = AdvisorObservationContext.builder()
.withAdvisorName(advisor.getName())
.withAdvisorType(AdvisorObservationContext.Type.AFTER)
.withAdvisorRequestContext(advisorContext)
.build();

return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
parentObservation.getObservationRegistry())
.parentObservation(parentObservation)
.observe(() -> advisor.adviseResponse(response, advisorContext));
}

public static Flux<ChatResponse> adviseResponse(Observation parentObservation, RequestResponseAdvisor advisor,
Flux<ChatResponse> fluxResponse, Map<String, Object> advisorContext) {

if (advisor.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
return fluxResponse
.map(chatResponse -> adviseResponse(parentObservation, advisor, chatResponse, advisorContext));
}
else if (advisor.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
adviseResponse(parentObservation, advisor, chatResponse, advisorContext);
});
}
else if (advisor.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
return fluxResponse.map(chatResponse -> {
boolean withFinishReason = chatResponse.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();

if (withFinishReason) {
return adviseResponse(parentObservation, advisor, chatResponse, advisorContext);
}
return chatResponse;
});
}

return advisor.adviseResponse(fluxResponse, advisorContext);
}

}
Loading