From 4a0a2d37ba369ddd3440ce596dd901299694fbd8 Mon Sep 17 00:00:00 2001 From: Shahar Yair Date: Fri, 27 Dec 2024 09:37:56 +0200 Subject: [PATCH] genai: Fixed nested pydantic structures recursion (#658) --- .../langchain_google_genai/_function_utils.py | 24 +- .../tests/unit_tests/test_function_utils.py | 218 +++++++++++++++++- 2 files changed, 237 insertions(+), 5 deletions(-) diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index d7227268..829f9549 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -315,10 +315,26 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): properties_item["items"] = _get_items_from_schema_any(v.get("items")) - if properties_item.get("type_") == glm.Type.OBJECT and v.get("properties"): - properties_item["properties"] = _get_properties_from_schema_any( - v.get("properties") - ) + if properties_item.get("type_") == glm.Type.OBJECT: + if ( + v.get("anyOf") + and isinstance(v["anyOf"], list) + and isinstance(v["anyOf"][0], dict) + ): + v = v["anyOf"][0] + v_properties = v.get("properties") + if v_properties: + properties_item["properties"] = _get_properties_from_schema_any( + v_properties + ) + if isinstance(v_properties, dict): + properties_item["required"] = [ + k for k, v in v_properties.items() if "default" not in v + ] + else: + # Providing dummy type for object without properties + properties_item["type_"] = glm.Type.STRING + if k == "title" and "description" not in properties_item: properties_item["description"] = k + " is " + str(v) diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 536c2a99..482248e9 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, Optional, Union +from typing import Any, Generator, List, Optional, Union from unittest.mock import MagicMock, patch import google.ai.generativelanguage as glm @@ -17,6 +17,222 @@ ) +def test_tool_with_anyof_nullable_param() -> None: + """ + Example test that checks a string parameter marked as Optional, + verifying it's recognized as a 'string' & 'nullable'. + """ + + @tool(parse_docstring=True) + def possibly_none( + a: Optional[str] = None, + ) -> str: + """ + A test function whose argument can be a string or None. + + Args: + a: Possibly none. + """ + return "value" + + # Convert to OpenAI, then to GenAI, then to dict + oai_tool = convert_to_openai_tool(possibly_none) + genai_tool = convert_to_genai_function_declarations([oai_tool]) + genai_tool_dict = tool_to_dict(genai_tool) + assert isinstance(genai_tool_dict, dict), "Expected a dict." + + function_declarations = genai_tool_dict.get("function_declarations") + assert isinstance( + function_declarations, + list, + ), "Expected a list." + + fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected a dict." + + parameters = fn_decl.get("parameters") + assert isinstance(parameters, dict), "Expected a dict." + + properties = parameters.get("properties") + assert isinstance(properties, dict), "Expected a dict." + + a_property = properties.get("a") + assert isinstance(a_property, dict), "Expected a dict." + + assert a_property.get("type_") == glm.Type.STRING, "Expected 'a' to be STRING." + assert a_property.get("nullable") is True, "Expected 'a' to be marked as nullable." + + +def test_tool_with_array_anyof_nullable_param() -> None: + """ + Checks an array parameter marked as Optional, verifying it's recognized + as an 'array' & 'nullable', and that the items are correctly typed. + """ + + @tool(parse_docstring=True) + def possibly_none_list( + items: Optional[List[str]] = None, + ) -> str: + """ + A test function whose argument can be a list of strings or None. + + Args: + items: Possibly a list of strings or None. + """ + return "value" + + # Convert to OpenAI tool + oai_tool = convert_to_openai_tool(possibly_none_list) + + # Manually assign the 'items' type in the parameters + oai_tool["function"]["parameters"]["properties"]["items"]["items"] = { + "type": "string" + } + + # Convert to GenAI, then to dict + genai_tool = convert_to_genai_function_declarations([oai_tool]) + genai_tool_dict = tool_to_dict(genai_tool) + assert isinstance(genai_tool_dict, dict), "Expected a dict." + + function_declarations = genai_tool_dict.get("function_declarations") + assert isinstance(function_declarations, list), "Expected a list." + + fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected a dict." + + parameters = fn_decl.get("parameters") + assert isinstance(parameters, dict), "Expected a dict." + + properties = parameters.get("properties") + assert isinstance(properties, dict), "Expected a dict." + + items_property = properties.get("items") + assert isinstance(items_property, dict), "Expected a dict." + + # Assertions + assert ( + items_property.get("type_") == glm.Type.ARRAY + ), "Expected 'items' to be ARRAY." + assert items_property.get("nullable"), "Expected 'items' to be marked as nullable." + # Check that the array items are recognized as strings + + items = items_property.get("items") + assert isinstance(items, dict), "Expected 'items' to be a dict." + + assert items.get("type_") == glm.Type.STRING, "Expected array items to be STRING." + + +def test_tool_with_nested_object_anyof_nullable_param() -> None: + """ + Checks an object parameter (dict) marked as Optional, verifying it's recognized + as an 'object' but defaults to string if there are no real properties, + and that it is 'nullable'. + """ + + @tool(parse_docstring=True) + def possibly_none_dict( + data: Optional[dict] = None, + ) -> str: + """ + A test function whose argument can be an object (dict) or None. + + Args: + data: Possibly a dict or None. + """ + return "value" + + # Convert to OpenAI, then to GenAI, then to dict + oai_tool = convert_to_openai_tool(possibly_none_dict) + genai_tool = convert_to_genai_function_declarations([oai_tool]) + genai_tool_dict = tool_to_dict(genai_tool) + assert isinstance(genai_tool_dict, dict), "Expected a dict." + + function_declarations = genai_tool_dict.get("function_declarations") + assert isinstance(function_declarations, list), "Expected a list." + + fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected a dict." + + parameters = fn_decl.get("parameters") + assert isinstance(parameters, dict), "Expected a dict." + + properties = parameters.get("properties") + assert isinstance(properties, dict), "Expected a dict." + + data_property = properties.get("data") + assert isinstance(data_property, dict), "Expected a dict." + + assert data_property.get("type_") in [ + glm.Type.OBJECT, + glm.Type.STRING, + ], "Expected 'data' to be recognized as an OBJECT or fallback to STRING." + assert ( + data_property.get("nullable") is True + ), "Expected 'data' to be marked as nullable." + + +def test_tool_with_enum_anyof_nullable_param() -> None: + """ + Checks a parameter with an enum, marked as Optional, verifying it's recognized + as 'string' & 'nullable', and that the 'enum' field is captured. + """ + + @tool(parse_docstring=True) + def possibly_none_enum( + status: Optional[str] = None, + ) -> str: + """ + A test function whose argument can be an enum string or None. + + Args: + status: Possibly one of ("active", "inactive", "pending") or None. + """ + return "value" + + # Convert to OpenAI tool + oai_tool = convert_to_openai_tool(possibly_none_enum) + + # Manually override the 'enum' for the 'status' property in the parameters + oai_tool["function"]["parameters"]["properties"]["status"]["enum"] = [ + "active", + "inactive", + "pending", + ] + + # Convert to GenAI, then to dict + genai_tool = convert_to_genai_function_declarations([oai_tool]) + genai_tool_dict = tool_to_dict(genai_tool) + assert isinstance(genai_tool_dict, dict), "Expected a dict." + + function_declarations = genai_tool_dict.get("function_declarations") + assert isinstance(function_declarations, list), "Expected a list." + + fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected a dict." + + parameters = fn_decl.get("parameters") + assert isinstance(parameters, dict), "Expected a dict." + + properties = parameters.get("properties") + assert isinstance(properties, dict), "Expected a dict." + + status_property = properties.get("status") + assert isinstance(status_property, dict), "Expected a dict." + + # Assertions + assert ( + status_property.get("type_") == glm.Type.STRING + ), "Expected 'status' to be STRING." + assert ( + status_property.get("nullable") is True + ), "Expected 'status' to be marked as nullable." + assert status_property.get("enum") == [ + "active", + "inactive", + "pending", + ], "Expected 'status' to have enum values." + + def test_format_tool_to_genai_function() -> None: @tool def get_datetime() -> str: