Skip to content

Commit

Permalink
feat: Implement DataCollectorWithoutNone
Browse files Browse the repository at this point in the history
  • Loading branch information
rht committed May 21, 2023
1 parent e3af2a5 commit 196b5c9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 12 deletions.
22 changes: 22 additions & 0 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,25 @@ def get_table_dataframe(self, table_name):
if table_name not in self.tables:
raise Exception("No such table.")
return pd.DataFrame(self.tables[table_name])


class DataCollectorWithoutNone(DataCollector):
"""
DataCollector where None values of the records are excluded
"""

def _record_agents(self, model):
"""Record agents data in a mapping of functions and agents."""
rep_funcs = self.agent_reporters.values()

def get_reports(agent):
_prefix = (agent.model.schedule.steps, agent.unique_id)
reports = (rep(agent) for rep in rep_funcs)
reports_without_none = tuple(r for r in reports if r is not None)
if len(reports_without_none) == 0:
return None
return _prefix + reports_without_none

agent_records = (get_reports(agent) for agent in model.schedule.agents)
agent_records_without_none = (r for r in agent_records if r is not None)
return agent_records_without_none
22 changes: 15 additions & 7 deletions mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# mypy
from typing import Any

from mesa.datacollection import DataCollector
from mesa.datacollection import DataCollector, DataCollectorWithoutNone


class Model:
Expand Down Expand Up @@ -66,7 +66,7 @@ def reset_randomizer(self, seed: int | None = None) -> None:
self._seed = seed

def initialize_data_collector(
self, model_reporters=None, agent_reporters=None, tables=None
self, model_reporters=None, agent_reporters=None, tables=None, exclude_none_values=False
) -> None:
if not hasattr(self, "schedule") or self.schedule is None:
raise RuntimeError(
Expand All @@ -76,10 +76,18 @@ def initialize_data_collector(
raise RuntimeError(
"You must add agents to the scheduler before initializing the data collector."
)
self.datacollector = DataCollector(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
tables=tables,
)
if exclude_none_values:
datacollector = DataCollectorWithoutNone(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
tables=tables,
)
else:
datacollector = DataCollector(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
tables=tables,
)
self.datacollector = datacollector
# Collect data for the first time during initialization.
self.datacollector.collect(self)
65 changes: 60 additions & 5 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,35 @@ def write_final_values(self):
self.model.datacollector.add_table_row("Final_Values", row)


class DifferentMockAgent(MockAgent):
# We define a different MockAgent to test for attributes that are present
# only in 1 type of agent, but not the other.
def __init__(self, unique_id, model, val=0):
super().__init__(unique_id, model, val=val)
self.val3 = val + 42


class MockModel(Model):
"""
Minimalistic model for testing purposes.
"""

schedule = BaseScheduler(None)

def __init__(self):
def __init__(self, test_exclude_none_values=False):
self.schedule = BaseScheduler(self)
self.model_val = 100

for i in range(10):
a = MockAgent(i, self, val=i)
self.schedule.add(a)
self.n = 10
for i in range(self.n):
self.schedule.add(MockAgent(i, self, val=i))
if test_exclude_none_values:
self.schedule.add(DifferentMockAgent(self.n + i, self, val=i))
if test_exclude_none_values:
# Only DifferentMockAgent has val3.
agent_reporters = {"value": lambda a: a.val, "value3": "val3"}
else:
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
self.initialize_data_collector(
{
"total_agents": lambda m: m.schedule.get_agent_count(),
Expand All @@ -54,8 +69,9 @@ def __init__(self):
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
{"value": lambda a: a.val, "value2": "val2"},
agent_reporters,
{"Final_Values": ["agent_id", "final_value"]},
exclude_none_values=test_exclude_none_values,
)

def test_model_calc_comp(self, input1, input2):
Expand Down Expand Up @@ -195,5 +211,44 @@ def test_initialize_before_agents_added_to_scheduler(self):
)


class TestDataCollectorWithoutNone(unittest.TestCase):
def setUp(self):
"""
Create the model and run it a set number of steps.
"""
self.model = MockModel(test_exclude_none_values=True)
for i in range(7):
if i == 4:
self.model.schedule.remove(self.model.schedule._agents[3])
self.model.step()

def test_agent_records(self):
"""
Test agent-level variable collection.
"""
data_collector = self.model.datacollector
agent_table = data_collector.get_agent_vars_dataframe()

assert len(data_collector._agent_records) == 8
for step, records in data_collector._agent_records.items():
if step < 5:
assert len(records) == 20
else:
assert len(records) == 19

for values in records:
agent_id = values[1]
if agent_id < self.model.n:
assert len(values) == 3
else:
# Agents with agent_id >= self.model.n are
# DifferentMockAgent, which additionally contains val3.
assert len(values) == 4

assert "value" in list(agent_table.columns)
assert "value2" not in list(agent_table.columns)
assert "value3" in list(agent_table.columns)


if __name__ == "__main__":
unittest.main()

0 comments on commit 196b5c9

Please sign in to comment.