@@ -112,7 +112,7 @@ def call(self, x, last_cache_x=None, train_mode=True):
112112 Mixed output tensor.
113113 """
114114 xx = self .time_shift (x , last_cache_x ) - x
115- if last_cache_x is not None or not train_mode :
115+ if not train_mode and last_cache_x is not None :
116116 last_cache_x = x [:, - 1 :]
117117 k = x + xx * self .x_k
118118 k = ops .relu (self .key (k )) ** 2
@@ -390,7 +390,7 @@ def call(
390390 B , T , C = ops .shape (x )
391391 H = self .n_head
392392 xx = self .time_shift (x , last_cache_x ) - x
393- if last_cache_x is not None or not train_mode :
393+ if not train_mode and last_cache_x is not None :
394394 last_cache_x = x [:, - 1 :]
395395 if padding_mask is not None :
396396 xx *= padding_mask
@@ -472,13 +472,14 @@ def reshape_and_cast(x, new_shape, dtype="float32"):
472472 x = x + ops .reshape (rwkv , (B , T , C ))
473473 x = self .output_layer (x * g )
474474 if train_mode :
475- return x , v_first
475+ return x , v_first , finnal_state
476476 return x , v_first , last_cache_x , finnal_state
477477
478478 def compute_output_shape (self , input_shape ):
479479 output_shapes = [
480480 [None , None , self .hidden_size ],
481481 [None , None , self .hidden_size ],
482+ [None , self .n_head , self .head_size , self .head_size ],
482483 ]
483484 return output_shapes
484485
@@ -621,7 +622,7 @@ def call(
621622 if self .use_initial_norm :
622623 x = self .ln0 (x )
623624 if train_mode :
624- xx , v_first = self .att (
625+ xx , v_first , state = self .att (
625626 self .ln1 (x ),
626627 v_first = v_first ,
627628 padding_mask = padding_mask ,
0 commit comments