From 1ce0ada4df2f530226c1d9d609b45b87453dec71 Mon Sep 17 00:00:00 2001 From: Jacob Sundstrom Date: Sat, 29 Oct 2022 13:55:46 -0700 Subject: [PATCH] Updated DataCollector to accept an arbitary scheudule and to return a default value of None if no such attribute exists. --- mesa/datacollection.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 695742ad33b..afc7c00f250 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -53,7 +53,7 @@ class DataCollector: """ - def __init__(self, model_reporters=None, agent_reporters=None, tables=None): + def __init__(self, model_reporters=None, agent_reporters=None, tables=None, schedule=None): """Instantiate a DataCollector with lists of model and agent reporters. Both model_reporters and agent_reporters accept a dictionary mapping a variable name to either an attribute name, or a method. @@ -76,6 +76,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): model_reporters: Dictionary of reporter names and attributes/funcs agent_reporters: Dictionary of reporter names and attributes/funcs. tables: Dictionary of table names to lists of column names. + schedule: A scheduler from the mesa.time module. If not supplied, this defaults to `model.schedule`. Notes: If you want to pickle your model you must not use lambda functions. @@ -100,6 +101,8 @@ class attributes of model self._agent_records = {} self.tables = {} + self.schedule = schedule + if model_reporters is not None: for name, reporter in model_reporters.items(): self._new_model_reporter(name, reporter) @@ -151,28 +154,29 @@ def _new_table(self, table_name, table_columns): new_table = {column: [] for column in table_columns} self.tables[table_name] = new_table - def _record_agents(self, model): + def _record_agents(self, model, schedule): """Record agents data in a mapping of functions and agents.""" - rep_funcs = self.agent_reporters.values() - if all(hasattr(rep, "attribute_name") for rep in rep_funcs): - prefix = ["model.schedule.steps", "unique_id"] - attributes = [func.attribute_name for func in rep_funcs] - get_reports = attrgetter(*prefix + attributes) - else: - - def get_reports(agent): - _prefix = (agent.model.schedule.steps, agent.unique_id) - reports = tuple(rep(agent) for rep in rep_funcs) - return _prefix + reports - - agent_records = map(get_reports, model.schedule.agents) + agent_records = map(partial(self._get_reports, self, schedule.steps), schedule.agents) return agent_records + @staticmethod + def _get_reports(collector, steps, agent): + """Get the agent reports for a given agent and return them in a tuple. """ + rep_funcs = collector.agent_reporters.values() + _prefix = (steps, agent.unique_id) + reports = tuple(rep(agent) for rep in rep_funcs) + return _prefix + reports + def _reporter_decorator(self, reporter): return reporter() def collect(self, model): """Collect all the data for the given model object.""" + if self.schedule is None: + schedule = model.schedule + else: + schedule = self.schedule + if self.model_reporters: for var, reporter in self.model_reporters.items(): @@ -189,8 +193,8 @@ def collect(self, model): self.model_vars[var].append(self._reporter_decorator(reporter)) if self.agent_reporters: - agent_records = self._record_agents(model) - self._agent_records[model.schedule.steps] = list(agent_records) + agent_records = self._record_agents(model, schedule) + self._agent_records[schedule.steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """Add a row dictionary to a specific table.