diff --git a/instructor/client_vertexai.py b/instructor/client_vertexai.py index ba4775ba9..51833a76e 100644 --- a/instructor/client_vertexai.py +++ b/instructor/client_vertexai.py @@ -1,15 +1,20 @@ from __future__ import annotations -from typing import Any +from typing import Any, Type, Union, get_origin from vertexai.preview.generative_models import ToolConfig # type: ignore import vertexai.generative_models as gm # type: ignore from pydantic import BaseModel import instructor +from instructor.dsl.parallel import get_types_array import jsonref def _create_gemini_json_schema(model: BaseModel): + # Add type check to ensure we have a concrete model class + if get_origin(model) is not None: + raise TypeError(f"Expected concrete model class, got type hint {model}") + schema = model.model_json_schema() schema_without_refs: dict[str, Any] = jsonref.replace_refs(schema) # type: ignore gemini_schema: dict[Any, Any] = { @@ -22,16 +27,28 @@ def _create_gemini_json_schema(model: BaseModel): return gemini_schema -def _create_vertexai_tool(model: BaseModel) -> gm.Tool: - parameters = _create_gemini_json_schema(model) - - declaration = gm.FunctionDeclaration( - name=model.__name__, description=model.__doc__, parameters=parameters - ) - - tool = gm.Tool(function_declarations=[declaration]) +def _create_vertexai_tool(models: Union[BaseModel, list[BaseModel], Type]) -> gm.Tool: + """Creates a tool with function declarations for single model or list of models""" + # Handle Iterable case first + if get_origin(models) is not None: + model_list = list(get_types_array(models)) + else: + # Handle both single model and list of models + model_list = models if isinstance(models, list) else [models] + + print(f"Debug - Model list: {[model.__name__ for model in model_list]}") + + declarations = [] + for model in model_list: + parameters = _create_gemini_json_schema(model) + declaration = gm.FunctionDeclaration( + name=model.__name__, + description=model.__doc__, + parameters=parameters + ) + declarations.append(declaration) - return tool + return gm.Tool(function_declarations=declarations) def vertexai_message_parser( @@ -84,11 +101,11 @@ def vertexai_function_response_parser( ) -def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel): +def vertexai_process_response(_kwargs: dict[str, Any], model: Union[BaseModel, list[BaseModel], Type]): messages: list[dict[str, str]] = _kwargs.pop("messages") contents = _vertexai_message_list_parser(messages) # type: ignore - tool = _create_vertexai_tool(model=model) + tool = _create_vertexai_tool(models=model) tool_config = ToolConfig( function_calling_config=ToolConfig.FunctionCallingConfig( @@ -122,6 +139,7 @@ def from_vertexai( **kwargs: Any, ) -> instructor.Instructor: assert mode in { + instructor.Mode.VERTEXAI_PARALLEL_TOOLS, instructor.Mode.VERTEXAI_TOOLS, instructor.Mode.VERTEXAI_JSON, }, "Mode must be instructor.Mode.VERTEXAI_TOOLS" diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index a42dfa418..5d207f5f8 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -1,4 +1,5 @@ import sys +import json from typing import ( Any, Optional, @@ -45,6 +46,38 @@ def from_response( ) +class VertexAIParallelBase(ParallelBase): + def from_response( + self, + response: Any, + mode: Mode, + validation_context: Optional[Any] = None, + strict: Optional[bool] = None, + ) -> Generator[BaseModel, None, None]: + assert mode == Mode.VERTEXAI_PARALLEL_TOOLS, "Mode must be VERTEXAI_PARALLEL_TOOLS" + + if not response or not response.candidates: + return + + for candidate in response.candidates: + if not candidate.content or not candidate.content.parts: + continue + + for part in candidate.content.parts: + if (hasattr(part, 'function_call') and + part.function_call is not None): + + name = part.function_call.name + arguments = part.function_call.args + + if name in self.registry: + # Convert dict to JSON string before validation + json_str = json.dumps(arguments) + yield self.registry[name].model_validate_json( + json_str, context=validation_context, strict=strict + ) + + if sys.version_info >= (3, 10): from types import UnionType @@ -82,3 +115,7 @@ def handle_parallel_model(typehint: type[Iterable[T]]) -> list[dict[str, Any]]: def ParallelModel(typehint: type[Iterable[T]]) -> ParallelBase: the_types = get_types_array(typehint) return ParallelBase(*[model for model in the_types]) + +def VertexAIParallelModel(typehint: type[Iterable[T]]) -> VertexAIParallelBase: + the_types = get_types_array(typehint) + return VertexAIParallelBase(*[model for model in the_types]) diff --git a/instructor/mode.py b/instructor/mode.py index 66bbfbad3..ebd330b40 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -18,6 +18,7 @@ class Mode(enum.Enum): COHERE_TOOLS = "cohere_tools" VERTEXAI_TOOLS = "vertexai_tools" VERTEXAI_JSON = "vertexai_json" + VERTEXAI_PARALLEL_TOOLS = "vertexai_parallel_tools" GEMINI_JSON = "gemini_json" GEMINI_TOOLS = "gemini_tools" COHERE_JSON_SCHEMA = "json_object" diff --git a/instructor/process_response.py b/instructor/process_response.py index d4a2100eb..7c0722f1a 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -16,7 +16,14 @@ from instructor.mode import Mode from instructor.dsl.iterable import IterableBase, IterableModel -from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model +from instructor.dsl.parallel import ( + ParallelBase, + ParallelModel, + handle_parallel_model, + get_types_array, + VertexAIParallelBase, + VertexAIParallelModel +) from instructor.dsl.partial import PartialBase from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type from instructor.function_calls import OpenAISchema, openai_schema @@ -112,7 +119,7 @@ def process_response( validation_context: dict[str, Any] | None = None, strict=None, mode: Mode = Mode.TOOLS, -): +) -> T_Model | list[T_Model] | VertexAIParallelBase | None: """ Process the response from the API call and convert it to the specified response model. @@ -485,6 +492,27 @@ def handle_gemini_tools( return response_model, new_kwargs +def handle_vertexai_parallel_tools( + response_model: type[Iterable[T]], new_kwargs: dict[str, Any] +) -> tuple[VertexAIParallelBase, dict[str, Any]]: + assert ( + new_kwargs.get("stream", False) is False + ), "stream=True is not supported when using PARALLEL_TOOLS mode" + + from instructor.client_vertexai import vertexai_process_response + from instructor.dsl.parallel import VertexAIParallelModel + + # Extract concrete types before passing to vertexai_process_response + model_types = list(get_types_array(response_model)) + contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types) + + new_kwargs["contents"] = contents + new_kwargs["tools"] = tools + new_kwargs["tool_config"] = tool_config + + return VertexAIParallelModel(typehint=response_model), new_kwargs + + def handle_vertexai_tools( response_model: type[T], new_kwargs: dict[str, Any] ) -> tuple[type[T], dict[str, Any]]: @@ -646,7 +674,7 @@ def prepare_response_model(response_model: type[T] | None) -> type[T] | None: def handle_response_model( response_model: type[T] | None, mode: Mode = Mode.TOOLS, **kwargs: Any -) -> tuple[type[T] | None, dict[str, Any]]: +) -> tuple[type[T] | VertexAIParallelBase | None, dict[str, Any]]: """ Handles the response model based on the specified mode and prepares the kwargs for the API call. @@ -690,6 +718,8 @@ def handle_response_model( if mode in {Mode.PARALLEL_TOOLS}: return handle_parallel_tools(response_model, new_kwargs) + elif mode in {Mode.VERTEXAI_PARALLEL_TOOLS}: + return handle_vertexai_parallel_tools(response_model, new_kwargs) response_model = prepare_response_model(response_model)