diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 50b15261a..ee6a81b69 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -4,12 +4,9 @@ from collections.abc import Iterable import textwrap -from typing import Any, Union, overload +from typing import overload import reprlib -# pylint: disable=bad-continuation, line-too-long - - import google.api_core.exceptions from google.generativeai import protos from google.generativeai import client @@ -509,6 +506,52 @@ def __init__( self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling + def count_tokens( + self, + content: content_types.ContentType | None = None, + *, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ): + history = self.history[:] + + if content is not None: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history.append(content) + + return self.model.count_tokens( + contents=history, + tools=tools, + tool_config=tool_config, + request_options=request_options, + ) + + async def count_tokens_async( + self, + content: content_types.ContentType | None = None, + *, + tools: content_types.FunctionLibraryType | None = None, + tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, + ): + history = self.history[:] + + if content is not None: + content = content_types.to_content(content) + if not content.role: + content.role = self._USER_ROLE + history.append(content) + + return await self.model.count_tokens( + contents=history, + tools=tools, + tool_config=tool_config, + request_options=request_options, + ) + def send_message( self, content: content_types.ContentType, diff --git a/tests/test_async_code_match.py b/tests/test_async_code_match.py index 0ec4550d4..2e45a9973 100644 --- a/tests/test_async_code_match.py +++ b/tests/test_async_code_match.py @@ -62,7 +62,7 @@ def _inspect_decorator_exemption(self, node, fpath) -> bool: return False - def _execute_code_match(self, source, asource): + def _execute_code_match(self, source, asource, fpath): asource = ( asource.replace("anext", "next") .replace("aiter", "iter") @@ -73,7 +73,7 @@ def _execute_code_match(self, source, asource): .replace("ASYNC_", "") ) asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource) + self.assertEqual(source, asource, f"Matching {fpath}") def test_code_match_for_async_methods(self): for fpath in (pathlib.Path(__file__).parent.parent / "google").rglob("*.py"): @@ -101,7 +101,7 @@ def test_code_match_for_async_methods(self): ) func_source = self._maybe_trim_docstring(snode) func_asource = self._maybe_trim_docstring(anode) - self._execute_code_match(func_source, func_asource) + self._execute_code_match(func_source, func_asource, fpath) # print(f"Matched {node.name}") else: code_match_funcs[node.name] = node