-
Notifications
You must be signed in to change notification settings - Fork 927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement VideoViz to record model runs in a video #2453
Open
wang-boyu
wants to merge
2
commits into
projectmesa:main
Choose a base branch
from
wang-boyu:video-viz
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Example of using VideoViz with the Schelling model.""" | ||
|
||
from mesa.examples.basic.schelling.model import Schelling | ||
from mesa.visualization.video_viz import ( | ||
VideoViz, | ||
make_plot_component, | ||
make_space_component, | ||
) | ||
|
||
# Create model | ||
model = Schelling(10, 10) | ||
|
||
|
||
def agent_portrayal(agent): | ||
"""Portray agents based on their type.""" | ||
if agent is None: | ||
return {} | ||
|
||
portrayal = { | ||
"color": "red" if agent.type == 0 else "blue", | ||
"size": 25, | ||
"marker": "s", # square marker | ||
} | ||
return portrayal | ||
|
||
|
||
# Create visualization with space and some metrics | ||
viz = VideoViz( | ||
model, | ||
[ | ||
make_space_component(agent_portrayal=agent_portrayal), | ||
make_plot_component("happy"), | ||
], | ||
title="Schelling's Segregation Model", | ||
) | ||
|
||
# Record simulation | ||
if __name__ == "__main__": | ||
viz.record(steps=50, filepath="schelling.mp4") | ||
print("Video saved to: schelling.mp4") | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
"""Mesa visualization module for recording videos of model simulations. | ||
|
||
This module uses Matplotlib to create visualizations of model spaces and | ||
measures, and records them as videos. | ||
|
||
Please install FFmpeg to use this module: | ||
- macOS: brew install ffmpeg | ||
- Linux: sudo apt-get install ffmpeg | ||
- Windows: download from https://ffmpeg.org/download.html | ||
""" | ||
|
||
import shutil | ||
from collections.abc import Callable, Sequence | ||
from pathlib import Path | ||
|
||
import matplotlib.animation as animation | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import mesa | ||
from mesa.visualization.mpl_drawing import ( | ||
draw_plot, | ||
draw_space, | ||
) | ||
|
||
|
||
def make_space_component( | ||
agent_portrayal: Callable | None = None, | ||
propertylayer_portrayal: dict | None = None, | ||
post_process: Callable | None = None, | ||
**space_drawing_kwargs, | ||
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]: | ||
"""Create a Matplotlib-based space visualization component. | ||
|
||
Args: | ||
agent_portrayal: Function to portray agents. | ||
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications | ||
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) | ||
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See | ||
the functions for drawing the various spaces for further details. | ||
|
||
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", | ||
"size", "marker", and "zorder". Other field are ignored and will result in a user warning. | ||
|
||
|
||
Returns: | ||
function: A function that returns a Axes instance with the space drawn | ||
""" | ||
if agent_portrayal is None: | ||
|
||
def agent_portrayal(a): | ||
return {} | ||
|
||
def _make_space_component(model, ax=None): | ||
space = getattr(model, "grid", None) or getattr(model, "space", None) | ||
ax = draw_space( | ||
space, | ||
agent_portrayal, | ||
propertylayer_portrayal, | ||
ax, | ||
**space_drawing_kwargs, | ||
) | ||
if post_process: | ||
post_process(ax) | ||
return ax | ||
|
||
return _make_space_component | ||
|
||
|
||
def make_plot_component( | ||
measure: Callable, | ||
post_process: Callable | None = None, | ||
**kwargs, | ||
) -> Callable[[mesa.Model, plt.Axes | None], plt.Axes]: | ||
"""Create a plotting function for a specified measure. | ||
|
||
Args: | ||
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. | ||
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) | ||
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor. | ||
|
||
Returns: | ||
function: A function that returns a Axes instance with the measure(s) drawn | ||
""" | ||
|
||
def _make_plot_component(model, ax=None): | ||
ax = draw_plot(model, measure, ax, **kwargs) | ||
if post_process: | ||
post_process(ax) | ||
return ax | ||
|
||
return _make_plot_component | ||
|
||
|
||
class VideoViz: | ||
"""Create high-quality video recordings of model simulations.""" | ||
|
||
def __init__( | ||
self, | ||
model: mesa.Model, | ||
components: Sequence[Callable[[mesa.Model, plt.Axes | None], plt.Axes]], | ||
*, | ||
title: str | None = None, | ||
figsize: tuple[float, float] | None = None, | ||
grid: tuple[int, int] | None = None, | ||
): | ||
"""Initialize video visualization configuration. | ||
|
||
Args: | ||
model: The model to simulate and record | ||
components: Sequence of component objects defining what to visualize | ||
title: Optional title for the video | ||
figsize: Optional figure size in inches (width, height) | ||
grid: Optional (rows, cols) for custom layout. Auto-calculated if None. | ||
""" | ||
# Check if FFmpeg is available | ||
if not shutil.which("ffmpeg"): | ||
raise RuntimeError( | ||
"FFmpeg not found. Please install FFmpeg to save animations:\n" | ||
" - macOS: brew install ffmpeg\n" | ||
" - Linux: sudo apt-get install ffmpeg\n" | ||
" - Windows: download from https://ffmpeg.org/download.html" | ||
) | ||
self.model = model | ||
self.components = components | ||
self.title = title | ||
self.figsize = figsize | ||
self.grid = grid or self._calculate_grid(len(components)) | ||
|
||
# Setup figure and axes | ||
self.fig, self.axes = self._setup_figure() | ||
|
||
def record( | ||
self, | ||
*, | ||
steps: int, | ||
filepath: str | Path, | ||
dpi: int = 100, | ||
fps: int = 10, | ||
codec: str = "h264", | ||
bitrate: int = 2000, | ||
) -> None: | ||
"""Record model simulation to video file. | ||
|
||
Args: | ||
steps: Number of simulation steps to record | ||
filepath: Where to save the video file | ||
dpi: Resolution of the output video | ||
fps: Frames per second in the output video | ||
codec: Video codec to use | ||
bitrate: Video bitrate in kbps (default: 2000) | ||
|
||
Raises: | ||
RuntimeError: If FFmpeg is not installed | ||
""" | ||
filepath = Path(filepath) | ||
|
||
def update(frame_num): | ||
# Update model state | ||
self.model.step() | ||
|
||
# Render all visualization frames | ||
for component, ax in zip(self.components, self.axes): | ||
ax.clear() | ||
component(self.model, ax) | ||
return self.axes | ||
|
||
# Create and save animation | ||
anim = animation.FuncAnimation( | ||
self.fig, update, frames=steps, interval=1000 / fps, blit=False | ||
) | ||
|
||
writer = animation.FFMpegWriter( | ||
fps=fps, | ||
codec=codec, | ||
bitrate=bitrate, # Now passing as integer | ||
) | ||
|
||
anim.save(filepath, writer=writer, dpi=dpi) | ||
|
||
def _calculate_grid(self, n_frames: int) -> tuple[int, int]: | ||
"""Calculate optimal grid layout for given number of frames.""" | ||
cols = min(3, n_frames) # Max 3 columns | ||
rows = int(np.ceil(n_frames / cols)) | ||
return (rows, cols) | ||
|
||
def _setup_figure(self): | ||
"""Setup matplotlib figure and axes.""" | ||
if not self.figsize: | ||
self.figsize = (5 * self.grid[1], 5 * self.grid[0]) | ||
fig = plt.figure(figsize=self.figsize) | ||
axes = [] | ||
|
||
for i in range(len(self.components)): | ||
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1) | ||
axes.append(ax) | ||
|
||
if self.title: | ||
fig.suptitle(self.title, fontsize=16) | ||
fig.tight_layout() | ||
return fig, axes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is perhaps for users the most tricky part. I agree on the error handling here. I would suggest also adding this to the module level docstring.