Skip to content

Commit

Permalink
polish(pu): polish world_model kv_cache statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 4, 2024
1 parent 9f5b941 commit d9ed770
Show file tree
Hide file tree
Showing 10 changed files with 1,250 additions and 72 deletions.
7 changes: 6 additions & 1 deletion lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def train_muzero_gpt(
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0:
if i % 20 == 0:
# if i % 2 == 0:# for reanalyze_ratio>0
# if i % 2 == 0:# for reanalyze_ratio>0
policy._target_model.world_model.past_keys_values_cache.clear()
policy._target_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
torch.cuda.empty_cache() # TODO: 是否需要立即释放显存
Expand Down Expand Up @@ -260,6 +260,11 @@ def train_muzero_gpt(
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

policy._target_model.world_model.past_keys_values_cache.clear()
policy._target_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
torch.cuda.empty_cache() # TODO: 是否需要立即释放显存
print('sample target_model past_keys_values_cache.clear()')

# NOTE: TODO
# TODO: for batch world model ,to improve kv reuse, we could donot reset
policy._learn_model.world_model.past_keys_values_cache.clear() # very important
Expand Down
10 changes: 7 additions & 3 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,26 @@
# "max_tokens": 2 * 6, # TODO: horizon

# 'embed_dim':512, # TODO:for atari
'embed_dim':1024, # TODO:for atari
# 'embed_dim':256, # TODO:for atari
# 'embed_dim':1024, # TODO:for atari
'embed_dim':768, # TODO:for atari


'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
# 'num_layers': 2, # TODO:for atari debug
# 'num_heads': 4,
# 'num_layers': 1, # TODO:for atari debug
# 'num_heads': 1,

'num_layers': 2, # TODO:for atari debug
'num_heads': 2,
'num_heads': 4,


'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:6',
"device": 'cuda:0',
# "device": 'cpu',
# 'support_size': 21,
'support_size': 601,
Expand Down
84 changes: 70 additions & 14 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
from ding.torch_utils import to_device
# from memory_profiler import profile
from line_profiler import line_profiler
import hashlib
# def quantize_state(state, num_buckets=1000):
def quantize_state(state, num_buckets=15):
# def quantize_state(state, num_buckets=10):
"""
量化状态向量。
参数:
state: 要量化的状态向量。
num_buckets: 量化的桶数。
返回:
量化后的状态向量的哈希值。
"""
# 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中
quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets))
# 使用更稳定的哈希函数
quantized_state_bytes = quantized_state.tobytes()
hash_object = hashlib.sha256(quantized_state_bytes)
return hash_object.hexdigest()

@dataclass
class WorldModelOutput:
Expand Down Expand Up @@ -242,6 +260,10 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
)
self.hit_count = 0
self.total_query_count = 0
self.length3_context_cnt = 0
self.length2_context_cnt = 0
self.root_hit_cnt = 0
self.root_total_query_cnt = 0



Expand Down Expand Up @@ -440,15 +462,28 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent
self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)
for i in range(latent_state.size(0)): # 遍历每个环境
state_single_env = latent_state[i] # 获取单个环境的 latent state
cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值
# cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值
quantized_state = state_single_env.detach().cpu().numpy()
cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值
for layer in range(self.num_layers):
self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512])
self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0)
self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size
self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size
# keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0))
self.root_total_query_cnt += 1
if cache_key not in self.past_keys_values_cache:
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu'))
else:
self.root_hit_cnt += 1
root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt
print('root_total_query_count:', self.root_total_query_count)
print(f'root_hit_ratio:{root_hit_ratio}')
print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}')
if self.past_keys_values_cache[cache_key].size>1:
print(f'=='*20)
print(f'NOTE: root_hit find size > 1')
print(f'=='*20)

elif n == int(256):
# TODO: n=256 means train tokenizer, 不需要计算target value
Expand Down Expand Up @@ -610,13 +645,13 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可
# TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表?

# if self.total_query_count>0 and self.total_query_count%99999==0:
# # if self.total_query_count>0 and self.total_query_count%1==0:
# self.hit_freq = self.hit_count/(self.total_query_count)
# print('hit_freq:', self.hit_freq)
# print('hit_count:', self.hit_count)
# print('total_query_count:', self.total_query_count)
# print(self.keys_values_wm_size_list)
if self.total_query_count>0 and self.total_query_count%5000==0:
# if self.total_query_count>0 and self.total_query_count%1==0:
self.hit_freq = self.hit_count/(self.total_query_count)
# print('hit_freq:', self.hit_freq)
# print('hit_count:', self.hit_count)
print('total_query_count:', self.total_query_count)


latest_state = state_action_history[-1][0]

Expand All @@ -627,8 +662,8 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
for i in range(ready_env_num):
self.total_query_count += 1
state_single_env = latest_state[i] # 获取单个环境的 latent state
hash_latest_state = hash(state_single_env) # 计算哈希值
matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值
cache_key = quantize_state(state_single_env) # 使用量化后的状态计算哈希值
matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值
if matched_value is not None:
# 如果找到匹配的值,将其添加到列表中
self.hit_count += 1
Expand All @@ -645,16 +680,34 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1
output_sequence, latent_state = [], []

# print(self.keys_values_wm_size_list)
reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens]
self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices)

# reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens]
# self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices)

action = state_action_history[-1][-1]
token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long)
token = token.reshape(-1, 1).to(self.device) # (B, 1)

# print(self.keys_values_wm_size_list)
# 获取self.keys_values_wm_size_list的最小值min_size
min_size = min(self.keys_values_wm_size_list)
if min_size >= self.config.max_tokens - 5:
self.length3_context_cnt += len(self.keys_values_wm_size_list)
if min_size >= 3:
self.length2_context_cnt += len(self.keys_values_wm_size_list)
# if max(self.keys_values_wm_size_list) == 7:
# print('max(self.keys_values_wm_size_list) == 7')
# if self.total_query_count>0 and self.total_query_count%1==0:
if self.total_query_count>0 and self.total_query_count%5000==0:
# 如果总查询次数大于0,计算并打印cnt的比率
length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count
print('>=3 node context_cnt:', self.length3_context_cnt)
print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio)
length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count
print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio)
print('>=2 node context_cnt:', self.length2_context_cnt)
# print(self.keys_values_wm_size_list)

for layer in range(self.num_layers):
# 每层的k和v缓存列表
kv_cache_k_list = []
Expand Down Expand Up @@ -730,7 +783,9 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy
for i in range(self.latent_state.size(0)): # 遍历每个环境
state_single_env = self.latent_state[i] # 获取单个环境的 latent state
cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值
# cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值
quantized_state = state_single_env.detach().cpu().numpy()
cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值
# 复制单个环境对应的 keys_values_wm 并存储
for layer in range(self.num_layers):
self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512])
Expand All @@ -754,6 +809,7 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# TODO: lru_cache
_, popped_kv_cache = self.past_keys_values_cache.popitem(last=False)
del popped_kv_cache # 不要这一行
# print('len(self.past_keys_values_cache) > self.max_cache_size')

# Example usage:
# Assuming `past_keys_values_cache` is a populated instance of `KeysValues`
Expand Down
Loading

0 comments on commit d9ed770

Please sign in to comment.