diff --git a/fla/layers/rwkv7.py b/fla/layers/rwkv7.py index 689d6b7ed..6edad23d6 100644 --- a/fla/layers/rwkv7.py +++ b/fla/layers/rwkv7.py @@ -150,7 +150,7 @@ def forward( v = self.v_proj(xv) if self.layer_idx == 0: - v_first.copy_(v) + v_first = v else: v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid()) a = self.a_lora(xa).sigmoid() @@ -162,7 +162,7 @@ def forward( # dealing with left-padding if attention_mask is not None: - v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + v = v * attention_mask[:, -v.shape[-2]:, None] r, w, k, v, kk, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (r, w, k, v, kk, a)) recurrent_state = last_state['recurrent_state'] if last_state is not None else None