From 97f443d615644ba02c65e80579cd6df3a77669cb Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Mon, 15 Jul 2024 17:45:08 -0400 Subject: [PATCH] Update to ResponseMetadata design * Remove inheritance from HashMap * No more subclasses per model provider * Builder class for ChatResponse * Fix the AbstractResponseMetadata#AI_METADATA_STRING parameter order * ChatResponseMetadata ignore Null values. --- .../ai/anthropic/AnthropicChatModel.java | 35 +++-- .../AnthropicChatResponseMetadata.java | 103 ------------- .../ai/azure/openai/AzureOpenAiChatModel.java | 73 ++++++---- .../AzureOpenAiChatResponseMetadata.java | 83 ----------- .../AzureOpenAiImageResponseMetadata.java | 3 +- .../ai/mistralai/MistralAiChatModel.java | 33 +++-- .../MistralAiChatResponseMetadata.java | 62 -------- .../ai/ollama/OllamaChatModel.java | 33 +++-- .../metadata/OllamaChatResponseMetadata.java | 57 -------- .../ai/openai/ImageResponseMetadata.java | 2 +- .../ai/openai/OpenAiChatModel.java | 50 +++++-- .../ai/openai/OpenAiImageModel.java | 7 +- .../metadata/OpenAiChatResponseMetadata.java | 103 ------------- .../metadata/OpenAiImageResponseMetadata.java | 62 -------- .../OpenAiAudioSpeechResponseMetadata.java | 3 +- ...nAiAudioTranscriptionResponseMetadata.java | 6 +- .../postgresml/PostgresMlEmbeddingModel.java | 11 +- .../PostgresMlEmbeddingModelIT.java | 52 ++++++- .../ai/stabilityai/StabilityAiImageModel.java | 2 +- .../TransformersEmbeddingModelTests.java | 3 +- .../embedding/VertexAiEmbeddingUsage.java | 28 ++++ .../VertexAiMultimodalEmbeddingModel.java | 20 +-- .../text/VertexAiTextEmbeddingModel.java | 9 +- ...> VertexAiMultimodalEmbeddingModelIT.java} | 33 +++-- .../text/VertexAiTextEmbeddingModelIT.java | 8 +- .../gemini/VertexAiGeminiChatModel.java | 56 ++++---- .../VertexAiChatResponseMetadata.java | 40 ------ .../client/advisor/QuestionAnswerAdvisor.java | 11 +- .../chat/metadata/ChatResponseMetadata.java | 135 +++++++++++++++--- .../ai/chat/model/ChatResponse.java | 44 +++++- .../embedding/EmbeddingResponseMetadata.java | 23 ++- .../ai/image/ImageResponse.java | 2 +- .../ai/image/ImageResponseMetadata.java | 15 +- .../ai/model/AbstractResponseMetadata.java | 76 ++++++++++ .../ai/model/MutableResponseMetadata.java | 126 ++++++++++++++++ .../ai/model/ResponseMetadata.java | 73 +++++++++- .../function/AbstractFunctionCallSupport.java | 8 +- .../client/ChatClientResponseEntityTests.java | 6 +- ...TextEmbeddingModelAutoConfigurationIT.java | 4 +- 39 files changed, 773 insertions(+), 727 deletions(-) delete mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicChatResponseMetadata.java delete mode 100644 models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java delete mode 100644 models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiChatResponseMetadata.java delete mode 100644 models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatResponseMetadata.java delete mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java delete mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java create mode 100644 models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java rename models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/{VertexAiMultimodelEmbeddingModelIT.java => VertexAiMultimodalEmbeddingModelIT.java} (87%) delete mode 100644 models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiChatResponseMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/AbstractResponseMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 35227909bb..8b8317b7d2 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -15,14 +15,6 @@ */ package org.springframework.ai.anthropic; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -32,12 +24,13 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata; +import org.springframework.ai.anthropic.metadata.AnthropicUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -52,10 +45,17 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + /** * The {@link ChatModel} implementation for the Anthropic service. * @@ -228,7 +228,20 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { .withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); }).toList(); - return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion)); + return new ChatResponse(generations, from(chatCompletion)); + } + + private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { + Assert.notNull(result, "Anthropic ChatCompletionResult must not be null"); + AnthropicUsage usage = AnthropicUsage.from(result.usage()); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("stop-reason", result.stopReason()) + .withKeyValue("stop-sequence", result.stopSequence()) + .withKeyValue("type", result.type()) + .build(); } private String fromMediaData(Object mediaData) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicChatResponseMetadata.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicChatResponseMetadata.java deleted file mode 100644 index 0ce5d8b905..0000000000 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicChatResponseMetadata.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.anthropic.metadata; - -import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.EmptyRateLimit; -import org.springframework.ai.chat.metadata.EmptyUsage; -import org.springframework.ai.chat.metadata.RateLimit; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; - -import java.util.HashMap; - -/** - * {@link ChatResponseMetadata} implementation for {@literal AnthropicApi}. - * - * @author Christian Tzolov - * @author Thomas Vitale - * @see ChatResponseMetadata - * @see RateLimit - * @see Usage - * @since 1.0.0 - */ -public class AnthropicChatResponseMetadata extends HashMap implements ChatResponseMetadata { - - protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }"; - - public static AnthropicChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { - Assert.notNull(result, "Anthropic ChatCompletionResult must not be null"); - AnthropicUsage usage = AnthropicUsage.from(result.usage()); - return new AnthropicChatResponseMetadata(result.id(), result.model(), usage); - } - - private final String id; - - private final String model; - - @Nullable - private RateLimit rateLimit; - - private final Usage usage; - - protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage) { - this(id, model, usage, null); - } - - protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage, - @Nullable AnthropicRateLimit rateLimit) { - this.id = id; - this.model = model; - this.usage = usage; - this.rateLimit = rateLimit; - } - - @Override - public String getId() { - return this.id; - } - - @Override - public String getModel() { - return this.model; - } - - @Override - @Nullable - public RateLimit getRateLimit() { - RateLimit rl = this.rateLimit; - return rl != null ? rl : new EmptyRateLimit(); - } - - @Override - public Usage getUsage() { - Usage usage = this.usage; - return usage != null ? usage : new EmptyUsage(); - } - - public AnthropicChatResponseMetadata withRateLimit(RateLimit rateLimit) { - this.rateLimit = rateLimit; - return this; - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit()); - } - -} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1cd1afd67b..9c54d0ee8e 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -15,33 +15,6 @@ */ package org.springframework.ai.azure.openai; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; - -import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.AbstractToolCallSupport; -import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; - import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; @@ -68,10 +41,36 @@ import com.azure.ai.openai.models.FunctionDefinition; import com.azure.core.util.BinaryData; import com.azure.core.util.IterableStream; - +import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractToolCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -151,8 +150,22 @@ public ChatResponse call(Prompt prompt) { PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); - return new ChatResponse(generations, - AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); + return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata)); + } + + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); + String id = chatCompletions.getId(); + AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withId(id) + .withUsage(usage) + .withModel(chatCompletions.getModel()) + .withPromptMetadata(promptFilterMetadata) + .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) + .build(); + + return chatResponseMetadata; } @Override diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java deleted file mode 100644 index a20c40ae00..0000000000 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.azure.openai.metadata; - -import com.azure.ai.openai.models.ChatCompletions; - -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -import java.util.HashMap; - -/** - * {@link ChatResponseMetadata} implementation for - * {@literal Microsoft Azure OpenAI Service}. - * - * @author John Blum - * @author Thomas Vitale - * @see ChatResponseMetadata - * @since 0.7.1 - */ -public class AzureOpenAiChatResponseMetadata extends HashMap implements ChatResponseMetadata { - - protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; - - @SuppressWarnings("all") - public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions, - PromptMetadata promptFilterMetadata) { - Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); - String id = chatCompletions.getId(); - AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); - AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage, - promptFilterMetadata); - return chatResponseMetadata; - } - - private final String id; - - private final Usage usage; - - private final PromptMetadata promptMetadata; - - protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) { - this.id = id; - this.usage = usage; - this.promptMetadata = promptMetadata; - } - - @Override - public String getId() { - return this.id; - } - - @Override - public Usage getUsage() { - return this.usage; - } - - @Override - public PromptMetadata getPromptMetadata() { - return this.promptMetadata; - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit()); - } - -} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java index e821913f7c..6d01d5cbb8 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java @@ -2,6 +2,7 @@ import com.azure.ai.openai.models.ImageGenerations; import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.util.Assert; import java.util.HashMap; @@ -15,7 +16,7 @@ * @author Benoit Moussaud * @since 1.0.0 M1 */ -public class AzureOpenAiImageResponseMetadata extends HashMap implements ImageResponseMetadata { +public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata { private final Long created; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index aef847a2df..ada6507351 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -15,14 +15,6 @@ */ package org.springframework.ai.mistralai; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -31,6 +23,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -44,7 +37,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; -import org.springframework.ai.mistralai.metadata.MistralAiChatResponseMetadata; +import org.springframework.ai.mistralai.metadata.MistralAiUsage; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -53,10 +46,17 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + /** * @author Ricken Bazolo * @author Christian Tzolov @@ -134,10 +134,21 @@ public ChatResponse call(Prompt prompt) { .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null))) .toList(); - return new ChatResponse(generations, MistralAiChatResponseMetadata.from(chatCompletion)); + return new ChatResponse(generations, from(chatCompletion)); }); } + public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { + Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); + MistralAiUsage usage = MistralAiUsage.from(result.usage()); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("created", result.created()) + .build(); + } + private Map toMap(String id, ChatCompletion.Choice choice) { Map map = new HashMap<>(); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiChatResponseMetadata.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiChatResponseMetadata.java deleted file mode 100644 index 2ca8bfbf6b..0000000000 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiChatResponseMetadata.java +++ /dev/null @@ -1,62 +0,0 @@ -package org.springframework.ai.mistralai.metadata; - -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.EmptyUsage; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.util.Assert; - -import java.util.HashMap; - -/** - * {@link ChatResponseMetadata} implementation for {@literal Mistral AI}. - * - * @author Thomas Vitale - * @see ChatResponseMetadata - * @see Usage - * @since 1.0.0 - */ -public class MistralAiChatResponseMetadata extends HashMap implements ChatResponseMetadata { - - protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s }"; - - public static MistralAiChatResponseMetadata from(MistralAiApi.ChatCompletion result) { - Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); - MistralAiUsage usage = MistralAiUsage.from(result.usage()); - return new MistralAiChatResponseMetadata(result.id(), result.model(), usage); - } - - private final String id; - - private final String model; - - private final Usage usage; - - protected MistralAiChatResponseMetadata(String id, String model, MistralAiUsage usage) { - this.id = id; - this.model = model; - this.usage = usage; - } - - @Override - public String getId() { - return this.id; - } - - @Override - public String getModel() { - return this.model; - } - - @Override - public Usage getUsage() { - Usage usage = this.usage; - return usage != null ? usage : new EmptyUsage(); - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage()); - } - -} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 39405c2d62..4e5efdfce6 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -18,25 +18,26 @@ import java.util.Base64; import java.util.List; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.metadata.OllamaUsage; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; + /** * {@link ChatModel} implementation for {@literal Ollama}. * @@ -102,7 +103,23 @@ public ChatResponse call(Prompt prompt) { if (response.promptEvalCount() != null && response.evalCount() != null) { generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null)); } - return new ChatResponse(List.of(generator), OllamaChatResponseMetadata.from(response)); + return new ChatResponse(List.of(generator), from(response)); + } + + public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { + Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); + return ChatResponseMetadata.builder() + .withUsage(OllamaUsage.from(response)) + .withModel(response.model()) + .withKeyValue("created-at", response.createdAt()) + .withKeyValue("eval-duration", response.evalDuration()) + .withKeyValue("eval-count", response.evalCount()) + .withKeyValue("load-duration", response.loadDuration()) + .withKeyValue("eval-duration", response.promptEvalDuration()) + .withKeyValue("eval-count", response.promptEvalCount()) + .withKeyValue("total-duration", response.totalDuration()) + .withKeyValue("done", response.done()) + .build(); } @Override @@ -116,7 +133,7 @@ public Flux stream(Prompt prompt) { if (Boolean.TRUE.equals(chunk.done())) { generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null)); } - return new ChatResponse(List.of(generation), OllamaChatResponseMetadata.from(chunk)); + return new ChatResponse(List.of(generation), from(chunk)); }); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatResponseMetadata.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatResponseMetadata.java deleted file mode 100644 index 6f1d213fae..0000000000 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatResponseMetadata.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.ollama.metadata; - -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.util.Assert; - -import java.util.HashMap; - -/** - * {@link ChatResponseMetadata} implementation for {@literal Ollama} - * - * @see ChatResponseMetadata - * @author Fu Cheng - */ -public class OllamaChatResponseMetadata extends HashMap implements ChatResponseMetadata { - - protected static final String AI_METADATA_STRING = "{ @type: %1$s, usage: %2$s, rateLimit: %3$s }"; - - public static OllamaChatResponseMetadata from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - Usage usage = OllamaUsage.from(response); - return new OllamaChatResponseMetadata(usage); - } - - private final Usage usage; - - protected OllamaChatResponseMetadata(Usage usage) { - this.usage = usage; - } - - @Override - public Usage getUsage() { - return this.usage; - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getClass().getTypeName(), getUsage(), getRateLimit()); - } - -} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java index f339b4a0bb..3ec1ad510c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java @@ -1,5 +1,5 @@ package org.springframework.ai.openai; -public class ImageResponseMetadata { +public interface ImageResponseMetadata { } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index fca2de7144..9d313cc2ea 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -15,14 +15,6 @@ */ package org.springframework.ai.openai; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -30,6 +22,7 @@ import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -49,7 +42,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; -import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; +import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -57,10 +50,17 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} * backed by {@link OpenAiApi}. @@ -165,7 +165,7 @@ public ChatResponse call(Prompt prompt) { } // Non function calling. - RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); + RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); List choices = chatCompletion.choices(); if (choices == null) { @@ -186,11 +186,22 @@ public ChatResponse call(Prompt prompt) { }).toList(); - return new ChatResponse(generations, - OpenAiChatResponseMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits)); + return new ChatResponse(generations, from(completionEntity.getBody(), rateLimit)); }); } + public static ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) { + Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withUsage(OpenAiUsage.from(result.usage())) + .withModel(result.model()) + .withRateLimit(rateLimit) + .withKeyValue("created", result.created()) + .withKeyValue("system-fingerprint", result.systemFingerprint()) + .build(); + } + @Override public Flux stream(Prompt prompt) { @@ -237,7 +248,7 @@ public Flux stream(Prompt prompt) { }).toList(); if (chatCompletion2.usage() != null) { - return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion2)); + return new ChatResponse(generations, from(chatCompletion2)); } else { return new ChatResponse(generations); @@ -253,6 +264,17 @@ public Flux stream(Prompt prompt) { }); } + private ChatResponseMetadata from(OpenAiApi.ChatCompletion result) { + Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withUsage(OpenAiUsage.from(result.usage())) + .withModel(result.model()) + .withKeyValue("created", result.created()) + .withKeyValue("system-fingerprint", result.systemFingerprint()) + .build(); + } + private List handleToolCallRequests(List previousMessages, ChatCompletion chatCompletion) { ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index efc4b24ec9..7c4267fbef 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.openai; -import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.image.Image; @@ -29,12 +27,13 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; -import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import java.util.List; + /** * OpenAiImageModel is a class that implements the ImageModel interface. It provides a * client for calling the OpenAI image generation API. @@ -130,7 +129,7 @@ private ImageResponse convertResponse(ResponseEntity implements ChatResponseMetadata { - - protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }"; - - public static OpenAiChatResponseMetadata from(OpenAiApi.ChatCompletion result) { - Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); - OpenAiUsage usage = OpenAiUsage.from(result.usage()); - return new OpenAiChatResponseMetadata(result.id(), result.model(), usage); - } - - private final String id; - - private final String model; - - @Nullable - private RateLimit rateLimit; - - private final Usage usage; - - protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage) { - this(id, model, usage, null); - } - - protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage, - @Nullable OpenAiRateLimit rateLimit) { - this.id = id; - this.model = model; - this.usage = usage; - this.rateLimit = rateLimit; - } - - @Override - public String getId() { - return this.id; - } - - @Override - public String getModel() { - return this.model; - } - - @Override - @Nullable - public RateLimit getRateLimit() { - RateLimit rateLimit = this.rateLimit; - return rateLimit != null ? rateLimit : new EmptyRateLimit(); - } - - @Override - public Usage getUsage() { - Usage usage = this.usage; - return usage != null ? usage : new EmptyUsage(); - } - - public OpenAiChatResponseMetadata withRateLimit(RateLimit rateLimit) { - this.rateLimit = rateLimit; - return this; - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit()); - } - -} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java deleted file mode 100644 index ec9519c827..0000000000 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.openai.metadata; - -import org.springframework.ai.image.ImageResponseMetadata; -import org.springframework.ai.openai.api.OpenAiImageApi; -import org.springframework.util.Assert; - -import java.util.HashMap; -import java.util.Objects; - -public class OpenAiImageResponseMetadata extends HashMap implements ImageResponseMetadata { - - private final Long created; - - public static OpenAiImageResponseMetadata from(OpenAiImageApi.OpenAiImageResponse openAiImageResponse) { - Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null"); - return new OpenAiImageResponseMetadata(openAiImageResponse.created()); - } - - protected OpenAiImageResponseMetadata(Long created) { - this.created = created; - } - - @Override - public Long getCreated() { - return this.created; - } - - @Override - public String toString() { - return "OpenAiImageResponseMetadata{" + "created=" + created + '}'; - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof OpenAiImageResponseMetadata that)) - return false; - return Objects.equals(created, that.created); - } - - @Override - public int hashCode() { - return Objects.hash(created); - } - -} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java index 4f38f3c0d4..efcb6ebca7 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java @@ -18,6 +18,7 @@ import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.ai.model.ResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.lang.Nullable; @@ -31,7 +32,7 @@ * @author Ahmed Yousri * @see RateLimit */ -public class OpenAiAudioSpeechResponseMetadata extends HashMap implements ResponseMetadata { +public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }"; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java index 5add8aa5b8..f5fa913909 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java @@ -17,14 +17,12 @@ import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; -import org.springframework.ai.model.ResponseMetadata; +import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.metadata.OpenAiRateLimit; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import java.util.HashMap; - /** * Audio transcription metadata implementation for {@literal OpenAI}. * @@ -32,7 +30,7 @@ * @since 0.8.1 * @see RateLimit */ -public class OpenAiAudioTranscriptionResponseMetadata extends HashMap implements ResponseMetadata { +public class OpenAiAudioTranscriptionResponseMetadata extends MutableResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java index 2783f6df38..3ba6ab13a7 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; +import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -200,11 +201,11 @@ public EmbeddingResponse call(EmbeddingRequest request) { } } - var metadata = new EmbeddingResponseMetadata( - Map.of("transformer", optionsToUse.getTransformer(), "vector-type", optionsToUse.getVectorType().name(), - "kwargs", ModelOptionsUtils.toJsonString(optionsToUse.getKwargs()))); - - return new EmbeddingResponse(data, metadata); + Map embeddingMetadata = Map.of("transformer", optionsToUse.getTransformer(), "vector-type", + optionsToUse.getVectorType().name(), "kwargs", + ModelOptionsUtils.toJsonString(optionsToUse.getKwargs())); + var embeddingResponseMetadata = new EmbeddingResponseMetadata("unknown", new EmptyUsage(), embeddingMetadata); + return new EmbeddingResponse(data, embeddingResponseMetadata); } /** diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java index 55c05d06d9..323554fe9f 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java @@ -30,6 +30,7 @@ import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType; import org.testcontainers.containers.PostgreSQLContainer; @@ -144,8 +145,21 @@ void embedForResponse(String vectorType) { assertThat(embeddingResponse).isNotNull(); assertThat(embeddingResponse.getResults()).hasSize(3); - assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf( - Map.of("transformer", "distilbert-base-uncased", "vector-type", vectorType, "kwargs", "{}")); + + EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata(); + assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys") + .containsExactlyInAnyOrder("transformer", "vector-type", "kwargs"); + + assertThat(metadata.get("transformer").toString()) + .as("Transformer in metadata should be 'distilbert-base-uncased'") + .isEqualTo("distilbert-base-uncased"); + + assertThat(metadata.get("vector-type").toString()) + .as("Vector type in metadata should match expected vector type") + .isEqualTo(vectorType); + + assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}"); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); @@ -170,8 +184,22 @@ void embedCallWithRequestOptionsOverride() { assertThat(embeddingResponse).isNotNull(); assertThat(embeddingResponse.getResults()).hasSize(3); - assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer", - "distilbert-base-uncased", "vector-type", VectorType.PG_VECTOR.name(), "kwargs", "{}")); + + EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata(); + + assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys") + .containsExactlyInAnyOrder("transformer", "vector-type", "kwargs"); + + assertThat(metadata.get("transformer").toString()) + .as("Transformer in metadata should be 'distilbert-base-uncased'") + .isEqualTo("distilbert-base-uncased"); + + assertThat(metadata.get("vector-type").toString()) + .as("Vector type in metadata should match expected vector type") + .isEqualTo(VectorType.PG_VECTOR.name()); + + assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}"); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); @@ -192,8 +220,20 @@ void embedCallWithRequestOptionsOverride() { assertThat(embeddingResponse).isNotNull(); assertThat(embeddingResponse.getResults()).hasSize(3); - assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer", - "intfloat/e5-small", "vector-type", VectorType.PG_ARRAY.name(), "kwargs", "{\"device\":\"cpu\"}")); + + metadata = embeddingResponse.getMetadata(); + + assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys") + .containsExactlyInAnyOrder("transformer", "vector-type", "kwargs"); + + assertThat(metadata.get("transformer").toString()).as("Transformer in metadata should be 'intfloat/e5-small'") + .isEqualTo("intfloat/e5-small"); + + assertThat(metadata.get("vector-type").toString()).as("Vector type in metadata should be PG_ARRAY") + .isEqualTo(VectorType.PG_ARRAY.name()); + + assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{\"device\":\"cpu\"}'") + .isEqualTo("{\"device\":\"cpu\"}"); assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(384); diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index abb52c9a9d..bf6f41f1a4 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -119,7 +119,7 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed())); }).toList(); - return new ImageResponse(imageGenerationList, ImageResponseMetadata.NULL); + return new ImageResponse(imageGenerationList, new ImageResponseMetadata()); } private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) { diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java index 40f963b5dd..57ee908b3b 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java @@ -24,6 +24,7 @@ import org.springframework.ai.embedding.EmbeddingResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Christian Tzolov @@ -76,7 +77,7 @@ void embedForResponse() throws Exception { embeddingModel.afterPropertiesSet(); EmbeddingResponse embed = embeddingModel.embedForResponse(List.of("Hello world", "World is big")); assertThat(embed.getResults()).hasSize(2); - assertThat(embed.getMetadata()).isEmpty(); + assertTrue(embed.getMetadata().isEmpty(), "Expected embed metadata to be empty, but it was not."); assertThat(embed.getResults().get(0).getOutput()).hasSize(384); assertThat(DF.format(embed.getResults().get(0).getOutput().get(0))).isEqualTo(DF.format(-0.19744634628295898)); diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java new file mode 100644 index 0000000000..ef0152c23a --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java @@ -0,0 +1,28 @@ +package org.springframework.ai.vertexai.embedding; + +import org.springframework.ai.chat.metadata.Usage; + +public class VertexAiEmbeddingUsage implements Usage { + + private final Integer totalTokens; + + public VertexAiEmbeddingUsage(Integer totalTokens) { + this.totalTokens = totalTokens; + } + + @Override + public Long getPromptTokens() { + return 0L; + } + + @Override + public Long getGenerationTokens() { + return 0L; + } + + @Override + public Long getTotalTokens() { + return Long.valueOf(this.totalTokens); + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index e4c3132ffa..5e99b118bb 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.Media; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingModel; import org.springframework.ai.embedding.DocumentEmbeddingRequest; @@ -35,6 +36,7 @@ import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.MultimodalInstanceBuilder; @@ -230,20 +232,18 @@ else if (media.getMimeType().isCompatibleWith(VIDEO_MIME_TYPE)) { String deploymentModelId = embeddingResponse.getDeployedModelId(); - EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), -1); - - responseMetadata.put("deployment-model-id", + Map metadataToUse = Map.of("deployment-model-id", StringUtils.hasText(deploymentModelId) ? deploymentModelId : "unknown"); - - return new EmbeddingResponse(embeddingList, generateResponseMetadata(mergedOptions.getModel(), 0)); + EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), 0, + metadataToUse); + return new EmbeddingResponse(embeddingList, responseMetadata); } - private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) { - EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); - metadata.put("model", model); - metadata.put("total-tokens", tokenCount); - return metadata; + private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens, + Map metadataToUse) { + Usage usage = new VertexAiEmbeddingUsage(totalTokens); + return new EmbeddingResponseMetadata(model, usage, metadataToUse); } @Override diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 70364a4394..aba0dc2ee3 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -20,6 +20,7 @@ import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.protobuf.Value; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; import org.springframework.ai.embedding.Embedding; @@ -32,6 +33,7 @@ import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -135,10 +137,11 @@ public EmbeddingResponse call(EmbeddingRequest request) { } } - private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) { + private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) { EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); - metadata.put("model", model); - metadata.put("total-tokens", tokenCount); + metadata.setModel(model); + Usage usage = new VertexAiEmbeddingUsage(totalTokens); + metadata.setUsage(usage); return metadata; } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java similarity index 87% rename from models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java rename to models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java index 652f8abee2..044d5380a5 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java @@ -33,10 +33,10 @@ import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -@SpringBootTest(classes = VertexAiMultimodelEmbeddingModelIT.Config.class) +@SpringBootTest(classes = VertexAiMultimodalEmbeddingModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") -class VertexAiMultimodelEmbeddingModelIT { +class VertexAiMultimodalEmbeddingModelIT { // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api @@ -68,8 +68,13 @@ void multipleInstancesEmbedding() { .isEqualTo(embeddingRequest.getInstructions().get(1).getId()); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()) + .as("Model in metadata should be 'multimodalembedding@001'") + .isEqualTo("multimodalembedding@001"); + + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()) + .as("Total tokens in metadata should be 0") + .isEqualTo("0"); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @@ -90,8 +95,8 @@ void textContentEmbedding() { .isEqualTo(MimeTypeUtils.TEXT_PLAIN); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @@ -113,8 +118,8 @@ void textMediaEmbedding() { .isEqualTo(MimeTypeUtils.TEXT_PLAIN); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @@ -139,8 +144,8 @@ void imageEmbedding() { assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @@ -164,8 +169,8 @@ void videoEmbedding() { .isEqualTo(new MimeType("video", "mp4")); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @@ -198,8 +203,8 @@ void textImageAndVideoEmbedding() { .isEqualTo(EmbeddingResultMetadata.ModalityType.VIDEO); assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java index 43a6c86cc3..4c9a9cdcb0 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -53,8 +53,12 @@ void defaultEmbedding(String modelName) { assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", modelName); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 5); + assertThat(embeddingResponse.getMetadata().getModel()).as("Model name in metadata should match expected model") + .isEqualTo(modelName); + + assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()) + .as("Total tokens in metadata should be 5") + .isEqualTo(5L); assertThat(embeddingModel.dimensions()).isEqualTo(768); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 04e75de560..750700175d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -15,13 +15,23 @@ */ package org.springframework.ai.vertexai.gemini; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; +import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.FunctionResponse; +import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.generativeai.GenerativeModel; +import com.google.cloud.vertexai.generativeai.PartMaker; +import com.google.cloud.vertexai.generativeai.ResponseStream; +import com.google.protobuf.Struct; +import com.google.protobuf.util.JsonFormat; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Media; import org.springframework.ai.chat.messages.Message; @@ -29,6 +39,7 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -38,35 +49,22 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata; import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Content; -import com.google.cloud.vertexai.api.FunctionCall; -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.cloud.vertexai.api.FunctionResponse; -import com.google.cloud.vertexai.api.GenerateContentResponse; -import com.google.cloud.vertexai.api.GenerationConfig; -import com.google.cloud.vertexai.api.Part; -import com.google.cloud.vertexai.api.Schema; -import com.google.cloud.vertexai.api.Tool; -import com.google.cloud.vertexai.generativeai.GenerativeModel; -import com.google.cloud.vertexai.generativeai.PartMaker; -import com.google.cloud.vertexai.generativeai.ResponseStream; -import com.google.protobuf.Struct; -import com.google.protobuf.util.JsonFormat; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + /** * @author Christian Tzolov * @author Grogdunn @@ -244,8 +242,8 @@ public Flux stream(Prompt prompt) { } } - private VertexAiChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) { - return new VertexAiChatResponseMetadata(new VertexAiUsage(response.getUsageMetadata())); + private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) { + return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build(); } @JsonInclude(Include.NON_NULL) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiChatResponseMetadata.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiChatResponseMetadata.java deleted file mode 100644 index 3ef20dcb49..0000000000 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiChatResponseMetadata.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.vertexai.gemini.metadata; - -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.Usage; - -import java.util.HashMap; - -/** - * @author Christian Tzolov - * @since 0.8.1 - */ -public class VertexAiChatResponseMetadata extends HashMap implements ChatResponseMetadata { - - private final VertexAiUsage usage; - - public VertexAiChatResponseMetadata(VertexAiUsage usage) { - this.usage = usage; - } - - @Override - public Usage getUsage() { - return this.usage; - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index afd709cb8d..d9391c420e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -127,15 +128,17 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map @Override public ChatResponse adviseResponse(ChatResponse response, Map context) { - response.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return response; + ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response); + chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); + return chatResponseBuilder.build(); } @Override public Flux adviseResponse(Flux fluxResponse, Map context) { return fluxResponse.map(cr -> { - cr.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return cr; + ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr); + chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); + return chatResponseBuilder.build(); }); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java index 17fea86cd0..f58bc5d240 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java @@ -15,40 +15,51 @@ */ package org.springframework.ai.chat.metadata; -import org.springframework.ai.model.ResponseMetadata; +import java.util.Map; +import java.util.Objects; -import java.util.HashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.AbstractResponseMetadata; +import org.springframework.ai.model.ResponseMetadata; /** - * Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI - * response. + * Models common AI provider metadata returned in an AI response. * * @author John Blum * @author Thomas Vitale - * @since 0.7.0 + * @author Mark Pollack + * @since 1.0.0 */ -public interface ChatResponseMetadata extends ResponseMetadata { +public class ChatResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata { - class DefaultChatResponseMetadata extends HashMap implements ChatResponseMetadata { + private final static Logger logger = LoggerFactory.getLogger(ChatResponseMetadata.class); - } + private String id = ""; // Set to blank to preserve backward compat with previous + // interface default methods + + private String model = ""; - ChatResponseMetadata NULL = new DefaultChatResponseMetadata(); + private RateLimit rateLimit = new EmptyRateLimit(); + + private Usage usage = new EmptyUsage(); + + private PromptMetadata promptMetadata = PromptMetadata.empty(); /** * A unique identifier for the chat completion operation. * @return unique operation identifier. */ - default String getId() { - return ""; + public String getId() { + return this.id; } /** * The model that handled the request. * @return the model that handled the request. */ - default String getModel() { - return ""; + public String getModel() { + return this.model; } /** @@ -56,8 +67,8 @@ default String getModel() { * @return AI provider specific metadata on rate limits. * @see RateLimit */ - default RateLimit getRateLimit() { - return new EmptyRateLimit(); + public RateLimit getRateLimit() { + return this.rateLimit; } /** @@ -65,12 +76,98 @@ default RateLimit getRateLimit() { * @return AI provider specific metadata on API usage. * @see Usage */ - default Usage getUsage() { - return new EmptyUsage(); + public Usage getUsage() { + return this.usage; + } + + /** + * Returns the prompt metadata gathered by the AI during request processing. + * @return the prompt metadata. + */ + public PromptMetadata getPromptMetadata() { + return this.promptMetadata; + } + + public static class Builder { + + private final ChatResponseMetadata chatResponseMetadata; + + public Builder() { + this.chatResponseMetadata = new ChatResponseMetadata(); + } + + public Builder withMetadata(Map mapToCopy) { + this.chatResponseMetadata.map.putAll(mapToCopy); + return this; + } + + public Builder withKeyValue(String key, Object value) { + if (key == null) { + throw new IllegalArgumentException("Key must not be null"); + } + if (value != null) { + this.chatResponseMetadata.map.put(key, value); + } + else { + logger.debug("Ignore null value for key [{}]", key); + } + return this; + } + + public Builder withId(String id) { + this.chatResponseMetadata.id = id; + return this; + } + + public Builder withModel(String model) { + this.chatResponseMetadata.model = model; + return this; + } + + public Builder withRateLimit(RateLimit rateLimit) { + this.chatResponseMetadata.rateLimit = rateLimit; + return this; + } + + public Builder withUsage(Usage usage) { + this.chatResponseMetadata.usage = usage; + return this; + } + + public Builder withPromptMetadata(PromptMetadata promptMetadata) { + this.chatResponseMetadata.promptMetadata = promptMetadata; + return this; + } + + public ChatResponseMetadata build() { + return this.chatResponseMetadata; + } + + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ChatResponseMetadata that)) + return false; + return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) + && Objects.equals(this.rateLimit, that.rateLimit) && Objects.equals(this.usage, that.usage) + && Objects.equals(this.promptMetadata, that.promptMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(this.id, this.model, this.rateLimit, this.usage, this.promptMetadata); } - default PromptMetadata getPromptMetadata() { - return PromptMetadata.empty(); + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit()); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index 8c31776c0e..e1b3172461 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -16,7 +16,9 @@ package org.springframework.ai.chat.model; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; import org.springframework.ai.model.ModelResponse; import org.springframework.util.CollectionUtils; @@ -40,7 +42,7 @@ public class ChatResponse implements ModelResponse { * provider. */ public ChatResponse(List generations) { - this(generations, ChatResponseMetadata.NULL); + this(generations, new ChatResponseMetadata()); } /** @@ -107,4 +109,44 @@ public int hashCode() { return Objects.hash(chatResponseMetadata, generations); } + public static ChatResponse.Builder builder() { + return new ChatResponse.Builder(); + } + + public static class Builder { + + private List generations; + + private ChatResponseMetadata.Builder chatResponseMetadataBuilder; + + private Builder() { + this.chatResponseMetadataBuilder = ChatResponseMetadata.builder(); + } + + public Builder from(ChatResponse other) { + this.generations = other.generations; + Set> entries = other.chatResponseMetadata.entrySet(); + for (Map.Entry entry : entries) { + this.chatResponseMetadataBuilder.withKeyValue(entry.getKey(), entry.getValue()); + } + return this; + } + + public Builder withMetadata(String key, Object value) { + this.chatResponseMetadataBuilder.withKeyValue(key, value); + return this; + } + + public Builder withGenerations(List generations) { + this.generations = generations; + return this; + + } + + public ChatResponse build() { + return new ChatResponse(generations, chatResponseMetadataBuilder.build()); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java index 40a3058383..335ac0ae2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java @@ -15,24 +15,20 @@ */ package org.springframework.ai.embedding; -import java.io.Serial; -import java.util.HashMap; -import java.util.Map; - import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.model.AbstractResponseMetadata; import org.springframework.ai.model.ResponseMetadata; +import java.util.Map; + /** * Common AI provider metadata returned in an embedding response. * * @author Christian Tzolov * @author Thomas Vitale */ -public class EmbeddingResponseMetadata extends HashMap implements ResponseMetadata { - - @Serial - private static final long serialVersionUID = 1L; +public class EmbeddingResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata { private String model; @@ -42,12 +38,15 @@ public EmbeddingResponseMetadata() { } public EmbeddingResponseMetadata(String model, Usage usage) { - this.model = model; - this.usage = usage; + this(model, usage, Map.of()); } - public EmbeddingResponseMetadata(Map metadata) { - super(metadata); + public EmbeddingResponseMetadata(String model, Usage usage, Map metadata) { + this.model = model; + this.usage = usage; + for (Map.Entry entry : metadata.entrySet()) { + this.map.put(entry.getKey(), entry.getValue()); + } } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java index 70a4b946a8..b6d6c87b88 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -43,7 +43,7 @@ public class ImageResponse implements ModelResponse { * provider. */ public ImageResponse(List generations) { - this(generations, ImageResponseMetadata.NULL); + this(generations, new ImageResponseMetadata()); } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java index a80b31bba7..fe5b78985d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.image; +import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.ai.model.ResponseMetadata; import java.util.HashMap; @@ -28,16 +29,20 @@ * @author Thomas Vitale * @since 1.0.0 */ -public interface ImageResponseMetadata extends ResponseMetadata { +public class ImageResponseMetadata extends MutableResponseMetadata { - class DefaultImageResponseMetadata extends HashMap implements ImageResponseMetadata { + private Long created; + public ImageResponseMetadata() { + this.created = System.currentTimeMillis(); } - ImageResponseMetadata NULL = new DefaultImageResponseMetadata(); + public ImageResponseMetadata(Long created) { + this.created = created; + } - default Long getCreated() { - return System.currentTimeMillis(); + public Long getCreated() { + return this.created; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractResponseMetadata.java new file mode 100644 index 0000000000..42bd8678e8 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/AbstractResponseMetadata.java @@ -0,0 +1,76 @@ +package org.springframework.ai.model; + +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +public class AbstractResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ id: %1$s, usage: %2$s, rateLimit: %3$s }"; + + protected final Map map = new ConcurrentHashMap<>(); + + /** + * Gets an entry from the context. Returns {@code null} when entry is not present. + * @param key key + * @param value type + * @return entry or {@code null} if not present + */ + @Nullable + public T get(String key) { + return (T) this.map.get(key); + } + + /** + * Gets an entry from the context. Throws exception when entry is not present. + * @param key key + * @param value type + * @return entry + * @throws IllegalArgumentException if not present + */ + @NonNull + public T getRequired(Object key) { + T object = (T) this.map.get(key); + if (object == null) { + throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]"); + } + return object; + } + + /** + * Checks if context contains a key. + * @param key key + * @return {@code true} when the context contains the entry with the given key + */ + public boolean containsKey(Object key) { + return this.map.containsKey(key); + } + + /** + * Returns an element or default if not present. + * @param key key + * @param defaultObject default object to return + * @param value type + * @return object or default if not present + */ + public T getOrDefault(Object key, T defaultObject) { + return (T) this.map.getOrDefault(key, defaultObject); + } + + public Set> entrySet() { + return Collections.unmodifiableMap(this.map).entrySet(); + } + + public Set keySet() { + return Collections.unmodifiableSet(this.map.keySet()); + } + + public boolean isEmpty() { + return this.map.isEmpty(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java new file mode 100644 index 0000000000..ac0c9254e7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java @@ -0,0 +1,126 @@ +package org.springframework.ai.model; + +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +public class MutableResponseMetadata implements ResponseMetadata { + + private final Map map = new ConcurrentHashMap<>(); + + /** + * Puts an element to the context. + * @param key key + * @param object value + * @param value type + * @return this for chaining + */ + public MutableResponseMetadata put(String key, T object) { + this.map.put(key, object); + return this; + } + + /** + * Gets an entry from the context. Returns {@code null} when entry is not present. + * @param key key + * @param value type + * @return entry or {@code null} if not present + */ + @Override + @Nullable + public T get(String key) { + return (T) this.map.get(key); + } + + /** + * Removes an entry from the context. + * @param key key by which to remove an entry + * @return the previous value associated with the key, or null if there was no mapping + * for the key + */ + public Object remove(Object key) { + return this.map.remove(key); + } + + /** + * Gets an entry from the context. Throws exception when entry is not present. + * @param key key + * @param value type + * @throws IllegalArgumentException if not present + * @return entry + */ + @Override + @NonNull + public T getRequired(Object key) { + T object = (T) this.map.get(key); + if (object == null) { + throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]"); + } + return object; + } + + /** + * Checks if context contains a key. + * @param key key + * @return {@code true} when the context contains the entry with the given key + */ + @Override + public boolean containsKey(Object key) { + return this.map.containsKey(key); + } + + /** + * Returns an element or default if not present. + * @param key key + * @param defaultObject default object to return + * @param value type + * @return object or default if not present + */ + @Override + public T getOrDefault(Object key, T defaultObject) { + return (T) this.map.getOrDefault(key, defaultObject); + } + + @Override + public Set> entrySet() { + return Collections.unmodifiableMap(this.map).entrySet(); + } + + public Set keySet() { + return Collections.unmodifiableSet(this.map.keySet()); + } + + @Override + public boolean isEmpty() { + return this.map.isEmpty(); + } + + /** + * Returns an element or calls a mapping function if entry not present. The function + * will insert the value to the map. + * @param key key + * @param mappingFunction mapping function + * @param value type + * @return object or one derived from the mapping function if not present + */ + public T computeIfAbsent(String key, Function mappingFunction) { + return (T) this.map.computeIfAbsent(key, mappingFunction); + } + + /** + * Clears the entries from the context. + */ + public void clear() { + this.map.clear(); + } + + public Map getRawMap() { + return map; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java index b1516ec3fb..24e544d4f2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -15,18 +15,77 @@ */ package org.springframework.ai.model; +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; /** - * Interface representing metadata associated with an AI model's response. This interface - * is designed to provide additional information about the generative response from an AI - * model, including processing details and model-specific data. It serves as a value - * object within the core domain, enhancing the understanding and management of AI model - * responses in various applications. + * Interface representing metadata associated with an AI model's response. * * @author Mark Pollack - * @since 0.8.0 + * @since 1.0.0 */ -public interface ResponseMetadata extends Map { +public interface ResponseMetadata { + + /** + * Gets an entry from the context. Returns {@code null} when entry is not present. + * @param key key + * @param value type + * @return entry or {@code null} if not present + */ + @Nullable + T get(String key); + + /** + * Gets an entry from the context. Throws exception when entry is not present. + * @param key key + * @param value type + * @throws IllegalArgumentException if not present + * @return entry + */ + @NonNull + T getRequired(Object key); + + /** + * Checks if context contains a key. + * @param key key + * @return {@code true} when the context contains the entry with the given key + */ + boolean containsKey(Object key); + + /** + * Returns an element or default if not present. + * @param key key + * @param defaultObject default object to return + * @param value type + * @return object or default if not present + */ + T getOrDefault(Object key, T defaultObject); + + /** + * Returns an element or default if not present. + * @param key key + * @param defaultObjectSupplier supplier for default object to return + * @param value type + * @return object or default if not present + * @since 1.11.0 + */ + default T getOrDefault(String key, Supplier defaultObjectSupplier) { + T value = get(key); + return value != null ? value : defaultObjectSupplier.get(); + } + + Set> entrySet(); + + public Set keySet(); + + /** + * Returns {@code true} if this map contains no key-value mappings. + * @return {@code true} if this map contains no key-value mappings + */ + boolean isEmpty(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java index d5be8ef6ca..4a8ce759a2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java @@ -136,9 +136,7 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) { // The chat completion tool call requires the complete conversation // history. Including the initial user message. - List conversationHistory = new ArrayList<>(); - - conversationHistory.addAll(this.doGetUserMessages(request)); + List conversationHistory = new ArrayList<>(this.doGetUserMessages(request)); Msg responseMessage = this.doGetToolResponseMessage(response); @@ -164,9 +162,7 @@ protected Flux handleFunctionCallOrReturnStream(Req request, Flux re // The chat completion tool call requires the complete conversation // history. Including the initial user message. - List conversationHistory = new ArrayList<>(); - - conversationHistory.addAll(this.doGetUserMessages(request)); + List conversationHistory = new ArrayList<>(this.doGetUserMessages(request)); Msg responseMessage = this.doGetToolResponseMessage(resp); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java index 2fae3d9f25..2e40f9def7 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java @@ -29,7 +29,6 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -58,8 +57,7 @@ record MyBean(String name, int age) { @Test public void responseEntityTest() { - ChatResponseMetadata metadata = new DefaultChatResponseMetadata(); - metadata.put("key1", "value1"); + ChatResponseMetadata metadata = ChatResponseMetadata.builder().withKeyValue("key1", "value1").build(); var chatResponse = new ChatResponse(List.of(new Generation(""" {"name":"John", "age":30} @@ -75,7 +73,7 @@ public void responseEntityTest() { .responseEntity(MyBean.class); assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); - assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1"); + assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1"); assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java index ca030ab9e9..b4385d64fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java @@ -112,8 +112,8 @@ public void multimodalEmbedding() { .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); - assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata().getUsage()).isEqualTo(0); assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);