Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into tc-modules-autoconf…
Browse files Browse the repository at this point in the history
…iguration
  • Loading branch information
eddumelendez committed May 19, 2024
2 parents 9abb3d2 + 09e122d commit d326945
Show file tree
Hide file tree
Showing 182 changed files with 2,444 additions and 1,180 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@ Let's make your `@Beans` intelligent!

For further information go to our [Spring AI reference documentation](https://docs.spring.io/spring-ai/reference/).

### Breadking changes
(15.05.2024)
On our march to release 1.0 M1 we have made several breaking changes. Apologies, it is for the best!

Renamed POM artifact names:
- spring-ai-qdrant -> spring-ai-qdrant-store
- spring-ai-cassandra -> spring-ai-cassandra-store
- spring-ai-pinecone -> spring-ai-pinecone-store
- spring-ai-redis -> spring-ai-redis-store
- spring-ai-qdrant -> spring-ai-qdrant-store
- spring-ai-gemfire -> spring-ai-gemfire-store
- spring-ai-azure-vector-store-spring-boot-starter -> spring-ai-azure-store-spring-boot-starter
- spring-ai-redis-spring-boot-starter -> spring-ai-redis-store-spring-boot-starter

## Project Links

* [Documentation](https://docs.spring.io/spring-ai/reference/)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,11 @@ protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> response) {
return response.getBody().content().stream().anyMatch(content -> content.type() == MediaContent.Type.TOOL_USE);
}

@Override
protected Flux<ResponseEntity<ChatCompletion>> doChatCompletionStream(ChatCompletionRequest request) {
// https://docs.anthropic.com/en/docs/tool-use
throw new UnsupportedOperationException(
"Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -100,7 +101,13 @@ public AnthropicApi(String baseUrl, String anthropicApiKey, String anthropicVers
.defaultStatusHandler(responseErrorHandler)
.build();

this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
this.webClient = WebClient.builder()
.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
.defaultStatusHandler(HttpStatusCode::isError,
resp -> Mono.just(new RuntimeException("Response exception, Status: [" + resp.statusCode()
+ "], Body:[" + resp.bodyToMono(java.lang.String.class) + "]")))
.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ void beanStreamOutputConverterRecords() {
@Test
void multiModalityTest() throws IOException {

byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray();
var imageData = new ClassPathResource("/test.png");

var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));

ChatResponse response = chatClient.call(new Prompt(List.of(userMessage)));
var response = chatClient.call(new Prompt(List.of(userMessage)));

logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
Expand All @@ -205,15 +205,15 @@ void functionCallTest() {
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location")
.withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.")
.build()))
.build();

ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);

Generation generation = response.getResults().get(0);
Generation generation = response.getResult();
assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
*/
package org.springframework.ai.azure.openai;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
Expand All @@ -33,15 +28,14 @@
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
Expand All @@ -59,6 +53,14 @@
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
Expand All @@ -68,6 +70,7 @@
* @author Ueibin Kim
* @author John Blum
* @author Christian Tzolov
* @author Grogdunn
* @see ChatClient
* @see com.azure.ai.openai.OpenAIClient
*/
Expand Down Expand Up @@ -158,17 +161,42 @@ public Flux<ChatResponse> stream(Prompt prompt) {
IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
.getChatCompletionsStream(options.getModel(), options);

return Flux.fromStream(chatCompletionsStream.stream()
Flux<ChatCompletions> chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream);

final var isFunctionCall = new AtomicBoolean(false);
final var accessibleChatCompletionsFlux = chatCompletionsFlux
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(ChatCompletions::getChoices)
.flatMap(List::stream)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return false;
}, false)
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);
return accessibleChatCompletionsFlux
.switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options,
Flux.just(accessibleChatCompletions)))
.flatMapIterable(ChatCompletions::getChoices)
.map(choice -> {
var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null;
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
}));
});

}

/**
Expand Down Expand Up @@ -522,9 +550,17 @@ protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions requ

@Override
protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
ChatResponseMessage responseMessage = response.getChoices().get(0).getMessage();
final var accessibleChatChoice = response.getChoices().get(0);
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
.orElse(accessibleChatChoice.getDelta());
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
assistantMessage.setToolCalls(responseMessage.getToolCalls());
final var toolCalls = responseMessage.getToolCalls();
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
final var tc1 = (ChatCompletionsFunctionToolCall) tc;
var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(),
new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments()));
return ((ChatCompletionsToolCall) toDowncast);
}).toList());
return assistantMessage;
}

Expand All @@ -533,6 +569,11 @@ protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
return this.openAIClient.getChatCompletions(request.getModel(), request);
}

@Override
protected Flux<ChatCompletions> doChatCompletionStream(ChatCompletionsOptions request) {
return Flux.fromIterable(this.openAIClient.getChatCompletionsStream(request.getModel(), request));
}

@Override
protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {

Expand All @@ -549,4 +590,4 @@ protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
}

}
}
Loading

0 comments on commit d326945

Please sign in to comment.