Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance AgentSet.do to accept a callable #2210

Closed
wants to merge 13 commits into from
33 changes: 22 additions & 11 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,26 +227,37 @@ def _update(self, agents: Iterable[Agent]):
return self

def do(
self, method_name: str, *args, return_results: bool = False, **kwargs
self, method: str | Callable, *args, return_results: bool = False, **kwargs
) -> AgentSet | list[Any]:
"""
Invoke a method on each agent in the AgentSet.
Invoke a method or function on each agent in the AgentSet.

Args:
method_name (str): The name of the method to call on each agent.
method (str, callable): the callable to do on each agents

* in case of str, the name of the method to call on each agent.
* in case of callable, the function to be called with each agent as first argument

return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
*args: Variable length argument list passed to the method being called.
**kwargs: Arbitrary keyword arguments passed to the method being called.
*args: Variable length argument list passed to the callable being called.
**kwargs: Arbitrary keyword arguments passed to the callable being called.

Returns:
AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself.
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
"""
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
res = [
getattr(agent, method_name)(*args, **kwargs)
for agentref in self._agents.keyrefs()
if (agent := agentref()) is not None
]
if isinstance(method, str):
res = [
getattr(agent, method)(*args, **kwargs)
for agentref in self._agents.keyrefs()
if (agent := agentref()) is not None
]
else:
res = [
method(agent, *args, **kwargs)
for agentref in self._agents.keyrefs()
if (agent := agentref()) is not None
]

return res if return_results else self

Expand Down