diff --git a/.github/workflows/test_pr.yaml b/.github/workflows/test_pr.yaml index 94f9d2221..93838bf77 100644 --- a/.github/workflows/test_pr.yaml +++ b/.github/workflows/test_pr.yaml @@ -23,8 +23,8 @@ jobs: - name: Run tests run: | python --version - pip install -q -e .[dev] - python -m unittest discover --pattern '*test*.py' + pip install .[dev] + python -m unittest test3_10: name: Test Py3.10 runs-on: ubuntu-latest @@ -36,8 +36,8 @@ jobs: - name: Run tests run: | python --version - pip install -q -e .[dev] - python -m unittest discover --pattern '*test*.py' + pip install -q .[dev] + python -m unittest test3_9: name: Test Py3.9 runs-on: ubuntu-latest @@ -49,8 +49,8 @@ jobs: - name: Run tests run: | python --version - pip install -q -e .[dev] - python -m unittest discover --pattern '*test*.py' + pip install .[dev] + python -m unittest pytype3_10: name: pytype 3.10 runs-on: ubuntu-latest @@ -62,7 +62,7 @@ jobs: - name: Run pytype run: | python --version - pip install -q -e .[dev] + pip install .[dev] pip install -q gspread ipython pytype format: @@ -76,7 +76,7 @@ jobs: - name: Check format run: | python --version - pip install -q -e . + pip install -q . pip install -q black black . --check diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 74026804a..dda76a2a1 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -42,9 +42,10 @@ Use the `palm.chat` function to have a discussion with a model: ``` -response = palm.chat(messages=["Hello."]) -print(response.last) # 'Hello! What can I help you with?' -response.reply("Can you tell me a joke?") +chat = palm.chat(messages=["Hello."]) +print(chat.last) # 'Hello! What can I help you with?' +chat = chat.reply("Can you tell me a joke?") +print(chat.last) # 'Why did the chicken cross the road?' ``` ## Models @@ -68,13 +69,20 @@ """ from __future__ import annotations -from google.generativeai import types from google.generativeai import version +from google.generativeai import types +from google.generativeai.types import GenerationConfig + + from google.generativeai.discuss import chat from google.generativeai.discuss import chat_async from google.generativeai.discuss import count_message_tokens +from google.generativeai.embedding import embed_content + +from google.generativeai.generative_models import GenerativeModel + from google.generativeai.text import generate_text from google.generativeai.text import generate_embeddings from google.generativeai.text import count_text_tokens diff --git a/google/generativeai/client.py b/google/generativeai/client.py index dead136c9..2a8b15a20 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -1,17 +1,3 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations import os @@ -27,7 +13,12 @@ from google.api_core import gapic_v1 from google.api_core import operations_v1 -from google.generativeai import version +try: + from google.generativeai import version + + __version__ = version.__version__ +except ImportError: + __version__ = "0.0.0" USER_AGENT = "genai-py" @@ -36,11 +27,10 @@ class _ClientManager: client_config: dict[str, Any] = dataclasses.field(default_factory=dict) default_metadata: Sequence[tuple[str, str]] = () + discuss_client: glm.DiscussServiceClient | None = None discuss_async_client: glm.DiscussServiceAsyncClient | None = None - model_client: glm.ModelServiceClient | None = None - text_client: glm.TextServiceClient | None = None - operations_client = None + clients: dict[str, Any] = dataclasses.field(default_factory=dict) def configure( self, @@ -54,7 +44,7 @@ def configure( # We could accept a dict since all the `Transport` classes take the same args, # but that seems rare. Users that need it can just switch to the low level API. transport: str | None = None, - client_options: client_options_lib.ClientOptions | dict | None = None, + client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None, client_info: gapic_v1.client_info.ClientInfo | None = None, default_metadata: Sequence[tuple[str, str]] = (), ) -> None: @@ -93,7 +83,7 @@ def configure( client_options.api_key = api_key - user_agent = f"{USER_AGENT}/{version.__version__}" + user_agent = f"{USER_AGENT}/{__version__}" if client_info: # Be respectful of any existing agent setting. if client_info.user_agent: @@ -114,12 +104,16 @@ def configure( self.client_config = client_config self.default_metadata = default_metadata - self.discuss_client = None - self.text_client = None - self.model_client = None - self.operations_client = None - def make_client(self, cls): + self.clients = {} + + def make_client(self, name): + if name.endswith("_async"): + name = name.split("_")[0] + cls = getattr(glm, name.title() + "ServiceAsyncClient") + else: + cls = getattr(glm, name.title() + "ServiceClient") + # Attempt to configure using defaults. if not self.client_config: configure() @@ -157,35 +151,25 @@ def call(*args, metadata=(), **kwargs): return client - def get_default_discuss_client(self) -> glm.DiscussServiceClient: - if self.discuss_client is None: - self.discuss_client = self.make_client(glm.DiscussServiceClient) - return self.discuss_client - - def get_default_text_client(self) -> glm.TextServiceClient: - if self.text_client is None: - self.text_client = self.make_client(glm.TextServiceClient) - return self.text_client - - def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient: - if self.discuss_async_client is None: - self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient) - return self.discuss_async_client + def get_default_client(self, name): + name = name.lower() + if name == "operations": + return self.get_default_operations_client() - def get_default_model_client(self) -> glm.ModelServiceClient: - if self.model_client is None: - self.model_client = self.make_client(glm.ModelServiceClient) - return self.model_client + client = self.clients.get(name) + if client is None: + client = self.make_client(name) + self.clients[name] = client + return client def get_default_operations_client(self) -> operations_v1.OperationsClient: - if self.operations_client is None: - self.model_client = get_default_model_client() - self.operations_client = self.model_client._transport.operations_client - - return self.operations_client - + client = self.clients.get("operations", None) + if client is None: + model_client = self.get_default_client("Model") + client = model_client._transport.operations_client + self.clients["operations"] = client -_client_manager = _ClientManager() + return client def configure( @@ -230,21 +214,33 @@ def configure( ) +_client_manager = _ClientManager() +_client_manager.configure() + + def get_default_discuss_client() -> glm.DiscussServiceClient: - return _client_manager.get_default_discuss_client() + return _client_manager.get_default_client("discuss") -def get_default_text_client() -> glm.TextServiceClient: - return _client_manager.get_default_text_client() +def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: + return _client_manager.get_default_client("discuss_async") -def get_default_operations_client() -> operations_v1.OperationsClient: - return _client_manager.get_default_operations_client() +def get_default_generative_client() -> glm.GenerativeServiceClient: + return _client_manager.get_default_client("generative") -def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: - return _client_manager.get_default_discuss_async_client() +def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient: + return _client_manager.get_default_client("generative_async") + + +def get_default_text_client() -> glm.TextServiceClient: + return _client_manager.get_default_client("text") + + +def get_default_operations_client() -> operations_v1.OperationsClient: + return _client_manager.get_default_client("operations") def get_default_model_client() -> glm.ModelServiceAsyncClient: - return _client_manager.get_default_model_client() + return _client_manager.get_default_client("model") diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 6ca60a6a9..b41b3ea85 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -301,16 +301,6 @@ def _make_generate_message_request( ) -def set_doc(doc): - """A decorator to set the docstring of a function.""" - - def inner(f): - f.__doc__ = doc - return f - - return inner - - DEFAULT_DISCUSS_MODEL = "models/chat-bison-001" @@ -411,7 +401,7 @@ def chat( return _generate_response(client=client, request=request) -@set_doc(chat.__doc__) +@string_utils.set_doc(chat.__doc__) async def chat_async( *, model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", @@ -447,7 +437,7 @@ async def chat_async( @string_utils.prettyprint -@set_doc(discuss_types.ChatResponse.__doc__) +@string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) @@ -457,7 +447,7 @@ def __init__(self, **kwargs): setattr(self, key, value) @property - @set_doc(discuss_types.ChatResponse.last.__doc__) + @string_utils.set_doc(discuss_types.ChatResponse.last.__doc__) def last(self) -> str | None: if self.messages[-1]: return self.messages[-1]["content"] @@ -470,7 +460,7 @@ def last(self, message: discuss_types.MessageOptions): message = type(message).to_dict(message) self.messages[-1] = message - @set_doc(discuss_types.ChatResponse.reply.__doc__) + @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError(f"reply can't be called on an async client, use reply_async instead.") @@ -489,7 +479,7 @@ def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResp request = _make_generate_message_request(**request) return _generate_response(request=request, client=self._client) - @set_doc(discuss_types.ChatResponse.reply.__doc__) + @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) async def reply_async( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py new file mode 100644 index 000000000..42edb7fcf --- /dev/null +++ b/google/generativeai/embedding.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable, Sequence, Mapping +import itertools +from typing import Iterable, overload, TypeVar, Union, Mapping + +import google.ai.generativelanguage as glm + +from google.generativeai.client import get_default_generative_client + +from google.generativeai.types import text_types +from google.generativeai.types import model_types +from google.generativeai.types import content_types + +DEFAULT_EMB_MODEL = "models/embedding-001" +EMBEDDING_MAX_BATCH_SIZE = 100 + +EmbeddingTaskType = glm.TaskType + +EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] + +_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = { + EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + 0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + "task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + "unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED, + EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY, + 1: EmbeddingTaskType.RETRIEVAL_QUERY, + "retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY, + "query": EmbeddingTaskType.RETRIEVAL_QUERY, + EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT, + 2: EmbeddingTaskType.RETRIEVAL_DOCUMENT, + "retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT, + "document": EmbeddingTaskType.RETRIEVAL_DOCUMENT, + EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY, + 3: EmbeddingTaskType.SEMANTIC_SIMILARITY, + "semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY, + "similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY, + EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION, + 4: EmbeddingTaskType.CLASSIFICATION, + "classification": EmbeddingTaskType.CLASSIFICATION, + EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING, + 5: EmbeddingTaskType.CLUSTERING, + "clustering": EmbeddingTaskType.CLUSTERING, +} + + +def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType: + if isinstance(x, str): + x = x.lower() + return _EMBEDDING_TASK_TYPE[x] + + +try: + # python 3.12+ + _batched = itertools.batched # type: ignore +except AttributeError: + T = TypeVar("T") + + def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: + if n < 1: + raise ValueError(f"Batch size `n` must be >0, got: {n}") + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + + if batch: + yield batch + + +@overload +def embed_content( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType, + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + client: glm.GenerativeServiceClient | None = None, +) -> text_types.EmbeddingDict: + ... + + +@overload +def embed_content( + model: model_types.BaseModelNameOptions, + content: Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + client: glm.GenerativeServiceClient | None = None, +) -> text_types.BatchEmbeddingDict: + ... + + +def embed_content( + model: model_types.BaseModelNameOptions, + content: content_types.ContentType | Iterable[content_types.ContentType], + task_type: EmbeddingTaskTypeOptions | None = None, + title: str | None = None, + client: glm.GenerativeServiceClient = None, +) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: + """ + Calls the API to create embeddings for content passed in. + + Args: + model: Which model to call, as a string or a `types.Model`. + + content: Content to embed. + + task_type: Optional task type for which the embeddings will be used. Can only be set for `models/embedding-001`. + + title: An optional title for the text. Only applicable when task_type is `RETRIEVAL_DOCUMENT`. + + Return: + Dictionary containing the embedding (list of float values) for the input content. + """ + model = model_types.make_model_name(model) + + if client is None: + client = get_default_generative_client() + + if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: + raise ValueError( + "If a title is specified, the task must be a retrieval document type task." + ) + + if task_type: + task_type = to_task_type(task_type) + + if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): + result = {"embedding": []} + requests = ( + glm.EmbedContentRequest( + model=model, content=content_types.to_content(c), task_type=task_type, title=title + ) + for c in content + ) + for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): + embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_response = client.batch_embed_contents(embedding_request) + embedding_dict = type(embedding_response).to_dict(embedding_response) + result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) + return result + else: + embedding_request = glm.EmbedContentRequest( + model=model, content=content_types.to_content(content), task_type=task_type, title=title + ) + embedding_response = client.embed_content(embedding_request) + embedding_dict = type(embedding_response).to_dict(embedding_response) + embedding_dict["embedding"] = embedding_dict["embedding"]["values"] + return embedding_dict diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py new file mode 100644 index 000000000..bf5e6f5ea --- /dev/null +++ b/google/generativeai/generative_models.py @@ -0,0 +1,480 @@ +"""Classes for working with the Gemini models.""" + +from __future__ import annotations + +import dataclasses +import textwrap + +# pylint: disable=bad-continuation, line-too-long + + +from collections.abc import Iterable + +from google.ai import generativelanguage as glm +from google.generativeai import client +from google.generativeai import string_utils +from google.generativeai.types import content_types +from google.generativeai.types import generation_types +from google.generativeai.types import safety_types + +_GENERATE_CONTENT_ASYNC_DOC = """The async version of `Model.generate_content`.""" + +_GENERATE_CONTENT_DOC = """A multipurpose function to generate responses from the model. + +This `GenerativeModel.generate_content` method can handle multimodal input, and multiturn +conversations. + +>>> model = genai.GenerativeModel('models/gemini-pro') +>>> result = model.generate_content('Tell me a story about a magic backpack') +>>> response.text + +### Streaming + +This method supports streaming with the `stream=True`. The result has the same type as the non streaming case, +but you can iterate over the response chunks as they become available: + +>>> result = model.generate_content('Tell me a story about a magic backpack', stream=True) +>>> for chunk in response: +... print(chunk.text) + +### Multi-turn + +This method supports multi-turn chats but is **stateless**: the entire conversation history needs to be sent with each +request. This takes some manual management but gives you complete control: + +>>> messages = [{'role':'user', 'parts': ['hello']}] +>>> response = model.generate_content(messages) # "Hello, how can I help" +>>> messages.append(response.candidates[0].content) +>>> messages.append({'role':'user', 'parts': ['How does quantum physics work?']}) +>>> response = model.generate_content(messages) + +For a simpler multi-turn interface see `GenerativeModel.start_chat`. + +### Input type flexibility + +While the underlying API strictly expects a `list[glm.Content]` objects, this method +will convert the user input into the correct type. The hierarchy of types that can be +converted is below. Any of these objects can be passed as an equivalent `dict`. + +* `Iterable[glm.Content]` +* `glm.Content` +* `Iterable[glm.Part]` +* `glm.Part` +* `str`, `Image`, or `glm.Blob` + +In an `Iterable[glm.Content]` each `content` is a separate message. +But note that an `Iterable[glm.Part]` is taken as the parts of a single message. + +Arguments: + contents: The contents serving as the model's prompt. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + stream: If True, yield response chunks as they are generated. +""" + +_SEND_MESSAGE_ASYNC_DOC = """The async version of `ChatSession.send_message`.""" + +_SEND_MESSAGE_DOC = """Sends the conversation history with the added message and returns the model's response. + +Appends the request and response to the conversation history. + +>>> model = genai.GenerativeModel(model="gemini-pro") +>>> chat = model.start_chat() +>>> response = chat.send_message("Hello") +>>> print(response.text) +"Hello! How can I assist you today?" +>>> len(chat.history) +2 + +Call it with `stream=True` to receive response chunks as they are generated: + +>>> chat = model.start_chat() +>>> response = chat.send_message("Explain quantum physics", stream=True) +>>> for chunk in response: +... print(chunk.text, end='') + +Once iteration over chunks is complete, the `response` and `ChatSession` are in states identical to the +`stream=False` case. Some properties are not available until iteration is complete. + +Like `GenerativeModel.generate_content` this method lets you override the model's `generation_config` and +`safety_settings`. + +Arguments: + content: The message contents. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + stream: If True, yield response chunks as they are generated. +""" + + +class GenerativeModel: + """ + The `genai.GenerativeModel` class wraps default parameters for calls to + `GenerativeModel.generate_message`, `GenerativeModel.count_tokens`, and + `GenerativeModel.start_chat`. + + This family of functionality is designed to support multi-turn conversations, and multimodal + requests. What media-types are supported for input and output is model-dependant. + + >>> import google.generativeai as genai + >>> import PIL.Image + >>> genai.configure(api_key='YOUR_API_KEY') + >>> model = genai.GenerativeModel('models/gemini-pro') + >>> result = model.generate_content('Tell me a story about a magic backpack') + >>> response.text + "In the quaint little town of Lakeside, there lived a young girl named Lily..." + + Multimodal input: + + >>> model = genai.GenerativeModel('models/gemini-pro') + >>> result = model.generate_content([ + ... "Give me a recipe for these:", PIL.Image.open('scones.jpeg')]) + >>> response.text + "**Blueberry Scones** ..." + + Multi-turn conversation: + + >>> chat = model.start_chat() + >>> response = chat.send_message("Hi, I have some questions for you.") + >>> response.text + "Sure, I'll do my best to answer your questions..." + + To list the compatible model names use: + + >>> for m in genai.list_models(): + ... if 'generateContent' in m.supported_generation_methods: + ... print(m.name) + + Arguments: + model_name: The name of the model to query. To list compatible models use + safety_settings: Sets the default safety filters. This controls which content is blocked + by the api before being returned. + generation_config: A `genai.GenerationConfig` setting the default generation parameters to + use. + """ + + def __init__( + self, + model_name: str = "gemini-m", + safety_settings: safety_types.SafetySettingOptions | None = None, + generation_config: generation_types.GenerationConfigType | None = None, + ): + if "/" not in model_name: + model_name = "models/" + model_name + self._model_name = model_name + self._safety_settings = safety_types.to_easy_safety_dict( + safety_settings, harm_category_set="new" + ) + self._generation_config = generation_types.to_generation_config_dict(generation_config) + self._client = None + self._async_client = None + + @property + def model_name(self): + return self._model_name + + def __str__(self): + return textwrap.dedent( + f""" \ + genai.GenerativeModel( + model_name='{self.model_name}', + generation_config={self._generation_config}. + safety_settings={self._safety_settings} + )""" + ) + + __repr__ = __str__ + + def _prepare_request( + self, + *, + contents: content_types.ContentsType, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + **kwargs, + ) -> glm.GenerateContentRequest: + """Creates a `glm.GenerateContentRequest` from raw inputs.""" + if not contents: + raise TypeError("contents must not be empty") + + contents = content_types.to_contents(contents) + + generation_config = generation_types.to_generation_config_dict(generation_config) + merged_gc = self._generation_config.copy() + merged_gc.update(generation_config) + + safety_settings = safety_types.to_easy_safety_dict(safety_settings, harm_category_set="new") + merged_ss = self._safety_settings.copy() + merged_ss.update(safety_settings) + merged_ss = safety_types.normalize_safety_settings(merged_ss, harm_category_set="new") + + return glm.GenerateContentRequest( + model=self._model_name, + contents=contents, + generation_config=merged_gc, + safety_settings=merged_ss, + **kwargs, + ) + + @string_utils.set_doc(_GENERATE_CONTENT_DOC) + def generate_content( + self, + contents: content_types.ContentsType, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stream: bool = False, + **kwargs, + ) -> generation_types.GenerateContentResponse: + request = self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + **kwargs, + ) + if self._client is None: + self._client = client.get_default_generative_client() + + if stream: + with generation_types.rewrite_stream_error(): + iterator = self._client.stream_generate_content(request) + return generation_types.GenerateContentResponse.from_iterator(iterator) + else: + response = self._client.generate_content(request) + return generation_types.GenerateContentResponse.from_response(response) + + @string_utils.set_doc(_GENERATE_CONTENT_ASYNC_DOC) + async def generate_content_async( + self, + contents: content_types.ContentsType, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stream: bool = False, + **kwargs, + ) -> generation_types.AsyncGenerateContentResponse: + request = self._prepare_request( + contents=contents, + generation_config=generation_config, + safety_settings=safety_settings, + **kwargs, + ) + if self._async_client is None: + self._async_client = client.get_default_generative_async_client() + + if stream: + with generation_types.rewrite_stream_error(): + iterator = await self._async_client.stream_generate_content(request) + return await generation_types.AsyncGenerateContentResponse.from_aiterator(iterator) + else: + response = await self._async_client.generate_content(request) + return generation_types.AsyncGenerateContentResponse.from_response(response) + + # fmt: off + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(self.model_name, contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._client.count_tokens(self.model_name, contents) + # fmt: on + + def start_chat( + self, + *, + history: Iterable[content_types.StrictContentType] | None = None, + ) -> ChatSession: + if self._generation_config.get("candidate_count", 1) > 1: + raise ValueError("Can't chat with `candidate_count > 1`") + return ChatSession( + model=self, + history=history, + ) + + +class ChatSession: + """Contains an ongoing conversation with the model. + + >>> model = genai.GenerativeModel(model="gemini-pro") + >>> chat = model.start_chat() + >>> response = chat.send_message("Hello") + >>> print(response.text) + >>> response = chat.send_message(...) + + This `ChatSession` object collects the messages sent and received, in its + `ChatSession.history` attribute. + + Arguments: + model: The model to use in the chat. + history: A chat history to initialize the object with. + """ + + _USER_ROLE = "user" + _MODEL_ROLE = "model" + + def __init__( + self, + model: GenerativeModel, + history: Iterable[content_types.StrictContentType] | None = None, + ): + self.model: GenerativeModel = model + self._history: list[glm.Content] = content_types.to_contents(history) + self._last_sent: glm.Content | None = None + self._last_received: generation_types.BaseGenerateContentResponse | None = None + + @string_utils.set_doc(_SEND_MESSAGE_DOC) + def send_message( + self, + content: content_types.ContentType, + *, + generation_config: generation_types.GenerationConfigType = None, + safety_settings: safety_types.SafetySettingOptions = None, + stream: bool = False, + **kwargs, + ) -> generation_types.GenerateContentResponse: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history = self.history[:] + history.append(content) + + generation_config = generation_types.to_generation_config_dict(generation_config) + if generation_config.get("candidate_count", 1) > 1: + raise ValueError("Can't chat with `candidate_count > 1`") + response = self.model.generate_content( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + **kwargs, + ) + + if response.prompt_feedback.block_reason: + raise generation_types.BlockedPromptException(response.prompt_feedback) + + if not stream: + if response.candidates[0].finish_reason not in ( + glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + glm.Candidate.FinishReason.STOP, + glm.Candidate.FinishReason.MAX_TOKENS, + ): + raise generation_types.StopCandidateException(response.candidates[0]) + + self._last_sent = content + self._last_received = response + + return response + + @string_utils.set_doc(_SEND_MESSAGE_ASYNC_DOC) + async def send_message_async( + self, + content: content_types.ContentType, + *, + generation_config: generation_types.GenerationConfigType = None, + safety_settings: safety_types.SafetySettingOptions = None, + stream: bool = False, + **kwargs, + ) -> generation_types.AsyncGenerateContentResponse: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history = self.history[:] + history.append(content) + + generation_config = generation_types.to_generation_config_dict(generation_config) + if generation_config.get("candidate_count", 1) > 1: + raise ValueError("Can't chat with `candidate_count > 1`") + response = await self.model.generate_content_async( + contents=history, + generation_config=generation_config, + safety_settings=safety_settings, + stream=stream, + **kwargs, + ) + + if response.prompt_feedback.block_reason: + raise generation_types.BlockedPromptException(response.prompt_feedback) + + if not stream: + if response.candidates[0].finish_reason not in ( + glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + glm.Candidate.FinishReason.STOP, + glm.Candidate.FinishReason.MAX_TOKENS, + ): + raise generation_types.StopCandidateException(response.candidates[0]) + + self._last_sent = content + self._last_received = response + + return response + + def __copy__(self): + return ChatSession( + model=self.model, + # Be sure the copy doesn't share the history. + history=list(self.history), + ) + + def rewind(self) -> tuple[glm.Content, glm.Content]: + """Removes the last request/response pair from the chat history.""" + if self._last_received is None: + result = self._history.pop(-2), self._history.pop() + return result + else: + result = self._last_sent, self._last_received.candidates[0].content + self._last_sent = None + self._last_received = None + return result + + @property + def last(self) -> generation_types.BaseGenerateContentResponse | None: + """returns the last received `genai.GenerateContentResponse`""" + return self._last_received + + @property + def history(self) -> list[glm.Content]: + """The chat history.""" + last = self._last_received + if last is None: + return self._history + + if last.candidates[0].finish_reason not in ( + glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + glm.Candidate.FinishReason.STOP, + glm.Candidate.FinishReason.MAX_TOKENS, + ): + error = generation_types.StopCandidateException(last.candidates[0]) + last._error = error + + if last._error is not None: + raise generation_types.BrokenResponseError( + "Can not build a coherent char history after a broken " + "streaming response " + "(See the previous Exception fro details). " + "To inspect the last response object, use `chat.last`." + "To remove the last request/response `Content` objects from the chat " + "call `last_send, last_received = chat.rewind()` and continue " + "without it." + ) from last._error + + sent = self._last_sent + received = self._last_received.candidates[0].content + if not received.role: + received.role = self._MODEL_ROLE + self._history.extend([sent, received]) + + self._last_sent = None + self._last_received = None + + return self._history + + @history.setter + def history(self, history): + self._history = content_types.to_contents(history) + self._last_self = None + self._last_received = None diff --git a/google/generativeai/string_utils.py b/google/generativeai/string_utils.py index 6cda23635..049bb885b 100644 --- a/google/generativeai/string_utils.py +++ b/google/generativeai/string_utils.py @@ -21,6 +21,16 @@ import textwrap +def set_doc(doc): + """A decorator to set the docstring of a function.""" + + def inner(f): + f.__doc__ = doc + return f + + return inner + + def strip_oneof(docstring): lines = docstring.splitlines() lines = [line for line in lines if ".. _oneof:" not in line] diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 18f7b6268..98d49964d 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -108,7 +108,9 @@ def _make_generate_text_request( """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) - safety_settings = safety_types.normalize_safety_settings(safety_settings) + safety_settings = safety_types.normalize_safety_settings( + safety_settings, harm_category_set="old" + ) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index 0bdf3a713..34720a465 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -18,6 +18,8 @@ from google.generativeai.types.model_types import * from google.generativeai.types.text_types import * from google.generativeai.types.citation_types import * +from google.generativeai.types.content_types import * +from google.generativeai.types.generation_types import * from google.generativeai.types.safety_types import * del discuss_types diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py new file mode 100644 index 000000000..b7c96e310 --- /dev/null +++ b/google/generativeai/types/content_types.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +import io +import mimetypes +import pathlib +import typing +from typing import Any, TypedDict, Union + +from google.ai import generativelanguage as glm + +if typing.TYPE_CHECKING: + import PIL.Image + import IPython.display + + IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) +else: + IMAGE_TYPES = () + try: + import PIL.Image + + IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) + except ImportError: + PIL = None + + try: + import IPython.display + + IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,) + except ImportError: + IPython = None + + +__all__ = [ + "BlobDict", + "BlobType", + "PartDict", + "PartType", + "ContentDict", + "ContentType", + "StrictContentType", + "ContentsType", +] + + +def pil_to_png_bytes(img): + bytesio = io.BytesIO() + img.save(bytesio, format="PNG") + bytesio.seek(0) + return bytesio.read() + + +def image_to_blob(image) -> glm.Blob: + if PIL is not None: + if isinstance(image, PIL.Image.Image): + return glm.Blob(mime_type="image/png", data=pil_to_png_bytes(image)) + + if IPython is not None: + if isinstance(image, IPython.display.Image): + name = image.filename + if name is None: + raise ValueError( + "Can only convert `IPython.display.Image` if " + "it is constructed from a local file (Image(filename=...))." + ) + + mime_type, _ = mimetypes.guess_type(name) + if mime_type is None: + mime_type = "image/unknown" + + return glm.Blob(mime_type=mime_type, data=image.data) + + raise TypeError( + "Could not convert image. epected an `Image` type" + "(`PIL.Image.Image` or `IPython.display.Image`).\n" + f"Got a: {type(image)}\n" + f"Value: {image}" + ) + + +class BlobDict(TypedDict): + mime_type: str + data: bytes + + +def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: + if is_content_dict(d): + content = dict(d) + content["parts"] = [to_part(part) for part in content["parts"]] + return glm.Content(content) + elif is_part_dict(d): + part = dict(d) + if "inline_data" in part: + part["inline_data"] = to_blob(part["inline_data"]) + return glm.Part(part) + elif is_blob_dict(d): + blob = d + return glm.Blob(blob) + else: + raise KeyError( + "Could not recognize the intended type of the `dict`. " + "A `Content` should have a 'parts' key. " + "A `Part` should have a 'inline_data' or a 'text' key. " + "A `Blob` should have 'mime_type' and 'data' keys. " + f"Got keys: {list(d.keys())}" + ) + + +def is_blob_dict(d): + return "mime_type" in d and "data" in d + + +if typing.TYPE_CHECKING: + BlobType = Union[ + glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image + ] # Any for the images +else: + BlobType = Union[glm.Blob, BlobDict, Any] + + +def to_blob(blob: BlobType) -> glm.Blob: + if isinstance(blob, Mapping): + blob = _convert_dict(blob) + + if isinstance(blob, glm.Blob): + return blob + elif isinstance(blob, IMAGE_TYPES): + return image_to_blob(blob) + else: + if isinstance(blob, Mapping): + raise KeyError( + "Could not recognize the intended type of the `dict`\n" "A content should have " + ) + raise TypeError( + "Could not create `Blob`, expected `Blob`, `dict` or an `Image` type" + "(`PIL.Image.Image` or `IPython.display.Image`).\n" + f"Got a: {type(blob)}\n" + f"Value: {blob}" + ) + + +class PartDict(TypedDict): + text: str + inline_data: BlobType + + +# When you need a `Part` accept a part object, part-dict, blob or string +PartType = Union[glm.Part, PartDict, BlobType, str] + + +def is_part_dict(d): + return "text" in d or "inline_data" in d + + +def to_part(part: PartType): + if isinstance(part, Mapping): + part = _convert_dict(part) + + if isinstance(part, glm.Part): + return part + elif isinstance(part, str): + return glm.Part(text=part) + else: + # Maybe it can be turned into a blob? + return glm.Part(inline_data=to_blob(part)) + + +class ContentDict(TypedDict): + parts: list[PartType] + role: str + + +def is_content_dict(d): + return "parts" in d + + +# When you need a message accept a `Content` object or dict, a list of parts, +# or a single part +ContentType = Union[glm.Content, ContentDict, Iterable[PartType], PartType] + +# For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet. +StrictContentType = Union[glm.Content, ContentDict] + + +def to_content(content: ContentType): + if isinstance(content, Mapping): + content = _convert_dict(content) + + if isinstance(content, glm.Content): + return content + elif isinstance(content, Iterable) and not isinstance(content, str): + return glm.Content(parts=[to_part(part) for part in content]) + else: + # Maybe this is a Part? + return glm.Content(parts=[to_part(content)]) + + +def strict_to_content(content: StrictContentType): + if isinstance(content, Mapping): + content = _convert_dict(content) + + if isinstance(content, glm.Content): + return content + else: + raise TypeError( + "Expected a `glm.Content` or a `dict(parts=...)`.\n" + f"Got type: {type(content)}\n" + f"Value: {content}\n" + ) + + +ContentsType = Union[ContentType, Iterable[StrictContentType], None] + + +def to_contents(contents: ContentsType) -> list[glm.Content]: + if contents is None: + return [] + + if isinstance(contents, Iterable) and not isinstance(contents, (str, Mapping)): + try: + # strict_to_content so [[parts], [parts]] doesn't assume roles. + contents = [strict_to_content(c) for c in contents] + return contents + except TypeError: + # If you get a TypeError here it's probably because that was a list + # of parts, not a list of contents, so fall back to `to_content`. + pass + + contents = [to_content(contents)] + return contents diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py new file mode 100644 index 000000000..3dab8e457 --- /dev/null +++ b/google/generativeai/types/generation_types.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import collections +import contextlib +import sys +from collections.abc import Iterable, AsyncIterable +import dataclasses +import itertools +import textwrap +from typing import TypedDict, Union + +import google.protobuf.json_format +import google.api_core.exceptions + +from google.ai import generativelanguage as glm +from google.generativeai import string_utils + +__all__ = [ + "AsyncGenerateContentResponse", + "BlockedPromptException", + "StopCandidateException", + "IncompleteIterationError", + "BrokenResponseError", + "GenerationConfigDict", + "GenerationConfigType", + "GenerationConfig", + "GenerateContentResponse", +] + + +class BlockedPromptException(Exception): + pass + + +class StopCandidateException(Exception): + pass + + +class IncompleteIterationError(Exception): + pass + + +class BrokenResponseError(Exception): + pass + + +class GenerationConfigDict(TypedDict): + # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/ + candidate_count: int + stop_sequences: Iterable[str] + max_output_tokens: int + temperature: float + + +@dataclasses.dataclass +class GenerationConfig: + candidate_count: int | None = None + stop_sequences: Iterable[str] | None = None + max_output_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + + +GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] + + +def to_generation_config_dict(generation_config: GenerationConfigType): + if generation_config is None: + return {} + elif isinstance(generation_config, glm.GenerationConfig): + return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error + elif isinstance(generation_config, GenerationConfig): + generation_config = dataclasses.asdict(generation_config) + return {key: value for key, value in generation_config.items() if value is not None} + elif hasattr(generation_config, "keys"): + return dict(generation_config) + else: + raise TypeError( + "Did not understand `generation_config`, expected a `dict` or" + f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:" + f" {generation_config}" + ) + + +def _join_citation_metadatas( + citation_metadatas: Iterable[glm.CitationMetadata], +): + citation_metadatas = list(citation_metadatas) + return citation_metadatas[-1] + + +def _join_safety_ratings_lists( + safety_ratings_lists: Iterable[list[glm.SafetyRating]], +): + ratings = {} + blocked = collections.defaultdict(list) + + for safety_ratings_list in safety_ratings_lists: + for rating in safety_ratings_list: + ratings[rating.category] = rating.probability + blocked[rating.category].append(rating.blocked) + + blocked = {category: any(blocked) for category, blocked in blocked.items()} + + safety_list = [] + for (category, probability), blocked in zip(ratings.items(), blocked.values()): + safety_list.append( + glm.SafetyRating(category=category, probability=probability, blocked=blocked) + ) + + return safety_list + + +def _join_contents(contents: Iterable[glm.Content]): + contents = tuple(contents) + roles = [c.role for c in contents if c.role] + if roles: + role = roles[0] + else: + role = "" + + parts = [] + for content in contents: + parts.extend(content.parts) + + merged_parts = [parts.pop(0)] + for part in parts: + if not merged_parts[-1].text: + merged_parts.append(part) + continue + + if not part.text: + merged_parts.append(part) + continue + + merged_part = glm.Part(merged_parts[-1]) + merged_part.text += part.text + merged_parts[-1] = merged_part + + return glm.Content( + role=role, + parts=merged_parts, + ) + + +def _join_candidates(candidates: Iterable[glm.Candidate]): + candidates = tuple(candidates) + + index = candidates[0].index # These should all be the same. + + return glm.Candidate( + index=index, + content=_join_contents([c.content for c in candidates]), + finish_reason=candidates[-1].finish_reason, + safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]), + citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]), + ) + + +def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): + # Assuming that is a candidate ends, it is no longer returned in the list of + # candidates and that's why candidates have an index + candidates = collections.defaultdict(list) + for candidate_list in candidate_lists: + for candidate in candidate_list: + candidates[candidate.index].append(candidate) + + new_candidates = [] + for index, candidate_parts in sorted(candidates.items()): + new_candidates.append(_join_candidates(candidate_parts)) + + return new_candidates + + +def _join_prompt_feedbacks( + prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], +): + # Always return the first prompt feedback. + return next(iter(prompt_feedbacks)) + + +def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): + return glm.GenerateContentResponse( + candidates=_join_candidate_lists(c.candidates for c in chunks), + prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), + ) + + +_INCOMPLETE_ITERATION_MESSAGE = """\ +Please let the response complete iteration before accessing the final accumulated +attributes (or call `response.resolve()`)""" + + +class BaseGenerateContentResponse: + def __init__( + self, + done: bool, + iterator: None + | Iterable[glm.GenerateContentResponse] + | AsyncIterable[glm.GenerateContentResponse], + result: glm.GenerateContentResponse, + chunks: Iterable[glm.GenerateContentResponse], + ): + self._done = done + self._iterator = iterator + self._result = result + self._chunks = list(chunks) + if result.prompt_feedback.block_reason: + self._error = BlockedPromptException(result) + else: + self._error = None + + @property + def candidates(self): + """The list of candidate responses. + + Raises: + IncompleteIterationError: With `stream=True` if iteration over the stream was not completed. + """ + if not self._done: + raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE) + return self._result.candidates + + @property + def parts(self): + """A quick accessor equivalent to `self.candidates[0].parts` + + Raises: + ValueError: If the candidate list does not contain exactly one candidate. + """ + candidates = self.candidates + if not candidates: + raise ValueError( + "The `response.parts` quick accessor only works for a single candidate, " + "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked." + ) + if len(candidates) > 1: + raise ValueError( + "The `response.parts` quick accessor only works with a " + "single candidate. With multiple candidates use " + "result.candidates[index].text" + ) + parts = candidates[0].content.parts + return parts + + @property + def text(self): + """A quick accessor equivalent to `self.candidates[0].parts[0].text` + + Raises: + ValueError: If the candidate list or parts list does not contain exactly one entry. + """ + parts = self.parts + if len(parts) > 1 or "text" not in parts[0]: + raise ValueError( + "The `response.text` quick accessor only works for " + "simple (single-`Part`) text responses. This response " + "contains multiple `Parts`. Use the `result.parts` " + "accessor or the full " + "`result.candidates[index].content.parts` lookup " + "instead" + ) + return parts[0].text + + @property + def prompt_feedback(self): + return self._result.prompt_feedback + + +@contextlib.contextmanager +def rewrite_stream_error(): + try: + yield + except (google.protobuf.json_format.ParseError, AttributeError) as e: + raise google.api_core.exceptions.BadRequest( + "Unknown error trying to retrieve streaming response. " + "Please retry with `stream=False` for more details." + ) + + +GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content_async` method. + + These are returned by `GenerativeModel.generate_content_async` and `ChatSession.send_message_async`. + This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` + and `candidates` attributes. This class adds several quick accessors for common use cases. + + The same object type is returned for both `stream=True/False`. + + ### Streaming + + When you pass `stream=True` to `GenerativeModel.generate_content_async` or `ChatSession.send_message_async`, + iterate over this object to receive chunks of the response: + + ``` + response = model.generate_content_async(..., stream=True): + async for chunk in response: + print(chunk.text) + ``` + + `AsyncGenerateContentResponse.prompt_feedback` is available immediately but + `AsyncGenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`), + are only available after the iteration is complete. + """ + +ASYNC_GENERATE_CONTENT_RESPONSE_DOC = ( + """This is the async version of `genai.GenerateContentResponse`.""" +) + + +@string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) +class GenerateContentResponse(BaseGenerateContentResponse): + @classmethod + def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): + iterator = iter(iterator) + with rewrite_stream_error(): + response = next(iterator) + + return cls( + done=False, + iterator=iterator, + result=response, + chunks=[response], + ) + + @classmethod + def from_response(cls, response: glm.GenerateContentResponse): + return cls( + done=True, + iterator=None, + result=response, + chunks=[response], + ) + + def __iter__(self): + # This is not thread safe. + if self._done: + for chunk in self._chunks: + yield GenerateContentResponse.from_response(chunk) + return + + # Always have the next chunk available. + if len(self._chunks) == 0: + self._chunks.append(next(self._iterator)) + + for n in itertools.count(): + if self._error: + raise self._error + + if n >= len(self._chunks) - 1: + # Look ahead for a new item, so that you know the stream is done + # when you yield the last item. + if self._done: + return + + try: + item = next(self._iterator) + except StopIteration: + self._done = True + except Exception as e: + self._error = e + self._done = True + else: + self._chunks.append(item) + self._result = _join_chunks([self._result, item]) + + item = self._chunks[n] + + item = GenerateContentResponse.from_response(item) + yield item + + def resolve(self): + if self._done: + return + + for _ in self: + pass + + +@string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) +class AsyncGenerateContentResponse(BaseGenerateContentResponse): + @classmethod + async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): + iterator = aiter(iterator) # type: ignore + with rewrite_stream_error(): + response = await anext(iterator) # type: ignore + + return cls( + done=False, + iterator=iterator, + result=response, + chunks=[response], + ) + + @classmethod + def from_response(cls, response: glm.GenerateContentResponse): + return cls( + done=True, + iterator=None, + result=response, + chunks=[response], + ) + + async def __aiter__(self): + # This is not thread safe. + if self._done: + for chunk in self._chunks: + yield GenerateContentResponse.from_response(chunk) + return + + # Always have the next chunk available. + if len(self._chunks) == 0: + self._chunks.append(await anext(self._iterator)) # type: ignore + + for n in itertools.count(): + if self._error: + raise self._error + + if n >= len(self._chunks) - 1: + # Look ahead for a new item, so that you know the stream is done + # when you yield the last item. + if self._done: + return + + try: + item = await anext(self._iterator) # type: ignore + except StopAsyncIteration: + self._done = True + except Exception as e: + self._error = e + self._done = True + else: + self._chunks.append(item) + self._result = _join_chunks([self._result, item]) + + item = self._chunks[n] + + item = GenerateContentResponse.from_response(item) + yield item + + async def resolve(self): + if self._done: + return + + async for _ in self: + pass diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index bedd65317..893206526 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -1,17 +1,3 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations from collections.abc import Mapping @@ -42,7 +28,7 @@ HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { +_OLD_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED, 0: HarmCategory.HARM_CATEGORY_UNSPECIFIED, "harm_category_unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, @@ -83,13 +69,55 @@ "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, "danger": HarmCategory.HARM_CATEGORY_DANGEROUS, } + +_NEW_HARM_CATEGORIES = { + 7: HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, +} # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> HarmCategory: +def to_old_harm_category(x: HarmCategoryOptions) -> HarmCategory: + if isinstance(x, str): + x = x.lower() + return _OLD_HARM_CATEGORIES[x] + + +def to_new_harm_category(x: HarmCategoryOptions) -> HarmCategory: if isinstance(x, str): x = x.lower() - return _HARM_CATEGORIES[x] + return _NEW_HARM_CATEGORIES[x] + + +def to_harm_category(x, harm_category_set): + if harm_category_set == "old": + return to_old_harm_category(x) + elif harm_category_set == "new": + return to_new_harm_category(x) + else: + raise ValueError("harm_category_set must be 'new' or 'old'") HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] @@ -184,30 +212,48 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] +EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] +def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: + if settings is None: + return {} + elif isinstance(settings, Mapping): + return { + to_harm_category(key, harm_category_set): to_block_threshold(value) + for key, value in settings.items() + } + else: # Iterable + return { + to_harm_category(d["category"], harm_category_set): to_block_threshold(d["threshold"]) + for d in settings + } + + def normalize_safety_settings( settings: SafetySettingOptions, + harm_category_set, ) -> list[SafetySettingDict] | None: if settings is None: return None if isinstance(settings, Mapping): return [ { - "category": to_harm_category(key), + "category": to_harm_category(key, harm_category_set), "threshold": to_block_threshold(value), } for key, value in settings.items() ] - return [ - { - "category": to_harm_category(d["category"]), - "threshold": to_block_threshold(d["threshold"]), - } - for d in settings - ] + else: + return [ + { + "category": to_harm_category(d["category"], harm_category_set), + "threshold": to_block_threshold(d["threshold"]), + } + for d in settings + ] def convert_setting_to_enum(setting: dict) -> SafetySettingDict: diff --git a/google/generativeai/version.py b/google/generativeai/version.py index 19a571db2..1ea5708c1 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.2.2" +__version__ = "0.3.0" diff --git a/setup.py b/setup.py index 6cfa220dd..d1568c981 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.3.3", + "google-ai-generativelanguage==0.4.0", "google-auth", "google-api-core", "protobuf", @@ -51,14 +51,7 @@ def get_version(): ] extras_require = { - "dev": [ - "absl-py", - "black", - "nose2", - "pandas", - "pytype", - "pyyaml", - ], + "dev": ["absl-py", "black", "nose2", "pandas", "pytype", "pyyaml", "Pillow", "ipython"], } url = "https://github.com/google/generative-ai-python" diff --git a/tests/test_client.py b/tests/test_client.py index 29c14ea51..34a0f9fc3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,18 +1,3 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import os from unittest import mock @@ -127,6 +112,22 @@ def test_default_metadata(self): text_client.classm() self.assertTrue(ClientTests.DummyClient.called_classm) + def test_same_config(self): + cm1 = client._ClientManager() + cm1.configure(api_key="abc") + + cm2 = client._ClientManager() + cm2.configure(client_options=dict(api_key="abc")) + + self.assertEqual( + cm1.client_config["client_info"].__dict__, cm2.client_config["client_info"].__dict__ + ) + self.assertEqual( + cm1.client_config["client_options"].__dict__, + cm2.client_config["client_options"].__dict__, + ) + self.assertEqual(cm1.default_metadata, cm2.default_metadata) + if __name__ == "__main__": absltest.main() diff --git a/tests/test_content.py b/tests/test_content.py new file mode 100644 index 000000000..62a582d37 --- /dev/null +++ b/tests/test_content.py @@ -0,0 +1,160 @@ +import copy +import pathlib +import unittest.mock + +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +import google.generativeai as genai +from google.generativeai.types import content_types +from google.generativeai.types import safety_types +import IPython.display +import PIL.Image + + +HERE = pathlib.Path(__file__).parent +TEST_IMAGE_PATH = HERE / "test_img.png" +TEST_IMAGE_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.png" +TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() + + +class UnitTests(parameterized.TestCase): + @parameterized.named_parameters( + ["PIL", PIL.Image.open(TEST_IMAGE_PATH)], + ["IPython", IPython.display.Image(filename=TEST_IMAGE_PATH)], + ) + def test_image_to_blob(self, image): + blob = content_types.image_to_blob(image) + self.assertIsInstance(blob, glm.Blob) + self.assertEqual(blob.mime_type, "image/png") + self.assertStartsWith(blob.data, b"\x89PNG") + + @parameterized.named_parameters( + ["BlobDict", {"mime_type": "image/png", "data": TEST_IMAGE_DATA}], + ["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_IMAGE_DATA)], + ["Image", IPython.display.Image(filename=TEST_IMAGE_PATH)], + ) + def test_to_blob(self, example): + blob = content_types.to_blob(example) + self.assertIsInstance(blob, glm.Blob) + self.assertEqual(blob.mime_type, "image/png") + self.assertStartsWith(blob.data, b"\x89PNG") + + @parameterized.named_parameters( + ["dict", {"text": "Hello world!"}], + ["glm.Part", glm.Part(text="Hello world!")], + ["str", "Hello world!"], + ) + def test_to_part(self, example): + part = content_types.to_part(example) + self.assertIsInstance(part, glm.Part) + self.assertEqual(part.text, "Hello world!") + + @parameterized.named_parameters( + ["Image", IPython.display.Image(filename=TEST_IMAGE_PATH)], + ["BlobDict", {"mime_type": "image/png", "data": TEST_IMAGE_DATA}], + [ + "PartDict", + {"inline_data": {"mime_type": "image/png", "data": TEST_IMAGE_DATA}}, + ], + ) + def test_img_to_part(self, example): + blob = content_types.to_part(example).inline_data + self.assertIsInstance(blob, glm.Blob) + self.assertEqual(blob.mime_type, "image/png") + self.assertStartsWith(blob.data, b"\x89PNG") + + @parameterized.named_parameters( + ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["ContentDict", {"parts": [{"text": "Hello world!"}]}], + ["ContentDict-str", {"parts": ["Hello world!"]}], + ["list[parts]", [{"text": "Hello world!"}]], + ["list[str]", ["Hello world!"]], + ["iterator[parts]", iter([{"text": "Hello world!"}])], + ["part", {"text": "Hello world!"}], + ["str", "Hello world!"], + ) + def test_to_content(self, example): + content = content_types.to_content(example) + part = content.parts[0] + + self.assertLen(content.parts, 1) + self.assertIsInstance(part, glm.Part) + self.assertEqual(part.text, "Hello world!") + + @parameterized.named_parameters( + ["ContentDict", {"parts": [PIL.Image.open(TEST_IMAGE_PATH)]}], + ["list[Image]", [PIL.Image.open(TEST_IMAGE_PATH)]], + ["Image", PIL.Image.open(TEST_IMAGE_PATH)], + ) + def test_img_to_content(self, example): + content = content_types.to_content(example) + blob = content.parts[0].inline_data + self.assertLen(content.parts, 1) + self.assertIsInstance(blob, glm.Blob) + self.assertEqual(blob.mime_type, "image/png") + self.assertStartsWith(blob.data, b"\x89PNG") + + @parameterized.named_parameters( + ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["ContentDict", {"parts": [{"text": "Hello world!"}]}], + ["ContentDict-str", {"parts": ["Hello world!"]}], + ) + def test_strict_to_content(self, example): + content = content_types.strict_to_content(example) + part = content.parts[0] + + self.assertLen(content.parts, 1) + self.assertIsInstance(part, glm.Part) + self.assertEqual(part.text, "Hello world!") + + @parameterized.named_parameters( + ["list[parts]", [{"text": "Hello world!"}]], + ["list[str]", ["Hello world!"]], + ["iterator[parts]", iter([{"text": "Hello world!"}])], + ["part", {"text": "Hello world!"}], + ["str", "Hello world!"], + ) + def test_strict_to_contents_fails(self, examples): + with self.assertRaises(TypeError): + content_types.strict_to_content(examples) + + @parameterized.named_parameters( + ["glm.Content", [glm.Content(parts=[{"text": "Hello world!"}])]], + ["ContentDict", [{"parts": [{"text": "Hello world!"}]}]], + ["ContentDict-unwraped", [{"parts": ["Hello world!"]}]], + ) + def test_to_contents(self, example): + contents = content_types.to_contents(example) + part = contents[0].parts[0] + + self.assertLen(contents, 1) + self.assertLen(contents[0].parts, 1) + self.assertIsInstance(part, glm.Part) + self.assertEqual(part.text, "Hello world!") + + def test_dict_to_content_fails(self): + with self.assertRaises(KeyError): + content_types.to_content({"bad": "dict"}) + + @parameterized.named_parameters( + [ + "ContentDict", + [{"parts": [{"inline_data": PIL.Image.open(TEST_IMAGE_PATH)}]}], + ], + ["ContentDict-unwraped", [{"parts": [PIL.Image.open(TEST_IMAGE_PATH)]}]], + ["Image", PIL.Image.open(TEST_IMAGE_PATH)], + ) + def test_img_to_contents(self, example): + contents = content_types.to_contents(example) + blob = contents[0].parts[0].inline_data + + self.assertLen(contents, 1) + self.assertLen(contents[0].parts, 1) + self.assertIsInstance(blob, glm.Blob) + self.assertEqual(blob.mime_type, "image/png") + self.assertStartsWith(blob.data, b"\x89PNG") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 8021d972a..32128da1d 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -33,7 +33,7 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client._client_manager.discuss_client = self.client + client._client_manager.clients["discuss"] = self.client self.observed_request = None diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 000000000..a67f47f39 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import unittest +import unittest.mock as mock + +import google.ai.generativelanguage as glm + +from google.generativeai import embedding + +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + +DEFAULT_EMB_MODEL = "models/embedding-001" + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["generative"] = self.client + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def embed_content( + request: glm.EmbedContentRequest, + ) -> glm.EmbedContentResponse: + self.observed_requests.append(request) + return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + + @add_client_method + def batch_embed_contents( + request: glm.BatchEmbedContentsRequest, + ) -> glm.BatchEmbedContentsResponse: + self.observed_requests.append(request) + return glm.BatchEmbedContentsResponse( + embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + ) + + def test_embed_content(self): + text = "What are you?" + emb = embedding.embed_content(model=DEFAULT_EMB_MODEL, content=text) + + self.assertIsInstance(emb, dict) + self.assertEqual( + self.observed_requests[-1], + glm.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + ), + ) + self.assertIsInstance(emb["embedding"][0], float) + + @parameterized.named_parameters( + [ + dict( + testcase_name="even-batch", + batch_size=100, + ), + dict( + testcase_name="even-batch-plus-one", + batch_size=101, + ), + dict(testcase_name="odd-batch", batch_size=237), + ] + ) + def test_batch_embed_contents(self, batch_size): + text = ["What are you?"] + texts = text * batch_size + emb = embedding.embed_content(model=DEFAULT_EMB_MODEL, content=texts) + + self.assertIsInstance(emb, dict) + + # Check that the list has the right length. + self.assertIsInstance(emb["embedding"][0], list) + self.assertLen(emb["embedding"], len(texts)) + + # Check that the right number of requests were sent. + self.assertLen( + self.observed_requests, + math.ceil(len(texts) / embedding.EMBEDDING_MAX_BATCH_SIZE), + ) + + def test_embed_content_title_and_task_1(self): + text = "What are you?" + emb = embedding.embed_content( + model=DEFAULT_EMB_MODEL, + content=text, + task_type="retrieval_document", + title="Exploring AI", + ) + + self.assertEqual( + embedding.to_task_type("retrieval_document"), + embedding.EmbeddingTaskType.RETRIEVAL_DOCUMENT, + ) + + def test_embed_content_title_and_task_2(self): + text = "What are you?" + with self.assertRaises(ValueError): + embedding.embed_content( + model=DEFAULT_EMB_MODEL, content=text, task_type="unspecified", title="Exploring AI" + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generation.py b/tests/test_generation.py new file mode 100644 index 000000000..87d79ea1c --- /dev/null +++ b/tests/test_generation.py @@ -0,0 +1,494 @@ +import inspect +import string + +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai.types import generation_types + + +class UnitTests(parameterized.TestCase): + @parameterized.named_parameters( + [ + "glm.GenerationConfig", + glm.GenerationConfig(temperature=0.1, stop_sequences=["end"]), + ], + ["GenerationConfigDict", {"temperature": 0.1, "stop_sequences": ["end"]}], + [ + "GenerationConfig", + generation_types.GenerationConfig(temperature=0.1, stop_sequences=["end"]), + ], + ) + def test_to_generation_config(self, config): + gd = generation_types.to_generation_config_dict(config) + self.assertIsInstance(gd, dict) + self.assertEqual(gd["temperature"], 0.1) + self.assertEqual(gd["stop_sequences"], ["end"]) + + def test_join_citation_metadatas(self): + citations = [ + glm.CitationMetadata( + citation_sources=[ + glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + ] + ), + glm.CitationMetadata( + citation_sources=[ + glm.CitationSource(start_index=3, end_index=33, uri="https://google.com"), + glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + ] + ), + ] + + result = generation_types._join_citation_metadatas(citations) + + expected = { + "citation_sources": [ + {"start_index": 3, "end_index": 33, "uri": "https://google.com"}, + {"start_index": 55, "end_index": 92, "uri": "https://google.com"}, + ] + } + self.assertEqual(expected, type(result).to_dict(result)) + + def test_join_safety_ratings_list(self): + ratings = [ + [ + glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), + glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), + ], + [ + glm.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), + glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), + glm.SafetyRating( + category="HARM_CATEGORY_DANGEROUS", + probability="HIGH", + blocked=True, + ), + ], + ] + + result = generation_types._join_safety_ratings_lists(ratings) + + expected = [ + {"category": 6, "probability": 4, "blocked": True}, + {"category": 5, "probability": 2, "blocked": False}, + {"category": 4, "probability": 2, "blocked": False}, + {"category": 1, "probability": 2, "blocked": False}, + ] + self.assertEqual(expected, [type(r).to_dict(r) for r in result]) + + def test_join_contents(self): + contents = [ + glm.Content(role="assistant", parts=[glm.Part(text="Tell me a story about a ")]), + glm.Content( + role="assistant", + parts=[glm.Part(text="magic backpack that looks like this: ")], + ), + glm.Content( + role="assistant", + parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + ), + ] + result = generation_types._join_contents(contents) + expected = { + "parts": [ + {"text": ("Tell me a story about a magic backpack that looks like" " this: ")}, + {"inline_data": {"mime_type": "image/png", "data": "REFUQSE="}}, + ], + "role": "assistant", + } + + self.assertEqual(expected, type(result).to_dict(result)) + + def test_many_join_contents(self): + import string + + contents = [ + glm.Content(role="assistant", parts=[glm.Part(text=a)]) for a in string.ascii_lowercase + ] + + result = generation_types._join_contents(contents) + expected = { + "parts": [{"text": string.ascii_lowercase}], + "role": "assistant", + } + + self.assertEqual(expected, type(result).to_dict(result)) + + def test_join_candidates(self): + candidates = [ + glm.Candidate( + index=0, + content=glm.Content( + role="assistant", + parts=[glm.Part(text="Tell me a story about a ")], + ), + citation_metadata=glm.CitationMetadata( + citation_sources=[ + glm.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + ] + ), + ), + glm.Candidate( + index=0, + content=glm.Content( + role="assistant", + parts=[glm.Part(text="magic backpack that looks like this: ")], + ), + citation_metadata=glm.CitationMetadata( + citation_sources=[ + glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + ] + ), + ), + glm.Candidate( + index=0, + content=glm.Content( + role="assistant", + parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + ), + citation_metadata=glm.CitationMetadata( + citation_sources=[ + glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + ] + ), + finish_reason="STOP", + ), + ] + result = generation_types._join_candidates(candidates) + + expected = { + "content": { + "parts": [ + {"text": ("Tell me a story about a magic backpack that looks like" " this: ")}, + {"text": ""}, + ], + "role": "assistant", + }, + "finish_reason": 1, + "citation_metadata": { + "citation_sources": [ + { + "start_index": 55, + "end_index": 92, + "uri": "https://google.com", + }, + { + "start_index": 3, + "end_index": 21, + "uri": "https://google.com", + }, + ] + }, + "index": 0, + "safety_ratings": [], + "token_count": 0, + } + + self.assertEqual(expected, type(result).to_dict(result)) + + def test_join_prompt_feedbacks(self): + feedbacks = [ + glm.GenerateContentResponse.PromptFeedback( + block_reason="SAFETY", + safety_ratings=[ + glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + ], + ), + glm.GenerateContentResponse.PromptFeedback(), + glm.GenerateContentResponse.PromptFeedback(), + glm.GenerateContentResponse.PromptFeedback( + safety_ratings=[ + glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), + ] + ), + ] + result = generation_types._join_prompt_feedbacks(feedbacks) + expected = feedbacks[0] + self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(result)) + + CANDIDATE_LISTS = [ + [ + { + "content": { + "parts": [{"text": "Here is a photo of a magic backpack:"}], + "role": "assistant", + }, + "index": 0, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + }, + { + "content": { + "parts": [{"text": "Tell me a story about a magic backpack"}], + "role": "assistant", + }, + "index": 1, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + }, + { + "content": { + "parts": [{"text": "Tell me a story about a "}], + "role": "assistant", + }, + "index": 2, + "citation_metadata": {"citation_sources": []}, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + }, + ], + [ + { + "content": { + "parts": [{"text": "magic backpack that looks like this: "}], + "role": "assistant", + }, + "index": 2, + "citation_metadata": { + "citation_sources": [ + { + "start_index": 3, + "end_index": 21, + "uri": "https://google.com", + } + ] + }, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + }, + { + "content": { + "parts": [ + { + "inline_data": { + "mime_type": "image/png", + "data": "REFUQSE=", + } + } + ], + "role": "assistant", + }, + "index": 0, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + }, + ], + [ + { + "content": { + "parts": [ + { + "inline_data": { + "mime_type": "image/png", + "data": "REFUQSE=", + } + } + ], + "role": "assistant", + }, + "index": 2, + "citation_metadata": { + "citation_sources": [ + { + "start_index": 3, + "end_index": 21, + "uri": "https://google.com", + } + ] + }, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + } + ], + ] + MERGED_CANDIDATES = [ + { + "content": { + "parts": [ + {"text": "Here is a photo of a magic backpack:"}, + { + "inline_data": { + "mime_type": "image/png", + "data": "REFUQSE=", + } + }, + ], + "role": "assistant", + }, + "citation_metadata": {"citation_sources": []}, + "index": 0, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + "grounding_attributions": [], + }, + { + "content": { + "parts": [{"text": "Tell me a story about a magic backpack"}], + "role": "assistant", + }, + "index": 1, + "citation_metadata": {"citation_sources": []}, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + "grounding_attributions": [], + }, + { + "content": { + "parts": [ + {"text": ("Tell me a story about a magic backpack" " that looks like this: ")}, + { + "inline_data": { + "mime_type": "image/png", + "data": "REFUQSE=", + } + }, + ], + "role": "assistant", + }, + "index": 2, + "citation_metadata": { + "citation_sources": [ + { + "start_index": 3, + "end_index": 21, + "uri": "https://google.com", + }, + ] + }, + "finish_reason": 0, + "safety_ratings": [], + "token_count": 0, + "grounding_attributions": [], + }, + ] + + def test_join_candidates(self): + candidate_lists = [[glm.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] + result = generation_types._join_candidate_lists(candidate_lists) + self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) + + def test_join_chunks(self): + chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + + chunks[0].prompt_feedback = glm.GenerateContentResponse.PromptFeedback( + block_reason="SAFETY", + safety_ratings=[ + glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + ], + ) + + result = generation_types._join_chunks(chunks) + + expected = glm.GenerateContentResponse( + { + "candidates": self.MERGED_CANDIDATES, + "prompt_feedback": { + "block_reason": 1, + "safety_ratings": [ + { + "category": 6, + "probability": 2, + "blocked": False, + } + ], + }, + }, + ) + + self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) + + def test_generate_content_response_iterator_end_to_end(self): + chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + merged = generation_types._join_chunks(chunks) + + response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) + + # Initially property access fails. + with self.assertRaises(generation_types.IncompleteIterationError): + _ = response.candidates + + # It yields the chunks as given. + for c1, c2 in zip(chunks, response): + c2 = c2._result + self.assertEqual(type(c1).to_dict(c1), type(c2).to_dict(c2)) + + # The final result is identical to _join_chunks's output. + self.assertEqual( + type(merged).to_dict(merged), + type(response._result).to_dict(response._result), + ) + + def test_generate_content_response_multiple_iterators(self): + chunks = [ + glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + for a in string.ascii_lowercase + ] + response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) + + # Do a partial iteration. + it1 = iter(response) + for i, chunk, a in zip(range(5), it1, string.ascii_lowercase): + self.assertEqual(a, chunk.candidates[0].content.parts[0].text) + + # Iterate past the first iterator. + it2 = iter(response) + for i, chunk, a in zip(range(10), it2, string.ascii_lowercase): + self.assertEqual(a, chunk.candidates[0].content.parts[0].text) + + # Resume the first iterator. + for i, chunk, a in zip(range(5), it1, string.ascii_lowercase[5:]): + self.assertEqual(a, chunk.candidates[0].content.parts[0].text) + + # Do a full iteration + chunks = list(response) + joined = "".join(chunk.candidates[0].content.parts[0].text for chunk in chunks) + self.assertEqual(joined, string.ascii_lowercase) + + parts = response.candidates[0].content.parts + self.assertLen(parts, 1) + self.assertEqual(parts[0].text, string.ascii_lowercase) + + def test_generate_content_response_resolve(self): + chunks = [ + glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + for a in "abcd" + ] + response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) + + # Initially property access fails. + with self.assertRaises(generation_types.IncompleteIterationError): + _ = response.candidates + + response.resolve() + + self.assertEqual(response.candidates[0].content.parts[0].text, "abcd") + + def test_generate_content_response_from_response(self): + raw_response = glm.GenerateContentResponse( + {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} + ) + response = generation_types.GenerateContentResponse.from_response(raw_response) + + self.assertEqual(response.candidates[0], raw_response.candidates[0]) + self.assertLen(list(response), 1) + + for chunk in response: + self.assertEqual( + type(raw_response).to_dict(raw_response), type(chunk._result).to_dict(chunk._result) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py new file mode 100644 index 000000000..01608eb97 --- /dev/null +++ b/tests/test_generative_models.py @@ -0,0 +1,611 @@ +import collections +from collections.abc import Iterable +import copy +import pathlib +import unittest.mock +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai import client as client_lib +from google.generativeai import generative_models +from google.generativeai.types import content_types +from google.generativeai.types import generation_types + +import PIL.Image + +HERE = pathlib.Path(__file__).parent +TEST_IMAGE_PATH = HERE / "test_img.png" +TEST_IMAGE_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.png" +TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() + + +def simple_response(text: str) -> glm.GenerateContentResponse: + return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) + + +class CUJTests(parameterized.TestCase): + """Tests are in order with the design doc.""" + + def setUp(self): + self.client = unittest.mock.MagicMock() + + client_lib._client_manager.clients["generative"] = self.client + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + self.observed_requests = [] + self.responses = collections.defaultdict(list) + + @add_client_method + def generate_content( + request: glm.GenerateContentRequest, + ) -> glm.GenerateContentResponse: + self.assertIsInstance(request, glm.GenerateContentRequest) + self.observed_requests.append(request) + response = self.responses["generate_content"].pop(0) + return response + + @add_client_method + def stream_generate_content( + request: glm.GetModelRequest, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + response = self.responses["stream_generate_content"].pop(0) + return response + + def test_hello(self): + # Generate text from text prompt + model = generative_models.GenerativeModel(model_name="gemini-m") + + self.responses["generate_content"].append(simple_response("world!")) + + response = model.generate_content("Hello") + + self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") + self.assertEqual(response.candidates[0].content.parts[0].text, "world!") + + self.assertEqual(response.text, "world!") + + @parameterized.named_parameters( + [ + "JustImage", + PIL.Image.open(TEST_IMAGE_PATH), + ], + [ + "ImageAndText", + ["What's in this picture?", PIL.Image.open(TEST_IMAGE_PATH)], + ], + ) + def test_image(self, content): + # Generate text from image + model = generative_models.GenerativeModel("gemini-m") + + cat = "It's a cat" + self.responses["generate_content"].append(simple_response(cat)) + + response = model.generate_content(content) + + self.assertEqual( + self.observed_requests[0].contents[0].parts[-1].inline_data.mime_type, + "image/png", + ) + self.assertEqual( + self.observed_requests[0].contents[0].parts[-1].inline_data.data[:4], + b"\x89PNG", + ) + self.assertEqual(response.candidates[0].content.parts[0].text, cat) + + self.assertEqual(response.text, cat) + + @parameterized.named_parameters( + ["dict", {"temperature": 0.0}, {"temperature": 0.5}], + [ + "object", + generation_types.GenerationConfig(temperature=0.0), + generation_types.GenerationConfig(temperature=0.5), + ], + [ + "glm", + glm.GenerationConfig(temperature=0.0), + glm.GenerationConfig(temperature=0.5), + ], + ) + def test_generation_config_overwrite(self, config1, config2): + # Generation config + model = generative_models.GenerativeModel("gemini-m", generation_config=config1) + + self.responses["generate_content"] = [ + simple_response(" world!"), + simple_response(" world!"), + ] + + _ = model.generate_content("hello") + self.assertEqual(self.observed_requests[-1].generation_config.temperature, 0.0) + + _ = model.generate_content("hello", generation_config=config2) + self.assertEqual(self.observed_requests[-1].generation_config.temperature, 0.5) + + @parameterized.named_parameters( + ["dict", {"danger": "low"}, {"danger": "high"}], + [ + "list-dict", + [ + dict( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ], + [ + dict(category="danger", threshold="high"), + ], + ], + [ + "object", + [ + glm.SafetySetting( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ], + [ + glm.SafetySetting( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ], + ], + ) + def test_safety_overwrite(self, safe1, safe2): + # Safety + model = generative_models.GenerativeModel("gemini-m", safety_settings={"danger": "low"}) + + self.responses["generate_content"] = [ + simple_response(" world!"), + simple_response(" world!"), + ] + + _ = model.generate_content("hello") + self.assertEqual( + self.observed_requests[-1].safety_settings[0].category, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].threshold, + glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ) + + _ = model.generate_content("hello", safety_settings={"danger": "high"}) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].category, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ) + self.assertEqual( + self.observed_requests[-1].safety_settings[0].threshold, + glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ) + + def test_stream_basic(self): + # Streaming + chunks = ["first", " second", " third"] + self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)] + + model = generative_models.GenerativeModel("gemini-m") + response = model.generate_content("Hello", stream=True) + + self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") + + for n, got in enumerate(response): + self.assertEqual(chunks[n], got.text) + + self.assertEqual(response.text, "".join(chunks)) + + def test_stream_lookahead(self): + chunks = ["first", " second", " third"] + self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)] + + model = generative_models.GenerativeModel("gemini-m") + response = model.generate_content("Hello", stream=True) + + self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") + + for expected, got in zip(chunks, response): + self.assertEqual(expected, got.text) + + self.assertEqual(response.text, "".join(chunks)) + + def test_stream_prompt_feedback_blocked(self): + chunks = [ + glm.GenerateContentResponse( + { + "prompt_feedback": {"block_reason": "SAFETY"}, + } + ) + ] + self.responses["stream_generate_content"] = [(chunk for chunk in chunks)] + + model = generative_models.GenerativeModel("gemini-m") + response = model.generate_content("Bad stuff!", stream=True) + + self.assertEqual( + response.prompt_feedback.block_reason, + glm.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, + ) + + with self.assertRaises(generation_types.BlockedPromptException): + for chunk in response: + pass + + def test_stream_prompt_feedback_not_blocked(self): + chunks = [ + glm.GenerateContentResponse( + { + "prompt_feedback": { + "safety_ratings": [ + { + "category": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "probability": glm.SafetyRating.HarmProbability.NEGLIGIBLE, + } + ] + }, + "candidates": [{"content": {"parts": [{"text": "first"}]}}], + } + ), + glm.GenerateContentResponse( + { + "candidates": [{"content": {"parts": [{"text": " second"}]}}], + } + ), + ] + self.responses["stream_generate_content"] = [(chunk for chunk in chunks)] + + model = generative_models.GenerativeModel("gemini-m") + response = model.generate_content("Hello", stream=True) + + self.assertEqual( + response.prompt_feedback.safety_ratings[0].category, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + ) + + text = "".join(chunk.text for chunk in response) + self.assertEqual(text, "first second") + + def test_chat(self): + # Multi turn chat + model = generative_models.GenerativeModel("gemini-m") + chat = model.start_chat() + + self.responses["generate_content"] = [ + simple_response("first"), + simple_response("second"), + simple_response("third"), + ] + + msg1 = "I really like fantasy books." + response = chat.send_message(msg1) + self.assertEqual(response.text, "first") + self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, msg1) + + msg2 = "I also like this image." + response = chat.send_message([msg2, PIL.Image.open(TEST_IMAGE_PATH)]) + + self.assertEqual(response.text, "second") + self.assertEqual(self.observed_requests[1].contents[0].parts[0].text, msg1) + self.assertEqual(self.observed_requests[1].contents[1].parts[0].text, "first") + self.assertEqual(self.observed_requests[1].contents[2].parts[0].text, msg2) + self.assertEqual( + self.observed_requests[1].contents[2].parts[1].inline_data.data[:4], + b"\x89PNG", + ) + + msg3 = "What things do I like?." + response = chat.send_message(msg3) + self.assertEqual(response.text, "third") + self.assertLen(chat.history, 6) + + def test_chat_roles(self): + self.responses["generate_content"] = [simple_response("hello!")] + + model = generative_models.GenerativeModel("gemini-pro") + chat = model.start_chat() + response = chat.send_message("hello?") + history = chat.history + self.assertEqual(history[0].role, "user") + self.assertEqual(history[1].role, "model") + + def test_chat_streaming_basic(self): + # Chat streaming + self.responses["stream_generate_content"] = [ + iter([simple_response("a"), simple_response("b"), simple_response("c")]), + iter([simple_response("1"), simple_response("2"), simple_response("3")]), + iter([simple_response("x"), simple_response("y"), simple_response("z")]), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + response = chat.send_message("letters?", stream=True) + + self.assertEqual("".join(chunk.text for chunk in response), "abc") + + response = chat.send_message("numbers?", stream=True) + + self.assertEqual("".join(chunk.text for chunk in response), "123") + + response = chat.send_message("more letters?", stream=True) + + self.assertEqual("".join(chunk.text for chunk in response), "xyz") + + def test_chat_incomplete_streaming_errors(self): + # Chat streaming + self.responses["stream_generate_content"] = [ + iter([simple_response("a"), simple_response("b"), simple_response("c")]), + iter([simple_response("1"), simple_response("2"), simple_response("3")]), + iter([simple_response("x"), simple_response("y"), simple_response("z")]), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + response = chat.send_message("letters?", stream=True) + + with self.assertRaises(generation_types.IncompleteIterationError): + chat.history + + with self.assertRaises(generation_types.IncompleteIterationError): + chat.send_message("numbers?", stream=True) + + for chunk in response: + pass + self.assertLen(chat.history, 2) + + response = chat.send_message("numbers?", stream=True) + self.assertEqual("".join(chunk.text for chunk in response), "123") + + def test_edit_history(self): + self.responses["generate_content"] = [ + simple_response("first"), + simple_response("second"), + simple_response("third"), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + response = chat.send_message("hello") + self.assertEqual(response.text, "first") + self.assertLen(chat.history, 2) + + response = chat.send_message("hello") + self.assertEqual(response.text, "second") + self.assertLen(chat.history, 4) + + chat.history[-1] = content_types.to_content("edited") + response = chat.send_message("hello") + self.assertEqual(response.text, "third") + self.assertLen(chat.history, 6) + + self.assertEqual(chat.history[3], content_types.to_content("edited")) + self.assertEqual(self.observed_requests[-1].contents[3].parts[0].text, "edited") + + def test_replace_history(self): + self.responses["generate_content"] = [ + simple_response("first"), + simple_response("second"), + simple_response("third"), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + chat.send_message("hello1") + chat.send_message("hello2") + + self.assertLen(chat.history, 4) + chat.history = [{"parts": ["Goodbye"]}, {"parts": ["Later gater"]}] + self.assertLen(chat.history, 2) + + response = chat.send_message("hello3") + self.assertEqual(response.text, "third") + self.assertLen(chat.history, 4) + + self.assertEqual(self.observed_requests[-1].contents[0].parts[0].text, "Goodbye") + + def test_copy_history(self): + self.responses["generate_content"] = [ + simple_response("first"), + simple_response("second"), + simple_response("third"), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat1 = model.start_chat() + chat1.send_message("hello1") + + chat2 = copy.deepcopy(chat1) + chat2.send_message("hello2") + + chat1.send_message("hello3") + + self.assertLen(chat1.history, 4) + expected = [ + {"role": "user", "parts": ["hello1"]}, + {"role": "model", "parts": ["first"]}, + {"role": "user", "parts": ["hello3"]}, + {"role": "model", "parts": ["third"]}, + ] + for content, ex in zip(chat1.history, expected): + self.assertEqual(content, content_types.to_content(ex)) + + self.assertLen(chat2.history, 4) + expected = [ + {"role": "user", "parts": ["hello1"]}, + {"role": "model", "parts": ["first"]}, + {"role": "user", "parts": ["hello2"]}, + {"role": "model", "parts": ["second"]}, + ] + for content, ex in zip(chat2.history, expected): + self.assertEqual(content, content_types.to_content(ex)) + + def test_chat_error_in_stream(self): + def throw(): + for c in "123": + yield simple_response(c) + raise ValueError() + + def no_throw(): + for c in "abc": + yield simple_response(c) + + self.responses["stream_generate_content"] = [ + no_throw(), + throw(), + no_throw(), + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + # Send a message, the response is okay.. + chat.send_message("hello1", stream=True).resolve() + + # Send a second message, it fails + response = chat.send_message("hello2", stream=True) + with self.assertRaises(ValueError): + # Iteration fails + for chunk in response: + pass + + # Since the response broke, we can't access the history + with self.assertRaises(generation_types.BrokenResponseError): + chat.history + + # or send another message. + with self.assertRaises(generation_types.BrokenResponseError): + chat.send_message("hello") + + # Rewind a step to before the error + chat.rewind() + + self.assertLen(chat.history, 2) + self.assertEqual(chat.history[0].parts[0].text, "hello1") + self.assertEqual(chat.history[1].parts[0].text, "abc") + + # And continue + chat.send_message("hello3", stream=True).resolve() + self.assertLen(chat.history, 4) + self.assertEqual(chat.history[2].parts[0].text, "hello3") + self.assertEqual(chat.history[3].parts[0].text, "abc") + + def test_chat_prompt_blocked(self): + self.responses["generate_content"] = [ + glm.GenerateContentResponse( + { + "prompt_feedback": {"block_reason": "SAFETY"}, + } + ) + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + with self.assertRaises(generation_types.BlockedPromptException): + chat.send_message("hello") + + self.assertLen(chat.history, 0) + + def test_chat_candidate_blocked(self): + # I feel like chat needs a .last so you can look at the partial results. + self.responses["generate_content"] = [ + glm.GenerateContentResponse( + { + "candidates": [{"finish_reason": "SAFETY"}], + } + ) + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + with self.assertRaises(generation_types.StopCandidateException): + chat.send_message("hello") + + def test_chat_streaming_unexpected_stop(self): + self.responses["stream_generate_content"] = [ + iter( + [ + simple_response("a"), + simple_response("b"), + simple_response("c"), + glm.GenerateContentResponse( + { + "candidates": [{"finish_reason": "SAFETY"}], + } + ), + ] + ) + ] + + model = generative_models.GenerativeModel("gemini-mm-m") + chat = model.start_chat() + + response = chat.send_message("hello", stream=True) + for chunk in response: + # The result doesn't know it's a chat result so it can't throw. + # Unless we give it some way to know? + pass + + with self.assertRaises(generation_types.BrokenResponseError): + # But when preparing the next message, we can throw: + response = chat.send_message("hello2", stream=True) + + # It's a little bad that here it's only on send message that you find out + # about the problem. "hello2" is never added, the `rewind` removes `hello1`. + chat.rewind() + self.assertLen(chat.history, 0) + + @parameterized.named_parameters( + [ + "GenerateContentResponse", + generation_types.GenerateContentResponse, + generation_types.AsyncGenerateContentResponse, + ], + [ + "GenerativeModel.generate_response", + generative_models.GenerativeModel.generate_content, + generative_models.GenerativeModel.generate_content_async, + ], + [ + "GenerativeModel.count_tokens", + generative_models.GenerativeModel.count_tokens, + generative_models.GenerativeModel.count_tokens_async, + ], + [ + "ChatSession.send_message", + generative_models.ChatSession.send_message, + generative_models.ChatSession.send_message_async, + ], + ) + def test_async_code_match(self, obj, aobj): + import inspect + import re + + source = inspect.getsource(obj) + asource = inspect.getsource(aobj) + + asource = ( + asource.replace("anext", "next") + .replace("aiter", "iter") + .replace("_async", "") + .replace("async ", "") + .replace("await ", "") + .replace("Async", "") + .replace("ASYNC_", "") + ) + + asource = re.sub(" *?# type: ignore", "", asource) + self.assertEqual(source, asource) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py new file mode 100644 index 000000000..1c48f3476 --- /dev/null +++ b/tests/test_generative_models_async.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import sys +from collections.abc import Iterable +import os +import unittest + +from google.generativeai import client as client_lib +from google.generativeai import generative_models +import google.ai.generativelanguage as glm + +from absl.testing import absltest +from absl.testing import parameterized + + +def simple_response(text: str) -> glm.GenerateContentResponse: + return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) + + +class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client_lib._client_manager.clients["generative_async"] = self.client + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + self.observed_requests = [] + self.responses = collections.defaultdict(list) + + @add_client_method + async def generate_content( + request: glm.GenerateContentRequest, + ) -> glm.GenerateContentResponse: + self.assertIsInstance(request, glm.GenerateContentRequest) + self.observed_requests.append(request) + response = self.responses["generate_content"].pop(0) + return response + + @add_client_method + async def stream_generate_content( + request: glm.GetModelRequest, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + response = self.responses["stream_generate_content"].pop(0) + return response + + async def test_basic(self): + # Generate text from text prompt + model = generative_models.GenerativeModel(model_name="gemini-m") + + self.responses["generate_content"] = [simple_response("world!")] + + response = await model.generate_content_async("Hello") + + self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") + self.assertEqual(response.candidates[0].content.parts[0].text, "world!") + + self.assertEqual(response.text, "world!") + + @unittest.skipIf( + sys.version_info.major == 3 and sys.version_info.minor < 10, + "streaming async requires python 3.10+", + ) + async def test_streaming(self): + # Generate text from text prompt + model = generative_models.GenerativeModel(model_name="gemini-m") + + async def responses(): + for c in "world!": + yield simple_response(c) + + self.responses["stream_generate_content"] = [responses()] + + response = await model.generate_content_async("Hello", stream=True) + + it = iter("world!") + async for chunk in response: + c = next(it) + self.assertEqual(chunk.text, c) + + self.assertEqual(response.text, "world!") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_img.png b/tests/test_img.png new file mode 100644 index 000000000..12ac5d224 Binary files /dev/null and b/tests/test_img.png differ diff --git a/tests/test_models.py b/tests/test_models.py index 693aeba39..60bf3a615 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -41,7 +41,7 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client._client_manager.model_client = self.client + client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help? diff --git a/tests/test_operations.py b/tests/test_operations.py index 148b8a2f8..80262db88 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -29,9 +29,9 @@ class OperationsTests(parameterized.TestCase): metadata_type = ( - "type.googleapis.com/google.ai.generativelanguage.v1beta3.CreateTunedModelMetadata" + "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata" ) - result_type = "type.googleapis.com/google.ai.generativelanguage.v1beta3.TunedModel" + result_type = "type.googleapis.com/google.ai.generativelanguage.v1beta.TunedModel" def test_end_to_end(self): name = "my-model" diff --git a/tests/test_text.py b/tests/test_text.py index c9ab0e45f..2fd6bd8c9 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -31,8 +31,8 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client._client_manager.text_client = self.client - client._client_manager.model_client = self.client + client._client_manager.clients["text"] = self.client + client._client_manager.clients["model"] = self.client self.observed_requests = []