Skip to content

Commit 5c99bb6

Browse files
authored
vertexai[patch], genai[patch]: use tool name explicitly in with struc… (#341)
1 parent 42b92c9 commit 5c99bb6

File tree

7 files changed

+27
-15
lines changed

7 files changed

+27
-15
lines changed

libs/genai/langchain_google_genai/_function_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,17 @@ def _tool_choice_to_tool_config(
304304
) -> _ToolConfigDict:
305305
allowed_function_names: Optional[List[str]] = None
306306
if tool_choice is True or tool_choice == "any":
307-
mode = "any"
307+
mode = "ANY"
308308
allowed_function_names = all_names
309309
elif tool_choice == "auto":
310-
mode = "auto"
310+
mode = "AUTO"
311311
elif tool_choice == "none":
312-
mode = "none"
312+
mode = "NONE"
313313
elif isinstance(tool_choice, str):
314-
mode = "any"
314+
mode = "ANY"
315315
allowed_function_names = [tool_choice]
316316
elif isinstance(tool_choice, list):
317-
mode = "any"
317+
mode = "ANY"
318318
allowed_function_names = tool_choice
319319
elif isinstance(tool_choice, dict):
320320
if "mode" in tool_choice:
@@ -334,7 +334,7 @@ def _tool_choice_to_tool_config(
334334
raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
335335
return _ToolConfigDict(
336336
function_calling_config={
337-
"mode": mode,
337+
"mode": mode.upper(),
338338
"allowed_function_names": allowed_function_names,
339339
}
340340
)

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ def bind_tools(
11891189

11901190
@property
11911191
def _supports_tool_choice(self) -> bool:
1192-
return "gemini-1.5" in self.model
1192+
return "gemini-1.5-pro" in self.model
11931193

11941194

11951195
def _get_tool_name(

libs/genai/tests/integration_tests/test_chat_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,20 @@ def my_tool(name: str, age: int) -> None:
430430
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'
431431

432432

433-
def test_chat_google_genai_function_calling_with_structured_output() -> None:
433+
# Test with model that supports tool choice (gemini 1.5) and one that doesn't
434+
# (gemini 1).
435+
@pytest.mark.parametrize("model_name", [_MODEL, "models/gemini-1.5-pro-001"])
436+
def test_chat_google_genai_function_calling_with_structured_output(
437+
model_name: str,
438+
) -> None:
434439
class MyModel(BaseModel):
435440
name: str
436441
age: int
437442

438443
safety = {
439444
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
440445
}
441-
llm = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety)
446+
llm = ChatGoogleGenerativeAI(model=model_name, safety_settings=safety)
442447
model = llm.with_structured_output(MyModel)
443448
message = HumanMessage(content="My name is Erick and I am 27 years old")
444449

libs/genai/tests/unit_tests/test_function_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_format_dict_to_genai_function() -> None:
101101
def test__tool_choice_to_tool_config(choice: Any) -> None:
102102
expected = _ToolConfigDict(
103103
function_calling_config={
104-
"mode": "any",
104+
"mode": "ANY",
105105
"allowed_function_names": ["foo"],
106106
},
107107
)

libs/vertexai/langchain_google_vertexai/chat_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
_ToolChoiceType,
116116
_ToolsType,
117117
_format_to_gapic_tool,
118+
_ToolType,
118119
)
119120

120121
logger = logging.getLogger(__name__)
@@ -1601,7 +1602,8 @@ class AnswerWithJustification(BaseModel):
16011602
parser = JsonOutputToolsParser(first_tool_only=True) | RunnableGenerator(
16021603
_yield_args
16031604
)
1604-
llm = self.bind_tools([schema], tool_choice=self._is_gemini_advanced)
1605+
tool_choice = _get_tool_name(schema) if self._is_gemini_advanced else None
1606+
llm = self.bind_tools([schema], tool_choice=tool_choice)
16051607
if include_raw:
16061608
parser_with_fallback = RunnablePassthrough.assign(
16071609
parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
@@ -1751,3 +1753,8 @@ def _get_usage_metadata_non_gemini(raw_metadata: dict) -> Optional[UsageMetadata
17511753
output_tokens=output_tokens,
17521754
total_tokens=input_tokens + output_tokens,
17531755
)
1756+
1757+
1758+
def _get_tool_name(tool: _ToolType) -> str:
1759+
vertexai_tool = _format_to_gapic_tool([tool])
1760+
return [f.name for f in vertexai_tool.function_declarations][0]

libs/vertexai/langchain_google_vertexai/functions_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ class _ToolDictLike(TypedDict):
5353
retrieval: Optional[_RetrievalLike]
5454

5555

56-
_ToolsType = Sequence[
57-
Union[gapic.Tool, vertexai.Tool, _ToolDictLike, _FunctionDeclarationLike]
58-
]
56+
_ToolType = Union[gapic.Tool, vertexai.Tool, _ToolDictLike, _FunctionDeclarationLike]
57+
_ToolsType = Sequence[_ToolType]
5958

6059
_ALLOWED_SCHEMA_FIELDS = []
6160
_ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields])

libs/vertexai/langchain_google_vertexai/model_garden.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def with_structured_output(
350350
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
351351
"""Model wrapper that returns outputs formatted to match the given schema."""
352352

353-
llm = self.bind_tools([schema], tool_choice="any")
353+
tool_name = convert_to_anthropic_tool(schema)["name"]
354+
llm = self.bind_tools([schema], tool_choice=tool_name)
354355
if isinstance(schema, type) and issubclass(schema, BaseModel):
355356
output_parser = ToolsOutputParser(
356357
first_tool_only=True, pydantic_schemas=[schema]

0 commit comments

Comments
 (0)