Skip to content

Commit

Permalink
Update to ResponseMetadata design
Browse files Browse the repository at this point in the history
* Remove inheritance from HashMap
* No more subclasses per model provider
* Builder class for ChatResponse
* Fix the AbstractResponseMetadata#AI_METADATA_STRING parameter order
* ChatResponseMetadata ignore Null values.
  • Loading branch information
markpollack authored and tzolov committed Jul 18, 2024
1 parent 17c4423 commit 97f443d
Show file tree
Hide file tree
Showing 39 changed files with 773 additions and 727 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@
*/
package org.springframework.ai.anthropic;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
Expand All @@ -32,12 +24,13 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -52,10 +45,17 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* The {@link ChatModel} implementation for the Anthropic service.
*
Expand Down Expand Up @@ -228,7 +228,20 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
}).toList();

return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion));
return new ChatResponse(generations, from(chatCompletion));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("stop-reason", result.stopReason())
.withKeyValue("stop-sequence", result.stopSequence())
.withKeyValue("type", result.type())
.build();
}

private String fromMediaData(Object mediaData) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,6 @@
*/
package org.springframework.ai.azure.openai;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
Expand All @@ -68,10 +41,36 @@
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
Expand Down Expand Up @@ -151,8 +150,22 @@ public ChatResponse call(Prompt prompt) {

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
.withId(id)
.withUsage(usage)
.withModel(chatCompletions.getModel())
.withPromptMetadata(promptFilterMetadata)
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
.build();

return chatResponseMetadata;
}

@Override
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.azure.ai.openai.models.ImageGenerations;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.MutableResponseMetadata;
import org.springframework.util.Assert;

import java.util.HashMap;
Expand All @@ -15,7 +16,7 @@
* @author Benoit Moussaud
* @since 1.0.0 M1
*/
public class AzureOpenAiImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata {

private final Long created;

Expand Down
Loading

0 comments on commit 97f443d

Please sign in to comment.