Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Type,
Union,
Callable,
Literal,
)


Expand Down Expand Up @@ -141,9 +142,9 @@ class GoogleGenAI(FunctionCallingLLM):
default=None,
description="Google GenAI tool to use for the model to augment responses.",
)
use_file_api: bool = Field(
default=True,
description="Whether or not to use the FileAPI for large files (>20MB).",
file_mode: Literal["inline", "fileapi", "hybrid"] = Field(
default="hybrid",
description="Whether to use inline-only, FileAPI-only or both for handling files.",
)

_max_tokens: int = PrivateAttr()
Expand All @@ -167,7 +168,7 @@ def __init__(
is_function_calling_model: bool = True,
cached_content: Optional[str] = None,
built_in_tool: Optional[types.Tool] = None,
use_file_api: bool = True,
file_mode: Literal["inline", "fileapi", "hybrid"] = "hybrid",
**kwargs: Any,
):
# API keys are optional. The API can be authorised via OAuth (detected
Expand Down Expand Up @@ -216,7 +217,7 @@ def __init__(
max_retries=max_retries,
cached_content=cached_content,
built_in_tool=built_in_tool,
use_file_api=use_file_api,
file_mode=file_mode,
**kwargs,
)

Expand Down Expand Up @@ -309,20 +310,17 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any):
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = asyncio.run(
next_msg, chat_kwargs, file_api_names = asyncio.run(
prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
self.model, messages, self.file_mode, self._client, **params
)
)
chat = self._client.chats.create(**chat_kwargs)
response = chat.send_message(
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
)

if self.use_file_api:
asyncio.run(
delete_uploaded_files([*chat_kwargs["history"], next_msg], self._client)
)
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return chat_from_gemini_response(response)

Expand All @@ -333,18 +331,15 @@ async def _achat(self, messages: Sequence[ChatMessage], **kwargs: Any):
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = await prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
next_msg, chat_kwargs, file_api_names = await prepare_chat_params(
self.model, messages, self.file_mode, self._client, **params
)
chat = self._client.aio.chats.create(**chat_kwargs)
response = await chat.send_message(
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
)

if self.use_file_api:
await delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
await delete_uploaded_files(file_api_names, self._client)

return chat_from_gemini_response(response)

Expand All @@ -366,9 +361,9 @@ def _stream_chat(
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = asyncio.run(
next_msg, chat_kwargs, file_api_names = asyncio.run(
prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
self.model, messages, self.file_mode, self._client, **params
)
)
chat = self._client.chats.create(**chat_kwargs)
Expand Down Expand Up @@ -402,12 +397,8 @@ def gen() -> ChatResponseGen:
llama_resp.message.blocks = [ThinkingBlock(content=thoughts)]
yield llama_resp

if self.use_file_api:
asyncio.run(
delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
)
if self.file_mode in ("fileapi", "hybrid"):
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return gen()

Expand All @@ -425,8 +416,8 @@ async def _astream_chat(
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = await prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
next_msg, chat_kwargs, file_api_names = await prepare_chat_params(
self.model, messages, self.file_mode, self._client, **params
)
chat = self._client.aio.chats.create(**chat_kwargs)

Expand Down Expand Up @@ -463,10 +454,7 @@ async def gen() -> ChatResponseAsyncGen:
]
yield llama_resp

if self.use_file_api:
await delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
await delete_uploaded_files(file_api_names, self._client)

return gen()

Expand Down Expand Up @@ -589,12 +577,13 @@ def structured_predict_without_function_calling(
llm_kwargs = llm_kwargs or {}

messages = prompt.format_messages(**prompt_args)
contents = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
)
contents_and_names = [
asyncio.run(chat_message_to_gemini(message, self.file_mode, self._client))
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = self._client.models.generate_content(
model=self.model,
contents=contents,
Expand All @@ -609,8 +598,7 @@ def structured_predict_without_function_calling(
},
)

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))
Copy link
Collaborator

@logan-markewich logan-markewich Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally programs can only have a single asyncio.run() call. Beyond that, you need nest_asyncio or similar

Not really a fan of having any of these inside the sync endpoints in the first place. I wonder if we can clean this up? I actually hit this issue the other day


if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -639,20 +627,22 @@ def structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = [
contents_and_names = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
)
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = self._client.models.generate_content(
model=self.model,
contents=contents,
config=generation_config,
)

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))

if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -686,20 +676,22 @@ async def astructured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = await asyncio.gather(
contents_and_names = await asyncio.gather(
*[
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
for message in messages
]
)
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = await self._client.aio.models.generate_content(
model=self.model,
contents=contents,
config=generation_config,
)

if self.use_file_api:
await delete_uploaded_files(contents, self._client)
await delete_uploaded_files(file_api_names, self._client)

if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -733,12 +725,14 @@ def stream_structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = [
contents_and_names = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
)
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
flexible_model = create_flexible_model(output_cls)
Expand All @@ -762,8 +756,7 @@ def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
if streaming_model:
yield streaming_model

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return gen()
else:
Expand Down Expand Up @@ -793,12 +786,14 @@ async def astream_structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = await asyncio.gather(
contents_and_names = await asyncio.gather(
*[
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
for message in messages
]
)
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
flexible_model = create_flexible_model(output_cls)
Expand All @@ -822,8 +817,7 @@ async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
if streaming_model:
yield streaming_model

if self.use_file_api:
await delete_uploaded_files(contents, self._client)
await delete_uploaded_files(file_api_names, self._client)

return gen()
else:
Expand Down
Loading
Loading