Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Altair Grid #1902

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
99 changes: 99 additions & 0 deletions mesa/experimental/altair_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Callable, Optional

import altair as alt
import solara

import mesa


def get_agent_data_from_coord_iter(data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my current implementation, I'm using an agent_portrayal method to generate the values needed to draw the space. That may be drawn from the old mesa visualization approach, IDK if there's a good way to pass in something like that to jupyterviz.

I mention because I wonder if it would be cleaner and more explicit than the way you're using json to dump and filter the agent dict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @rlskoeser; the JSON approach is not very clean. Is there a more explicit way to do this?

"""
Extracts agent data from a sequence of tuples containing agent objects and their coordinates.

Parameters:
- data (iterable): A sequence of tuples where each tuple contains an agent object and its coordinates.

Yields:
- dict: A dictionary containing agent data with updated coordinates. The dictionary excludes 'model' and 'pos' attributes.
"""
for agent, (x, y) in data:
if agent:
agent_data = agent[0].__dict__.copy()
agent_data.update({"x": x, "y": y})
agent_data.pop("model", None)
agent_data.pop("pos", None)
yield agent_data


def create_grid(
color: Optional[str] = None,
on_click: Optional[Callable[[mesa.Model, mesa.space.Coordinate], None]] = None,
) -> Callable[[mesa.Model], solara.component]:
"""
Factory function for creating a grid component for a Mesa model.

Parameters:
- color (Optional[str]): Color of the grid lines. Defaults to None.
- on_click (Optional[Callable[[mesa.Model, mesa.space.Coordinate], None]]):
Function to be called when a grid cell is clicked. Defaults to None.

Returns:
- Callable[[mesa.Model], solara.component]: A function that creates a grid component for the given model.
"""

def create_grid_function(model: mesa.Model) -> solara.component:
return Grid(model, color, on_click)

return create_grid_function


def Grid(model, color=None, on_click=None):
"""
Handles click events on grid cells.

Parameters:
- datum (dict): Data associated with the clicked cell.

Notes:
- Invokes the provided `on_click` function with the model and cell coordinates.
- Updates the data displayed on the grid.
"""
if color is None:
color = "unique_id:N"

if color[-2] != ":":
color = color + ":N"

print(model.grid.coord_iter())

data = solara.reactive(
list(get_agent_data_from_coord_iter(model.grid.coord_iter()))
)

def update_data():
data.value = list(get_agent_data_from_coord_iter(model.grid.coord_iter()))

def click_handler(datum):
if datum is None:
return
on_click(model, datum["x"], datum["y"])
update_data()

default_tooltip = [
f"{key}:N" for key in data.value[0]
] # add all agent attributes to tooltip
chart = (
alt.Chart(alt.Data(values=data.value))
.mark_rect()
.encode(
x=alt.X("x:N", scale=alt.Scale(domain=list(range(model.grid.width)))),
y=alt.Y(
"y:N",
scale=alt.Scale(domain=list(range(model.grid.height - 1, -1, -1))),
),
color=color,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of my models where I'm using a custom altair space drawer, I'm setting color, size, and shape. Probably reasonable not to support all of those on the first pass, but it would be good to think about a more generalized approach (like the agent portrayal method) that would make it possible to customize this without having to completely re-implement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be pretty straightforward to do with the create_grid function

tooltip=default_tooltip,
)
.properties(width=600, height=600)
)
return solara.FigureAltair(chart, on_click=click_handler)
22 changes: 20 additions & 2 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
import threading
from typing import Optional

import matplotlib.pyplot as plt
import reacton.ipywidgets as widgets
import solara
from solara.alias import rv

import mesa.experimental.components.matplotlib as components_matplotlib
from mesa.experimental.altair_grid import create_grid
from mesa.experimental.UserParam import Slider

# Avoid interactive backend
Expand Down Expand Up @@ -113,6 +115,9 @@ def render_in_jupyter():
components_matplotlib.SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer == "altair":
# draw with the default implementation
SpaceAltair(model, agent_portrayal, dependencies=[current_step.value])
elif space_drawer:
# if specified, draw agent space with an alternate renderer
space_drawer(model, agent_portrayal)
Expand All @@ -128,7 +133,7 @@ def render_in_jupyter():
model, measure, dependencies=[current_step.value]
)

def render_in_browser():
def render_in_browser(statistics=False):
# if space drawer is disabled, do not include it
layout_types = [{"Space": "default"}] if space_drawer else []

Expand All @@ -144,6 +149,13 @@ def render_in_browser():
ModelController(model, play_interval, current_step, reset_counter)
with solara.Card("Progress", margin=1, elevation=2):
solara.Markdown(md_text=f"####Step - {current_step}")
with solara.Card("Analytics", margin=1, elevation=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the PR.

if statistics:
df = model.datacollector.get_model_vars_dataframe()
for col in list(df.columns):
solara.Markdown(
md_text=f"####Avg. {col} - {df.loc[:, f'{col}'].mean()}"
)

items = [
Card(
Expand Down Expand Up @@ -345,6 +357,12 @@ def change_handler(value, name=name):
raise ValueError(f"{input_type} is not a supported input type")


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: Optional[list[any]] = None):
grid = create_grid(color="wealth")
grid(model)


def make_text(renderer):
def function(model):
solara.Markdown(renderer(model))
Expand All @@ -356,7 +374,7 @@ def get_initial_grid_layout(layout_types):
grid_lay = []
y_coord = 0
for ii in range(len(layout_types)):
template_layout = {"h": 10, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0}
template_layout = {"h": 20, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0}
if ii == 0:
grid_lay.append(template_layout)
else:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"pandas",
"solara",
"tqdm",
"altair"
]
dynamic = ["version"]

Expand Down
3 changes: 3 additions & 0 deletions tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,18 @@ def test_call_space_drawer(self, mock_space_matplotlib):
}
current_step = 0
dependencies = [current_step]

# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
solara.render(
JupyterViz(
model_class=mock_model_class,
model_params={},
agent_portrayal=agent_portrayal,
space_drawer="default",
)
)

# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(
mock_model_class.return_value, agent_portrayal, dependencies=dependencies
Expand Down