Skip to content

Commit

Permalink
feature(pu): add task_complexity_weight option
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan committed Jan 6, 2025
1 parent 22fd7c1 commit d098e71
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
82 changes: 77 additions & 5 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,44 @@ def allocate_batch_size(

return batch_sizes

import numpy as np

def compute_task_weights(task_rewards: dict, epsilon: float = 1e-6,
min_weight: float = 0.05, max_weight: float = 0.5,
temperature: float = 1.0) -> dict:
"""
根据任务的评估奖励计算任务权重,加入鲁棒性设计,避免权重过小或过大。
Args:
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 需要是归一化reward,或者不同任务的最大值是在同一幅度上
epsilon (float): 避免分母为零的小值。
min_weight (float): 权重的最小值,用于 clip。
max_weight (float): 权重的最大值,用于 clip。
temperature (float): 控制权重分布的温度系数,越大分布越均匀。
Returns:
dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。
"""
# Step 1: 计算初始权重(反比例关系)
# 任务奖励越低,权重越高,并加上 epsilon 避免分母为零
raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()}

# Step 2: 进行温度缩放,控制权重的均匀性
# 温度缩放公式: w_i = (1 / r_i)^(1/temperature)
scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()}

# Step 3: 归一化权重
total_weight = sum(scaled_weights.values())
normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()}

# Step 4: 裁剪权重,确保在 [min_weight, max_weight] 范围内
clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()}

# Step 5: 再次归一化,确保裁剪后的权重和为 1
total_clipped_weight = sum(clipped_weights.values())
final_weights = {task_id: weight / total_clipped_weight for task_id, weight in clipped_weights.items()}

return final_weights

def train_unizero_multitask_segment_ddp(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
Expand Down Expand Up @@ -286,6 +324,11 @@ def train_unizero_multitask_segment_ddp(
reanalyze_batch_size = cfg.policy.reanalyze_batch_size
update_per_collect = cfg.policy.update_per_collect

task_complexity_weight = cfg.policy.task_complexity_weight

# 创建任务奖励字典
task_rewards = {} # {task_id: reward}

while True:
# 动态调整batch_size
if cfg.policy.allocated_batch_sizes:
Expand Down Expand Up @@ -338,8 +381,17 @@ def train_unizero_multitask_segment_ddp(
# 判断评估是否成功
if stop is None or reward is None:
print(f"Rank {rank} 在评估过程中遇到问题,继续训练...")
task_rewards[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值
else:
print(f"评估成功: stop={stop}, reward={reward}")
# 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值
try:
eval_mean_reward = reward.get('eval_episode_return_mean', float('inf'))
print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}")
task_rewards[cfg.policy.task_id] = eval_mean_reward
except Exception as e:
print(f"提取评估奖励时发生错误: {e}")
task_rewards[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值


print('=' * 20)
print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...')
Expand Down Expand Up @@ -376,14 +428,33 @@ def train_unizero_multitask_segment_ddp(
for replay_buffer in game_buffers
)

# 同步训练前所有rank的准备状态


# 计算任务权重
try:
# 汇聚任务奖励
dist.barrier()
logging.info(f'Rank {rank}: 通过训练前的同步障碍')
if task_complexity_weight:
all_task_rewards = [None for _ in range(world_size)]
dist.all_gather_object(all_task_rewards, task_rewards)
# 合并任务奖励
merged_task_rewards = {}
for rewards in all_task_rewards:
if rewards:
merged_task_rewards.update(rewards)
# 计算全局任务权重
task_weights = compute_task_weights(merged_task_rewards)
# 同步任务权重
dist.broadcast_object_list([task_weights], src=0)
print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}")
else:
task_weights = None

except Exception as e:
logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}')
break


# 学习策略
if not not_enough_data:
for i in range(update_per_collect):
Expand Down Expand Up @@ -414,8 +485,9 @@ def train_unizero_multitask_segment_ddp(
break

if train_data_multi_task:
learn_kwargs = {'task_weights':task_weights}
# 在训练时,DDP会自动同步梯度和参数
log_vars = learner.train(train_data_multi_task, envstep_multi_task)
log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs)

if cfg.policy.use_priority:
for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)):
Expand Down
6 changes: 4 additions & 2 deletions lzero/policy/sampled_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _init_learn(self) -> None:
self.task_id = self._cfg.task_id
self.task_num_for_current_rank = self._cfg.task_num

def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]:
"""
Forward function for learning policy in learn mode, handling multiple tasks.
"""
Expand Down Expand Up @@ -394,8 +394,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
self.inverse_scalar_transform_handle,
task_id=task_id
)
if task_weights is not None:
weighted_total_loss += losses.loss_total * task_weights[task_id]

weighted_total_loss += losses.loss_total
# weighted_total_loss += losses.loss_total
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
observation_shape_list=observation_shape_list,
action_space_size_list=action_space_size_list,
from_pixels=False,
# ===== only for debug =====
# frame_skip=10, # 100
frame_skip=2,
continuous=True, # Assuming all DMC tasks use continuous action spaces
collector_env_num=collector_env_num,
Expand Down Expand Up @@ -48,8 +50,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
action_space_size_list=action_space_size_list,
policy_loss_type='kl',
obs_type='vector',
use_shared_projection=True,
# use_shared_projection=False,
# use_shared_projection=True, # TODO
use_shared_projection=False,
num_unroll_steps=num_unroll_steps,
policy_entropy_weight=5e-2,
continuous_action_space=True,
Expand Down Expand Up @@ -85,6 +87,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
num_experts_of_moe_in_transformer=4,
),
),
task_complexity_weight=True, # TODO
total_batch_size=total_batch_size,
allocated_batch_sizes=False,
# train_start_after_envsteps=int(2e3),
Expand Down Expand Up @@ -141,7 +144,7 @@ def generate_configs(env_id_list: List[str],
num_segments: int,
total_batch_size: int):
configs = []
exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_upc80_usp_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_upc80_notusp_taskweight_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list]

Expand Down Expand Up @@ -202,10 +205,10 @@ def create_env_manager():
# os.environ["NCCL_TIMEOUT"] = "3600000000"

# 定义环境列表
# env_id_list = [
# 'acrobot-swingup', # 6 1
# 'cartpole-swingup', # 5 1
# ]
env_id_list = [
'acrobot-swingup', # 6 1
'cartpole-swingup', # 5 1
]

# DMC 8games
env_id_list = [
Expand Down

0 comments on commit d098e71

Please sign in to comment.