Skip to content

Commit

Permalink
Automatic function calling. (#201)
Browse files Browse the repository at this point in the history
* Starting automatic function calling

* Working on AFC

* Fix typos

* Add tools overrides for generate_content and send_message

* Add initial AFC loop.

* Basic debugging, streaming's probably broken.

* Add error with stream=True

* format

* add pydantic

* fix tests

* replace __init__

* Fix pytype

* Remove  property

* format

* working on it

* working on it

* working on it

* format

* Add test for schema gen

* Split test

* Fix type anno & classmethod

* fixup: black

* Fix mutable defaults.

* Fix mutable defaults

---------

Co-authored-by: Mark McDonald <macd@google.com>
  • Loading branch information
MarkDaoust and markmcd authored Feb 22, 2024
1 parent b28f2d1 commit d894807
Show file tree
Hide file tree
Showing 6 changed files with 747 additions and 63 deletions.
186 changes: 158 additions & 28 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
model_name: str = "gemini-pro",
safety_settings: safety_types.SafetySettingOptions | None = None,
generation_config: generation_types.GenerationConfigType | None = None,
tools: content_types.ToolsType = None,
tools: content_types.FunctionLibraryType | None = None,
):
if "/" not in model_name:
model_name = "models/" + model_name
Expand All @@ -80,7 +80,7 @@ def __init__(
safety_settings, harm_category_set="new"
)
self._generation_config = generation_types.to_generation_config_dict(generation_config)
self._tools = content_types.to_tools(tools)
self._tools = content_types.to_function_library(tools)

self._client = None
self._async_client = None
Expand All @@ -94,8 +94,9 @@ def __str__(self):
f"""\
genai.GenerativeModel(
model_name='{self.model_name}',
generation_config={self._generation_config}.
safety_settings={self._safety_settings}
generation_config={self._generation_config},
safety_settings={self._safety_settings},
tools={self._tools},
)"""
)

Expand All @@ -107,12 +108,16 @@ def _prepare_request(
contents: content_types.ContentsType,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
**kwargs,
tools: content_types.FunctionLibraryType | None,
) -> glm.GenerateContentRequest:
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
if not contents:
raise TypeError("contents must not be empty")

tools_lib = self._get_tools_lib(tools)
if tools_lib is not None:
tools_lib = tools_lib.to_proto()

contents = content_types.to_contents(contents)

generation_config = generation_types.to_generation_config_dict(generation_config)
Expand All @@ -129,19 +134,26 @@ def _prepare_request(
contents=contents,
generation_config=merged_gc,
safety_settings=merged_ss,
tools=self._tools,
**kwargs,
tools=tools_lib,
)

def _get_tools_lib(
self, tools: content_types.FunctionLibraryType
) -> content_types.FunctionLibrary | None:
if tools is None:
return self._tools
else:
return content_types.to_function_library(tools)

def generate_content(
self,
contents: content_types.ContentsType,
*,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
request_options: dict[str, Any] | None = None,
**kwargs,
) -> generation_types.GenerateContentResponse:
"""A multipurpose function to generate responses from the model.
Expand Down Expand Up @@ -201,7 +213,7 @@ def generate_content(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
**kwargs,
tools=tools,
)
if self._client is None:
self._client = client.get_default_generative_client()
Expand Down Expand Up @@ -230,15 +242,15 @@ async def generate_content_async(
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
request_options: dict[str, Any] | None = None,
**kwargs,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `GenerativeModel.generate_content`."""
request = self._prepare_request(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
**kwargs,
tools=tools,
)
if self._async_client is None:
self._async_client = client.get_default_generative_async_client()
Expand Down Expand Up @@ -299,6 +311,7 @@ def start_chat(
self,
*,
history: Iterable[content_types.StrictContentType] | None = None,
enable_automatic_function_calling: bool = False,
) -> ChatSession:
"""Returns a `genai.ChatSession` attached to this model.
Expand All @@ -314,6 +327,7 @@ def start_chat(
return ChatSession(
model=self,
history=history,
enable_automatic_function_calling=enable_automatic_function_calling,
)


Expand Down Expand Up @@ -341,11 +355,13 @@ def __init__(
self,
model: GenerativeModel,
history: Iterable[content_types.StrictContentType] | None = None,
enable_automatic_function_calling: bool = False,
):
self.model: GenerativeModel = model
self._history: list[glm.Content] = content_types.to_contents(history)
self._last_sent: glm.Content | None = None
self._last_received: generation_types.BaseGenerateContentResponse | None = None
self.enable_automatic_function_calling = enable_automatic_function_calling

def send_message(
self,
Expand All @@ -354,7 +370,7 @@ def send_message(
generation_config: generation_types.GenerationConfigType = None,
safety_settings: safety_types.SafetySettingOptions = None,
stream: bool = False,
**kwargs,
tools: content_types.FunctionLibraryType | None = None,
) -> generation_types.GenerateContentResponse:
"""Sends the conversation history with the added message and returns the model's response.
Expand Down Expand Up @@ -387,23 +403,52 @@ def send_message(
safety_settings: Overrides for the model's safety settings.
stream: If True, yield response chunks as they are generated.
"""
if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"The `google.generativeai` SDK does not yet support `stream=True` with "
"`enable_automatic_function_calling=True`"
)

tools_lib = self.model._get_tools_lib(tools)

content = content_types.to_content(content)

if not content.role:
content.role = self._USER_ROLE

history = self.history[:]
history.append(content)

generation_config = generation_types.to_generation_config_dict(generation_config)
if generation_config.get("candidate_count", 1) > 1:
raise ValueError("Can't chat with `candidate_count > 1`")

response = self.model.generate_content(
contents=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
**kwargs,
tools=tools_lib,
)

self._check_response(response=response, stream=stream)

if self.enable_automatic_function_calling and tools_lib is not None:
self.history, content, response = self._handle_afc(
response=response,
history=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
)

self._last_sent = content
self._last_received = response

return response

def _check_response(self, *, response, stream):
if response.prompt_feedback.block_reason:
raise generation_types.BlockedPromptException(response.prompt_feedback)

Expand All @@ -415,10 +460,49 @@ def send_message(
):
raise generation_types.StopCandidateException(response.candidates[0])

self._last_sent = content
self._last_received = response
def _get_function_calls(self, response) -> list[glm.FunctionCall]:
candidates = response.candidates
if len(candidates) != 1:
raise ValueError(
f"Automatic function calling only works with 1 candidate, got: {len(candidates)}"
)
parts = candidates[0].content.parts
function_calls = [part.function_call for part in parts if part and "function_call" in part]
return function_calls

def _handle_afc(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
if not all(callable(tools_lib[fc]) for fc in function_calls):
break
history.append(response.candidates[0].content)

function_response_parts: list[glm.Part] = []
for fc in function_calls:
fr = tools_lib(fc)
assert fr is not None, (
"This should never happen, it should only return None if the declaration"
"is not callable, and that's guarded against above."
)
function_response_parts.append(fr)

return response
send = glm.Content(role=self._USER_ROLE, parts=function_response_parts)
history.append(send)

response = self.model.generate_content(
contents=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
)

self._check_response(response=response, stream=stream)

*history, content = history
return history, content, response

async def send_message_async(
self,
Expand All @@ -427,42 +511,88 @@ async def send_message_async(
generation_config: generation_types.GenerationConfigType = None,
safety_settings: safety_types.SafetySettingOptions = None,
stream: bool = False,
**kwargs,
tools: content_types.FunctionLibraryType | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `ChatSession.send_message`."""
if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"The `google.generativeai` SDK does not yet support `stream=True` with "
"`enable_automatic_function_calling=True`"
)

tools_lib = self.model._get_tools_lib(tools)

content = content_types.to_content(content)

if not content.role:
content.role = self._USER_ROLE

history = self.history[:]
history.append(content)

generation_config = generation_types.to_generation_config_dict(generation_config)
if generation_config.get("candidate_count", 1) > 1:
raise ValueError("Can't chat with `candidate_count > 1`")
response = await self.model.generate_content_async(

response = await self.model.generate_content(
contents=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
**kwargs,
tools=tools_lib,
)

if response.prompt_feedback.block_reason:
raise generation_types.BlockedPromptException(response.prompt_feedback)
self._check_response(response=response, stream=stream)

if not stream:
if response.candidates[0].finish_reason not in (
glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
glm.Candidate.FinishReason.STOP,
glm.Candidate.FinishReason.MAX_TOKENS,
):
raise generation_types.StopCandidateException(response.candidates[0])
if self.enable_automatic_function_calling and tools_lib is not None:
self.history, content, response = await self._handle_afc_async(
response=response,
history=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
)

self._last_sent = content
self._last_received = response

return response

async def _handle_afc_async(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
if not all(callable(tools_lib[fc]) for fc in function_calls):
break
history.append(response.candidates[0].content)

function_response_parts: list[glm.Part] = []
for fc in function_calls:
fr = tools_lib(fc)
assert fr is not None, (
"This should never happen, it should only return None if the declaration"
"is not callable, and that's guarded against above."
)
function_response_parts.append(fr)

send = glm.Content(role=self._USER_ROLE, parts=function_response_parts)
history.append(send)

response = await self.model.generate_content_async(
contents=history,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
)

self._check_response(response=response, stream=stream)

*history, content = history
return history, content, response

def __copy__(self):
return ChatSession(
model=self.model,
Expand Down
Loading

0 comments on commit d894807

Please sign in to comment.