Skip to content

Commit

Permalink
chore: upgrade type annotations in graph and custom modules (#3591)
Browse files Browse the repository at this point in the history
* refactor(tests): update import statements in conftest.py to use collections.abc module for better compatibility and maintainability

* run pyupgrade on graph module

* [autofix.ci] apply automated fixes

* refactor(attributes.py): change import statement from 'typing.Callable' to 'collections.abc.Callable' for better compatibility
refactor(code_parser.py): update type annotations to use '|' for Union types for better readability
refactor(base_component.py): update type annotations to use '|' for Union types for better readability
refactor(component.py): change import statement from 'typing.Callable' to 'collections.abc.Callable' for better compatibility
refactor(component.py): update type annotations to use '|' for Union types for better readability
refactor(component.py): update type annotations to use 'list' instead of 'List' for consistency

refactor(custom_component.py): update typing imports and annotations for better readability and consistency

refactor(utils.py): change type hint 'List' to 'list' for consistency and compatibility with Python 3.9

* run make format

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
ogabrielluiz and autofix-ci[bot] authored Aug 28, 2024
1 parent 5d81e8d commit 9c8dd8e
Show file tree
Hide file tree
Showing 22 changed files with 296 additions and 292 deletions.
2 changes: 1 addition & 1 deletion src/backend/base/langflow/custom/attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable
from collections.abc import Callable

import emoji

Expand Down
28 changes: 14 additions & 14 deletions src/backend/base/langflow/custom/code_parser/code_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import inspect
import traceback
from typing import Any, Dict, List, Type, Union
from typing import Any

from cachetools import TTLCache, keys
from fastapi import HTTPException
Expand Down Expand Up @@ -29,7 +29,7 @@ def find_class_ast_node(class_obj):
return None, []

# Read the source code from the file
with open(source_file, "r") as file:
with open(source_file) as file:
source_code = file.read()

# Parse the source code into an AST
Expand Down Expand Up @@ -59,7 +59,7 @@ class CodeParser:
A parser for Python source code, extracting code details.
"""

def __init__(self, code: Union[str, Type]) -> None:
def __init__(self, code: str | type) -> None:
"""
Initializes the parser with the provided code.
"""
Expand All @@ -70,7 +70,7 @@ def __init__(self, code: Union[str, Type]) -> None:
# If the code is a class, get its source code
code = inspect.getsource(code)
self.code = code
self.data: Dict[str, Any] = {
self.data: dict[str, Any] = {
"imports": [],
"functions": [],
"classes": [],
Expand Down Expand Up @@ -99,15 +99,15 @@ def get_tree(self):

return tree

def parse_node(self, node: Union[ast.stmt, ast.AST]) -> None:
def parse_node(self, node: ast.stmt | ast.AST) -> None:
"""
Parses an AST node and updates the data
dictionary with the relevant information.
"""
if handler := self.handlers.get(type(node)): # type: ignore
handler(node) # type: ignore

def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
def parse_imports(self, node: ast.Import | ast.ImportFrom) -> None:
"""
Extracts "imports" from the code, including aliases.
"""
Expand Down Expand Up @@ -161,7 +161,7 @@ def construct_eval_env(self, return_type_str: str, imports) -> dict:
exec(f"import {module} as {alias if alias else module}", eval_env)
return eval_env

def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
def parse_callable_details(self, node: ast.FunctionDef) -> dict[str, Any]:
"""
Extracts details from a single function or method node.
"""
Expand All @@ -187,7 +187,7 @@ def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:

return func.model_dump()

def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_function_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the arguments of a function or method node.
"""
Expand All @@ -202,7 +202,7 @@ def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:

return args

def parse_positional_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_positional_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the positional arguments of a function or method node.
"""
Expand All @@ -220,7 +220,7 @@ def parse_positional_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
return args

def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_varargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the *args argument of a function or method node.
"""
Expand All @@ -231,7 +231,7 @@ def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:

return args

def parse_keyword_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_keyword_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the keyword-only arguments of a function or method node.
"""
Expand All @@ -242,7 +242,7 @@ def parse_keyword_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
return args

def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_kwargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the **kwargs argument of a function or method node.
"""
Expand All @@ -253,7 +253,7 @@ def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:

return args

def parse_function_body(self, node: ast.FunctionDef) -> List[str]:
def parse_function_body(self, node: ast.FunctionDef) -> list[str]:
"""
Parses the body of a function or method node.
"""
Expand Down Expand Up @@ -394,7 +394,7 @@ def execute_and_inspect_classes(self, code: str):
bases.append(bases_base)
return bases

def parse_code(self) -> Dict[str, Any]:
def parse_code(self) -> dict[str, Any]:
"""
Runs all parsing operations and returns the resulting data.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import operator
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
from uuid import UUID
import warnings

Expand All @@ -24,11 +24,11 @@ class BaseComponent:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."

_code: Optional[str] = None
_code: str | None = None
"""The code of the component. Defaults to None."""
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str | UUID] = None
_user_id: str | UUID | None = None
_template_config: dict = {}

def __init__(self, **data):
Expand Down
17 changes: 9 additions & 8 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Union, get_type_hints
from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
from collections.abc import Callable
from uuid import UUID

import nanoid # type: ignore
Expand Down Expand Up @@ -30,8 +31,8 @@


class Component(CustomComponent):
inputs: List["InputTypes"] = []
outputs: List[Output] = []
inputs: list["InputTypes"] = []
outputs: list[Output] = []
code_class_base_inheritance: ClassVar[str] = "Component"
_output_logs: dict[str, Log] = {}

Expand Down Expand Up @@ -228,7 +229,7 @@ def set_output_value(self, name: str, value: Any):
else:
raise ValueError(f"Output {name} not found in {self.__class__.__name__}")

def map_outputs(self, outputs: List[Output]):
def map_outputs(self, outputs: list[Output]):
"""
Maps the given list of outputs to the component.
Expand All @@ -247,7 +248,7 @@ def map_outputs(self, outputs: List[Output]):
raise ValueError("Output name cannot be None.")
self._outputs[output.name] = output

def map_inputs(self, inputs: List["InputTypes"]):
def map_inputs(self, inputs: list["InputTypes"]):
"""
Maps the given inputs to the component.
Expand Down Expand Up @@ -449,7 +450,7 @@ def _map_parameters_on_template(self, template: dict):
)
raise ValueError(f"Parameter {name} not found in {self.__class__.__name__}. ")

def _get_method_return_type(self, method_name: str) -> List[str]:
def _get_method_return_type(self, method_name: str) -> list[str]:
method = getattr(self, method_name)
return_type = get_type_hints(method)["return"]
extracted_return_types = self._extract_return_type(return_type)
Expand Down Expand Up @@ -530,7 +531,7 @@ def set_attributes(self, params: dict):
_attributes[key] = input_obj.value or None
self._attributes = _attributes

def _set_outputs(self, outputs: List[dict]):
def _set_outputs(self, outputs: list[dict]):
self.outputs = [Output(**output) for output in outputs]
for output in self.outputs:
setattr(self, output.name, output)
Expand Down Expand Up @@ -646,7 +647,7 @@ def custom_repr(self):
return str(self.repr_value)
return self.repr_value

def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
def build_inputs(self, user_id: str | UUID | None = None):
"""
Builds the inputs for the custom component.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from collections.abc import Callable, Sequence

import yaml
from cachetools import TTLCache
Expand Down Expand Up @@ -48,42 +49,42 @@ class CustomComponent(BaseComponent):
_tree (Optional[dict]): The code tree of the custom component.
"""

name: Optional[str] = None
name: str | None = None
"""The name of the component used to styles. Defaults to None."""
display_name: Optional[str] = None
display_name: str | None = None
"""The display name of the component. Defaults to None."""
description: Optional[str] = None
description: str | None = None
"""The description of the component. Defaults to None."""
icon: Optional[str] = None
icon: str | None = None
"""The icon of the component. It should be an emoji. Defaults to None."""
is_input: Optional[bool] = None
is_input: bool | None = None
"""The input state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
is_output: Optional[bool] = None
is_output: bool | None = None
"""The output state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
field_config: dict = {}
"""The field configuration of the component. Defaults to an empty dictionary."""
field_order: Optional[List[str]] = None
field_order: list[str] | None = None
"""The field order of the component. Defaults to an empty list."""
frozen: Optional[bool] = False
frozen: bool | None = False
"""The default frozen state of the component. Defaults to False."""
build_parameters: Optional[dict] = None
build_parameters: dict | None = None
"""The build parameters of the component. Defaults to None."""
_vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
_code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
status: Optional[Any] = None
function: Callable | None = None
repr_value: Any | None = ""
status: Any | None = None
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_data: Optional[List[Data]] = None
_outputs: List[OutputValue] = []
_logs: List[Log] = []
_flows_data: list[Data] | None = None
_outputs: list[OutputValue] = []
_logs: list[Log] = []
_output_logs: dict[str, Log] = {}
_tracing_service: Optional["TracingService"] = None
_tree: Optional[dict] = None
_tree: dict | None = None

def __init__(self, **data):
"""
Expand Down Expand Up @@ -215,7 +216,7 @@ def update_build_config(
self,
build_config: dotdict,
field_value: Any,
field_name: Optional[str] = None,
field_name: str | None = None,
):
build_config[field_name] = field_value
return build_config
Expand All @@ -230,7 +231,7 @@ def tree(self):
"""
return self.get_code_tree(self._code or "")

def to_data(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Data]:
def to_data(self, data: Any, keys: list[str] | None = None, silent_errors: bool = False) -> list[Data]:
"""
Converts input data into a list of Data objects.
Expand Down Expand Up @@ -289,7 +290,7 @@ def get_method_return_type(self, method_name: str):

return self._extract_return_type(return_type)

def create_references_from_data(self, data: List[Data], include_data: bool = False) -> str:
def create_references_from_data(self, data: list[Data], include_data: bool = False) -> str:
"""
Create references from a list of data.
Expand Down Expand Up @@ -352,7 +353,7 @@ def get_method(self, method_name: str):
return build_methods[0] if build_methods else {}

@property
def get_function_entrypoint_return_type(self) -> List[Any]:
def get_function_entrypoint_return_type(self) -> list[Any]:
"""
Gets the return type of the function entrypoint for the custom component.
Expand All @@ -361,7 +362,7 @@ def get_function_entrypoint_return_type(self) -> List[Any]:
"""
return self.get_method_return_type(self._function_entrypoint_name)

def _extract_return_type(self, return_type: Any) -> List[Any]:
def _extract_return_type(self, return_type: Any) -> list[Any]:
return post_process_type(return_type)

@property
Expand Down Expand Up @@ -451,7 +452,7 @@ def index(self, value: int = 0):
Callable: A function that returns the value at the given index.
"""

def get_index(iterable: List[Any]):
def get_index(iterable: list[Any]):
return iterable[value] if iterable else iterable

return get_index
Expand All @@ -465,18 +466,18 @@ def get_function(self):
"""
return validate.create_function(self._code, self._function_entrypoint_name)

async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
async def load_flow(self, flow_id: str, tweaks: dict | None = None) -> "Graph":
if not self.user_id:
raise ValueError("Session is invalid")
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)

async def run_flow(
self,
inputs: Optional[Union[dict, List[dict]]] = None,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
output_type: Optional[str] = "chat",
tweaks: Optional[dict] = None,
inputs: dict | list[dict] | None = None,
flow_id: str | None = None,
flow_name: str | None = None,
output_type: str | None = "chat",
tweaks: dict | None = None,
) -> Any:
return await run_flow(
inputs=inputs,
Expand All @@ -487,7 +488,7 @@ async def run_flow(
user_id=str(self._user_id),
)

def list_flows(self) -> List[Data]:
def list_flows(self) -> list[Data]:
if not self.user_id:
raise ValueError("Session is invalid")
try:
Expand All @@ -508,7 +509,7 @@ def build(self, *args: Any, **kwargs: Any) -> Any:
"""
raise NotImplementedError

def log(self, message: LoggableType | list[LoggableType], name: Optional[str] = None):
def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
"""
Logs a message.
Expand All @@ -531,7 +532,7 @@ def post_code_processing(self, new_frontend_node: dict, current_frontend_node: d
)
return frontend_node

def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
def get_langchain_callbacks(self) -> list["BaseCallbackHandler"]:
if self._tracing_service:
return self._tracing_service.get_langchain_callbacks()
return []
Loading

0 comments on commit 9c8dd8e

Please sign in to comment.