Skip to content

Commit

Permalink
Updated DataCollector to accept an arbitary scheudule and to return a…
Browse files Browse the repository at this point in the history
… default value of None if no such attribute exists.
  • Loading branch information
woolgathering committed Oct 29, 2022
1 parent 659b00a commit 1ce0ada
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DataCollector:
"""

def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
def __init__(self, model_reporters=None, agent_reporters=None, tables=None, schedule=None):
"""Instantiate a DataCollector with lists of model and agent reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, or a method.
Expand All @@ -76,6 +76,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
model_reporters: Dictionary of reporter names and attributes/funcs
agent_reporters: Dictionary of reporter names and attributes/funcs.
tables: Dictionary of table names to lists of column names.
schedule: A scheduler from the mesa.time module. If not supplied, this defaults to `model.schedule`.
Notes:
If you want to pickle your model you must not use lambda functions.
Expand All @@ -100,6 +101,8 @@ class attributes of model
self._agent_records = {}
self.tables = {}

self.schedule = schedule

if model_reporters is not None:
for name, reporter in model_reporters.items():
self._new_model_reporter(name, reporter)
Expand Down Expand Up @@ -151,28 +154,29 @@ def _new_table(self, table_name, table_columns):
new_table = {column: [] for column in table_columns}
self.tables[table_name] = new_table

def _record_agents(self, model):
def _record_agents(self, model, schedule):
"""Record agents data in a mapping of functions and agents."""
rep_funcs = self.agent_reporters.values()
if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(*prefix + attributes)
else:

def get_reports(agent):
_prefix = (agent.model.schedule.steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

agent_records = map(get_reports, model.schedule.agents)
agent_records = map(partial(self._get_reports, self, schedule.steps), schedule.agents)
return agent_records

@staticmethod
def _get_reports(collector, steps, agent):
"""Get the agent reports for a given agent and return them in a tuple. """
rep_funcs = collector.agent_reporters.values()
_prefix = (steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

def _reporter_decorator(self, reporter):
return reporter()

def collect(self, model):
"""Collect all the data for the given model object."""
if self.schedule is None:
schedule = model.schedule
else:
schedule = self.schedule

if self.model_reporters:

for var, reporter in self.model_reporters.items():
Expand All @@ -189,8 +193,8 @@ def collect(self, model):
self.model_vars[var].append(self._reporter_decorator(reporter))

if self.agent_reporters:
agent_records = self._record_agents(model)
self._agent_records[model.schedule.steps] = list(agent_records)
agent_records = self._record_agents(model, schedule)
self._agent_records[schedule.steps] = list(agent_records)

def add_table_row(self, table_name, row, ignore_missing=False):
"""Add a row dictionary to a specific table.
Expand Down

0 comments on commit 1ce0ada

Please sign in to comment.