Skip to content

Commit

Permalink
Fix: Make actor ID optional for single-network torch policies
Browse files Browse the repository at this point in the history
  • Loading branch information
EnliteAI Bot authored and enliteai committed Apr 29, 2021
1 parent 8b26b32 commit 457169e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion maze/core/agent/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,17 @@ def compute_logits_dict(self, observation: Any, actor_id: ActorIDType = None) ->
obs_t = convert_to_torch(observation, device=self._device, cast=None, in_place=True)
return self.network_for(actor_id)(obs_t)

def network_for(self, actor_id: ActorIDType) -> nn.Module:
def network_for(self, actor_id: Optional[ActorIDType]) -> nn.Module:
"""Helper function for returning a network for the given policy ID (using either just the sub-step ID
or the full Actor ID as key, depending on the separated agent networks mode.
:param actor_id: Actor ID to get a network for
:return: Network corresponding to the given policy ID.
"""
if actor_id is None:
assert len(self.networks) == 1, "multiple networks are available, please specify the actor ID explicitly"
return list(self.networks.values())[0]

network_key = actor_id if actor_id[0] in self.substeps_with_separate_agent_nets else actor_id[0]
return self.networks[network_key]

Expand Down
13 changes: 13 additions & 0 deletions maze/test/core/agent/test_torch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Torch policy mechanics tests."""

from maze.test.shared_test_utils.helper_functions import build_dummy_maze_env, \
flatten_concat_probabilistic_policy_for_env


def test_actor_id_is_optional_for_single_network_policies():
env = build_dummy_maze_env()
policy = flatten_concat_probabilistic_policy_for_env(env)

obs = env.reset()
action = policy.compute_action(obs) # No actor ID provided
assert action in env.action_space

0 comments on commit 457169e

Please sign in to comment.