Skip to content

Commit

Permalink
Merge branch 'main' into aws-bedrock-converse
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjiang153 committed Jun 21, 2024
2 parents 0fffb5a + 067a33d commit 0527909
Show file tree
Hide file tree
Showing 298 changed files with 10,817 additions and 2,553 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
root = true

[*.{adoc,bat,groovy,html,java,js,jsp,kt,kts,md,properties,py,rb,sh,sql,svg,txt,xml,xsd}]
charset = utf-8

[*.{groovy,java,kt,kts,xml,xsd}]
indent_style = tab
indent_size = 4
continuation_indent_size = 8
end_of_line = lf
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Spring AI supports many AI models. For an overview see here. Specific models c
* OpenAI
* Azure OpenAI
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral)
* HuggingFace
* Hugging Face
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
* Stability AI
Expand Down Expand Up @@ -163,7 +163,7 @@ Though the `DocumentWriter` interface isn't exclusively for Vector Database writ

**Vector Stores:** Vector Databases are instrumental in incorporating your data with AI models.
They ascertain which document sections the AI should use for generating responses.
Examples of Vector Databases include Chroma, Postgres, Pinecone, Qdrant, Weaviate, Mongo Atlas, and Redis. Spring AI's `VectorStore` abstraction permits effortless transitions between database implementations.
Examples of Vector Databases include Chroma, Oracle, Postgres, Pinecone, Qdrant, Weaviate, Mongo Atlas, and Redis. Spring AI's `VectorStore` abstraction permits effortless transitions between database implementations.



Expand Down
2 changes: 1 addition & 1 deletion document-readers/tika-reader/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
</scm>

<properties>
<tika.version>2.9.0</tika.version>
<tika.version>2.9.2</tika.version>
</properties>

<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
Expand All @@ -59,11 +58,12 @@
* The {@link ChatModel} implementation for the Anthropic service.
*
* @author Christian Tzolov
* @author luocongqiu
* @since 1.0.0
*/
public class AnthropicChatModel extends
AbstractFunctionCallSupport<AnthropicApi.RequestMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletion>>
implements ChatModel, StreamingChatModel {
implements ChatModel {

private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);

Expand All @@ -81,7 +81,7 @@ public class AnthropicChatModel extends
/**
* The default options used for the chat completion requests.
*/
private AnthropicChatOptions defaultOptions;
private final AnthropicChatOptions defaultOptions;

/**
* The retry template used to retry the OpenAI API calls.
Expand Down Expand Up @@ -280,20 +280,14 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, AnthropicChatOptions.class);
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, AnthropicChatOptions.class);

Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}

if (this.defaultOptions != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
Expand Down Expand Up @@ -90,6 +91,11 @@ public Builder withModel(String model) {
return this;
}

public Builder withModel(AnthropicApi.ChatModel model) {
this.options.model = model.getValue();
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,11 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.")
.withDescription(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.build()))
.build();

Expand All @@ -214,9 +215,7 @@ void functionCallTest() {
logger.info("Response: {}", response);

Generation generation = response.getResult();
assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15");
assertThat(generation.getOutput().getContent()).contains("30", "10", "15");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsToolCall;
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
import com.azure.ai.openai.models.ChatMessageContentItem;
import com.azure.ai.openai.models.ChatMessageImageContentItem;
import com.azure.ai.openai.models.ChatMessageImageUrl;
import com.azure.ai.openai.models.ChatMessageTextContentItem;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
Expand All @@ -32,23 +39,18 @@
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
Expand All @@ -58,6 +60,7 @@
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
Expand All @@ -74,12 +77,13 @@
* @author John Blum
* @author Christian Tzolov
* @author Grogdunn
* @author Benoit Moussaud
* @author luocongqiu
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
*/
public class AzureOpenAiChatModel
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
implements ChatModel, StreamingChatModel {
public class AzureOpenAiChatModel extends
AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions> implements ChatModel {

private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";

Expand All @@ -88,14 +92,14 @@ public class AzureOpenAiChatModel
private final Logger logger = LoggerFactory.getLogger(getClass());

/**
* The configuration information for a chat completions request.
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
*/
private AzureOpenAiChatOptions defaultOptions;
private final OpenAIClient openAIClient;

/**
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
* The configuration information for a chat completions request.
*/
private final OpenAIClient openAIClient;
private AzureOpenAiChatOptions defaultOptions;

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
this(microsoftOpenAiClient,
Expand Down Expand Up @@ -143,8 +147,7 @@ public ChatResponse call(Prompt prompt) {
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = chatCompletions.getChoices()
.stream()
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();
Expand Down Expand Up @@ -229,24 +232,17 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, AzureOpenAiChatOptions.class);
// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
// options = ModelOptionsUtils.merge(runtimeOptions, options,
// ChatCompletionsOptions.class);
options = merge(updatedRuntimeOptions, options);

Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, AzureOpenAiChatOptions.class);
// JSON merge doesn't due to Azure OpenAI service bug:
// https://github.com/Azure/azure-sdk-for-java/issues/38183
// options = ModelOptionsUtils.merge(runtimeOptions, options,
// ChatCompletionsOptions.class);
options = merge(updatedRuntimeOptions, options);

}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
+ prompt.getOptions().getClass().getSimpleName());
}
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
}

// Add the enabled functions definitions to the request's tools parameter.
Expand Down Expand Up @@ -278,7 +274,17 @@ private ChatRequestMessage fromSpringAiMessage(Message message) {

switch (message.getMessageType()) {
case USER:
return new ChatRequestUserMessage(message.getContent());
// https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/openai/azure-ai-openai/README.md#text-completions-with-images
List<ChatMessageContentItem> items = new ArrayList<>();
items.add(new ChatMessageTextContentItem(message.getContent()));
if (!CollectionUtils.isEmpty(message.getMedia())) {
items.addAll(message.getMedia()
.stream()
.map(media -> new ChatMessageImageContentItem(
new ChatMessageImageUrl(media.getData().toString())))
.toList());
}
return new ChatRequestUserMessage(items);
case SYSTEM:
return new ChatRequestSystemMessage(message.getContent());
case ASSISTANT:
Expand Down
Loading

0 comments on commit 0527909

Please sign in to comment.