Skip to content

Commit

Permalink
Fix projectmesa#1419. DataCollector accepts an arbitrary schedule at …
Browse files Browse the repository at this point in the history
…creation (defaults to model.schedule otherwise) and will return None if an attribute is not found instead of throwing an AttributeError.
  • Loading branch information
jacob-thrackle committed Oct 29, 2022
1 parent 952857c commit f81be13
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DataCollector:
"""

def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
def __init__(self, model_reporters=None, agent_reporters=None, tables=None, schedule=None):
"""Instantiate a DataCollector with lists of model and agent reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, or a method.
Expand All @@ -76,6 +76,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
model_reporters: Dictionary of reporter names and attributes/funcs
agent_reporters: Dictionary of reporter names and attributes/funcs.
tables: Dictionary of table names to lists of column names.
schedule: A scheduler from the mesa.time module. If not supplied, this defaults to `model.schedule`.
Notes:
If you want to pickle your model you must not use lambda functions.
Expand All @@ -100,6 +101,8 @@ class attributes of model
self._agent_records = {}
self.tables = {}

self.schedule = schedule

if model_reporters is not None:
for name, reporter in model_reporters.items():
self._new_model_reporter(name, reporter)
Expand Down Expand Up @@ -151,28 +154,29 @@ def _new_table(self, table_name, table_columns):
new_table = {column: [] for column in table_columns}
self.tables[table_name] = new_table

def _record_agents(self, model):
def _record_agents(self, model, schedule):
"""Record agents data in a mapping of functions and agents."""
rep_funcs = self.agent_reporters.values()
if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(*prefix + attributes)
else:

def get_reports(agent):
_prefix = (agent.model.schedule.steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

agent_records = map(get_reports, model.schedule.agents)
agent_records = map(partial(self._get_reports, self, schedule.steps), schedule.agents)
return agent_records

@staticmethod
def _get_reports(collector, steps, agent):
"""Get the agent reports for a given agent and return them in a tuple. """
rep_funcs = collector.agent_reporters.values()
_prefix = (steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

def _reporter_decorator(self, reporter):
return reporter()

def collect(self, model):
"""Collect all the data for the given model object."""
if self.schedule is None:
schedule = model.schedule
else:
schedule = self.schedule

if self.model_reporters:

for var, reporter in self.model_reporters.items():
Expand All @@ -189,8 +193,8 @@ def collect(self, model):
self.model_vars[var].append(self._reporter_decorator(reporter))

if self.agent_reporters:
agent_records = self._record_agents(model)
self._agent_records[model.schedule.steps] = list(agent_records)
agent_records = self._record_agents(model, schedule)
self._agent_records[schedule.steps] = list(agent_records)

def add_table_row(self, table_name, row, ignore_missing=False):
"""Add a row dictionary to a specific table.
Expand Down

0 comments on commit f81be13

Please sign in to comment.