diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index 3ccbb539377..f98d499a3c6 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -32,7 +32,7 @@ overload, ) -from typing_extensions import ParamSpec, get_origin, get_type_hints, override +from typing_extensions import ParamSpec, get_type_hints, override from reflex import constants from reflex.base import Base @@ -40,6 +40,7 @@ from reflex.utils import console, imports, serializers, types from reflex.utils.exceptions import VarDependencyError, VarTypeError, VarValueError from reflex.utils.format import format_state_name +from reflex.utils.types import get_origin from reflex.vars import ( ComputedVar, ImmutableVarData, diff --git a/reflex/ivars/object.py b/reflex/ivars/object.py index 4586ffa0a65..5401b678ae9 100644 --- a/reflex/ivars/object.py +++ b/reflex/ivars/object.py @@ -19,11 +19,9 @@ overload, ) -from typing_extensions import get_origin - from reflex.utils import types from reflex.utils.exceptions import VarAttributeError -from reflex.utils.types import GenericType, get_attribute_access_type +from reflex.utils.types import GenericType, get_attribute_access_type, get_origin from reflex.vars import ImmutableVarData, Var, VarData from .base import ( diff --git a/reflex/ivars/sequence.py b/reflex/ivars/sequence.py index d05ec5aad6d..47f5c6f0150 100644 --- a/reflex/ivars/sequence.py +++ b/reflex/ivars/sequence.py @@ -22,11 +22,9 @@ overload, ) -from typing_extensions import get_origin - from reflex import constants from reflex.constants.base import REFLEX_VAR_OPENING_TAG -from reflex.utils.types import GenericType +from reflex.utils.types import GenericType, get_origin from reflex.vars import ( ImmutableVarData, Var, diff --git a/reflex/utils/types.py b/reflex/utils/types.py index c164bb93e93..a5e4eccda22 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -6,7 +6,7 @@ import inspect import sys import types -from functools import cached_property, wraps +from functools import cached_property, lru_cache, wraps from typing import ( Any, Callable, @@ -21,9 +21,11 @@ Union, _GenericAlias, # type: ignore get_args, - get_origin, get_type_hints, ) +from typing import ( + get_origin as get_origin_og, +) import sqlalchemy @@ -133,6 +135,20 @@ def __bool__(self) -> bool: return False +@lru_cache() +def get_origin(tp): + """Get the origin of a class. + + Args: + tp: The class to get the origin of. + + Returns: + The origin of the class. + """ + return get_origin_og(tp) + + +@lru_cache() def is_generic_alias(cls: GenericType) -> bool: """Check whether the class is a generic alias. @@ -157,6 +173,7 @@ def is_none(cls: GenericType) -> bool: return cls is type(None) or cls is None +@lru_cache() def is_union(cls: GenericType) -> bool: """Check if a class is a Union. @@ -169,6 +186,7 @@ def is_union(cls: GenericType) -> bool: return get_origin(cls) in UnionTypes +@lru_cache() def is_literal(cls: GenericType) -> bool: """Check if a class is a Literal. @@ -314,6 +332,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None return None # Attribute is not accessible. +@lru_cache() def get_base_class(cls: GenericType) -> Type: """Get the base class of a class. diff --git a/reflex/vars.py b/reflex/vars.py index a3d40722d55..ea7d28b974a 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -29,7 +29,6 @@ _GenericAlias, # type: ignore cast, get_args, - get_origin, get_type_hints, ) @@ -51,7 +50,7 @@ ParsedImportDict, parse_imports, ) -from reflex.utils.types import override +from reflex.utils.types import get_origin, override if TYPE_CHECKING: from reflex.state import BaseState @@ -182,15 +181,14 @@ def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None: var_data.interpolations if isinstance(var_data, VarData) else [] ) - return ( - cls( + if state or _imports or hooks or interpolations: + return cls( state=state, imports=_imports, hooks=hooks, interpolations=interpolations, ) - or None - ) + return None def __bool__(self) -> bool: """Check if the var data is non-empty. @@ -302,14 +300,13 @@ def merge( else {k: None for k in var_data.hooks} ) - return ( - ImmutableVarData( + if state or _imports or hooks: + return ImmutableVarData( state=state, imports=_imports, hooks=hooks, ) - or None - ) + return None def __bool__(self) -> bool: """Check if the var data is non-empty. diff --git a/scripts/integration.sh b/scripts/integration.sh index 17ba66ec795..dc8b5d5537a 100755 --- a/scripts/integration.sh +++ b/scripts/integration.sh @@ -34,4 +34,4 @@ if [ -f /proc/$pid/winpid ]; then echo "Windows detected, passing winpid $pid to port waiter" fi -python scripts/wait_for_listening_port.py $check_ports --timeout=1800 --server-pid "$pid" +python scripts/wait_for_listening_port.py $check_ports --timeout=900 --server-pid "$pid"