Skip to content

Commit

Permalink
fix(pu): fix self.past_keys_values_cache.popitem cuda memory bug, use…
Browse files Browse the repository at this point in the history
… torch.cuda.empty_cache() every 200 steps now
  • Loading branch information
puyuan1996 committed Jan 11, 2024
1 parent c13fea7 commit 01c8385
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 11 deletions.
6 changes: 3 additions & 3 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def train_muzero_gpt(

# NOTE: TODO
# TODO: for batch world model ,to improve kv reuse, we could donot reset
policy._learn_model.world_model.past_keys_values_cache.clear()

torch.cuda.empty_cache() # TODO
policy._learn_model.world_model.past_keys_values_cache.clear() # very important
del policy._learn_model.world_model.keys_values_wm
torch.cuda.empty_cache() # TODO: NOTE

# if collector.envstep > 0:
# # TODO: only for debug
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def search(
# _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens)

# model.world_model.past_keys_values_cache.clear() # 清除缓存
# del model.world_model.keys_values_wm # TODO: 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand Down
3 changes: 2 additions & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:1',
"device": 'cuda:0',
# "device": 'cpu',
'support_size': 21,
'action_shape': 6,# TODO:for atari
'max_cache_size':500,
# 'max_cache_size':100,
# 'max_cache_size':1000,
"env_num":8,
}
Expand Down
38 changes: 38 additions & 0 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,10 +630,48 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
if len(self.past_keys_values_cache) > self.max_cache_size:
# TODO: lru_cache
self.past_keys_values_cache.popitem(last=False) # Removes the earliest inserted item
# popitem返回一个键值对,其中第二个元素是值
# _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False)
# 如果popped_kv_cache是一个包含张量或复杂对象的容器,您可能需要进一步删除这些对象
# 例如:
# del popped_kv_cache # 不要这一行
# torch.cuda.empty_cache() # 请注意,频繁调用可能会影响性能, 先del反而清除不掉占用的2MB缓存

# Example usage:
# Assuming `past_keys_values_cache` is a populated instance of `KeysValues`
# and `num_layers` is the number of transformer layers
# cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2)
# print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB')

return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value


# 计算显存使用量的函数
def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int):
total_memory_bytes = 0

# 遍历OrderedDict中所有的KeysValues实例
for kv_instance in past_keys_values_cache.values():
num_layers = len(kv_instance) # 获取层数
for layer in range(num_layers):
kv_cache = kv_instance[layer]
k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状
v_shape = kv_cache._v_cache.shape # 获取values缓存的形状

# 计算元素个数并乘以每个元素的字节数
k_memory = torch.prod(torch.tensor(k_shape)) * 4
v_memory = torch.prod(torch.tensor(v_shape)) * 4

# 累加keys和values缓存的内存
layer_memory = k_memory + v_memory
total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字

# 将总内存从字节转换为吉字节
total_memory_gb = total_memory_bytes / (1024 ** 3)
return total_memory_gb



def compute_loss(self, batch, tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses:

if len(batch['observations'][0, 0].shape) == 3:
Expand Down
21 changes: 19 additions & 2 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def _init_learn(self) -> None:
# )
self._optimizer_world_model = configure_optimizer(
model=self._model.world_model,
learning_rate=3e-3,
# learning_rate=1e-4, # NOTE: TODO
# learning_rate=3e-3,
learning_rate=1e-4, # NOTE: TODO
weight_decay=self._cfg.weight_decay,
# weight_decay=0.01,
exclude_submodules=['none'] # NOTE
Expand Down Expand Up @@ -534,7 +534,22 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict())


# 确保所有的CUDA核心完成工作,以便准确统计显存使用情况
torch.cuda.synchronize()
# 获取当前分配的显存总量(字节)
current_memory_allocated = torch.cuda.memory_allocated()
# 获取程序运行到目前为止分配过的最大显存量(字节)
max_memory_allocated = torch.cuda.max_memory_allocated()

# 将显存使用量从字节转换为GB
current_memory_allocated_gb = current_memory_allocated / (1024**3)
max_memory_allocated_gb = max_memory_allocated / (1024**3)
# 使用SummaryWriter记录当前和最大显存使用量


return_loss_dict = {
'Current_GPU': current_memory_allocated_gb,
'Max_GPU': max_memory_allocated_gb,
'collect_mcts_temperature': self._collect_mcts_temperature,
'collect_epsilon': self.collect_epsilon,
'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'],
Expand Down Expand Up @@ -948,6 +963,8 @@ def _monitor_vars_learn(self) -> List[str]:
tensorboard according to the return value ``_forward_learn``.
"""
return [
'Current_GPU',
'Max_GPU',
'collect_epsilon',
'collect_mcts_temperature',
# 'cur_lr',
Expand Down
6 changes: 6 additions & 0 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def collect(self,
completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id]))

eps_steps_lst[env_id] += 1

if eps_steps_lst[env_id] % 200==0:
torch.cuda.empty_cache() # TODO: NOTE
print('torch.cuda.empty_cache()')
# print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')

total_transitions += 1

if self.policy_config.use_priority:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DI-engine[common_env]>=0.4.7
DI-engine[common_env]>=0.5.0
gym[accept-rom-license]==0.25.1
numpy>=1.22.4
pympler
10 changes: 7 additions & 3 deletions zoo/atari/config/atari_muzero_gpt_config_stack4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(1)
torch.cuda.set_device(0)

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
env_name = 'PongNoFrameskip-v4'
Expand Down Expand Up @@ -64,9 +64,13 @@
atari_muzero_config = dict(
# TODO: world_model.py decode_obs_tokens
# TODO: tokenizer.py: lpips loss
# exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr3e-3-gcv05-reconslossw01-minmax-latentgrad02_stack4_upc1000_seed0',
# exp_name=f'data_mz_gpt_ctree_0111_debug/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss-obsloss2_kllw0-lr1e-4-gcv05-biasfalse-reconslossw0-minmax_stack4_mcs500_seed0',

exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr3e-3-gcv05-biasfalse-minmax-iter60k-fixed_stack4_seed0',
# exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss-obsloss2_kllw0-lr1e-4-gcv05-biasfalse-reconslossw01-minmax-latentgrad1_stack4_mcs500_seed0',

# exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss-obsloss2_kllw0-lr1e-4-gcv05-reconslossw01-minmax-latentgrad02-iter60ktrain_stack4_emptycache-per200_seed0',

exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss-obsloss2_kllw0-lr1e-4-gcv05-biasfalse-reconslossw0-minmax-iter60k-fixed_stack4_mcs500_emptycache-per1_seed0',

# exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-reconslossw01-minmax-latentgrad0.2-fromm60ktrain_stack4_upc1000_seed0',
# exp_name=f'data_mz_gpt_ctree_0111/{env_name[:-14]}_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_mcs500_contembdings_mz-repenet-lastlinear-lsd1024-6488-kaiming-lerelu_obsmseloss_obsloss2_kllw0-lr1e-4-gcv05-onlyreconslossw1-biasfalse-minmax_stack4_seed0',
Expand Down
4 changes: 3 additions & 1 deletion zoo/atari/config/atari_muzero_gpt_config_stack4_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
collector_env_num = 8
n_episode = 8
evaluator_env_num = 1
update_per_collect = 1000
# update_per_collect = 1000
update_per_collect = 10

# update_per_collect = None
# model_update_ratio = 0.25
model_update_ratio = 0.25
Expand Down

0 comments on commit 01c8385

Please sign in to comment.