Skip to content

Commit

Permalink
Add Parallel Tool mode for Vertex AI
Browse files Browse the repository at this point in the history
  • Loading branch information
devjn committed Nov 26, 2024
1 parent 58eef74 commit 41c1a78
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 15 deletions.
42 changes: 30 additions & 12 deletions instructor/client_vertexai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

from typing import Any
from typing import Any, Type, Union, get_origin

Check failure on line 3 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP035)

instructor/client_vertexai.py:3:1: UP035 `typing.Type` is deprecated, use `type` instead

from vertexai.preview.generative_models import ToolConfig # type: ignore
import vertexai.generative_models as gm # type: ignore
from pydantic import BaseModel
import instructor
from instructor.dsl.parallel import get_types_array
import jsonref


def _create_gemini_json_schema(model: BaseModel):
# Add type check to ensure we have a concrete model class
if get_origin(model) is not None:
raise TypeError(f"Expected concrete model class, got type hint {model}")

schema = model.model_json_schema()
schema_without_refs: dict[str, Any] = jsonref.replace_refs(schema) # type: ignore
gemini_schema: dict[Any, Any] = {
Expand All @@ -22,16 +27,28 @@ def _create_gemini_json_schema(model: BaseModel):
return gemini_schema


def _create_vertexai_tool(model: BaseModel) -> gm.Tool:
parameters = _create_gemini_json_schema(model)

declaration = gm.FunctionDeclaration(
name=model.__name__, description=model.__doc__, parameters=parameters
)

tool = gm.Tool(function_declarations=[declaration])
def _create_vertexai_tool(models: Union[BaseModel, list[BaseModel], Type]) -> gm.Tool:

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP007)

instructor/client_vertexai.py:30:35: UP007 Use `X | Y` for type annotations

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP006)

instructor/client_vertexai.py:30:69: UP006 Use `type` instead of `Type` for type annotation

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of parameter "models" is partially unknown   Parameter type is "BaseModel | list[BaseModel] | Type[Unknown]" (reportUnknownParameterType)

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Expected type arguments for generic class "Type" (reportMissingTypeArgument)

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of parameter "models" is partially unknown   Parameter type is "BaseModel | list[BaseModel] | Type[Unknown]" (reportUnknownParameterType)

Check failure on line 30 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Expected type arguments for generic class "Type" (reportMissingTypeArgument)
"""Creates a tool with function declarations for single model or list of models"""
# Handle Iterable case first
if get_origin(models) is not None:
model_list = list(get_types_array(models))

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of "model_list" is partially unknown   Type of "model_list" is "list[Unknown]" (reportUnknownVariableType)

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__init__"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Argument of type "BaseModel | list[BaseModel] | Type[Unknown]" cannot be assigned to parameter "typehint" of type "type[Iterable[T@get_types_array]]" in function "get_types_array"   Type "BaseModel | list[BaseModel] | Type[Unknown]" is incompatible with type "type[Iterable[T@get_types_array]]"     Type "BaseModel" is incompatible with type "type[Iterable[T@get_types_array]]" (reportArgumentType)

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "model_list" is partially unknown   Type of "model_list" is "list[Unknown]" (reportUnknownVariableType)

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Argument type is partially unknown   Argument corresponds to parameter "iterable" in function "__init__"   Argument type is "tuple[Unknown, ...]" (reportUnknownArgumentType)

Check failure on line 34 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Argument of type "BaseModel | list[BaseModel] | Type[Unknown]" cannot be assigned to parameter "typehint" of type "type[Iterable[T@get_types_array]]" in function "get_types_array"   Type "BaseModel | list[BaseModel] | Type[Unknown]" is incompatible with type "type[Iterable[T@get_types_array]]"     Type "BaseModel" is incompatible with type "type[Iterable[T@get_types_array]]" (reportArgumentType)
else:
# Handle both single model and list of models
model_list = models if isinstance(models, list) else [models]

Check failure on line 37 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of "model_list" is partially unknown   Type of "model_list" is "list[BaseModel] | list[BaseModel | Type[Unknown]]" (reportUnknownVariableType)

Check failure on line 37 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "model_list" is partially unknown   Type of "model_list" is "list[BaseModel] | list[BaseModel | Type[Unknown]]" (reportUnknownVariableType)

print(f"Debug - Model list: {[model.__name__ for model in model_list]}")

Check failure on line 39 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of "__name__" is partially unknown   Type of "__name__" is "Unknown | str" (reportUnknownMemberType)

Check failure on line 39 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of "model" is partially unknown   Type of "model" is "Unknown | BaseModel | Type[Unknown]" (reportUnknownVariableType)

Check failure on line 39 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "__name__" is partially unknown   Type of "__name__" is "Unknown | str" (reportUnknownMemberType)

Check failure on line 39 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "model" is partially unknown   Type of "model" is "Unknown | BaseModel | Type[Unknown]" (reportUnknownVariableType)

declarations = []
for model in model_list:

Check failure on line 42 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Type of "model" is partially unknown   Type of "model" is "Unknown | BaseModel | Type[Unknown]" (reportUnknownVariableType)

Check failure on line 42 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "model" is partially unknown   Type of "model" is "Unknown | BaseModel | Type[Unknown]" (reportUnknownVariableType)
parameters = _create_gemini_json_schema(model)

Check failure on line 43 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.10)

Argument of type "Unknown | BaseModel | Type[Unknown]" cannot be assigned to parameter "model" of type "BaseModel" in function "_create_gemini_json_schema"   Type "Unknown | BaseModel | Type[Unknown]" is incompatible with type "BaseModel"     "Type[Unknown]" is incompatible with "BaseModel" (reportArgumentType)

Check failure on line 43 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Argument of type "Unknown | BaseModel | Type[Unknown]" cannot be assigned to parameter "model" of type "BaseModel" in function "_create_gemini_json_schema"   Type "Unknown | BaseModel | Type[Unknown]" is incompatible with type "BaseModel"     "Type[Unknown]" is incompatible with "BaseModel" (reportArgumentType)
declaration = gm.FunctionDeclaration(
name=model.__name__,
description=model.__doc__,
parameters=parameters
)
declarations.append(declaration)

return tool
return gm.Tool(function_declarations=declarations)


def vertexai_message_parser(
Expand Down Expand Up @@ -84,11 +101,11 @@ def vertexai_function_response_parser(
)


def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel):
def vertexai_process_response(_kwargs: dict[str, Any], model: Union[BaseModel, list[BaseModel], Type]):

Check failure on line 104 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP007)

instructor/client_vertexai.py:104:63: UP007 Use `X | Y` for type annotations

Check failure on line 104 in instructor/client_vertexai.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP006)

instructor/client_vertexai.py:104:97: UP006 Use `type` instead of `Type` for type annotation
messages: list[dict[str, str]] = _kwargs.pop("messages")
contents = _vertexai_message_list_parser(messages) # type: ignore

tool = _create_vertexai_tool(model=model)
tool = _create_vertexai_tool(models=model)

tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
Expand Down Expand Up @@ -122,6 +139,7 @@ def from_vertexai(
**kwargs: Any,
) -> instructor.Instructor:
assert mode in {
instructor.Mode.VERTEXAI_PARALLEL_TOOLS,
instructor.Mode.VERTEXAI_TOOLS,
instructor.Mode.VERTEXAI_JSON,
}, "Mode must be instructor.Mode.VERTEXAI_TOOLS"
Expand Down
37 changes: 37 additions & 0 deletions instructor/dsl/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import json
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -45,6 +46,38 @@ def from_response(
)


class VertexAIParallelBase(ParallelBase):
def from_response(
self,
response: Any,
mode: Mode,
validation_context: Optional[Any] = None,
strict: Optional[bool] = None,
) -> Generator[BaseModel, None, None]:
assert mode == Mode.VERTEXAI_PARALLEL_TOOLS, "Mode must be VERTEXAI_PARALLEL_TOOLS"

if not response or not response.candidates:
return

for candidate in response.candidates:
if not candidate.content or not candidate.content.parts:
continue

for part in candidate.content.parts:
if (hasattr(part, 'function_call') and
part.function_call is not None):

name = part.function_call.name
arguments = part.function_call.args

if name in self.registry:
# Convert dict to JSON string before validation
json_str = json.dumps(arguments)
yield self.registry[name].model_validate_json(
json_str, context=validation_context, strict=strict
)


if sys.version_info >= (3, 10):
from types import UnionType

Expand Down Expand Up @@ -82,3 +115,7 @@ def handle_parallel_model(typehint: type[Iterable[T]]) -> list[dict[str, Any]]:
def ParallelModel(typehint: type[Iterable[T]]) -> ParallelBase:
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])

def VertexAIParallelModel(typehint: type[Iterable[T]]) -> VertexAIParallelBase:
the_types = get_types_array(typehint)
return VertexAIParallelBase(*[model for model in the_types])
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Mode(enum.Enum):
COHERE_TOOLS = "cohere_tools"
VERTEXAI_TOOLS = "vertexai_tools"
VERTEXAI_JSON = "vertexai_json"
VERTEXAI_PARALLEL_TOOLS = "vertexai_parallel_tools"
GEMINI_JSON = "gemini_json"
GEMINI_TOOLS = "gemini_tools"
COHERE_JSON_SCHEMA = "json_object"
Expand Down
36 changes: 33 additions & 3 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@

from instructor.mode import Mode
from instructor.dsl.iterable import IterableBase, IterableModel
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.parallel import (
ParallelBase,
ParallelModel,
handle_parallel_model,
get_types_array,
VertexAIParallelBase,
VertexAIParallelModel

Check failure on line 25 in instructor/process_response.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (F401)

instructor/process_response.py:25:5: F401 `instructor.dsl.parallel.VertexAIParallelModel` imported but unused
)
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
Expand Down Expand Up @@ -112,7 +119,7 @@ def process_response(
validation_context: dict[str, Any] | None = None,
strict=None,
mode: Mode = Mode.TOOLS,
):
) -> T_Model | list[T_Model] | VertexAIParallelBase | None:
"""
Process the response from the API call and convert it to the specified response model.
Expand Down Expand Up @@ -485,6 +492,27 @@ def handle_gemini_tools(
return response_model, new_kwargs


def handle_vertexai_parallel_tools(
response_model: type[Iterable[T]], new_kwargs: dict[str, Any]
) -> tuple[VertexAIParallelBase, dict[str, Any]]:
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"

from instructor.client_vertexai import vertexai_process_response
from instructor.dsl.parallel import VertexAIParallelModel

# Extract concrete types before passing to vertexai_process_response
model_types = list(get_types_array(response_model))
contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types)

new_kwargs["contents"] = contents
new_kwargs["tools"] = tools
new_kwargs["tool_config"] = tool_config

return VertexAIParallelModel(typehint=response_model), new_kwargs


def handle_vertexai_tools(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
Expand Down Expand Up @@ -646,7 +674,7 @@ def prepare_response_model(response_model: type[T] | None) -> type[T] | None:

def handle_response_model(
response_model: type[T] | None, mode: Mode = Mode.TOOLS, **kwargs: Any
) -> tuple[type[T] | None, dict[str, Any]]:
) -> tuple[type[T] | VertexAIParallelBase | None, dict[str, Any]]:
"""
Handles the response model based on the specified mode and prepares the kwargs for the API call.
Expand Down Expand Up @@ -690,6 +718,8 @@ def handle_response_model(

if mode in {Mode.PARALLEL_TOOLS}:
return handle_parallel_tools(response_model, new_kwargs)
elif mode in {Mode.VERTEXAI_PARALLEL_TOOLS}:
return handle_vertexai_parallel_tools(response_model, new_kwargs)

response_model = prepare_response_model(response_model)

Expand Down

0 comments on commit 41c1a78

Please sign in to comment.