Skip to content

Commit

Permalink
Merge pull request #2 from r7b7/add-tool-support
Browse files Browse the repository at this point in the history
Add tool support
  • Loading branch information
r7b7 authored Dec 11, 2024
2 parents 42ca6e5 + 4c48bda commit c79fc63
Show file tree
Hide file tree
Showing 24 changed files with 147 additions and 35 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ A library designed for quick prototyping with LLMs, and fully compatible with pr
3. Pull a PR

## Features
1. Support for following LLM providers:
1. Support for following LLM providers: OpenAI, Anthropic, Groq, Ollama
2. PromptBuilder to build complex prompts
3. Support to customize default client implementations (Flexible approach for integrating with frameworks like SpringBoot)

1. OpenAI, Anthropic, Groq, Ollama
2. PromptBuilder to build complex prompts
3. Support default to custom client implementations without refactoring entire codebases
### Features in Pipeline
1. Image Support
2. Stream Response

## Installation
1. Add jitpack repository in pom file
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/com/r7b7/client/DefaultAnthropicClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Properties;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.r7b7.client.model.AnthroToolResponse;
import com.r7b7.client.model.AnthropicResponse;
import com.r7b7.client.model.Message;
import com.r7b7.config.PropertyConfig;
Expand Down Expand Up @@ -67,7 +68,14 @@ private CompletionResponse extractResponseText(String responseBody) {
try {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, AnthropicResponse.class);
msgs = response.content().stream().map(content -> new Message(content.type(), content.text())).toList();
final String role = response.role();
msgs = response.content().stream().map(content -> {
if(content.type().equalsIgnoreCase("tool_use")){
return new Message(role, content.text(), List.of(new AnthroToolResponse(content.type(), content.id(), content.name(), content.input())));
} else {
return new Message(role, content.text(), null);
}
}).toList();
metadata = Map.of(
"id", response.id(),
"model", response.model(),
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/r7b7/client/DefaultGroqClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ private CompletionResponse extractResponseText(String responseBody) {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OpenAIResponse.class);
msgs = response.choices().stream()
.map(choice -> new Message(choice.message().role(), choice.message().content())).toList();
.map(choice -> new Message(choice.message().role(), choice.message().content(),
choice.message().toolCalls()))
.toList();
metadata = Map.of(
"id", response.id(),
"model", response.model(),
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/r7b7/client/DefaultOpenAIClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ private CompletionResponse extractResponseText(String responseBody) {
ObjectMapper mapper = new ObjectMapper();
response = mapper.readValue(responseBody, OpenAIResponse.class);
msgs = response.choices().stream()
.map(choice -> new Message(choice.message().role(), choice.message().content())).toList();
.map(choice -> new Message(choice.message().role(), choice.message().content(),
choice.message().toolCalls()))
.toList();
metadata = Map.of(
"id", response.id(),
"model", response.model(),
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/com/r7b7/client/model/AnthroToolResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.r7b7.client.model;

public record AnthroToolResponse(String type, String id, String name, Object input){

}
2 changes: 1 addition & 1 deletion src/main/java/com/r7b7/client/model/AnthropicResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public record AnthropicResponse(String id, String model, List<Content> content,
public record AnthropicResponse(String id, String model, String role, List<Content> content,
@JsonProperty("usage") AnthroUsage usage) {

}
2 changes: 1 addition & 1 deletion src/main/java/com/r7b7/client/model/Content.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(ignoreUnknown = true)
public record Content(String type, String text) {}
public record Content(String type, String text, String id, String name, Object input) {}
5 changes: 4 additions & 1 deletion src/main/java/com/r7b7/client/model/Message.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.r7b7.client.model;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public record Message(String role, String content) {}
public record Message(String role, String content, @JsonProperty("tool_calls")List<?> toolCalls) {}
6 changes: 6 additions & 0 deletions src/main/java/com/r7b7/entity/AnthropicTool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.r7b7.entity;

import java.util.Map;

public record AnthropicTool(String name, String description, Map<String, Object> input_schema) {
}
5 changes: 5 additions & 0 deletions src/main/java/com/r7b7/entity/Tool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.r7b7.entity;

public record Tool(String type, ToolFunction function) {

}
6 changes: 6 additions & 0 deletions src/main/java/com/r7b7/entity/ToolFunction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.r7b7.entity;

import java.util.Map;

public record ToolFunction(String name, String description, Map<String, Object> parameters) {
}
22 changes: 19 additions & 3 deletions src/main/java/com/r7b7/model/BaseLLMRequest.java
Original file line number Diff line number Diff line change
@@ -1,26 +1,42 @@
package com.r7b7.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import com.r7b7.entity.Message;
import com.r7b7.entity.ToolFunction;

public class BaseLLMRequest implements ILLMRequest {
private final List<Message> messages;
private final Map<String, Object> parameters;
private List<ToolFunction> functions;
private Object toolChoice;

public BaseLLMRequest(List<Message> messages, Map<String, Object> parameters) {
public BaseLLMRequest(List<Message> messages, Map<String, Object> parameters, List<ToolFunction> functions, Object toolChoice) {
this.messages = messages;
this.parameters = parameters;
this.functions = functions;
this.toolChoice = toolChoice;
}

@Override
public List<Message> getPrompt() {
return messages;
return this.messages;
}

@Override
public Map<String, Object> getParameters() {
return parameters;
return this.parameters;
}

@Override
public List<ToolFunction> getFunctions() {
return this.functions;
}

@Override
public Object getToolChoice() {
return this.toolChoice;
}
}
3 changes: 3 additions & 0 deletions src/main/java/com/r7b7/model/ILLMRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import java.util.Map;

import com.r7b7.entity.Message;
import com.r7b7.entity.ToolFunction;

public interface ILLMRequest {
List<Message> getPrompt();
Map<String, Object> getParameters();
List<ToolFunction> getFunctions();
Object getToolChoice();
}
11 changes: 11 additions & 0 deletions src/main/java/com/r7b7/service/AnthropicService.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import com.r7b7.client.IAnthropicClient;
import com.r7b7.client.factory.LLMClientFactory;
import com.r7b7.entity.AnthropicTool;
import com.r7b7.entity.CompletionRequest;
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
import com.r7b7.entity.Role;
import com.r7b7.entity.Tool;
import com.r7b7.model.ILLMRequest;
import com.r7b7.util.StringUtility;

Expand All @@ -34,6 +36,15 @@ public CompletionResponse generateResponse(ILLMRequest request) {
requestMap.put("system", systemMessage);
}
requestMap.put("messages", request.getPrompt());
if (null != request.getFunctions()) {
List<AnthropicTool> tool = request.getFunctions().stream().map(func -> new AnthropicTool(func.name(), func.description(), func.parameters())).toList();
requestMap.put("tools", tool);
}
if (null != request.getToolChoice() && request.getToolChoice() instanceof String) {
requestMap.put("tool_choice", Map.of("type", request.getToolChoice()));
} else {
requestMap.put("tool_choice", request.getToolChoice());
}
// set mandatory param if not set explicitly
requestMap.put("max_tokens", 1024);
if (null != request.getParameters()) {
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/com/r7b7/service/GroqService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
import com.r7b7.entity.Role;
import com.r7b7.entity.Tool;
import com.r7b7.model.ILLMRequest;

public class GroqService implements ILLMService {
Expand All @@ -29,6 +30,13 @@ public CompletionResponse generateResponse(ILLMRequest request) {
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", this.model);
requestMap.put("messages", request.getPrompt());
if (null != request.getFunctions()) {
List<Tool> tool = request.getFunctions().stream().map(func -> new Tool("function", func)).toList();
requestMap.put("tools", tool);
}
if (null != request.getToolChoice()) {
requestMap.put("tool_choice", request.getToolChoice());
}
if (null != request.getParameters()) {
for (Map.Entry<String, Object> entry : request.getParameters().entrySet()) {
requestMap.put(entry.getKey(), entry.getValue());
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/com/r7b7/service/OllamaService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
import com.r7b7.entity.Role;
import com.r7b7.entity.Tool;
import com.r7b7.model.ILLMRequest;

public class OllamaService implements ILLMService {
Expand All @@ -27,6 +28,13 @@ public CompletionResponse generateResponse(ILLMRequest request) {
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", this.model);
requestMap.put("messages", request.getPrompt());
if (null != request.getFunctions()) {
List<Tool> tool = request.getFunctions().stream().map(func -> new Tool("function", func)).toList();
requestMap.put("tools", tool);
}
if (null != request.getToolChoice()) {
requestMap.put("tool_choice", request.getToolChoice());
}

Map<String, Object> optionsMap = new HashMap<>();
if (null != request.getParameters()) {
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/com/r7b7/service/OpenAIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
import com.r7b7.entity.Role;
import com.r7b7.entity.Tool;
import com.r7b7.model.ILLMRequest;

public class OpenAIService implements ILLMService {
Expand All @@ -29,6 +30,13 @@ public CompletionResponse generateResponse(ILLMRequest request) {
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", this.model);
requestMap.put("messages", request.getPrompt());
if (null != request.getFunctions()) {
List<Tool> tool = request.getFunctions().stream().map(func -> new Tool("function", func)).toList();
requestMap.put("tools", tool);
}
if (null != request.getToolChoice()) {
requestMap.put("tool_choice", request.getToolChoice());
}
if (null != request.getParameters()) {
for (Map.Entry<String, Object> entry : request.getParameters().entrySet()) {
requestMap.put(entry.getKey(), entry.getValue());
Expand Down
23 changes: 18 additions & 5 deletions src/main/java/com/r7b7/service/PromptBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,43 @@
import java.util.Map;

import com.r7b7.entity.Message;
import com.r7b7.entity.ToolFunction;

public class PromptBuilder {
private List<Message> messages = new ArrayList<>();
private Map<String, Object> params = new HashMap<>();
private List<ToolFunction> functions = new ArrayList<>();
private Object toolChoice = "none";

public PromptBuilder addMessage(Message message) {
messages.add(message);
this.messages.add(message);
return this;
}

public PromptBuilder addParam(String key, Object value) {
params.put(key, value);
this.params.put(key, value);
return this;
}

public PromptBuilder addTool(ToolFunction function) {
this.functions.add(function);
return this;
}

public PromptBuilder addToolChoice(String choice) {
this.toolChoice = choice;
return this;
}

public List<Message> getMessages() {
return messages;
return this.messages;
}

public Map<String, Object> getParams() {
return params;
return this.params;
}

public PromptEngine build(ILLMService service) {
return new PromptEngine(service, params, messages);
return new PromptEngine(service, this.params, this.messages, this.functions, this.toolChoice);
}
}
14 changes: 10 additions & 4 deletions src/main/java/com/r7b7/service/PromptEngine.java
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
package com.r7b7.service;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import com.r7b7.entity.CompletionResponse;
import com.r7b7.entity.Message;
import com.r7b7.entity.ToolFunction;
import com.r7b7.model.BaseLLMRequest;
import com.r7b7.model.ILLMRequest;

public class PromptEngine {
private final ILLMService llmService;
private final Map<String, Object> params;
private final List<Message> messages;
private List<ToolFunction> functions = new ArrayList<>();
private Object toolChoice = "none";

public PromptEngine(ILLMService llmService) {
this(llmService, null, null);
this(llmService, null, null, null, null);
}

public PromptEngine(ILLMService llmService, Map<String, Object> params, List<Message> messages) {
public PromptEngine(ILLMService llmService, Map<String, Object> params, List<Message> messages, List<ToolFunction> functions, Object toolChoice) {
this.llmService = llmService;
this.params = params;
this.messages = messages;
this.functions = functions;
this.toolChoice = toolChoice;
}

public CompletionResponse sendQuery() {
ILLMRequest request = new BaseLLMRequest(this.messages, this.params);
ILLMRequest request = new BaseLLMRequest(this.messages, this.params, this.functions, this.toolChoice);
CompletionResponse response = this.llmService.generateResponse(request);
return response;
}

public CompletableFuture<CompletionResponse> sendQueryAsync() {
ILLMRequest request = new BaseLLMRequest(this.messages, this.params);
ILLMRequest request = new BaseLLMRequest(this.messages, this.params, this.functions, this.toolChoice);
return this.llmService.generateResponseAsync(request);
}

Expand Down
6 changes: 3 additions & 3 deletions src/test/java/com/r7b7/service/AnthropicServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ private ILLMRequest createMockLLMRequest(String prompt) {
messages.add(new Message(Role.system, "You are a helpful assistant"));
messages.add(new Message(Role.assistant, "You are a helpful assistant"));
messages.add(new Message(Role.user, prompt));
ILLMRequest request = new BaseLLMRequest(messages, null);
ILLMRequest request = new BaseLLMRequest(messages, null, null, null);
return request;
}

private CompletionResponse createMockCompletionResponse(String content) {
List<com.r7b7.client.model.Message> messages = new ArrayList<>();
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content");
com.r7b7.client.model.Message msg = new com.r7b7.client.model.Message("user", "test content", null);
messages.add(msg);
Map<String, Object> metaData = new HashMap<>();
metaData.put("model", TEST_MODEL);
Expand All @@ -140,7 +140,7 @@ private ILLMRequest createMockLLMRequestWithParams(String prompt) {
"temperature", 0.7,
"max_token", 1000);

ILLMRequest request = new BaseLLMRequest(messages, params);
ILLMRequest request = new BaseLLMRequest(messages, params, null, null);
return request;
}
}
Loading

0 comments on commit c79fc63

Please sign in to comment.