From 44899023bb3a7429e5ff7ffce51d23903d385a66 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sat, 10 Aug 2024 13:20:29 +0200 Subject: [PATCH] Add additinal chatclinet observation convention and context tests --- .../ai/chat/client/DefaultChatClient.java | 8 +- .../ChatClientObservationDocumentation.java | 12 +- ...efaultChatClientObservationConvention.java | 36 +++- .../conventions/AiOperationType.java | 3 +- .../observation/conventions/AiProvider.java | 3 +- ...entInputContentObservationFilterTests.java | 4 +- .../ChatClientObservationContextTests.java | 64 +++++++ ...tChatClientObservationConventionTests.java | 174 ++++++++++++++++++ 8 files changed, 291 insertions(+), 13 deletions(-) create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 13b60dec246..e20407f3612 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -49,6 +49,8 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.observation.AiOperationMetadata; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.util.Assert; @@ -327,7 +329,8 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in String formatParam) { ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, - new AiOperationMetadata("framework", "spring_ai"), formatParam, false); + new AiOperationMetadata(AiOperationType.FRAMEWORK.value(), AiProvider.SPRING_AI.value()), + formatParam, false); return ChatClientObservationDocumentation.AI_CHAT_CLIENT .observation(inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, @@ -425,7 +428,8 @@ public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSp private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { return Flux.deferContextual(contextView -> { ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest, - new AiOperationMetadata("framework", "spring_ai"), "", true); + new AiOperationMetadata(AiOperationType.FRAMEWORK.value(), AiProvider.SPRING_AI.value()), "", + true); Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java index 537c156b017..2af2ae45ec9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java @@ -87,12 +87,22 @@ public enum HighCardinalityKeyNames implements KeyName { /** * Enabled tool function names. */ - TOOL_FUNCTION_NAMES { + CHAT_CLIENT_TOOL_FUNCTION_NAMES { @Override public String asString() { return "spring.ai.chat.client.tool.function.names"; } }, + /** + * List of configured chat client function callbacks. + */ + CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS { + @Override + public String asString() { + return "spring.ai.chat.client.tool.functioncallbacks"; + } + }, + /** * List of configured chat client advisors. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java index 215d007b9f0..5d873bad18d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java @@ -36,8 +36,13 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva private static final KeyValue STATUS_NONE = KeyValue.of(LowCardinalityKeyNames.STATUS, KeyValue.NONE_VALUE); - private static final KeyValue TOOL_FUNCTION_NAMES_NONE = KeyValue - .of(ChatClientObservationDocumentation.HighCardinalityKeyNames.TOOL_FUNCTION_NAMES, KeyValue.NONE_VALUE); + private static final KeyValue CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS_NONE = KeyValue.of( + ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS, + KeyValue.NONE_VALUE); + + private static final KeyValue CHAT_CLIENT_TOOL_FUNCTION_NAMES_NONE = KeyValue.of( + ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES, + KeyValue.NONE_VALUE); private static final KeyValue CHAT_CLIENT_ADVISOR_NONE = KeyValue .of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR, KeyValue.NONE_VALUE); @@ -70,7 +75,13 @@ public String getContextualName(ChatClientObservationContext context) { @Override public KeyValues getLowCardinalityKeyValues(ChatClientObservationContext context) { return KeyValues.of(springAiKind(context), aiOperationType(context), aiProvider(context), stream(context), - status(context), toolFunctionNames(context), chatClientAvisor(context), chatClientAvisorParam(context)); + status(context)); + } + + @Override + public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext context) { + return KeyValues.of(toolFunctionNames(context), toolFunctionCallbacks(context), chatClientAvisor(context), + chatClientAvisorParam(context)); } protected KeyValue springAiKind(ChatClientObservationContext context) { @@ -97,12 +108,25 @@ protected KeyValue aiProvider(ChatClientObservationContext context) { protected KeyValue toolFunctionNames(ChatClientObservationContext context) { if (CollectionUtils.isEmpty(context.getRequest().getFunctionNames())) { - return TOOL_FUNCTION_NAMES_NONE; + return CHAT_CLIENT_TOOL_FUNCTION_NAMES_NONE; } - return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.TOOL_FUNCTION_NAMES, + return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES, context.getRequest().getFunctionNames().stream().collect(Collectors.joining(","))); } + protected KeyValue toolFunctionCallbacks(ChatClientObservationContext context) { + if (CollectionUtils.isEmpty(context.getRequest().getFunctionCallbacks())) { + return CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS_NONE; + } + return KeyValue.of( + ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS, + context.getRequest() + .getFunctionCallbacks() + .stream() + .map(fc -> fc.getName()) + .collect(Collectors.joining(","))); + } + protected KeyValue chatClientAvisor(ChatClientObservationContext context) { if (CollectionUtils.isEmpty(context.getRequest().getAdvisors())) { return CHAT_CLIENT_ADVISOR_NONE; @@ -115,7 +139,7 @@ protected KeyValue chatClientAvisorParam(ChatClientObservationContext context) { if (CollectionUtils.isEmpty(context.getRequest().getAdvisorParams())) { return CHAT_CLIENT_ADVISOR_PARAM_NONE; } - return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR, + return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAM, context.getRequest() .getAdvisorParams() .entrySet() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java index 45ea85671d0..4ea92d880bc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java @@ -32,7 +32,8 @@ public enum AiOperationType { CHAT("chat"), EMBEDDING("embedding"), IMAGE("image"), - TEXT_COMPLETION("text_completion"); + TEXT_COMPLETION("text_completion"), + FRAMEWORK("framework"); private final String value; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 01f678f02ea..a93fe168acf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -33,7 +33,8 @@ public enum AiProvider { MISTRAL_AI("mistral_ai"), OLLAMA("ollama"), OPENAI("openai"), - VERTEX_AI("vertex_ai"); + VERTEX_AI("vertex_ai"), + SPRING_AI("spring_ai"); private final String value; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 51d62e7ff2b..fe142227e34 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -98,8 +98,8 @@ void whenWithTextThenAugmentContext() { private AiOperationMetadata generateOperationMetadata() { return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) + .operationType(AiOperationType.FRAMEWORK.value()) + .provider(AiProvider.SPRING_AI.value()) .build(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java new file mode 100644 index 00000000000..38761697eaa --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.client.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.observation.AiOperationMetadata; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; + +import io.micrometer.observation.ObservationRegistry; + +/** + * Unit tests for {@link ChatClientObservationContext}. + * + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +class ChatClientObservationContextTests { + + @Mock + ChatModel chatModel; + + @Test + void whenMandatoryRequestOptionsThenReturn() { + + var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null); + + var observationContext = new ChatClientObservationContext(request, generateOperationMetadata(), "", true); + + assertThat(observationContext).isNotNull(); + } + + private AiOperationMetadata generateOperationMetadata() { + return AiOperationMetadata.builder() + .operationType(AiOperationType.FRAMEWORK.value()) + .provider(AiProvider.SPRING_AI.value()) + .build(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java new file mode 100644 index 00000000000..86a3a47bed7 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.client.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.observation.AiOperationMetadata; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; + +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; + +/** + * Unit tests for {@link DefaultChatClientObservationConvention}. + * + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +class DefaultChatClientObservationConventionTests { + + @Mock + ChatModel chatModel; + + private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention(); + + DefaultChatClientRequestSpec request; + + @BeforeEach + public void beforeEach() { + request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null); + } + + @Test + void shouldHaveName() { + assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME); + } + + @Test + void shouldHaveContextualName() { + ChatClientObservationContext observationContext = new ChatClientObservationContext(request, + generateOperationMetadata(), "", true); + + assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("framework spring_ai"); + } + + @Test + void supportsOnlyChatClientObservationContext() { + ChatClientObservationContext observationContext = new ChatClientObservationContext(request, + generateOperationMetadata(), "", true); + + assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); + assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); + } + + @Test + void shouldHaveRequiredKeyValues() { + ChatClientObservationContext observationContext = new ChatClientObservationContext(request, + generateOperationMetadata(), "", true); + + assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( + KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat.client"), + KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true")); + } + + static RequestResponseAdvisor dummyAdvisor(String name) { + return new RequestResponseAdvisor() { + @Override + public String getName() { + return name; + } + }; + } + + static FunctionCallback dummyFunction(String name) { + return new FunctionCallback() { + @Override + public String getName() { + return name; + } + + @Override + public String getDescription() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getDescription'"); + } + + @Override + public String getInputTypeSchema() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getInputTypeSchema'"); + } + + @Override + public String call(String functionInput) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'call'"); + } + }; + } + + @Test + void shouldHaveOptionalKeyValues() { + + var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), + List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), + List.of("function1", "function2"), List.of(), null, + List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), + ObservationRegistry.NOOP, null); + + ChatClientObservationContext observationContext = new ChatClientObservationContext(request, + generateOperationMetadata(), "json", true); + + assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR.asString(), "advisor1,advisor2"), + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAM.asString(), + "advParam1:advisorParam1Value"), + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), "function1,function2"), + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(), + "functionCallback1,functionCallback2")); + } + + private AiOperationMetadata generateOperationMetadata() { + return AiOperationMetadata.builder() + .operationType(AiOperationType.FRAMEWORK.value()) + .provider(AiProvider.SPRING_AI.value()) + .build(); + } + + static class TestUsage implements Usage { + + @Override + public Long getPromptTokens() { + return 1000L; + } + + @Override + public Long getGenerationTokens() { + return 500L; + } + + } + +}