From 59f580d019fe965fc849cb507052e7df0bedf340 Mon Sep 17 00:00:00 2001 From: rht Date: Sat, 20 May 2023 20:17:33 -0400 Subject: [PATCH] feat: Implement DataCollectorWithoutNone --- mesa/datacollection.py | 27 ++++++++++++++- mesa/model.py | 10 ++++-- tests/test_datacollector.py | 65 ++++++++++++++++++++++++++++++++++--- 3 files changed, 94 insertions(+), 8 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index d58133f1ad8..f002df5cb88 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -51,7 +51,13 @@ class DataCollector: one and stores the results. """ - def __init__(self, model_reporters=None, agent_reporters=None, tables=None): + def __init__( + self, + model_reporters=None, + agent_reporters=None, + tables=None, + exclude_none_values=False, + ): """Instantiate a DataCollector with lists of model and agent reporters. Both model_reporters and agent_reporters accept a dictionary mapping a variable name to either an attribute name, or a method. @@ -74,6 +80,8 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): model_reporters: Dictionary of reporter names and attributes/funcs agent_reporters: Dictionary of reporter names and attributes/funcs. tables: Dictionary of table names to lists of column names. + exclude_non_values: Boolean of whether to drop records which values + are None in the final result. Notes: If you want to pickle your model you must not use lambda functions. @@ -97,6 +105,7 @@ class attributes of a model self.model_vars = {} self._agent_records = {} self.tables = {} + self.exclude_none_values = exclude_none_values if model_reporters is not None: for name, reporter in model_reporters.items(): @@ -151,7 +160,23 @@ def _new_table(self, table_name, table_columns): def _record_agents(self, model): """Record agents data in a mapping of functions and agents.""" rep_funcs = self.agent_reporters.values() + if self.exclude_none_values: + # Drop records which values are None. + + def get_reports(agent): + _prefix = (agent.model.schedule.steps, agent.unique_id) + reports = (rep(agent) for rep in rep_funcs) + reports_without_none = tuple(r for r in reports if r is not None) + if len(reports_without_none) == 0: + return None + return _prefix + reports_without_none + + agent_records = (get_reports(agent) for agent in model.schedule.agents) + agent_records_without_none = (r for r in agent_records if r is not None) + return agent_records_without_none + if all(hasattr(rep, "attribute_name") for rep in rep_funcs): + # This branch is for performance optimization purpose. prefix = ["model.schedule.steps", "unique_id"] attributes = [func.attribute_name for func in rep_funcs] get_reports = attrgetter(*prefix + attributes) diff --git a/mesa/model.py b/mesa/model.py index be6072663bf..ffe31b7cc73 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -66,7 +66,11 @@ def reset_randomizer(self, seed: int | None = None) -> None: self._seed = seed def initialize_data_collector( - self, model_reporters=None, agent_reporters=None, tables=None + self, + model_reporters=None, + agent_reporters=None, + tables=None, + exclude_none_values=False, ) -> None: if not hasattr(self, "schedule") or self.schedule is None: raise RuntimeError( @@ -76,10 +80,12 @@ def initialize_data_collector( raise RuntimeError( "You must add agents to the scheduler before initializing the data collector." ) - self.datacollector = DataCollector( + datacollector = DataCollector( model_reporters=model_reporters, agent_reporters=agent_reporters, tables=tables, + exclude_none_values=exclude_none_values, ) + self.datacollector = datacollector # Collect data for the first time during initialization. self.datacollector.collect(self) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 7ba72c73df5..c44c708e31d 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -32,6 +32,14 @@ def write_final_values(self): self.model.datacollector.add_table_row("Final_Values", row) +class DifferentMockAgent(MockAgent): + # We define a different MockAgent to test for attributes that are present + # only in 1 type of agent, but not the other. + def __init__(self, unique_id, model, val=0): + super().__init__(unique_id, model, val=val) + self.val3 = val + 42 + + class MockModel(Model): """ Minimalistic model for testing purposes. @@ -39,13 +47,20 @@ class MockModel(Model): schedule = BaseScheduler(None) - def __init__(self): + def __init__(self, test_exclude_none_values=False): self.schedule = BaseScheduler(self) self.model_val = 100 - for i in range(10): - a = MockAgent(i, self, val=i) - self.schedule.add(a) + self.n = 10 + for i in range(self.n): + self.schedule.add(MockAgent(i, self, val=i)) + if test_exclude_none_values: + self.schedule.add(DifferentMockAgent(self.n + i, self, val=i)) + if test_exclude_none_values: + # Only DifferentMockAgent has val3. + agent_reporters = {"value": lambda a: a.val, "value3": "val3"} + else: + agent_reporters = {"value": lambda a: a.val, "value2": "val2"} self.initialize_data_collector( { "total_agents": lambda m: m.schedule.get_agent_count(), @@ -54,8 +69,9 @@ def __init__(self): "model_calc_comp": [self.test_model_calc_comp, [3, 4]], "model_calc_fail": [self.test_model_calc_comp, [12, 0]], }, - {"value": lambda a: a.val, "value2": "val2"}, + agent_reporters, {"Final_Values": ["agent_id", "final_value"]}, + exclude_none_values=test_exclude_none_values, ) def test_model_calc_comp(self, input1, input2): @@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self): ) +class TestDataCollectorExcludeNone(unittest.TestCase): + def setUp(self): + """ + Create the model and run it a set number of steps. + """ + self.model = MockModel(test_exclude_none_values=True) + for i in range(7): + if i == 4: + self.model.schedule.remove(self.model.schedule._agents[3]) + self.model.step() + + def test_agent_records(self): + """ + Test agent-level variable collection. + """ + data_collector = self.model.datacollector + agent_table = data_collector.get_agent_vars_dataframe() + + assert len(data_collector._agent_records) == 8 + for step, records in data_collector._agent_records.items(): + if step < 5: + assert len(records) == 20 + else: + assert len(records) == 19 + + for values in records: + agent_id = values[1] + if agent_id < self.model.n: + assert len(values) == 3 + else: + # Agents with agent_id >= self.model.n are + # DifferentMockAgent, which additionally contains val3. + assert len(values) == 4 + + assert "value" in list(agent_table.columns) + assert "value2" not in list(agent_table.columns) + assert "value3" in list(agent_table.columns) + + if __name__ == "__main__": unittest.main()