-
Notifications
You must be signed in to change notification settings - Fork 4
Add anthropic model support #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3740508
add anthropic model support, testing needed
Char15Xu eb1f549
fixed conflict error in libem/core/model/init.py
Char15Xu 78a6423
Merge branch 'main' into feat-claude
Char15Xu 963c3b1
merge changes
Char15Xu 7d2d091
add claude key configure support in libem.cli and syntax revisions
Char15Xu 24ff013
Merge branch 'main' into feat-claude
Char15Xu d2e5c9e
claude successfully work on examples/model/claude.py
Char15Xu 53d75b6
Merge branch 'main' into feat-claude
Char15Xu 9b4c0a5
delete print statement in cli/libem
Char15Xu 82b30fb
Merge branch 'main' into feat-claude
Char15Xu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import libem | ||
|
||
from libem.match.prompt import rules | ||
|
||
def positive(): | ||
e1 = "Dyson Hot+Cool AM09 Jet Focus heater and fan, White/Silver" | ||
e2 = "Dyson AM09 Hot + Cool Jet Focus Fan Heater - W/S" | ||
|
||
is_match = libem.match(e1, e2) | ||
|
||
print("Entity 1:", e1) | ||
print("Entity 2:", e2) | ||
print("Match:", is_match['answer']) | ||
|
||
def negative(): | ||
e1 = "Dyson Hot+Cool AM09 Jet Focus heater and fan, White/Silver" | ||
e2 = "Dyson AM09 Hot + Cool Jet Focus Fan Heater - Black japan" | ||
|
||
rules.add("Color differentiates entities.") | ||
is_match = libem.match(e1, e2) | ||
|
||
print("Entity 1:", e1) | ||
print("Entity 2:", e2) | ||
print("Match:", is_match['answer']) | ||
|
||
def main(): | ||
libem.calibrate({ | ||
"libem.match.parameter.model": "claude-3-5-sonnet-20240620", | ||
}, verbose=True) | ||
positive() | ||
negative() | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
import os | ||
import json | ||
import httpx | ||
import importlib | ||
import inspect | ||
import asyncio | ||
from anthropic import ( | ||
AsyncAnthropic, APITimeoutError | ||
) | ||
|
||
import libem | ||
from libem.core import exec | ||
|
||
os.environ.setdefault( | ||
"CLAUDE_API_KEY", | ||
libem.LIBEM_CONFIG.get("CLAUDE_API_KEY", "") | ||
) | ||
|
||
_client = None | ||
|
||
|
||
def get_client(): | ||
global _client | ||
|
||
if not os.environ.get("CLAUDE_API_KEY"): | ||
Char15Xu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise EnvironmentError(f"CLAUDE_API_KEY is not set.") | ||
|
||
if not _client: | ||
_client = AsyncAnthropic( | ||
api_key=os.environ["CLAUDE_API_KEY"], | ||
http_client=httpx.AsyncClient( | ||
limits=httpx.Limits( | ||
max_connections=1000, | ||
max_keepalive_connections=100 | ||
) | ||
) | ||
) | ||
return _client | ||
|
||
|
||
def call(*args, **kwargs) -> dict: | ||
return exec.run_async_task( | ||
async_call(*args, **kwargs) | ||
) | ||
|
||
|
||
# Model call with multiple rounds of tool use | ||
async def async_call( | ||
prompt: str | list | dict, | ||
tools: list[str] = None, | ||
context: list = None, | ||
model: str = "claude-3-5-sonnet-20240620", | ||
temperature: float = 0.0, | ||
seed: int = None, | ||
max_model_call: int = 3, | ||
) -> dict: | ||
client = get_client() | ||
|
||
context = context or [] | ||
|
||
# format the prompt to messages | ||
system_message = None | ||
user_messages = [] | ||
|
||
match prompt: | ||
case list(): | ||
for msg in prompt: | ||
if msg["role"] == "system": | ||
system_message = msg["content"] | ||
else: | ||
user_messages.append(msg) | ||
case dict(): | ||
for role, content in prompt.items(): | ||
if role == "system": | ||
system_message = content | ||
else: | ||
user_messages.append({"role": role, "content": content}) | ||
case str(): | ||
user_messages = [{"role": "user", "content": prompt}] | ||
case _: | ||
raise ValueError(f"Invalid prompt type: {type(prompt)}") | ||
|
||
# Handle context | ||
for msg in context: | ||
if msg["role"] == "system": | ||
if system_message is None: | ||
system_message = msg["content"] | ||
else: | ||
system_message += "\n" + msg["content"] | ||
else: | ||
user_messages.insert(0, msg) | ||
|
||
messages = user_messages | ||
|
||
# trace variables | ||
num_model_calls = 0 | ||
num_input_tokens, num_output_tokens = 0, 0 | ||
tool_usages, tool_outputs = [], [] | ||
|
||
"""Start call""" | ||
|
||
if not tools: | ||
try: | ||
response = await client.messages.create( | ||
messages=messages, | ||
system=system_message, | ||
model=model, | ||
temperature=temperature, | ||
max_tokens = 1000, | ||
) | ||
except APITimeoutError as e: # catch timeout error | ||
raise libem.ModelTimedoutException(e) | ||
|
||
response_message = response.content[0].text | ||
print(response_message) | ||
num_model_calls += 1 | ||
num_input_tokens += response.usage.input_tokens | ||
num_output_tokens += response.usage.input_tokens | ||
else: | ||
# Load the tool modules | ||
tools = [importlib.import_module(tool) for tool in tools] | ||
|
||
# Get the functions from the tools and | ||
# prefer async functions if available | ||
available_functions = { | ||
tool.name: getattr(tool, 'async_func', tool.func) | ||
for tool in tools | ||
} | ||
|
||
# Get the schema from the tools | ||
tools = [tool.schema for tool in tools] | ||
|
||
# Call model | ||
try: | ||
response = await client.messages.create( | ||
messages=messages, | ||
system=system_message, | ||
tools=tools, | ||
tool_choice="auto", | ||
model=model, | ||
temperature=temperature, | ||
max_tokens = 1000, | ||
) | ||
|
||
except APITimeoutError as e: # catch timeout error | ||
raise libem.ModelTimedoutException(e) | ||
|
||
response_message = response.content[0].text | ||
tool_uses = response_message.tool_use | ||
|
||
num_model_calls += 1 | ||
num_input_tokens += response.usage.input_tokens | ||
num_output_tokens += response.usage.input_tokens | ||
|
||
# Call tools | ||
while tool_use: | ||
messages.append(response_message) | ||
|
||
for tool_use in tool_uses: | ||
function_name = tool_use.name | ||
function_to_call = available_functions[function_name] | ||
function_args = json.loads(tool_use.input) | ||
|
||
libem.debug(f"[{function_name}] {function_args}") | ||
|
||
if inspect.iscoroutinefunction(function_to_call): | ||
function_response = function_to_call(**function_args) | ||
else: | ||
function_response = function_to_call(**function_args) | ||
|
||
messages.append( | ||
{ | ||
"role": "tool", | ||
"name": function_name, | ||
"content": str(function_response), | ||
"tool_use_id": tool_use.id, | ||
} | ||
) | ||
|
||
tool_usages.append({ | ||
"id": tool_use.id, | ||
'name': function_name, | ||
"arguments": function_args, | ||
"response": function_response, | ||
}) | ||
|
||
tool_outputs.append({ | ||
function_name: function_response, | ||
}) | ||
|
||
tool_uses = [] | ||
|
||
if num_model_calls < max_model_call: | ||
# Call the model again with the tool outcomes | ||
try: | ||
response = await client.messages.create( | ||
messages=messages, | ||
system=system_message, | ||
tools=tools, | ||
tool_choice="auto", | ||
model=model, | ||
temperature=temperature, | ||
max_tokens = 1000, | ||
) | ||
except APITimeoutError as e: # catch timeout error | ||
raise libem.ModelTimedoutException(e) | ||
|
||
response_message = response.content[0].text | ||
tool_uses = response_message.tool_use | ||
|
||
num_model_calls += 1 | ||
num_input_tokens += response.usage.input_tokens | ||
num_output_tokens += response.usage.input_tokens | ||
|
||
if num_model_calls == max_model_call: | ||
libem.debug(f"[model] max call reached: " | ||
f"{messages}\n{response_message}") | ||
|
||
"""End call""" | ||
|
||
messages.append(response_message) | ||
|
||
libem.trace.add({ | ||
"model": { | ||
"messages": messages, | ||
"tool_usages": tool_usages, | ||
"num_model_calls": num_model_calls, | ||
"num_input_tokens": num_input_tokens, | ||
"num_output_tokens": num_output_tokens, | ||
} | ||
}) | ||
|
||
return { | ||
"output": response_message, | ||
"tool_outputs": tool_outputs, | ||
"messages": messages, | ||
"stats": { | ||
"num_model_calls": num_model_calls, | ||
"num_input_tokens": num_input_tokens, | ||
"num_output_tokens": num_output_tokens, | ||
} | ||
} | ||
|
||
|
||
def reset(): | ||
global _client | ||
_client = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.