Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggegrated agent metric in DataCollection, graph in ChartModule #1145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from operator import attrgetter
import pandas as pd
import types
import statistics


class DataCollector:
Expand Down Expand Up @@ -99,6 +100,7 @@
self.agent_reporters = {}

self.model_vars = {}
self.agent_vars = {}
self._agent_records = {}
self.tables = {}

Expand Down Expand Up @@ -159,6 +161,7 @@
if all([hasattr(rep, "attribute_name") for rep in rep_funcs]):
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
self.agent_vars = {k: v for v, k in enumerate(prefix + attributes)}
get_reports = attrgetter(*prefix + attributes)
else:

Expand Down Expand Up @@ -256,3 +259,36 @@
if table_name not in self.tables:
raise Exception("No such table.")
return pd.DataFrame(self.tables[table_name])

def get_agent_metric(self, var_name, metric="mean"):
"""Get a single aggegrated value from an agent variable.

Args:
var_name: The name of the variable to aggegrate.
metric: Statistics metric to be used (default: mean)
all functions from built-in statistics module are supported
as well as "min", "max", "sum" and "len"

"""
# Get the reporter from the name
reporter = self.agent_reporters[var_name]

Check warning on line 274 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L274

Added line #L274 was not covered by tests

# Get the index of the reporter
attr_index = self.agent_vars[reporter]

Check warning on line 277 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L277

Added line #L277 was not covered by tests

# Create a dictionary with all attributes from all agents
attr_dict = self._agent_records

Check warning on line 280 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L280

Added line #L280 was not covered by tests

# Get the values from all agents in a list
values_tuples = list(attr_dict.values())[-1]

Check warning on line 283 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L283

Added line #L283 was not covered by tests

# Get the correct value using the attribute index
values = [value_tuple[attr_index] for value_tuple in values_tuples]

# Calculate the metric among all agents (mean by default)
if metric in ["min", "max", "sum", "len"]:
value = eval(f"{metric}(values)")

Check warning on line 290 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L290

Added line #L290 was not covered by tests
else:
stat_function = getattr(statistics, metric)
value = stat_function(values)
return value

Check warning on line 294 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L292-L294

Added lines #L292 - L294 were not covered by tests
18 changes: 14 additions & 4 deletions mesa/visualization/modules/ChartVisualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
data_collector_name="datacollector")

TODO:
Have it be able to handle agent-level variables as well.
Aggregate agent level variables other than mean (requires ChartModule
API change)

More Pythonic customization; in particular, have both series-level and
chart-level options settable in Python, and passed to the front-end
Expand Down Expand Up @@ -78,9 +79,18 @@

for s in self.series:
name = s["Label"]
try:
val = data_collector.model_vars[name][-1] # Latest value
except (IndexError, KeyError):
if name in data_collector.model_vars.keys():
try:
val = data_collector.model_vars[name][-1] # Latest value
except (IndexError, KeyError):
val = 0

Check warning on line 86 in mesa/visualization/modules/ChartVisualization.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/modules/ChartVisualization.py#L85-L86

Added lines #L85 - L86 were not covered by tests
elif name in data_collector.agent_reporters.keys():
try:

Check warning on line 88 in mesa/visualization/modules/ChartVisualization.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/modules/ChartVisualization.py#L88

Added line #L88 was not covered by tests
# Returns mean of latest value of all agents
val = data_collector.get_agent_metric(name)
except (IndexError, KeyError):
val = 0

Check warning on line 92 in mesa/visualization/modules/ChartVisualization.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/modules/ChartVisualization.py#L90-L92

Added lines #L90 - L92 were not covered by tests
else:
val = 0
current_values.append(val)
return current_values