From 97988e22780751d0d264ededb8aafa9c390e1ccf Mon Sep 17 00:00:00 2001 From: puyuan Date: Fri, 7 Feb 2025 21:20:26 +0800 Subject: [PATCH] polish(pu): add moco multigpu support --- lzero/policy/sampled_unizero_multitask.py | 36 ++++++++++-- ...te_suz_multitask_ddp_8games_moco_config.py | 55 ++++++++++--------- 2 files changed, 60 insertions(+), 31 deletions(-) diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py index 54a717199..955c3f09b 100644 --- a/lzero/policy/sampled_unizero_multitask.py +++ b/lzero/policy/sampled_unizero_multitask.py @@ -27,6 +27,8 @@ ) from lzero.policy.unizero import UniZeroPolicy from .utils import configure_optimizers_nanogpt +import torch.nn.functional as F +import torch.distributed as dist import sys sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/') from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect @@ -64,7 +66,7 @@ def parameters(self): list(self.tokenizer.parameters()) + list(self.transformer.parameters()) + list(self.pos_emb.parameters()) + - list(self.task_emb.parameters()) + + # list(self.task_emb.parameters()) + list(self.act_embedding_table.parameters()) ) @@ -73,7 +75,7 @@ def zero_grad(self, set_to_none=False): self.tokenizer.zero_grad(set_to_none=set_to_none) self.transformer.zero_grad(set_to_none=set_to_none) self.pos_emb.zero_grad(set_to_none=set_to_none) - self.task_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) self.act_embedding_table.zero_grad(set_to_none=set_to_none) @@ -308,7 +310,8 @@ def _init_learn(self) -> None: # TODO # 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad) # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) - self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device) # only compatiable with for 1GPU training + # self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training self.grad_correct.init_param() self.grad_correct.rep_grad = False @@ -463,10 +466,33 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s self._optimizer_world_model.zero_grad() if self._cfg.use_moco: - # 这里可以集成 MoCo 或 CAGrad 等梯度校正方法, 1gpu 需要知道所有task对应的梯度 + # 如果已经初始化且多 GPU 情况下,只有 rank0 收集其他 GPU 的 loss_list + if dist.is_initialized() and dist.get_world_size() > 1: + rank = dist.get_rank() + world_size = dist.get_world_size() + # 利用分布式 gather_object:仅 rank0 指定接收缓冲区 + if rank == 0: + gathered_losses = [None for _ in range(world_size)] + else: + gathered_losses = None # 其他进程不需要接收 + # gather_object 要求所有进程参与:每个进程发送自己的 losses_list,rank0 接收 + dist.gather_object(losses_list, gathered_losses, dst=0) + if rank == 0: + # 将各 GPU 上的 losses_list 展平,汇总成全局 losses_list + all_losses_list = [] + for loss_list_tmp in gathered_losses: + all_losses_list.extend(loss_list_tmp) + losses_list = all_losses_list + else: + # 非 rank0 设置为 None,防止误用 + losses_list = None + + # 调用 MoCo 后向,由 grad_correct 中的 backward 实现梯度校正 + # 注意:在 moco.backward 中会判断当前 rank 是否为 0,只有 rank0 会根据 losses_list 计算梯度, + # 其他 rank 直接等待广播校正后共享梯度 lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) else: - # 不使用梯度校正的情况 + # 不使用梯度校正的情况,由各 rank 自己执行反向传播 lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) weighted_total_loss.backward() diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py index 74f713ebc..fed7ab967 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py @@ -17,8 +17,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec action_space_size_list=action_space_size_list, from_pixels=False, # ===== only for debug ===== - # frame_skip=100, # 100 - frame_skip=2, + frame_skip=50, # 100 + # frame_skip=2, continuous=True, # Assuming all DMC tasks use continuous action spaces collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -38,6 +38,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec calpha=0.5, rescale=1, ), use_moco=True, # ==============TODO============== + total_task_num=len(env_id_list), task_num=len(env_id_list), task_id=0, # To be set per task model=dict( @@ -54,6 +55,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec # use_shared_projection=True, # TODO use_shared_projection=False, # use_task_embed=True, # TODO + task_embed_option=None, # ==============TODO: none ============== use_task_embed=False, # ==============TODO============== num_unroll_steps=num_unroll_steps, policy_entropy_weight=5e-2, @@ -90,6 +92,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec num_experts_of_moe_in_transformer=4, ), ), + use_task_exploitation_weight=False, # TODO # task_complexity_weight=True, # TODO task_complexity_weight=False, # TODO total_batch_size=total_batch_size, @@ -153,7 +156,7 @@ def generate_configs(env_id_list: List[str], # TODO: debug # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' - exp_name_prefix = f'data_suz_mt_20250113/ddp_1gpu-moco_nlayer8_upc80_notaskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20250207_debug/ddp_2gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' # exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' @@ -205,7 +208,7 @@ def create_env_manager(): Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py + python -m torch.distributed.launch --nproc_per_node=2 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py """ @@ -236,16 +239,16 @@ def create_env_manager(): # ] # DMC 8games - env_id_list = [ - 'acrobot-swingup', - 'cartpole-balance', - 'cartpole-balance_sparse', - 'cartpole-swingup', - 'cartpole-swingup_sparse', - 'cheetah-run', - "ball_in_cup-catch", - "finger-spin", - ] + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # ] # DMC 18games # env_id_list = [ @@ -278,18 +281,18 @@ def create_env_manager(): n_episode = 8 evaluator_env_num = 3 num_simulations = 50 - # max_env_step = int(5e5) - max_env_step = int(1e6) + max_env_step = int(5e5) + # max_env_step = int(1e6) reanalyze_ratio = 0.0 - # nlayer=4 + # nlayer=4/8 total_batch_size = 512 batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - # nlayer=8/12 - total_batch_size = 256 - batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + # # nlayer=12 + # total_batch_size = 256 + # batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] num_unroll_steps = 5 infer_context_length = 2 @@ -299,12 +302,12 @@ def create_env_manager(): reanalyze_partition = 0.75 # ======== TODO: only for debug ======== - # collector_env_num = 2 - # num_segments = 2 - # n_episode = 2 - # evaluator_env_num = 2 - # num_simulations = 1 - # batch_size = [4 for _ in range(len(env_id_list))] + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 1 + batch_size = [4 for _ in range(len(env_id_list))] # ======================================= seed = 0 # You can iterate over multiple seeds if needed