diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..2108caa Binary files /dev/null and b/.DS_Store differ diff --git a/Makefile b/Makefile index 6314069..dbf84c8 100644 --- a/Makefile +++ b/Makefile @@ -153,3 +153,7 @@ duckdb: python examples/apps/integration/duckdb_cluster.py mongodb: python examples/apps/integration/mongodb_cluster.py + +.PHONY: claude +claude: + python examples/model/claude.py \ No newline at end of file diff --git a/benchmark/.DS_Store b/benchmark/.DS_Store new file mode 100644 index 0000000..601d5e8 Binary files /dev/null and b/benchmark/.DS_Store differ diff --git a/examples/model/claude.py b/examples/model/claude.py new file mode 100644 index 0000000..940ea22 --- /dev/null +++ b/examples/model/claude.py @@ -0,0 +1,34 @@ +import libem + +from libem.match.prompt import rules + +def positive(): + e1 = "Dyson Hot+Cool AM09 Jet Focus heater and fan, White/Silver" + e2 = "Dyson AM09 Hot + Cool Jet Focus Fan Heater - W/S" + + is_match = libem.match(e1, e2) + + print("Entity 1:", e1) + print("Entity 2:", e2) + print("Match:", is_match['answer']) + +def negative(): + e1 = "Dyson Hot+Cool AM09 Jet Focus heater and fan, White/Silver" + e2 = "Dyson AM09 Hot + Cool Jet Focus Fan Heater - Black japan" + + rules.add("Color differentiates entities.") + is_match = libem.match(e1, e2) + + print("Entity 1:", e1) + print("Entity 2:", e2) + print("Match:", is_match['answer']) + +def main(): + libem.calibrate({ + "libem.match.parameter.model": "claude-3-5-sonnet-20240620", + }, verbose=True) + positive() + negative() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/libem/core/model/__init__.py b/libem/core/model/__init__.py index 79f6d6c..51cac24 100644 --- a/libem/core/model/__init__.py +++ b/libem/core/model/__init__.py @@ -1,5 +1,5 @@ from libem.core.model import ( - openai, llama + openai, llama, claude ) from libem.core import exec import libem @@ -15,6 +15,10 @@ async def async_call(*args, **kwargs) -> dict: return llama.call(*args, **kwargs) elif kwargs.get("model", "") == "llama3.1": return llama.call(*args, **kwargs) + elif kwargs.get("model", "") == "llama3.1": + return llama.call(*args, **kwargs) + elif kwargs.get("model", "") == "claude-3-5-sonnet-20240620": + return await claude.call(*args, **kwargs) else: return await openai.async_call(*args, **kwargs) diff --git a/libem/core/model/claude.py b/libem/core/model/claude.py new file mode 100644 index 0000000..e3d7b4f --- /dev/null +++ b/libem/core/model/claude.py @@ -0,0 +1,247 @@ +import os +import json +import httpx +import importlib +import inspect +import asyncio +from anthropic import ( + AsyncAnthropic, APITimeoutError +) + +import libem +from libem.core import exec + +os.environ.setdefault( + "CLAUDE_API_KEY", + libem.LIBEM_CONFIG.get("CLAUDE_API_KEY", "") +) + +_client = None + + +def get_client(): + global _client + + if not os.environ.get("CLAUDE_API_KEY"): + raise EnvironmentError(f"CLAUDE_API_KEY is not set.") + + if not _client: + _client = AsyncAnthropic( + api_key=os.environ["CLAUDE_API_KEY"], + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_connections=1000, + max_keepalive_connections=100 + ) + ) + ) + return _client + + +def call(*args, **kwargs) -> dict: + return exec.run_async_task( + async_call(*args, **kwargs) + ) + + +# Model call with multiple rounds of tool use +async def async_call( + prompt: str | list | dict, + tools: list[str] = None, + context: list = None, + model: str = "claude-3-5-sonnet-20240620", + temperature: float = 0.0, + seed: int = None, + max_model_call: int = 3, +) -> dict: + client = get_client() + + context = context or [] + + # format the prompt to messages + system_message = None + user_messages = [] + + match prompt: + case list(): + for msg in prompt: + if msg["role"] == "system": + system_message = msg["content"] + else: + user_messages.append(msg) + case dict(): + for role, content in prompt.items(): + if role == "system": + system_message = content + else: + user_messages.append({"role": role, "content": content}) + case str(): + user_messages = [{"role": "user", "content": prompt}] + case _: + raise ValueError(f"Invalid prompt type: {type(prompt)}") + + # Handle context + for msg in context: + if msg["role"] == "system": + if system_message is None: + system_message = msg["content"] + else: + system_message += "\n" + msg["content"] + else: + user_messages.insert(0, msg) + + messages = user_messages + + # trace variables + num_model_calls = 0 + num_input_tokens, num_output_tokens = 0, 0 + tool_usages, tool_outputs = [], [] + + """Start call""" + + if not tools: + try: + response = await client.messages.create( + messages=messages, + system=system_message, + model=model, + temperature=temperature, + max_tokens = 1000, + ) + except APITimeoutError as e: # catch timeout error + raise libem.ModelTimedoutException(e) + + response_message = response.content[0].text + print(response_message) + num_model_calls += 1 + num_input_tokens += response.usage.input_tokens + num_output_tokens += response.usage.input_tokens + else: + # Load the tool modules + tools = [importlib.import_module(tool) for tool in tools] + + # Get the functions from the tools and + # prefer async functions if available + available_functions = { + tool.name: getattr(tool, 'async_func', tool.func) + for tool in tools + } + + # Get the schema from the tools + tools = [tool.schema for tool in tools] + + # Call model + try: + response = await client.messages.create( + messages=messages, + system=system_message, + tools=tools, + tool_choice="auto", + model=model, + temperature=temperature, + max_tokens = 1000, + ) + + except APITimeoutError as e: # catch timeout error + raise libem.ModelTimedoutException(e) + + response_message = response.content[0].text + tool_uses = response_message.tool_use + + num_model_calls += 1 + num_input_tokens += response.usage.input_tokens + num_output_tokens += response.usage.input_tokens + + # Call tools + while tool_use: + messages.append(response_message) + + for tool_use in tool_uses: + function_name = tool_use.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_use.input) + + libem.debug(f"[{function_name}] {function_args}") + + if inspect.iscoroutinefunction(function_to_call): + function_response = function_to_call(**function_args) + else: + function_response = function_to_call(**function_args) + + messages.append( + { + "role": "tool", + "name": function_name, + "content": str(function_response), + "tool_use_id": tool_use.id, + } + ) + + tool_usages.append({ + "id": tool_use.id, + 'name': function_name, + "arguments": function_args, + "response": function_response, + }) + + tool_outputs.append({ + function_name: function_response, + }) + + tool_uses = [] + + if num_model_calls < max_model_call: + # Call the model again with the tool outcomes + try: + response = await client.messages.create( + messages=messages, + system=system_message, + tools=tools, + tool_choice="auto", + model=model, + temperature=temperature, + max_tokens = 1000, + ) + except APITimeoutError as e: # catch timeout error + raise libem.ModelTimedoutException(e) + + response_message = response.content[0].text + tool_uses = response_message.tool_use + + num_model_calls += 1 + num_input_tokens += response.usage.input_tokens + num_output_tokens += response.usage.input_tokens + + if num_model_calls == max_model_call: + libem.debug(f"[model] max call reached: " + f"{messages}\n{response_message}") + + """End call""" + + messages.append(response_message) + + libem.trace.add({ + "model": { + "messages": messages, + "tool_usages": tool_usages, + "num_model_calls": num_model_calls, + "num_input_tokens": num_input_tokens, + "num_output_tokens": num_output_tokens, + } + }) + + return { + "output": response_message, + "tool_outputs": tool_outputs, + "messages": messages, + "stats": { + "num_model_calls": num_model_calls, + "num_input_tokens": num_input_tokens, + "num_output_tokens": num_output_tokens, + } + } + + +def reset(): + global _client + _client = None \ No newline at end of file diff --git a/libem/parameter.py b/libem/parameter.py index 60b50c2..1535013 100644 --- a/libem/parameter.py +++ b/libem/parameter.py @@ -4,7 +4,7 @@ default="gpt-4o-2024-08-06", options=["gpt-4o","gpt-4o-mini", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", - "llama3", "llama3.1"] + "llama3", "llama3.1", "claude-3-5-sonnet-20240620"] ) temperature = Parameter( default=0,