diff --git a/baselines/QLearning/transf_qmix.py b/baselines/QLearning/transf_qmix.py index 3a7d6d13..3d131bb9 100644 --- a/baselines/QLearning/transf_qmix.py +++ b/baselines/QLearning/transf_qmix.py @@ -426,7 +426,7 @@ def _env_sample_step(env_state, unused): # INIT NETWORK # init agent - if env.name=='smax': # smax agent + if 'smax' in env.name.lower(): # smax agent agent_class = TransformerAgentSmax n_entities = wrapped_env._env.num_allies+wrapped_env._env.num_enemies # must be explicit for the n_entities if using policy decoupling init_x = (