Skip to content

Commit

Permalink
feature(pu): add task_exploitation_weight option
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan authored and PaParaZz1 committed Jan 14, 2025
1 parent d098e71 commit 8beb492
Show file tree
Hide file tree
Showing 17 changed files with 4,019 additions and 205 deletions.
80 changes: 80 additions & 0 deletions lzero/entry/compute_task_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@



import numpy as np
import torch


def symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 归一化,减少目标值的幅度差异。
symlog(x) = sign(x) * log(|x| + 1)
"""
return torch.sign(x) * torch.log(torch.abs(x) + 1)


def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 的逆操作,用于恢复原始值。
inv_symlog(x) = sign(x) * (exp(|x|) - 1)
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)


def compute_task_weights(
task_rewards: dict,
epsilon: float = 1e-6,
min_weight: float = 0.1,
max_weight: float = 0.5,
temperature: float = 1.0,
use_symlog: bool = True,
) -> dict:
"""
改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。
Args:
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。
epsilon (float): 避免分母为零的小值。
min_weight (float): 权重的最小值,用于裁剪。
max_weight (float): 权重的最大值,用于裁剪。
temperature (float): 控制权重分布的温度系数。
use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。
Returns:
dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。
"""
# Step 1: 矫正奖励值(可选,使用 symlog)
if use_symlog:
rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32)
corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正
task_rewards = dict(zip(task_rewards.keys(), corrected_rewards))

# Step 2: 计算初始权重(反比例关系)
raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()}

# Step 3: 温度缩放
scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()}

# Step 4: 归一化权重
total_weight = sum(scaled_weights.values())
normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()}

# Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内
clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()}

final_weights = clipped_weights
return final_weights

task_rewards_list = [
{"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300},
{"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000},
{"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10},
]

for i, task_rewards in enumerate(task_rewards_list, start=1):
print(f"Case {i}: Original Rewards: {task_rewards}")
print("Original Weights:")
print(compute_task_weights(task_rewards, use_symlog=False))
print("Improved Weights with Symlog:")
print(compute_task_weights(task_rewards, use_symlog=True))
print()
217 changes: 182 additions & 35 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from ding.utils import EasyTimer
import torch.nn.functional as F

import torch.distributed as dist

import concurrent.futures


# 设置超时时间 (秒)
TIMEOUT = 12000 # 例如200分钟

Expand Down Expand Up @@ -132,42 +135,124 @@ def allocate_batch_size(

import numpy as np

def compute_task_weights(task_rewards: dict, epsilon: float = 1e-6,
min_weight: float = 0.05, max_weight: float = 0.5,
temperature: float = 1.0) -> dict:

def symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 归一化,减少目标值的幅度差异。
symlog(x) = sign(x) * log(|x| + 1)
"""
return torch.sign(x) * torch.log(torch.abs(x) + 1)

def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
根据任务的评估奖励计算任务权重,加入鲁棒性设计,避免权重过小或过大。
Symlog 的逆操作,用于恢复原始值。
inv_symlog(x) = sign(x) * (exp(|x|) - 1)
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

# 全局最大值和最小值(用于 "run-max-min")
GLOBAL_MAX = -float('inf')
GLOBAL_MIN = float('inf')

def compute_task_weights(
task_rewards: dict,
option: str = "symlog",
epsilon: float = 1e-6,
temperature: float = 1.0,
use_softmax: bool = False, # 是否使用 Softmax
reverse: bool = False, # 正比 (False) 或反比 (True)
clip_min: float = 1e-2, # 权重的最小值
clip_max: float = 1.0, # 权重的最大值
) -> dict:
"""
改进后的任务权重计算函数,支持多种标准化方式、Softmax 和正反比权重计算,并增加权重范围裁剪功能。
Args:
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 需要是归一化reward,或者不同任务的最大值是在同一幅度上
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励或损失。
option (str): 标准化方式,可选值为 "symlog", "max-min", "run-max-min", "rank", "none"。
epsilon (float): 避免分母为零的小值。
min_weight (float): 权重的最小值,用于 clip。
max_weight (float): 权重的最大值,用于 clip。
temperature (float): 控制权重分布的温度系数,越大分布越均匀。
temperature (float): 控制权重分布的温度系数。
use_softmax (bool): 是否使用 Softmax 进行权重分配。
reverse (bool): 若为 True,权重与值反比;若为 False,权重与值正比。
clip_min (float): 权重的最小值,用于裁剪。
clip_max (float): 权重的最大值,用于裁剪。
Returns:
dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重
dict: 每个任务的权重,键为 task_id,值为归一化后的权重
"""
# Step 1: 计算初始权重(反比例关系)
# 任务奖励越低,权重越高,并加上 epsilon 避免分母为零
raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()}

# Step 2: 进行温度缩放,控制权重的均匀性
# 温度缩放公式: w_i = (1 / r_i)^(1/temperature)
scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()}

# Step 3: 归一化权重
total_weight = sum(scaled_weights.values())
normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()}

# Step 4: 裁剪权重,确保在 [min_weight, max_weight] 范围内
clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()}

# Step 5: 再次归一化,确保裁剪后的权重和为 1
total_clipped_weight = sum(clipped_weights.values())
final_weights = {task_id: weight / total_clipped_weight for task_id, weight in clipped_weights.items()}

return final_weights
import torch
import torch.nn.functional as F

global GLOBAL_MAX, GLOBAL_MIN

# 如果输入为空字典,直接返回空结果
if not task_rewards:
return {}

# Step 1: 对 task_rewards 的值构造张量
task_ids = list(task_rewards.keys())
rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32)

if option == "symlog":
# 使用 symlog 标准化
scaled_rewards = symlog(rewards_tensor)
elif option == "max-min":
# 使用最大最小值归一化
max_reward = rewards_tensor.max().item()
min_reward = rewards_tensor.min().item()
scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon)
elif option == "run-max-min":
# 使用全局最大最小值归一化
GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item())
GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item())
scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon)
elif option == "rank":
# 使用 rank 标准化
# Rank 是基于值大小的排名,1 表示最小值,越大排名越高
sorted_indices = torch.argsort(rewards_tensor)
scaled_rewards = torch.empty_like(rewards_tensor)
rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) # 1 到 N
scaled_rewards[sorted_indices] = rank_values
elif option == "none":
# 不进行标准化
scaled_rewards = rewards_tensor
else:
raise ValueError(f"Unsupported option: {option}")

# Step 2: 根据 reverse 确定权重是正比还是反比
if not reverse:
# 正比:权重与值正相关
raw_weights = scaled_rewards
else:
# 反比:权重与值负相关
# 避免 scaled_rewards 为负数或零
scaled_rewards = torch.clamp(scaled_rewards, min=epsilon)
raw_weights = 1.0 / scaled_rewards

# Step 3: 根据是否使用 Softmax 进行权重计算
if use_softmax:
# 使用 Softmax 进行权重分配
beta = 1.0 / max(temperature, epsilon) # 确保 temperature 不为零
logits = -beta * raw_weights
softmax_weights = F.softmax(logits, dim=0).numpy()
weights = dict(zip(task_ids, softmax_weights))
else:
# 不使用 Softmax,直接计算权重
# 温度缩放
scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # 确保温度不为零

# 归一化权重
total_weight = scaled_weights.sum()
normalized_weights = scaled_weights / total_weight

# 转换为字典
weights = dict(zip(task_ids, normalized_weights.numpy()))

# Step 4: Clip 权重范围
for task_id in weights:
weights[task_id] = max(min(weights[task_id], clip_max), clip_min)

return weights

def train_unizero_multitask_segment_ddp(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
Expand All @@ -193,6 +278,17 @@ def train_unizero_multitask_segment_ddp(
Returns:
- policy (:obj:`Policy`): 收敛的策略。
"""
# 初始化温度调度器
initial_temperature = 10.0
final_temperature = 1.0
threshold_steps = int(1e4) # 训练步数达到 10k 时,温度降至 1.0
temperature_scheduler = TemperatureScheduler(
initial_temp=initial_temperature,
final_temp=final_temperature,
threshold_steps=threshold_steps,
mode='linear' # 或 'exponential'
)

# 获取当前进程的rank和总进程数
rank = get_rank()
world_size = get_world_size()
Expand Down Expand Up @@ -325,6 +421,8 @@ def train_unizero_multitask_segment_ddp(
update_per_collect = cfg.policy.update_per_collect

task_complexity_weight = cfg.policy.task_complexity_weight
use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight
task_exploitation_weight = None

# 创建任务奖励字典
task_rewards = {} # {task_id: reward}
Expand Down Expand Up @@ -428,7 +526,10 @@ def train_unizero_multitask_segment_ddp(
for replay_buffer in game_buffers
)


# 获取当前温度
current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter)
# collector._policy._task_weight_temperature = current_temperature_task_weight
# policy.collect_mode.get_attribute('task_weight_temperature') = current_temperature_task_weight

# 计算任务权重
try:
Expand All @@ -443,13 +544,12 @@ def train_unizero_multitask_segment_ddp(
if rewards:
merged_task_rewards.update(rewards)
# 计算全局任务权重
task_weights = compute_task_weights(merged_task_rewards)
task_weights = compute_task_weights(merged_task_rewards, temperature=current_temperature_task_weight)
# 同步任务权重
dist.broadcast_object_list([task_weights], src=0)
print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}")
else:
task_weights = None

except Exception as e:
logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}')
break
Expand Down Expand Up @@ -485,10 +585,57 @@ def train_unizero_multitask_segment_ddp(
break

if train_data_multi_task:
learn_kwargs = {'task_weights':task_weights}
# learn_kwargs = {'task_exploitation_weight':task_exploitation_weight, 'task_weights':task_weights, }
learn_kwargs = {'task_weights':task_exploitation_weight}

# 在训练时,DDP会自动同步梯度和参数
log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs)

# 判断是否需要计算task_exploitation_weight
if i == 0:
# 计算任务权重
try:
dist.barrier() # 等待所有进程同步
if use_task_exploitation_weight:
# 收集所有任务的 obs_loss
all_obs_loss = [None for _ in range(world_size)]
# 构建当前进程的任务 obs_loss 数据
merged_obs_loss_task = {}
for cfg, replay_buffer in zip(cfgs, game_buffers):
task_id = cfg.policy.task_id
if f'noreduce_obs_loss_task{task_id}' in log_vars[0]:
merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}']
# 汇聚所有进程的 obs_loss 数据
dist.all_gather_object(all_obs_loss, merged_obs_loss_task)
# 合并所有进程的 obs_loss 数据
global_obs_loss_task = {}
for obs_loss_task in all_obs_loss:
if obs_loss_task:
global_obs_loss_task.update(obs_loss_task)
# 计算全局任务权重
if global_obs_loss_task:
task_exploitation_weight = compute_task_weights(
global_obs_loss_task,
option="rank",
# temperature=current_temperature_task_weight # TODO
temperature=1,
)
# 广播任务权重到所有进程
dist.broadcast_object_list([task_exploitation_weight], src=0)
print(f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}")
else:
logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。")
task_exploitation_weight = None
else:
task_exploitation_weight = None
# 更新训练参数,使其包含计算后的任务权重
learn_kwargs['task_weight'] = task_exploitation_weight
except Exception as e:
logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}')
raise e # 保留异常抛出,便于外部捕获和分析



if cfg.policy.use_priority:
for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)):
# 更新任务特定的重放缓冲区优先级
Expand Down
Loading

0 comments on commit 8beb492

Please sign in to comment.