From 81ad2d007015ba0d68dd435a3c1ba895ffbbba4c Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Wed, 18 Oct 2023 16:24:29 +0200 Subject: [PATCH 1/5] DataCollector: Allow agent reporters to take class methods and functions with parameter lists Modify the `DataCollector` class to allow agent reporters to take methods of a class/instance and functions with parameters placed in a list (like model reporters), by extending the `_new_agent_reporter` method. This implementation starts by checking if the reporter is an attribute string. If so, it creates a function to retrieve the attribute from an agent. Next, it checks if the reporter is a list. If it is, this indicates that we have a function with parameters, so it wraps that function to pass those parameters when called. For any other type (like lambdas or methods), we assume they're directly suitable to be used as reporters. Now, with this modification, agent reporters in the `DataCollector` class can take: 1. Attribute strings 2. Function objects (like lambdas) 3. Methods of a class/instance 4. Functions with parameters placed in a list This approach ensures backward compatibility because the existing checks for attribute strings and function objects remain unchanged. The added functionality only extends the capabilities of the class without altering the existing behavior. --- mesa/datacollection.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 7f3dc847111..be3b23905a2 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -125,22 +125,37 @@ def _new_model_reporter(self, name, reporter): self.model_reporters[name] = reporter self.model_vars[name] = [] - def _new_agent_reporter(self, name, reporter): - """Add a new agent-level reporter to collect. +def _new_agent_reporter(self, name, reporter): + """Add a new agent-level reporter to collect. + + Args: + name: Name of the agent-level variable to collect. + reporter: Attribute string, function object, method of a class/instance, or + function with parameters placed in a list that returns the + variable when given an agent instance. + """ + # Check if the reporter is an attribute string + if isinstance(reporter, str): + attribute_name = reporter - Args: - name: Name of the agent-level variable to collect. - reporter: Attribute string, or function object that returns the - variable when given a model instance. - """ - if isinstance(reporter, str): - attribute_name = reporter + def attr_reporter(agent): + return getattr(agent, attribute_name, None) + + reporter = attr_reporter + + # Check if the reporter is a function with arguments placed in a list + elif isinstance(reporter, list): + func, params = reporter[0], reporter[1] + + def func_with_params(agent): + return func(agent, *params) + + reporter = func_with_params - def reporter(agent): - return getattr(agent, attribute_name, None) + # For other types (like lambda functions, method of a class/instance), + # it's already suitable to be used as a reporter directly. - reporter.attribute_name = attribute_name - self.agent_reporters[name] = reporter + self.agent_reporters[name] = reporter def _new_table(self, table_name, table_columns): """Add a new table that objects can write to. From 3a128a5ebf1ba4069c3616a022ea3f3e114b9cb8 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Wed, 18 Oct 2023 16:41:13 +0200 Subject: [PATCH 2/5] DataCollector: Add tests for class instance method and function lists --- tests/test_datacollector.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 8cdee4401e5..d2f2d36a2e6 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -24,6 +24,9 @@ def step(self): self.val += 1 self.val2 += 1 + def double_val(self): + return self.val * 2 + def write_final_values(self): """ Write the final value to the appropriate table. @@ -31,6 +34,8 @@ def write_final_values(self): row = {"agent_id": self.unique_id, "final_value": self.val} self.model.datacollector.add_table_row("Final_Values", row) +def agent_function_with_params(agent, multiplier, offset): + return (agent.val * multiplier) + offset class DifferentMockAgent(MockAgent): # We define a different MockAgent to test for attributes that are present @@ -56,17 +61,20 @@ def __init__(self): self.schedule.add(MockAgent(i, self, val=i)) agent_reporters = {"value": lambda a: a.val, "value2": "val2"} self.initialize_data_collector( - { + model_reporters={ "total_agents": lambda m: m.schedule.get_agent_count(), "model_value": "model_val", "model_calc": self.schedule.get_agent_count, "model_calc_comp": [self.test_model_calc_comp, [3, 4]], "model_calc_fail": [self.test_model_calc_comp, [12, 0]], }, - agent_reporters, - {"Final_Values": ["agent_id", "final_value"]}, + agent_reporters={ + "value": lambda a: a.val, + "value2": "val2", + "double_value": MockAgent.double_val, + "value_with_params": [agent_function_with_params, [2, 3]] + } ) - def test_model_calc_comp(self, input1, input2): if input2 > 0: return (self.model_val * input1) / input2 @@ -132,6 +140,19 @@ def test_agent_records(self): data_collector = self.model.datacollector agent_table = data_collector.get_agent_vars_dataframe() + assert "double_value" in list(agent_table.columns) + assert "value_with_params" in list(agent_table.columns) + + # Check the double_value column + for step, agent_id, value in agent_table["double_value"].items(): + expected_value = agent_id * 2 + self.assertEqual(value, expected_value) + + # Check the value_with_params column + for step, agent_id, value in agent_table["value_with_params"].items(): + expected_value = (agent_id * 2) + 3 + self.assertEqual(value, expected_value) + assert len(data_collector._agent_records) == 8 for step, records in data_collector._agent_records.items(): if step < 5: From 92ef8bc214eacf043f182cfcaa1e9f19a58ab0a9 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Oct 2023 14:22:39 +0200 Subject: [PATCH 3/5] datacollection: Fix _new_agent_reporter indentation _new_agent_reporter wasn't intended into the DataCollector class, so thus not seen as a method. --- mesa/datacollection.py | 48 +++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index be3b23905a2..4aa4533a698 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -125,37 +125,37 @@ def _new_model_reporter(self, name, reporter): self.model_reporters[name] = reporter self.model_vars[name] = [] -def _new_agent_reporter(self, name, reporter): - """Add a new agent-level reporter to collect. - - Args: - name: Name of the agent-level variable to collect. - reporter: Attribute string, function object, method of a class/instance, or - function with parameters placed in a list that returns the - variable when given an agent instance. - """ - # Check if the reporter is an attribute string - if isinstance(reporter, str): - attribute_name = reporter + def _new_agent_reporter(self, name, reporter): + """Add a new agent-level reporter to collect. + + Args: + name: Name of the agent-level variable to collect. + reporter: Attribute string, function object, method of a class/instance, or + function with parameters placed in a list that returns the + variable when given an agent instance. + """ + # Check if the reporter is an attribute string + if isinstance(reporter, str): + attribute_name = reporter - def attr_reporter(agent): - return getattr(agent, attribute_name, None) + def attr_reporter(agent): + return getattr(agent, attribute_name, None) - reporter = attr_reporter + reporter = attr_reporter - # Check if the reporter is a function with arguments placed in a list - elif isinstance(reporter, list): - func, params = reporter[0], reporter[1] + # Check if the reporter is a function with arguments placed in a list + elif isinstance(reporter, list): + func, params = reporter[0], reporter[1] - def func_with_params(agent): - return func(agent, *params) + def func_with_params(agent): + return func(agent, *params) - reporter = func_with_params + reporter = func_with_params - # For other types (like lambda functions, method of a class/instance), - # it's already suitable to be used as a reporter directly. + # For other types (like lambda functions, method of a class/instance), + # it's already suitable to be used as a reporter directly. - self.agent_reporters[name] = reporter + self.agent_reporters[name] = reporter def _new_table(self, table_name, table_columns): """Add a new table that objects can write to. From 05d176782500f29a3e6159ae199a00c584f4ab5b Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Oct 2023 14:26:02 +0200 Subject: [PATCH 4/5] Fix DataCollector tests for new agent reporter types - Move agent_reporters specification into initialize_data_collector() - Add back the tables argument (accidentally deleted in previous commit) - Use parentheses to parse step and agent_id from agent records dataframe, since those are the multi-index key - Update expected values for new agent reporter types - Update length values of new agent table and vars (both increase by 2 due to 2 new agent reporter columns) --- tests/test_datacollector.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index d2f2d36a2e6..a45d8891971 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -34,9 +34,11 @@ def write_final_values(self): row = {"agent_id": self.unique_id, "final_value": self.val} self.model.datacollector.add_table_row("Final_Values", row) + def agent_function_with_params(agent, multiplier, offset): return (agent.val * multiplier) + offset + class DifferentMockAgent(MockAgent): # We define a different MockAgent to test for attributes that are present # only in 1 type of agent, but not the other. @@ -59,7 +61,6 @@ def __init__(self): self.n = 10 for i in range(self.n): self.schedule.add(MockAgent(i, self, val=i)) - agent_reporters = {"value": lambda a: a.val, "value2": "val2"} self.initialize_data_collector( model_reporters={ "total_agents": lambda m: m.schedule.get_agent_count(), @@ -72,9 +73,11 @@ def __init__(self): "value": lambda a: a.val, "value2": "val2", "double_value": MockAgent.double_val, - "value_with_params": [agent_function_with_params, [2, 3]] - } + "value_with_params": [agent_function_with_params, [2, 3]], + }, + tables={"Final_Values": ["agent_id", "final_value"]}, ) + def test_model_calc_comp(self, input1, input2): if input2 > 0: return (self.model_val * input1) / input2 @@ -144,13 +147,13 @@ def test_agent_records(self): assert "value_with_params" in list(agent_table.columns) # Check the double_value column - for step, agent_id, value in agent_table["double_value"].items(): - expected_value = agent_id * 2 + for (step, agent_id), value in agent_table["double_value"].items(): + expected_value = (step + agent_id) * 2 self.assertEqual(value, expected_value) # Check the value_with_params column - for step, agent_id, value in agent_table["value_with_params"].items(): - expected_value = (agent_id * 2) + 3 + for (step, agent_id), value in agent_table["value_with_params"].items(): + expected_value = ((step + agent_id) * 2) + 3 self.assertEqual(value, expected_value) assert len(data_collector._agent_records) == 8 @@ -161,7 +164,7 @@ def test_agent_records(self): assert len(records) == 9 for values in records: - assert len(values) == 4 + assert len(values) == 6 assert "value" in list(agent_table.columns) assert "value2" in list(agent_table.columns) @@ -196,7 +199,7 @@ def test_exports(self): agent_vars = data_collector.get_agent_vars_dataframe() table_df = data_collector.get_table_dataframe("Final_Values") assert model_vars.shape == (8, 5) - assert agent_vars.shape == (77, 2) + assert agent_vars.shape == (77, 4) assert table_df.shape == (9, 2) with self.assertRaises(Exception): From 0fe24d78a0b937678058258f93d0950cba53ae4e Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Thu, 19 Oct 2023 14:40:18 +0200 Subject: [PATCH 5/5] DataCollector: Update docs with new agent reporter syntax --- mesa/datacollection.py | 59 +++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 4aa4533a698..8bddfa23da2 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -56,44 +56,49 @@ def __init__( agent_reporters=None, tables=None, ): - """Instantiate a DataCollector with lists of model and agent reporters. + """ + Instantiate a DataCollector with lists of model and agent reporters. Both model_reporters and agent_reporters accept a dictionary mapping a - variable name to either an attribute name, or a method. - For example, if there was only one model-level reporter for number of - agents, it might look like: - {"agent_count": lambda m: m.schedule.get_agent_count() } - If there was only one agent-level reporter (e.g. the agent's energy), - it might look like this: - {"energy": "energy"} - or like this: - {"energy": lambda a: a.energy} + variable name to either an attribute name, a function, a method of a class/instance, + or a function with parameters placed in a list. + + Model reporters can take four types of arguments: + 1. Lambda function: + {"agent_count": lambda m: m.schedule.get_agent_count()} + 2. Method of a class/instance: + {"agent_count": self.get_agent_count} # self here is a class instance + {"agent_count": Model.get_agent_count} # Model here is a class + 3. Class attributes of a model: + {"model_attribute": "model_attribute"} + 4. Functions with parameters that have been placed in a list: + {"Model_Function": [function, [param_1, param_2]]} + + Agent reporters can similarly take: + 1. Attribute name (string) referring to agent's attribute: + {"energy": "energy"} + 2. Lambda function: + {"energy": lambda a: a.energy} + 3. Method of an agent class/instance: + {"agent_action": self.do_action} # self here is an agent class instance + {"agent_action": Agent.do_action} # Agent here is a class + 4. Functions with parameters placed in a list: + {"Agent_Function": [function, [param_1, param_2]]} The tables arg accepts a dictionary mapping names of tables to lists of columns. For example, if we want to allow agents to write their age when they are destroyed (to keep track of lifespans), it might look like: - {"Lifespan": ["unique_id", "age"]} + {"Lifespan": ["unique_id", "age"]} Args: - model_reporters: Dictionary of reporter names and attributes/funcs - agent_reporters: Dictionary of reporter names and attributes/funcs. + model_reporters: Dictionary of reporter names and attributes/funcs/methods. + agent_reporters: Dictionary of reporter names and attributes/funcs/methods. tables: Dictionary of table names to lists of column names. Notes: - If you want to pickle your model you must not use lambda functions. - If your model includes a large number of agents, you should *only* - use attribute names for the agent reporter, it will be much faster. - - Model reporters can take four types of arguments: - lambda like above: - {"agent_count": lambda m: m.schedule.get_agent_count() } - method of a class/instance: - {"agent_count": self.get_agent_count} # self here is a class instance - {"agent_count": Model.get_agent_count} # Model here is a class - class attributes of a model - {"model_attribute": "model_attribute"} - functions with parameters that have placed in a list - {"Model_Function":[function, [param_1, param_2]]} + - If you want to pickle your model you must not use lambda functions. + - If your model includes a large number of agents, it is recommended to + use attribute names for the agent reporter, as it will be faster. """ self.model_reporters = {} self.agent_reporters = {}