Skip to content

Commit

Permalink
Merge pull request #1103 from AI-Hypercomputer:fix_gmm_acc_dtype
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708032400
  • Loading branch information
maxtext authors committed Dec 19, 2024
2 parents 4651cb3 + d56d840 commit 6c767f4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion MaxText/kernels/megablox/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ def _accum(is_last_k_tile):
else:
loaded_rhs = mask_k_rem_rhs(rhs[...]).astype(input_dtype)

acc_scratch[...] += aqt_pl.dot_general(
is_quantized = lhs_quantize_dtype or rhs_quantize_dtype
# aqt_pl.dot_general did not handle accumulation dtype well
# when both lhs and rhs are not quantized. A workaround is to use lax.dot_general
dot_general = aqt_pl.dot_general if is_quantized else jax.lax.dot_general
acc_scratch[...] += dot_general(
loaded_lhs,
loaded_rhs,
preferred_element_type=jnp.float32,
Expand Down

0 comments on commit 6c767f4

Please sign in to comment.