Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataCollector: Allow agent reporters to take class methods and functions with parameter lists #1838

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 51 additions & 31 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,49 @@ def __init__(
agent_reporters=None,
tables=None,
):
"""Instantiate a DataCollector with lists of model and agent reporters.
"""
Instantiate a DataCollector with lists of model and agent reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, or a method.
For example, if there was only one model-level reporter for number of
agents, it might look like:
{"agent_count": lambda m: m.schedule.get_agent_count() }
If there was only one agent-level reporter (e.g. the agent's energy),
it might look like this:
{"energy": "energy"}
or like this:
{"energy": lambda a: a.energy}
variable name to either an attribute name, a function, a method of a class/instance,
or a function with parameters placed in a list.

Model reporters can take four types of arguments:
1. Lambda function:
{"agent_count": lambda m: m.schedule.get_agent_count()}
2. Method of a class/instance:
{"agent_count": self.get_agent_count} # self here is a class instance
{"agent_count": Model.get_agent_count} # Model here is a class
3. Class attributes of a model:
{"model_attribute": "model_attribute"}
4. Functions with parameters that have been placed in a list:
{"Model_Function": [function, [param_1, param_2]]}

Agent reporters can similarly take:
1. Attribute name (string) referring to agent's attribute:
{"energy": "energy"}
2. Lambda function:
{"energy": lambda a: a.energy}
3. Method of an agent class/instance:
{"agent_action": self.do_action} # self here is an agent class instance
{"agent_action": Agent.do_action} # Agent here is a class
4. Functions with parameters placed in a list:
{"Agent_Function": [function, [param_1, param_2]]}

The tables arg accepts a dictionary mapping names of tables to lists of
columns. For example, if we want to allow agents to write their age
when they are destroyed (to keep track of lifespans), it might look
like:
{"Lifespan": ["unique_id", "age"]}
{"Lifespan": ["unique_id", "age"]}

Args:
model_reporters: Dictionary of reporter names and attributes/funcs
agent_reporters: Dictionary of reporter names and attributes/funcs.
model_reporters: Dictionary of reporter names and attributes/funcs/methods.
agent_reporters: Dictionary of reporter names and attributes/funcs/methods.
tables: Dictionary of table names to lists of column names.

Notes:
If you want to pickle your model you must not use lambda functions.
If your model includes a large number of agents, you should *only*
use attribute names for the agent reporter, it will be much faster.

Model reporters can take four types of arguments:
lambda like above:
{"agent_count": lambda m: m.schedule.get_agent_count() }
method of a class/instance:
{"agent_count": self.get_agent_count} # self here is a class instance
{"agent_count": Model.get_agent_count} # Model here is a class
class attributes of a model
{"model_attribute": "model_attribute"}
functions with parameters that have placed in a list
{"Model_Function":[function, [param_1, param_2]]}
- If you want to pickle your model you must not use lambda functions.
- If your model includes a large number of agents, it is recommended to
use attribute names for the agent reporter, as it will be faster.
"""
self.model_reporters = {}
self.agent_reporters = {}
Expand Down Expand Up @@ -130,16 +135,31 @@ def _new_agent_reporter(self, name, reporter):

Args:
name: Name of the agent-level variable to collect.
reporter: Attribute string, or function object that returns the
variable when given a model instance.
reporter: Attribute string, function object, method of a class/instance, or
function with parameters placed in a list that returns the
variable when given an agent instance.
"""
# Check if the reporter is an attribute string
if isinstance(reporter, str):
attribute_name = reporter

def reporter(agent):
def attr_reporter(agent):
return getattr(agent, attribute_name, None)

reporter.attribute_name = attribute_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug introduced in bc5a5cd, and you fixed in this PR.

reporter = attr_reporter

# Check if the reporter is a function with arguments placed in a list
elif isinstance(reporter, list):
func, params = reporter[0], reporter[1]

def func_with_params(agent):
return func(agent, *params)

reporter = func_with_params

# For other types (like lambda functions, method of a class/instance),
# it's already suitable to be used as a reporter directly.

self.agent_reporters[name] = reporter

def _new_table(self, table_name, table_columns):
Expand Down
36 changes: 30 additions & 6 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def step(self):
self.val += 1
self.val2 += 1

def double_val(self):
return self.val * 2

def write_final_values(self):
"""
Write the final value to the appropriate table.
Expand All @@ -32,6 +35,10 @@ def write_final_values(self):
self.model.datacollector.add_table_row("Final_Values", row)


def agent_function_with_params(agent, multiplier, offset):
return (agent.val * multiplier) + offset


class DifferentMockAgent(MockAgent):
# We define a different MockAgent to test for attributes that are present
# only in 1 type of agent, but not the other.
Expand All @@ -54,17 +61,21 @@ def __init__(self):
self.n = 10
for i in range(self.n):
self.schedule.add(MockAgent(i, self, val=i))
agent_reporters = {"value": lambda a: a.val, "value2": "val2"}
self.initialize_data_collector(
{
model_reporters={
"total_agents": lambda m: m.schedule.get_agent_count(),
"model_value": "model_val",
"model_calc": self.schedule.get_agent_count,
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
agent_reporters,
{"Final_Values": ["agent_id", "final_value"]},
agent_reporters={
"value": lambda a: a.val,
"value2": "val2",
"double_value": MockAgent.double_val,
"value_with_params": [agent_function_with_params, [2, 3]],
},
tables={"Final_Values": ["agent_id", "final_value"]},
)

def test_model_calc_comp(self, input1, input2):
Expand Down Expand Up @@ -132,6 +143,19 @@ def test_agent_records(self):
data_collector = self.model.datacollector
agent_table = data_collector.get_agent_vars_dataframe()

assert "double_value" in list(agent_table.columns)
assert "value_with_params" in list(agent_table.columns)

# Check the double_value column
for (step, agent_id), value in agent_table["double_value"].items():
expected_value = (step + agent_id) * 2
self.assertEqual(value, expected_value)

# Check the value_with_params column
for (step, agent_id), value in agent_table["value_with_params"].items():
expected_value = ((step + agent_id) * 2) + 3
self.assertEqual(value, expected_value)

assert len(data_collector._agent_records) == 8
for step, records in data_collector._agent_records.items():
if step < 5:
Expand All @@ -140,7 +164,7 @@ def test_agent_records(self):
assert len(records) == 9

for values in records:
assert len(values) == 4
assert len(values) == 6

assert "value" in list(agent_table.columns)
assert "value2" in list(agent_table.columns)
Expand Down Expand Up @@ -175,7 +199,7 @@ def test_exports(self):
agent_vars = data_collector.get_agent_vars_dataframe()
table_df = data_collector.get_table_dataframe("Final_Values")
assert model_vars.shape == (8, 5)
assert agent_vars.shape == (77, 2)
assert agent_vars.shape == (77, 4)
assert table_df.shape == (9, 2)

with self.assertRaises(Exception):
Expand Down