diff --git a/docs/apis/visualization.md b/docs/apis/visualization.md index 7815ec568b8..7c39dae5b01 100644 --- a/docs/apis/visualization.md +++ b/docs/apis/visualization.md @@ -21,19 +21,26 @@ For a detailed tutorial, please refer to our [Visualization Tutorial](../tutoria ``` -## Matplotlib-based components +## Matplotlib-based visualizations ```{eval-rst} -.. automodule:: mesa.visualization.components.matplotlib +.. automodule:: mesa.visualization.components.matplotlib_components :members: :undoc-members: :show-inheritance: ``` -## Altair-based components +```{eval-rst} +.. automodule:: mesa.visualization.mpl_space_drawing + :members: + :undoc-members: + :show-inheritance: +``` + +## Altair-based visualizations ```{eval-rst} -.. automodule:: mesa.visualization.components.altair +.. automodule:: mesa.visualization.components.altair_components :members: :undoc-members: :show-inheritance: diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py index fafedde6b09..6664e2c132f 100644 --- a/mesa/examples/advanced/pd_grid/app.py +++ b/mesa/examples/advanced/pd_grid/app.py @@ -3,7 +3,11 @@ """ from mesa.examples.advanced.pd_grid.model import PdGrid -from mesa.visualization import SolaraViz, make_plot_component, make_space_component +from mesa.visualization import ( + SolaraViz, + make_plot_component, + make_space_component, +) from mesa.visualization.UserParam import Slider diff --git a/mesa/examples/basic/boid_flockers/app.py b/mesa/examples/basic/boid_flockers/app.py index bcecb0a3ebd..e374185c05f 100644 --- a/mesa/examples/basic/boid_flockers/app.py +++ b/mesa/examples/basic/boid_flockers/app.py @@ -51,7 +51,7 @@ def boid_draw(agent): page = SolaraViz( model, - [make_space_component(agent_portrayal=boid_draw)], + [make_space_component(agent_portrayal=boid_draw, backend="matplotlib")], model_params=model_params, name="Boid Flocking Model", ) diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index ff329ab9667..7f92cbbbe01 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -1,5 +1,9 @@ from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel -from mesa.visualization import SolaraViz, make_plot_component, make_space_component +from mesa.visualization import ( + SolaraViz, + make_plot_component, + make_space_component, +) def agent_portrayal(agent): diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index 4bac98704cc..66597b0e71d 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -1,7 +1,11 @@ """Solara based visualization for Mesa models.""" -from .components.altair import make_space_altair -from .components.matplotlib import make_plot_component, make_space_component +from mesa.visualization.mpl_space_drawing import ( + draw_space, +) + +from .components import make_plot_component, make_space_component +from .components.altair_components import make_space_altair from .solara_viz import JupyterViz, SolaraViz from .UserParam import Slider @@ -10,6 +14,7 @@ "SolaraViz", "Slider", "make_space_altair", - "make_space_component", + "draw_space", "make_plot_component", + "make_space_component", ] diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py new file mode 100644 index 00000000000..4b70fc2b97c --- /dev/null +++ b/mesa/visualization/components/__init__.py @@ -0,0 +1,83 @@ +"""custom solara components.""" + +from __future__ import annotations + +from collections.abc import Callable + +from .altair_components import SpaceAltair, make_altair_space +from .matplotlib_components import ( + SpaceMatplotlib, + make_mpl_plot_component, + make_mpl_space_component, +) + + +def make_space_component( + agent_portrayal: Callable | None = None, + propertylayer_portrayal: dict | None = None, + post_process: Callable | None = None, + backend: str = "matplotlib", + **space_drawing_kwargs, +) -> SpaceMatplotlib | SpaceAltair: + """Create a Matplotlib-based space visualization component. + + Args: + agent_portrayal: Function to portray agents. + propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications + post_process : a callable that will be called with the Axes instance. Allows for fine-tuning plots (e.g., control ticks) + backend: the backend to use {"matplotlib", "altair"} + space_drawing_kwargs : additional keyword arguments to be passed on to the underlying backend specific space drawer function. See + the functions for drawing the various spaces for the appropriate backend further details. + + + Returns: + function: A function that creates a space component + """ + if backend == "matplotlib": + return make_mpl_space_component( + agent_portrayal, + propertylayer_portrayal, + post_process, + **space_drawing_kwargs, + ) + elif backend == "altair": + return make_altair_space( + agent_portrayal, + propertylayer_portrayal, + post_process, + **space_drawing_kwargs, + ) + else: + raise ValueError( + f"unknown backend {backend}, must be one of matplotlib, altair" + ) + + +def make_plot_component( + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable | None = None, + backend: str = "matplotlib", + **plot_drawing_kwargs, +): + """Create a plotting function for a specified measure using the specified backend. + + Args: + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + backend: the backend to use {"matplotlib", "altair"} + plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component + + Notes: + altair plotting backend is not yet implemented and planned for mesa 3.1. + + Returns: + function: A function that creates a plot component + """ + if backend == "matplotlib": + return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs) + elif backend == "altair": + raise NotImplementedError("altair line plots are not yet implemented") + else: + raise ValueError( + f"unknown backend {backend}, must be one of matplotlib, altair" + ) diff --git a/mesa/visualization/components/altair.py b/mesa/visualization/components/altair_components.py similarity index 81% rename from mesa/visualization/components/altair.py rename to mesa/visualization/components/altair_components.py index 5aeee84761d..b610e46f0d0 100644 --- a/mesa/visualization/components/altair.py +++ b/mesa/visualization/components/altair_components.py @@ -1,6 +1,7 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib +import warnings import solara @@ -12,7 +13,33 @@ from mesa.visualization.utils import update_counter -def make_space_altair(agent_portrayal=None): # noqa: D103 +def make_space_altair(*args, **kwargs): # noqa: D103 + warnings.warn( + "make_space_altair has been renamed to make_altair_space", + DeprecationWarning, + stacklevel=2, + ) + return make_altair_space(*args, **kwargs) + + +def make_altair_space( + agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs +): + """Create an Altair-based space visualization component. + + Args: + agent_portrayal: Function to portray agents. + propertylayer_portrayal: not yet implemented + post_process :not yet implemented + space_drawing_kwargs : not yet implemented + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + + Returns: + function: A function that creates a SpaceMatplotlib component + """ if agent_portrayal is None: def agent_portrayal(a): @@ -25,7 +52,12 @@ def MakeSpaceAltair(model): @solara.component -def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103 +def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): + """Create an Altair-based space visualization component. + + Returns: + a solara FigureAltair instance + """ update_counter.get() space = getattr(model, "grid", None) if space is None: diff --git a/mesa/visualization/components/matplotlib_components.py b/mesa/visualization/components/matplotlib_components.py new file mode 100644 index 00000000000..6c7bb1ae040 --- /dev/null +++ b/mesa/visualization/components/matplotlib_components.py @@ -0,0 +1,176 @@ +"""Matplotlib based solara components for visualization MESA spaces and plots.""" + +from __future__ import annotations + +import warnings +from collections.abc import Callable + +import matplotlib.pyplot as plt +import solara +from matplotlib.figure import Figure + +from mesa.visualization.mpl_space_drawing import draw_space +from mesa.visualization.utils import update_counter + + +def make_space_matplotlib(*args, **kwargs): # noqa: D103 + warnings.warn( + "make_space_matplotlib has been renamed to make_mpl_space_component", + DeprecationWarning, + stacklevel=2, + ) + return make_mpl_space_component(*args, **kwargs) + + +def make_mpl_space_component( + agent_portrayal: Callable | None = None, + propertylayer_portrayal: dict | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, +) -> SpaceMatplotlib: + """Create a Matplotlib-based space visualization component. + + Args: + agent_portrayal: Function to portray agents. + propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications + post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) + space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See + the functions for drawing the various spaces for further details. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + + Returns: + function: A function that creates a SpaceMatplotlib component + """ + if agent_portrayal is None: + + def agent_portrayal(a): + return {} + + def MakeSpaceMatplotlib(model): + return SpaceMatplotlib( + model, + agent_portrayal, + propertylayer_portrayal, + post_process=post_process, + **space_drawing_kwargs, + ) + + return MakeSpaceMatplotlib + + +@solara.component +def SpaceMatplotlib( + model, + agent_portrayal, + propertylayer_portrayal, + dependencies: list[any] | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, +): + """Create a Matplotlib-based space visualization component.""" + update_counter.get() + + space = getattr(model, "grid", None) + if space is None: + space = getattr(model, "space", None) + + fig = Figure() + ax = fig.add_subplot() + + draw_space( + space, + agent_portrayal, + propertylayer_portrayal=propertylayer_portrayal, + ax=ax, + **space_drawing_kwargs, + ) + + if post_process is not None: + post_process(ax) + + solara.FigureMatplotlib( + fig, format="png", bbox_inches="tight", dependencies=dependencies + ) + + +def make_plot_measure(*args, **kwargs): # noqa: D103 + warnings.warn( + "make_plot_measure has been renamed to make_plot_component", + DeprecationWarning, + stacklevel=2, + ) + return make_mpl_plot_component(*args, **kwargs) + + +def make_mpl_plot_component( + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable | None = None, + save_format="png", +): + """Create a plotting function for a specified measure. + + Args: + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + save_format: save format of figure in solara backend + + Returns: + function: A function that creates a PlotMatplotlib component. + """ + + def MakePlotMatplotlib(model): + return PlotMatplotlib( + model, measure, post_process=post_process, save_format=save_format + ) + + return MakePlotMatplotlib + + +@solara.component +def PlotMatplotlib( + model, + measure, + dependencies: list[any] | None = None, + post_process: Callable | None = None, + save_format="png", +): + """Create a Matplotlib-based plot for a measure or measures. + + Args: + model (mesa.Model): The model instance. + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + dependencies (list[any] | None): Optional dependencies for the plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + save_format: format used for saving the figure. + + Returns: + solara.FigureMatplotlib: A component for rendering the plot. + """ + update_counter.get() + fig = Figure() + ax = fig.subplots() + df = model.datacollector.get_model_vars_dataframe() + if isinstance(measure, str): + ax.plot(df.loc[:, measure]) + ax.set_ylabel(measure) + elif isinstance(measure, dict): + for m, color in measure.items(): + ax.plot(df.loc[:, m], label=m, color=color) + ax.legend(loc="best") + elif isinstance(measure, list | tuple): + for m in measure: + ax.plot(df.loc[:, m], label=m) + ax.legend(loc="best") + + if post_process is not None: + post_process(ax) + + ax.set_xlabel("Step") + # Set integer x axis + ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + solara.FigureMatplotlib( + fig, format=save_format, bbox_inches="tight", dependencies=dependencies + ) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/mpl_space_drawing.py similarity index 77% rename from mesa/visualization/components/matplotlib.py rename to mesa/visualization/mpl_space_drawing.py index 2bda984f775..6353d8106b8 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -1,4 +1,10 @@ -"""Matplotlib based solara components for visualization MESA spaces and plots.""" +"""Helper functions for drawing mesa spaces with matplotlib. + +These functions are used by the provided matplotlib components, but can also be used to quickly visualize +a space with matplotlib for example when creating a mp4 of a movie run or when needing a figure +for a paper. + +""" import itertools import math @@ -6,15 +12,13 @@ from collections.abc import Callable from typing import Any -import matplotlib.pyplot as plt import networkx as nx import numpy as np -import solara +from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable from matplotlib.collections import PatchCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba -from matplotlib.figure import Figure from matplotlib.patches import RegularPolygon import mesa @@ -32,95 +36,12 @@ PropertyLayer, SingleGrid, ) -from mesa.visualization.utils import update_counter -# For typing OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid HexGrid = HexSingleGrid | HexMultiGrid | mesa.experimental.cell_space.HexGrid Network = NetworkGrid | mesa.experimental.cell_space.Network -def make_space_matplotlib(*args, **kwargs): # noqa: D103 - warnings.warn( - "make_space_matplotlib has been renamed to make_space_component", - DeprecationWarning, - stacklevel=2, - ) - return make_space_component(*args, **kwargs) - - -def make_space_component( - agent_portrayal: Callable | None = None, - propertylayer_portrayal: dict | None = None, - post_process: Callable | None = None, - **space_drawing_kwargs, -): - """Create a Matplotlib-based space visualization component. - - Args: - agent_portrayal: Function to portray agents. - propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications - post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) - space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See - the functions for drawing the various spaces for further details. - - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. - - - Returns: - function: A function that creates a SpaceMatplotlib component - """ - if agent_portrayal is None: - - def agent_portrayal(a): - return {} - - def MakeSpaceMatplotlib(model): - return SpaceMatplotlib( - model, - agent_portrayal, - propertylayer_portrayal, - post_process=post_process, - **space_drawing_kwargs, - ) - - return MakeSpaceMatplotlib - - -@solara.component -def SpaceMatplotlib( - model, - agent_portrayal, - propertylayer_portrayal, - dependencies: list[any] | None = None, - post_process: Callable | None = None, - **space_drawing_kwargs, -): - """Create a Matplotlib-based space visualization component.""" - update_counter.get() - - space = getattr(model, "grid", None) - if space is None: - space = getattr(model, "space", None) - - fig = Figure() - ax = fig.add_subplot() - - draw_space( - space, - agent_portrayal, - propertylayer_portrayal=propertylayer_portrayal, - ax=ax, - post_process=post_process, - **space_drawing_kwargs, - ) - - solara.FigureMatplotlib( - fig, format="png", bbox_inches="tight", dependencies=dependencies - ) - - def collect_agent_data( space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid, agent_portrayal: Callable, @@ -173,7 +94,6 @@ def draw_space( agent_portrayal: Callable, propertylayer_portrayal: dict | None = None, ax: Axes | None = None, - post_process: Callable | None = None, **space_drawing_kwargs, ): """Draw a Matplotlib-based visualization of the space. @@ -184,8 +104,6 @@ def draw_space( propertylayer_portrayal: a dict specifying how to show propertylayer(s) ax: the axes upon which to draw the plot post_process: a callable called with the Axes instance - postprocess: a user-specified callable to do post-processing called with the Axes instance. This callable - can be used for any further fine-tuning of the plot (e.g., changing ticks, etc.) space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space. Returns: @@ -214,9 +132,6 @@ def draw_space( if propertylayer_portrayal: draw_property_layers(space, propertylayer_portrayal, ax=ax) - if post_process is not None: - post_process(ax=ax) - return ax @@ -641,83 +556,3 @@ def _scatter(ax: Axes, arguments, **kwargs): **{k: v[logical] for k, v in arguments.items()}, **kwargs, ) - - -def make_plot_measure(*args, **kwargs): # noqa: D103 - warnings.warn( - "make_plot_measure has been renamed to make_plot_component", - DeprecationWarning, - stacklevel=2, - ) - return make_plot_component(*args, **kwargs) - - -def make_plot_component( - measure: str | dict[str, str] | list[str] | tuple[str], - post_process: Callable | None = None, - save_format="png", -): - """Create a plotting function for a specified measure. - - Args: - measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. - post_process: a user-specified callable to do post-processing called with the Axes instance. - save_format: save format of figure in solara backend - - Returns: - function: A function that creates a PlotMatplotlib component. - """ - - def MakePlotMatplotlib(model): - return PlotMatplotlib( - model, measure, post_process=post_process, save_format=save_format - ) - - return MakePlotMatplotlib - - -@solara.component -def PlotMatplotlib( - model, - measure, - dependencies: list[any] | None = None, - post_process: Callable | None = None, - save_format="png", -): - """Create a Matplotlib-based plot for a measure or measures. - - Args: - model (mesa.Model): The model instance. - measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. - dependencies (list[any] | None): Optional dependencies for the plot. - post_process: a user-specified callable to do post-processing called with the Axes instance. - save_format: format used for saving the figure. - - Returns: - solara.FigureMatplotlib: A component for rendering the plot. - """ - update_counter.get() - fig = Figure() - ax = fig.subplots() - df = model.datacollector.get_model_vars_dataframe() - if isinstance(measure, str): - ax.plot(df.loc[:, measure]) - ax.set_ylabel(measure) - elif isinstance(measure, dict): - for m, color in measure.items(): - ax.plot(df.loc[:, m], label=m, color=color) - ax.legend(loc="best") - elif isinstance(measure, list | tuple): - for m in measure: - ax.plot(df.loc[:, m], label=m) - ax.legend(loc="best") - - if post_process is not None: - post_process(ax) - - ax.set_xlabel("Step") - # Set integer x axis - ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) - solara.FigureMatplotlib( - fig, format=save_format, bbox_inches="tight", dependencies=dependencies - ) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 4bca98d8b54..d5d83f437ea 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -31,7 +31,7 @@ import reacton.core import solara -import mesa.visualization.components.altair as components_altair +import mesa.visualization.components.altair_components as components_altair from mesa.visualization.UserParam import Slider from mesa.visualization.utils import force_update, update_counter diff --git a/tests/test_components_matplotlib.py b/tests/test_components_matplotlib.py index c85dd1ce292..9c454e77b2e 100644 --- a/tests/test_components_matplotlib.py +++ b/tests/test_components_matplotlib.py @@ -17,7 +17,7 @@ PropertyLayer, SingleGrid, ) -from mesa.visualization.components.matplotlib import ( +from mesa.visualization.mpl_space_drawing import ( draw_continuous_space, draw_hex_grid, draw_network, diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index af6badd0bc2..9276696676f 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -6,9 +6,9 @@ import solara import mesa -import mesa.visualization.components.altair -import mesa.visualization.components.matplotlib -from mesa.visualization.components.matplotlib import make_space_component +import mesa.visualization.components.altair_components +import mesa.visualization.components.matplotlib_components +from mesa.visualization.components.matplotlib_components import make_mpl_space_component from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs @@ -88,10 +88,12 @@ def Test(user_params): def test_call_space_drawer(mocker): # noqa: D103 mock_space_matplotlib = mocker.spy( - mesa.visualization.components.matplotlib, "SpaceMatplotlib" + mesa.visualization.components.matplotlib_components, "SpaceMatplotlib" ) - mock_space_altair = mocker.spy(mesa.visualization.components.altair, "SpaceAltair") + mock_space_altair = mocker.spy( + mesa.visualization.components.altair_components, "SpaceAltair" + ) model = mesa.Model() mocker.patch.object(mesa.Model, "__init__", return_value=None) @@ -103,7 +105,9 @@ def test_call_space_drawer(mocker): # noqa: D103 propertylayer_portrayal = None # initialize with space drawer unspecified (use default) # component must be rendered for code to run - solara.render(SolaraViz(model, components=[make_space_component(agent_portrayal)])) + solara.render( + SolaraViz(model, components=[make_mpl_space_component(agent_portrayal)]) + ) # should call default method with class instance and agent portrayal mock_space_matplotlib.assert_called_with( model, agent_portrayal, propertylayer_portrayal, post_process=None @@ -114,7 +118,7 @@ def test_call_space_drawer(mocker): # noqa: D103 solara.render(SolaraViz(model)) # should call default method with class instance and agent portrayal assert mock_space_matplotlib.call_count == 0 - assert mock_space_altair.call_count > 0 + assert mock_space_altair.call_count == 0 # specify a custom space method class AltSpace: @@ -132,7 +136,7 @@ def drawer(model): centroids_coordinates=[(0, 1), (0, 0), (1, 0)], ) solara.render( - SolaraViz(voronoi_model, components=[make_space_component(agent_portrayal)]) + SolaraViz(voronoi_model, components=[make_mpl_space_component(agent_portrayal)]) )