diff --git a/mesa/datacollection.py b/mesa/datacollection.py index fcaedc7c8c8..d3e365a68dc 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -255,3 +255,25 @@ def get_table_dataframe(self, table_name): if table_name not in self.tables: raise Exception("No such table.") return pd.DataFrame(self.tables[table_name]) + + +class DataCollectorWithoutNone(DataCollector): + """ + DataCollector where None values of the records are excluded + """ + + def _record_agents(self, model): + """Record agents data in a mapping of functions and agents.""" + rep_funcs = self.agent_reporters.values() + + def get_reports(agent): + _prefix = (agent.model.schedule.steps, agent.unique_id) + reports = (rep(agent) for rep in rep_funcs) + reports_without_none = tuple(r for r in reports if r is not None) + if len(reports_without_none) == 0: + return None + return _prefix + reports_without_none + + agent_records = (get_reports(agent) for agent in model.schedule.agents) + agent_records_without_none = (r for r in agent_records if r is not None) + return agent_records_without_none diff --git a/mesa/model.py b/mesa/model.py index be6072663bf..59a6ccd5197 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -12,7 +12,7 @@ # mypy from typing import Any -from mesa.datacollection import DataCollector +from mesa.datacollection import DataCollector, DataCollectorWithoutNone class Model: @@ -66,7 +66,7 @@ def reset_randomizer(self, seed: int | None = None) -> None: self._seed = seed def initialize_data_collector( - self, model_reporters=None, agent_reporters=None, tables=None + self, model_reporters=None, agent_reporters=None, tables=None, exclude_none_values=False ) -> None: if not hasattr(self, "schedule") or self.schedule is None: raise RuntimeError( @@ -76,10 +76,18 @@ def initialize_data_collector( raise RuntimeError( "You must add agents to the scheduler before initializing the data collector." ) - self.datacollector = DataCollector( - model_reporters=model_reporters, - agent_reporters=agent_reporters, - tables=tables, - ) + if exclude_none_values: + datacollector = DataCollectorWithoutNone( + model_reporters=model_reporters, + agent_reporters=agent_reporters, + tables=tables, + ) + else: + datacollector = DataCollector( + model_reporters=model_reporters, + agent_reporters=agent_reporters, + tables=tables, + ) + self.datacollector = datacollector # Collect data for the first time during initialization. self.datacollector.collect(self) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 7ba72c73df5..a968dd19ebe 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -32,6 +32,14 @@ def write_final_values(self): self.model.datacollector.add_table_row("Final_Values", row) +class DifferentMockAgent(MockAgent): + # We define a different MockAgent to test for attributes that are present + # only in 1 type of agent, but not the other. + def __init__(self, unique_id, model, val=0): + super().__init__(unique_id, model, val=val) + self.val3 = val + 42 + + class MockModel(Model): """ Minimalistic model for testing purposes. @@ -39,13 +47,20 @@ class MockModel(Model): schedule = BaseScheduler(None) - def __init__(self): + def __init__(self, test_exclude_none_values=False): self.schedule = BaseScheduler(self) self.model_val = 100 - for i in range(10): - a = MockAgent(i, self, val=i) - self.schedule.add(a) + self.n = 10 + for i in range(self.n): + self.schedule.add(MockAgent(i, self, val=i)) + if test_exclude_none_values: + self.schedule.add(DifferentMockAgent(self.n + i, self, val=i)) + if test_exclude_none_values: + # Only DifferentMockAgent has val3. + agent_reporters = {"value": lambda a: a.val, "value3": "val3"} + else: + agent_reporters = {"value": lambda a: a.val, "value2": "val2"} self.initialize_data_collector( { "total_agents": lambda m: m.schedule.get_agent_count(), @@ -54,8 +69,9 @@ def __init__(self): "model_calc_comp": [self.test_model_calc_comp, [3, 4]], "model_calc_fail": [self.test_model_calc_comp, [12, 0]], }, - {"value": lambda a: a.val, "value2": "val2"}, + agent_reporters, {"Final_Values": ["agent_id", "final_value"]}, + exclude_none_values=test_exclude_none_values, ) def test_model_calc_comp(self, input1, input2): @@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self): ) +class TestDataCollectorWithoutNone(unittest.TestCase): + def setUp(self): + """ + Create the model and run it a set number of steps. + """ + self.model = MockModel(test_exclude_none_values=True) + for i in range(7): + if i == 4: + self.model.schedule.remove(self.model.schedule._agents[3]) + self.model.step() + + def test_agent_records(self): + """ + Test agent-level variable collection. + """ + data_collector = self.model.datacollector + agent_table = data_collector.get_agent_vars_dataframe() + + assert len(data_collector._agent_records) == 8 + for step, records in data_collector._agent_records.items(): + if step < 5: + assert len(records) == 20 + else: + assert len(records) == 19 + + for values in records: + agent_id = values[1] + if agent_id < self.model.n: + assert len(values) == 3 + else: + # Agents with agent_id >= self.model.n are + # DifferentMockAgent, which additionally contains val3. + assert len(values) == 4 + + assert "value" in list(agent_table.columns) + assert "value2" not in list(agent_table.columns) + assert "value3" in list(agent_table.columns) + + if __name__ == "__main__": unittest.main()