From d894807aa33aaff8434200bd23aad83e28f47051 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 22 Feb 2024 11:13:45 -0800 Subject: [PATCH] Automatic function calling. (#201) * Starting automatic function calling * Working on AFC * Fix typos * Add tools overrides for generate_content and send_message * Add initial AFC loop. * Basic debugging, streaming's probably broken. * Add error with stream=True * format * add pydantic * fix tests * replace __init__ * Fix pytype * Remove property * format * working on it * working on it * working on it * format * Add test for schema gen * Split test * Fix type anno & classmethod * fixup: black * Fix mutable defaults. * Fix mutable defaults --------- Co-authored-by: Mark McDonald --- google/generativeai/generative_models.py | 186 +++++++-- google/generativeai/types/content_types.py | 382 +++++++++++++++++- google/generativeai/types/generation_types.py | 2 +- setup.py | 3 +- tests/test_content.py | 217 +++++++++- tests/test_generative_models.py | 20 +- 6 files changed, 747 insertions(+), 63 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index ee11669e1..b13ada142 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -71,7 +71,7 @@ def __init__( model_name: str = "gemini-pro", safety_settings: safety_types.SafetySettingOptions | None = None, generation_config: generation_types.GenerationConfigType | None = None, - tools: content_types.ToolsType = None, + tools: content_types.FunctionLibraryType | None = None, ): if "/" not in model_name: model_name = "models/" + model_name @@ -80,7 +80,7 @@ def __init__( safety_settings, harm_category_set="new" ) self._generation_config = generation_types.to_generation_config_dict(generation_config) - self._tools = content_types.to_tools(tools) + self._tools = content_types.to_function_library(tools) self._client = None self._async_client = None @@ -94,8 +94,9 @@ def __str__(self): f"""\ genai.GenerativeModel( model_name='{self.model_name}', - generation_config={self._generation_config}. - safety_settings={self._safety_settings} + generation_config={self._generation_config}, + safety_settings={self._safety_settings}, + tools={self._tools}, )""" ) @@ -107,12 +108,16 @@ def _prepare_request( contents: content_types.ContentsType, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, - **kwargs, + tools: content_types.FunctionLibraryType | None, ) -> glm.GenerateContentRequest: """Creates a `glm.GenerateContentRequest` from raw inputs.""" if not contents: raise TypeError("contents must not be empty") + tools_lib = self._get_tools_lib(tools) + if tools_lib is not None: + tools_lib = tools_lib.to_proto() + contents = content_types.to_contents(contents) generation_config = generation_types.to_generation_config_dict(generation_config) @@ -129,10 +134,17 @@ def _prepare_request( contents=contents, generation_config=merged_gc, safety_settings=merged_ss, - tools=self._tools, - **kwargs, + tools=tools_lib, ) + def _get_tools_lib( + self, tools: content_types.FunctionLibraryType + ) -> content_types.FunctionLibrary | None: + if tools is None: + return self._tools + else: + return content_types.to_function_library(tools) + def generate_content( self, contents: content_types.ContentsType, @@ -140,8 +152,8 @@ def generate_content( generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, request_options: dict[str, Any] | None = None, - **kwargs, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -201,7 +213,7 @@ def generate_content( contents=contents, generation_config=generation_config, safety_settings=safety_settings, - **kwargs, + tools=tools, ) if self._client is None: self._client = client.get_default_generative_client() @@ -230,15 +242,15 @@ async def generate_content_async( generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, stream: bool = False, + tools: content_types.FunctionLibraryType | None = None, request_options: dict[str, Any] | None = None, - **kwargs, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" request = self._prepare_request( contents=contents, generation_config=generation_config, safety_settings=safety_settings, - **kwargs, + tools=tools, ) if self._async_client is None: self._async_client = client.get_default_generative_async_client() @@ -299,6 +311,7 @@ def start_chat( self, *, history: Iterable[content_types.StrictContentType] | None = None, + enable_automatic_function_calling: bool = False, ) -> ChatSession: """Returns a `genai.ChatSession` attached to this model. @@ -314,6 +327,7 @@ def start_chat( return ChatSession( model=self, history=history, + enable_automatic_function_calling=enable_automatic_function_calling, ) @@ -341,11 +355,13 @@ def __init__( self, model: GenerativeModel, history: Iterable[content_types.StrictContentType] | None = None, + enable_automatic_function_calling: bool = False, ): self.model: GenerativeModel = model self._history: list[glm.Content] = content_types.to_contents(history) self._last_sent: glm.Content | None = None self._last_received: generation_types.BaseGenerateContentResponse | None = None + self.enable_automatic_function_calling = enable_automatic_function_calling def send_message( self, @@ -354,7 +370,7 @@ def send_message( generation_config: generation_types.GenerationConfigType = None, safety_settings: safety_types.SafetySettingOptions = None, stream: bool = False, - **kwargs, + tools: content_types.FunctionLibraryType | None = None, ) -> generation_types.GenerateContentResponse: """Sends the conversation history with the added message and returns the model's response. @@ -387,23 +403,52 @@ def send_message( safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. """ + if self.enable_automatic_function_calling and stream: + raise NotImplementedError( + "The `google.generativeai` SDK does not yet support `stream=True` with " + "`enable_automatic_function_calling=True`" + ) + + tools_lib = self.model._get_tools_lib(tools) + content = content_types.to_content(content) + if not content.role: content.role = self._USER_ROLE + history = self.history[:] history.append(content) generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: raise ValueError("Can't chat with `candidate_count > 1`") + response = self.model.generate_content( contents=history, generation_config=generation_config, safety_settings=safety_settings, stream=stream, - **kwargs, + tools=tools_lib, ) + self._check_response(response=response, stream=stream) + + if self.enable_automatic_function_calling and tools_lib is not None: + self.history, content, response = self._handle_afc( + response=response, + history=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools_lib=tools_lib, + ) + + self._last_sent = content + self._last_received = response + + return response + + def _check_response(self, *, response, stream): if response.prompt_feedback.block_reason: raise generation_types.BlockedPromptException(response.prompt_feedback) @@ -415,10 +460,49 @@ def send_message( ): raise generation_types.StopCandidateException(response.candidates[0]) - self._last_sent = content - self._last_received = response + def _get_function_calls(self, response) -> list[glm.FunctionCall]: + candidates = response.candidates + if len(candidates) != 1: + raise ValueError( + f"Automatic function calling only works with 1 candidate, got: {len(candidates)}" + ) + parts = candidates[0].content.parts + function_calls = [part.function_call for part in parts if part and "function_call" in part] + return function_calls + + def _handle_afc( + self, *, response, history, generation_config, safety_settings, stream, tools_lib + ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + + while function_calls := self._get_function_calls(response): + if not all(callable(tools_lib[fc]) for fc in function_calls): + break + history.append(response.candidates[0].content) + + function_response_parts: list[glm.Part] = [] + for fc in function_calls: + fr = tools_lib(fc) + assert fr is not None, ( + "This should never happen, it should only return None if the declaration" + "is not callable, and that's guarded against above." + ) + function_response_parts.append(fr) - return response + send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + history.append(send) + + response = self.model.generate_content( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + ) + + self._check_response(response=response, stream=stream) + + *history, content = history + return history, content, response async def send_message_async( self, @@ -427,42 +511,88 @@ async def send_message_async( generation_config: generation_types.GenerationConfigType = None, safety_settings: safety_types.SafetySettingOptions = None, stream: bool = False, - **kwargs, + tools: content_types.FunctionLibraryType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `ChatSession.send_message`.""" + if self.enable_automatic_function_calling and stream: + raise NotImplementedError( + "The `google.generativeai` SDK does not yet support `stream=True` with " + "`enable_automatic_function_calling=True`" + ) + + tools_lib = self.model._get_tools_lib(tools) + content = content_types.to_content(content) + if not content.role: content.role = self._USER_ROLE + history = self.history[:] history.append(content) generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: raise ValueError("Can't chat with `candidate_count > 1`") - response = await self.model.generate_content_async( + + response = await self.model.generate_content( contents=history, generation_config=generation_config, safety_settings=safety_settings, stream=stream, - **kwargs, + tools=tools_lib, ) - if response.prompt_feedback.block_reason: - raise generation_types.BlockedPromptException(response.prompt_feedback) + self._check_response(response=response, stream=stream) - if not stream: - if response.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, - ): - raise generation_types.StopCandidateException(response.candidates[0]) + if self.enable_automatic_function_calling and tools_lib is not None: + self.history, content, response = await self._handle_afc_async( + response=response, + history=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools_lib=tools_lib, + ) self._last_sent = content self._last_received = response return response + async def _handle_afc_async( + self, *, response, history, generation_config, safety_settings, stream, tools_lib + ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + + while function_calls := self._get_function_calls(response): + if not all(callable(tools_lib[fc]) for fc in function_calls): + break + history.append(response.candidates[0].content) + + function_response_parts: list[glm.Part] = [] + for fc in function_calls: + fr = tools_lib(fc) + assert fr is not None, ( + "This should never happen, it should only return None if the declaration" + "is not callable, and that's guarded against above." + ) + function_response_parts.append(fr) + + send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + history.append(send) + + response = await self.model.generate_content_async( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + tools=tools_lib, + ) + + self._check_response(response=response, stream=stream) + + *history, content = history + return history, content, response + def __copy__(self): return ChatSession( model=self.model, diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 0011619d6..72d4860b0 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -1,11 +1,13 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence import io +import inspect import mimetypes -import pathlib import typing -from typing import Any, TypedDict, Union +from typing import Any, Callable, TypedDict, Union + +import pydantic from google.ai import generativelanguage as glm @@ -40,7 +42,14 @@ "ContentType", "StrictContentType", "ContentsType", + "FunctionDeclaration", + "CallableFunctionDeclaration", + "FunctionDeclarationType", + "Tool", + "ToolDict", "ToolsType", + "FunctionLibrary", + "FunctionLibraryType", ] @@ -242,15 +251,364 @@ def to_contents(contents: ContentsType) -> list[glm.Content]: return contents -ToolsType = Union[Iterable[glm.Tool], glm.Tool, dict[str, Any], None] +def _generate_schema( + f: Callable[..., Any], + *, + descriptions: Mapping[str, str] | None = None, + required: Sequence[str] | None = None, +) -> dict[str, Any]: + """Generates the OpenAPI Schema for a python function. + + Args: + f: The function to generate an OpenAPI Schema for. + descriptions: Optional. A `{name: description}` mapping for annotating input + arguments of the function with user-provided descriptions. It + defaults to an empty dictionary (i.e. there will not be any + description for any of the inputs). + required: Optional. For the user to specify the set of required arguments in + function calls to `f`. If unspecified, it will be automatically + inferred from `f`. + + Returns: + dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format. + """ + if descriptions is None: + descriptions = {} + if required is None: + required = [] + defaults = dict(inspect.signature(f).parameters) + fields_dict = { + name: ( + # 1. We infer the argument type here: use Any rather than None so + # it will not try to auto-infer the type based on the default value. + (param.annotation if param.annotation != inspect.Parameter.empty else Any), + pydantic.Field( + # 2. We do not support default values for now. + # default=( + # param.default if param.default != inspect.Parameter.empty + # else None + # ), + # 3. We support user-provided descriptions. + description=descriptions.get(name, None), + ), + ) + for name, param in defaults.items() + # We do not support *args or **kwargs + if param.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + } + parameters = pydantic.create_model(f.__name__, **fields_dict).schema() + # Postprocessing + # 4. Suppress unnecessary title generation: + # * https://github.com/pydantic/pydantic/issues/1051 + # * http://cl/586221780 + parameters.pop("title", None) + for name, function_arg in parameters.get("properties", {}).items(): + function_arg.pop("title", None) + annotation = defaults[name].annotation + # 5. Nullable fields: + # * https://github.com/pydantic/pydantic/issues/1270 + # * https://stackoverflow.com/a/58841311 + # * https://github.com/pydantic/pydantic/discussions/4872 + if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args( + annotation + ): + function_arg["nullable"] = True + # 6. Annotate required fields. + if required: + # We use the user-provided "required" fields if specified. + parameters["required"] = required + else: + # Otherwise we infer it from the function signature. + parameters["required"] = [ + k + for k in defaults + if ( + defaults[k].default == inspect.Parameter.empty + and defaults[k].kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + ) + ] + schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters) + return schema -def to_tools(tools: ToolsType) -> list[glm.Tool]: - if tools is None: - return [] - elif isinstance(tools, Mapping): - return [glm.Tool(tools)] - elif isinstance(tools, Iterable): - return [glm.Tool(t) for t in tools] +def _rename_schema_fields(schema): + if schema is None: + return schema + + schema = schema.copy() + + type_ = schema.pop("type", None) + if type_ is not None: + schema["type_"] = type_.upper() + + format_ = schema.pop("format", None) + if format_ is not None: + schema["format_"] = format_ + + items = schema.pop("items", None) + if items is not None: + schema["items"] = _rename_schema_fields(items) + + properties = schema.pop("properties", None) + if properties is not None: + schema["properties"] = {k: _rename_schema_fields(v) for k, v in properties.items()} + + return schema + + +class FunctionDeclaration: + def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): + """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = glm.FunctionDeclaration( + name=name, description=description, parameters=_rename_schema_fields(parameters) + ) + + @property + def name(self) -> str: + return self._proto.name + + @property + def description(self) -> str: + return self._proto.description + + @property + def parameters(self) -> glm.Schema: + return self._proto.parameters + + @classmethod + def from_proto(cls, proto) -> FunctionDeclaration: + self = cls(name="", description="", parameters={}) + self._proto = proto + return self + + def to_proto(self) -> glm.FunctionDeclaration: + return self._proto + + @staticmethod + def from_function(function: Callable[..., Any], descriptions: dict[str, str] | None = None): + """Builds a `CallableFunctionDeclaration` from a python function. + + The function should have type annotations. + + This method is able to generate the schema for arguments annotated with types: + + `AllowedTypes = float | int | str | list[AllowedTypes] | dict` + + This method does not yet build a schema for `TypedDict`, that would allow you to specify the dictionary + contents. But you can build these manually. + """ + + if descriptions is None: + descriptions = {} + + schema = _generate_schema(function, descriptions=descriptions) + + return CallableFunctionDeclaration(**schema, function=function) + + +StructType = dict[str, "ValueType"] +ValueType = Union[float, str, bool, StructType, list["ValueType"], None] + + +class CallableFunctionDeclaration(FunctionDeclaration): + """An extension of `FunctionDeclaration` that can be built from a python function, and is callable. + + Note: The python function must have type annotations. + """ + + def __init__( + self, + *, + name: str, + description: str, + parameters: dict[str, Any] | None = None, + function: Callable[..., Any], + ): + super().__init__(name=name, description=description, parameters=parameters) + self.function = function + + def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + result = self.function(**fc.args) + if not isinstance(result, dict): + result = {"result": result} + return glm.FunctionResponse(name=fc.name, response=result) + + +FunctionDeclarationType = Union[ + FunctionDeclaration, + glm.FunctionDeclaration, + dict[str, Any], + Callable[..., Any], +] + + +def _make_function_declaration( + fun: FunctionDeclarationType, +) -> FunctionDeclaration | glm.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): + return fun + elif isinstance(fun, dict): + if "function" in fun: + return CallableFunctionDeclaration(**fun) + else: + return FunctionDeclaration(**fun) + elif callable(fun): + return CallableFunctionDeclaration.from_function(fun) + else: + raise TypeError( + "Expected an instance of `genai.FunctionDeclaraionType`. Got a:\n" f" {type(fun)=}\n", + fun, + ) + + +def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: + if isinstance(fd, glm.FunctionDeclaration): + return fd + + return fd.to_proto() + + +class Tool: + """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + + def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): + # The main path doesn't use this but is seems useful. + self._function_declarations = [_make_function_declaration(f) for f in function_declarations] + self._index = {} + for fd in self._function_declarations: + name = fd.name + if name in self._index: + raise ValueError("") + self._index[fd.name] = fd + + self._proto = glm.Tool( + function_declarations=[_encode_fd(fd) for fd in self._function_declarations] + ) + + @property + def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + return self._function_declarations + + def __getitem__( + self, name: str | glm.FunctionCall + ) -> FunctionDeclaration | glm.FunctionDeclaration: + if not isinstance(name, str): + name = name.name + + return self._index[name] + + def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + declaration = self[fc] + if not callable(declaration): + return None + + return declaration(fc) + + def to_proto(self): + return self._proto + + +class ToolDict(TypedDict): + function_declarations: list[FunctionDeclarationType] + + +ToolType = Union[ + Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType +] + + +def _make_tool(tool: ToolType) -> Tool: + if isinstance(tool, Tool): + return tool + elif isinstance(tool, glm.Tool): + return Tool(function_declarations=tool.function_declarations) + elif isinstance(tool, dict): + if "function_declarations" in tool: + return Tool(**tool) + else: + fd = tool + return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + elif isinstance(tool, Iterable): + return Tool(function_declarations=tool) + else: + try: + return Tool(function_declarations=[tool]) + except Exception as e: + raise TypeError( + "Expected an instance of `genai.ToolType`. Got a:\n" f" {type(tool)=}", + tool, + ) from e + + +class FunctionLibrary: + """A container for a set of `Tool` objects, manages lookup and execution of their functions.""" + + def __init__(self, tools: Iterable[ToolType]): + tools = _make_tools(tools) + self._tools = list(tools) + self._index = {} + for tool in self._tools: + for declaration in tool.function_declarations: + name = declaration.name + if name in self._index: + raise ValueError( + f"A `FunctionDeclaration` named {name} is already defined. " + "Each `FunctionDeclaration` must be uniquely named." + ) + self._index[declaration.name] = declaration + + def __getitem__( + self, name: str | glm.FunctionCall + ) -> FunctionDeclaration | glm.FunctionDeclaration: + if not isinstance(name, str): + name = name.name + + return self._index[name] + + def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + declaration = self[fc] + if not callable(declaration): + return None + + response = declaration(fc) + return glm.Part(function_response=response) + + def to_proto(self): + return [tool.to_proto() for tool in self._tools] + + +ToolsType = Union[Iterable[ToolType], ToolType] + + +def _make_tools(tools: ToolsType) -> list[Tool]: + if isinstance(tools, Iterable) and not isinstance(tools, Mapping): + tools = [_make_tool(t) for t in tools] + if len(tools) > 1 and all(len(t.function_declarations) == 1 for t in tools): + # flatten into a single tool. + tools = [_make_tool([t.function_declarations[0] for t in tools])] + return tools + else: + tool = tools + return [_make_tool(tool)] + + +FunctionLibraryType = Union[FunctionLibrary, ToolsType] + + +def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | None: + if lib is None: + return lib + elif isinstance(lib, FunctionLibrary): + return lib else: - return [glm.Tool(tools)] + return FunctionLibrary(tools=lib) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 0e061d5e3..d7da4cbe4 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -7,7 +7,7 @@ import dataclasses import itertools import textwrap -from typing import List, Tuple, TypedDict, Union +from typing import TypedDict, Union import google.protobuf.json_format import google.api_core.exceptions diff --git a/setup.py b/setup.py index 9f777f454..a4e75272d 100644 --- a/setup.py +++ b/setup.py @@ -45,9 +45,10 @@ def get_version(): "google-ai-generativelanguage==0.4.0", "google-auth>=2.15.0", # 2.15 adds API key auth support "google-api-core", - "typing-extensions", "protobuf", + "pydantic", "tqdm", + "typing-extensions", ] extras_require = { diff --git a/tests/test_content.py b/tests/test_content.py index 2c1253b02..6d3333956 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -1,13 +1,24 @@ -import copy +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pathlib -import unittest.mock +from typing import Any from absl.testing import absltest from absl.testing import parameterized import google.ai.generativelanguage as glm -import google.generativeai as genai from google.generativeai.types import content_types -from google.generativeai.types import safety_types import IPython.display import PIL.Image @@ -22,6 +33,11 @@ TEST_JPG_DATA = TEST_JPG_PATH.read_bytes() +# simple test function +def datetime(): + "Returns the current UTC date and time." + + class UnitTests(parameterized.TestCase): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], @@ -173,7 +189,87 @@ def test_img_to_contents(self, example): @parameterized.named_parameters( [ - "OneTool", + "FunctionLibrary", + content_types.FunctionLibrary( + tools=glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ) + ] + ) + ), + ], + [ + "IterableTool-Tool", + [ + content_types.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ) + ] + ) + ], + ], + [ + "IterableTool-glm.Tool", + [ + glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + ) + ] + ) + ], + ], + [ + "IterableTool-ToolDict", + [ + dict( + function_declarations=[ + dict( + name="datetime", + description="Returns the current UTC date and time.", + ) + ] + ) + ], + ], + [ + "IterableTool-IterableFD", + [ + [ + glm.FunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + ) + ] + ], + ], + [ + "IterableTool-FD", + [ + glm.FunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + ) + ], + ], + [ + "Tool", + content_types.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ) + ] + ), + ], + [ + "glm.Tool", glm.Tool( function_declarations=[ glm.FunctionDeclaration( @@ -191,27 +287,118 @@ def test_img_to_contents(self, example): ), ], [ - "ListOfTools", + "IterableFD-FD", [ - glm.Tool( - function_declarations=[ - glm.FunctionDeclaration( - name="datetime", - description="Returns the current UTC date and time.", - ) - ] + content_types.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ) + ], + ], + [ + "IterableFD-CFD", + [ + content_types.CallableFunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + function=datetime, ) ], ], + [ + "IterableFD-dict", + [dict(name="datetime", description="Returns the current UTC date and time.")], + ], + ["IterableFD-Callable", [datetime]], + [ + "FD", + content_types.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ), + ], + [ + "CFD", + content_types.CallableFunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + function=datetime, + ), + ], + [ + "glm.FD", + glm.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ), + ], + ["dict", dict(name="datetime", description="Returns the current UTC date and time.")], + ["Callable", datetime], ) def test_to_tools(self, tools): - tools = content_types.to_tools(tools) + function_library = content_types.to_function_library(tools) + if function_library is None: + raise ValueError("This shouldn't happen") + tools = function_library.to_proto() + + tools = type(tools[0]).to_dict(tools[0]) + tools["function_declarations"][0].pop("parameters", None) + expected = dict( function_declarations=[ dict(name="datetime", description="Returns the current UTC date and time.") ] ) - self.assertEqual(type(tools[0]).to_dict(tools[0]), expected) + + self.assertEqual(tools, expected) + + def test_two_fun_is_one_tool(self): + def a(): + pass + + def b(): + pass + + function_library = content_types.to_function_library([a, b]) + if function_library is None: + raise ValueError("This shouldn't happen") + tools = function_library.to_proto() + + self.assertLen(tools, 1) + self.assertLen(tools[0].function_declarations, 2) + + @parameterized.named_parameters( + ["int", int, glm.Schema(type=glm.Type.INTEGER)], + ["float", float, glm.Schema(type=glm.Type.NUMBER)], + ["str", str, glm.Schema(type=glm.Type.STRING)], + [ + "list", + list[str], + glm.Schema( + type=glm.Type.ARRAY, + items=glm.Schema(type=glm.Type.STRING), + ), + ], + [ + "list-list-int", + list[list[int]], + glm.Schema( + type=glm.Type.ARRAY, + items=glm.Schema( + glm.Schema( + type=glm.Type.ARRAY, + items=glm.Schema(type=glm.Type.INTEGER), + ), + ), + ), + ], + ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], + ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ) + def test_auto_schema(self, annotation, expected): + def fun(a: annotation): + pass + + cfd = content_types.FunctionDeclaration.from_function(fun) + got = cfd.parameters.properties["a"] + self.assertEqual(got, expected) if __name__ == "__main__": diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 78b25690d..4a63a8767 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -641,6 +641,11 @@ def test_count_tokens_smoke(self, contents): generative_models.ChatSession.send_message, generative_models.ChatSession.send_message_async, ], + [ + "ChatSession._handle_afc", + generative_models.ChatSession._handle_afc, + generative_models.ChatSession._handle_afc_async, + ], ) def test_async_code_match(self, obj, aobj): import inspect @@ -879,8 +884,9 @@ def test_repr_for_multi_turn_chat(self): ChatSession( model=genai.GenerativeModel( model_name='models/gemini-pro', - generation_config={}. - safety_settings={} + generation_config={}, + safety_settings={}, + tools=None, ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" @@ -905,8 +911,9 @@ def test_repr_for_incomplete_streaming_chat(self): ChatSession( model=genai.GenerativeModel( model_name='models/gemini-pro', - generation_config={}. - safety_settings={} + generation_config={}, + safety_settings={}, + tools=None, ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -947,8 +954,9 @@ def test_repr_for_broken_streaming_chat(self): ChatSession( model=genai.GenerativeModel( model_name='models/gemini-pro', - generation_config={}. - safety_settings={} + generation_config={}, + safety_settings={}, + tools=None, ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )"""