diff --git a/mesa/space.py b/mesa/space.py index 16da1fbd7b5..362099671a2 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -507,6 +507,26 @@ def position_agent( coords = (x, y) self.place_agent(agent, coords) + def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None: + """Swap agents positions""" + agents_no_pos = [] + if (pos_a := agent_a.pos) is None: + agents_no_pos.append(agent_a) + if (pos_b := agent_b.pos) is None: + agents_no_pos.append(agent_b) + if agents_no_pos: + agents_no_pos = [f"" for a in agents_no_pos] + raise Exception(f"{', '.join(agents_no_pos)} - not on the grid") + + if pos_a == pos_b: + return + + self.remove_agent(agent_a) + self.remove_agent(agent_b) + + self.place_agent(agent_a, pos_b) + self.place_agent(agent_b, pos_a) + def place_agent(self, agent: Agent, pos: Coordinate) -> None: if self.is_cell_empty(pos): super().place_agent(agent, pos) diff --git a/mesa/time.py b/mesa/time.py index 3b42a56aa9f..34fc8705fad 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -28,7 +28,7 @@ from collections import defaultdict # mypy -from typing import Iterator, Union +from typing import Iterator, Union, Iterable from mesa.agent import Agent from mesa.model import Model @@ -100,8 +100,9 @@ def agent_buffer(self, shuffled: bool = False) -> Iterator[Agent]: remove and/or add agents during stepping. """ - agent_keys = list(self._agents.keys()) + agent_keys = self._agents.keys() if shuffled: + agent_keys = list(agent_keys) self.model.random.shuffle(agent_keys) for key in agent_keys: @@ -192,16 +193,18 @@ def __init__( def step(self) -> None: """Executes all the stages for all agents.""" - agent_keys = list(self._agents.keys()) + agent_keys = self._agents.keys() if self.shuffle: + agent_keys = list(agent_keys) self.model.random.shuffle(agent_keys) for stage in self.stage_list: for agent_key in agent_keys: getattr(self._agents[agent_key], stage)() # Run stage # We recompute the keys because some agents might have been removed # in the previous loop. - agent_keys = list(self._agents.keys()) + agent_keys = self._agents.keys() if self.shuffle_between_stages: + agent_keys = list(agent_keys) self.model.random.shuffle(agent_keys) self.time += self.stage_time @@ -264,8 +267,9 @@ def step(self, shuffle_types: bool = True, shuffle_agents: bool = True) -> None: shuffle_agents: If True, the order of execution of each agents in a type group is shuffled. """ - type_keys: list[type[Agent]] = list(self.agents_by_type.keys()) + type_keys: Iterable[type[Agent]] = self.agents_by_type.keys() if shuffle_types: + type_keys = list(type_keys) self.model.random.shuffle(type_keys) for agent_class in type_keys: self.step_type(agent_class, shuffle_agents=shuffle_agents) @@ -280,8 +284,9 @@ def step_type(self, type_class: type[Agent], shuffle_agents: bool = True) -> Non Args: type_class: Class object of the type to run. """ - agent_keys: list[int] = list(self.agents_by_type[type_class].keys()) + agent_keys: Iterable[int] = self.agents_by_type[type_class].keys() if shuffle_agents: + agent_keys = list(agent_keys) self.model.random.shuffle(agent_keys) for agent_key in agent_keys: self.agents_by_type[type_class][agent_key].step() diff --git a/tests/test_grid.py b/tests/test_grid.py index 3b43872ee43..3104ba2e68e 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -270,6 +270,34 @@ def test_enforcement(self, mock_model): with self.assertRaises(Exception): self.move_to_empty(self.agents[0], num_agents=self.num_agents) + # Swap agents positions + agent_a, agent_b = random.sample(list(self.grid), k=2) + pos_a = agent_a.pos + pos_b = agent_b.pos + + self.grid.swap_pos(agent_a, agent_b) + + assert agent_a.pos == pos_b + assert agent_b.pos == pos_a + assert self.grid[pos_a] == agent_b + assert self.grid[pos_b] == agent_a + + # Swap the same agents + self.grid.swap_pos(agent_a, agent_a) + + assert agent_a.pos == pos_b + assert self.grid[pos_b] == agent_a + + # Raise for agents not on the grid + self.grid.remove_agent(agent_a) + self.grid.remove_agent(agent_b) + + id_a = agent_a.unique_id + id_b = agent_b.unique_id + e_message = f", - not on the grid" + with self.assertRaisesRegex(Exception, e_message): + self.grid.swap_pos(agent_a, agent_b) + # Number of agents at each position for testing # Initial agent positions for testing