From f1528d4a6647f20f3c769c45e7cf74c80b21542e Mon Sep 17 00:00:00 2001 From: nighood Date: Tue, 28 Nov 2023 23:57:30 +0800 Subject: [PATCH 1/8] env(rjy): add mamujoco for LightZero --- zoo/multiagent_mujoco/__init__.py | 0 .../envs/multiagent_mujoco_lightzero_env.py | 106 ++++++++++++++++++ .../test_multiagent_mujoco_lightzero_env.py | 0 3 files changed, 106 insertions(+) create mode 100644 zoo/multiagent_mujoco/__init__.py create mode 100644 zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py create mode 100644 zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py diff --git a/zoo/multiagent_mujoco/__init__.py b/zoo/multiagent_mujoco/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py new file mode 100644 index 000000000..7863f2409 --- /dev/null +++ b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py @@ -0,0 +1,106 @@ +import os +from typing import Union + +import gym +import numpy as np +from ding.envs import BaseEnvTimestep +from ding.envs.common import save_frames_as_gif +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY +from dizoo.multiagent_mujoco.envs.multi_mujoco_env import MujocoEnv + + +@ENV_REGISTRY.register('mujoco_lightzero') +class MAMujocoEnvLZ(MujocoEnv): + """ + Overview: + The modified MuJoCo environment with continuous action space for LightZero's algorithms. + """ + + config = dict( + stop_value=int(1e6), + action_clip=False, + delay_reward_step=0, + replay_path=None, + save_replay_gif=False, + replay_path_gif=None, + action_bins_per_branch=None, + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + ) + + def __init__(self, cfg: dict) -> None: + super().__init__(cfg) + self._cfg = cfg + # We use env_name to indicate the env_id in LightZero. + self._cfg.env_id = self._cfg.env_name + self._action_clip = cfg.action_clip + self._delay_reward_step = cfg.delay_reward_step + self._init_flag = False + self._replay_path = None + self._replay_path_gif = cfg.replay_path_gif + self._save_replay_gif = cfg.save_replay_gif + self._action_bins_per_branch = cfg.action_bins_per_branch + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = self._make_env() + if self._replay_path is not None: + self._env = gym.wrappers.RecordVideo( + self._env, + video_folder=self._replay_path, + episode_trigger=lambda episode_id: True, + name_prefix='rl-video-{}'.format(id(self)) + ) + + self._env.observation_space.dtype = np.float32 + self._observation_space = self._env.observation_space + self._action_space = self._env.action_space + self._reward_space = gym.spaces.Box( + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1,), dtype=np.float32 + ) + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + obs = self._env.reset() + obs = to_ndarray(obs).astype('float32') + self._eval_episode_return = 0. + + action_mask = None + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + + return obs + + def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: + if self._action_bins_per_branch: + action = self.map_action(action) + action = to_ndarray(action) + if self._save_replay_gif: + self._frames.append(self._env.render(mode='rgb_array')) + if self._action_clip: + action = np.clip(action, -1, 1) + obs, rew, done, info = self._env.step(action) + self._eval_episode_return += rew + if done: + if self._save_replay_gif: + path = os.path.join( + self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_name, self._save_replay_count) + ) + save_frames_as_gif(self._frames, path) + self._save_replay_count += 1 + info['eval_episode_return'] = self._eval_episode_return + + obs = to_ndarray(obs).astype(np.float32) + rew = to_ndarray([rew]).astype(np.float32) + + action_mask = None + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + + return BaseEnvTimestep(obs, rew, done, info) + + def __repr__(self) -> str: + return "LightZero MAMujoco Env({})".format(self._cfg.env_name) + diff --git a/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py new file mode 100644 index 000000000..e69de29bb From de89e54653ae173052403ac13b6b3b7c0eb3a37d Mon Sep 17 00:00:00 2001 From: nighood Date: Wed, 29 Nov 2023 15:37:47 +0800 Subject: [PATCH 2/8] fix(rjy): fix mamujoco and add test --- zoo/multiagent_mujoco/envs/__init__.py | 1 + .../envs/multiagent_mujoco_lightzero_env.py | 78 +++++++++---------- .../test_multiagent_mujoco_lightzero_env.py | 40 ++++++++++ 3 files changed, 76 insertions(+), 43 deletions(-) create mode 100644 zoo/multiagent_mujoco/envs/__init__.py diff --git a/zoo/multiagent_mujoco/envs/__init__.py b/zoo/multiagent_mujoco/envs/__init__.py new file mode 100644 index 000000000..48b10ddf3 --- /dev/null +++ b/zoo/multiagent_mujoco/envs/__init__.py @@ -0,0 +1 @@ +from .multiagent_mujoco_lightzero_env import MAMujocoEnvLZ \ No newline at end of file diff --git a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py index 7863f2409..9448a9be6 100644 --- a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py +++ b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py @@ -7,24 +7,18 @@ from ding.envs.common import save_frames_as_gif from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY -from dizoo.multiagent_mujoco.envs.multi_mujoco_env import MujocoEnv +from dizoo.multiagent_mujoco.envs.multi_mujoco_env import MujocoEnv,MujocoMulti @ENV_REGISTRY.register('mujoco_lightzero') class MAMujocoEnvLZ(MujocoEnv): """ Overview: - The modified MuJoCo environment with continuous action space for LightZero's algorithms. + The modified Multi-agentMuJoCo environment with continuous action space for LightZero's algorithms. """ config = dict( stop_value=int(1e6), - action_clip=False, - delay_reward_step=0, - replay_path=None, - save_replay_gif=False, - replay_path_gif=None, - action_bins_per_branch=None, norm_obs=dict(use_norm=False, ), norm_reward=dict(use_norm=False, ), ) @@ -34,40 +28,50 @@ def __init__(self, cfg: dict) -> None: self._cfg = cfg # We use env_name to indicate the env_id in LightZero. self._cfg.env_id = self._cfg.env_name - self._action_clip = cfg.action_clip - self._delay_reward_step = cfg.delay_reward_step self._init_flag = False - self._replay_path = None - self._replay_path_gif = cfg.replay_path_gif - self._save_replay_gif = cfg.save_replay_gif - self._action_bins_per_branch = cfg.action_bins_per_branch def reset(self) -> np.ndarray: if not self._init_flag: - self._env = self._make_env() - if self._replay_path is not None: - self._env = gym.wrappers.RecordVideo( - self._env, - video_folder=self._replay_path, - episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) - ) - - self._env.observation_space.dtype = np.float32 - self._observation_space = self._env.observation_space - self._action_space = self._env.action_space - self._reward_space = gym.spaces.Box( - low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1,), dtype=np.float32 - ) + self._env = MujocoMulti(env_args=self._cfg) self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) self._env.seed(self._seed + np_seed) elif hasattr(self, '_seed'): self._env.seed(self._seed) + obs = self._env.reset() - obs = to_ndarray(obs).astype('float32') + obs = to_ndarray(obs) self._eval_episode_return = 0. + self.env_info = self._env.get_env_info() + + self._num_agents = self.env_info['n_agents'] + self._agents = [i for i in range(self._num_agents)] + self._observation_space = gym.spaces.Dict( + { + 'agent_state': gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=obs['agent_state'].shape, dtype=np.float32 + ), + 'global_state': gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=obs['global_state'].shape, dtype=np.float32 + ), + } + ) + self._action_space = gym.spaces.Dict({agent: self._env.action_space[agent] for agent in self._agents}) + single_agent_obs_space = self._env.action_space[self._agents[0]] + if isinstance(single_agent_obs_space, gym.spaces.Box): + self._action_dim = single_agent_obs_space.shape + elif isinstance(single_agent_obs_space, gym.spaces.Discrete): + self._action_dim = (single_agent_obs_space.n, ) + else: + raise Exception('Only support `Box` or `Discrte` obs space for single agent.') + self._reward_space = gym.spaces.Dict( + { + agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) + for agent in self._agents + } + ) action_mask = None obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} @@ -75,25 +79,13 @@ def reset(self) -> np.ndarray: return obs def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: - if self._action_bins_per_branch: - action = self.map_action(action) action = to_ndarray(action) - if self._save_replay_gif: - self._frames.append(self._env.render(mode='rgb_array')) - if self._action_clip: - action = np.clip(action, -1, 1) obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew if done: - if self._save_replay_gif: - path = os.path.join( - self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_name, self._save_replay_count) - ) - save_frames_as_gif(self._frames, path) - self._save_replay_count += 1 info['eval_episode_return'] = self._eval_episode_return - obs = to_ndarray(obs).astype(np.float32) + obs = to_ndarray(obs) rew = to_ndarray([rew]).astype(np.float32) action_mask = None diff --git a/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py index e69de29bb..13d399cd3 100644 --- a/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py +++ b/zoo/multiagent_mujoco/envs/test_multiagent_mujoco_lightzero_env.py @@ -0,0 +1,40 @@ +from time import time +import pytest +import numpy as np +from easydict import EasyDict +from zoo.multiagent_mujoco.envs import MAMujocoEnvLZ + + +@pytest.mark.envtest +@pytest.mark.parametrize( + 'cfg', [ + EasyDict({ + 'env_name': 'mujoco_lightzero', + 'scenario': 'Ant-v2', + 'agent_conf': "2x4d", + 'agent_obsk': 2, + 'add_agent_id': False, + 'episode_limit': 1000, + },) + ] +) + +class TestMAMujocoEnvLZ: + def test_naive(self, cfg): + env = MAMujocoEnvLZ(cfg) + env.seed(314) + assert env._seed == 314 + obs = env.reset() + assert isinstance(obs, dict) + for i in range(10): + random_action = env.random_action() + timestep = env.step(random_action[0]) + print(timestep) + assert isinstance(timestep.obs, dict) + assert isinstance(timestep.done, bool) + assert timestep.obs['observation']['global_state'].shape == (2, 111) + assert timestep.obs['observation']['agent_state'].shape == (2, 54) + assert timestep.reward.shape == (1, ) + assert isinstance(timestep, tuple) + print(env.observation_space, env.action_space, env.reward_space) + env.close() From 933f6730c6320cf5e8154219e55648b03443e25f Mon Sep 17 00:00:00 2001 From: nighood Date: Wed, 6 Dec 2023 10:44:25 +0800 Subject: [PATCH 3/8] feature(rjy): add independent sez pipeline(\train) --- lzero/mcts/buffer/game_buffer.py | 18 +- lzero/mcts/buffer/game_buffer_muzero.py | 5 +- .../game_buffer_sampled_efficientzero.py | 60 +- lzero/mcts/buffer/game_segment.py | 42 +- ..._efficientzero_model_mlp_ma_independent.py | 540 ++++++++++++++++++ lzero/policy/sampled_efficientzero.py | 149 +++-- lzero/policy/scaling_transform.py | 2 +- lzero/policy/utils.py | 18 +- lzero/worker/muzero_collector.py | 34 +- ...ent_mujoco_sampled_efficientzero_config.py | 130 +++++ zoo/multiagent_mujoco/entry/__init__.py | 1 + .../entry/train_sez_independent_mamujoco.py | 195 +++++++ .../envs/multiagent_mujoco_lightzero_env.py | 2 +- 13 files changed, 1091 insertions(+), 105 deletions(-) create mode 100644 lzero/model/sampled_efficientzero_model_mlp_ma_independent.py create mode 100644 zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py create mode 100644 zoo/multiagent_mujoco/entry/__init__.py create mode 100644 zoo/multiagent_mujoco/entry/train_sez_independent_mamujoco.py diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index e5066bca5..1fbb2613b 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -57,6 +57,9 @@ def __init__(self, cfg: dict): self.batch_size = self._cfg.batch_size self._alpha = self._cfg.priority_prob_alpha self._beta = self._cfg.priority_prob_beta + self._multi_agent = self._cfg.model.get('multi_agent', False) + if self._multi_agent: + self._num_agents = self._cfg.model.get('agent_num', 1) self.game_segment_buffer = [] self.game_pos_priorities = [] @@ -344,9 +347,18 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: ) else: assert len(data) == len(meta['priorities']), " priorities should be of same length as the game steps" - priorities = meta['priorities'].copy().reshape(-1) - priorities[valid_len:len(data)] = 0. - self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities)) + if self._multi_agent: + priorities = meta['priorities'].copy() + priorities[valid_len:len(data)] = np.zeros_like(priorities[0]) + if len(self.game_pos_priorities) == 0: + self.game_pos_priorities = priorities + else: + self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities)) + + else: + priorities = priorities.reshape(-1) + priorities[valid_len:len(data)] = 0. + self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities)) self.game_segment_buffer.append(data) self.game_segment_game_pos_look_up += [ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index daddf6f9f..a09148e2a 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -670,7 +670,10 @@ def _compute_target_policy_non_reanalyzed( else: # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 policy_mask.append(0) - target_policies.append([0 for _ in range(policy_shape)]) + if self._multi_agent: + target_policies.append([np.zeros_like(child_visit[0][0])] * self._cfg.model.agent_num) + else: + target_policies.append([0 for _ in range(policy_shape)]) policy_index += 1 diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index ede58eb4e..2e5cda8a8 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -3,6 +3,8 @@ import numpy as np import torch from ding.utils import BUFFER_REGISTRY +from ding.utils.data import default_collate, default_decollate +from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree @@ -140,7 +142,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # sampled related core code # ============================================================== actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() + self._cfg.num_unroll_steps] + if not isinstance(actions_tmp, list): + actions_tmp = actions_tmp.tolist() # NOTE: self._cfg.num_unroll_steps + 1 root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment + @@ -152,14 +156,25 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # pad random action if self._cfg.model.continuous_action_space: - actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + if self._multi_agent: + actions_tmp += [ + np.random.randn(self._cfg.model.agent_num, self._cfg.model.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] - root_sampled_actions_tmp += [ + root_sampled_actions_tmp += [ + np.random.rand(self._cfg.model.agent_num, self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) + ] + else: + actions_tmp += [ + np.random.randn(self._cfg.model.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + root_sampled_actions_tmp += [ np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] + else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( @@ -192,7 +207,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: mask_list.append(mask_tmp) # formalize the input observations - obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + if not self._multi_agent: + obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # ============================================================== # sampled related core code # ============================================================== @@ -202,7 +218,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: ] for i in range(len(current_batch)): - current_batch[i] = np.asarray(current_batch[i]) + current_batch[i] = to_ndarray(current_batch[i]) total_transitions = self.get_num_of_transitions() @@ -272,16 +288,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values, batch_value_prefixs = [], [] with torch.no_grad(): - value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + if not self._multi_agent: + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) network_output = [] for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_obs = to_dtype(to_device(to_tensor(value_obs_list[beg_index:end_index]), self._cfg.device), torch.float) # calculate the target value + m_obs = default_collate(m_obs) + if self._multi_agent: + m_obs = m_obs[0] m_output = model.initial_inference(m_obs) # TODO(pu) @@ -355,12 +375,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A ] ) else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) + if self._multi_agent: + value_list = value_list.reshape(transition_batch_size, self._cfg.model.agent_num) + factor = np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + value_list = value_list * factor.reshape(transition_batch_size, 1).astype(np.float32) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() + if self._multi_agent: + value_list = value_list * np.array(value_mask)[:, np.newaxis] + else: + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() horizon_id, value_index = 0, 0 for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, @@ -399,7 +427,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A ] # * config.discount_factor ** (current_index - base_index) target_value_prefixs.append(value_prefix) else: - target_values.append(0) + target_values.append(np.zeros_like(value_list[0])) target_value_prefixs.append(value_prefix) value_index += 1 @@ -407,8 +435,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_value_prefixs.append(target_value_prefixs) batch_target_values.append(target_values) - batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) + batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=np.float32) + batch_target_values = np.asarray(batch_target_values, dtype=np.float32) return batch_value_prefixs, batch_target_values diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ae9260921..27aa19fb7 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -5,6 +5,7 @@ from easydict import EasyDict from ding.utils.compression_helper import jpeg_data_decompressor +from ding.torch_utils import to_ndarray class GameSegment: @@ -96,12 +97,23 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool if padding: pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) if pad_len > 0: - pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) - stacked_obs = np.concatenate((stacked_obs, pad_frames)) + pad_frames = [stacked_obs[-1] for _ in range(pad_len)] + stacked_obs += pad_frames if self.transform2string: stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs + def _zero_obs(self, input_data): + if isinstance(input_data, dict): + # Process dict + return {k: self._zero_obs(v) for k, v in input_data.items()} + elif isinstance(input_data, (list, np.ndarray)): + # Process arrays or lists + return np.zeros_like(input_data) + else: + # Process other types (e.g. numbers, strings, etc.) + return input_data + def zero_obs(self) -> List: """ Overview: @@ -109,7 +121,7 @@ def zero_obs(self) -> List: Returns: ndarray: An array filled with zeros. """ - return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] + return [self._zero_obs(self.obs_segment[0]) for _ in range(self.frame_stack_num)] def get_obs(self) -> List: """ @@ -212,9 +224,9 @@ def store_search_stats( Overview: store the visit count distributions and value of the root node after MCTS. """ - sum_visits = sum(visit_counts) + sum_visits = np.sum(visit_counts, axis=-1) if idx is None: - self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) + self.child_visit_segment.append([visit_count / sum_visits[i] for i,visit_count in enumerate(visit_counts)]) self.root_value_segment.append(root_value) if self.sampled_algo: self.root_sampled_actions.append(root_sampled_actions) @@ -272,26 +284,26 @@ def game_segment_to_array(self) -> None: For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. """ - self.obs_segment = np.array(self.obs_segment) - self.action_segment = np.array(self.action_segment) - self.reward_segment = np.array(self.reward_segment) + self.obs_segment = to_ndarray(self.obs_segment) + self.action_segment = to_ndarray(self.action_segment) + self.reward_segment = to_ndarray(self.reward_segment) # Check if all elements in self.child_visit_segment have the same length if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): - self.child_visit_segment = np.array(self.child_visit_segment) + self.child_visit_segment = to_ndarray(self.child_visit_segment) else: # In the case of environments with a variable action space, such as board games, # the elements in child_visit_segment may have different lengths. # In such scenarios, it is necessary to use the object data type. - self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) + self.child_visit_segment = to_ndarray(self.child_visit_segment, dtype=object) - self.root_value_segment = np.array(self.root_value_segment) - self.improved_policy_probs = np.array(self.improved_policy_probs) + self.root_value_segment = to_ndarray(self.root_value_segment) + self.improved_policy_probs = to_ndarray(self.improved_policy_probs) - self.action_mask_segment = np.array(self.action_mask_segment) - self.to_play_segment = np.array(self.to_play_segment) + self.action_mask_segment = to_ndarray(self.action_mask_segment) + self.to_play_segment = to_ndarray(self.to_play_segment) if self.use_ture_chance_label_in_chance_encoder: - self.chance_segment = np.array(self.chance_segment) + self.chance_segment = to_ndarray(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: """ diff --git a/lzero/model/sampled_efficientzero_model_mlp_ma_independent.py b/lzero/model/sampled_efficientzero_model_mlp_ma_independent.py new file mode 100644 index 000000000..ea2c20ee6 --- /dev/null +++ b/lzero/model/sampled_efficientzero_model_mlp_ma_independent.py @@ -0,0 +1,540 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.model.common import ReparameterizationHead +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import EZNetworkOutput, RepresentationNetworkMLP +from .efficientzero_model_mlp import DynamicsNetworkMLP +from .utils import renormalize, get_params_mean +from .sampled_efficientzero_model_mlp import SampledEfficientZeroModelMLP + + +@MODEL_REGISTRY.register('SampledEfficientZeroModelMLPMaIndependent') +class SampledEfficientZeroModelMLPMaIndependent(nn.Module): + + def __init__( + self, + global_observation_shape, + agent_observation_shape, + action_space_size, + latent_state_dim: int = 256, + lstm_hidden_size: int = 512, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = None, + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of Sampled EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ + e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(SampledEfficientZeroModelMLPMaIndependent, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.continuous_action_space = continuous_action_space + self.global_observation_shape = global_observation_shape + self.agent_observation_shape = agent_observation_shape + self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.latent_state_dim = latent_state_dim + self.fc_reward_layers = fc_reward_layers + self.fc_value_layers = fc_value_layers + self.fc_policy_layers = fc_policy_layers + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.self_supervised_learning_loss = self_supervised_learning_loss + + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.obs_representation_shape = global_observation_shape + agent_observation_shape + self.representation_network = RepresentationNetworkMLP( + observation_shape=self.obs_representation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=self.lstm_hidden_size, + fc_reward_layers=self.fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + continuous_action_space=self.continuous_action_space, + action_space_size=self.action_space_size, + num_channels=self.latent_state_dim, + fc_value_layers=self.fc_value_layers, + fc_policy_layers=self.fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + # batch_size = obs.size(0) + if isinstance(obs, dict): + # If data is a dictionary, find the first non-dictionary element and get its shape[0] + # TODO(rjy): written in recursive form + for k, v in obs.items(): + if not isinstance(v, dict): + batch_size = v.shape[0] * v.shape[1] # rbatch_size=batch_size*num_agent + obs_device = v.device + break + elif isinstance(obs, torch.Tensor): + # If data is a torch.tensor, directly return its shape[0] + batch_size = obs.shape[0] + obs_device = obs.device + + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs_device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs_device) + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + obs_for_represent = torch.cat((observation['global_state'], observation['agent_state']), dim=-1) + obs_for_represent = obs_for_represent.reshape(-1, obs_for_represent.shape[-1]) + latent_state = self.representation_network(obs_for_represent) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy, value = self.prediction_network(latent_state) + return policy, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + if not self.continuous_action_space: + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + else: + # continuous action space + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 3: + # (batch_size, action_dim, 1) -> (batch_size, action_dim) + # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) + action = action.squeeze(-1) + + action_encoding = action + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if not self.state_norm: + return next_latent_state, next_reward_hidden_state, value_prefix + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self): + return get_params_mean(self) + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + continuous_action_space, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + # ============================================================== + # specific sampled related config + # ============================================================== + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = None, + ): + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + The networks are mainly built on fully connected layers. + Arguments: + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. For continuous action space, it is the dimension of \ + continuous action. + - num_channels (:obj:`int`): The num of channels in latent states. + - num_res_blocks (:obj:`int`): The number of res blocks. + - fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): hidden layers of the policy prediction head (MLP head). + - output_support_size (:obj:`int`): dim of value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about thee following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. default set it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + self.continuous_action_space = continuous_action_space + self.norm_type = norm_type + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.action_space_size = action_space_size + if self.continuous_action_space: + self.action_encoding_dim = self.action_space_size + else: + self.action_encoding_dim = 1 + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=2, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + # sampled related core code + if self.continuous_action_space: + self.fc_policy_head = ReparameterizationHead( + input_size=self.num_channels, + output_size=action_space_size, + layer_num=2, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + activation=nn.ReLU(), + norm_type=None, + bound_type=self.bound_type + ) + else: + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=2, + activation=activation, + norm_type=self.norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size). + If action space is continuous, shape is (B, action_space_size * 2). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x_prediction_common) + + # sampled related core code + policy = self.fc_policy_head(x_prediction_common) + if self.continuous_action_space: + policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) + + return policy, value diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 184ab8c42..763bc8ff6 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -5,7 +5,8 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.torch_utils import to_tensor +from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray +from ding.utils.data import default_collate, default_decollate from ding.utils import POLICY_REGISTRY from ditk import logging from torch.distributions import Categorical, Independent, Normal @@ -224,7 +225,10 @@ def default_model(self) -> Tuple[str, List[str]]: if self._cfg.model.model_type == "conv": return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] elif self._cfg.model.model_type == "mlp": - return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] + if self._cfg.model.multi_agent is True: + return 'SampledEfficientZeroModelMLPMaIndependent', ['lzero.model.sampled_efficientzero_model_mlp_ma_independent'] + else: + return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) @@ -296,6 +300,7 @@ def _init_learn(self) -> None: self.inverse_scalar_transform_handle = InverseScalarTransform( self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution ) + self._multi_agent = self._cfg.model.multi_agent def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ @@ -330,24 +335,39 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # shape: (batch_size, num_unroll_steps, action_dim) # NOTE: .float(), in continuous action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1) - data_list = [ - mask_batch, - target_value_prefix.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [mask_batch, target_value_prefix, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) + if self._cfg.model.multi_agent: + action_batch = to_dtype(to_device(to_tensor(action_batch), self._cfg.device), torch.float) + action_batch = default_collate(default_collate(action_batch)) # (num_unroll_steps, batch_size, action_dim, 1) + action_batch = action_batch.transpose(0, 1) # (batch_size, num_unroll_steps, action_dim, 1) + mask_batch = to_dtype(default_collate(mask_batch), torch.float) + data_list = [ + target_value_prefix.astype('float32'), + target_value.astype('float32'), target_policy, weights + ] + [target_value_prefix, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + else: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1) + data_list = [ + mask_batch, + target_value_prefix.astype('float32'), + target_value.astype('float32'), target_policy, weights + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) # ============================================================== # sampled related core code # ============================================================== # shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1) - child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1) - - target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - - assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + if self._cfg.model.multi_agent: + child_sampled_actions_batch = default_collate(default_collate(child_sampled_actions_batch)) + child_sampled_actions_batch = to_dtype(to_device(child_sampled_actions_batch, self._cfg.device), torch.float) + child_sampled_actions_batch = child_sampled_actions_batch.transpose(0, 1) + else: + child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1) + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) # ``scalar_transform`` to transform the original value to the scaled value, # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. @@ -785,6 +805,7 @@ def _init_collect(self) -> None: else: self._mcts_collect = MCTSPtree(self._cfg) self._collect_mcts_temperature = 1 + self._multi_agent = self._cfg.model.multi_agent def _forward_collect( self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, @@ -815,7 +836,18 @@ def _forward_collect( """ self._collect_model.eval() self._collect_mcts_temperature = temperature - active_collect_env_num = data.shape[0] + if isinstance(data, dict): + # If data is a dictionary, find the first non-dictionary element and get its shape[0] + # TODO(rjy): written in recursive form + for k, v in data.items(): + if not isinstance(v, dict): + active_collect_env_num = v.shape[0]*v.shape[1] + agent_num = v.shape[1] # multi-agent + elif isinstance(data, torch.Tensor): + # If data is a torch.tensor, directly return its shape[0] + active_collect_env_num = data.shape[0] + agent_num = 1 # single-agent + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -871,48 +903,62 @@ def _forward_collect( roots_values = roots.get_values() # shape: {list: batch_size} roots_sampled_actions = roots.get_sampled_actions() # {list: 1}->{list:6} + if self._multi_agent: + active_collect_env_num = active_collect_env_num // agent_num data_id = [i for i in range(active_collect_env_num)] output = {i: None for i in data_id} if ready_env_id is None: ready_env_id = np.arange(active_collect_env_num) for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.mcts_ctree: - # In ctree, the method roots.get_sampled_actions() returns a list object. - root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) - else: - # In ptree, the same method roots.get_sampled_actions() returns an Action object. - root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) - - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - - if self._cfg.mcts_ctree: - # In ctree, the method roots.get_sampled_actions() returns a list object. - action = np.array(roots_sampled_actions[i][action]) - else: - # In ptree, the same method roots.get_sampled_actions() returns an Action object. - action = roots_sampled_actions[i][action].value - - if not self._cfg.model.continuous_action_space: - if len(action.shape) == 0: - action = int(action) - elif len(action.shape) == 1: - action = int(action[0]) - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'root_sampled_actions': root_sampled_actions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], + 'action': [], + 'visit_count_distributions': [], + 'root_sampled_actions': [], + 'visit_count_distribution_entropy': [], + 'searched_value': [], + 'predicted_value': [], + 'predicted_policy_logits': [], } + for j in range(agent_num): + index = i * agent_num + j + distributions, value = roots_visit_count_distributions[index], roots_values[index] + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. + root_sampled_actions = np.array([action for action in roots_sampled_actions[index]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[index]]) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. + action = np.array(roots_sampled_actions[index][action]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + action = roots_sampled_actions[index][action].value + + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) + + output[env_id]['action'].append(action) + output[env_id]['visit_count_distributions'].append(distributions) + output[env_id]['root_sampled_actions'].append(root_sampled_actions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['searched_value'].append(value) + output[env_id]['predicted_value'].append(pred_values[index]) + output[env_id]['predicted_policy_logits'].append(policy_logits[index]) + + for k,v in output[env_id].items(): + output[env_id][k] = np.array(v) return output @@ -926,6 +972,7 @@ def _init_eval(self) -> None: self._mcts_eval = MCTSCtree(self._cfg) else: self._mcts_eval = MCTSPtree(self._cfg) + self._multi_agent = self._cfg.model.multi_agent def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): """ diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 75b170612..d2ba13126 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -137,7 +137,7 @@ def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.T p_high = x - x_low p_low = 1 - p_high - target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) + target = torch.zeros(*x.shape, set_size).to(x.device) x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 0c49a4ac7..760d12cce 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -9,7 +9,8 @@ from easydict import EasyDict from scipy.stats import entropy from torch.nn import functional as F - +from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray +from ding.utils.data import default_collate, default_decollate def visualize_avg_softmax(logits): """ @@ -311,15 +312,24 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, obs_shape: 4 4 4 4 4 4 ----, ----, ----, ----, ----, ----, """ - obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() + # obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() + obs_batch_ori = to_dtype(to_device(to_tensor(obs_batch_ori), cfg.device), torch.float) # ``obs_batch`` is used in ``initial_inference()``, which is the first stacked obs at timestep t1 in # ``obs_batch_ori``. shape is (4, 4*3) = (4, 12) - obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num * cfg.model.observation_shape] + if cfg.model.multi_agent: + obs_batch_ori = default_collate(obs_batch_ori) + obs_batch = obs_batch_ori[0] + else: + obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num * cfg.model.observation_shape] if cfg.model.self_supervised_learning_loss: # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. - obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] + if cfg.model.multi_agent: + obs_target_batch = obs_batch_ori[1:] + obs_target_batch = default_collate(obs_target_batch) # {'agent_state': (num_unroll_steps, batch_size, agent_num, obs_shape) + else: + obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] return obs_batch, obs_target_batch diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index aca581c47..2053e63ef 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,15 +1,16 @@ import time -from collections import deque, namedtuple +from collections import deque, namedtuple, defaultdict from typing import Optional, Any, List import numpy as np import torch from ding.envs import BaseEnvManager -from ding.torch_utils import to_ndarray +from ding.torch_utils import to_ndarray, to_device from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ broadcast_object_list, allreduce_data from ding.worker.collector.base_serial_collector import ISerialCollector from torch.nn import L1Loss +from ding.utils.data import default_collate from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation @@ -56,6 +57,9 @@ def __init__( self._collect_print_freq = collect_print_freq self._timer = EasyTimer() self._end_flag = False + self._multi_agent = policy_config.model.get('multi_agent', False) + if self._multi_agent: + self._agent_num = policy_config.model.agent_num self._rank = get_rank() self._world_size = get_world_size() @@ -216,6 +220,8 @@ def _compute_priorities(self, i, pred_values_lst, search_values_lst): priorities = L1Loss(reduction='none' )(pred_values, search_values).detach().cpu().numpy() + 1e-6 + if self._multi_agent: + priorities = priorities.reshape(-1, self._agent_num) else: # priorities is None -> use the max priority for all newly collected data priorities = None @@ -368,7 +374,11 @@ def collect(self, improved_policy_lst = [[] for _ in range(env_nums)] # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self._multi_agent: + eps_steps_lst, visit_entropies_lst = np.zeros((env_nums, self._agent_num)), np.zeros( + (env_nums, self._agent_num)) + else: + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) self_play_moves = 0. @@ -388,8 +398,12 @@ def collect(self, ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} stack_obs = list(stack_obs.values()) + stack_obs = default_collate(stack_obs) + if not isinstance(stack_obs, dict): + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = to_device(stack_obs, self.policy_config.device) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} @@ -399,12 +413,6 @@ def collect(self, chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} chance = [chance_dict[env_id] for env_id in ready_env_id] - stack_obs = to_ndarray(stack_obs) - - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - # ============================================================== # policy forward # ============================================================== @@ -638,15 +646,15 @@ def collect(self, last_game_priorities[env_id] = None # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + # self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) self_play_moves += eps_steps_lst[env_id] self_play_episodes += 1 pred_values_lst[env_id] = [] search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 + eps_steps_lst[env_id] = np.zeros(self._agent_num) + visit_entropies_lst[env_id] = np.zeros(self._agent_num) # Env reset is done by env_manager automatically self._policy.reset([env_id]) diff --git a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py new file mode 100644 index 000000000..f9a5e9b6e --- /dev/null +++ b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py @@ -0,0 +1,130 @@ +from easydict import EasyDict +import os +os.environ["CUDA_VISIBLE_DEVICES"] = '7' + +# options={'Hopper-v2', 'HalfCheetah-v2', 'Walker2d-v2', 'Ant-v2', 'Humanoid-v2'} +env_name = 'Hopper-v2' +agent_conf = "3x1" +n_agent = 3 + +if env_name == 'Hopper-v2' and agent_conf == "3x1": + action_space_size = 1 + agent_observation_shape = 4 + global_observation_shape = 11 +elif env_name in ['HalfCheetah-v2', 'Walker2d-v2'] and agent_conf == "2x3": + action_space_size = 3 + agent_observation_shape = 8 + global_observation_shape = 17 +elif env_name == 'Ant-v2' and agent_conf == "2x4d": + action_space_size = 4 + agent_observation_shape = 54 + global_observation_shape = 111 +elif env_name == 'Humanoid-v2' and agent_conf == "9|8": + action_space_size = 9,8 + agent_observation_shape = 35 + global_observation_shape = 367 + +ignore_done = False +if env_name == 'HalfCheetah-v2': + # for halfcheetah, we ignore done signal to predict the Q value of the last step correctly. + ignore_done = True + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +seed = 0 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +continuous_action_space = True +K = 20 # num_of_sampled_actions +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 + +max_env_step = int(5e6) +reanalyze_ratio = 0. +policy_entropy_loss_weight = 0.005 + +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +mujoco_sampled_efficientzero_config = dict( + exp_name= + f'marl_result/debug/{env_name[:-3]}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs-{batch_size}_pelw{policy_entropy_loss_weight}_seed{seed}', + env=dict( + env_name=env_name, + scenario=env_name, + agent_conf=agent_conf, + agent_obsk=2, + add_agent_id=False, + episode_limit=1000, + continuous=True, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + multi_agent=True, + agent_num=n_agent, + agent_observation_shape=agent_observation_shape, + global_observation_shape=global_observation_shape, + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, + res_connection_in_dynamics=True, + norm_type=None, + ), + cuda=True, + multi_agent=True, + use_priority=False, + policy_entropy_loss_weight=policy_entropy_loss_weight, + ignore_done=ignore_done, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + discount_factor=0.997, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +mujoco_sampled_efficientzero_config = EasyDict(mujoco_sampled_efficientzero_config) +main_config = mujoco_sampled_efficientzero_config + +mujoco_sampled_efficientzero_create_config = dict( + env=dict( + type='multiagent_mujoco_lightzero', + import_names=['zoo.multiagent_mujoco.envs.multiagent_mujoco_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), +) +mujoco_sampled_efficientzero_create_config = EasyDict(mujoco_sampled_efficientzero_create_config) +create_config = mujoco_sampled_efficientzero_create_config + +if __name__ == "__main__": + from zoo.multiagent_mujoco.entry import train_sez_independent_mamujoco + + train_sez_independent_mamujoco([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/multiagent_mujoco/entry/__init__.py b/zoo/multiagent_mujoco/entry/__init__.py new file mode 100644 index 000000000..c99ded2fc --- /dev/null +++ b/zoo/multiagent_mujoco/entry/__init__.py @@ -0,0 +1 @@ +from .train_sez_independent_mamujoco import train_sez_independent_mamujoco \ No newline at end of file diff --git a/zoo/multiagent_mujoco/entry/train_sez_independent_mamujoco.py b/zoo/multiagent_mujoco/entry/train_sez_independent_mamujoco.py new file mode 100644 index 000000000..588a923d9 --- /dev/null +++ b/zoo/multiagent_mujoco/entry/train_sez_independent_mamujoco.py @@ -0,0 +1,195 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.entry.utils import random_collect + + +def train_sez_independent_mamujoco( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ + "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" + + if create_cfg.policy.type == 'muzero': + from lzero.mcts import MuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import EfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'gumbel_muzero': + from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer + + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # load pretrained model + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if cfg.policy.update_per_collect is not None: + update_per_collect = cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py index 9448a9be6..d3103bf13 100644 --- a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py +++ b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py @@ -10,7 +10,7 @@ from dizoo.multiagent_mujoco.envs.multi_mujoco_env import MujocoEnv,MujocoMulti -@ENV_REGISTRY.register('mujoco_lightzero') +@ENV_REGISTRY.register('multiagent_mujoco_lightzero') class MAMujocoEnvLZ(MujocoEnv): """ Overview: From 1e0454554626e31b6d98e66ecbea3a55d0d0f99e Mon Sep 17 00:00:00 2001 From: nighood Date: Fri, 8 Dec 2023 22:42:18 +0800 Subject: [PATCH 4/8] algo(rjy): fix forward_learn and game_buffer --- lzero/mcts/buffer/game_buffer_muzero.py | 3 +- .../game_buffer_sampled_efficientzero.py | 11 +- lzero/policy/sampled_efficientzero.py | 105 ++++++++++++------ lzero/policy/scaling_transform.py | 4 +- ...ent_mujoco_sampled_efficientzero_config.py | 13 ++- 5 files changed, 88 insertions(+), 48 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index a09148e2a..acd8f8956 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -678,7 +678,8 @@ def _compute_target_policy_non_reanalyzed( policy_index += 1 batch_target_policies_non_re.append(target_policies) - batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) + if not self._multi_agent: + batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) return batch_target_policies_non_re def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index 2e5cda8a8..ef1a1b27f 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -428,15 +428,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_value_prefixs.append(value_prefix) else: target_values.append(np.zeros_like(value_list[0])) - target_value_prefixs.append(value_prefix) + target_value_prefixs.append(np.array([0,])) value_index += 1 batch_value_prefixs.append(target_value_prefixs) batch_target_values.append(target_values) - batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=np.float32) - batch_target_values = np.asarray(batch_target_values, dtype=np.float32) + if not self._multi_agent: + batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=np.float32) + batch_target_values = np.asarray(batch_target_values, dtype=np.float32) return batch_value_prefixs, batch_target_values @@ -585,8 +586,8 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: policy_index += 1 batch_target_policies_re.append(target_policies) - - batch_target_policies_re = np.array(batch_target_policies_re) + if not self._multi_agent: + batch_target_policies_re = np.array(batch_target_policies_re) return batch_target_policies_re, root_sampled_actions diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 763bc8ff6..5bda160e5 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -302,6 +302,22 @@ def _init_learn(self) -> None: ) self._multi_agent = self._cfg.model.multi_agent + def _prepocess_data(self, data_list): + def get_depth(lst): + if not isinstance(lst, list): + return 0 + return 1 + get_depth(lst[0]) + for i in range(len(data_list)): + depth = get_depth(data_list[i]) + if depth != 0: + for _ in range(depth): + data_list[i] = default_collate(data_list[i]) + data_list[i] = to_dtype(to_device(data_list[i], self._cfg.device), torch.float) + data_list[i] = data_list[i].permute(*range(depth-1, -1, -1), *range(depth, data_list[i].dim())) + else: + data_list[i] = to_dtype(to_device(to_tensor(data_list[i]), self._cfg.device), torch.float) + return data_list + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ Overview: @@ -336,16 +352,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # shape: (batch_size, num_unroll_steps, action_dim) # NOTE: .float(), in continuous action space. if self._cfg.model.multi_agent: - action_batch = to_dtype(to_device(to_tensor(action_batch), self._cfg.device), torch.float) - action_batch = default_collate(default_collate(action_batch)) # (num_unroll_steps, batch_size, action_dim, 1) - action_batch = action_batch.transpose(0, 1) # (batch_size, num_unroll_steps, action_dim, 1) - mask_batch = to_dtype(default_collate(mask_batch), torch.float) - data_list = [ - target_value_prefix.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [target_value_prefix, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) + data_list = [action_batch, mask_batch, target_value_prefix, target_value, target_policy, weights, child_sampled_actions_batch] + [action_batch, mask_batch, target_value_prefix, target_value, + target_policy, weights, child_sampled_actions_batch] = self._prepocess_data(data_list) else: action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1) data_list = [ @@ -355,15 +364,10 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: ] [mask_batch, target_value_prefix, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) - # ============================================================== - # sampled related core code - # ============================================================== - # shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1) - if self._cfg.model.multi_agent: - child_sampled_actions_batch = default_collate(default_collate(child_sampled_actions_batch)) - child_sampled_actions_batch = to_dtype(to_device(child_sampled_actions_batch, self._cfg.device), torch.float) - child_sampled_actions_batch = child_sampled_actions_batch.transpose(0, 1) - else: + # ============================================================== + # sampled related core code + # ============================================================== + # shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1) child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1) target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) target_value = target_value.view(self._cfg.batch_size, -1) @@ -398,15 +402,32 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: ).detach().cpu() # calculate the new priorities for each transition. - value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) - value_priority = value_priority.data.cpu().numpy() + 1e-6 + if self._cfg.use_priority: + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + else: + value_priority = np.ones(self._cfg.model.agent_num*self._cfg.batch_size) # ============================================================== # calculate policy and value loss for the first step. # ============================================================== - value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) - - policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + if self._multi_agent: + # (B, unroll_step, agent_num, 601) -> (B*agent_num, unroll_step, 601) + target_value_categorical = target_value_categorical.transpose(1, 2) + target_value_categorical = target_value_categorical.reshape((-1, *target_value_categorical.shape[2:])) + # (B, unroll_step, agent_num, action_dim) -> (B*agent_num, unroll_step, action_dim) + action_batch = action_batch.transpose(1,2) + action_batch = action_batch.reshape((-1, *action_batch.shape[2:])) + + target_value_prefix_categorical = torch.repeat_interleave(target_value_prefix_categorical, repeats=self._cfg.model.agent_num, dim=0) + + weights = torch.repeat_interleave(weights, repeats=self._cfg.model.agent_num) + # value shape (B*agent_num, 601) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + policy_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + else: + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) # ============================================================== # sampled related core code: calculate policy loss, typically cross_entropy_loss # ============================================================== @@ -421,8 +442,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 ) - value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + if self._multi_agent: + value_prefix_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + else: + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) # ============================================================== # the core recurrent_inference in SampledEfficientZero policy. @@ -494,10 +519,16 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # reset hidden states every ``lstm_horizon_len`` unroll steps. if (step_k + 1) % self._cfg.lstm_horizon_len == 0: - reward_hidden_state = ( - torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), - torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + if self._multi_agent: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size*self._cfg.model.agent_num, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size*self._cfg.model.agent_num, self._cfg.model.lstm_hidden_size).to(self._cfg.device) ) + else: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) if self._cfg.monitor_extra_statistics: original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) @@ -629,7 +660,9 @@ def _calculate_policy_loss_cont( # take the init hypothetical step k=unroll_step target_normalized_visit_count = target_policy[:, unroll_step] - + if self._multi_agent: + target_normalized_visit_count = target_normalized_visit_count.reshape((self._cfg.batch_size*self._cfg.model.agent_num, -1)) + mask_batch = torch.repeat_interleave(mask_batch, repeats=3, dim=0) # ******* NOTE: target_policy_entropy is only for debug. ****** non_masked_indices = torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) # Check if there are any unmasked rows @@ -639,13 +672,17 @@ def _calculate_policy_loss_cont( ) target_dist = Categorical(target_normalized_visit_count_masked) target_policy_entropy = target_dist.entropy().mean() + + # shape: (batch_size, num_unroll_steps, num_agent, num_of_sampled_actions, action_dim, 1) -> (batch_size*num_agent, + # num_of_sampled_actions, action_dim) e.g. (4, 6, 3, 20, 2, 1) -> (12, 20, 2) + target_sampled_actions = child_sampled_actions_batch[:, unroll_step].view(-1, *child_sampled_actions_batch[:, unroll_step].shape[2:]) else: # Set target_policy_entropy to 0 if all rows are masked target_policy_entropy = 0 - # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) + # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, + # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) + target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) policy_entropy = dist.entropy().mean() policy_entropy_loss = -dist.entropy() @@ -669,7 +706,7 @@ def _calculate_policy_loss_cont( # NOTE: for numerical stability. target_sampled_actions_clamped = torch.clamp( - target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6), torch.tensor(1 - 1e-6) + target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6).to(self._cfg.device), torch.tensor(1 - 1e-6).to(self._cfg.device) ) target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index d2ba13126..aa20ce1de 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -139,8 +139,8 @@ def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.T target = torch.zeros(*x.shape, set_size).to(x.device) x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta - target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) - target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) + target.scatter_(target.dim()-1, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) + target.scatter_(target.dim()-1, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) return target diff --git a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py index f9a5e9b6e..0a5732131 100644 --- a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py +++ b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py @@ -1,6 +1,6 @@ from easydict import EasyDict import os -os.environ["CUDA_VISIBLE_DEVICES"] = '7' +os.environ["CUDA_VISIBLE_DEVICES"] = '6' # options={'Hopper-v2', 'HalfCheetah-v2', 'Walker2d-v2', 'Ant-v2', 'Humanoid-v2'} env_name = 'Hopper-v2' @@ -33,14 +33,14 @@ # begin of the most frequently changed config specified by the user # ============================================================== seed = 0 -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 3 +collector_env_num = 3 +n_episode = 3 +evaluator_env_num = 1 continuous_action_space = True K = 20 # num_of_sampled_actions num_simulations = 50 -update_per_collect = 200 -batch_size = 256 +update_per_collect = 5 +batch_size = 16 max_env_step = int(5e6) reanalyze_ratio = 0. @@ -86,6 +86,7 @@ cuda=True, multi_agent=True, use_priority=False, + ssl_loss_weight=0, policy_entropy_loss_weight=policy_entropy_loss_weight, ignore_done=ignore_done, env_type='not_board_games', From c54c0a549e81c22c1fe1e8905fb2142e75f73263 Mon Sep 17 00:00:00 2001 From: nighood Date: Sat, 9 Dec 2023 00:16:46 +0800 Subject: [PATCH 5/8] algo(rjy): add pipeline of sez ma (train+eval) --- lzero/policy/sampled_efficientzero.py | 95 ++++++++++++------- lzero/worker/muzero_evaluator.py | 12 ++- ...ent_mujoco_sampled_efficientzero_config.py | 3 +- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 5bda160e5..e54661321 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -1034,7 +1034,16 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._eval_model.eval() - active_eval_env_num = data.shape[0] + if isinstance(data, dict): + # If data is a dictionary, find the first non-dictionary element and get its shape[0] + for k, v in data.items(): + if not isinstance(v, dict): + active_eval_env_num = v.shape[0]*v.shape[1] + agent_num = v.shape[1] # multi-agent + elif isinstance(data, torch.Tensor): + # If data is a torch.tensor, directly return its shape[0] + active_eval_env_num = data.shape[0] + agent_num = 1 # single-agent with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._eval_model.initial_inference(data) @@ -1088,6 +1097,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read roots_sampled_actions = roots.get_sampled_actions( ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + if self._multi_agent: + active_eval_env_num = active_eval_env_num // agent_num data_id = [i for i in range(active_eval_env_num)] output = {i: None for i in data_id} @@ -1095,44 +1106,56 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ready_env_id = np.arange(active_eval_env_num) for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - try: - root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') - root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. - action, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # ============================================================== - # sampled related core code - # ============================================================== + output[env_id] = { + 'action': [], + 'visit_count_distributions': [], + 'root_sampled_actions': [], + 'visit_count_distribution_entropy': [], + 'searched_value': [], + 'predicted_value': [], + 'predicted_policy_logits': [], + } + for j in range(agent_num): + index = i * agent_num + j + distributions, value = roots_visit_count_distributions[index], roots_values[index] + try: + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[index]]) + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + root_sampled_actions = np.array([action for action in roots_sampled_actions[index]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # ============================================================== + # sampled related core code + # ============================================================== - try: - action = roots_sampled_actions[i][action].value - # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') - action = np.array(roots_sampled_actions[i][action]) + try: + action = roots_sampled_actions[index][action].value + # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + action = np.array(roots_sampled_actions[index][action]) - if not self._cfg.model.continuous_action_space: - if len(action.shape) == 0: - action = int(action) - elif len(action.shape) == 1: - action = int(action[0]) + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'root_sampled_actions': root_sampled_actions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } + output[env_id]['action'].append(action) + output[env_id]['visit_count_distributions'].append(distributions) + output[env_id]['root_sampled_actions'].append(root_sampled_actions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['searched_value'].append(value) + output[env_id]['predicted_value'].append(pred_values[index]) + output[env_id]['predicted_policy_logits'].append(policy_logits[index]) + + for k,v in output[env_id].items(): + output[env_id][k] = np.array(v) return output diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 313a07e07..ccb4997ac 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -11,6 +11,8 @@ from ding.utils import get_world_size, get_rank, broadcast_object_list from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor from easydict import EasyDict +from ding.torch_utils import to_ndarray, to_device +from ding.utils.data import default_collate from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation @@ -271,18 +273,18 @@ def eval( ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} stack_obs = list(stack_obs.values()) + stack_obs = default_collate(stack_obs) + if not isinstance(stack_obs, dict): + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = to_device(stack_obs, self.policy_config.device) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] - stack_obs = to_ndarray(stack_obs) - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - # ============================================================== # policy forward # ============================================================== diff --git a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py index 0a5732131..65325a686 100644 --- a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py +++ b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py @@ -101,7 +101,8 @@ num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, - eval_freq=int(2e3), + # eval_freq=int(2e3), + eval_freq=int(2), replay_buffer_size=int(1e6), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, From f888f6f088ad27b5c58ba197569278781cb7d816 Mon Sep 17 00:00:00 2001 From: nighood Date: Sat, 9 Dec 2023 16:01:43 +0800 Subject: [PATCH 6/8] fix(rjy): fix config --- ...iagent_mujoco_sampled_efficientzero_config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py index 65325a686..058f0d0cd 100644 --- a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py +++ b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py @@ -33,16 +33,16 @@ # begin of the most frequently changed config specified by the user # ============================================================== seed = 0 -collector_env_num = 3 -n_episode = 3 -evaluator_env_num = 1 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 continuous_action_space = True K = 20 # num_of_sampled_actions num_simulations = 50 update_per_collect = 5 -batch_size = 16 +batch_size = 256 -max_env_step = int(5e6) +max_env_step = int(5e5) reanalyze_ratio = 0. policy_entropy_loss_weight = 0.005 @@ -52,7 +52,7 @@ mujoco_sampled_efficientzero_config = dict( exp_name= - f'marl_result/debug/{env_name[:-3]}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs-{batch_size}_pelw{policy_entropy_loss_weight}_seed{seed}', + f'marl_result/mamujoco/{env_name[:-3]}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs-{batch_size}_pelw{policy_entropy_loss_weight}_seed{seed}', env=dict( env_name=env_name, scenario=env_name, @@ -101,8 +101,8 @@ num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, - # eval_freq=int(2e3), - eval_freq=int(2), + eval_freq=int(2e3), + # eval_freq=int(2), replay_buffer_size=int(1e6), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, From 27420d4195b77a9d53e65594aa430d91c3b2c8d0 Mon Sep 17 00:00:00 2001 From: nighood Date: Fri, 14 Jun 2024 15:28:27 +0800 Subject: [PATCH 7/8] polish(rjy): polish comments of mamujoco --- .../envs/multiagent_mujoco_lightzero_env.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py index d3103bf13..164ac160d 100644 --- a/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py +++ b/zoo/multiagent_mujoco/envs/multiagent_mujoco_lightzero_env.py @@ -14,7 +14,10 @@ class MAMujocoEnvLZ(MujocoEnv): """ Overview: - The modified Multi-agentMuJoCo environment with continuous action space for LightZero's algorithms. + The modified Multi-agentMuJoCo environment with continuous action space for LightZero's algorithms. \ + You can find the original implementation at \ + [Multi-Agent Mujoco](https://robotics.farama.org/envs/MaMuJoCo/index.html). The class is registered \ + in ENV_REGISTRY with the key 'multiagent_mujoco_lightzero'. """ config = dict( @@ -24,6 +27,18 @@ class MAMujocoEnvLZ(MujocoEnv): ) def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialize the Multi-agent MuJoCo environment. + Arguments: + - cfg (:obj:`dict`): Config dict. The following keys must be specified: + - 'env_name' (:obj:`str`): The name of the environment. + - 'scenario' (:obj:`str`): The scenario of the environment. + - 'agent_conf' (:obj:`str`): The configuration of the agents. + - 'agent_obsk' (:obj:`int`): The observation space of the agents. + - 'add_agent_id' (:obj:`bool`): Whether to add agent id to the observation. + - 'episode_limit' (:obj:`int`): The maximum number of episodes. + """ super().__init__(cfg) self._cfg = cfg # We use env_name to indicate the env_id in LightZero. @@ -31,6 +46,14 @@ def __init__(self, cfg: dict) -> None: self._init_flag = False def reset(self) -> np.ndarray: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - obs (:obj:`np.ndarray`): The initial observation after resetting. The observation is a dict with keys \ + 'observation', 'action_mask', and 'to_play'. The 'observation' is a dict with keys 'agent_state' and \ + 'global_state'. + """ if not self._init_flag: self._env = MujocoMulti(env_args=self._cfg) self._init_flag = True @@ -79,6 +102,14 @@ def reset(self) -> np.ndarray: return obs def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: + """ + Overview: + Take a step in the environment with the given action. + Arguments: + - action (:obj:`np.ndarray`): The action to be taken. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The timestep information including observation, reward, done flag, and info. + """ action = to_ndarray(action) obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew From fc84583ad0419f1e77e97d065bf20688c2f4e50f Mon Sep 17 00:00:00 2001 From: nighood Date: Fri, 21 Jun 2024 17:39:00 +0800 Subject: [PATCH 8/8] fix(rjy): Divide the handling of single/multi-agentin the code into two separate branches. --- .../game_buffer_sampled_efficientzero.py | 33 +- lzero/mcts/buffer/game_segment.py | 82 +- lzero/policy/sampled_efficientzero.py | 613 +-------------- lzero/policy/sampled_efficientzero_ma.py | 702 ++++++++++++++++++ lzero/policy/scaling_transform.py | 16 +- lzero/policy/utils.py | 22 +- lzero/worker/muzero_collector.py | 36 +- lzero/worker/muzero_evaluator.py | 15 +- ...ent_mujoco_sampled_efficientzero_config.py | 4 +- 9 files changed, 857 insertions(+), 666 deletions(-) create mode 100644 lzero/policy/sampled_efficientzero_ma.py diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index ef1a1b27f..106ad0360 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -267,14 +267,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + # transition_batch_size = game_segment_batch_size * (num_unroll_steps + 1) transition_batch_size = len(value_obs_list) game_segment_batch_size = len(pos_in_game_segment_list) to_play, action_mask = self._preprocess_to_play_and_action_mask( game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list ) - if self._cfg.model.continuous_action_space is True: + if self._cfg.model.continuous_action_space: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) @@ -296,12 +296,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = to_dtype(to_device(to_tensor(value_obs_list[beg_index:end_index]), self._cfg.device), torch.float) - - # calculate the target value - m_obs = default_collate(m_obs) if self._multi_agent: + m_obs = to_dtype(to_device(to_tensor(value_obs_list[beg_index:end_index]), self._cfg.device), torch.float) + m_obs = default_collate(m_obs) m_obs = m_obs[0] + else: + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + # calculate the target value m_output = model.initial_inference(m_obs) # TODO(pu) @@ -355,8 +356,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A ) roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree.roots(self._cfg - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) + MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) roots_values = roots.get_values() value_list = np.array(roots_values) @@ -392,8 +392,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A horizon_id, value_index = 0, 0 for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): + pos_in_game_segment_list, + to_play_segment): target_values = [] target_value_prefixs = [] @@ -401,7 +401,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A base_index = state_index for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): for i, reward in enumerate(reward_list[current_index:bootstrap_index]): if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check @@ -423,13 +422,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) # Since the horizon is small and the discount_factor is close to 1. # Compute the reward sum to approximate the value prefix for simplification - value_prefix += reward_list[current_index - ] # * config.discount_factor ** (current_index - base_index) + value_prefix += reward_list[current_index] target_value_prefixs.append(value_prefix) else: - target_values.append(np.zeros_like(value_list[0])) - target_value_prefixs.append(np.array([0,])) - + if self._multi_agent: + target_values.append(np.zeros_like(value_list[0])) + target_value_prefixs.append(np.array([0,])) + else: + target_values.append(0) + target_value_prefixs.append(value_prefix) value_index += 1 batch_value_prefixs.append(target_value_prefixs) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 27aa19fb7..ebf6a098b 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -52,6 +52,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder + self._multi_agent = config.model.get('multi_agent', False) if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: # for vector obs input, e.g. classical control and box2d environments @@ -83,7 +84,6 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] - def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: """ Overview: @@ -97,8 +97,12 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool if padding: pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) if pad_len > 0: - pad_frames = [stacked_obs[-1] for _ in range(pad_len)] - stacked_obs += pad_frames + if self._multi_agent: + pad_frames = [stacked_obs[-1] for _ in range(pad_len)] + stacked_obs += pad_frames + else: + pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) + stacked_obs = np.concatenate((stacked_obs, pad_frames)) if self.transform2string: stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs @@ -113,7 +117,7 @@ def _zero_obs(self, input_data): else: # Process other types (e.g. numbers, strings, etc.) return input_data - + def zero_obs(self) -> List: """ Overview: @@ -121,7 +125,10 @@ def zero_obs(self) -> List: Returns: ndarray: An array filled with zeros. """ - return [self._zero_obs(self.obs_segment[0]) for _ in range(self.frame_stack_num)] + if self._multi_agent: + return [self._zero_obs(self.obs_segment[0]) for _ in range(self.frame_stack_num)] + else: + return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] def get_obs(self) -> List: """ @@ -224,9 +231,17 @@ def store_search_stats( Overview: store the visit count distributions and value of the root node after MCTS. """ - sum_visits = np.sum(visit_counts, axis=-1) + if self._multi_agent: + sum_visits = np.sum(visit_counts, axis=-1) + else: + sum_visits = sum(visit_counts) + if idx is None: - self.child_visit_segment.append([visit_count / sum_visits[i] for i,visit_count in enumerate(visit_counts)]) + if self._multi_agent: + self.child_visit_segment.append([visit_count / sum_visits[i] for i, visit_count in enumerate(visit_counts)]) + else: + self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) + self.root_value_segment.append(root_value) if self.sampled_algo: self.root_sampled_actions.append(root_sampled_actions) @@ -234,7 +249,10 @@ def store_search_stats( if self.gumbel_algo: self.improved_policy_probs.append(improved_policy) else: - self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] + if self._multi_agent: + self.child_visit_segment[idx] = [visit_count / sum_visits[i] for i, visit_count in enumerate(visit_counts)] + else: + self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] self.root_value_segment[idx] = root_value self.improved_policy_probs[idx] = improved_policy @@ -261,12 +279,12 @@ def game_segment_to_array(self) -> None: - game_segment_i (obs): 4 20 5 ----|----...----|-----| - game_segment_i+1 (obs): 4 20 5 - ----|----...----|-----| + ----|----...----|-----| - game_segment_i (rew): 20 5 4 ----...----|------|-----| - game_segment_i+1 (rew): 20 5 4 - ----...----|------|-----| + ----...----|------|-----| Postprocessing: - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment. @@ -284,26 +302,44 @@ def game_segment_to_array(self) -> None: For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. """ - self.obs_segment = to_ndarray(self.obs_segment) - self.action_segment = to_ndarray(self.action_segment) - self.reward_segment = to_ndarray(self.reward_segment) + if self._multi_agent: + self.obs_segment = to_ndarray(self.obs_segment) + self.action_segment = to_ndarray(self.action_segment) + self.reward_segment = to_ndarray(self.reward_segment) + else: + self.obs_segment = np.array(self.obs_segment) + self.action_segment = np.array(self.action_segment) + self.reward_segment = np.array(self.reward_segment) # Check if all elements in self.child_visit_segment have the same length if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): - self.child_visit_segment = to_ndarray(self.child_visit_segment) + if self._multi_agent: + self.child_visit_segment = to_ndarray(self.child_visit_segment) + else: + self.child_visit_segment = np.array(self.child_visit_segment) else: # In the case of environments with a variable action space, such as board games, # the elements in child_visit_segment may have different lengths. # In such scenarios, it is necessary to use the object data type. - self.child_visit_segment = to_ndarray(self.child_visit_segment, dtype=object) - - self.root_value_segment = to_ndarray(self.root_value_segment) - self.improved_policy_probs = to_ndarray(self.improved_policy_probs) - - self.action_mask_segment = to_ndarray(self.action_mask_segment) - self.to_play_segment = to_ndarray(self.to_play_segment) - if self.use_ture_chance_label_in_chance_encoder: - self.chance_segment = to_ndarray(self.chance_segment) + if self._multi_agent: + self.child_visit_segment = to_ndarray(self.child_visit_segment, dtype=object) + else: + self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) + + if self._multi_agent: + self.root_value_segment = to_ndarray(self.root_value_segment) + self.improved_policy_probs = to_ndarray(self.improved_policy_probs) + self.action_mask_segment = to_ndarray(self.action_mask_segment) + self.to_play_segment = to_ndarray(self.to_play_segment) + if self.use_ture_chance_label_in_chance_encoder: + self.chance_segment = to_ndarray(self.chance_segment) + else: + self.root_value_segment = np.array(self.root_value_segment) + self.improved_policy_probs = np.array(self.improved_policy_probs) + self.action_mask_segment = np.array(self.action_mask_segment) + self.to_play_segment = np.array(self.to_play_segment) + if self.use_ture_chance_label_in_chance_encoder: + self.chance_segment = np.array(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: """ diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index e54661321..f872c9d32 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -20,215 +20,29 @@ prepare_obs, \ configure_optimizers from lzero.policy.muzero import MuZeroPolicy +from lzero.policy.sampled_efficientzero import SampledEfficientZeroPolicy -@POLICY_REGISTRY.register('sampled_efficientzero') -class SampledEfficientZeroPolicy(MuZeroPolicy): +@POLICY_REGISTRY.register('sampled_efficientzero_ma') +class SampledEfficientZeroMAPolicy(SampledEfficientZeroPolicy): """ Overview: The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303. """ - # The default_config for Sampled EfficientZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) the stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=True, - # (int) The size of action space. For discrete action space, it is the number of actions. - # For continuous action space, it is the dimension of action. - action_space_size=6, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) the image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (int) The hidden size in LSTM. - lstm_hidden_size=512, - # (str) The type of sigma. options={'conditioned', 'fixed'} - sigma_type='conditioned', - # (float) The fixed sigma value. Only effective when ``sigma_type='fixed'``. - fixed_sigma_value=0.3, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) Whether to use multi-gpu training. - multi_gpu=False, - # (bool) ``sampled_algo=True`` means the policy is sampled-based algorithm (e.g. Sampled EfficientZero), which is used in ``collector``. - sampled_algo=True, - # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) - gumbel_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda in policy. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. The options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use gray scale image. - gray_scale=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ****** learn ****** - # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. - # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, - # we should set it to True to avoid the influence of the done flag. - ignore_done=False, - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. - update_per_collect=None, - # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', - learning_rate=0.2, # init lr for manually decay schedule - # optim_type='Adam', - # lr_piecewise_constant_decay=False, - # learning_rate=0.003, # lr for Adam optimizer - # (float) Weight uniform initialization range in the last output layer - init_w=3e-3, - normalize_prob_of_sampled_actions=False, - policy_loss_type='cross_entropy', # options={'cross_entropy', 'KL'} - # (int) Frequency of target network update. - target_update_freq=100, - weight_decay=1e-4, - momentum=0.9, - grad_clip_value=10, - # You can use either "n_sample" or "n_episode" in collector.collect. - # Get "n_episode" episodes per collect. - n_episode=8, - # (float) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of step for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. - lstm_horizon_len=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=2, - # (bool) Whether to use the cosine learning rate decay. - cos_lr_scheduler=False, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (bool) Whether to use manually decayed temperature. - # i.e. temperature: 1 -> 0.5 -> 0.25 - manual_temperature_decay=False, - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. - use_ture_chance_label_in_chance_encoder=False, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - - # ****** Explore by random collect ****** - # (int) The number of episodes to collect data randomly before training. - random_collect_episode_num=0, - - # ****** Explore by eps greedy ****** - eps=dict( - # (bool) Whether to use eps greedy exploration in collecting data. - eps_greedy_exploration_in_collect=False, - # (str) The type of decaying epsilon. Options are 'linear', 'exp'. - type='linear', - # (float) The start value of eps. - start=1., - # (float) The end value of eps. - end=0.05, - # (int) The decay steps from start to end eps. - decay=int(1e5), - ), - ) - def default_model(self) -> Tuple[str, List[str]]: """ Overview: - Return this algorithm default model setting. + Return this algorithm default model setting for multi-agent. Returns: - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For Sampled EfficientZero, ``lzero.model.sampled_efficientzero_model.SampledEfficientZeroModel`` + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. """ - if self._cfg.model.model_type == "conv": - return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] - elif self._cfg.model.model_type == "mlp": - if self._cfg.model.multi_agent is True: - return 'SampledEfficientZeroModelMLPMaIndependent', ['lzero.model.sampled_efficientzero_model_mlp_ma_independent'] - else: - return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] + # if self._cfg.model.model_type == "conv": + # return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] + if self._cfg.model.model_type == "mlp": + return 'SampledEfficientZeroModelMLPMaIndependent', ['lzero.model.sampled_efficientzero_model_mlp_ma_independent'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) @@ -237,69 +51,7 @@ def _init_learn(self) -> None: Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - if self._cfg.model.continuous_action_space: - # Weight Init for the last output layer of gaussian policy head in prediction network. - init_w = self._cfg.init_w - self._model.prediction_network.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) - self._model.prediction_network.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) - self._model.prediction_network.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) - try: - self._model.prediction_network.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) - except Exception as exception: - logging.warning(exception) - - if self._cfg.optim_type == 'SGD': - self._optimizer = optim.SGD( - self._model.parameters(), - lr=self._cfg.learning_rate, - momentum=self._cfg.momentum, - weight_decay=self._cfg.weight_decay, - ) - - elif self._cfg.optim_type == 'Adam': - self._optimizer = optim.Adam( - self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) - - if self._cfg.cos_lr_scheduler is True: - from torch.optim.lr_scheduler import CosineAnnealingLR - self.lr_scheduler = CosineAnnealingLR(self._optimizer, 1e6, eta_min=0, last_epoch=-1) - - if self._cfg.lr_piecewise_constant_decay: - from torch.optim.lr_scheduler import LambdaLR - max_step = self._cfg.threshold_training_steps_for_final_lr - # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq} - ) - self._learn_model = self._model - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) + super()._init_learn() self._multi_agent = self._cfg.model.multi_agent def _prepocess_data(self, data_list): @@ -610,7 +362,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'policy_sigma_max': sigma.max().item(), 'policy_sigma_min': sigma.min().item(), 'policy_sigma_mean': sigma.mean().item(), - # take the fist dim in action space + # take the first dim in action space 'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), @@ -621,7 +373,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # sampled related core code # ============================================================== - # take the fist dim in action space + # take the first dim in action space 'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), @@ -630,218 +382,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: return return_data - def _calculate_policy_loss_cont( - self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, - mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int - ) -> Tuple[torch.Tensor]: - """ - Overview: - Calculate the policy loss for continuous action space. - Arguments: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. - - target_policy (:obj:`torch.Tensor`): The target policy tensor. - - mask_batch (:obj:`torch.Tensor`): The mask tensor. - - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. - - unroll_step (:obj:`int`): The unroll step. - Returns: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. - - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. - - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. - - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. - - mu (:obj:`torch.Tensor`): The mu tensor. - - sigma (:obj:`torch.Tensor`): The sigma tensor. - """ - (mu, sigma - ) = policy_logits[:, :self._cfg.model.action_space_size], policy_logits[:, -self._cfg.model.action_space_size:] - - dist = Independent(Normal(mu, sigma), 1) - - # take the init hypothetical step k=unroll_step - target_normalized_visit_count = target_policy[:, unroll_step] - if self._multi_agent: - target_normalized_visit_count = target_normalized_visit_count.reshape((self._cfg.batch_size*self._cfg.model.agent_num, -1)) - mask_batch = torch.repeat_interleave(mask_batch, repeats=3, dim=0) - # ******* NOTE: target_policy_entropy is only for debug. ****** - non_masked_indices = torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) - # Check if there are any unmasked rows - if len(non_masked_indices) > 0: - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count, 0, non_masked_indices - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy = target_dist.entropy().mean() - - # shape: (batch_size, num_unroll_steps, num_agent, num_of_sampled_actions, action_dim, 1) -> (batch_size*num_agent, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 3, 20, 2, 1) -> (12, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].view(-1, *child_sampled_actions_batch[:, unroll_step].shape[2:]) - else: - # Set target_policy_entropy to 0 if all rows are masked - target_policy_entropy = 0 - - # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) - - policy_entropy = dist.entropy().mean() - policy_entropy_loss = -dist.entropy() - - # Project the sampled-based improved policy back onto the space of representable policies. calculate KL - # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is - # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for - # numerical stability. - target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) - log_prob_sampled_actions = [] - for k in range(self._cfg.model.num_of_sampled_actions): - # target_sampled_actions[:,i,:].shape: batch_size, action_dim -> 4,2 - # dist.log_prob(target_sampled_actions[:,i,:]).shape: batch_size -> 4 - # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) - - # way 1: - # log_prob = dist.log_prob(target_sampled_actions[:, k, :]) - - # way 2: SAC-like - y = 1 - target_sampled_actions[:, k, :].pow(2) - - # NOTE: for numerical stability. - target_sampled_actions_clamped = torch.clamp( - target_sampled_actions[:, k, :], torch.tensor(-1 + 1e-6).to(self._cfg.device), torch.tensor(1 - 1e-6).to(self._cfg.device) - ) - target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) - - # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum) - log_prob = dist.log_prob(target_sampled_actions_before_tanh).unsqueeze(-1) - log_prob = log_prob - torch.log(y + 1e-6).sum(-1, keepdim=True) - log_prob = log_prob.squeeze(-1) - - log_prob_sampled_actions.append(log_prob) - - # shape: (batch_size, num_of_sampled_actions) e.g. (4,20) - log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) - - if self._cfg.normalize_prob_of_sampled_actions: - # normalize the prob of sampled actions - prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( - -1 - ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() - # the above line is equal to the following line. - # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) - log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) - - # NOTE: the +=. - if self._cfg.policy_loss_type == 'KL': - # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) - policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] - elif self._cfg.policy_loss_type == 'cross_entropy': - # cross_entropy loss: - sum(p * log (q) ) - policy_loss += -torch.sum( - torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 - ) * mask_batch[:, unroll_step] - - return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - - def _calculate_policy_loss_disc( - self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, - mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int - ) -> Tuple[torch.Tensor]: - """ - Overview: - Calculate the policy loss for discrete action space. - Arguments: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_logits (:obj:`torch.Tensor`): The policy logits tensor. - - target_policy (:obj:`torch.Tensor`): The target policy tensor. - - mask_batch (:obj:`torch.Tensor`): The mask tensor. - - child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. - - unroll_step (:obj:`int`): The unroll step. - Returns: - - policy_loss (:obj:`torch.Tensor`): The policy loss tensor. - - policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. - - policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. - - target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. - - target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. - """ - prob = torch.softmax(policy_logits, dim=-1) - dist = Categorical(prob) - - # take the init hypothetical step k=unroll_step - target_normalized_visit_count = target_policy[:, unroll_step] - - # Note: The target_policy_entropy is just for debugging. - target_normalized_visit_count_masked = torch.index_select( - target_normalized_visit_count, 0, - torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) - ) - target_dist = Categorical(target_normalized_visit_count_masked) - target_policy_entropy = target_dist.entropy().mean() - - # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size, - # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2) - target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1) - - policy_entropy = dist.entropy().mean() - policy_entropy_loss = -dist.entropy() - - # Project the sampled-based improved policy back onto the space of representable policies. calculate KL - # loss (batch_size, num_of_sampled_actions) -> (4,20) target_normalized_visit_count is - # categorical distribution, the range of target_log_prob_sampled_actions is (-inf, 0), add 1e-6 for - # numerical stability. - target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) - - log_prob_sampled_actions = [] - for k in range(self._cfg.model.num_of_sampled_actions): - # target_sampled_actions[:,i,:] shape: (batch_size, action_dim) e.g. (4,2) - # dist.log_prob(target_sampled_actions[:,i,:]) shape: batch_size e.g. 4 - # dist is normal distribution, the range of log_prob_sampled_actions is (-inf, inf) - - if len(target_sampled_actions.shape) == 2: - target_sampled_actions = target_sampled_actions.unsqueeze(-1) - - log_prob = torch.log(prob.gather(-1, target_sampled_actions[:, k].long()).squeeze(-1) + 1e-6) - log_prob_sampled_actions.append(log_prob) - - # (batch_size, num_of_sampled_actions) e.g. (4,20) - log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) - - if self._cfg.normalize_prob_of_sampled_actions: - # normalize the prob of sampled actions - prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( - -1 - ).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() - # the above line is equal to the following line. - # prob_sampled_actions_norm = F.normalize(torch.exp(log_prob_sampled_actions), p=1., dim=-1, eps=1e-6) - log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) - - # NOTE: the +=. - if self._cfg.policy_loss_type == 'KL': - # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) - policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] - elif self._cfg.policy_loss_type == 'cross_entropy': - # cross_entropy loss: - sum(p * log (q) ) - policy_loss += -torch.sum( - torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 - ) * mask_batch[:, unroll_step] - - return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions - def _init_collect(self) -> None: """ Overview: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ - self._collect_model = self._model - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self._collect_mcts_temperature = 1 + super()._init_collect() self._multi_agent = self._cfg.model.multi_agent def _forward_collect( @@ -878,13 +424,13 @@ def _forward_collect( # TODO(rjy): written in recursive form for k, v in data.items(): if not isinstance(v, dict): - active_collect_env_num = v.shape[0]*v.shape[1] + active_collect_env_num = v.shape[0] * v.shape[1] agent_num = v.shape[1] # multi-agent elif isinstance(data, torch.Tensor): # If data is a torch.tensor, directly return its shape[0] - active_collect_env_num = data.shape[0] - agent_num = 1 # single-agent - + active_collect_env_num = data.shape[0] + agent_num = 1 # single-agent + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -994,7 +540,7 @@ def _forward_collect( output[env_id]['predicted_value'].append(pred_values[index]) output[env_id]['predicted_policy_logits'].append(policy_logits[index]) - for k,v in output[env_id].items(): + for k, v in output[env_id].items(): output[env_id][k] = np.array(v) return output @@ -1004,11 +550,7 @@ def _init_eval(self) -> None: Overview: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) + super()._init_eval() self._multi_agent = self._cfg.model.multi_agent def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): @@ -1038,12 +580,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read # If data is a dictionary, find the first non-dictionary element and get its shape[0] for k, v in data.items(): if not isinstance(v, dict): - active_eval_env_num = v.shape[0]*v.shape[1] + active_eval_env_num = v.shape[0] * v.shape[1] agent_num = v.shape[1] # multi-agent elif isinstance(data, torch.Tensor): # If data is a torch.tensor, directly return its shape[0] - active_eval_env_num = data.shape[0] - agent_num = 1 # single-agent + active_eval_env_num = data.shape[0] + agent_num = 1 # single-agent with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._eval_model.initial_inference(data) @@ -1154,112 +696,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read output[env_id]['predicted_value'].append(pred_values[index]) output[env_id]['predicted_policy_logits'].append(policy_logits[index]) - for k,v in output[env_id].items(): + for k, v in output[env_id].items(): output[env_id][k] = np.array(v) - return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - if self._cfg.model.continuous_action_space: - return [ - 'collect_mcts_temperature', - 'cur_lr', - 'total_loss', - 'weighted_total_loss', - 'policy_loss', - 'value_prefix_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_value_prefix', - 'target_value', - 'predicted_value_prefixs', - 'predicted_values', - 'transformed_target_value_prefix', - 'transformed_target_value', - - # ============================================================== - # sampled related core code - # ============================================================== - 'policy_entropy', - 'target_policy_entropy', - 'policy_mu_max', - 'policy_mu_min', - 'policy_mu_mean', - 'policy_sigma_max', - 'policy_sigma_min', - 'policy_sigma_mean', - # take the fist dim in action space - 'target_sampled_actions_max', - 'target_sampled_actions_min', - 'target_sampled_actions_mean', - 'total_grad_norm_before_clip', - ] - else: - return [ - 'collect_mcts_temperature', - 'cur_lr', - 'total_loss', - 'weighted_total_loss', - 'loss_mean', - 'policy_loss', - 'value_prefix_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_value_prefix', - 'target_value', - 'predicted_value_prefixs', - 'predicted_values', - 'transformed_target_value_prefix', - 'transformed_target_value', - - # ============================================================== - # sampled related core code - # ============================================================== - 'policy_entropy', - 'target_policy_entropy', - - # take the fist dim in action space - 'target_sampled_actions_max', - 'target_sampled_actions_min', - 'target_sampled_actions_mean', - 'total_grad_norm_before_clip', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass + return output \ No newline at end of file diff --git a/lzero/policy/sampled_efficientzero_ma.py b/lzero/policy/sampled_efficientzero_ma.py new file mode 100644 index 000000000..f872c9d32 --- /dev/null +++ b/lzero/policy/sampled_efficientzero_ma.py @@ -0,0 +1,702 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray +from ding.utils.data import default_collate, default_decollate +from ding.utils import POLICY_REGISTRY +from ditk import logging +from torch.distributions import Categorical, Independent, Normal +from torch.nn import L1Loss + +from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs, \ + configure_optimizers +from lzero.policy.muzero import MuZeroPolicy +from lzero.policy.sampled_efficientzero import SampledEfficientZeroPolicy + + +@POLICY_REGISTRY.register('sampled_efficientzero_ma') +class SampledEfficientZeroMAPolicy(SampledEfficientZeroPolicy): + """ + Overview: + The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303. + """ + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for multi-agent. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + """ + # if self._cfg.model.model_type == "conv": + # return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] + if self._cfg.model.model_type == "mlp": + return 'SampledEfficientZeroModelMLPMaIndependent', ['lzero.model.sampled_efficientzero_model_mlp_ma_independent'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + super()._init_learn() + self._multi_agent = self._cfg.model.multi_agent + + def _prepocess_data(self, data_list): + def get_depth(lst): + if not isinstance(lst, list): + return 0 + return 1 + get_depth(lst[0]) + for i in range(len(data_list)): + depth = get_depth(data_list[i]) + if depth != 0: + for _ in range(depth): + data_list[i] = default_collate(data_list[i]) + data_list[i] = to_dtype(to_device(data_list[i], self._cfg.device), torch.float) + data_list[i] = data_list[i].permute(*range(depth-1, -1, -1), *range(depth, data_list[i].dim())) + else: + data_list[i] = to_dtype(to_device(to_tensor(data_list[i]), self._cfg.device), torch.float) + return data_list + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + # ============================================================== + # sampled related core code + # ============================================================== + obs_batch_ori, action_batch, child_sampled_actions_batch, mask_batch, indices, weights, make_time = current_batch + target_value_prefix, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .float(), in continuous action space. + if self._cfg.model.multi_agent: + data_list = [action_batch, mask_batch, target_value_prefix, target_value, target_policy, weights, child_sampled_actions_batch] + [action_batch, mask_batch, target_value_prefix, target_value, + target_policy, weights, child_sampled_actions_batch] = self._prepocess_data(data_list) + else: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1) + data_list = [ + mask_batch, + target_value_prefix.astype('float32'), + target_value.astype('float32'), target_policy, weights + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + # ============================================================== + # sampled related core code + # ============================================================== + # shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1) + child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1) + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_value_prefix = scalar_transform(target_value_prefix) + transformed_target_value = scalar_transform(target_value) + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in SampledEfficientZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for logging. + predicted_value_prefixs = [] + if self._cfg.monitor_extra_statistics: + latent_state_list = latent_state.detach().cpu().numpy() + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + if self._cfg.use_priority: + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + else: + value_priority = np.ones(self._cfg.model.agent_num*self._cfg.batch_size) + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + if self._multi_agent: + # (B, unroll_step, agent_num, 601) -> (B*agent_num, unroll_step, 601) + target_value_categorical = target_value_categorical.transpose(1, 2) + target_value_categorical = target_value_categorical.reshape((-1, *target_value_categorical.shape[2:])) + # (B, unroll_step, agent_num, action_dim) -> (B*agent_num, unroll_step, action_dim) + action_batch = action_batch.transpose(1,2) + action_batch = action_batch.reshape((-1, *action_batch.shape[2:])) + + target_value_prefix_categorical = torch.repeat_interleave(target_value_prefix_categorical, repeats=self._cfg.model.agent_num, dim=0) + + weights = torch.repeat_interleave(weights, repeats=self._cfg.model.agent_num) + # value shape (B*agent_num, 601) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + policy_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + else: + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + # ============================================================== + # sampled related core code: calculate policy loss, typically cross_entropy_loss + # ============================================================== + if self._cfg.model.continuous_action_space: + """continuous action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 + ) + else: + """discrete action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( + policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 + ) + + if self._multi_agent: + value_prefix_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size*self._cfg.model.agent_num, device=self._cfg.device) + else: + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + # ============================================================== + # the core recurrent_inference in SampledEfficientZero policy. + # ============================================================== + for step_k in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, + # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, reward_hidden_state, action_batch[:, step_k] + ) + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( + network_output + ) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch. + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in + # game buffer now. + # ============================================================== + # sampled related core code: + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the += in policy loss. + # ============================================================== + if self._cfg.model.continuous_action_space: + """continuous action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + policy_loss, + policy_logits, + target_policy, + mask_batch, + child_sampled_actions_batch, + unroll_step=step_k + 1 + ) + else: + """discrete action space""" + policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( + policy_loss, + policy_logits, + target_policy, + mask_batch, + child_sampled_actions_batch, + unroll_step=step_k + 1 + ) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) + + # reset hidden states every ``lstm_horizon_len`` unroll steps. + if (step_k + 1) % self._cfg.lstm_horizon_len == 0: + if self._multi_agent: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size*self._cfg.model.agent_num, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size*self._cfg.model.agent_num, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + else: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + + if self._cfg.monitor_extra_statistics: + original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) + original_value_prefixs_cpu = original_value_prefixs.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_value_prefixs.append(original_value_prefixs_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + + self._cfg.policy_entropy_loss_weight * policy_entropy_loss + ) + weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.cos_lr_scheduler or self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + if self._cfg.monitor_extra_statistics: + predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) + predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) + + return_data = { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': value_prefix_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.flatten().mean().item(), + 'value_priority_orig': value_priority, + 'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + 'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), + 'predicted_values': predicted_values.detach().cpu().numpy().mean().item() + } + + if self._cfg.model.continuous_action_space: + return_data.update({ + # ============================================================== + # sampled related core code + # ============================================================== + 'policy_mu_max': mu[:, 0].max().item(), + 'policy_mu_min': mu[:, 0].min().item(), + 'policy_mu_mean': mu[:, 0].mean().item(), + 'policy_sigma_max': sigma.max().item(), + 'policy_sigma_min': sigma.min().item(), + 'policy_sigma_mean': sigma.mean().item(), + # take the first dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() + }) + else: + return_data.update({ + # ============================================================== + # sampled related core code + # ============================================================== + # take the first dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() + }) + + return return_data + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + super()._init_collect() + self._multi_agent = self._cfg.model.multi_agent + + def _forward_collect( + self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, + epsilon: float = 0.25, ready_env_id: np.array = None, + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + if isinstance(data, dict): + # If data is a dictionary, find the first non-dictionary element and get its shape[0] + # TODO(rjy): written in recursive form + for k, v in data.items(): + if not isinstance(v, dict): + active_collect_env_num = v.shape[0] * v.shape[1] + agent_num = v.shape[1] # multi-agent + elif isinstance(data, torch.Tensor): + # If data is a torch.tensor, directly return its shape[0] + active_collect_env_num = data.shape[0] + agent_num = 1 # single-agent + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num) + ] + else: + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) + ] + + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots( + active_collect_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + # python mcts_tree + roots = MCTSPtree.roots( + active_collect_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + roots_sampled_actions = roots.get_sampled_actions() # {list: 1}->{list:6} + + if self._multi_agent: + active_collect_env_num = active_collect_env_num // agent_num + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i, env_id in enumerate(ready_env_id): + output[env_id] = { + 'action': [], + 'visit_count_distributions': [], + 'root_sampled_actions': [], + 'visit_count_distribution_entropy': [], + 'searched_value': [], + 'predicted_value': [], + 'predicted_policy_logits': [], + } + for j in range(agent_num): + index = i * agent_num + j + distributions, value = roots_visit_count_distributions[index], roots_values[index] + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. + root_sampled_actions = np.array([action for action in roots_sampled_actions[index]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[index]]) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. + action = np.array(roots_sampled_actions[index][action]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + action = roots_sampled_actions[index][action].value + + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) + + output[env_id]['action'].append(action) + output[env_id]['visit_count_distributions'].append(distributions) + output[env_id]['root_sampled_actions'].append(root_sampled_actions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['searched_value'].append(value) + output[env_id]['predicted_value'].append(pred_values[index]) + output[env_id]['predicted_policy_logits'].append(policy_logits[index]) + + for k, v in output[env_id].items(): + output[env_id][k] = np.array(v) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + super()._init_eval() + self._multi_agent = self._cfg.model.multi_agent + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + if isinstance(data, dict): + # If data is a dictionary, find the first non-dictionary element and get its shape[0] + for k, v in data.items(): + if not isinstance(v, dict): + active_eval_env_num = v.shape[0] * v.shape[1] + agent_num = v.shape[1] # multi-agent + elif isinstance(data, torch.Tensor): + # If data is a torch.tensor, directly return its shape[0] + active_eval_env_num = data.shape[0] + agent_num = 1 # single-agent + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_eval_env_num) + ] + else: + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num) + ] + + # cpp mcts_tree + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + active_eval_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + # python mcts_tree + roots = MCTSPtree.roots( + active_eval_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + # ============================================================== + # sampled related core code + # ============================================================== + roots_sampled_actions = roots.get_sampled_actions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + + if self._multi_agent: + active_eval_env_num = active_eval_env_num // agent_num + data_id = [i for i in range(active_eval_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i, env_id in enumerate(ready_env_id): + output[env_id] = { + 'action': [], + 'visit_count_distributions': [], + 'root_sampled_actions': [], + 'visit_count_distribution_entropy': [], + 'searched_value': [], + 'predicted_value': [], + 'predicted_policy_logits': [], + } + for j in range(agent_num): + index = i * agent_num + j + distributions, value = roots_visit_count_distributions[index], roots_values[index] + try: + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[index]]) + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + root_sampled_actions = np.array([action for action in roots_sampled_actions[index]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than sampling during the evaluation phase. + action, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # ============================================================== + # sampled related core code + # ============================================================== + + try: + action = roots_sampled_actions[index][action].value + # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') + except Exception: + # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + action = np.array(roots_sampled_actions[index][action]) + + if not self._cfg.model.continuous_action_space: + if len(action.shape) == 0: + action = int(action) + elif len(action.shape) == 1: + action = int(action[0]) + + output[env_id]['action'].append(action) + output[env_id]['visit_count_distributions'].append(distributions) + output[env_id]['root_sampled_actions'].append(root_sampled_actions) + output[env_id]['visit_count_distribution_entropy'].append(visit_count_distribution_entropy) + output[env_id]['searched_value'].append(value) + output[env_id]['predicted_value'].append(pred_values[index]) + output[env_id]['predicted_policy_logits'].append(policy_logits[index]) + + for k, v in output[env_id].items(): + output[env_id][k] = np.array(v) + + return output \ No newline at end of file diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index aa20ce1de..c364195df 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -118,7 +118,7 @@ def visit_count_temperature( return fixed_temperature_value -def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor: +def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor, multi_agent: bool = False) -> torch.Tensor: """ Overview: We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations. @@ -137,10 +137,16 @@ def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.T p_high = x - x_low p_low = 1 - p_high - target = torch.zeros(*x.shape, set_size).to(x.device) - x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta - target.scatter_(target.dim()-1, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) - target.scatter_(target.dim()-1, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) + if multi_agent: + target = torch.zeros(*x.shape, set_size).to(x.device) + dim = target.dim() - 1 + else: + target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) + dim = 2 + + x_high_idx, x_low_idx = (x_high - min) / delta, (x_low - min) / delta + target.scatter_(dim, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) + target.scatter_(dim, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) return target diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 760d12cce..94ddb4715 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -312,23 +312,21 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, obs_shape: 4 4 4 4 4 4 ----, ----, ----, ----, ----, ----, """ - # obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() - obs_batch_ori = to_dtype(to_device(to_tensor(obs_batch_ori), cfg.device), torch.float) - # ``obs_batch`` is used in ``initial_inference()``, which is the first stacked obs at timestep t1 in - # ``obs_batch_ori``. shape is (4, 4*3) = (4, 12) if cfg.model.multi_agent: + # ``obs_batch`` is used in ``initial_inference()``, which is the first stacked obs at timestep t1 in + # ``obs_batch_ori``. shape is (4, 4*3) = (4, 12) + obs_batch_ori = to_dtype(to_device(to_tensor(obs_batch_ori), cfg.device), torch.float) obs_batch_ori = default_collate(obs_batch_ori) obs_batch = obs_batch_ori[0] + if cfg.model.self_supervised_learning_loss: + # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than + # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. + obs_target_batch = obs_batch_ori[1:] + obs_target_batch = default_collate(obs_target_batch) # {'agent_state': (num_unroll_steps, batch_size, agent_num, obs_shape) else: + obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() obs_batch = obs_batch_ori[:, 0:cfg.model.frame_stack_num * cfg.model.observation_shape] - - if cfg.model.self_supervised_learning_loss: - # ``obs_target_batch`` is only used for calculate consistency loss, which take the all obs other than - # timestep t1, and is only performed in the last 8 timesteps in the second dim in ``obs_batch_ori``. - if cfg.model.multi_agent: - obs_target_batch = obs_batch_ori[1:] - obs_target_batch = default_collate(obs_target_batch) # {'agent_state': (num_unroll_steps, batch_size, agent_num, obs_shape) - else: + if cfg.model.self_supervised_learning_loss: obs_target_batch = obs_batch_ori[:, cfg.model.observation_shape:] return obs_batch, obs_target_batch diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 2053e63ef..d7909de5e 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -398,12 +398,17 @@ def collect(self, ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - stack_obs = default_collate(stack_obs) - if not isinstance(stack_obs, dict): + if self._multi_agent: + stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + stack_obs = default_collate(stack_obs) + stack_obs = to_device(stack_obs, self.policy_config.device) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = to_device(stack_obs, self.policy_config.device) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} @@ -646,15 +651,21 @@ def collect(self, last_game_priorities[env_id] = None # log - # self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] + if self._multi_agent: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id].sum() + eps_steps_lst[env_id] = np.zeros(self._agent_num) + visit_entropies_lst[env_id] = np.zeros(self._agent_num) + else: + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 self_play_episodes += 1 pred_values_lst[env_id] = [] search_values_lst[env_id] = [] - eps_steps_lst[env_id] = np.zeros(self._agent_num) - visit_entropies_lst[env_id] = np.zeros(self._agent_num) # Env reset is done by env_manager automatically self._policy.reset([env_id]) @@ -673,11 +684,6 @@ def collect(self, } for i in range(len(self.game_segment_pool)) ] self.game_segment_pool.clear() - # for i in range(len(self.game_segment_pool)): - # print(self.game_segment_pool[i][0].obs_segment.__len__()) - # print(self.game_segment_pool[i][0].reward_segment) - # for i in range(len(return_data[0])): - # print(return_data[0][i].reward_segment) break collected_duration = sum([d['time'] for d in self._episode_info]) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index ccb4997ac..3adb66cc4 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -273,12 +273,17 @@ def eval( ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - stack_obs = default_collate(stack_obs) - if not isinstance(stack_obs, dict): + if self._multi_agent: + stack_obs = {env_id: game_segments[env_id].get_obs()[0] for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + stack_obs = default_collate(stack_obs) + stack_obs = to_device(stack_obs, self.policy_config.device) + else: + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = to_device(stack_obs, self.policy_config.device) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} diff --git a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py index 058f0d0cd..b131863f0 100644 --- a/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py +++ b/zoo/multiagent_mujoco/config/multiagent_mujoco_sampled_efficientzero_config.py @@ -119,8 +119,8 @@ ), env_manager=dict(type='base'), policy=dict( - type='sampled_efficientzero', - import_names=['lzero.policy.sampled_efficientzero'], + type='sampled_efficientzero_ma', + import_names=['lzero.policy.sampled_efficientzero_ma'], ), ) mujoco_sampled_efficientzero_create_config = EasyDict(mujoco_sampled_efficientzero_create_config)