Skip to content

Commit

Permalink
Add advisor strategy for ON_FINISH_REASON streaming response. Used by…
Browse files Browse the repository at this point in the history
… the Q&A advisor
  • Loading branch information
tzolov committed Sep 2, 2024
1 parent 131cbd4 commit ac82c5e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec in
rwc.advisorContext);
}))
.single()
.flatMapMany(r -> {
DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(r.request,
.flatMapMany(rwc -> {
DefaultChatClientRequestSpec advisedRequest = toDefaultChatClientRequestSpec(rwc.request,
inputRequest.getObservationRegistry(), inputRequest.getCustomObservationConvention());

var prompt = toPrompt(advisedRequest, null);
Expand All @@ -483,8 +483,8 @@ private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec in
advisedResponse = advisor.adviseResponse(advisedResponse, advisorContext);
}
}

return advisedResponse;

});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import java.util.Map;

import reactor.core.publisher.Flux;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;

/**
* Advisor called before and after the {@link ChatModel#call(Prompt)} and
Expand All @@ -37,7 +38,26 @@ public interface RequestResponseAdvisor {

public enum StreamResponseMode {

CHUNK, AGGREGATE, CUSTOM;
/**
* The sync advisor will be called for each response chunk (e.g. on each Flux
* item).
*/
PER_CHUNK,
/**
* The sync advisor is called only on chunks that contain a finish reason. Usually
* the last chunk in the stream.
*/
ON_FINISH_REASON,
/**
* The sync advisor is called only once after the stream is completed and an
* aggregated response is computed. Note that at that stage the advisor can not
* modify the response, but only observe it and react on the aggregated response.
*/
AGGREGATE,
/**
* Delegates to the stream advisor implementation.
*/
CUSTOM;

}

Expand Down Expand Up @@ -85,14 +105,29 @@ default ChatResponse adviseResponse(ChatResponse response, Map<String, Object> c
*/
default Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {

if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) {
if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, context));
}
else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
this.adviseResponse(chatResponse, context);
});
}
else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
return fluxResponse.map(chatResponse -> {
boolean withFinishReason = chatResponse.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();

if (withFinishReason) {
return this.adviseResponse(chatResponse, context);
}
return chatResponse;
});
}

return fluxResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
import org.springframework.ai.vectorstore.SearchRequest;
Expand All @@ -34,8 +33,6 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;

/**
* Context for the question is retrieved from a Vector Store and added to the prompt's
* user text.
Expand Down Expand Up @@ -133,14 +130,6 @@ public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> co
return chatResponseBuilder.build();
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxChatResponse, Map<String, Object> context) {
// return fluxResponse.map(cr -> adviseResponse(cr, context));
return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> {
this.adviseResponse(chatResponse, context);
});
}

protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

if (!context.containsKey(FILTER_EXPRESSION)
Expand All @@ -153,8 +142,7 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

@Override
public StreamResponseMode getStreamResponseMode() {
// return StreamResponseMode.CHUNK;
return StreamResponseMode.AGGREGATE;
return StreamResponseMode.ON_FINISH_REASON;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import io.micrometer.observation.ObservationRegistry;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -93,14 +94,29 @@ public StreamResponseMode getStreamResponseMode() {
@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> advisorContext) {

if (this.getStreamResponseMode() == StreamResponseMode.CHUNK) {
if (this.getStreamResponseMode() == StreamResponseMode.PER_CHUNK) {
return fluxResponse.map(chatResponse -> this.adviseResponse(chatResponse, advisorContext));
}
else if (this.getStreamResponseMode() == StreamResponseMode.AGGREGATE) {
return new MessageAggregator().aggregate(fluxResponse, chatResponse -> {
this.adviseResponse(chatResponse, advisorContext);
});
}
else if (this.getStreamResponseMode() == StreamResponseMode.ON_FINISH_REASON) {
return fluxResponse.map(chatResponse -> {
boolean withFinishReason = chatResponse.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();

if (withFinishReason) {
return this.adviseResponse(chatResponse, advisorContext);
}
return chatResponse;
});
}

return this.targetAdvisor.adviseResponse(fluxResponse, advisorContext);
}
Expand Down

0 comments on commit ac82c5e

Please sign in to comment.