diff --git a/maze/core/agent/torch_policy.py b/maze/core/agent/torch_policy.py index 2285fd25e..6a62d1eef 100644 --- a/maze/core/agent/torch_policy.py +++ b/maze/core/agent/torch_policy.py @@ -134,13 +134,17 @@ def compute_logits_dict(self, observation: Any, actor_id: ActorIDType = None) -> obs_t = convert_to_torch(observation, device=self._device, cast=None, in_place=True) return self.network_for(actor_id)(obs_t) - def network_for(self, actor_id: ActorIDType) -> nn.Module: + def network_for(self, actor_id: Optional[ActorIDType]) -> nn.Module: """Helper function for returning a network for the given policy ID (using either just the sub-step ID or the full Actor ID as key, depending on the separated agent networks mode. :param actor_id: Actor ID to get a network for :return: Network corresponding to the given policy ID. """ + if actor_id is None: + assert len(self.networks) == 1, "multiple networks are available, please specify the actor ID explicitly" + return list(self.networks.values())[0] + network_key = actor_id if actor_id[0] in self.substeps_with_separate_agent_nets else actor_id[0] return self.networks[network_key] diff --git a/maze/test/core/agent/test_torch_policy.py b/maze/test/core/agent/test_torch_policy.py new file mode 100644 index 000000000..139db77ee --- /dev/null +++ b/maze/test/core/agent/test_torch_policy.py @@ -0,0 +1,13 @@ +"""Torch policy mechanics tests.""" + +from maze.test.shared_test_utils.helper_functions import build_dummy_maze_env, \ + flatten_concat_probabilistic_policy_for_env + + +def test_actor_id_is_optional_for_single_network_policies(): + env = build_dummy_maze_env() + policy = flatten_concat_probabilistic_policy_for_env(env) + + obs = env.reset() + action = policy.compute_action(obs) # No actor ID provided + assert action in env.action_space