From d9e718d7bde45c8180edbc084673720c21e207b8 Mon Sep 17 00:00:00 2001 From: HongyuHansonYao <159659100+HongyuHansonYao@users.noreply.github.com> Date: Fri, 31 May 2024 11:48:43 -0700 Subject: [PATCH] Layout property not pushed through on rx.plotly (#3394) * init fix * Update reflex/components/plotly/plotly.py Co-authored-by: Masen Furer * plotly: treat `data` as a `dict`-type Var in _render this allows the data to be passed directly as a figure or from a state var * removed width height prop as they are no longer needed * updated * reverted some of the changes * fixed unit tests * regen pyi --------- Co-authored-by: Hongyu Yao Co-authored-by: Masen Furer Co-authored-by: Hongyu Yao --- reflex/components/plotly/plotly.py | 18 ++++++++++++------ reflex/components/plotly/plotly.pyi | 4 ---- reflex/utils/serializers.py | 6 +++--- tests/components/graphing/test_plotly.py | 2 +- tests/utils/test_format.py | 4 +++- 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/reflex/components/plotly/plotly.py b/reflex/components/plotly/plotly.py index 7a0dd835f6c..3ee1977c356 100644 --- a/reflex/components/plotly/plotly.py +++ b/reflex/components/plotly/plotly.py @@ -35,11 +35,17 @@ class Plotly(PlotlyLib): # The config of the graph. config: Var[Dict] - # The width of the graph. - width: Var[str] - - # The height of the graph. - height: Var[str] - # If true, the graph will resize when the window is resized. use_resize_handler: Var[bool] + + def _render(self): + tag = super()._render() + figure = self.data.to(dict) + if self.layout is None: + tag.remove_props("data", "layout") + tag.special_props.add( + Var.create_safe(f"{{...{figure._var_name_unwrapped}}}") + ) + else: + tag.add_props(data=figure["data"]) + return tag diff --git a/reflex/components/plotly/plotly.pyi b/reflex/components/plotly/plotly.pyi index 7c8068e1e03..02288804f0b 100644 --- a/reflex/components/plotly/plotly.pyi +++ b/reflex/components/plotly/plotly.pyi @@ -101,8 +101,6 @@ class Plotly(PlotlyLib): data: Optional[Union[Var[Figure], Figure]] = None, # type: ignore layout: Optional[Union[Var[Dict], Dict]] = None, config: Optional[Union[Var[Dict], Dict]] = None, - width: Optional[Union[Var[str], str]] = None, - height: Optional[Union[Var[str], str]] = None, use_resize_handler: Optional[Union[Var[bool], bool]] = None, style: Optional[Style] = None, key: Optional[Any] = None, @@ -164,8 +162,6 @@ class Plotly(PlotlyLib): data: The figure to display. This can be a plotly figure or a plotly data json. layout: The layout of the graph. config: The config of the graph. - width: The width of the graph. - height: The height of the graph. use_resize_handler: If true, the graph will resize when the window is resized. style: The style of the component. key: A unique key for the component. diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 1e1de7507b4..f3f9e635f92 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -255,7 +255,7 @@ def serialize_enum(en: Enum) -> str: en: The enum to serialize. Returns: - The serialized enum. + The serialized enum. """ return en.value @@ -313,7 +313,7 @@ def serialize_dataframe(df: DataFrame) -> dict: from plotly.io import to_json @serializer - def serialize_figure(figure: Figure) -> list: + def serialize_figure(figure: Figure) -> dict: """Serialize a plotly figure. Args: @@ -322,7 +322,7 @@ def serialize_figure(figure: Figure) -> list: Returns: The serialized figure. """ - return json.loads(str(to_json(figure)))["data"] + return json.loads(str(to_json(figure))) except ImportError: pass diff --git a/tests/components/graphing/test_plotly.py b/tests/components/graphing/test_plotly.py index 0e17789b573..69b046bea34 100644 --- a/tests/components/graphing/test_plotly.py +++ b/tests/components/graphing/test_plotly.py @@ -30,7 +30,7 @@ def test_serialize_plotly(plotly_fig: go.Figure): plotly_fig: The figure to serialize. """ value = serialize(plotly_fig) - assert isinstance(value, list) + assert isinstance(value, dict) assert value == serialize_figure(plotly_fig) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 19f3851759c..e9c79a4a2ca 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -3,12 +3,14 @@ import datetime from typing import Any, List +import plotly.graph_objects as go import pytest from reflex.components.tags.tag import Tag from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent from reflex.style import Style from reflex.utils import format +from reflex.utils.serializers import serialize_figure from reflex.vars import BaseVar, Var from tests.test_state import ( ChildState, @@ -661,7 +663,7 @@ def test_format_query_params(input, output): 2: {"prop1": 42, "prop2": "hello"}, }, "dt": "1989-11-09 18:53:00+01:00", - "fig": [], + "fig": serialize_figure(go.Figure()), "key": "", "map_key": "a", "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},