Skip to content

Commit 73e449d

Browse files
Merge pull request #2777 from AI-Hypercomputer:nicogrande/fix-kvcache-unpacking
PiperOrigin-RevId: 839806334
2 parents a0d9ab3 + 5499607 commit 73e449d

File tree

14 files changed

+85
-3
lines changed

14 files changed

+85
-3
lines changed

src/MaxText/layers/deepseek.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def __call__(
171171
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
172172
else:
173173
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
174+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
175+
if isinstance(inputs, tuple):
176+
inputs = inputs[0]
174177
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
175178
inputs = checkpoint_name(inputs, "decoder_layer_input")
176179

@@ -240,6 +243,10 @@ def __call__(
240243
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
241244
else:
242245
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
246+
247+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
248+
if isinstance(inputs, tuple):
249+
inputs = inputs[0]
243250
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
244251
inputs = checkpoint_name(inputs, "decoder_layer_input")
245252

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def __call__(
6161
kv_cache=None,
6262
attention_metadata=None,
6363
):
64+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
65+
if isinstance(inputs, tuple):
66+
inputs = inputs[0]
6467
x = self.with_logical_constraint(inputs)
6568
x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input")
6669

src/MaxText/layers/gemma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def __call__(
132132
kv_cache=None,
133133
attention_metadata=None,
134134
):
135+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
136+
if isinstance(inputs, tuple):
137+
inputs = inputs[0]
135138
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
136139
inputs = checkpoint_name(inputs, "decoder_layer_input")
137140
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]

src/MaxText/layers/gemma2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def __call__(
226226
kv_cache=None,
227227
attention_metadata=None,
228228
):
229+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
230+
if isinstance(inputs, tuple):
231+
inputs = inputs[0]
229232
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
230233
inputs = checkpoint_name(inputs, "decoder_layer_input")
231234
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]

src/MaxText/layers/gemma3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def __call__(
193193
attention_metadata=None,
194194
):
195195
cfg = self.config
196+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
197+
if isinstance(inputs, tuple):
198+
inputs = inputs[0]
196199
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
197200
inputs = checkpoint_name(inputs, "decoder_layer_input")
198201

src/MaxText/layers/gpt3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,10 @@ def __call__(
430430
kv_cache=None,
431431
attention_metadata=None,
432432
):
433+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
434+
if isinstance(inputs, tuple):
435+
inputs = inputs[0]
436+
433437
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
434438
inputs = checkpoint_name(inputs, "decoder_layer_input")
435439
lnx = self.pre_self_attention_norm(inputs)

src/MaxText/layers/gpt_oss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def __call__(
149149
attention_metadata=None,
150150
):
151151
cfg = self.config
152+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
153+
if isinstance(inputs, tuple):
154+
inputs = inputs[0]
152155

153156
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
154157
inputs = checkpoint_name(inputs, "decoder_layer_input")

src/MaxText/layers/llama2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def __call__(
152152
):
153153
cfg = self.config
154154

155+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
156+
if isinstance(inputs, tuple):
157+
inputs = inputs[0]
155158
inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names)
156159
inputs = checkpoint_name(inputs, "decoder_layer_input")
157160
lnx_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.activation_axis_names))

src/MaxText/layers/llama4.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ def __call__(
454454
cfg = self.config
455455
assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`."
456456

457+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
458+
if isinstance(inputs, tuple):
459+
inputs = inputs[0]
457460
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
458461
inputs = checkpoint_name(inputs, "decoder_layer_input")
459462

src/MaxText/layers/mistral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,11 @@ def __call__(
135135
kv_cache=None,
136136
attention_metadata=None,
137137
):
138-
139138
cfg = self.config
140139

140+
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
141+
if isinstance(inputs, tuple):
142+
inputs = inputs[0]
141143
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
142144
inputs = checkpoint_name(inputs, "decoder_layer_input")
143145
lnx = self.pre_self_attention_layer_norm(inputs)

0 commit comments

Comments
 (0)