Skip to content

Commit

Permalink
PoC: Create a factory class for defining tools.
Browse files Browse the repository at this point in the history
Created a new factory class that centralises the creation of Tools. It provides and API for creating Tools by
- name
- lambda
- bean type
  • Loading branch information
pgerhard committed Jun 5, 2024
1 parent 87bb4cb commit cb5a2a5
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ public <T extends ChatOptions> ChatClientRequest options(T options) {
return this;
}

public <I, O> ChatClientRequest function(FunctionCallbackWrapper<I, O> functionWrapper) {
this.functionCallbacks.add(functionWrapper);
return this;
}

public <I, O> ChatClientRequest function(String name, String description,
java.util.function.Function<I, O> function) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.springframework.ai.chat.client.tool;

import org.springframework.ai.chat.client.ChatClient;

import java.util.function.Function;

public class ToolPlayground {

record Request(String name) {
}

record Response(String title) {
}

static class ToolBean implements Function<Request, Response> {

@Override
public Response apply(Request request) {
return null;
}
}

/*
* TODO: It appears that we could make this work with the existing implementation by
* changing the tool methods to return a FunctionCallWrapper.
* I think it might be worth renaming the FunctionCallbackWrapper
*/
public void playground(ChatClient client) {

client.prompt()
.function("Name", "description", (Request request) -> new Response(""));

/*
* This seems like a pretty good interface for using the name based API
*/
client.prompt()
.function(Tools.getByName("somefunction"));

/*
* This seems like a pretty good interface for using the bean based API
*/
client.prompt()
.function(Tools.getByBean(ToolBean.class));

/*
* To get proper type inference we need to add the generics here. If we do not add
* this then we do not have type information on the input or the output To me this
* looks kinda ugly and makes the caller write less elegant code
*/
client.prompt()
.function(
Tools.<Request, Response>getByLambda(
"somefunction",
"description",
request -> new Response("")
));
/*
* To a limited extent it is possible to address this ugly syntax by defining an
* explicit type on the input type. And it is fair to notice that this issue also
* exists with the current implementation
*/
client.prompt()
.function(
Tools.getByLambda(
"somefunction",
"description",
(Request reqest) -> new Response("")
));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.springframework.ai.chat.client.tool;

import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.Assert;

import java.util.function.Function;

/**
* TODO: In the current implementation the application context is required to define a tool by name and by type
* At runtime this should not be an issue, however accessing a non-static attribute from static methods has obvious issues.
*
* Proposal:
* I think it would make sense to redesign how functions / tools are handled in general. Functions created via the
* FunctionCallbackWrappers do a lot of processing when they are defined in the ChatClient. However, the functions created
* via the bean name are only handled much later (if my reading of the codebase is correct this happens during execution of the call to the LLM).
* I believe it would be better to handle all functions / tools in a more uniform way, so either process them as they are defined
* or when the LLM call is being executed. This would not only make the code more consistent but also make it easier to extend the
* mechanisms by which functions / tools can be defined.
*/
public class Tools {

private static GenericApplicationContext context;

public Tools(GenericApplicationContext context) {
Tools.context = context;
}


public static <Req, Res> FunctionCallbackWrapper<Object, Object> getByName(String name) {
// TODO we need to get to the application context.
// Get the bean by name and then create the tool itself...
String description = context.getBeanDefinition(name).getDescription();
return FunctionCallbackWrapper.builder(Function.identity())
.withName(name)
.withDescription(description)
.build();
}

public static <Req, Res, T> FunctionCallbackWrapper<Object, Object> getByBean(Class<T> beanType) {
String[] namesForType = context.getBeanNamesForType(beanType);
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton");
String name = namesForType[0];
String description = context.getBeanDefinition(name).getDescription();
return FunctionCallbackWrapper.builder(Function.identity())
.withName(name)
.withDescription(description)
.build();
}

public static <Req, Res> FunctionCallbackWrapper<Req, Res> getByLambda(String name, String description,
Function<Req, Res> func) {
return FunctionCallbackWrapper.builder(func).withDescription("description").withName(name).build();
}

}

0 comments on commit cb5a2a5

Please sign in to comment.