Skip to content

Commit

Permalink
Add additinal chatclinet observation convention and context tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tzolov committed Aug 10, 2024
1 parent ad0344c commit 4489902
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -327,7 +329,8 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in
String formatParam) {

ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest,
new AiOperationMetadata("framework", "spring_ai"), formatParam, false);
new AiOperationMetadata(AiOperationType.FRAMEWORK.value(), AiProvider.SPRING_AI.value()),
formatParam, false);

return ChatClientObservationDocumentation.AI_CHAT_CLIENT
.observation(inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
Expand Down Expand Up @@ -425,7 +428,8 @@ public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSp
private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
return Flux.deferContextual(contextView -> {
ChatClientObservationContext observationContext = new ChatClientObservationContext(inputRequest,
new AiOperationMetadata("framework", "spring_ai"), "", true);
new AiOperationMetadata(AiOperationType.FRAMEWORK.value(), AiProvider.SPRING_AI.value()), "",
true);

Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
inputRequest.customObservationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,22 @@ public enum HighCardinalityKeyNames implements KeyName {
/**
* Enabled tool function names.
*/
TOOL_FUNCTION_NAMES {
CHAT_CLIENT_TOOL_FUNCTION_NAMES {
@Override
public String asString() {
return "spring.ai.chat.client.tool.function.names";
}
},
/**
* List of configured chat client function callbacks.
*/
CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS {
@Override
public String asString() {
return "spring.ai.chat.client.tool.functioncallbacks";
}
},

/**
* List of configured chat client advisors.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva

private static final KeyValue STATUS_NONE = KeyValue.of(LowCardinalityKeyNames.STATUS, KeyValue.NONE_VALUE);

private static final KeyValue TOOL_FUNCTION_NAMES_NONE = KeyValue
.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.TOOL_FUNCTION_NAMES, KeyValue.NONE_VALUE);
private static final KeyValue CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS_NONE = KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS,
KeyValue.NONE_VALUE);

private static final KeyValue CHAT_CLIENT_TOOL_FUNCTION_NAMES_NONE = KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES,
KeyValue.NONE_VALUE);

private static final KeyValue CHAT_CLIENT_ADVISOR_NONE = KeyValue
.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR, KeyValue.NONE_VALUE);
Expand Down Expand Up @@ -70,7 +75,13 @@ public String getContextualName(ChatClientObservationContext context) {
@Override
public KeyValues getLowCardinalityKeyValues(ChatClientObservationContext context) {
return KeyValues.of(springAiKind(context), aiOperationType(context), aiProvider(context), stream(context),
status(context), toolFunctionNames(context), chatClientAvisor(context), chatClientAvisorParam(context));
status(context));
}

@Override
public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext context) {
return KeyValues.of(toolFunctionNames(context), toolFunctionCallbacks(context), chatClientAvisor(context),
chatClientAvisorParam(context));
}

protected KeyValue springAiKind(ChatClientObservationContext context) {
Expand All @@ -97,12 +108,25 @@ protected KeyValue aiProvider(ChatClientObservationContext context) {

protected KeyValue toolFunctionNames(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getFunctionNames())) {
return TOOL_FUNCTION_NAMES_NONE;
return CHAT_CLIENT_TOOL_FUNCTION_NAMES_NONE;
}
return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.TOOL_FUNCTION_NAMES,
return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES,
context.getRequest().getFunctionNames().stream().collect(Collectors.joining(",")));
}

protected KeyValue toolFunctionCallbacks(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getFunctionCallbacks())) {
return CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS_NONE;
}
return KeyValue.of(
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS,
context.getRequest()
.getFunctionCallbacks()
.stream()
.map(fc -> fc.getName())
.collect(Collectors.joining(",")));
}

protected KeyValue chatClientAvisor(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getAdvisors())) {
return CHAT_CLIENT_ADVISOR_NONE;
Expand All @@ -115,7 +139,7 @@ protected KeyValue chatClientAvisorParam(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getAdvisorParams())) {
return CHAT_CLIENT_ADVISOR_PARAM_NONE;
}
return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR,
return KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAM,
context.getRequest()
.getAdvisorParams()
.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public enum AiOperationType {
CHAT("chat"),
EMBEDDING("embedding"),
IMAGE("image"),
TEXT_COMPLETION("text_completion");
TEXT_COMPLETION("text_completion"),
FRAMEWORK("framework");

private final String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public enum AiProvider {
MISTRAL_AI("mistral_ai"),
OLLAMA("ollama"),
OPENAI("openai"),
VERTEX_AI("vertex_ai");
VERTEX_AI("vertex_ai"),
SPRING_AI("spring_ai");

private final String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ void whenWithTextThenAugmentContext() {

private AiOperationMetadata generateOperationMetadata() {
return AiOperationMetadata.builder()
.operationType(AiOperationType.CHAT.value())
.provider(AiProvider.OLLAMA.value())
.operationType(AiOperationType.FRAMEWORK.value())
.provider(AiProvider.SPRING_AI.value())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.observation;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;

import io.micrometer.observation.ObservationRegistry;

/**
* Unit tests for {@link ChatClientObservationContext}.
*
* @author Christian Tzolov
*/
@ExtendWith(MockitoExtension.class)
class ChatClientObservationContextTests {

@Mock
ChatModel chatModel;

@Test
void whenMandatoryRequestOptionsThenReturn() {

var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);

var observationContext = new ChatClientObservationContext(request, generateOperationMetadata(), "", true);

assertThat(observationContext).isNotNull();
}

private AiOperationMetadata generateOperationMetadata() {
return AiOperationMetadata.builder()
.operationType(AiOperationType.FRAMEWORK.value())
.provider(AiProvider.SPRING_AI.value())
.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.observation;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;

import io.micrometer.common.KeyValue;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;

/**
* Unit tests for {@link DefaultChatClientObservationConvention}.
*
* @author Christian Tzolov
*/
@ExtendWith(MockitoExtension.class)
class DefaultChatClientObservationConventionTests {

@Mock
ChatModel chatModel;

private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention();

DefaultChatClientRequestSpec request;

@BeforeEach
public void beforeEach() {
request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);
}

@Test
void shouldHaveName() {
assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME);
}

@Test
void shouldHaveContextualName() {
ChatClientObservationContext observationContext = new ChatClientObservationContext(request,
generateOperationMetadata(), "", true);

assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("framework spring_ai");
}

@Test
void supportsOnlyChatClientObservationContext() {
ChatClientObservationContext observationContext = new ChatClientObservationContext(request,
generateOperationMetadata(), "", true);

assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse();
}

@Test
void shouldHaveRequiredKeyValues() {
ChatClientObservationContext observationContext = new ChatClientObservationContext(request,
generateOperationMetadata(), "", true);

assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat.client"),
KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true"));
}

static RequestResponseAdvisor dummyAdvisor(String name) {
return new RequestResponseAdvisor() {
@Override
public String getName() {
return name;
}
};
}

static FunctionCallback dummyFunction(String name) {
return new FunctionCallback() {
@Override
public String getName() {
return name;
}

@Override
public String getDescription() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getDescription'");
}

@Override
public String getInputTypeSchema() {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'getInputTypeSchema'");
}

@Override
public String call(String functionInput) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'call'");
}
};
}

@Test
void shouldHaveOptionalKeyValues() {

var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(),
List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(),
List.of("function1", "function2"), List.of(), null,
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"),
ObservationRegistry.NOOP, null);

ChatClientObservationContext observationContext = new ChatClientObservationContext(request,
generateOperationMetadata(), "json", true);

assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains(
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR.asString(), "advisor1,advisor2"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAM.asString(),
"advParam1:advisorParam1Value"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), "function1,function2"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(),
"functionCallback1,functionCallback2"));
}

private AiOperationMetadata generateOperationMetadata() {
return AiOperationMetadata.builder()
.operationType(AiOperationType.FRAMEWORK.value())
.provider(AiProvider.SPRING_AI.value())
.build();
}

static class TestUsage implements Usage {

@Override
public Long getPromptTokens() {
return 1000L;
}

@Override
public Long getGenerationTokens() {
return 500L;
}

}

}

0 comments on commit 4489902

Please sign in to comment.