From 3e19a3fd360403ad4b5ab6ff4a5ee3e0b3092a43 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 28 Aug 2024 16:31:14 -0300 Subject: [PATCH] feat: allow passing a Component to the set method (#3597) * refactor: Add _find_matching_output_method to Component class * feat: allow components to be passed in set method * fix: Add test for graph set method with valid component * set value variable to the output callable * refactor: Update test_component.py to use set_component method This commit refactors the test_component.py file in the custom_component directory. The test_set_invalid_input() function has been renamed to test_set_component() to better reflect its purpose. Additionally, the test_set_component() function now sets the agent parameter using the set_component() method instead of raising a ValueError. This change improves the readability and maintainability of the code. * refactor: Fix formatting issue in _build_error_string_from_matching_pairs The _build_error_string_from_matching_pairs method in the Component class had a formatting issue when input types were empty. This commit fixes the issue by adding a check for empty input types and providing an empty list as a fallback. This improves the accuracy and readability of the error string generated by the method. * fix(component.py): add validation to ensure output method is a string to prevent potential runtime errors --- .../custom/custom_component/component.py | 31 +++++++++++++++++++ .../custom/custom_component/test_component.py | 10 +++--- .../tests/unit/graph/graph/test_base.py | 15 +++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index be88d2d251ed..da14c673c627 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -310,9 +310,40 @@ def _method_is_valid_output(self, method: Callable): ) return method_is_output + def _build_error_string_from_matching_pairs(self, matching_pairs: list[tuple[Output, Input]]): + text = "" + for output, input_ in matching_pairs: + text += f"{output.name}[{','.join(output.types)}]->{input_.name}[{','.join(input_.input_types or [])}]\n" + return text + + def _find_matching_output_method(self, value: "Component"): + # get all outputs of the value component + outputs = value.outputs + # check if the any of the types in the output.types matches ONLY one input in the current component + matching_pairs = [] + for output in outputs: + for input_ in self.inputs: + for output_type in output.types: + if input_.input_types and output_type in input_.input_types: + matching_pairs.append((output, input_)) + if len(matching_pairs) > 1: + matching_pairs_str = self._build_error_string_from_matching_pairs(matching_pairs) + raise ValueError( + f"There are multiple outputs from {value.__class__.__name__} that can connect to inputs in {self.__class__.__name__}: {matching_pairs_str}" + ) + output, input_ = matching_pairs[0] + if not isinstance(output.method, str): + raise ValueError(f"Method {output.method} is not a valid output of {value.__class__.__name__}") + return getattr(value, output.method) + def _process_connection_or_parameter(self, key, value): _input = self._get_or_create_input(key) # We need to check if callable AND if it is a method from a class that inherits from Component + if isinstance(value, Component): + # We need to find the Output that can connect to an input of the current component + # if there's more than one output that matches, we need to raise an error + # because we don't know which one to connect to + value = self._find_matching_output_method(value) if callable(value) and self._inherits_from_component(value): try: self._method_is_valid_output(value) diff --git a/src/backend/tests/unit/custom/custom_component/test_component.py b/src/backend/tests/unit/custom/custom_component/test_component.py index b52d993e0e27..d5d1821ece99 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component.py +++ b/src/backend/tests/unit/custom/custom_component/test_component.py @@ -18,11 +18,9 @@ def test_set_invalid_output(): chatoutput.set(input_value=chatinput.build_config) -def test_set_invalid_input(): +def test_set_component(): crewai_agent = CrewAIAgentComponent() task = SequentialTaskComponent() - with pytest.raises( - ValueError, - match="You set CrewAI Agent as value for `agent`. You should pass one of the following: 'build_output'", - ): - task.set(agent=crewai_agent) + task.set(agent=crewai_agent) + assert task._edges[0]["source"] == crewai_agent._id + assert crewai_agent in task._components diff --git a/src/backend/tests/unit/graph/graph/test_base.py b/src/backend/tests/unit/graph/graph/test_base.py index 59908d9dca43..136dbdc56108 100644 --- a/src/backend/tests/unit/graph/graph/test_base.py +++ b/src/backend/tests/unit/graph/graph/test_base.py @@ -2,9 +2,11 @@ import pytest +from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent from langflow.components.inputs.ChatInput import ChatInput from langflow.components.outputs.ChatOutput import ChatOutput from langflow.components.outputs.TextOutput import TextOutputComponent +from langflow.components.tools.YfinanceTool import YfinanceToolComponent from langflow.graph.graph.base import Graph from langflow.graph.graph.constants import Finish @@ -139,3 +141,16 @@ def test_graph_functional_start_end(): assert len(results) == len(ids) + 1 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) assert results[-1] == Finish() + + +def test_graph_set_with_invalid_component(): + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + with pytest.raises(ValueError, match="There are multiple outputs"): + chat_output.set(sender_name=chat_input) + + +def test_graph_set_with_valid_component(): + tool = YfinanceToolComponent() + tool_calling_agent = ToolCallingAgentComponent() + tool_calling_agent.set(tools=[tool])