-
Notifications
You must be signed in to change notification settings - Fork 737
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PoC: Create a factory class for defining tools.
Created a new factory class that centralises the creation of Tools. It provides and API for creating Tools by - name - lambda - bean type
- Loading branch information
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/ToolPlayground.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package org.springframework.ai.chat.client.tool; | ||
|
||
import org.springframework.ai.chat.client.ChatClient; | ||
|
||
import java.util.function.Function; | ||
|
||
public class ToolPlayground { | ||
|
||
record Request(String name) { | ||
} | ||
|
||
record Response(String title) { | ||
} | ||
|
||
static class ToolBean implements Function<Request, Response> { | ||
|
||
@Override | ||
public Response apply(Request request) { | ||
return null; | ||
} | ||
} | ||
|
||
/* | ||
* TODO: It appears that we could make this work with the existing implementation by | ||
* changing the tool methods to return a FunctionCallWrapper. | ||
* I think it might be worth renaming the FunctionCallbackWrapper | ||
*/ | ||
public void playground(ChatClient client) { | ||
|
||
client.prompt() | ||
.function("Name", "description", (Request request) -> new Response("")); | ||
|
||
/* | ||
* This seems like a pretty good interface for using the name based API | ||
*/ | ||
client.prompt() | ||
.function(Tools.getByName("somefunction")); | ||
|
||
/* | ||
* This seems like a pretty good interface for using the bean based API | ||
*/ | ||
client.prompt() | ||
.function(Tools.getByBean(ToolBean.class)); | ||
|
||
/* | ||
* To get proper type inference we need to add the generics here. If we do not add | ||
* this then we do not have type information on the input or the output To me this | ||
* looks kinda ugly and makes the caller write less elegant code | ||
*/ | ||
client.prompt() | ||
.function( | ||
Tools.<Request, Response>getByLambda( | ||
"somefunction", | ||
"description", | ||
request -> new Response("") | ||
)); | ||
/* | ||
* To a limited extent it is possible to address this ugly syntax by defining an | ||
* explicit type on the input type. And it is fair to notice that this issue also | ||
* exists with the current implementation | ||
*/ | ||
client.prompt() | ||
.function( | ||
Tools.getByLambda( | ||
"somefunction", | ||
"description", | ||
(Request reqest) -> new Response("") | ||
)); | ||
} | ||
|
||
} |
56 changes: 56 additions & 0 deletions
56
spring-ai-core/src/main/java/org/springframework/ai/chat/client/tool/Tools.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package org.springframework.ai.chat.client.tool; | ||
|
||
import org.springframework.ai.model.function.FunctionCallbackWrapper; | ||
import org.springframework.context.support.GenericApplicationContext; | ||
import org.springframework.util.Assert; | ||
|
||
import java.util.function.Function; | ||
|
||
/** | ||
* TODO: In the current implementation the application context is required to define a tool by name and by type | ||
* At runtime this should not be an issue, however accessing a non-static attribute from static methods has obvious issues. | ||
* | ||
* Proposal: | ||
* I think it would make sense to redesign how functions / tools are handled in general. Functions created via the | ||
* FunctionCallbackWrappers do a lot of processing when they are defined in the ChatClient. However, the functions created | ||
* via the bean name are only handled much later (if my reading of the codebase is correct this happens during execution of the call to the LLM). | ||
* I believe it would be better to handle all functions / tools in a more uniform way, so either process them as they are defined | ||
* or when the LLM call is being executed. This would not only make the code more consistent but also make it easier to extend the | ||
* mechanisms by which functions / tools can be defined. | ||
*/ | ||
public class Tools { | ||
|
||
private static GenericApplicationContext context; | ||
|
||
public Tools(GenericApplicationContext context) { | ||
Tools.context = context; | ||
} | ||
|
||
|
||
public static <Req, Res> FunctionCallbackWrapper<Object, Object> getByName(String name) { | ||
// TODO we need to get to the application context. | ||
// Get the bean by name and then create the tool itself... | ||
String description = context.getBeanDefinition(name).getDescription(); | ||
return FunctionCallbackWrapper.builder(Function.identity()) | ||
.withName(name) | ||
.withDescription(description) | ||
.build(); | ||
} | ||
|
||
public static <Req, Res, T> FunctionCallbackWrapper<Object, Object> getByBean(Class<T> beanType) { | ||
String[] namesForType = context.getBeanNamesForType(beanType); | ||
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton"); | ||
String name = namesForType[0]; | ||
String description = context.getBeanDefinition(name).getDescription(); | ||
return FunctionCallbackWrapper.builder(Function.identity()) | ||
.withName(name) | ||
.withDescription(description) | ||
.build(); | ||
} | ||
|
||
public static <Req, Res> FunctionCallbackWrapper<Req, Res> getByLambda(String name, String description, | ||
Function<Req, Res> func) { | ||
return FunctionCallbackWrapper.builder(func).withDescription("description").withName(name).build(); | ||
} | ||
|
||
} |