From 69bc7620735533fce84fb45dda6c672c869ecde7 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 13 Nov 2024 13:12:47 +0100 Subject: [PATCH] initial commit --- mesa/agent.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 044b40acecf..d6803b4be3d 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -65,9 +65,11 @@ def __init__(self, model: Model, *args, **kwargs) -> None: self.model: Model = model self.unique_id: int = next(self._ids[model]) + self._mesa_weakref = weakref.ref(self) ## fixme change code and test impact self.pos: Position | None = None self.model.register_agent(self) + def remove(self) -> None: """Remove and delete the agent from the model. @@ -135,7 +137,8 @@ def __init__(self, agents: Iterable[Agent], random: Random | None = None): Random() ) # FIXME see issue 1981, how to get the central rng from model self.random = random - self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) + # self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) + self._agents = {agent._mesa_weakref: None for agent in agents} def __len__(self) -> int: """Return the number of agents in the AgentSet.""" @@ -143,11 +146,13 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Agent]: """Provide an iterator over the agents in the AgentSet.""" - return self._agents.keys() + + # fixme should this not also be used to cleanout the weakrefs that resolve to None? + return (agent for entry in self._agents.keys() if (agent := entry()) is not None) def __contains__(self, agent: Agent) -> bool: """Check if an agent is in the AgentSet. Can be used like `agent in agentset`.""" - return agent in self._agents + return agent._mesa_weakref in self._agents def select( self, @@ -210,7 +215,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet: Using inplace = True is more performant """ - weakrefs = list(self._agents.keyrefs()) + weakrefs = list(self._agents.keys()) self.random.shuffle(weakrefs) if inplace: @@ -240,7 +245,7 @@ def sort( if isinstance(key, str): key = operator.attrgetter(key) - sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending) + sorted_agents = sorted(self, key=key, reverse=not ascending) return ( AgentSet(sorted_agents, self.random) @@ -253,7 +258,7 @@ def _update(self, agents: Iterable[Agent]): This is a private method primarily used internally by other methods like select, shuffle, and sort. """ - self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) + self._agents = {agent: None for agent in agents} return self def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: @@ -273,11 +278,11 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: """ # we iterate over the actual weakref keys and check if weakref is alive before calling the method if isinstance(method, str): - for agentref in self._agents.keyrefs(): + for agentref in self._agents: if (agent := agentref()) is not None: getattr(agent, method)(*args, **kwargs) else: - for agentref in self._agents.keyrefs(): + for agentref in self._agents: if (agent := agentref()) is not None: method(agent, *args, **kwargs) @@ -288,7 +293,7 @@ def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: It's a fast, optimized version of calling shuffle() followed by do(). """ - weakrefs = list(self._agents.keyrefs()) + weakrefs = list(self._agents) self.random.shuffle(weakrefs) if isinstance(method, str): @@ -321,13 +326,13 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: if isinstance(method, str): res = [ getattr(agent, method)(*args, **kwargs) - for agentref in self._agents.keyrefs() + for agentref in self._agents if (agent := agentref()) is not None ] else: res = [ method(agent, *args, **kwargs) - for agentref in self._agents.keyrefs() + for agentref in self._agents if (agent := agentref()) is not None ] @@ -390,22 +395,22 @@ def get( if handle_missing == "error": if is_single_attr: - return [getattr(agent, attr_names) for agent in self._agents] + return [getattr(agent, attr_names) for agent_ref in self._agents if (agent := agent_ref()) is not None] else: return [ [getattr(agent, attr) for attr in attr_names] - for agent in self._agents + for agent_ref in self._agents if (agent := agent_ref()) is not None ] elif handle_missing == "default": if is_single_attr: return [ - getattr(agent, attr_names, default_value) for agent in self._agents + getattr(agent, attr_names, default_value) for agent_ref in self._agents if (agent := agent_ref()) is not None ] else: return [ [getattr(agent, attr, default_value) for attr in attr_names] - for agent in self._agents + for agent_ref in self._agents if (agent := agent_ref()) is not None ] else: @@ -437,7 +442,7 @@ def __getitem__(self, item: int | slice) -> Agent: Returns: Agent | list[Agent]: The selected agent or list of agents based on the index or slice provided. """ - return list(self._agents.keys())[item] + return list(self._agents.keys())[item] #fixme def add(self, agent: Agent): """Add an agent to the AgentSet. @@ -448,7 +453,7 @@ def add(self, agent: Agent): Note: This method is an implementation of the abstract method from MutableSet. """ - self._agents[agent] = None + self._agents[agent._mesa_weakref] = None def discard(self, agent: Agent): """Remove an agent from the AgentSet if it exists. @@ -462,7 +467,7 @@ def discard(self, agent: Agent): This method is an implementation of the abstract method from MutableSet. """ with contextlib.suppress(KeyError): - del self._agents[agent] + del self._agents[agent._mesa_weakref] def remove(self, agent: Agent): """Remove an agent from the AgentSet. @@ -475,7 +480,7 @@ def remove(self, agent: Agent): Note: This method is an implementation of the abstract method from MutableSet. """ - del self._agents[agent] + del self._agents[agent._mesa_weakref] def __getstate__(self): """Retrieve the state of the AgentSet for serialization. @@ -483,7 +488,7 @@ def __getstate__(self): Returns: dict: A dictionary representing the state of the AgentSet. """ - return {"agents": list(self._agents.keys()), "random": self.random} + return {"agents": list(self._agents.keys()), "random": self.random} # fixme, we have to resolve all weakrefs here def __setstate__(self, state): """Set the state of the AgentSet during deserialization. @@ -492,7 +497,7 @@ def __setstate__(self, state): state (dict): A dictionary representing the state to restore. """ self.random = state["random"] - self._update(state["agents"]) + self._update(state["agents"]) # fixme, we have to get the weakrefs out again here def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: """Group agents by the specified attribute or return from the callable.