Skip to content

Commit

Permalink
fix(utils/llm_tool_call): support anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Dec 27, 2024
1 parent 322fac6 commit 7c4b106
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 6 additions & 1 deletion npiai/types/function_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

ToolFunction = Callable[..., Awaitable[Any]]

__EMPTY_PARAMS__ = {"type": "object", "properties": {}}


@dataclass(frozen=True)
class FunctionRegistration:
Expand Down Expand Up @@ -44,11 +46,14 @@ def get_tool_param(self, strict: bool = True) -> ChatCompletionToolParam:
"function": {
"name": self.name,
"description": self.description,
"parameters": (
self.schema if self.schema is not None else __EMPTY_PARAMS__
),
},
}

if self.schema is not None and strict:
tool["function"]["strict"] = True
tool["function"]["parameters"] = self.schema
# tool["function"]["parameters"] = self.schema

return tool
4 changes: 2 additions & 2 deletions npiai/utils/llm_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from litellm.types.completion import ChatCompletionMessageParam
from pydantic import BaseModel

from npiai.llm import LLM
from npiai.llm import LLM, Anthropic
from .parse_npi_function import parse_npi_function


Expand All @@ -20,7 +20,7 @@ async def llm_tool_call(

response = await llm.completion(
messages=messages,
tools=[fn_reg.get_tool_param()],
tools=[fn_reg.get_tool_param(strict=not isinstance(llm, Anthropic))],
max_tokens=4096,
tool_choice={"type": "function", "function": {"name": fn_reg.name}},
)
Expand Down

0 comments on commit 7c4b106

Please sign in to comment.