Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement exclude_none_values in DataCollector #1702

Merged
merged 1 commit into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class DataCollector:
one and stores the results.
"""

def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
def __init__(
self,
model_reporters=None,
agent_reporters=None,
tables=None,
exclude_none_values=False,
):
"""Instantiate a DataCollector with lists of model and agent reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, or a method.
Expand All @@ -74,6 +80,8 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
model_reporters: Dictionary of reporter names and attributes/funcs
agent_reporters: Dictionary of reporter names and attributes/funcs.
tables: Dictionary of table names to lists of column names.
exclude_none_values: Boolean of whether to drop records which values
are None, in the final result.

Notes:
If you want to pickle your model you must not use lambda functions.
Expand All @@ -97,6 +105,7 @@ class attributes of a model
self.model_vars = {}
self._agent_records = {}
self.tables = {}
self.exclude_none_values = exclude_none_values

if model_reporters is not None:
for name, reporter in model_reporters.items():
Expand Down Expand Up @@ -151,7 +160,23 @@ def _new_table(self, table_name, table_columns):
def _record_agents(self, model):
"""Record agents data in a mapping of functions and agents."""
rep_funcs = self.agent_reporters.values()
if self.exclude_none_values:
# Drop records which values are None.

def get_reports(agent):
_prefix = (agent.model.schedule.steps, agent.unique_id)
reports = (rep(agent) for rep in rep_funcs)
reports_without_none = tuple(r for r in reports if r is not None)
if len(reports_without_none) == 0:
return None
return _prefix + reports_without_none

agent_records = (get_reports(agent) for agent in model.schedule.agents)
agent_records_without_none = (r for r in agent_records if r is not None)
return agent_records_without_none

if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
# This branch is for performance optimization purpose.
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(*prefix + attributes)
Expand Down
7 changes: 6 additions & 1 deletion mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def reset_randomizer(self, seed: int | None = None) -> None:
self._seed = seed

def initialize_data_collector(
self, model_reporters=None, agent_reporters=None, tables=None
self,
model_reporters=None,
agent_reporters=None,
tables=None,
exclude_none_values=False,
) -> None:
if not hasattr(self, "schedule") or self.schedule is None:
raise RuntimeError(
Expand All @@ -80,6 +84,7 @@ def initialize_data_collector(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
tables=tables,
exclude_none_values=exclude_none_values,
)
# Collect data for the first time during initialization.
self.datacollector.collect(self)
65 changes: 60 additions & 5 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,35 @@ def write_final_values(self):
self.model.datacollector.add_table_row("Final_Values", row)


class DifferentMockAgent(MockAgent):
# We define a different MockAgent to test for attributes that are present
# only in 1 type of agent, but not the other.
def __init__(self, unique_id, model, val=0):
super().__init__(unique_id, model, val=val)
self.val3 = val + 42


class MockModel(Model):
"""
Minimalistic model for testing purposes.
"""

schedule = BaseScheduler(None)

def __init__(self):
def __init__(self, test_exclude_none_values=False):
self.schedule = BaseScheduler(self)
self.model_val = 100

for i in range(10):
a = MockAgent(i, self, val=i)
self.schedule.add(a)
self.n = 10
for i in range(self.n):
self.schedule.add(MockAgent(i, self, val=i))
if test_exclude_none_values:
self.schedule.add(DifferentMockAgent(self.n + i, self, val=i))
if test_exclude_none_values:
# Only DifferentMockAgent has val3.
agent_reporters = {"value": lambda a: a.val, "value3": "val3"}
else:
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
self.initialize_data_collector(
{
"total_agents": lambda m: m.schedule.get_agent_count(),
Expand All @@ -54,8 +69,9 @@ def __init__(self):
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
{"value": lambda a: a.val, "value2": "val2"},
agent_reporters,
{"Final_Values": ["agent_id", "final_value"]},
exclude_none_values=test_exclude_none_values,
)

def test_model_calc_comp(self, input1, input2):
Expand Down Expand Up @@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self):
)


class TestDataCollectorExcludeNone(unittest.TestCase):
def setUp(self):
"""
Create the model and run it a set number of steps.
"""
self.model = MockModel(test_exclude_none_values=True)
for i in range(7):
if i == 4:
self.model.schedule.remove(self.model.schedule._agents[3])
self.model.step()

def test_agent_records(self):
"""
Test agent-level variable collection.
"""
data_collector = self.model.datacollector
agent_table = data_collector.get_agent_vars_dataframe()

assert len(data_collector._agent_records) == 8
for step, records in data_collector._agent_records.items():
if step < 5:
assert len(records) == 20
else:
assert len(records) == 19

for values in records:
agent_id = values[1]
if agent_id < self.model.n:
assert len(values) == 3
else:
# Agents with agent_id >= self.model.n are
# DifferentMockAgent, which additionally contains val3.
assert len(values) == 4

assert "value" in list(agent_table.columns)
assert "value2" not in list(agent_table.columns)
assert "value3" in list(agent_table.columns)


if __name__ == "__main__":
unittest.main()