Skip to content

Commit

Permalink
polish(pu): add moco multigpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan committed Feb 7, 2025
1 parent 8fe1a6d commit 97988e2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 31 deletions.
36 changes: 31 additions & 5 deletions lzero/policy/sampled_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from lzero.policy.unizero import UniZeroPolicy
from .utils import configure_optimizers_nanogpt
import torch.nn.functional as F
import torch.distributed as dist
import sys
sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/')
from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect
Expand Down Expand Up @@ -64,7 +66,7 @@ def parameters(self):
list(self.tokenizer.parameters()) +
list(self.transformer.parameters()) +
list(self.pos_emb.parameters()) +
list(self.task_emb.parameters()) +
# list(self.task_emb.parameters()) +
list(self.act_embedding_table.parameters())
)

Expand All @@ -73,7 +75,7 @@ def zero_grad(self, set_to_none=False):
self.tokenizer.zero_grad(set_to_none=set_to_none)
self.transformer.zero_grad(set_to_none=set_to_none)
self.pos_emb.zero_grad(set_to_none=set_to_none)
self.task_emb.zero_grad(set_to_none=set_to_none)
# self.task_emb.zero_grad(set_to_none=set_to_none)
self.act_embedding_table.zero_grad(set_to_none=set_to_none)


Expand Down Expand Up @@ -308,7 +310,8 @@ def _init_learn(self) -> None:
# TODO
# 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad)
# self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device)
self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device) # only compatiable with for 1GPU training
# self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training
self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training

self.grad_correct.init_param()
self.grad_correct.rep_grad = False
Expand Down Expand Up @@ -463,10 +466,33 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s
self._optimizer_world_model.zero_grad()

if self._cfg.use_moco:
# 这里可以集成 MoCo 或 CAGrad 等梯度校正方法, 1gpu 需要知道所有task对应的梯度
# 如果已经初始化且多 GPU 情况下,只有 rank0 收集其他 GPU 的 loss_list
if dist.is_initialized() and dist.get_world_size() > 1:
rank = dist.get_rank()
world_size = dist.get_world_size()
# 利用分布式 gather_object:仅 rank0 指定接收缓冲区
if rank == 0:
gathered_losses = [None for _ in range(world_size)]
else:
gathered_losses = None # 其他进程不需要接收
# gather_object 要求所有进程参与:每个进程发送自己的 losses_list,rank0 接收
dist.gather_object(losses_list, gathered_losses, dst=0)
if rank == 0:
# 将各 GPU 上的 losses_list 展平,汇总成全局 losses_list
all_losses_list = []
for loss_list_tmp in gathered_losses:
all_losses_list.extend(loss_list_tmp)
losses_list = all_losses_list
else:
# 非 rank0 设置为 None,防止误用
losses_list = None

# 调用 MoCo 后向,由 grad_correct 中的 backward 实现梯度校正
# 注意:在 moco.backward 中会判断当前 rank 是否为 0,只有 rank0 会根据 losses_list 计算梯度,
# 其他 rank 直接等待广播校正后共享梯度
lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
else:
# 不使用梯度校正的情况
# 不使用梯度校正的情况,由各 rank 自己执行反向传播
lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device)
weighted_total_loss.backward()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
action_space_size_list=action_space_size_list,
from_pixels=False,
# ===== only for debug =====
# frame_skip=100, # 100
frame_skip=2,
frame_skip=50, # 100
# frame_skip=2,
continuous=True, # Assuming all DMC tasks use continuous action spaces
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand All @@ -38,6 +38,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
calpha=0.5, rescale=1,
),
use_moco=True, # ==============TODO==============
total_task_num=len(env_id_list),
task_num=len(env_id_list),
task_id=0, # To be set per task
model=dict(
Expand All @@ -54,6 +55,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
# use_shared_projection=True, # TODO
use_shared_projection=False,
# use_task_embed=True, # TODO
task_embed_option=None, # ==============TODO: none ==============
use_task_embed=False, # ==============TODO==============
num_unroll_steps=num_unroll_steps,
policy_entropy_weight=5e-2,
Expand Down Expand Up @@ -90,6 +92,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
num_experts_of_moe_in_transformer=4,
),
),
use_task_exploitation_weight=False, # TODO
# task_complexity_weight=True, # TODO
task_complexity_weight=False, # TODO
total_batch_size=total_batch_size,
Expand Down Expand Up @@ -153,7 +156,7 @@ def generate_configs(env_id_list: List[str],
# TODO: debug
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

exp_name_prefix = f'data_suz_mt_20250113/ddp_1gpu-moco_nlayer8_upc80_notaskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20250207_debug/ddp_2gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

# exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'

Expand Down Expand Up @@ -205,7 +208,7 @@ def create_env_manager():
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py
torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
"""

Expand Down Expand Up @@ -236,16 +239,16 @@ def create_env_manager():
# ]

# DMC 8games
env_id_list = [
'acrobot-swingup',
'cartpole-balance',
'cartpole-balance_sparse',
'cartpole-swingup',
'cartpole-swingup_sparse',
'cheetah-run',
"ball_in_cup-catch",
"finger-spin",
]
# env_id_list = [
# 'acrobot-swingup',
# 'cartpole-balance',
# 'cartpole-balance_sparse',
# 'cartpole-swingup',
# 'cartpole-swingup_sparse',
# 'cheetah-run',
# "ball_in_cup-catch",
# "finger-spin",
# ]

# DMC 18games
# env_id_list = [
Expand Down Expand Up @@ -278,18 +281,18 @@ def create_env_manager():
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
# max_env_step = int(5e5)
max_env_step = int(1e6)
max_env_step = int(5e5)
# max_env_step = int(1e6)

reanalyze_ratio = 0.0

# nlayer=4
# nlayer=4/8
total_batch_size = 512
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

# nlayer=8/12
total_batch_size = 256
batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
# # nlayer=12
# total_batch_size = 256
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]

num_unroll_steps = 5
infer_context_length = 2
Expand All @@ -299,12 +302,12 @@ def create_env_manager():
reanalyze_partition = 0.75

# ======== TODO: only for debug ========
# collector_env_num = 2
# num_segments = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 1
# batch_size = [4 for _ in range(len(env_id_list))]
collector_env_num = 2
num_segments = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 1
batch_size = [4 for _ in range(len(env_id_list))]
# =======================================

seed = 0 # You can iterate over multiple seeds if needed
Expand Down

0 comments on commit 97988e2

Please sign in to comment.