Skip to content

Commit

Permalink
fix(pu): always save latest kv_cache for latent_state to tackle POMDP…
Browse files Browse the repository at this point in the history
… state alias problem
  • Loading branch information
jiayilee65 committed Apr 4, 2024
1 parent 4255702 commit f5a928e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 103 deletions.
35 changes: 19 additions & 16 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,29 @@
cfg['world_model'] = {
"device": 'cuda:6',

'tokens_per_block': 2,
'max_blocks': 4,
"max_tokens": 2 * 4, # TODO: horizon:4
"context_length": 2 * 4,
"context_length_for_recurrent":2 * 4,
"recurrent_keep_deepth": 2,
"gru_gating": False,
# "gru_gating": True,
# 'tokens_per_block': 2,
# 'max_blocks': 4,
# "max_tokens": 2 * 4, # TODO: horizon:4
# "context_length": 2 * 4,
# "context_length_for_recurrent":2 * 4,
# "recurrent_keep_deepth": 0,
# "gru_gating": False,
# # "gru_gating": True,


# 'tokens_per_block': 2,
# 'max_blocks': 8,
# "max_tokens": 2 * 8, # TODO: horizon:8
'tokens_per_block': 2,
'max_blocks': 8,
"max_tokens": 2 * 8, # TODO: horizon:8
# "context_length": 2 * 8,
# "context_length_for_recurrent":2 * 8,
# "recurrent_keep_deepth": 1,
# # "context_length": 4,
# # "context_length_for_recurrent":4,
# "gru_gating": False,
# # "gru_gating": True,
"context_length": 10,
"context_length_for_recurrent": 10,
# "recurrent_keep_deepth": 2,
"recurrent_keep_deepth": 100,
# "context_length": 4,
# "context_length_for_recurrent":4,
"gru_gating": False,
# "gru_gating": True,

# 'tokens_per_block': 2,
# 'max_blocks': 8,
Expand Down
7 changes: 4 additions & 3 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
cfg['world_model'] = {
'tokens_per_block': 2,
"device": 'cuda:7',

# 'max_blocks': 16,
# "max_tokens": 2 * 16, # 1+0+15 memory_length = 0
Expand All @@ -20,8 +21,9 @@

'max_blocks': 16,
"max_tokens": 2 * 16, # 1+0+15 memory_length = 0
"context_length": 8,
"context_length_for_recurrent": 8,
"context_length": 2 * 16,
"context_length_for_recurrent": 2 * 16,
"recurrent_keep_deepth": 100,

# 'max_blocks': 30,
# "max_tokens": 2 * 30, # 15+0+15 memory_length = 0
Expand Down Expand Up @@ -65,7 +67,6 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:7',

'support_size': 21,
'action_shape': 4, # NOTE:for memory
Expand Down
145 changes: 80 additions & 65 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> to
observations = obs_act_dict['obs']
buffer_action = obs_act_dict['action']
current_obs = obs_act_dict['current_obs']
else:
observations = obs_act_dict
buffer_action = None
current_obs = None
# else:
# observations = obs_act_dict
# buffer_action = None
# current_obs = None
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E)

if current_obs is not None:
Expand Down Expand Up @@ -386,9 +386,9 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_st
print('root_total_query_cnt:', self.root_total_query_cnt)
print(f'root_hit_ratio:{root_hit_ratio}')
print(f'root_hit find size {self.past_keys_values_cache_init_infer[cache_key].size}')
if self.past_keys_values_cache_init_infer[cache_key].size >= self.config.max_tokens - 5:
if self.past_keys_values_cache_init_infer[cache_key].size >= self.config.max_tokens - 3:
print(f'==' * 20)
print(f'NOTE: root_hit find size >= self.config.max_tokens - 5')
print(f'NOTE: root_hit find size >= self.config.max_tokens - 3')
print(f'==' * 20)
# 这里需要deepcopy因为在transformer的forward中会原地修改matched_value
self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda')))
Expand Down Expand Up @@ -424,7 +424,8 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_st
act_tokens = rearrange(buffer_action, 'b l -> b l 1')

# 选择每个样本的最后一步
last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到
###### 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 ######
last_steps = act_tokens[:, -1:, :]
# 使用torch.cat在第二个维度上连接原始act_tokens和last_steps
act_tokens = torch.cat((act_tokens, last_steps), dim=1)

Expand Down Expand Up @@ -455,8 +456,13 @@ def forward_initial_inference(self, obs_act_dict):
self.past_keys_values_cache_recurrent_infer.clear()
# print('=='*20)
# print('self.past_keys_values_cache_recurrent_infer.clear() after init_inference')
self.latent_state_index_in_search_path = [[] for i in range(latent_state.shape[0])]
self.next_latent_state_depth = [[] for i in range(latent_state.shape[0])]

# self.latent_state_index_in_search_path = [[] for i in range(latent_state.shape[0])]
# self.next_latent_state_depth = [[] for i in range(latent_state.shape[0])]
# self.last_depth = [0 for i in range(latent_state.shape[0])]
# 维护一个全局的depth_map字典,用于存储已计算过的深度信息
# self.depth_map = [{0: 1} for i in range(latent_state.shape[0])] # 根节点处的深度映射

return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

"""
Expand All @@ -466,19 +472,21 @@ def forward_initial_inference(self, obs_act_dict):
其内部也是通过batch执行transformer forward的推理
"""

def convert_to_depth(self, search_path):
depth_map = {0: 1} # 根节点处的深度映射
depth_in_search_path = []
for index in search_path:
# 如果当前索引对应的深度没有被计算过,则基于父节点的深度计算它
if index not in depth_map:
if search_path[index] not in depth_map:
depth_map[search_path[index]] = max(list(depth_map.values())) + 1
else:
depth_map[index] = depth_map[search_path[index]] + 1
depth_in_search_path.append(depth_map[index])

return depth_in_search_path
def convert_to_depth(self, search_path, depth_map, last_depth):
# 获取新加入的元素
new_index = search_path[-1]

# 如果新加入的元素对应的深度没有被计算过,则基于父节点的深度计算它
if new_index not in depth_map:
if search_path[new_index] not in depth_map:
depth_map[search_path[new_index]] = max(list(depth_map.values())) + 1
else:
depth_map[new_index] = depth_map[search_path[new_index]] + 1

# 将新加入元素的深度添加到last_depth的末尾
last_depth.append(depth_map[new_index])

return last_depth


@torch.no_grad()
Expand All @@ -491,17 +499,17 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0,
latest_state, action = state_action_history[-1]
# print(f'action:{action}')

for i, latent_state_index in enumerate(latent_state_index_in_search_path):
self.latent_state_index_in_search_path[i].append(latent_state_index)
# 示例数据
# latent_state_index_in_search_path = [0, 0, 0, 0, 2, 0, 2, 3, 4]
# 转换为深度表示
self.next_latent_state_depth[i] = self.convert_to_depth(self.latent_state_index_in_search_path[i])
# 打印结果
# print(f'next_latent_state_depth:{self.next_latent_state_depth[i]}')
# for i, latent_state_index in enumerate(latent_state_index_in_search_path):
# self.latent_state_index_in_search_path[i].append(latent_state_index)

# # 如果是第一次计算,则初始化self.next_latent_state_depth[i]
# if simulation_index == 0:
# self.next_latent_state_depth[i] = self.convert_to_depth(self.latent_state_index_in_search_path[i], self.depth_map[i], [])
# else:
# # 否则,在上一次计算得到的self.next_latent_state_depth[i]的基础上,只计算新加入的元素的深度
# self.next_latent_state_depth[i] = self.convert_to_depth(self.latent_state_index_in_search_path[i], self.depth_map[i], self.next_latent_state_depth[i])

# print(f'next_latent_state_depth:{self.next_latent_state_depth}')

# 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息
ready_env_num = latest_state.shape[0]
self.keys_values_wm_list = []
Expand Down Expand Up @@ -655,42 +663,49 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length-3
self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length-3

# TODO: memory_env;每次都存最新的

if is_init_infer:
# init_infer: laten_state是encoder编码得到的,没有误差
# 比较并存储较大的缓存
if cache_key in self.past_keys_values_cache_init_infer:
existing_kvcache = self.past_keys_values_cache_init_infer[cache_key]
# 检查现有缓存和新缓存之间是否存在大小差异
# if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.context_length-1:
# 仅在大小小于 max_tokens - 1 时存储,以避免重置
self.past_keys_values_cache_init_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
elif self.keys_values_wm_single_env.size < self.context_length-1:
# 仅在大小小于 max_tokens - 1 时存储,以避免重置
self.past_keys_values_cache_init_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# TODO:每次都存储最新的
self.past_keys_values_cache_init_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
else:
# recurrent_infer: laten_state是预测的,会有误差
# 比较并存储较大的缓存
if cache_key in self.past_keys_values_cache_recurrent_infer:
existing_kvcache = self.past_keys_values_cache_recurrent_infer[cache_key]
# 检查现有缓存和新缓存之间是否存在大小差异
# if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.context_length_for_recurrent-1:
# 仅在大小小于 max_tokens - 1 时存储,以避免重置
# if latent_state_index_in_search_path[i] == 0:
if self.next_latent_state_depth[i][-1] <= self.recurrent_keep_deepth:
# TODO: (root_latent_state, a, current_latent_state) kv_cache 即相当于只是在root下一层的预测state才存储kv_cache
# print('save kv_cache of predicted latent state')
self.past_keys_values_cache_recurrent_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
elif self.keys_values_wm_single_env.size < self.context_length_for_recurrent-1:
# if latent_state_index_in_search_path[i] == 0: # bug: < self.action_shape**(self.recurrent_keep_deepth-1):
if self.next_latent_state_depth[i][-1] <= self.recurrent_keep_deepth:
# TODO: (root_latent_state, a, current_latent_state) kv_cache 即相当于只是在root下一层的预测state才存储kv_cache
# print('save kv_cache of predicted latent state')
self.past_keys_values_cache_recurrent_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
self.past_keys_values_cache_recurrent_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))

# # TODO: memory_env;每次都存最新的
# if is_init_infer:
# # init_infer: laten_state是encoder编码得到的,没有误差
# # 比较并存储较大的缓存
# if cache_key in self.past_keys_values_cache_init_infer:
# existing_kvcache = self.past_keys_values_cache_init_infer[cache_key]
# # 检查现有缓存和新缓存之间是否存在大小差异
# # if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
# if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.context_length-1:
# # 仅在大小小于 max_tokens - 1 时存储,以避免重置
# self.past_keys_values_cache_init_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# # elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
# elif self.keys_values_wm_single_env.size < self.context_length-1:
# # 仅在大小小于 max_tokens - 1 时存储,以避免重置
# self.past_keys_values_cache_init_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# else:
# # recurrent_infer: laten_state是预测的,会有误差
# # 比较并存储较大的缓存
# if cache_key in self.past_keys_values_cache_recurrent_infer:
# existing_kvcache = self.past_keys_values_cache_recurrent_infer[cache_key]
# # 检查现有缓存和新缓存之间是否存在大小差异
# # if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
# if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.context_length_for_recurrent-1:
# # 仅在大小小于 max_tokens - 1 时存储,以避免重置
# # if latent_state_index_in_search_path[i] == 0:
# if self.next_latent_state_depth[i][-1] <= self.recurrent_keep_deepth:
# # TODO: (root_latent_state, a, current_latent_state) kv_cache 即相当于只是在root下一层的预测state才存储kv_cache
# # print('save kv_cache of predicted latent state')
# self.past_keys_values_cache_recurrent_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
# # elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1:
# elif self.keys_values_wm_single_env.size < self.context_length_for_recurrent-1:
# # if latent_state_index_in_search_path[i] == 0: # bug: < self.action_shape**(self.recurrent_keep_deepth-1):
# if self.next_latent_state_depth[i][-1] <= self.recurrent_keep_deepth:
# # TODO: (root_latent_state, a, current_latent_state) kv_cache 即相当于只是在root下一层的预测state才存储kv_cache
# # print('save kv_cache of predicted latent state')
# self.past_keys_values_cache_recurrent_infer[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))

def retrieve_or_generate_kvcache(self, latent_state, ready_env_num, simulation_index=0):
"""
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/muzero_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
from .gpt_models.tokenizer.tokenizer import Tokenizer
from .gpt_models.tokenizer.nets import Encoder, Decoder
# from .gpt_models.cfg_cartpole import cfg
# from .gpt_models.cfg_memory import cfg # NOTE: TODO
from .gpt_models.cfg_atari import cfg
from .gpt_models.cfg_memory import cfg # NOTE: TODO
# from .gpt_models.cfg_atari import cfg

if cfg.world_model.obs_type == 'vector':
self.representation_network = RepresentationNetworkMLP(
Expand Down
Loading

0 comments on commit f5a928e

Please sign in to comment.