From d098e716dc5f0488ccd7eb9c71a52c6c6b92eef8 Mon Sep 17 00:00:00 2001 From: puyuan Date: Mon, 6 Jan 2025 17:19:31 +0800 Subject: [PATCH] feature(pu): add task_complexity_weight option --- .../train_unizero_multitask_segment_ddp.py | 82 +++++++++++++++++-- lzero/policy/sampled_unizero_multitask.py | 6 +- ...m_state_suz_multitask_ddp_8games_config.py | 17 ++-- 3 files changed, 91 insertions(+), 14 deletions(-) diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index 15c199bdc..c9de788da 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -130,6 +130,44 @@ def allocate_batch_size( return batch_sizes +import numpy as np + +def compute_task_weights(task_rewards: dict, epsilon: float = 1e-6, + min_weight: float = 0.05, max_weight: float = 0.5, + temperature: float = 1.0) -> dict: + """ + 根据任务的评估奖励计算任务权重,加入鲁棒性设计,避免权重过小或过大。 + + Args: + task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 需要是归一化reward,或者不同任务的最大值是在同一幅度上 + epsilon (float): 避免分母为零的小值。 + min_weight (float): 权重的最小值,用于 clip。 + max_weight (float): 权重的最大值,用于 clip。 + temperature (float): 控制权重分布的温度系数,越大分布越均匀。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。 + """ + # Step 1: 计算初始权重(反比例关系) + # 任务奖励越低,权重越高,并加上 epsilon 避免分母为零 + raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()} + + # Step 2: 进行温度缩放,控制权重的均匀性 + # 温度缩放公式: w_i = (1 / r_i)^(1/temperature) + scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()} + + # Step 3: 归一化权重 + total_weight = sum(scaled_weights.values()) + normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()} + + # Step 4: 裁剪权重,确保在 [min_weight, max_weight] 范围内 + clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()} + + # Step 5: 再次归一化,确保裁剪后的权重和为 1 + total_clipped_weight = sum(clipped_weights.values()) + final_weights = {task_id: weight / total_clipped_weight for task_id, weight in clipped_weights.items()} + + return final_weights def train_unizero_multitask_segment_ddp( input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], @@ -286,6 +324,11 @@ def train_unizero_multitask_segment_ddp( reanalyze_batch_size = cfg.policy.reanalyze_batch_size update_per_collect = cfg.policy.update_per_collect + task_complexity_weight = cfg.policy.task_complexity_weight + + # 创建任务奖励字典 + task_rewards = {} # {task_id: reward} + while True: # 动态调整batch_size if cfg.policy.allocated_batch_sizes: @@ -338,8 +381,17 @@ def train_unizero_multitask_segment_ddp( # 判断评估是否成功 if stop is None or reward is None: print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + task_rewards[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值 else: - print(f"评估成功: stop={stop}, reward={reward}") + # 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值 + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"提取评估奖励时发生错误: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值 + print('=' * 20) print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') @@ -376,14 +428,33 @@ def train_unizero_multitask_segment_ddp( for replay_buffer in game_buffers ) - # 同步训练前所有rank的准备状态 + + + # 计算任务权重 try: + # 汇聚任务奖励 dist.barrier() - logging.info(f'Rank {rank}: 通过训练前的同步障碍') + if task_complexity_weight: + all_task_rewards = [None for _ in range(world_size)] + dist.all_gather_object(all_task_rewards, task_rewards) + # 合并任务奖励 + merged_task_rewards = {} + for rewards in all_task_rewards: + if rewards: + merged_task_rewards.update(rewards) + # 计算全局任务权重 + task_weights = compute_task_weights(merged_task_rewards) + # 同步任务权重 + dist.broadcast_object_list([task_weights], src=0) + print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}") + else: + task_weights = None + except Exception as e: - logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') break + # 学习策略 if not not_enough_data: for i in range(update_per_collect): @@ -414,8 +485,9 @@ def train_unizero_multitask_segment_ddp( break if train_data_multi_task: + learn_kwargs = {'task_weights':task_weights} # 在训练时,DDP会自动同步梯度和参数 - log_vars = learner.train(train_data_multi_task, envstep_multi_task) + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) if cfg.policy.use_priority: for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py index 5919bfe69..5a8a342ad 100644 --- a/lzero/policy/sampled_unizero_multitask.py +++ b/lzero/policy/sampled_unizero_multitask.py @@ -297,7 +297,7 @@ def _init_learn(self) -> None: self.task_id = self._cfg.task_id self.task_num_for_current_rank = self._cfg.task_num - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: """ Forward function for learning policy in learn mode, handling multiple tasks. """ @@ -394,8 +394,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self.inverse_scalar_transform_handle, task_id=task_id ) + if task_weights is not None: + weighted_total_loss += losses.loss_total * task_weights[task_id] - weighted_total_loss += losses.loss_total + # weighted_total_loss += losses.loss_total assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py index 1871fd78c..a1369f30b 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py @@ -16,6 +16,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec observation_shape_list=observation_shape_list, action_space_size_list=action_space_size_list, from_pixels=False, + # ===== only for debug ===== + # frame_skip=10, # 100 frame_skip=2, continuous=True, # Assuming all DMC tasks use continuous action spaces collector_env_num=collector_env_num, @@ -48,8 +50,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec action_space_size_list=action_space_size_list, policy_loss_type='kl', obs_type='vector', - use_shared_projection=True, - # use_shared_projection=False, + # use_shared_projection=True, # TODO + use_shared_projection=False, num_unroll_steps=num_unroll_steps, policy_entropy_weight=5e-2, continuous_action_space=True, @@ -85,6 +87,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec num_experts_of_moe_in_transformer=4, ), ), + task_complexity_weight=True, # TODO total_batch_size=total_batch_size, allocated_batch_sizes=False, # train_start_after_envsteps=int(2e3), @@ -141,7 +144,7 @@ def generate_configs(env_id_list: List[str], num_segments: int, total_batch_size: int): configs = [] - exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_upc80_usp_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_upc80_notusp_taskweight_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] @@ -202,10 +205,10 @@ def create_env_manager(): # os.environ["NCCL_TIMEOUT"] = "3600000000" # 定义环境列表 - # env_id_list = [ - # 'acrobot-swingup', # 6 1 - # 'cartpole-swingup', # 5 1 - # ] + env_id_list = [ + 'acrobot-swingup', # 6 1 + 'cartpole-swingup', # 5 1 + ] # DMC 8games env_id_list = [