Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel committed Nov 13, 2024
1 parent 54d7e28 commit 69bc762
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def __init__(self, model: Model, *args, **kwargs) -> None:

self.model: Model = model
self.unique_id: int = next(self._ids[model])
self._mesa_weakref = weakref.ref(self) ## fixme change code and test impact
self.pos: Position | None = None
self.model.register_agent(self)


def remove(self) -> None:
"""Remove and delete the agent from the model.
Expand Down Expand Up @@ -135,19 +137,22 @@ def __init__(self, agents: Iterable[Agent], random: Random | None = None):
Random()
) # FIXME see issue 1981, how to get the central rng from model
self.random = random
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
# self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
self._agents = {agent._mesa_weakref: None for agent in agents}

def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
return len(self._agents)

def __iter__(self) -> Iterator[Agent]:
"""Provide an iterator over the agents in the AgentSet."""
return self._agents.keys()

# fixme should this not also be used to cleanout the weakrefs that resolve to None?
return (agent for entry in self._agents.keys() if (agent := entry()) is not None)

def __contains__(self, agent: Agent) -> bool:
"""Check if an agent is in the AgentSet. Can be used like `agent in agentset`."""
return agent in self._agents
return agent._mesa_weakref in self._agents

def select(
self,
Expand Down Expand Up @@ -210,7 +215,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet:
Using inplace = True is more performant
"""
weakrefs = list(self._agents.keyrefs())
weakrefs = list(self._agents.keys())
self.random.shuffle(weakrefs)

if inplace:
Expand Down Expand Up @@ -240,7 +245,7 @@ def sort(
if isinstance(key, str):
key = operator.attrgetter(key)

sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending)
sorted_agents = sorted(self, key=key, reverse=not ascending)

return (
AgentSet(sorted_agents, self.random)
Expand All @@ -253,7 +258,7 @@ def _update(self, agents: Iterable[Agent]):
This is a private method primarily used internally by other methods like select, shuffle, and sort.
"""
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
self._agents = {agent: None for agent in agents}
return self

def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
Expand All @@ -273,11 +278,11 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
"""
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
if isinstance(method, str):
for agentref in self._agents.keyrefs():
for agentref in self._agents:
if (agent := agentref()) is not None:
getattr(agent, method)(*args, **kwargs)
else:
for agentref in self._agents.keyrefs():
for agentref in self._agents:
if (agent := agentref()) is not None:
method(agent, *args, **kwargs)

Expand All @@ -288,7 +293,7 @@ def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
It's a fast, optimized version of calling shuffle() followed by do().
"""
weakrefs = list(self._agents.keyrefs())
weakrefs = list(self._agents)
self.random.shuffle(weakrefs)

if isinstance(method, str):
Expand Down Expand Up @@ -321,13 +326,13 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
if isinstance(method, str):
res = [
getattr(agent, method)(*args, **kwargs)
for agentref in self._agents.keyrefs()
for agentref in self._agents
if (agent := agentref()) is not None
]
else:
res = [
method(agent, *args, **kwargs)
for agentref in self._agents.keyrefs()
for agentref in self._agents
if (agent := agentref()) is not None
]

Expand Down Expand Up @@ -390,22 +395,22 @@ def get(

if handle_missing == "error":
if is_single_attr:
return [getattr(agent, attr_names) for agent in self._agents]
return [getattr(agent, attr_names) for agent_ref in self._agents if (agent := agent_ref()) is not None]
else:
return [
[getattr(agent, attr) for attr in attr_names]
for agent in self._agents
for agent_ref in self._agents if (agent := agent_ref()) is not None
]

elif handle_missing == "default":
if is_single_attr:
return [
getattr(agent, attr_names, default_value) for agent in self._agents
getattr(agent, attr_names, default_value) for agent_ref in self._agents if (agent := agent_ref()) is not None
]
else:
return [
[getattr(agent, attr, default_value) for attr in attr_names]
for agent in self._agents
for agent_ref in self._agents if (agent := agent_ref()) is not None
]

else:
Expand Down Expand Up @@ -437,7 +442,7 @@ def __getitem__(self, item: int | slice) -> Agent:
Returns:
Agent | list[Agent]: The selected agent or list of agents based on the index or slice provided.
"""
return list(self._agents.keys())[item]
return list(self._agents.keys())[item] #fixme

def add(self, agent: Agent):
"""Add an agent to the AgentSet.
Expand All @@ -448,7 +453,7 @@ def add(self, agent: Agent):
Note:
This method is an implementation of the abstract method from MutableSet.
"""
self._agents[agent] = None
self._agents[agent._mesa_weakref] = None

def discard(self, agent: Agent):
"""Remove an agent from the AgentSet if it exists.
Expand All @@ -462,7 +467,7 @@ def discard(self, agent: Agent):
This method is an implementation of the abstract method from MutableSet.
"""
with contextlib.suppress(KeyError):
del self._agents[agent]
del self._agents[agent._mesa_weakref]

def remove(self, agent: Agent):
"""Remove an agent from the AgentSet.
Expand All @@ -475,15 +480,15 @@ def remove(self, agent: Agent):
Note:
This method is an implementation of the abstract method from MutableSet.
"""
del self._agents[agent]
del self._agents[agent._mesa_weakref]

def __getstate__(self):
"""Retrieve the state of the AgentSet for serialization.
Returns:
dict: A dictionary representing the state of the AgentSet.
"""
return {"agents": list(self._agents.keys()), "random": self.random}
return {"agents": list(self._agents.keys()), "random": self.random} # fixme, we have to resolve all weakrefs here

def __setstate__(self, state):
"""Set the state of the AgentSet during deserialization.
Expand All @@ -492,7 +497,7 @@ def __setstate__(self, state):
state (dict): A dictionary representing the state to restore.
"""
self.random = state["random"]
self._update(state["agents"])
self._update(state["agents"]) # fixme, we have to get the weakrefs out again here

def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
"""Group agents by the specified attribute or return from the callable.
Expand Down

0 comments on commit 69bc762

Please sign in to comment.