Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add possibility to register functions in a type-safe manner #787

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;

import org.springframework.ai.chat.client.resolver.BeanNameResolver;
import org.springframework.ai.chat.client.resolver.SimpleNameResolver;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.Media;
Expand Down Expand Up @@ -64,21 +69,27 @@
*/
public interface ChatClient {

static ChatClient create(ChatModel chatModel) {
return builder(chatModel).build();
static ChatClient create(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return builder(chatModel, beanNameResolver).build();
}

// QUESTION: Should this constructor be removed in favour of the one requiring the
// beanNameResolver?
static Builder builder(ChatModel chatModel) {
return new Builder(chatModel);
return new Builder(chatModel, new SimpleNameResolver());
}

static Builder builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return new Builder(chatModel, beanNameResolver);
}

ChatClientRequest prompt();

ChatClientPromptRequest prompt(Prompt prompt);

/**
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
* settings are replicated from the default {@link ChatClientRequest} of this client.
* Return a {@link Builder} to create a new {@link ChatClient} whose settings are
* replicated from the default {@link ChatClientRequest} of this client.
*/
Builder mutate();

Expand Down Expand Up @@ -263,16 +274,20 @@ class ChatClientRequest {

private final Map<String, Object> advisorParams = new HashMap<>();

private final BeanNameResolver beanNameResolver;

/* copy constructor */
ChatClientRequest(ChatClientRequest ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.beanNameResolver);
}

public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams,
BeanNameResolver beanNameResolver) {

this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
Expand All @@ -288,14 +303,15 @@ public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Objec
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.beanNameResolver = beanNameResolver;
}

/**
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
Builder builder = ChatClient.builder(chatModel)
Builder builder = ChatClient.builder(chatModel, beanNameResolver)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
Expand Down Expand Up @@ -349,6 +365,11 @@ public <T extends ChatOptions> ChatClientRequest options(T options) {
return this;
}

public <I, O> ChatClientRequest function(FunctionCallbackWrapper<I, O> functionWrapper) {
this.functionCallbacks.add(functionWrapper);
return this;
}

public <I, O> ChatClientRequest function(String name, String description,
java.util.function.Function<I, O> function) {

Expand All @@ -365,6 +386,11 @@ public <I, O> ChatClientRequest function(String name, String description,
return this;
}

public <T> ChatClientRequest functionBean(Class<T> functionBeanType) {
this.functionNames.add(beanNameResolver.resolveName(functionBeanType));
return this;
}

public ChatClientRequest functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
this.functionNames.addAll(List.of(functionBeanNames));
Expand Down Expand Up @@ -528,7 +554,7 @@ private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest,
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
adviseRequest.advisorParams(), inputRequest.beanNameResolver);
}

return advisedRequest;
Expand Down Expand Up @@ -735,11 +761,11 @@ class Builder {

private final ChatModel chatModel;

Builder(ChatModel chatModel) {
Builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of());
List.of(), List.of(), null, List.of(), Map.of(), beanNameResolver);
}

public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.springframework.ai.chat.client;

import org.springframework.context.annotation.Description;
import org.springframework.core.annotation.AliasFor;
import org.springframework.stereotype.Indexed;

import java.lang.annotation.*;

@Target({ ElementType.METHOD, ElementType.TYPE, ElementType.ANNOTATION_TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Indexed
@Description("")
public @interface Tool {

@AliasFor(annotation = Description.class, attribute = "value")
String description();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springframework.ai.chat.client.resolver;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.Assert;

public class ApplicationContextBeanNameResolver implements BeanNameResolver {

private final GenericApplicationContext context;

public ApplicationContextBeanNameResolver(GenericApplicationContext context) {
this.context = context;
}

@Override
public <T> String resolveName(Class<T> beanType) {
String[] namesForType = context.getBeanNamesForType(beanType);
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton");

/*
* The following snippet could be used in other places to resolve the description
* from a function registered as a bean BeanDefinition beanDefinition =
* context.getBeanDefinition(namesForType[0]); String description =
* beanDefinition.getDescription();
*/

return namesForType[0];
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.springframework.ai.chat.client.resolver;

public interface BeanNameResolver {

<T> String resolveName(Class<T> beanType);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.springframework.ai.chat.client.resolver;

public class SimpleNameResolver implements BeanNameResolver {

@Override
public <T> String resolveName(Class<T> beanType) {
return beanType.getSimpleName().substring(0, 1).toLowerCase() + beanType.getSimpleName().substring(1);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.springframework.ai.chat.client.tool;

import org.springframework.ai.chat.client.ChatClient;

import java.util.function.Function;

public class ToolPlayground {

record Request(String name) {
}

record Response(String title) {
}

static class ToolBean implements Function<Request, Response> {

@Override
public Response apply(Request request) {
return null;
}
}

/*
* TODO: It appears that we could make this work with the existing implementation by
* changing the tool methods to return a FunctionCallWrapper.
* I think it might be worth renaming the FunctionCallbackWrapper
*/
public void playground(ChatClient client) {

client.prompt()
.function("Name", "description", (Request request) -> new Response(""));

/*
* This seems like a pretty good interface for using the name based API
*/
client.prompt()
.function(Tools.getByName("somefunction"));

/*
* This seems like a pretty good interface for using the bean based API
*/
client.prompt()
.function(Tools.getByBean(ToolBean.class));

/*
* To get proper type inference we need to add the generics here. If we do not add
* this then we do not have type information on the input or the output To me this
* looks kinda ugly and makes the caller write less elegant code
*/
client.prompt()
.function(
Tools.<Request, Response>getByLambda(
"somefunction",
"description",
request -> new Response("")
));
/*
* To a limited extent it is possible to address this ugly syntax by defining an
* explicit type on the input type. And it is fair to notice that this issue also
* exists with the current implementation
*/
client.prompt()
.function(
Tools.getByLambda(
"somefunction",
"description",
(Request reqest) -> new Response("")
));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.springframework.ai.chat.client.tool;

import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.Assert;

import java.util.function.Function;

/**
* TODO: In the current implementation the application context is required to define a tool by name and by type
* At runtime this should not be an issue, however accessing a non-static attribute from static methods has obvious issues.
*
* Proposal:
* I think it would make sense to redesign how functions / tools are handled in general. Functions created via the
* FunctionCallbackWrappers do a lot of processing when they are defined in the ChatClient. However, the functions created
* via the bean name are only handled much later (if my reading of the codebase is correct this happens during execution of the call to the LLM).
* I believe it would be better to handle all functions / tools in a more uniform way, so either process them as they are defined
* or when the LLM call is being executed. This would not only make the code more consistent but also make it easier to extend the
* mechanisms by which functions / tools can be defined.
*/
public class Tools {

private static GenericApplicationContext context;

public Tools(GenericApplicationContext context) {
Tools.context = context;
}


public static <Req, Res> FunctionCallbackWrapper<Object, Object> getByName(String name) {
// TODO we need to get to the application context.
// Get the bean by name and then create the tool itself...
String description = context.getBeanDefinition(name).getDescription();
return FunctionCallbackWrapper.builder(Function.identity())
.withName(name)
.withDescription(description)
.build();
}

public static <Req, Res, T> FunctionCallbackWrapper<Object, Object> getByBean(Class<T> beanType) {
String[] namesForType = context.getBeanNamesForType(beanType);
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton");
String name = namesForType[0];
String description = context.getBeanDefinition(name).getDescription();
return FunctionCallbackWrapper.builder(Function.identity())
.withName(name)
.withDescription(description)
.build();
}

public static <Req, Res> FunctionCallbackWrapper<Req, Res> getByLambda(String name, String description,
Function<Req, Res> func) {
return FunctionCallbackWrapper.builder(func).withDescription("description").withName(name).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.chat.client;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.resolver.ApplicationContextBeanNameResolver;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -28,6 +29,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import org.springframework.context.support.GenericApplicationContext;

/**
* {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}.
Expand Down Expand Up @@ -60,8 +62,9 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClien
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel) {
ChatClient.Builder builder = ChatClient.builder(chatModel);
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel,
GenericApplicationContext context) {
ChatClient.Builder builder = ChatClient.builder(chatModel, new ApplicationContextBeanNameResolver(context));
return chatClientBuilderConfigurer.configure(builder);
}

Expand Down