Skip to content

Commit c2afdde

Browse files
committed
debug for remat
1 parent b2b1573 commit c2afdde

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

keras_hub/src/models/rwkv7/rwkv7_layer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def call(self, x, last_cache_x=None, train_mode=True):
112112
Mixed output tensor.
113113
"""
114114
xx = self.time_shift(x, last_cache_x) - x
115-
if last_cache_x is not None or not train_mode:
115+
if not train_mode and last_cache_x is not None:
116116
last_cache_x = x[:, -1:]
117117
k = x + xx * self.x_k
118118
k = ops.relu(self.key(k)) ** 2
@@ -390,7 +390,7 @@ def call(
390390
B, T, C = ops.shape(x)
391391
H = self.n_head
392392
xx = self.time_shift(x, last_cache_x) - x
393-
if last_cache_x is not None or not train_mode:
393+
if not train_mode and last_cache_x is not None:
394394
last_cache_x = x[:, -1:]
395395
if padding_mask is not None:
396396
xx *= padding_mask
@@ -472,13 +472,14 @@ def reshape_and_cast(x, new_shape, dtype="float32"):
472472
x = x + ops.reshape(rwkv, (B, T, C))
473473
x = self.output_layer(x * g)
474474
if train_mode:
475-
return x, v_first
475+
return x, v_first, finnal_state
476476
return x, v_first, last_cache_x, finnal_state
477477

478478
def compute_output_shape(self, input_shape):
479479
output_shapes = [
480480
[None, None, self.hidden_size],
481481
[None, None, self.hidden_size],
482+
[None, self.n_head, self.head_size, self.head_size],
482483
]
483484
return output_shapes
484485

@@ -621,7 +622,7 @@ def call(
621622
if self.use_initial_norm:
622623
x = self.ln0(x)
623624
if train_mode:
624-
xx, v_first = self.att(
625+
xx, v_first, state = self.att(
625626
self.ln1(x),
626627
v_first=v_first,
627628
padding_mask=padding_mask,

0 commit comments

Comments
 (0)