diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py index deb3acffa3..3e88f5ff9c 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py @@ -17,6 +17,7 @@ Type, Union, Callable, + Literal, ) @@ -141,9 +142,9 @@ class GoogleGenAI(FunctionCallingLLM): default=None, description="Google GenAI tool to use for the model to augment responses.", ) - use_file_api: bool = Field( - default=True, - description="Whether or not to use the FileAPI for large files (>20MB).", + file_mode: Literal["inline", "fileapi", "hybrid"] = Field( + default="hybrid", + description="Whether to use inline-only, FileAPI-only or both for handling files.", ) _max_tokens: int = PrivateAttr() @@ -167,7 +168,7 @@ def __init__( is_function_calling_model: bool = True, cached_content: Optional[str] = None, built_in_tool: Optional[types.Tool] = None, - use_file_api: bool = True, + file_mode: Literal["inline", "fileapi", "hybrid"] = "hybrid", **kwargs: Any, ): # API keys are optional. The API can be authorised via OAuth (detected @@ -216,7 +217,7 @@ def __init__( max_retries=max_retries, cached_content=cached_content, built_in_tool=built_in_tool, - use_file_api=use_file_api, + file_mode=file_mode, **kwargs, ) @@ -309,9 +310,9 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any): **kwargs.pop("generation_config", {}), } params = {**kwargs, "generation_config": generation_config} - next_msg, chat_kwargs = asyncio.run( + next_msg, chat_kwargs, file_api_names = asyncio.run( prepare_chat_params( - self.model, messages, self.use_file_api, self._client, **params + self.model, messages, self.file_mode, self._client, **params ) ) chat = self._client.chats.create(**chat_kwargs) @@ -319,10 +320,7 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any): next_msg.parts if isinstance(next_msg, types.Content) else next_msg ) - if self.use_file_api: - asyncio.run( - delete_uploaded_files([*chat_kwargs["history"], next_msg], self._client) - ) + asyncio.run(delete_uploaded_files(file_api_names, self._client)) return chat_from_gemini_response(response) @@ -333,18 +331,15 @@ async def _achat(self, messages: Sequence[ChatMessage], **kwargs: Any): **kwargs.pop("generation_config", {}), } params = {**kwargs, "generation_config": generation_config} - next_msg, chat_kwargs = await prepare_chat_params( - self.model, messages, self.use_file_api, self._client, **params + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( + self.model, messages, self.file_mode, self._client, **params ) chat = self._client.aio.chats.create(**chat_kwargs) response = await chat.send_message( next_msg.parts if isinstance(next_msg, types.Content) else next_msg ) - if self.use_file_api: - await delete_uploaded_files( - [*chat_kwargs["history"], next_msg], self._client - ) + await delete_uploaded_files(file_api_names, self._client) return chat_from_gemini_response(response) @@ -366,9 +361,9 @@ def _stream_chat( **kwargs.pop("generation_config", {}), } params = {**kwargs, "generation_config": generation_config} - next_msg, chat_kwargs = asyncio.run( + next_msg, chat_kwargs, file_api_names = asyncio.run( prepare_chat_params( - self.model, messages, self.use_file_api, self._client, **params + self.model, messages, self.file_mode, self._client, **params ) ) chat = self._client.chats.create(**chat_kwargs) @@ -402,12 +397,8 @@ def gen() -> ChatResponseGen: llama_resp.message.blocks = [ThinkingBlock(content=thoughts)] yield llama_resp - if self.use_file_api: - asyncio.run( - delete_uploaded_files( - [*chat_kwargs["history"], next_msg], self._client - ) - ) + if self.file_mode in ("fileapi", "hybrid"): + asyncio.run(delete_uploaded_files(file_api_names, self._client)) return gen() @@ -425,8 +416,8 @@ async def _astream_chat( **kwargs.pop("generation_config", {}), } params = {**kwargs, "generation_config": generation_config} - next_msg, chat_kwargs = await prepare_chat_params( - self.model, messages, self.use_file_api, self._client, **params + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( + self.model, messages, self.file_mode, self._client, **params ) chat = self._client.aio.chats.create(**chat_kwargs) @@ -463,10 +454,7 @@ async def gen() -> ChatResponseAsyncGen: ] yield llama_resp - if self.use_file_api: - await delete_uploaded_files( - [*chat_kwargs["history"], next_msg], self._client - ) + await delete_uploaded_files(file_api_names, self._client) return gen() @@ -589,12 +577,13 @@ def structured_predict_without_function_calling( llm_kwargs = llm_kwargs or {} messages = prompt.format_messages(**prompt_args) - contents = [ - asyncio.run( - chat_message_to_gemini(message, self.use_file_api, self._client) - ) + contents_and_names = [ + asyncio.run(chat_message_to_gemini(message, self.file_mode, self._client)) for message in messages ] + contents = [it[0] for it in contents_and_names] + file_api_names = [name for it in contents_and_names for name in it[1]] + response = self._client.models.generate_content( model=self.model, contents=contents, @@ -609,8 +598,7 @@ def structured_predict_without_function_calling( }, ) - if self.use_file_api: - asyncio.run(delete_uploaded_files(contents, self._client)) + asyncio.run(delete_uploaded_files(file_api_names, self._client)) if isinstance(response.parsed, BaseModel): return response.parsed @@ -639,20 +627,22 @@ def structured_predict( generation_config["response_schema"] = output_cls messages = prompt.format_messages(**prompt_args) - contents = [ + contents_and_names = [ asyncio.run( - chat_message_to_gemini(message, self.use_file_api, self._client) + chat_message_to_gemini(message, self.file_mode, self._client) ) for message in messages ] + contents = [it[0] for it in contents_and_names] + file_api_names = [name for it in contents_and_names for name in it[1]] + response = self._client.models.generate_content( model=self.model, contents=contents, config=generation_config, ) - if self.use_file_api: - asyncio.run(delete_uploaded_files(contents, self._client)) + asyncio.run(delete_uploaded_files(file_api_names, self._client)) if isinstance(response.parsed, BaseModel): return response.parsed @@ -686,20 +676,22 @@ async def astructured_predict( generation_config["response_schema"] = output_cls messages = prompt.format_messages(**prompt_args) - contents = await asyncio.gather( + contents_and_names = await asyncio.gather( *[ - chat_message_to_gemini(message, self.use_file_api, self._client) + chat_message_to_gemini(message, self.file_mode, self._client) for message in messages ] ) + contents = [it[0] for it in contents_and_names] + file_api_names = [name for it in contents_and_names for name in it[1]] + response = await self._client.aio.models.generate_content( model=self.model, contents=contents, config=generation_config, ) - if self.use_file_api: - await delete_uploaded_files(contents, self._client) + await delete_uploaded_files(file_api_names, self._client) if isinstance(response.parsed, BaseModel): return response.parsed @@ -733,12 +725,14 @@ def stream_structured_predict( generation_config["response_schema"] = output_cls messages = prompt.format_messages(**prompt_args) - contents = [ + contents_and_names = [ asyncio.run( - chat_message_to_gemini(message, self.use_file_api, self._client) + chat_message_to_gemini(message, self.file_mode, self._client) ) for message in messages ] + contents = [it[0] for it in contents_and_names] + file_api_names = [name for it in contents_and_names for name in it[1]] def gen() -> Generator[Union[Model, FlexibleModel], None, None]: flexible_model = create_flexible_model(output_cls) @@ -762,8 +756,7 @@ def gen() -> Generator[Union[Model, FlexibleModel], None, None]: if streaming_model: yield streaming_model - if self.use_file_api: - asyncio.run(delete_uploaded_files(contents, self._client)) + asyncio.run(delete_uploaded_files(file_api_names, self._client)) return gen() else: @@ -793,12 +786,14 @@ async def astream_structured_predict( generation_config["response_schema"] = output_cls messages = prompt.format_messages(**prompt_args) - contents = await asyncio.gather( + contents_and_names = await asyncio.gather( *[ - chat_message_to_gemini(message, self.use_file_api, self._client) + chat_message_to_gemini(message, self.file_mode, self._client) for message in messages ] ) + contents = [it[0] for it in contents_and_names] + file_api_names = [name for it in contents_and_names for name in it[1]] async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]: flexible_model = create_flexible_model(output_cls) @@ -822,8 +817,7 @@ async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]: if streaming_model: yield streaming_model - if self.use_file_api: - await delete_uploaded_files(contents, self._client) + await delete_uploaded_files(file_api_names, self._client) return gen() else: diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py index 3b0f3f0a79..8098ccbbbc 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py @@ -2,8 +2,18 @@ import json import logging from collections.abc import Sequence -from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Union, Optional, Type, Tuple, cast +from io import IOBase +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Union, + Optional, + Type, + Tuple, + Literal, + cast, +) import typing import google.genai.types as types @@ -232,25 +242,30 @@ def chat_from_gemini_response( async def create_file_part( - file_bytes: bytes, mime_type: str, use_file_api: bool, client: Optional[Client] -) -> types.PartUnion: + file_buffer: IOBase, + mime_type: str, + file_mode: Literal["inline", "fileapi", "hybrid"], + client: Optional[Client], +) -> tuple[types.Part, Optional[str]]: """Create a Part or File object for the given file depending on its size.""" - if ( - not use_file_api - or len(file_bytes) - < 20 * 1024 * 1024 # 20MB is the Gemini inline data size limit - ): - return types.Part.from_bytes( - data=file_bytes, - mime_type=mime_type, - ) + if file_mode in ("inline", "hybrid"): + file_buffer.seek(0, 2) # Seek to end + size = file_buffer.tell() # Get file size + file_buffer.seek(0) # Reset to beginning + + if size < 20 * 1024 * 1024: # 20MB is the Gemini inline data size limit + return types.Part.from_bytes( + data=file_buffer.read(), + mime_type=mime_type, + ), None + elif file_mode == "inline": + raise ValueError("Files in inline mode must be smaller than 20MB.") if client is None: raise ValueError("A Google GenAI client must be provided for use with FileAPI.") - buffer = BytesIO(file_bytes) file = await client.aio.files.upload( - file=buffer, config=types.UploadFileConfig(mime_type=mime_type) + file=file_buffer, config=types.UploadFileConfig(mime_type=mime_type) ) # Wait for file processing @@ -261,35 +276,37 @@ async def create_file_part( if file.state.name == "FAILED": raise ValueError("Failed to upload the file with FileAPI") - return file + return types.Part.from_uri( + file_uri=file.uri, + mime_type=mime_type, + ), file.name -async def delete_uploaded_files( - contents: list[Union[types.Content, types.File]], client: Client -) -> None: +async def delete_uploaded_files(file_api_names: list[str], client: Client) -> None: """Delete files uploaded with File API.""" await asyncio.gather( - *[ - client.aio.files.delete(name=content.name) - for content in contents - if isinstance(content, types.File) - ] + *[client.aio.files.delete(name=name) for name in file_api_names] ) async def chat_message_to_gemini( - message: ChatMessage, use_file_api: bool = False, client: Optional[Client] = None -) -> Union[types.Content, types.File]: + message: ChatMessage, + file_mode: Literal["inline", "fileapi", "hybrid"] = "hybrid", + client: Optional[Client] = None, +) -> tuple[types.Content, list[str]]: """Convert ChatMessages to Gemini-specific history, including ImageDocuments.""" unique_tool_calls = [] parts = [] + file_api_names = [] part = None for index, block in enumerate(message.blocks): + file_api_name = None + if isinstance(block, TextBlock): if block.text: part = types.Part.from_text(text=block.text) elif isinstance(block, ImageBlock): - file_bytes = block.resolve_image(as_base64=False).read() + file_buffer = block.resolve_image(as_base64=False) mime_type = ( block.image_mimetype @@ -297,13 +314,11 @@ async def chat_message_to_gemini( else "image/jpeg" # TODO: Fail? ) - part = await create_file_part(file_bytes, mime_type, use_file_api, client) - - if isinstance(part, types.File): - return part # Return the file as it is a message content and not a part + part, file_api_name = await create_file_part( + file_buffer, mime_type, file_mode, client + ) elif isinstance(block, VideoBlock): file_buffer = block.resolve_video(as_base64=False) - file_bytes = file_buffer.read() mime_type = ( block.video_mimetype @@ -311,25 +326,22 @@ async def chat_message_to_gemini( else "video/mp4" # TODO: Fail? ) - part = await create_file_part(file_bytes, mime_type, use_file_api, client) - - if isinstance(part, types.File): - return part # Return the file as it is a message content and not a part - + part, file_api_name = await create_file_part( + file_buffer, mime_type, file_mode, client + ) part.video_metadata = types.VideoMetadata(fps=block.fps) - elif isinstance(block, DocumentBlock): file_buffer = block.resolve_document() - file_bytes = file_buffer.read() + mime_type = ( block.document_mimetype if block.document_mimetype is not None else "application/pdf" ) - part = await create_file_part(file_bytes, mime_type, use_file_api, client) - if isinstance(part, types.File): - return part # Return the file as it is a message content and not a part + part, file_api_name = await create_file_part( + file_buffer, mime_type, file_mode, client + ) elif isinstance(block, ThinkingBlock): if block.content: part = types.Part.from_text(text=block.content) @@ -345,6 +357,10 @@ async def chat_message_to_gemini( else: msg = f"Unsupported content block type: {type(block).__name__}" raise ValueError(msg) + + if file_api_name is not None: + file_api_names.append(file_api_name) + if part is not None: if message.role == MessageRole.MODEL: thought_signatures = message.additional_kwargs.get( @@ -385,12 +401,12 @@ async def chat_message_to_gemini( ) return types.Content( role=ROLES_TO_GEMINI[message.role], parts=[function_response_part] - ) + ), file_api_names return types.Content( role=ROLES_TO_GEMINI[message.role], parts=parts, - ) + ), file_api_names def convert_schema_to_function_declaration( @@ -427,16 +443,16 @@ class ChatParams(typing.TypedDict): async def prepare_chat_params( model: str, messages: Sequence[ChatMessage], - use_file_api: bool = False, + file_mode: Literal["inline", "fileapi", "hybrid"] = "hybrid", client: Optional[Client] = None, **kwargs: Any, -) -> tuple[Union[types.Content, types.File], ChatParams]: +) -> tuple[types.Content, ChatParams, list[str]]: """ Prepare common parameters for chat creation. Args: messages: Sequence of chat messages - use_file_api: Whether to use File API or not for large files. + file_mode: The mode for file uploading client: Google Genai client used for uploading large files. **kwargs: Additional keyword arguments @@ -444,6 +460,7 @@ async def prepare_chat_params( tuple containing: - next_msg: the next message to send - chat_kwargs: processed keyword arguments for chat creation + - file_api_names: list of file api names to delete after chat call """ # Extract system message if present @@ -455,12 +472,14 @@ async def prepare_chat_params( # Merge messages with the same role merged_messages = merge_neighboring_same_role_messages(messages) - initial_history = await asyncio.gather( + initial_history_and_names = await asyncio.gather( *[ - chat_message_to_gemini(message, use_file_api, client) + chat_message_to_gemini(message, file_mode, client) for message in merged_messages ] ) + initial_history = [it[0] for it in initial_history_and_names] + file_api_names = [name for it in initial_history_and_names for name in it[1]] # merge tool messages into a single tool message # while maintaining the tool names @@ -527,7 +546,7 @@ async def prepare_chat_params( chat_kwargs["config"] = types.GenerateContentConfig(**config) - return next_msg, chat_kwargs + return next_msg, chat_kwargs, file_api_names def handle_streaming_flexible_model( diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml index 3f07a8fce9..e0ea120af0 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml @@ -27,7 +27,7 @@ dev = [ [project] name = "llama-index-llms-google-genai" -version = "0.7.1" +version = "0.8.0" description = "llama-index llms google genai integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py index 10a6d42550..4090eb7980 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py @@ -771,7 +771,7 @@ async def test_prepare_chat_params_more_than_2_tool_calls(): ChatMessage(content="Here is a list of puppies.", role=MessageRole.ASSISTANT), ] - next_msg, chat_kwargs = await prepare_chat_params( + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( expected_model_name, test_messages ) @@ -831,7 +831,9 @@ async def test_prepare_chat_params_with_system_message(): ] # Execute prepare_chat_params - next_msg, chat_kwargs = await prepare_chat_params(model_name, messages) + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( + model_name, messages + ) # Verify system_prompt is forwarded to system_instruction cfg = chat_kwargs["config"] @@ -1061,7 +1063,7 @@ async def test_cached_content_in_chat_params() -> None: messages = [ChatMessage(content="Test message", role=MessageRole.USER)] # Prepare chat params with the LLM's generation config - next_msg, chat_kwargs = await prepare_chat_params( + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( llm.model, messages, generation_config=llm._generation_config ) @@ -1231,7 +1233,7 @@ async def test_built_in_tool_in_chat_params() -> None: ) # Prepare chat params - next_msg, chat_kwargs = await prepare_chat_params( + next_msg, chat_kwargs, file_api_names = await prepare_chat_params( llm.model, messages, generation_config=llm._generation_config ) diff --git a/llama-index-integrations/llms/llama-index-llms-google-genai/uv.lock b/llama-index-integrations/llms/llama-index-llms-google-genai/uv.lock index 7a1c666f04..a3de07fa9e 100644 --- a/llama-index-integrations/llms/llama-index-llms-google-genai/uv.lock +++ b/llama-index-integrations/llms/llama-index-llms-google-genai/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9, <4.0" resolution-markers = [ "python_full_version >= '3.11'",