Skip to content

Commit

Permalink
fix(pu): fix kv_cache, use one env in recurrent_inference() in search
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Feb 22, 2024
1 parent 7f80343 commit 0ad129e
Show file tree
Hide file tree
Showing 8 changed files with 2,171 additions and 146 deletions.
6 changes: 1 addition & 5 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
'embed_dim': 128, # z_channels
# 'embed_dim': 1024, # z_channels
# 'embed_dim': 256, # z_channels

'encoder':
'encoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug
Expand Down Expand Up @@ -61,7 +60,6 @@
'embed_dim':1024, # TODO:for atari
# 'embed_dim':256, # TODO:for atari


'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
# 'num_layers': 2, # TODO:for atari debug
Expand Down Expand Up @@ -93,11 +91,9 @@
# "env_num":1, # TODO
'latent_recon_loss_weight':0.05,
'perceptual_loss_weight':0.05,

# 'latent_recon_loss_weight':0.,
# 'perceptual_loss_weight':0.,


}
from easydict import EasyDict
cfg = EasyDict(cfg)
Expand Down
6 changes: 5 additions & 1 deletion lzero/model/gpt_models/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Mo
self.head_module = head_module

def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
if isinstance(prev_steps, torch.Tensor):
x_sliced = [x[i, self.compute_slice(num_steps, prev_steps[i].item())] for i in range(prev_steps.shape[0])]
x_sliced = torch.cat(x_sliced, dim=0)
elif isinstance(prev_steps, int):
x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
return self.head_module(x_sliced)


Expand Down
17 changes: 8 additions & 9 deletions lzero/model/gpt_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,23 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.
kv_cache.update(k, v)
k, v = kv_cache.get()

# method1: manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v

# TODO
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)

# method1: efficient attention using Flash Attention CUDA kernels
# method2: efficient attention using Flash Attention CUDA kernels
# attn_mask = self.mask[L:L + T, :L + T].bool() # assuming your mask is a ByteTensor
# eval性能很不好,与collect不一致
# y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.attn_pdrop if self.training else 0, is_causal=True)
# 测试
# y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.attn_pdrop, is_causal=True)

# method2: manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v


y = rearrange(y, 'b h t e -> b t (h e)')

y = self.resid_drop(self.proj(y))
Expand Down
230 changes: 111 additions & 119 deletions lzero/model/gpt_models/world_model.py

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions lzero/model/gpt_models/world_model_envnum1_kv-latent-1-env.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,11 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可
# TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表?

if self.total_query_count>0:
self.hit_freq = self.hit_count/self.total_query_count
print('hit_freq:', self.hit_freq)
print('hit_count:', self.hit_count)
print('total_query_count:', self.total_query_count)
# if self.total_query_count>0:
# self.hit_freq = self.hit_count/self.total_query_count
# print('hit_freq:', self.hit_freq)
# print('hit_count:', self.hit_count)
# print('total_query_count:', self.total_query_count)


latest_state = state_action_history[-1][0]
Expand Down
Loading

0 comments on commit 0ad129e

Please sign in to comment.