From 4ef7de65c5386c714b5e9561217981329c436c19 Mon Sep 17 00:00:00 2001 From: Chris Ruppelt Date: Fri, 6 Dec 2024 04:18:43 -0500 Subject: [PATCH 1/2] [WIP] Ollama tools --- instructor/__init__.py | 7 +++- instructor/client_ollama.py | 76 ++++++++++++++++++++++++++++++++++ instructor/function_calls.py | 19 ++++++++- instructor/mode.py | 1 + instructor/process_response.py | 7 ++++ instructor/utils.py | 1 + pyproject.toml | 1 + 7 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 instructor/client_ollama.py diff --git a/instructor/__init__.py b/instructor/__init__.py index efd503c22..f61eda5bd 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -95,4 +95,9 @@ if importlib.util.find_spec("writerai") is not None: from .client_writer import from_writer - __all__ += ["from_writer"] \ No newline at end of file + __all__ += ["from_writer"] + +if importlib.util.find_spec("ollama") is not None: + from .client_ollama import from_ollama + + __all__ += ["from_ollama"] diff --git a/instructor/client_ollama.py b/instructor/client_ollama.py new file mode 100644 index 000000000..ff256466d --- /dev/null +++ b/instructor/client_ollama.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, overload + +import ollama + +import instructor + + +@overload +def from_ollama( + client: ( + ollama.Client + ), + mode: instructor.Mode = instructor.Mode.OLLAMA_TOOLS, + **kwargs: Any, +) -> instructor.Instructor: ... + + +@overload +def from_ollama( + client: ( + ollama.AsyncClient + ), + mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS, + **kwargs: Any, +) -> instructor.AsyncInstructor: ... + + +def from_ollama( + client: ( + ollama.Client + | ollama.AsyncClient + ), + mode: instructor.Mode = instructor.Mode.OLLAMA_TOOLS, + **kwargs: Any, +) -> instructor.Instructor | instructor.AsyncInstructor: + assert ( + mode + in { + instructor.Mode.OLLAMA_TOOLS, + } + ), "Mode be one of {instructor.Mode.OLLAMA_TOOLS}" + + assert isinstance( + client, + ( + ollama.Client, + ollama.AsyncClient, + ), + ), "Client must be an instance of {ollama.Client, ollama.AsyncClient}" + + create = client.chat + + if isinstance( + client, + ( + ollama.Client, + ), + ): + return instructor.Instructor( + client=client, + create=instructor.patch(create=create, mode=mode), + provider=instructor.Provider.OLLAMA, + mode=mode, + **kwargs, + ) + + else: + return instructor.AsyncInstructor( + client=client, + create=instructor.patch(create=create, mode=mode), + provider=instructor.Provider.OLLAMA, + mode=mode, + **kwargs, + ) \ No newline at end of file diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 8507c2cd6..8741a0698 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -3,6 +3,7 @@ import logging from functools import wraps from typing import Annotated, Any, Optional, TypeVar, cast + from docstring_parser import parse from openai.types.chat import ChatCompletion from pydantic import ( @@ -21,7 +22,6 @@ map_to_gemini_function_schema, ) - T = TypeVar("T") logger = logging.getLogger("instructor") @@ -138,6 +138,9 @@ def from_response( if mode == Mode.WRITER_TOOLS: return cls.parse_writer_tools(completion, validation_context, strict) + + if mode == Mode.OLLAMA_TOOLS: + return cls.parse_ollama_tools(completion, validation_context, strict) if completion.choices[0].finish_reason == "length": raise IncompleteOutputException(last_completion=completion) @@ -386,6 +389,20 @@ def parse_json( context=validation_context, strict=strict, ) + + @classmethod + def parse_ollama_tools( + cls: type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[dict[str, Any]] = None, + strict: Optional[bool] = None, + ): + message = completion.message.content + return cls.model_validate_json( + message, + context=validation_context, + strict=strict, + ) def openai_schema(cls: type[BaseModel]) -> OpenAISchema: diff --git a/instructor/mode.py b/instructor/mode.py index 66bbfbad3..a3c95607c 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -27,6 +27,7 @@ class Mode(enum.Enum): FIREWORKS_TOOLS = "fireworks_tools" FIREWORKS_JSON = "fireworks_json" WRITER_TOOLS = "writer_tools" + OLLAMA_TOOLS = "ollama_tools" @classmethod def warn_mode_functions_deprecation(cls): diff --git a/instructor/process_response.py b/instructor/process_response.py index d4a2100eb..c11ee950b 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -596,6 +596,12 @@ def handle_writer_tools( new_kwargs["tool_choice"] = "auto" return response_model, new_kwargs +def handle_ollama_tools( + response_model: type[T], new_kwargs: dict[str, Any] +) -> tuple[type[T], dict[str, Any]]: + new_kwargs["format"] = response_model.model_json_schema() + return response_model, new_kwargs + def is_typed_dict(cls) -> bool: return ( @@ -715,6 +721,7 @@ def handle_response_model( Mode.FIREWORKS_JSON: handle_fireworks_json, Mode.FIREWORKS_TOOLS: handle_fireworks_tools, Mode.WRITER_TOOLS: handle_writer_tools, + Mode.OLLAMA_TOOLS: handle_ollama_tools, } if mode in mode_handlers: diff --git a/instructor/utils.py b/instructor/utils.py index 55d746760..a2d57efc4 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -54,6 +54,7 @@ class Provider(Enum): CEREBRAS = "cerebras" FIREWORKS = "fireworks" WRITER = "writer" + OLLAMA = "ollama" UNKNOWN = "unknown" diff --git a/pyproject.toml b/pyproject.toml index ad7b084f1..6ecc569a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ jsonref = { version = "^1.1.0", optional = true } cerebras_cloud_sdk = { version = "^1.5.0", optional = true } fireworks-ai = { version = "^0.15.4", optional = true } writer-sdk = { version = "^1.2.0", optional = true } +ollama-python = { version = "TODO", optional = true } # Ollama hasn't releaed a version yet [tool.poetry.extras] anthropic = ["anthropic", "xmltodict"] From c7f47a8feaee5f2f59668adc81f2c161c41f6ec0 Mon Sep 17 00:00:00 2001 From: Chris Ruppelt Date: Fri, 6 Dec 2024 04:23:56 -0500 Subject: [PATCH 2/2] typo --- instructor/client_ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instructor/client_ollama.py b/instructor/client_ollama.py index ff256466d..eb98336e6 100644 --- a/instructor/client_ollama.py +++ b/instructor/client_ollama.py @@ -22,7 +22,7 @@ def from_ollama( client: ( ollama.AsyncClient ), - mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS, + mode: instructor.Mode = instructor.Mode.OLLAMA_TOOLS, **kwargs: Any, ) -> instructor.AsyncInstructor: ...