Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement VideoViz to record model runs in a video #2453

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ dmypy.json
# JS dependencies
mesa/visualization/templates/external/
mesa/visualization/templates/js/external/

# Video
**/*.mp4
40 changes: 40 additions & 0 deletions mesa/examples/basic/schelling/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Example of using VideoViz with the Schelling model."""

from mesa.examples.basic.schelling.model import Schelling
from mesa.visualization.video_viz import (

Check warning on line 4 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L3-L4

Added lines #L3 - L4 were not covered by tests
VideoViz,
make_plot_component,
make_space_component,
)

# Create model
model = Schelling(10, 10)

Check warning on line 11 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L11

Added line #L11 was not covered by tests


def agent_portrayal(agent):

Check warning on line 14 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L14

Added line #L14 was not covered by tests
"""Portray agents based on their type."""
if agent is None:
return {}

Check warning on line 17 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L17

Added line #L17 was not covered by tests

portrayal = {

Check warning on line 19 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L19

Added line #L19 was not covered by tests
"color": "red" if agent.type == 0 else "blue",
"size": 25,
"marker": "s", # square marker
}
return portrayal

Check warning on line 24 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L24

Added line #L24 was not covered by tests


# Create visualization with space and some metrics
viz = VideoViz(

Check warning on line 28 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L28

Added line #L28 was not covered by tests
model,
[
make_space_component(agent_portrayal=agent_portrayal),
make_plot_component("happy"),
],
title="Schelling's Segregation Model",
)

# Record simulation
if __name__ == "__main__":
viz.record(steps=50, filepath="schelling.mp4")
print("Video saved to: schelling.mp4")

Check warning on line 40 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L39-L40

Added lines #L39 - L40 were not covered by tests
4 changes: 3 additions & 1 deletion mesa/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Solara based visualization for Mesa models."""

from mesa.visualization.mpl_space_drawing import (
from mesa.visualization.mpl_drawing import (
draw_plot,
draw_space,
)

Expand All @@ -15,6 +16,7 @@
"Slider",
"make_space_altair",
"draw_space",
"draw_plot",
"make_plot_component",
"make_space_component",
]
23 changes: 3 additions & 20 deletions mesa/visualization/components/matplotlib_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import warnings
from collections.abc import Callable

import matplotlib.pyplot as plt
import solara
from matplotlib.figure import Figure

from mesa.visualization.mpl_space_drawing import draw_space
from mesa.visualization.mpl_drawing import draw_plot, draw_space
from mesa.visualization.utils import update_counter


Expand Down Expand Up @@ -151,26 +150,10 @@ def PlotMatplotlib(
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
ax.legend(loc="best")
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
ax.legend(loc="best")

ax = fig.add_subplot()
draw_plot(model, measure, ax)
if post_process is not None:
post_process(ax)

ax.set_xlabel("Step")
# Set integer x axis
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
solara.FigureMatplotlib(
fig, format=save_format, bbox_inches="tight", dependencies=dependencies
)
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,41 @@ def _scatter(ax: Axes, arguments, **kwargs):
**{k: v[logical] for k, v in arguments.items()},
**kwargs,
)


def draw_plot(
model,
measure,
ax: Axes | None = None,
):
"""Create a Matplotlib-based plot for a measure or measures.

Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
ax: the axes upon which to draw the plot
post_process: a user-specified callable to do post-processing called with the Axes instance.

Returns:
plt.Axes: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
ax.legend(loc="best")
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
ax.legend(loc="best")

ax.set_xlabel("Step")
# Set integer x axis
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

return ax
201 changes: 201 additions & 0 deletions mesa/visualization/video_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Mesa visualization module for recording videos of model simulations.

This module uses Matplotlib to create visualizations of model spaces and
measures, and records them as videos.

Please install FFmpeg to use this module:
- macOS: brew install ffmpeg
- Linux: sudo apt-get install ffmpeg
- Windows: download from https://ffmpeg.org/download.html
"""

import shutil
from collections.abc import Callable, Sequence
from pathlib import Path

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

import mesa
from mesa.visualization.mpl_drawing import (
draw_plot,
draw_space,
)


def make_space_component(
agent_portrayal: Callable | None = None,
propertylayer_portrayal: dict | None = None,
post_process: Callable | None = None,
**space_drawing_kwargs,
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]:
"""Create a Matplotlib-based space visualization component.

Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
the functions for drawing the various spaces for further details.

``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.


Returns:
function: A function that returns a Axes instance with the space drawn
"""
if agent_portrayal is None:

def agent_portrayal(a):
return {}

def _make_space_component(model, ax=None):
space = getattr(model, "grid", None) or getattr(model, "space", None)
ax = draw_space(
space,
agent_portrayal,
propertylayer_portrayal,
ax,
**space_drawing_kwargs,
)
if post_process:
post_process(ax)
return ax

return _make_space_component


def make_plot_component(
measure: Callable,
post_process: Callable | None = None,
**kwargs,
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]:
"""Create a plotting function for a specified measure.

Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor.

Returns:
function: A function that returns a Axes instance with the measure(s) drawn
"""

def _make_plot_component(model, ax=None):
ax = draw_plot(model, measure, ax, **kwargs)
if post_process:
post_process(ax)
return ax

return _make_plot_component


class VideoViz:
"""Create high-quality video recordings of model simulations."""

def __init__(
self,
model: mesa.Model,
components: Sequence[Callable[[mesa.Model, plt.Axes | None], plt.Axes]],
*,
title: str | None = None,
figsize: tuple[float, float] | None = None,
grid: tuple[int, int] | None = None,
):
"""Initialize video visualization configuration.

Args:
model: The model to simulate and record
components: Sequence of component objects defining what to visualize
title: Optional title for the video
figsize: Optional figure size in inches (width, height)
grid: Optional (rows, cols) for custom layout. Auto-calculated if None.
"""
# Check if FFmpeg is available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perhaps for users the most tricky part. I agree on the error handling here. I would suggest also adding this to the module level docstring.

if not shutil.which("ffmpeg"):
raise RuntimeError(
"FFmpeg not found. Please install FFmpeg to save animations:\n"
" - macOS: brew install ffmpeg\n"
" - Linux: sudo apt-get install ffmpeg\n"
" - Windows: download from https://ffmpeg.org/download.html"
)
self.model = model
self.components = components
self.title = title
self.figsize = figsize
self.grid = grid or self._calculate_grid(len(components))

# Setup figure and axes
self.fig, self.axes = self._setup_figure()

def record(
self,
*,
steps: int,
filepath: str | Path,
dpi: int = 100,
fps: int = 10,
codec: str = "h264",
bitrate: int = 2000,
) -> None:
"""Record model simulation to video file.

Args:
steps: Number of simulation steps to record
filepath: Where to save the video file
dpi: Resolution of the output video
fps: Frames per second in the output video
codec: Video codec to use
bitrate: Video bitrate in kbps (default: 2000)

Raises:
RuntimeError: If FFmpeg is not installed
"""
filepath = Path(filepath)

def update(frame_num):
# Update model state
self.model.step()

# Render all visualization frames
for component, ax in zip(self.components, self.axes):
ax.clear()
component(self.model, ax)
return self.axes

# Create and save animation
anim = animation.FuncAnimation(
self.fig, update, frames=steps, interval=1000 / fps, blit=False
)

writer = animation.FFMpegWriter(
fps=fps,
codec=codec,
bitrate=bitrate, # Now passing as integer
)

anim.save(filepath, writer=writer, dpi=dpi)

def _calculate_grid(self, n_frames: int) -> tuple[int, int]:
"""Calculate optimal grid layout for given number of frames."""
cols = min(3, n_frames) # Max 3 columns
rows = int(np.ceil(n_frames / cols))
return (rows, cols)

def _setup_figure(self):
"""Setup matplotlib figure and axes."""
if not self.figsize:
self.figsize = (5 * self.grid[1], 5 * self.grid[0])
fig = plt.figure(figsize=self.figsize)
axes = []

for i in range(len(self.components)):
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1)
axes.append(ax)

if self.title:
fig.suptitle(self.title, fontsize=16)
fig.tight_layout()
return fig, axes
2 changes: 1 addition & 1 deletion tests/test_components_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PropertyLayer,
SingleGrid,
)
from mesa.visualization.mpl_space_drawing import (
from mesa.visualization.mpl_drawing import (
draw_continuous_space,
draw_hex_grid,
draw_network,
Expand Down
Loading