From c387f517b6b3818679d494f182cda73ea516cc07 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 12 Dec 2024 19:36:31 +0000 Subject: [PATCH 01/15] [ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522) * Throw warnings when Redis lock is held for more than the allowed threshold * initial tests * fix tests and address comments * fix tests fr, and use pydantic validators * darglint fix * increase lock expiration in tests to 2500 * remove print statement --------- Co-authored-by: Khaleel Al-Adhami --- reflex/config.py | 3 + reflex/constants/config.py | 2 + reflex/state.py | 53 +++++++++++++++++ reflex/utils/console.py | 66 +++++++++++++++++++-- reflex/utils/exceptions.py | 4 ++ tests/units/test_state.py | 114 +++++++++++++++++++++++++++++++++++-- 6 files changed, 232 insertions(+), 10 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index ae2c0ea0e8..bbea6a5d0c 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -684,6 +684,9 @@ class Config: # Maximum expiration lock time for redis state manager redis_lock_expiration: int = constants.Expiration.LOCK + # Maximum lock time before warning for redis state manager. + redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD + # Token expiration time for redis state manager redis_token_expiration: int = constants.Expiration.TOKEN diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 970e67844a..7425fd8648 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -29,6 +29,8 @@ class Expiration(SimpleNamespace): LOCK = 10000 # The PING timeout PING = 120 + # The maximum time in milliseconds to hold a lock before throwing a warning. + LOCK_WARNING_THRESHOLD = 1000 class GitIgnore(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index 3e606bf572..434ee39217 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -71,6 +71,11 @@ except ModuleNotFoundError: BaseModelV1 = BaseModelV2 +try: + from pydantic.v1 import validator +except ModuleNotFoundError: + from pydantic import validator + import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError @@ -94,6 +99,7 @@ DynamicRouteArgShadowsStateVar, EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, + InvalidLockWarningThresholdError, InvalidStateManagerMode, LockExpiredError, ReflexRuntimeError, @@ -2834,6 +2840,7 @@ def create(cls, state: Type[BaseState]): redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, + lock_warning_threshold=config.redis_lock_warning_threshold, ) raise InvalidStateManagerMode( f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" @@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +def _default_lock_warning_threshold() -> int: + """Get the default lock warning threshold. + + Returns: + The default lock warning threshold. + """ + return get_config().redis_lock_warning_threshold + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager): # The maximum time to hold a lock (ms). lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration) + # The maximum time to hold a lock (ms) before warning. + lock_warning_threshold: int = pydantic.Field( + default_factory=_default_lock_warning_threshold + ) + # The keyspace subscription string when redis is waiting for lock to be released _redis_notify_keyspace_events: str = ( "K" # Enable keyspace notifications (target a particular key) @@ -3402,6 +3423,17 @@ async def set_state( f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.event(background=True)` decorator for long-running tasks." ) + elif lock_id is not None: + time_taken = self.lock_expiration / 1000 - ( + await self.redis.ttl(self._lock_key(token)) + ) + if time_taken > self.lock_warning_threshold / 1000: + console.warn( + f"Lock for token {token} was held too long {time_taken=}s, " + f"use `@rx.event(background=True)` decorator for long-running tasks.", + dedupe=True, + ) + client_token, substate_name = _split_substate_key(token) # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: @@ -3451,6 +3483,27 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: yield state await self.set_state(token, state, lock_id) + @validator("lock_warning_threshold") + @classmethod + def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values): + """Validate the lock warning threshold. + + Args: + lock_warning_threshold: The lock warning threshold. + values: The validated attributes. + + Returns: + The lock warning threshold. + + Raises: + InvalidLockWarningThresholdError: If the lock warning threshold is invalid. + """ + if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]): + raise InvalidLockWarningThresholdError( + f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." + ) + return lock_warning_threshold + @staticmethod def _lock_key(token: str) -> bytes: """Get the redis key for a token's lock. diff --git a/reflex/utils/console.py b/reflex/utils/console.py index b3ba7163d2..be545140af 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -20,6 +20,24 @@ # Info messages which have been printed. _EMITTED_INFO = set() +# Warnings which have been printed. +_EMIITED_WARNINGS = set() + +# Errors which have been printed. +_EMITTED_ERRORS = set() + +# Success messages which have been printed. +_EMITTED_SUCCESS = set() + +# Debug messages which have been printed. +_EMITTED_DEBUG = set() + +# Logs which have been printed. +_EMITTED_LOGS = set() + +# Prints which have been printed. +_EMITTED_PRINTS = set() + def set_log_level(log_level: LogLevel): """Set the log level. @@ -55,25 +73,37 @@ def is_debug() -> bool: return _LOG_LEVEL <= LogLevel.DEBUG -def print(msg: str, **kwargs): +def print(msg: str, dedupe: bool = False, **kwargs): """Print a message. Args: msg: The message to print. + dedupe: If True, suppress multiple console logs of print message. kwargs: Keyword arguments to pass to the print function. """ + if dedupe: + if msg in _EMITTED_PRINTS: + return + else: + _EMITTED_PRINTS.add(msg) _console.print(msg, **kwargs) -def debug(msg: str, **kwargs): +def debug(msg: str, dedupe: bool = False, **kwargs): """Print a debug message. Args: msg: The debug message. + dedupe: If True, suppress multiple console logs of debug message. kwargs: Keyword arguments to pass to the print function. """ if is_debug(): msg_ = f"[purple]Debug: {msg}[/purple]" + if dedupe: + if msg_ in _EMITTED_DEBUG: + return + else: + _EMITTED_DEBUG.add(msg_) if progress := kwargs.pop("progress", None): progress.console.print(msg_, **kwargs) else: @@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs): print(f"[cyan]Info: {msg}[/cyan]", **kwargs) -def success(msg: str, **kwargs): +def success(msg: str, dedupe: bool = False, **kwargs): """Print a success message. Args: msg: The success message. + dedupe: If True, suppress multiple console logs of success message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_SUCCESS: + return + else: + _EMITTED_SUCCESS.add(msg) print(f"[green]Success: {msg}[/green]", **kwargs) -def log(msg: str, **kwargs): +def log(msg: str, dedupe: bool = False, **kwargs): """Takes a string and logs it to the console. Args: msg: The message to log. + dedupe: If True, suppress multiple console logs of log message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_LOGS: + return + else: + _EMITTED_LOGS.add(msg) _console.log(msg, **kwargs) @@ -129,14 +171,20 @@ def rule(title: str, **kwargs): _console.rule(title, **kwargs) -def warn(msg: str, **kwargs): +def warn(msg: str, dedupe: bool = False, **kwargs): """Print a warning message. Args: msg: The warning message. + dedupe: If True, suppress multiple console logs of warning message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.WARNING: + if dedupe: + if msg in _EMIITED_WARNINGS: + return + else: + _EMIITED_WARNINGS.add(msg) print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) @@ -169,14 +217,20 @@ def deprecate( _EMITTED_DEPRECATION_WARNINGS.add(feature_name) -def error(msg: str, **kwargs): +def error(msg: str, dedupe: bool = False, **kwargs): """Print an error message. Args: msg: The error message. + dedupe: If True, suppress multiple console logs of error message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.ERROR: + if dedupe: + if msg in _EMITTED_ERRORS: + return + else: + _EMITTED_ERRORS.add(msg) print(f"[red]{msg}[/red]", **kwargs) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 6c378e1591..ae5ec01683 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn: " Please install it through your system package manager." + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "") ) + + +class InvalidLockWarningThresholdError(ReflexError): + """Raised when an invalid lock warning threshold is provided.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 9e952e10f7..912d72f4f1 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -56,6 +56,7 @@ from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import ( + InvalidLockWarningThresholdError, ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, @@ -67,7 +68,9 @@ from .states import GenState CI = bool(os.environ.get("CI", False)) -LOCK_EXPIRATION = 2000 if CI else 300 +LOCK_EXPIRATION = 2500 if CI else 300 +LOCK_WARNING_THRESHOLD = 1000 if CI else 100 +LOCK_WARN_SLEEP = 1.5 if CI else 0.15 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 @@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire( substate_token_redis: A token + substate name for looking up in state manager. """ state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(0.01) @@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend( unexp_num1 = 666 state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD order = [] @@ -1840,6 +1845,39 @@ async def _coro_waiter(): assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.mark.asyncio +async def test_state_manager_lock_warning_threshold_contend( + state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker +): + """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. + + Args: + state_manager_redis: A state manager instance. + token: A token. + substate_token_redis: A token + substate name for looking up in state manager. + mocker: Pytest mocker object. + """ + console_warn = mocker.patch("reflex.utils.console.warn") + + state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + + order = [] + + async def _coro_blocker(): + async with state_manager_redis.modify_state(substate_token_redis): + order.append("blocker") + await asyncio.sleep(LOCK_WARN_SLEEP) + + tasks = [ + asyncio.create_task(_coro_blocker()), + ] + + await tasks[0] + console_warn.assert_called() + assert console_warn.call_count == 7 + + class CopyingAsyncMock(AsyncMock): """An AsyncMock, but deepcopy the args and kwargs first.""" @@ -3253,12 +3291,42 @@ async def test_setvar_async_setter(): @pytest.mark.parametrize( "expiration_kwargs, expected_values", [ - ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)), + ( + {"redis_lock_expiration": 20000}, + ( + 20000, + constants.Expiration.TOKEN, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), ( {"redis_lock_expiration": 50000, "redis_token_expiration": 5600}, - (50000, 5600), + (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD), + ), + ( + {"redis_token_expiration": 7600}, + ( + constants.Expiration.LOCK, + 7600, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), + ( + {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500}, + (50000, constants.Expiration.TOKEN, 1500), + ), + ( + {"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000}, + (constants.Expiration.LOCK, 5600, 3000), + ), + ( + { + "redis_lock_expiration": 50000, + "redis_token_expiration": 5600, + "redis_lock_warning_threshold": 2000, + }, + (50000, 5600, 2000), ), - ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)), ], ) def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): @@ -3288,6 +3356,44 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_ state_manager = StateManager.create(state=State) assert state_manager.lock_expiration == expected_values[0] # type: ignore assert state_manager.token_expiration == expected_values[1] # type: ignore + assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore + + +@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") +@pytest.mark.parametrize( + "redis_lock_expiration, redis_lock_warning_threshold", + [ + (10000, 10000), + (20000, 30000), + ], +) +def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( + tmp_path, redis_lock_expiration, redis_lock_warning_threshold +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = f""" +import reflex as rx +config = rx.Config( + app_name="project1", + redis_url="redis://localhost:6379", + state_manager_mode="redis", + redis_lock_expiration = {redis_lock_expiration}, + redis_lock_warning_threshold = {redis_lock_warning_threshold}, +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + + with chdir(proj_root): + # reload config for each parameter to avoid stale values + reflex.config.get_config(reload=True) + from reflex.state import State, StateManager + + with pytest.raises(InvalidLockWarningThresholdError): + StateManager.create(state=State) + del sys.modules[constants.Config.MODULE] class MixinState(State, mixin=True): From 60a5b7bc7a73d22fdf3c77850b899f3116df0590 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 12 Dec 2024 14:28:17 -0800 Subject: [PATCH 02/15] [ENG-4194] TypeError: Cannot create property 'token' on string (#4534) --- reflex/app.py | 2 +- tests/integration/test_client_storage.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 67bb203fa2..10dd889b3c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1295,7 +1295,7 @@ async def process( await asyncio.create_task( app.event_namespace.emit( "reload", - data=format.json_dumps(event), + data=event, to=sid, ) ) diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 236d3e14e1..2652d6ccb3 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -637,8 +637,7 @@ async def poll_for_not_hydrated(): assert await AppHarness._poll_for_async(poll_for_not_hydrated) # Trigger event to get a new instance of the state since the old was expired. - state_var_input = driver.find_element(By.ID, "state_var") - state_var_input.send_keys("re-triggering") + set_sub("c1", "c1 post expire") # get new references to all cookie and local storage elements (again) c1 = driver.find_element(By.ID, "c1") @@ -659,7 +658,7 @@ async def poll_for_not_hydrated(): l1s = driver.find_element(By.ID, "l1s") s1s = driver.find_element(By.ID, "s1s") - assert c1.text == "c1 value" + assert c1.text == "c1 post expire" assert c2.text == "c2 value" assert c3.text == "" # temporary cookie expired after reset state! assert c4.text == "c4 value" @@ -690,11 +689,11 @@ async def get_sub_state(): async def poll_for_c1_set(): sub_state = await get_sub_state() - return sub_state.c1 == "c1 value" + return sub_state.c1 == "c1 post expire" assert await AppHarness._poll_for_async(poll_for_c1_set) sub_state = await get_sub_state() - assert sub_state.c1 == "c1 value" + assert sub_state.c1 == "c1 post expire" assert sub_state.c2 == "c2 value" assert sub_state.c3 == "" assert sub_state.c4 == "c4 value" From d5d41a0d9ecbb7a38b058d9afebcb48f52dd33ee Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 12 Dec 2024 14:28:30 -0800 Subject: [PATCH 03/15] raise_console_error during integration tests (#4535) --- tests/integration/conftest.py | 23 ++++++++++++++++++++ tests/integration/test_exception_handlers.py | 2 ++ 2 files changed, 25 insertions(+) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f7b825f162..d11344903f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,6 +6,7 @@ import pytest +import reflex.app from reflex.config import environment from reflex.testing import AppHarness, AppHarnessProd @@ -76,3 +77,25 @@ def app_harness_env(request): The AppHarness class to use for the test. """ return request.param + + +@pytest.fixture(autouse=True) +def raise_console_error(request, mocker): + """Spy on calls to `console.error` used by the framework. + + Help catch spurious error conditions that might otherwise go unnoticed. + + If a test is marked with `ignore_console_error`, the spy will be ignored + after the test. + + Args: + request: The pytest request object. + mocker: The pytest mocker object. + + Yields: + control to the test function. + """ + spy = mocker.spy(reflex.app.console, "error") + yield + if "ignore_console_error" not in request.keywords: + spy.assert_not_called() diff --git a/tests/integration/test_exception_handlers.py b/tests/integration/test_exception_handlers.py index 406c21e5d9..a645d1de6e 100644 --- a/tests/integration/test_exception_handlers.py +++ b/tests/integration/test_exception_handlers.py @@ -13,6 +13,8 @@ from reflex.testing import AppHarness, AppHarnessProd +pytestmark = [pytest.mark.ignore_console_error] + def TestApp(): """A test app for event exception handler integration.""" From 7ca50c62e2354d142f5870defb2a8a6cc79d43ee Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 12 Dec 2024 23:24:15 +0000 Subject: [PATCH 04/15] [ENG-4153]Use empty string for None values in `rx.input` and `rx.el.input` (#4521) * Use empty string for None values in `rx.input` and `rx.el.input` * fix tests * fix pyi scripts * use nullish coalescing operator * fix unit tests * use ternary operator so old browsers and dynamic components tests pass * address comment on doing this only for optionals * Fix tests * pyright! * fix comments --- reflex/components/el/elements/forms.py | 28 +++++++++++++++++++ reflex/components/el/elements/forms.pyi | 4 +-- .../radix/themes/components/text_field.py | 15 ++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 205aae2673..61ded4fd3c 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -18,6 +18,7 @@ prevent_default, ) from reflex.utils.imports import ImportDict +from reflex.utils.types import is_optional from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -382,6 +383,33 @@ class Input(BaseHTML): # Fired when a key is released on_key_up: EventHandler[key_event] + @classmethod + def create(cls, *children, **props): + """Create an Input component. + + Args: + *children: The children of the component. + **props: The properties of the component. + + Returns: + The component. + """ + from reflex.vars.number import ternary_operation + + value = props.get("value") + + # React expects an empty string(instead of null) for controlled inputs. + if value is not None and is_optional( + (value_var := Var.create(value))._var_type + ): + props["value"] = ternary_operation( + (value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues] + & (value_var != Var(_js_expr="undefined")), + value, + Var.create(""), + ) + return super().create(*children, **props) + class Label(BaseHTML): """Display the label element.""" diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 5870d4b22a..dfab40b219 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -512,7 +512,7 @@ class Input(BaseHTML): on_unmount: Optional[EventType[[], BASE_STATE]] = None, **props, ) -> "Input": - """Create the component. + """Create an Input component. Args: *children: The children of the component. @@ -576,7 +576,7 @@ class Input(BaseHTML): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the component. Returns: The component. diff --git a/reflex/components/radix/themes/components/text_field.py b/reflex/components/radix/themes/components/text_field.py index 3dabe09366..7e6dfe85c9 100644 --- a/reflex/components/radix/themes/components/text_field.py +++ b/reflex/components/radix/themes/components/text_field.py @@ -9,7 +9,9 @@ from reflex.components.core.debounce import DebounceInput from reflex.components.el import elements from reflex.event import EventHandler, input_event, key_event +from reflex.utils.types import is_optional from reflex.vars.base import Var +from reflex.vars.number import ternary_operation from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent @@ -96,6 +98,19 @@ def create(cls, *children, **props) -> Component: Returns: The component. """ + value = props.get("value") + + # React expects an empty string(instead of null) for controlled inputs. + if value is not None and is_optional( + (value_var := Var.create(value))._var_type + ): + props["value"] = ternary_operation( + (value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues] + & (value_var != Var(_js_expr="undefined")), + value, + Var.create(""), + ) + component = super().create(*children, **props) if props.get("value") is not None and props.get("on_change") is not None: # create a debounced input if the user requests full control to avoid typing jank From 72085408557c383cf13b395d23a670f4df1afa2f Mon Sep 17 00:00:00 2001 From: Vy Nguyen <114444436+vydpnguyen@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:50:01 -0800 Subject: [PATCH 05/15] Wrap Checkbox component in Context Menu (partial fix for #4262) (#4479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add ContextMenuCheckBoxItem component * Remove toggle_state method * Import Checkbox and implement ContextMenuCheckbox class * Change parameter for create function * Reformat code and import block * Removed unused vars * Format file * Remove create method * Reformat code * Update reflex/components/radix/themes/components/context_menu.py Co-authored-by: Thomas Brandého * Update reflex/components/radix/themes/components/context_menu.py Co-authored-by: Thomas Brandého * Update reflex/components/radix/themes/components/context_menu.py Co-authored-by: Thomas Brandého * Update reflex/components/radix/themes/components/context_menu.py Co-authored-by: Thomas Brandého * Reformat code * Add automatically modified pyi file * Update context_menu.pyi file --------- Co-authored-by: Thomas Brandého --- .../radix/themes/components/context_menu.py | 11 ++ .../radix/themes/components/context_menu.pyi | 155 ++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/reflex/components/radix/themes/components/context_menu.py b/reflex/components/radix/themes/components/context_menu.py index ea49022337..f8512a902e 100644 --- a/reflex/components/radix/themes/components/context_menu.py +++ b/reflex/components/radix/themes/components/context_menu.py @@ -8,6 +8,7 @@ from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent +from .checkbox import Checkbox LiteralDirType = Literal["ltr", "rtl"] @@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent): tag = "ContextMenu.Separator" +class ContextMenuCheckbox(Checkbox): + """The component that contains the checkbox.""" + + tag = "ContextMenu.CheckboxItem" + + # Text to render as shortcut. + shortcut: Var[str] + + class ContextMenu(ComponentNamespace): """Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press.""" @@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace): sub_content = staticmethod(ContextMenuSubContent.create) item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) + checkbox = staticmethod(ContextMenuCheckbox.create) context_menu = ContextMenu() diff --git a/reflex/components/radix/themes/components/context_menu.pyi b/reflex/components/radix/themes/components/context_menu.pyi index c5ef757d13..2d3ffbebcd 100644 --- a/reflex/components/radix/themes/components/context_menu.pyi +++ b/reflex/components/radix/themes/components/context_menu.pyi @@ -12,6 +12,7 @@ from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent +from .checkbox import Checkbox LiteralDirType = Literal["ltr", "rtl"] LiteralSizeType = Literal["1", "2"] @@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent): """ ... +class ContextMenuCheckbox(Checkbox): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + shortcut: Optional[Union[Var[str], str]] = None, + as_child: Optional[Union[Var[bool], bool]] = None, + size: Optional[ + Union[ + Breakpoints[str, Literal["1", "2", "3"]], + Literal["1", "2", "3"], + Var[ + Union[ + Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"] + ] + ], + ] + ] = None, + variant: Optional[ + Union[ + Literal["classic", "soft", "surface"], + Var[Literal["classic", "soft", "surface"]], + ] + ] = None, + color_scheme: Optional[ + Union[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ], + Var[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ] + ], + ] + ] = None, + high_contrast: Optional[Union[Var[bool], bool]] = None, + default_checked: Optional[Union[Var[bool], bool]] = None, + checked: Optional[Union[Var[bool], bool]] = None, + disabled: Optional[Union[Var[bool], bool]] = None, + required: Optional[Union[Var[bool], bool]] = None, + name: Optional[Union[Var[str], str]] = None, + value: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_change: Optional[ + Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]] + ] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "ContextMenuCheckbox": + """Create a new component instance. + + Will prepend "RadixThemes" to the component tag to avoid conflicts with + other UI libraries for common names, like Text and Button. + + Args: + *children: Child components. + shortcut: Text to render as shortcut. + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. + size: Checkbox size "1" - "3" + variant: Variant of checkbox: "classic" | "surface" | "soft" + color_scheme: Override theme color for checkbox + high_contrast: Whether to render the checkbox with higher contrast color against background + default_checked: Whether the checkbox is checked by default + checked: Whether the checkbox is checked + disabled: Whether the checkbox is disabled + required: Whether the checkbox is required + name: The name of the checkbox control when submitting the form. + value: The value of the checkbox control when submitting the form. + on_change: Fired when the checkbox is checked or unchecked. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Component properties. + + Returns: + A new component instance. + """ + ... + class ContextMenu(ComponentNamespace): root = staticmethod(ContextMenuRoot.create) trigger = staticmethod(ContextMenuTrigger.create) @@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace): sub_content = staticmethod(ContextMenuSubContent.create) item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) + checkbox = staticmethod(ContextMenuCheckbox.create) context_menu = ContextMenu() From f4aea1b3abba4d18cb4162d44f177b14712f7fc0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 13 Dec 2024 09:31:27 -0800 Subject: [PATCH 06/15] [ENG-3583] Respect cors_allowed_origins for backend HTTP requests (#4533) --- reflex/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/app.py b/reflex/app.py index 10dd889b3c..935fe7900c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -436,7 +436,7 @@ def _add_cors(self): allow_credentials=True, allow_methods=["*"], allow_headers=["*"], - allow_origins=["*"], + allow_origins=get_config().cors_allowed_origins, ) @property From 206de4df7af09ef31148bd5f7896c4268b4aac4e Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Fri, 13 Dec 2024 19:27:51 +0000 Subject: [PATCH 07/15] Unify `is_external` prop in `rx.redirect` and `rx.link` (#4389) * Unify `is_external` prop in `rx.redirect` and `rx.link` * default external to `None` * address PR comment * use a one-liner * reorder args for api stability Co-authored-by: Masen Furer * reorder doc args * external arg as deprecated in the docs * Update reflex/event.py Co-authored-by: Khaleel Al-Adhami * Fixup typing_extensions import and ruff --------- Co-authored-by: Masen Furer Co-authored-by: Khaleel Al-Adhami --- reflex/event.py | 44 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 05a163d3e1..8342062441 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -25,6 +25,7 @@ overload, ) +import typing_extensions from typing_extensions import ( Concatenate, ParamSpec, @@ -714,26 +715,61 @@ def fn(): ) +@overload +def redirect( + path: str | Var[str], + is_external: Optional[bool] = None, + replace: bool = False, +) -> EventSpec: ... + + +@overload +@typing_extensions.deprecated("`external` is deprecated use `is_external` instead") +def redirect( + path: str | Var[str], + is_external: Optional[bool] = None, + replace: bool = False, + external: Optional[bool] = None, +) -> EventSpec: ... + + def redirect( path: str | Var[str], - external: Optional[bool] = False, - replace: Optional[bool] = False, + is_external: Optional[bool] = None, + replace: bool = False, + external: Optional[bool] = None, ) -> EventSpec: """Redirect to a new path. Args: path: The path to redirect to. - external: Whether to open in new tab or not. + is_external: Whether to open in new tab or not. replace: If True, the current page will not create a new history entry. + external(Deprecated): Whether to open in new tab or not. Returns: An event to redirect to the path. """ + if external is not None: + console.deprecate( + "The `external` prop in `rx.redirect`", + "use `is_external` instead.", + "0.6.6", + "0.7.0", + ) + + # is_external should take precedence over external. + is_external = ( + (False if external is None else external) + if is_external is None + else is_external + ) + return server_side( "_redirect", get_fn_signature(redirect), path=path, - external=external, + external=is_external, replace=replace, ) From ec897021371f7928124982e9383cb465791b727a Mon Sep 17 00:00:00 2001 From: Joodith <67360396+Joodith@users.noreply.github.com> Date: Sat, 14 Dec 2024 01:05:35 +0530 Subject: [PATCH 08/15] Include step attribute in input (#4073) * Include step attribute in input * Remove `step` prop from TextField it is inherited from Input, and does not need to be redefined --------- Co-authored-by: Masen Furer --- .../radix/themes/components/text_field.py | 2 +- .../radix/themes/components/text_field.pyi | 118 +++++++++++++++--- 2 files changed, 104 insertions(+), 16 deletions(-) diff --git a/reflex/components/radix/themes/components/text_field.py b/reflex/components/radix/themes/components/text_field.py index 7e6dfe85c9..c8bdab4435 100644 --- a/reflex/components/radix/themes/components/text_field.py +++ b/reflex/components/radix/themes/components/text_field.py @@ -19,7 +19,7 @@ LiteralTextFieldVariant = Literal["classic", "surface", "soft"] -class TextFieldRoot(elements.Div, RadixThemesComponent): +class TextFieldRoot(elements.Input, RadixThemesComponent): """Captures user input with an optional slot for buttons and icons.""" tag = "TextField.Root" diff --git a/reflex/components/radix/themes/components/text_field.pyi b/reflex/components/radix/themes/components/text_field.pyi index 09d58ed8f7..81c991899c 100644 --- a/reflex/components/radix/themes/components/text_field.pyi +++ b/reflex/components/radix/themes/components/text_field.pyi @@ -17,7 +17,7 @@ from ..base import RadixThemesComponent LiteralTextFieldSize = Literal["1", "2", "3"] LiteralTextFieldVariant = Literal["classic", "surface", "soft"] -class TextFieldRoot(elements.Div, RadixThemesComponent): +class TextFieldRoot(elements.Input, RadixThemesComponent): @overload @classmethod def create( # type: ignore @@ -120,6 +120,30 @@ class TextFieldRoot(elements.Div, RadixThemesComponent): type: Optional[Union[Var[str], str]] = None, value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None, list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + default_checked: Optional[Union[Var[bool], bool]] = None, + dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_enc_type: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_no_validate: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, auto_capitalize: Optional[ Union[Var[Union[bool, int, str]], bool, int, str] @@ -192,12 +216,12 @@ class TextFieldRoot(elements.Div, RadixThemesComponent): Args: *children: The children of the component. - size: Text field size "1" - "3" + size: Specifies the visible width of a text control variant: Variant of text field: "classic" | "surface" | "soft" color_scheme: Override theme color for text field radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full" auto_complete: Whether the input should have autocomplete enabled - default_value: The value of the input when initially rendered. + default_value: The initial value for a text field disabled: Disables the input max_length: Specifies the maximum number of characters allowed in the input min_length: Specifies the minimum number of characters required in the input @@ -208,11 +232,31 @@ class TextFieldRoot(elements.Div, RadixThemesComponent): type: Specifies the type of input value: Value of the input list: References a datalist for suggested options - on_change: Fired when the value of the textarea changes. - on_focus: Fired when the textarea is focused. - on_blur: Fired when the textarea is blurred. - on_key_down: Fired when a key is pressed down. - on_key_up: Fired when a key is released. + on_change: Fired when the input value changes + on_focus: Fired when the input gains focus + on_blur: Fired when the input loses focus + on_key_down: Fired when a key is pressed down + on_key_up: Fired when a key is released + accept: Accepted types of files when the input is file type + alt: Alternate text for input type="image" + auto_focus: Automatically focuses the input when the page loads + capture: Captures media from the user (camera or microphone) + checked: Indicates whether the input is checked (for checkboxes and radio buttons) + default_checked: The initial value (for checkboxes and radio buttons) + dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted + form: Associates the input with a form (by id) + form_action: URL to send the form data to (for type="submit" buttons) + form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons) + form_method: HTTP method to use for sending form data (for type="submit" buttons) + form_no_validate: Bypasses form validation when submitting (for type="submit" buttons) + form_target: Specifies where to display the response after submitting the form (for type="submit" buttons) + max: Specifies the maximum value for the input + min: Specifies the minimum value for the input + multiple: Indicates whether multiple values can be entered in an input of the type email or file + pattern: Regex pattern the input's value must match to be valid + src: URL for image inputs + step: Specifies the legal number intervals for an input + use_map: Name of the image map used with the input access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -457,6 +501,30 @@ class TextField(ComponentNamespace): type: Optional[Union[Var[str], str]] = None, value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None, list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + default_checked: Optional[Union[Var[bool], bool]] = None, + dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_enc_type: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + form_no_validate: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, auto_capitalize: Optional[ Union[Var[Union[bool, int, str]], bool, int, str] @@ -529,12 +597,12 @@ class TextField(ComponentNamespace): Args: *children: The children of the component. - size: Text field size "1" - "3" + size: Specifies the visible width of a text control variant: Variant of text field: "classic" | "surface" | "soft" color_scheme: Override theme color for text field radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full" auto_complete: Whether the input should have autocomplete enabled - default_value: The value of the input when initially rendered. + default_value: The initial value for a text field disabled: Disables the input max_length: Specifies the maximum number of characters allowed in the input min_length: Specifies the minimum number of characters required in the input @@ -545,11 +613,31 @@ class TextField(ComponentNamespace): type: Specifies the type of input value: Value of the input list: References a datalist for suggested options - on_change: Fired when the value of the textarea changes. - on_focus: Fired when the textarea is focused. - on_blur: Fired when the textarea is blurred. - on_key_down: Fired when a key is pressed down. - on_key_up: Fired when a key is released. + on_change: Fired when the input value changes + on_focus: Fired when the input gains focus + on_blur: Fired when the input loses focus + on_key_down: Fired when a key is pressed down + on_key_up: Fired when a key is released + accept: Accepted types of files when the input is file type + alt: Alternate text for input type="image" + auto_focus: Automatically focuses the input when the page loads + capture: Captures media from the user (camera or microphone) + checked: Indicates whether the input is checked (for checkboxes and radio buttons) + default_checked: The initial value (for checkboxes and radio buttons) + dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted + form: Associates the input with a form (by id) + form_action: URL to send the form data to (for type="submit" buttons) + form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons) + form_method: HTTP method to use for sending form data (for type="submit" buttons) + form_no_validate: Bypasses form validation when submitting (for type="submit" buttons) + form_target: Specifies where to display the response after submitting the form (for type="submit" buttons) + max: Specifies the maximum value for the input + min: Specifies the minimum value for the input + multiple: Indicates whether multiple values can be entered in an input of the type email or file + pattern: Regex pattern the input's value must match to be valid + src: URL for image inputs + step: Specifies the legal number intervals for an input + use_map: Name of the image map used with the input access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. From ff510cacc5252e5985c33aaba7193bcd6bc2517b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 12:37:34 -0800 Subject: [PATCH 09/15] enable C4 rule (#4536) --- pyproject.toml | 2 +- reflex/compiler/utils.py | 2 +- reflex/components/component.py | 14 +--- .../radix/themes/components/icon_button.py | 2 +- reflex/components/recharts/charts.py | 8 +- reflex/constants/route.py | 2 +- reflex/event.py | 12 ++- reflex/model.py | 8 +- reflex/reflex.py | 4 +- reflex/state.py | 83 +++++++++---------- reflex/utils/exec.py | 2 +- reflex/utils/prerequisites.py | 6 +- reflex/vars/base.py | 45 +++++----- reflex/vars/function.py | 4 +- reflex/vars/number.py | 2 +- reflex/vars/sequence.py | 2 +- tests/units/components/core/test_banner.py | 6 +- tests/units/test_var.py | 4 +- 18 files changed, 94 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d59e7b8af3..731dbdd463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] target-version = "py39" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "D", "E", "F", "I", "SIM", "W", "RUF", "FURB", "ERA"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "RUF", "SIM", "W"] lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] lint.pydocstyle.convention = "google" diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 29398da87d..85d531be91 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -115,7 +115,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]: default, rest = compile_import_statement(fields) # prevent lib from being rendered on the page if all imports are non rendered kind - if not any({f.render for f in fields}): # type: ignore + if not any(f.render for f in fields): # type: ignore continue if not lib: diff --git a/reflex/components/component.py b/reflex/components/component.py index fd7c93cbd1..85458f16c8 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1208,7 +1208,7 @@ def _iter_parent_classes_with_method(cls, method: str) -> Iterator[Type[Componen Yields: The parent classes that define the method (differently than the base). """ - seen_methods = set([getattr(Component, method)]) + seen_methods = {getattr(Component, method)} for clz in cls.mro(): if clz is Component: break @@ -1390,15 +1390,9 @@ def _get_imports(self) -> ParsedImportDict: # Collect imports from Vars used directly by this component. var_datas = [var._get_all_var_data() for var in self._get_vars()] - var_imports: List[ImmutableParsedImportDict] = list( - map( - lambda var_data: var_data.imports, - filter( - None, - var_datas, - ), - ) - ) + var_imports: List[ImmutableParsedImportDict] = [ + var_data.imports for var_data in var_datas if var_data is not None + ] added_import_dicts: list[ParsedImportDict] = [] for clz in self._iter_parent_classes_with_method("add_imports"): diff --git a/reflex/components/radix/themes/components/icon_button.py b/reflex/components/radix/themes/components/icon_button.py index 2a32afe3a1..68c67485a0 100644 --- a/reflex/components/radix/themes/components/icon_button.py +++ b/reflex/components/radix/themes/components/icon_button.py @@ -79,7 +79,7 @@ def create(cls, *children, **props) -> Component: else: size_map_var = Match.create( props["size"], - *[(size, px) for size, px in RADIX_TO_LUCIDE_SIZE.items()], + *list(RADIX_TO_LUCIDE_SIZE.items()), 12, ) if not isinstance(size_map_var, Var): diff --git a/reflex/components/recharts/charts.py b/reflex/components/recharts/charts.py index 13f1252136..85e10c2c54 100644 --- a/reflex/components/recharts/charts.py +++ b/reflex/components/recharts/charts.py @@ -84,10 +84,10 @@ def create(cls, *children, **props) -> Component: cls._ensure_valid_dimension("width", width) cls._ensure_valid_dimension("height", height) - dim_props = dict( - width=width or "100%", - height=height or "100%", - ) + dim_props = { + "width": width or "100%", + "height": height or "100%", + } # Provide min dimensions so the graph always appears, even if the outer container is zero-size. if width is None: dim_props["min_width"] = 200 diff --git a/reflex/constants/route.py b/reflex/constants/route.py index 2af2f33c69..ab00fab153 100644 --- a/reflex/constants/route.py +++ b/reflex/constants/route.py @@ -31,7 +31,7 @@ class RouteVar(SimpleNamespace): # This subset of router_data is included in chained on_load events. -ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY)) +ROUTER_DATA_INCLUDE = {RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY} class RouteRegex(SimpleNamespace): diff --git a/reflex/event.py b/reflex/event.py index 8342062441..65ef5f3e62 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -297,7 +297,7 @@ def __init__( handler: EventHandler, event_actions: Dict[str, Union[bool, int]] | None = None, client_handler_name: str = "", - args: Tuple[Tuple[Var, Var], ...] = tuple(), + args: Tuple[Tuple[Var, Var], ...] = (), ): """Initialize an EventSpec. @@ -312,7 +312,7 @@ def __init__( object.__setattr__(self, "event_actions", event_actions) object.__setattr__(self, "handler", handler) object.__setattr__(self, "client_handler_name", client_handler_name) - object.__setattr__(self, "args", args or tuple()) + object.__setattr__(self, "args", args or ()) def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec: """Copy the event spec, with updated args. @@ -514,7 +514,7 @@ def no_args_event_spec() -> Tuple[()]: Returns: An empty tuple. """ - return tuple() # type: ignore + return () # type: ignore # These chains can be used for their side effects when no other events are desired. @@ -1137,9 +1137,7 @@ def run_script( Var(javascript_code) if isinstance(javascript_code, str) else javascript_code ) - return call_function( - ArgsFunctionOperation.create(tuple(), javascript_code), callback - ) + return call_function(ArgsFunctionOperation.create((), javascript_code), callback) def get_event(state, event): @@ -1491,7 +1489,7 @@ def get_handler_args( """ args = inspect.getfullargspec(event_spec.handler.fn).args - return event_spec.args if len(args) > 1 else tuple() + return event_spec.args if len(args) > 1 else () def fix_events( diff --git a/reflex/model.py b/reflex/model.py index 03b1cc8280..b1123add14 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -52,12 +52,12 @@ def get_engine_args(url: str | None = None) -> dict[str, Any]: Returns: The database engine arguments as a dict. """ - kwargs: dict[str, Any] = dict( + kwargs: dict[str, Any] = { # Print the SQL queries if the log level is INFO or lower. - echo=environment.SQLALCHEMY_ECHO.get(), + "echo": environment.SQLALCHEMY_ECHO.get(), # Check connections before returning them. - pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(), - ) + "pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(), + } conf = get_config() url = url or conf.db_url if url is not None and url.startswith("sqlite"): diff --git a/reflex/reflex.py b/reflex/reflex.py index 829c7c0d23..bcc9499efd 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -443,13 +443,13 @@ def deploy( hidden=True, ), regions: List[str] = typer.Option( - list(), + [], "-r", "--region", help="The regions to deploy to. `reflex cloud regions` For multiple envs, repeat this option, e.g. --region sjc --region iad", ), envs: List[str] = typer.Option( - list(), + [], "--env", help="The environment variables to set: =. For multiple envs, repeat this option, e.g. --env k1=v2 --env k2=v2.", ), diff --git a/reflex/state.py b/reflex/state.py index 434ee39217..aa40a8418a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -437,9 +437,7 @@ def __init__( ) # Create a fresh copy of the backend variables for this instance - self._backend_vars = copy.deepcopy( - {name: item for name, item in self.backend_vars.items()} - ) + self._backend_vars = copy.deepcopy(self.backend_vars) def __repr__(self) -> str: """Get the string representation of the state. @@ -523,9 +521,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars # Check if another substate class with the same name has already been defined. - if cls.get_name() in set( - c.get_name() for c in parent_state.class_subclasses - ): + if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}: # This should not happen, since we have added module prefix to state names in #3214 raise StateValueError( f"The substate class '{cls.get_name()}' has been defined multiple times. " @@ -788,11 +784,11 @@ def _init_var_dependency_dicts(cls): ) # ComputedVar with cache=False always need to be recomputed - cls._always_dirty_computed_vars = set( + cls._always_dirty_computed_vars = { cvar_name for cvar_name, cvar in cls.computed_vars.items() if not cvar._cache - ) + } # Any substate containing a ComputedVar with cache=False always needs to be recomputed if cls._always_dirty_computed_vars: @@ -1862,11 +1858,11 @@ def _expired_computed_vars(self) -> set[str]: Returns: Set of computed vars to include in the delta. """ - return set( + return { cvar for cvar in self.computed_vars if self.computed_vars[cvar].needs_update(instance=self) - ) + } def _dirty_computed_vars( self, from_vars: set[str] | None = None, include_backend: bool = True @@ -1880,12 +1876,12 @@ def _dirty_computed_vars( Returns: Set of computed vars to include in the delta. """ - return set( + return { cvar for dirty_var in from_vars or self.dirty_vars for cvar in self._computed_var_dependencies[dirty_var] if include_backend or not self.computed_vars[cvar]._backend - ) + } @classmethod def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: @@ -1895,16 +1891,16 @@ def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: Set of State classes that may need to be fetched to recalc computed vars. """ # _always_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = set( + fetch_substates = { cls.get_class_substate((cls.get_name(), *substate_name.split("."))) for substate_name in cls._always_dirty_substates - ) + } for dependent_substates in cls._substate_var_dependencies.values(): fetch_substates.update( - set( + { cls.get_class_substate((cls.get_name(), *substate_name.split("."))) for substate_name in dependent_substates - ) + } ) return fetch_substates @@ -2206,7 +2202,7 @@ def _field_tuple( return md5( pickle.dumps( - list(sorted(_field_tuple(field_name) for field_name in cls.base_vars)) + sorted(_field_tuple(field_name) for field_name in cls.base_vars) ) ).hexdigest() @@ -3654,33 +3650,30 @@ class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes.""" # Methods on wrapped objects which should mark the state as dirty. - __mark_dirty_attrs__ = set( - [ - "add", - "append", - "clear", - "difference_update", - "discard", - "extend", - "insert", - "intersection_update", - "pop", - "popitem", - "remove", - "reverse", - "setdefault", - "sort", - "symmetric_difference_update", - "update", - ] - ) + __mark_dirty_attrs__ = { + "add", + "append", + "clear", + "difference_update", + "discard", + "extend", + "insert", + "intersection_update", + "pop", + "popitem", + "remove", + "reverse", + "setdefault", + "sort", + "symmetric_difference_update", + "update", + } + # Methods on wrapped objects might return mutable objects that should be tracked. - __wrap_mutable_attrs__ = set( - [ - "get", - "setdefault", - ] - ) + __wrap_mutable_attrs__ = { + "get", + "setdefault", + } # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy. __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set( @@ -3723,7 +3716,7 @@ def _mark_dirty( self, wrapped=None, instance=None, - args=tuple(), + args=(), kwargs=None, ) -> Any: """Mark the state as dirty, then call a wrapped function. @@ -3979,7 +3972,7 @@ def _mark_dirty( self, wrapped=None, instance=None, - args=tuple(), + args=(), kwargs=None, ) -> Any: """Raise an exception when an attempt is made to modify the object. diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 3e69ecd0b6..0543c2d16c 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -117,7 +117,7 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True): console.print("New packages detected: Updating app...") else: if any( - [x in line for x in ("bin executable does not exist on disk",)] + x in line for x in ("bin executable does not exist on disk",) ): console.error( "Try setting `REFLEX_USE_NPM=1` and re-running `reflex init` and `reflex run` to use npm instead of bun:\n" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index a83843eeb8..f64ba7458d 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -699,7 +699,7 @@ def _update_next_config( } if transpile_packages: next_config["transpilePackages"] = list( - set((format_library_name(p) for p in transpile_packages)) + {format_library_name(p) for p in transpile_packages} ) if export: next_config["output"] = "export" @@ -925,7 +925,7 @@ def _inner(*args, **kwargs): @cached_procedure( cache_file=str(get_web_dir() / "reflex.install_frontend_packages.cached"), - payload_fn=lambda p, c: f"{sorted(list(p))!r},{c.json()}", + payload_fn=lambda p, c: f"{sorted(p)!r},{c.json()}", ) def install_frontend_packages(packages: set[str], config: Config): """Installs the base and custom frontend packages. @@ -1300,7 +1300,7 @@ def get_release_by_tag(tag: str) -> dict | None: for tp in templates_data: if tp["hidden"] or tp["code_url"] is None: continue - known_fields = set(f.name for f in dataclasses.fields(Template)) + known_fields = {f.name for f in dataclasses.fields(Template)} filtered_templates[tp["name"]] = Template( **{k: v for k, v in tp.items() if k in known_fields} ) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 941a9d81ab..3ff3c52def 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -146,7 +146,7 @@ def old_school_imports(self) -> ImportDict: Returns: The imports as a mutable dict. """ - return dict((k, list(v)) for k, v in self.imports) + return {k: list(v) for k, v in self.imports} def merge(*all: VarData | None) -> VarData | None: """Merge multiple var data objects. @@ -1591,14 +1591,12 @@ def _cached_get_all_var_data(self) -> VarData | None: The cached VarData. """ return VarData.merge( - *map( - lambda value: ( - value._get_all_var_data() if isinstance(value, Var) else None - ), - map( - lambda field: getattr(self, field.name), - dataclasses.fields(self), # type: ignore - ), + *( + value._get_all_var_data() if isinstance(value, Var) else None + for value in ( + getattr(self, field.name) + for field in dataclasses.fields(self) # type: ignore + ) ), self._var_data, ) @@ -1889,20 +1887,20 @@ def _replace(self, merge_var_data=None, **kwargs: Any) -> Self: Raises: TypeError: If kwargs contains keys that are not allowed. """ - field_values = dict( - fget=kwargs.pop("fget", self._fget), - initial_value=kwargs.pop("initial_value", self._initial_value), - cache=kwargs.pop("cache", self._cache), - deps=kwargs.pop("deps", self._static_deps), - auto_deps=kwargs.pop("auto_deps", self._auto_deps), - interval=kwargs.pop("interval", self._update_interval), - backend=kwargs.pop("backend", self._backend), - _js_expr=kwargs.pop("_js_expr", self._js_expr), - _var_type=kwargs.pop("_var_type", self._var_type), - _var_data=kwargs.pop( + field_values = { + "fget": kwargs.pop("fget", self._fget), + "initial_value": kwargs.pop("initial_value", self._initial_value), + "cache": kwargs.pop("cache", self._cache), + "deps": kwargs.pop("deps", self._static_deps), + "auto_deps": kwargs.pop("auto_deps", self._auto_deps), + "interval": kwargs.pop("interval", self._update_interval), + "backend": kwargs.pop("backend", self._backend), + "_js_expr": kwargs.pop("_js_expr", self._js_expr), + "_var_type": kwargs.pop("_var_type", self._var_type), + "_var_data": kwargs.pop( "_var_data", VarData.merge(self._var_data, merge_var_data) ), - ) + } if kwargs: unexpected_kwargs = ", ".join(kwargs.keys()) @@ -2371,10 +2369,7 @@ def _cached_get_all_var_data(self) -> VarData | None: The cached VarData. """ return VarData.merge( - *map( - lambda arg: arg[1]._get_all_var_data(), - self._args, - ), + *(arg[1]._get_all_var_data() for arg in self._args), self._return._get_all_var_data(), self._var_data, ) diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 9879fdb5d3..2a7d50e1b5 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -292,7 +292,7 @@ def create( class DestructuredArg: """Class for destructured arguments.""" - fields: Tuple[str, ...] = tuple() + fields: Tuple[str, ...] = () rest: Optional[str] = None def to_javascript(self) -> str: @@ -314,7 +314,7 @@ def to_javascript(self) -> str: class FunctionArgs: """Class for function arguments.""" - args: Tuple[Union[str, DestructuredArg], ...] = tuple() + args: Tuple[Union[str, DestructuredArg], ...] = () rest: Optional[str] = None diff --git a/reflex/vars/number.py b/reflex/vars/number.py index d14f9695b0..d04aded353 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -51,7 +51,7 @@ def raise_unsupported_operand_types( VarTypeError: The operand types are unsupported. """ raise VarTypeError( - f"Unsupported Operand type(s) for {operator}: {', '.join(map(lambda t: t.__name__, operands_types))}" + f"Unsupported Operand type(s) for {operator}: {', '.join(t.__name__ for t in operands_types)}" ) diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index a026453091..476c1e32c4 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -1177,7 +1177,7 @@ def foreach(self, fn: Any): if num_args == 0: return_value = fn() - function_var = ArgsFunctionOperation.create(tuple(), return_value) + function_var = ArgsFunctionOperation.create((), return_value) else: # generic number var number_var = Var("").to(NumberVar, int) diff --git a/tests/units/components/core/test_banner.py b/tests/units/components/core/test_banner.py index fe6de5eae1..e1498d12c0 100644 --- a/tests/units/components/core/test_banner.py +++ b/tests/units/components/core/test_banner.py @@ -12,7 +12,7 @@ def test_websocket_target_url(): url = WebsocketTargetURL.create() var_data = url._get_all_var_data() assert var_data is not None - assert sorted(tuple((key for key, _ in var_data.imports))) == sorted( + assert sorted(key for key, _ in var_data.imports) == sorted( ("$/utils/state", "$/env.json") ) @@ -20,7 +20,7 @@ def test_websocket_target_url(): def test_connection_banner(): banner = ConnectionBanner.create() _imports = banner._get_all_imports(collapse=True) - assert sorted(tuple(_imports)) == sorted( + assert sorted(_imports) == sorted( ( "react", "$/utils/context", @@ -38,7 +38,7 @@ def test_connection_banner(): def test_connection_modal(): modal = ConnectionModal.create() _imports = modal._get_all_imports(collapse=True) - assert sorted(tuple(_imports)) == sorted( + assert sorted(_imports) == sorted( ( "react", "$/utils/context", diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 048752d119..1072fca1b5 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -372,7 +372,7 @@ def test_basic_operations(TestObj): "var, expected", [ (v([1, 2, 3]), "[1, 2, 3]"), - (v(set([1, 2, 3])), "[1, 2, 3]"), + (v({1, 2, 3}), "[1, 2, 3]"), (v(["1", "2", "3"]), '["1", "2", "3"]'), ( Var(_js_expr="foo")._var_set_state("state").to(list), @@ -903,7 +903,7 @@ def test_literal_var(): True, False, None, - set([1, 2, 3]), + {1, 2, 3}, ] ) assert ( From 682bca7f9a7b6d8c531a2f3192acb70fa68c3a95 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:40:38 +0100 Subject: [PATCH 10/15] improve StateManagerRedis error message (#4444) --- reflex/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index aa40a8418a..e454746a9b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3350,7 +3350,7 @@ async def get_state( state_cls = self.state.get_class_substate(state_path) else: raise RuntimeError( - "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" + f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) # The deserialized or newly created (sub)state instance. From 76ce1120029c38cfe7d849557dfc9fbbec6dbf6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 12:41:29 -0800 Subject: [PATCH 11/15] add datetime var comparison operations (#4406) * add datetime var operations * add future annotations * add LiteralDatetimeVar * remove methods that don't apply * fix serialization * add unit and integrations test * oops, forgot to commit that important change --- reflex/vars/__init__.py | 1 + reflex/vars/datetime.py | 222 ++++++++++++++++++ .../test_datetime_operations.py | 87 +++++++ tests/units/utils/test_serializers.py | 1 + 4 files changed, 311 insertions(+) create mode 100644 reflex/vars/datetime.py create mode 100644 tests/integration/tests_playwright/test_datetime_operations.py diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index 1a4cebe19a..cb02319bc6 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -9,6 +9,7 @@ from .base import get_uuid_string_var as get_uuid_string_var from .base import var_operation as var_operation from .base import var_operation_return as var_operation_return +from .datetime import DateTimeVar as DateTimeVar from .function import FunctionStringVar as FunctionStringVar from .function import FunctionVar as FunctionVar from .function import VarOperationCall as VarOperationCall diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py new file mode 100644 index 0000000000..a4548e6f73 --- /dev/null +++ b/reflex/vars/datetime.py @@ -0,0 +1,222 @@ +"""Immutable datetime and date vars.""" + +from __future__ import annotations + +import dataclasses +import sys +from datetime import date, datetime +from typing import Any, NoReturn, TypeVar, Union, overload + +from reflex.utils.exceptions import VarTypeError +from reflex.vars.number import BooleanVar + +from .base import ( + CustomVarOperationReturn, + LiteralVar, + Var, + VarData, + var_operation, + var_operation_return, +) + +DATETIME_T = TypeVar("DATETIME_T", datetime, date) + +datetime_types = Union[datetime, date] + + +def raise_var_type_error(): + """Raise a VarTypeError. + + Raises: + VarTypeError: Cannot compare a datetime object with a non-datetime object. + """ + raise VarTypeError("Cannot compare a datetime object with a non-datetime object.") + + +class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)): + """A variable that holds a datetime or date object.""" + + @overload + def __lt__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __lt__(self, other: NoReturn) -> NoReturn: ... + + def __lt__(self, other: Any): + """Less than comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_lt_operation(self, other) + + @overload + def __le__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __le__(self, other: NoReturn) -> NoReturn: ... + + def __le__(self, other: Any): + """Less than or equal comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_le_operation(self, other) + + @overload + def __gt__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __gt__(self, other: NoReturn) -> NoReturn: ... + + def __gt__(self, other: Any): + """Greater than comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_gt_operation(self, other) + + @overload + def __ge__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __ge__(self, other: NoReturn) -> NoReturn: ... + + def __ge__(self, other: Any): + """Greater than or equal comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_ge_operation(self, other) + + +@var_operation +def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Greater than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs, strict=True) + + +@var_operation +def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Less than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs, strict=True) + + +@var_operation +def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Less than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs) + + +@var_operation +def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Greater than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs) + + +def date_compare_operation( + lhs: DateTimeVar[DATETIME_T] | Any, + rhs: DateTimeVar[DATETIME_T] | Any, + strict: bool = False, +) -> CustomVarOperationReturn: + """Check if the value is less than the other value. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + strict: Whether to use strict comparison. + + Returns: + The result of the operation. + """ + return var_operation_return( + f"({lhs} { '<' if strict else '<='} {rhs})", + bool, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralDatetimeVar(LiteralVar, DateTimeVar): + """Base class for immutable datetime and date vars.""" + + _var_value: datetime | date = dataclasses.field(default=datetime.now()) + + @classmethod + def create(cls, value: datetime | date, _var_data: VarData | None = None): + """Create a new instance of the class. + + Args: + value: The value to set. + + Returns: + LiteralDatetimeVar: The new instance of the class. + """ + js_expr = f'"{str(value)}"' + return cls( + _js_expr=js_expr, + _var_type=type(value), + _var_value=value, + _var_data=_var_data, + ) + + +DATETIME_TYPES = (datetime, date, DateTimeVar) diff --git a/tests/integration/tests_playwright/test_datetime_operations.py b/tests/integration/tests_playwright/test_datetime_operations.py new file mode 100644 index 0000000000..fafd15c420 --- /dev/null +++ b/tests/integration/tests_playwright/test_datetime_operations.py @@ -0,0 +1,87 @@ +from typing import Generator + +import pytest +from playwright.sync_api import Page, expect + +from reflex.testing import AppHarness + + +def DatetimeOperationsApp(): + from datetime import datetime + + import reflex as rx + + class DtOperationsState(rx.State): + date1: datetime = datetime(2021, 1, 1) + date2: datetime = datetime(2031, 1, 1) + date3: datetime = datetime(2021, 1, 1) + + app = rx.App(state=DtOperationsState) + + @app.add_page + def index(): + return rx.vstack( + rx.text(DtOperationsState.date1, id="date1"), + rx.text(DtOperationsState.date2, id="date2"), + rx.text(DtOperationsState.date3, id="date3"), + rx.text("Operations between date1 and date2"), + rx.text(DtOperationsState.date1 == DtOperationsState.date2, id="1_eq_2"), + rx.text(DtOperationsState.date1 != DtOperationsState.date2, id="1_neq_2"), + rx.text(DtOperationsState.date1 < DtOperationsState.date2, id="1_lt_2"), + rx.text(DtOperationsState.date1 <= DtOperationsState.date2, id="1_le_2"), + rx.text(DtOperationsState.date1 > DtOperationsState.date2, id="1_gt_2"), + rx.text(DtOperationsState.date1 >= DtOperationsState.date2, id="1_ge_2"), + rx.text("Operations between date1 and date3"), + rx.text(DtOperationsState.date1 == DtOperationsState.date3, id="1_eq_3"), + rx.text(DtOperationsState.date1 != DtOperationsState.date3, id="1_neq_3"), + rx.text(DtOperationsState.date1 < DtOperationsState.date3, id="1_lt_3"), + rx.text(DtOperationsState.date1 <= DtOperationsState.date3, id="1_le_3"), + rx.text(DtOperationsState.date1 > DtOperationsState.date3, id="1_gt_3"), + rx.text(DtOperationsState.date1 >= DtOperationsState.date3, id="1_ge_3"), + ) + + +@pytest.fixture() +def datetime_operations_app(tmp_path_factory) -> Generator[AppHarness, None, None]: + """Start Table app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("datetime_operations_app"), + app_source=DatetimeOperationsApp, # type: ignore + ) as harness: + assert harness.app_instance is not None, "app is not running" + yield harness + + +def test_datetime_operations(datetime_operations_app: AppHarness, page: Page): + assert datetime_operations_app.frontend_url is not None + + page.goto(datetime_operations_app.frontend_url) + expect(page).to_have_url(datetime_operations_app.frontend_url + "/") + # Check the actual values + expect(page.locator("id=date1")).to_have_text("2021-01-01 00:00:00") + expect(page.locator("id=date2")).to_have_text("2031-01-01 00:00:00") + expect(page.locator("id=date3")).to_have_text("2021-01-01 00:00:00") + + # Check the operations between date1 and date2 + expect(page.locator("id=1_eq_2")).to_have_text("false") + expect(page.locator("id=1_neq_2")).to_have_text("true") + expect(page.locator("id=1_lt_2")).to_have_text("true") + expect(page.locator("id=1_le_2")).to_have_text("true") + expect(page.locator("id=1_gt_2")).to_have_text("false") + expect(page.locator("id=1_ge_2")).to_have_text("false") + + # Check the operations between date1 and date3 + expect(page.locator("id=1_eq_3")).to_have_text("true") + expect(page.locator("id=1_neq_3")).to_have_text("false") + expect(page.locator("id=1_lt_3")).to_have_text("false") + expect(page.locator("id=1_le_3")).to_have_text("true") + expect(page.locator("id=1_gt_3")).to_have_text("false") + expect(page.locator("id=1_ge_3")).to_have_text("true") diff --git a/tests/units/utils/test_serializers.py b/tests/units/utils/test_serializers.py index 355f40d3fe..329e6b2525 100644 --- a/tests/units/utils/test_serializers.py +++ b/tests/units/utils/test_serializers.py @@ -222,6 +222,7 @@ def test_serialize(value: Any, expected: str): '"2021-01-01 01:01:01.000001"', True, ), + (datetime.date(2021, 1, 1), '"2021-01-01"', True), (Color(color="slate", shade=1), '"var(--slate-1)"', True), (BaseSubclass, '"BaseSubclass"', True), (Path("."), '"."', True), From 144442176687a55e1add5bd115bc6b0e9c73fa96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 13:28:55 -0800 Subject: [PATCH 12/15] add deps and position field in VarData (#4518) * fix memoized event trigger order * allow to declare deps in event signature for memoized event triggers * clean up the code to pass tests * handle position of hooks * clean up code * revert test changes * add future annotations * remove non-necessary stuff * reuse data_callback name if already set during first call to add_hooks * remove HookVar and use Var with VarData instead * remove test change * readd removed line * fix order of stmt for cleaner code * fix typing * something broke during the merge I guess * remove hack and pass proper const for position * oops, bad syntax in jinja * use "hook_position" instead of "hook_positions" match the name of the enum --------- Co-authored-by: Masen Furer --- .../web/pages/stateful_component.js.jinja2 | 6 ++- reflex/compiler/templates.py | 1 + reflex/components/component.py | 52 +++++++++++++++---- reflex/components/core/clipboard.py | 18 ++++--- reflex/components/core/clipboard.pyi | 2 +- reflex/components/datadisplay/dataeditor.py | 7 ++- reflex/constants/compiler.py | 6 +++ reflex/vars/base.py | 50 ++++++++++++++++-- 8 files changed, 115 insertions(+), 27 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 index 4a40ef5456..b04a787815 100644 --- a/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/stateful_component.js.jinja2 @@ -5,11 +5,15 @@ export function {{tag_name}} () { {{ hook }} {% endfor %} + {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %} + {{ hook }} + {% endfor %} + {% for hook in memo_trigger_hooks %} {{ hook }} {% endfor %} - {% for hook in component._get_all_hooks() %} + {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %} {{ hook }} {% endfor %} diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index c868a0cbb7..631aa4ee2d 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -45,6 +45,7 @@ def __init__(self) -> None: "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, + "hook_position": constants.Hooks.HookPosition, } diff --git a/reflex/components/component.py b/reflex/components/component.py index 85458f16c8..46318a30bc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1368,7 +1368,9 @@ def _get_hooks_imports(self) -> ParsedImportDict: if user_hooks_data is not None: other_imports.append(user_hooks_data.imports) other_imports.extend( - hook_imports for hook_imports in self._get_added_hooks().values() + hook_vardata.imports + for hook_vardata in self._get_added_hooks().values() + if hook_vardata is not None ) return imports.merge_imports(_imports, *other_imports) @@ -1516,7 +1518,7 @@ def _get_hooks_internal(self) -> dict[str, None]: **self._get_special_hooks(), } - def _get_added_hooks(self) -> dict[str, ImportDict]: + def _get_added_hooks(self) -> dict[str, VarData | None]: """Get the hooks added via `add_hooks` method. Returns: @@ -1525,17 +1527,15 @@ def _get_added_hooks(self) -> dict[str, ImportDict]: code = {} def extract_var_hooks(hook: Var): - _imports = {} var_data = VarData.merge(hook._get_all_var_data()) if var_data is not None: for sub_hook in var_data.hooks: - code[sub_hook] = {} - if var_data.imports: - _imports = var_data.imports + code[sub_hook] = None + if str(hook) in code: - code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) + code[str(hook)] = VarData.merge(var_data, code[str(hook)]) else: - code[str(hook)] = _imports + code[str(hook)] = var_data # Add the hook code from add_hooks for each parent class (this is reversed to preserve # the order of the hooks in the final output) @@ -1544,7 +1544,7 @@ def extract_var_hooks(hook: Var): if isinstance(hook, Var): extract_var_hooks(hook) else: - code[hook] = {} + code[hook] = None return code @@ -1586,8 +1586,8 @@ def _get_all_hooks(self) -> dict[str, None]: if hooks is not None: code[hooks] = None - for hook in self._get_added_hooks(): - code[hook] = None + for hook, var_data in self._get_added_hooks().items(): + code[hook] = var_data # Add the hook code for the children. for child in self.children: @@ -2189,6 +2189,31 @@ def _get_hook_deps(hook: str) -> list[str]: ] return [var_name] + @staticmethod + def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]: + """Get the dependencies accessed by event triggers. + + Args: + event: The event trigger to extract deps from. + + Returns: + The dependencies accessed by the event triggers. + """ + events: list = [event] + deps = set() + + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + for a in arg: + var_datas = VarData.merge(a._get_all_var_data()) + if var_datas and var_datas.deps is not None: + deps |= {str(dep) for dep in var_datas.deps} + return deps + @classmethod def _get_memoized_event_triggers( cls, @@ -2225,6 +2250,11 @@ def _get_memoized_event_triggers( # Calculate Var dependencies accessed by the handler for useCallback dep array. var_deps = ["addEvents", "Event"] + + # Get deps from event trigger var data. + var_deps.extend(cls._get_deps_from_event_trigger(event)) + + # Get deps from hooks. for arg in event_args: var_data = arg._get_all_var_data() if var_data is None: diff --git a/reflex/components/core/clipboard.py b/reflex/components/core/clipboard.py index 938cd13c07..644de80d07 100644 --- a/reflex/components/core/clipboard.py +++ b/reflex/components/core/clipboard.py @@ -6,11 +6,12 @@ from reflex.components.base.fragment import Fragment from reflex.components.tags.tag import Tag +from reflex.constants.compiler import Hooks from reflex.event import EventChain, EventHandler, passthrough_event_spec from reflex.utils.format import format_prop, wrap from reflex.utils.imports import ImportVar from reflex.vars import get_unique_variable_name -from reflex.vars.base import Var +from reflex.vars.base import Var, VarData class Clipboard(Fragment): @@ -72,7 +73,7 @@ def add_imports(self) -> dict[str, ImportVar]: ), } - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var[str]]: """Add hook to register paste event listener. Returns: @@ -83,13 +84,14 @@ def add_hooks(self) -> list[str]: return [] if isinstance(on_paste, EventChain): on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(") + hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})" + return [ - "usePasteHandler(%s, %s, %s)" - % ( - str(self.targets), - str(self.on_paste_event_actions), - on_paste, - ) + Var( + hook_expr, + _var_type="str", + _var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER), + ), ] diff --git a/reflex/components/core/clipboard.pyi b/reflex/components/core/clipboard.pyi index 69e0e866df..328554f2a4 100644 --- a/reflex/components/core/clipboard.pyi +++ b/reflex/components/core/clipboard.pyi @@ -71,6 +71,6 @@ class Clipboard(Fragment): ... def add_imports(self) -> dict[str, ImportVar]: ... - def add_hooks(self) -> list[str]: ... + def add_hooks(self) -> list[str | Var[str]]: ... clipboard = Clipboard.create diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index 79813205f2..2b80720ea1 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -339,8 +339,11 @@ def add_hooks(self) -> list[str]: editor_id = get_unique_variable_name() # Define the name of the getData callback associated with this component and assign to get_cell_content. - data_callback = f"getData_{editor_id}" - self.get_cell_content = Var(_js_expr=data_callback) # type: ignore + if self.get_cell_content is not None: + data_callback = self.get_cell_content._js_expr + else: + data_callback = f"getData_{editor_id}" + self.get_cell_content = Var(_js_expr=data_callback) # type: ignore code = [f"function {data_callback}([col, row])" "{"] diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b7ffef1613..7ca55f4dd9 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -132,6 +132,12 @@ class Hooks(SimpleNamespace): } })""" + class HookPosition(enum.Enum): + """The position of the hook in the component.""" + + PRE_TRIGGER = "pre_trigger" + POST_TRIGGER = "post_trigger" + class MemoizationDisposition(enum.Enum): """The conditions under which a component should be memoized.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 3ff3c52def..094a478c8c 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -42,7 +42,8 @@ from reflex import constants from reflex.base import Base -from reflex.utils import console, imports, serializers, types +from reflex.constants.compiler import Hooks +from reflex.utils import console, exceptions, imports, serializers, types from reflex.utils.exceptions import ( VarAttributeError, VarDependencyError, @@ -115,12 +116,20 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Dependencies of the var + deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) + + # Position of the hook in the component + position: Hooks.HookPosition | None = None + def __init__( self, state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + deps: list[Var] | None = None, + position: Hooks.HookPosition | None = None, ): """Initialize the var data. @@ -129,6 +138,8 @@ def __init__( field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + deps: Dependencies of the var for useCallback. + position: Position of the hook in the component. """ immutable_imports: ImmutableParsedImportDict = tuple( sorted( @@ -139,6 +150,8 @@ def __init__( object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__(self, "deps", tuple(deps or [])) + object.__setattr__(self, "position", position or None) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -154,6 +167,9 @@ def merge(*all: VarData | None) -> VarData | None: Args: *all: The var data objects to merge. + Raises: + ReflexError: If trying to merge VarData with different positions. + Returns: The merged var data object. @@ -184,12 +200,32 @@ def merge(*all: VarData | None) -> VarData | None: *(var_data.imports for var_data in all_var_datas) ) - if state or _imports or hooks or field_name: + deps = [dep for var_data in all_var_datas for dep in var_data.deps] + + positions = list( + { + var_data.position + for var_data in all_var_datas + if var_data.position is not None + } + ) + if positions: + if len(positions) > 1: + raise exceptions.ReflexError( + f"Cannot merge var data with different positions: {positions}" + ) + position = positions[0] + else: + position = None + + if state or _imports or hooks or field_name or deps or position: return VarData( state=state, field_name=field_name, imports=_imports, hooks=hooks, + deps=deps, + position=position, ) return None @@ -200,7 +236,14 @@ def __bool__(self) -> bool: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks or self.field_name) + return bool( + self.state + or self.imports + or self.hooks + or self.field_name + or self.deps + or self.position + ) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: @@ -480,7 +523,6 @@ def _replace( raise TypeError( "The _var_full_name_needs_state_prefix argument is not supported for Var." ) - value_with_replaced = dataclasses.replace( self, _var_type=_var_type or self._var_type, From 61cb72596e917c807f8e0acce98b589ca0178878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 14:06:26 -0800 Subject: [PATCH 13/15] enable PTH rule (#4476) * enable PTH rule * fix import in test_call_script * fix units tests * reorder ruff rules * Update reflex/utils/build.py Co-authored-by: Masen Furer * format pyproject.toml --------- Co-authored-by: Masen Furer --- benchmarks/benchmark_compile_times.py | 3 +- benchmarks/benchmark_imports.py | 3 +- pyproject.toml | 31 ++++++-------- reflex/config.py | 2 +- reflex/constants/custom_components.py | 2 +- reflex/custom_components/custom_components.py | 42 +++++++++---------- reflex/reflex.py | 3 +- reflex/testing.py | 16 +++---- reflex/utils/build.py | 2 +- reflex/utils/exec.py | 8 ++-- reflex/utils/export.py | 3 +- reflex/utils/path_ops.py | 4 +- reflex/utils/prerequisites.py | 35 ++++++++-------- reflex/vars/datetime.py | 2 +- tests/integration/test_call_script.py | 4 +- tests/units/states/upload.py | 21 ++++------ tests/units/test_prerequisites.py | 10 ++--- tests/units/utils/test_serializers.py | 2 +- tests/units/utils/test_utils.py | 6 +-- 19 files changed, 95 insertions(+), 104 deletions(-) diff --git a/benchmarks/benchmark_compile_times.py b/benchmarks/benchmark_compile_times.py index 2273bd5c83..56cb4e4cc1 100644 --- a/benchmarks/benchmark_compile_times.py +++ b/benchmarks/benchmark_compile_times.py @@ -5,6 +5,7 @@ import argparse import json import os +from pathlib import Path from utils import send_data_to_posthog @@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> list[dict]: Returns: list[dict]: The stats for each test. """ - with open(json_file, "r") as file: + with Path(json_file).open() as file: json_data = json.load(file) # Load the JSON data if it is a string, otherwise assume it's already a dictionary diff --git a/benchmarks/benchmark_imports.py b/benchmarks/benchmark_imports.py index 4706c0cf6f..8c3f9f46c5 100644 --- a/benchmarks/benchmark_imports.py +++ b/benchmarks/benchmark_imports.py @@ -5,6 +5,7 @@ import argparse import json import os +from pathlib import Path from utils import send_data_to_posthog @@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> dict: Returns: dict: The stats for each test. """ - with open(json_file, "r") as file: + with Path(json_file).open() as file: json_data = json.load(file) # Load the JSON data if it is a string, otherwise assume it's already a dictionary diff --git a/pyproject.toml b/pyproject.toml index 731dbdd463..fb6079fa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,26 +4,19 @@ version = "0.6.7dev1" description = "Web apps in pure Python." license = "Apache-2.0" authors = [ - "Nikhil Rao ", - "Alek Petuskey ", - "Masen Furer ", - "Elijah Ahianyo ", - "Thomas Brandého ", + "Nikhil Rao ", + "Alek Petuskey ", + "Masen Furer ", + "Elijah Ahianyo ", + "Thomas Brandého ", ] readme = "README.md" homepage = "https://reflex.dev" repository = "https://github.com/reflex-dev/reflex" documentation = "https://reflex.dev/docs/getting-started/introduction" -keywords = [ - "web", - "framework", -] -classifiers = [ - "Development Status :: 4 - Beta", -] -packages = [ - {include = "reflex"} -] +keywords = ["web", "framework"] +classifiers = ["Development Status :: 4 - Beta"] +packages = [{ include = "reflex" }] [tool.poetry.dependencies] python = "^3.9" @@ -42,11 +35,11 @@ uvicorn = ">=0.20.0" starlette-admin = ">=0.11.0,<1.0" alembic = ">=1.11.1,<2.0" platformdirs = ">=3.10.0,<5.0" -distro = {version = ">=1.8.0,<2.0", platform = "linux"} +distro = { version = ">=1.8.0,<2.0", platform = "linux" } python-engineio = "!=4.6.0" wrapt = [ - {version = ">=1.14.0,<2.0", python = ">=3.11"}, - {version = ">=1.11.0,<2.0", python = "<3.11"}, + { version = ">=1.14.0,<2.0", python = ">=3.11" }, + { version = ">=1.11.0,<2.0", python = "<3.11" }, ] packaging = ">=23.1,<25.0" reflex-hosting-cli = ">=0.1.29,<2.0" @@ -93,7 +86,7 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] target-version = "py39" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "RUF", "SIM", "W"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PTH", "RUF", "SIM", "W"] lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] lint.pydocstyle.convention = "google" diff --git a/reflex/config.py b/reflex/config.py index bbea6a5d0c..0579b019f1 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -873,7 +873,7 @@ def get_config(reload: bool = False) -> Config: with _config_lock: sys_path = sys.path.copy() sys.path.clear() - sys.path.append(os.getcwd()) + sys.path.append(str(Path.cwd())) try: # Try to import the module with only the current directory in the path. return _get_config() diff --git a/reflex/constants/custom_components.py b/reflex/constants/custom_components.py index d879a01f27..a499327b19 100644 --- a/reflex/constants/custom_components.py +++ b/reflex/constants/custom_components.py @@ -10,7 +10,7 @@ class CustomComponents(SimpleNamespace): """Constants for the custom components.""" # The name of the custom components source directory. - SRC_DIR = "custom_components" + SRC_DIR = Path("custom_components") # The name of the custom components pyproject.toml file. PYPROJECT_TOML = Path("pyproject.toml") # The name of the custom components package README file. diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 41808d60a9..4a169802f4 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -150,27 +150,27 @@ def _populate_demo_app(name_variants: NameVariants): from reflex.compiler import templates from reflex.reflex import _init - demo_app_dir = name_variants.demo_app_dir + demo_app_dir = Path(name_variants.demo_app_dir) demo_app_name = name_variants.demo_app_name - console.info(f"Creating app for testing: {demo_app_dir}") + console.info(f"Creating app for testing: {demo_app_dir!s}") - os.makedirs(demo_app_dir) + demo_app_dir.mkdir(exist_ok=True) with set_directory(demo_app_dir): # We start with the blank template as basis. _init(name=demo_app_name, template=constants.Templates.DEFAULT) # Then overwrite the app source file with the one we want for testing custom components. # This source file is rendered using jinja template file. - with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f: - f.write( - templates.CUSTOM_COMPONENTS_DEMO_APP.render( - custom_component_module_dir=name_variants.custom_component_module_dir, - module_name=name_variants.module_name, - ) + demo_file = Path(f"{demo_app_name}/{demo_app_name}.py") + demo_file.write_text( + templates.CUSTOM_COMPONENTS_DEMO_APP.render( + custom_component_module_dir=name_variants.custom_component_module_dir, + module_name=name_variants.module_name, ) + ) # Append the custom component package to the requirements.txt file. - with open(f"{constants.RequirementsTxt.FILE}", "a") as f: + with Path(f"{constants.RequirementsTxt.FILE}").open(mode="a") as f: f.write(f"{name_variants.package_name}\n") @@ -296,13 +296,14 @@ def _populate_custom_component_project(name_variants: NameVariants): ) console.info( - f"Initializing the component directory: {CustomComponents.SRC_DIR}/{name_variants.custom_component_module_dir}" + f"Initializing the component directory: {CustomComponents.SRC_DIR / name_variants.custom_component_module_dir}" ) - os.makedirs(CustomComponents.SRC_DIR) + CustomComponents.SRC_DIR.mkdir(exist_ok=True) with set_directory(CustomComponents.SRC_DIR): - os.makedirs(name_variants.custom_component_module_dir) + module_dir = Path(name_variants.custom_component_module_dir) + module_dir.mkdir(exist_ok=True, parents=True) _write_source_and_init_py( - custom_component_src_dir=name_variants.custom_component_module_dir, + custom_component_src_dir=module_dir, component_class_name=name_variants.component_class_name, module_name=name_variants.module_name, ) @@ -814,7 +815,7 @@ def _validate_project_info(): ) pyproject_toml["project"] = project try: - with open(CustomComponents.PYPROJECT_TOML, "w") as f: + with CustomComponents.PYPROJECT_TOML.open("w") as f: tomlkit.dump(pyproject_toml, f) except (OSError, TOMLKitError) as ex: console.error(f"Unable to write to pyproject.toml due to {ex}") @@ -922,16 +923,15 @@ def _validate_url_with_protocol_prefix(url: str | None) -> bool: def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None: image_file = file_extension = None while image_file is None: - image_filepath = console.ask( - "Upload a preview image of your demo app (enter to skip)" + image_filepath = Path( + console.ask("Upload a preview image of your demo app (enter to skip)") ) if not image_filepath: break - file_extension = image_filepath.split(".")[-1] + file_extension = image_filepath.suffix try: - with open(image_filepath, "rb") as f: - image_file = f.read() - return image_file, file_extension + image_file = image_filepath.read_bytes() + return image_file, file_extension except OSError as ose: console.error(f"Unable to read the {file_extension} file due to {ose}") raise typer.Exit(code=1) from ose diff --git a/reflex/reflex.py b/reflex/reflex.py index bcc9499efd..20866941f5 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -3,7 +3,6 @@ from __future__ import annotations import atexit -import os from pathlib import Path from typing import List, Optional @@ -298,7 +297,7 @@ def export( True, "--frontend-only", help="Export only frontend.", show_default=False ), zip_dest_dir: str = typer.Option( - os.getcwd(), + str(Path.cwd()), help="The directory to export the zip files to.", show_default=False, ), diff --git a/reflex/testing.py b/reflex/testing.py index 05b7d7c9d9..ca31054b30 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -8,7 +8,6 @@ import functools import inspect import os -import pathlib import platform import re import signal @@ -20,6 +19,7 @@ import time import types from http.server import SimpleHTTPRequestHandler +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -100,7 +100,7 @@ def __init__(self, path): def __enter__(self): """Save current directory and perform chdir.""" - self._old_cwd.append(os.getcwd()) + self._old_cwd.append(Path.cwd()) os.chdir(self.path) def __exit__(self, *excinfo): @@ -120,8 +120,8 @@ class AppHarness: app_source: Optional[ Callable[[], None] | types.ModuleType | str | functools.partial[Any] ] - app_path: pathlib.Path - app_module_path: pathlib.Path + app_path: Path + app_module_path: Path app_module: Optional[types.ModuleType] = None app_instance: Optional[reflex.App] = None frontend_process: Optional[subprocess.Popen] = None @@ -136,7 +136,7 @@ class AppHarness: @classmethod def create( cls, - root: pathlib.Path, + root: Path, app_source: Optional[ Callable[[], None] | types.ModuleType | str | functools.partial[Any] ] = None, @@ -814,7 +814,7 @@ def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler): """SimpleHTTPRequestHandler with custom error page handling.""" - def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs): + def __init__(self, *args, error_page_map: dict[int, Path], **kwargs): """Initialize the handler. Args: @@ -857,8 +857,8 @@ class Subdir404TCPServer(socketserver.TCPServer): def __init__( self, *args, - root: pathlib.Path, - error_page_map: dict[int, pathlib.Path] | None, + root: Path, + error_page_map: dict[int, Path] | None, **kwargs, ): """Initialize the server. diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 14709d99ce..e263374e14 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -150,7 +150,7 @@ def zip_app( _zip( component_name=constants.ComponentName.BACKEND, target=zip_dest_dir / constants.ComponentName.BACKEND.zip(), - root_dir=Path("."), + root_dir=Path.cwd(), dirs_to_exclude={"__pycache__"}, files_to_exclude=files_to_exclude, top_level_dirs_to_exclude={"assets"}, diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 0543c2d16c..621c4a608a 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -24,7 +24,7 @@ frontend_process = None -def detect_package_change(json_file_path: str) -> str: +def detect_package_change(json_file_path: Path) -> str: """Calculates the SHA-256 hash of a JSON file and returns it as a hexadecimal string. Args: @@ -37,7 +37,7 @@ def detect_package_change(json_file_path: str) -> str: >>> detect_package_change("package.json") 'a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0u1v2w3x4y5z6a7b8c9d0e1f2' """ - with open(json_file_path, "r") as file: + with json_file_path.open("r") as file: json_data = json.load(file) # Calculate the hash @@ -81,7 +81,7 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True): from reflex.utils import processes json_file_path = get_web_dir() / constants.PackageJson.PATH - last_hash = detect_package_change(str(json_file_path)) + last_hash = detect_package_change(json_file_path) process = None first_run = True @@ -124,7 +124,7 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True): "`REFLEX_USE_NPM=1 reflex init`\n" "`REFLEX_USE_NPM=1 reflex run`" ) - new_hash = detect_package_change(str(json_file_path)) + new_hash = detect_package_change(json_file_path) if new_hash != last_hash: last_hash = new_hash kill(process.pid) diff --git a/reflex/utils/export.py b/reflex/utils/export.py index 31ac0d0b5f..2fbf633f65 100644 --- a/reflex/utils/export.py +++ b/reflex/utils/export.py @@ -1,6 +1,5 @@ """Export utilities.""" -import os from pathlib import Path from typing import Optional @@ -15,7 +14,7 @@ def export( zipping: bool = True, frontend: bool = True, backend: bool = True, - zip_dest_dir: str = os.getcwd(), + zip_dest_dir: str = str(Path.cwd()), upload_db_file: bool = False, api_url: Optional[str] = None, deploy_url: Optional[str] = None, diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index a2ba2b1512..38560977e2 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -205,14 +205,14 @@ def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]): # Read the existing json object from the file. json_object = {} if fp.stat().st_size: - with open(fp) as f: + with fp.open() as f: json_object = json.load(f) # Update the json object with the new data. json_object.update(update_dict) # Write the updated json object to the file - with open(fp, "w") as f: + with fp.open("w") as f: json.dump(json_object, f, ensure_ascii=False) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index f64ba7458d..25e753d093 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -290,7 +290,7 @@ def get_app(reload: bool = False) -> ModuleType: "If this error occurs in a reflex test case, ensure that `get_app` is mocked." ) module = config.module - sys.path.insert(0, os.getcwd()) + sys.path.insert(0, str(Path.cwd())) app = __import__(module, fromlist=(constants.CompileVars.APP,)) if reload: @@ -438,9 +438,11 @@ def create_config(app_name: str): from reflex.compiler import templates config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config" - with open(constants.Config.FILE, "w") as f: - console.debug(f"Creating {constants.Config.FILE}") - f.write(templates.RXCONFIG.render(app_name=app_name, config_name=config_name)) + + console.debug(f"Creating {constants.Config.FILE}") + constants.Config.FILE.write_text( + templates.RXCONFIG.render(app_name=app_name, config_name=config_name) + ) def initialize_gitignore( @@ -494,14 +496,14 @@ def initialize_requirements_txt(): console.debug(f"Detected encoding for {fp} as {encoding}.") try: other_requirements_exist = False - with open(fp, "r", encoding=encoding) as f: + with fp.open("r", encoding=encoding) as f: for req in f: # Check if we have a package name that is reflex if re.match(r"^reflex[^a-zA-Z0-9]", req): console.debug(f"{fp} already has reflex as dependency.") return other_requirements_exist = True - with open(fp, "a", encoding=encoding) as f: + with fp.open("a", encoding=encoding) as f: preceding_newline = "\n" if other_requirements_exist else "" f.write( f"{preceding_newline}{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n" @@ -732,13 +734,13 @@ def download_and_run(url: str, *args, show_status: bool = False, **env): response.raise_for_status() # Save the script to a temporary file. - script = tempfile.NamedTemporaryFile() - with open(script.name, "w") as f: - f.write(response.text) + script = Path(tempfile.NamedTemporaryFile().name) + + script.write_text(response.text) # Run the script. env = {**os.environ, **env} - process = processes.new_process(["bash", f.name, *args], env=env) + process = processes.new_process(["bash", str(script), *args], env=env) show = processes.show_status if show_status else processes.show_logs show(f"Installing {url}", process) @@ -752,14 +754,14 @@ def download_and_extract_fnm_zip(): # Download the zip file url = constants.Fnm.INSTALL_URL console.debug(f"Downloading {url}") - fnm_zip_file = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip" + fnm_zip_file: Path = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip" # Function to download and extract the FNM zip release. try: # Download the FNM zip release. # TODO: show progress to improve UX response = net.get(url, follow_redirects=True) response.raise_for_status() - with open(fnm_zip_file, "wb") as output_file: + with fnm_zip_file.open("wb") as output_file: for chunk in response.iter_bytes(): output_file.write(chunk) @@ -807,7 +809,7 @@ def install_node(): ) else: # All other platforms (Linux, MacOS). # Add execute permissions to fnm executable. - os.chmod(constants.Fnm.EXE, stat.S_IXUSR) + constants.Fnm.EXE.chmod(stat.S_IXUSR) # Install node. # Specify arm64 arch explicitly for M1s and M2s. architecture_arg = ( @@ -1326,7 +1328,7 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str raise typer.Exit(1) from ose # Use httpx GET with redirects to download the zip file. - zip_file_path = Path(temp_dir) / "template.zip" + zip_file_path: Path = Path(temp_dir) / "template.zip" try: # Note: following redirects can be risky. We only allow this for reflex built templates at the moment. response = net.get(template_url, follow_redirects=True) @@ -1336,9 +1338,8 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str console.error(f"Failed to download the template: {he}") raise typer.Exit(1) from he try: - with open(zip_file_path, "wb") as f: - f.write(response.content) - console.debug(f"Downloaded the zip to {zip_file_path}") + zip_file_path.write_bytes(response.content) + console.debug(f"Downloaded the zip to {zip_file_path}") except OSError as ose: console.error(f"Unable to write the downloaded zip to disk {ose}") raise typer.Exit(1) from ose diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index a4548e6f73..b6f4c24c61 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -210,7 +210,7 @@ def create(cls, value: datetime | date, _var_data: VarData | None = None): Returns: LiteralDatetimeVar: The new instance of the class. """ - js_expr = f'"{str(value)}"' + js_expr = f'"{value!s}"' return cls( _js_expr=js_expr, _var_type=type(value), diff --git a/tests/integration/test_call_script.py b/tests/integration/test_call_script.py index 8c4bab8ce4..203c20e9b7 100644 --- a/tests/integration/test_call_script.py +++ b/tests/integration/test_call_script.py @@ -15,6 +15,7 @@ def CallScript(): """A test app for browser javascript integration.""" + from pathlib import Path from typing import Dict, List, Optional, Union import reflex as rx @@ -186,8 +187,7 @@ def reset_(self): self.reset() app = rx.App(state=rx.State) - with open("assets/external.js", "w") as f: - f.write(external_scripts) + Path("assets/external.js").write_text(external_scripts) @app.add_page def index(): diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 338025bcdb..66d9479b4f 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -61,14 +61,13 @@ async def multi_handle_upload(self, files: List[rx.UploadFile]): """ for file in files: upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" + assert file.filename is not None + outfile = self._tmp_path / file.filename # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) + outfile.write_bytes(upload_data) # Update the img var. - assert file.filename is not None self.img_list.append(file.filename) @rx.event(background=True) @@ -109,14 +108,13 @@ async def multi_handle_upload(self, files: List[rx.UploadFile]): """ for file in files: upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" + assert file.filename is not None + outfile = self._tmp_path / file.filename # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) + outfile.write_bytes(upload_data) # Update the img var. - assert file.filename is not None self.img_list.append(file.filename) @rx.event(background=True) @@ -157,14 +155,13 @@ async def multi_handle_upload(self, files: List[rx.UploadFile]): """ for file in files: upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" + assert file.filename is not None + outfile = self._tmp_path / file.filename # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) + outfile.write_bytes(upload_data) # Update the img var. - assert file.filename is not None self.img_list.append(file.filename) @rx.event(background=True) diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py index 2497318e73..90afe09632 100644 --- a/tests/units/test_prerequisites.py +++ b/tests/units/test_prerequisites.py @@ -105,8 +105,8 @@ def test_initialize_requirements_txt_no_op(mocker): return_value=Mock(best=lambda: Mock(encoding="utf-8")), ) mock_fp_touch = mocker.patch("pathlib.Path.touch") - open_mock = mock_open(read_data="reflex==0.2.9") - mocker.patch("builtins.open", open_mock) + open_mock = mock_open(read_data="reflex==0.6.7") + mocker.patch("pathlib.Path.open", open_mock) initialize_requirements_txt() assert open_mock.call_count == 1 assert open_mock.call_args.kwargs["encoding"] == "utf-8" @@ -122,7 +122,7 @@ def test_initialize_requirements_txt_missing_reflex(mocker): return_value=Mock(best=lambda: Mock(encoding="utf-8")), ) open_mock = mock_open(read_data="random-package=1.2.3") - mocker.patch("builtins.open", open_mock) + mocker.patch("pathlib.Path.open", open_mock) initialize_requirements_txt() # Currently open for read, then open for append assert open_mock.call_count == 2 @@ -138,7 +138,7 @@ def test_initialize_requirements_txt_not_exist(mocker): # File does not exist, create file with reflex mocker.patch("pathlib.Path.exists", return_value=False) open_mock = mock_open() - mocker.patch("builtins.open", open_mock) + mocker.patch("pathlib.Path.open", open_mock) initialize_requirements_txt() assert open_mock.call_count == 2 # By default, use utf-8 encoding @@ -170,7 +170,7 @@ def test_requirements_txt_other_encoding(mocker): ) initialize_requirements_txt() open_mock = mock_open(read_data="random-package=1.2.3") - mocker.patch("builtins.open", open_mock) + mocker.patch("pathlib.Path.open", open_mock) initialize_requirements_txt() # Currently open for read, then open for append assert open_mock.call_count == 2 diff --git a/tests/units/utils/test_serializers.py b/tests/units/utils/test_serializers.py index 329e6b2525..e5a47abaa4 100644 --- a/tests/units/utils/test_serializers.py +++ b/tests/units/utils/test_serializers.py @@ -225,7 +225,7 @@ def test_serialize(value: Any, expected: str): (datetime.date(2021, 1, 1), '"2021-01-01"', True), (Color(color="slate", shade=1), '"var(--slate-1)"', True), (BaseSubclass, '"BaseSubclass"', True), - (Path("."), '"."', True), + (Path(), '"."', True), ], ) def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool): diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 20bad41468..f8573111ce 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -270,7 +270,7 @@ def test_unsupported_literals(cls: type): ("appname2.io", "AppnameioConfig"), ], ) -def test_create_config(app_name, expected_config_name, mocker): +def test_create_config(app_name: str, expected_config_name: str, mocker): """Test templates.RXCONFIG is formatted with correct app name and config class name. Args: @@ -278,7 +278,7 @@ def test_create_config(app_name, expected_config_name, mocker): expected_config_name: Expected config name. mocker: Mocker object. """ - mocker.patch("builtins.open") + mocker.patch("pathlib.Path.write_text") tmpl_mock = mocker.patch("reflex.compiler.templates.RXCONFIG") prerequisites.create_config(app_name) tmpl_mock.render.assert_called_with( @@ -464,7 +464,7 @@ class Resp(Base): mocker.patch("httpx.stream", return_value=Resp()) download = mocker.patch("reflex.utils.prerequisites.download_and_extract_fnm_zip") process = mocker.patch("reflex.utils.processes.new_process") - chmod = mocker.patch("reflex.utils.prerequisites.os.chmod") + chmod = mocker.patch("pathlib.Path.chmod") mocker.patch("reflex.utils.processes.stream_logs") prerequisites.install_node() From d7956c19d3bb025d5eaef55fc6d49f8a7d4cded4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 13 Dec 2024 14:49:37 -0800 Subject: [PATCH 14/15] enable PERF rules (#4469) * enable PERF rules * fix scripts folder * Update reflex/compiler/utils.py Co-authored-by: Masen Furer --------- Co-authored-by: Masen Furer --- pyproject.toml | 5 +++-- reflex/base.py | 15 ++++++++------- reflex/compiler/utils.py | 3 +-- reflex/components/component.py | 8 ++++---- reflex/components/el/elements/__init__.py | 2 +- reflex/components/el/elements/__init__.pyi | 2 +- reflex/event.py | 15 ++++++++------- reflex/model.py | 8 ++++---- reflex/state.py | 17 ++++++++--------- reflex/utils/processes.py | 6 +++--- reflex/utils/pyi_generator.py | 7 +++---- scripts/wait_for_listening_port.py | 9 ++++----- 12 files changed, 48 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fb6079fa53..57d49e3f07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,14 +85,15 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] target-version = "py39" +output-format = "concise" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PTH", "RUF", "SIM", "W"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "W"] lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -"tests/*.py" = ["D100", "D103", "D104", "B018"] +"tests/*.py" = ["D100", "D103", "D104", "B018", "PERF"] "reflex/.templates/*.py" = ["D100", "D103", "D104"] "*.pyi" = ["D301", "D415", "D417", "D418", "E742"] "*/blank.py" = ["I001"] diff --git a/reflex/base.py b/reflex/base.py index 692f123a8a..a88e557ef4 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -30,15 +30,16 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None # can't use reflex.config.environment here cause of circular import reload = os.getenv("__RELOAD_CONFIG", "").lower() == "true" - for base in bases: - try: + base = None + try: + for base in bases: if not reload and getattr(base, field_name, None): pass - except TypeError as te: - raise VarNameError( - f'State var "{field_name}" in {base} has been shadowed by a substate var; ' - f'use a different field name instead".' - ) from te + except TypeError as te: + raise VarNameError( + f'State var "{field_name}" in {base} has been shadowed by a substate var; ' + f'use a different field name instead".' + ) from te # monkeypatch pydantic validate_field_name method to skip validating diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 85d531be91..1d698431cc 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -123,8 +123,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]: raise ValueError("No default field allowed for empty library.") if rest is None or len(rest) == 0: raise ValueError("No fields to import.") - for module in sorted(rest): - import_dicts.append(get_import_dict(module)) + import_dicts.extend(get_import_dict(module) for module in sorted(rest)) continue # remove the version before rendering the package imports diff --git a/reflex/components/component.py b/reflex/components/component.py index 46318a30bc..34800ab6e6 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1403,8 +1403,9 @@ def _get_imports(self) -> ParsedImportDict: if not isinstance(list_of_import_dict, list): list_of_import_dict = [list_of_import_dict] - for import_dict in list_of_import_dict: - added_import_dicts.append(parse_imports(import_dict)) + added_import_dicts.extend( + [parse_imports(import_dict) for import_dict in list_of_import_dict] + ) return imports.merge_imports( *self._get_props_imports(), @@ -1586,8 +1587,7 @@ def _get_all_hooks(self) -> dict[str, None]: if hooks is not None: code[hooks] = None - for hook, var_data in self._get_added_hooks().items(): - code[hook] = var_data + code.update(self._get_added_hooks()) # Add the hook code for the children. for child in self.children: diff --git a/reflex/components/el/elements/__init__.py b/reflex/components/el/elements/__init__.py index 45a7e04b87..f0d4fd2004 100644 --- a/reflex/components/el/elements/__init__.py +++ b/reflex/components/el/elements/__init__.py @@ -127,7 +127,7 @@ EXCLUDE = ["del_", "Del", "image"] -for _, v in _MAPPING.items(): +for v in _MAPPING.values(): v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE]) _SUBMOD_ATTRS: dict[str, list[str]] = _MAPPING diff --git a/reflex/components/el/elements/__init__.pyi b/reflex/components/el/elements/__init__.pyi index c96a80987a..defaa5848d 100644 --- a/reflex/components/el/elements/__init__.pyi +++ b/reflex/components/el/elements/__init__.pyi @@ -339,5 +339,5 @@ _MAPPING = { ], } EXCLUDE = ["del_", "Del", "image"] -for _, v in _MAPPING.items(): +for v in _MAPPING.values(): v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE]) diff --git a/reflex/event.py b/reflex/event.py index 65ef5f3e62..e4ca55c701 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -350,13 +350,14 @@ def add_args(self, *args: Var) -> EventSpec: # Construct the payload. values = [] - for arg in args: - try: - values.append(LiteralVar.create(arg)) - except TypeError as e: - raise EventHandlerTypeError( - f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." - ) from e + arg = None + try: + for arg in args: + values.append(LiteralVar.create(value=arg)) # noqa: PERF401 + except TypeError as e: + raise EventHandlerTypeError( + f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}." + ) from e new_payload = tuple(zip(fn_args, values)) return self.with_args(self.args + new_payload) diff --git a/reflex/model.py b/reflex/model.py index b1123add14..cb8bed29bb 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -4,6 +4,7 @@ import re from collections import defaultdict +from contextlib import suppress from typing import Any, ClassVar, Optional, Type, Union import alembic.autogenerate @@ -290,11 +291,10 @@ def dict(self, **kwargs): relationships = {} # SQLModel relationships do not appear in __fields__, but should be included if present. for name in self.__sqlmodel_relationships__: - try: + with suppress( + sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed. + ): relationships[name] = self._dict_recursive(getattr(self, name)) - except sqlalchemy.orm.exc.DetachedInstanceError: - # This happens when the relationship was never loaded and the session is closed. - continue return { **base_fields, **relationships, diff --git a/reflex/state.py b/reflex/state.py index e454746a9b..b181090da1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3438,17 +3438,16 @@ async def set_state( ) # Recursively set_state on all known substates. - tasks = [] - for substate in state.substates.values(): - tasks.append( - asyncio.create_task( - self.set_state( - token=_substate_key(client_token, substate), - state=substate, - lock_id=lock_id, - ) + tasks = [ + asyncio.create_task( + self.set_state( + _substate_key(client_token, substate), + substate, + lock_id, ) ) + for substate in state.substates.values() + ] # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): pickle_state = state._serialize() diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index 4d0e64a963..ef2d364014 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -58,7 +58,9 @@ def get_process_on_port(port) -> Optional[psutil.Process]: The process on the given port. """ for proc in psutil.process_iter(["pid", "name", "cmdline"]): - try: + with contextlib.suppress( + psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess + ): if importlib.metadata.version("psutil") >= "6.0.0": conns = proc.net_connections(kind="inet") # type: ignore else: @@ -66,8 +68,6 @@ def get_process_on_port(port) -> Optional[psutil.Process]: for conn in conns: if conn.laddr.port == int(port): return proc - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - pass return None diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 2d3d2664eb..c3a7b0ed12 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -287,10 +287,9 @@ def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str: for line in (clz.create.__doc__ or "").splitlines(): if "**" in line: indent = line.split("**")[0] - for nline in [ - f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items() - ]: - new_docstring.append(nline) + new_docstring.extend( + [f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()] + ) new_docstring.append(line) return "\n".join(new_docstring) diff --git a/scripts/wait_for_listening_port.py b/scripts/wait_for_listening_port.py index 247ff4fbaa..857ee7c6d8 100644 --- a/scripts/wait_for_listening_port.py +++ b/scripts/wait_for_listening_port.py @@ -49,11 +49,10 @@ def main(): parser.add_argument("--server-pid", type=int) args = parser.parse_args() executor = ThreadPoolExecutor(max_workers=len(args.port)) - futures = [] - for p in args.port: - futures.append( - executor.submit(_wait_for_port, p, args.server_pid, args.timeout) - ) + futures = [ + executor.submit(_wait_for_port, p, args.server_pid, args.timeout) + for p in args.port + ] for f in as_completed(futures): ok, msg = f.result() if ok: From f71e6f9559cfaaf13e3fc99111ae780527a2ae26 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 16 Dec 2024 12:21:32 -0800 Subject: [PATCH 15/15] Revert "only mark backend vars as dirty if they have changed (#4494)" (#4547) This reverts commit 3d89d74bdcd5afb85408ebb67a672701579f4efc. --- reflex/state.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index b181090da1..e7e6bcf326 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1307,9 +1307,6 @@ def __setattr__(self, name: str, value: Any): return if name in self.backend_vars: - # abort if unchanged - if self._backend_vars.get(name) == value: - return self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty()