diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index c6cdb9408f..9826accc90 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -26,7 +26,12 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import java.util.function.Function; +import org.springframework.ai.chat.client.resolver.BeanNameResolver; +import org.springframework.ai.chat.client.resolver.SimpleNameResolver; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.context.support.GenericApplicationContext; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.Media; @@ -64,12 +69,18 @@ */ public interface ChatClient { - static ChatClient create(ChatModel chatModel) { - return builder(chatModel).build(); + static ChatClient create(ChatModel chatModel, BeanNameResolver beanNameResolver) { + return builder(chatModel, beanNameResolver).build(); } + // QUESTION: Should this constructor be removed in favour of the one requiring the + // beanNameResolver? static Builder builder(ChatModel chatModel) { - return new Builder(chatModel); + return new Builder(chatModel, new SimpleNameResolver()); + } + + static Builder builder(ChatModel chatModel, BeanNameResolver beanNameResolver) { + return new Builder(chatModel, beanNameResolver); } ChatClientRequest prompt(); @@ -77,8 +88,8 @@ static Builder builder(ChatModel chatModel) { ChatClientPromptRequest prompt(Prompt prompt); /** - * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose - * settings are replicated from the default {@link ChatClientRequest} of this client. + * Return a {@link Builder} to create a new {@link ChatClient} whose settings are + * replicated from the default {@link ChatClientRequest} of this client. */ Builder mutate(); @@ -263,16 +274,20 @@ class ChatClientRequest { private final Map advisorParams = new HashMap<>(); + private final BeanNameResolver beanNameResolver; + /* copy constructor */ ChatClientRequest(ChatClientRequest ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, - ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams); + ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, + ccr.beanNameResolver); } public ChatClientRequest(ChatModel chatModel, String userText, Map userParams, String systemText, Map systemParams, List functionCallbacks, List messages, List functionNames, List media, ChatOptions chatOptions, - List advisors, Map advisorParams) { + List advisors, Map advisorParams, + BeanNameResolver beanNameResolver) { this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions(); @@ -288,6 +303,7 @@ public ChatClientRequest(ChatModel chatModel, String userText, Map s.text(this.systemText).params(this.systemParams)) .defaultUser(u -> u.text(this.userText) .params(this.userParams) @@ -349,6 +365,11 @@ public ChatClientRequest options(T options) { return this; } + public ChatClientRequest function(FunctionCallbackWrapper functionWrapper) { + this.functionCallbacks.add(functionWrapper); + return this; + } + public ChatClientRequest function(String name, String description, java.util.function.Function function) { @@ -365,6 +386,11 @@ public ChatClientRequest function(String name, String description, return this; } + public ChatClientRequest functionBean(Class functionBeanType) { + this.functionNames.add(beanNameResolver.resolveName(functionBeanType)); + return this; + } + public ChatClientRequest functions(String... functionBeanNames) { Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null"); this.functionNames.addAll(List.of(functionBeanNames)); @@ -528,7 +554,7 @@ private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest, adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(), adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(), adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(), - adviseRequest.advisorParams()); + adviseRequest.advisorParams(), inputRequest.beanNameResolver); } return advisedRequest; @@ -735,11 +761,11 @@ class Builder { private final ChatModel chatModel; - Builder(ChatModel chatModel) { + Builder(ChatModel chatModel, BeanNameResolver beanNameResolver) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); this.chatModel = chatModel; this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of()); + List.of(), List.of(), null, List.of(), Map.of(), beanNameResolver); } public Builder defaultAdvisors(RequestResponseAdvisor... advisor) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/Tool.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/Tool.java new file mode 100644 index 0000000000..88aeb111e0 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/Tool.java @@ -0,0 +1,19 @@ +package org.springframework.ai.chat.client; + +import org.springframework.context.annotation.Description; +import org.springframework.core.annotation.AliasFor; +import org.springframework.stereotype.Indexed; + +import java.lang.annotation.*; + +@Target({ ElementType.METHOD, ElementType.TYPE, ElementType.ANNOTATION_TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Indexed +@Description("") +public @interface Tool { + + @AliasFor(annotation = Description.class, attribute = "value") + String description(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/ApplicationContextBeanNameResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/ApplicationContextBeanNameResolver.java new file mode 100644 index 0000000000..2fbe15143c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/ApplicationContextBeanNameResolver.java @@ -0,0 +1,30 @@ +package org.springframework.ai.chat.client.resolver; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.util.Assert; + +public class ApplicationContextBeanNameResolver implements BeanNameResolver { + + private final GenericApplicationContext context; + + public ApplicationContextBeanNameResolver(GenericApplicationContext context) { + this.context = context; + } + + @Override + public String resolveName(Class beanType) { + String[] namesForType = context.getBeanNamesForType(beanType); + Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton"); + + /* + * The following snippet could be used in other places to resolve the description + * from a function registered as a bean BeanDefinition beanDefinition = + * context.getBeanDefinition(namesForType[0]); String description = + * beanDefinition.getDescription(); + */ + + return namesForType[0]; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/BeanNameResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/BeanNameResolver.java new file mode 100644 index 0000000000..5d107dd8fc --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/BeanNameResolver.java @@ -0,0 +1,7 @@ +package org.springframework.ai.chat.client.resolver; + +public interface BeanNameResolver { + + String resolveName(Class beanType); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/SimpleNameResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/SimpleNameResolver.java new file mode 100644 index 0000000000..c72a397278 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/resolver/SimpleNameResolver.java @@ -0,0 +1,10 @@ +package org.springframework.ai.chat.client.resolver; + +public class SimpleNameResolver implements BeanNameResolver { + + @Override + public String resolveName(Class beanType) { + return beanType.getSimpleName().substring(0, 1).toLowerCase() + beanType.getSimpleName().substring(1); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/ToolPlayground.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/ToolPlayground.java new file mode 100644 index 0000000000..9c7168e6e4 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/ToolPlayground.java @@ -0,0 +1,71 @@ +package org.springframework.ai.chat.client.tool; + +import org.springframework.ai.chat.client.ChatClient; + +import java.util.function.Function; + +public class ToolPlayground { + + record Request(String name) { + } + + record Response(String title) { + } + + static class ToolBean implements Function { + + @Override + public Response apply(Request request) { + return null; + } + } + + /* + * TODO: It appears that we could make this work with the existing implementation by + * changing the tool methods to return a FunctionCallWrapper. + * I think it might be worth renaming the FunctionCallbackWrapper + */ + public void playground(ChatClient client) { + + client.prompt() + .function("Name", "description", (Request request) -> new Response("")); + + /* + * This seems like a pretty good interface for using the name based API + */ + client.prompt() + .function(Tools.getByName("somefunction")); + + /* + * This seems like a pretty good interface for using the bean based API + */ + client.prompt() + .function(Tools.getByBean(ToolBean.class)); + + /* + * To get proper type inference we need to add the generics here. If we do not add + * this then we do not have type information on the input or the output To me this + * looks kinda ugly and makes the caller write less elegant code + */ + client.prompt() + .function( + Tools.getByLambda( + "somefunction", + "description", + request -> new Response("") + )); + /* + * To a limited extent it is possible to address this ugly syntax by defining an + * explicit type on the input type. And it is fair to notice that this issue also + * exists with the current implementation + */ + client.prompt() + .function( + Tools.getByLambda( + "somefunction", + "description", + (Request reqest) -> new Response("") + )); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/Tools.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/Tools.java new file mode 100644 index 0000000000..4ec3c5b4ff --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/Tools.java @@ -0,0 +1,56 @@ +package org.springframework.ai.chat.client.tool; + +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.util.Assert; + +import java.util.function.Function; + +/** + * TODO: In the current implementation the application context is required to define a tool by name and by type + * At runtime this should not be an issue, however accessing a non-static attribute from static methods has obvious issues. + * + * Proposal: + * I think it would make sense to redesign how functions / tools are handled in general. Functions created via the + * FunctionCallbackWrappers do a lot of processing when they are defined in the ChatClient. However, the functions created + * via the bean name are only handled much later (if my reading of the codebase is correct this happens during execution of the call to the LLM). + * I believe it would be better to handle all functions / tools in a more uniform way, so either process them as they are defined + * or when the LLM call is being executed. This would not only make the code more consistent but also make it easier to extend the + * mechanisms by which functions / tools can be defined. + */ +public class Tools { + + private static GenericApplicationContext context; + + public Tools(GenericApplicationContext context) { + Tools.context = context; + } + + + public static FunctionCallbackWrapper getByName(String name) { + // TODO we need to get to the application context. + // Get the bean by name and then create the tool itself... + String description = context.getBeanDefinition(name).getDescription(); + return FunctionCallbackWrapper.builder(Function.identity()) + .withName(name) + .withDescription(description) + .build(); + } + + public static FunctionCallbackWrapper getByBean(Class beanType) { + String[] namesForType = context.getBeanNamesForType(beanType); + Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton"); + String name = namesForType[0]; + String description = context.getBeanDefinition(name).getDescription(); + return FunctionCallbackWrapper.builder(Function.identity()) + .withName(name) + .withDescription(description) + .build(); + } + + public static FunctionCallbackWrapper getByLambda(String name, String description, + Function func) { + return FunctionCallbackWrapper.builder(func).withDescription("description").withName(name).build(); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java index ac39626b7b..841233aaa3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.chat.client; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.resolver.ApplicationContextBeanNameResolver; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.client.ChatClientCustomizer; import org.springframework.beans.factory.ObjectProvider; @@ -28,6 +29,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Scope; +import org.springframework.context.support.GenericApplicationContext; /** * {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}. @@ -60,8 +62,9 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider