From 3d40750504556caf4f0765aca43bc582cc699885 Mon Sep 17 00:00:00 2001 From: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:50:00 +0000 Subject: [PATCH] Do not cast x --- keras_nlp/models/gemma/gemma_attention.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index f14ff79caf..e01c1f8ce4 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -92,9 +92,6 @@ def build(self, inputs_shape): def _apply_rope(self, x, positions): """Rope rotate q or k.""" # TODO: refactor to use RotaryEmbedding layer? - x = ops.cast( - x, dtype="float32" - ) # Carry out rope in float32, then downcast max_wavelength = 10000 x_shape = ops.shape(x) freq_exponents = (2.0 / x_shape[-1]) * ops.arange( @@ -103,14 +100,14 @@ def _apply_rope(self, x, positions): timescale = max_wavelength**freq_exponents radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] - sin, cos = ops.sin(radians), ops.cos(radians) + sin = ops.cast(ops.sin(radians), self.compute_dtype) + cos = ops.cast(ops.cos(radians), self.compute_dtype) x1, x2 = ops.split(x, 2, axis=-1) # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA # compilation on jax. We should be able to remove this once the # following PR is in all jax releases we care about: # https://github.com/openxla/xla/pull/7875 output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - output = ops.cast(output, dtype=self.compute_dtype) return ops.reshape(output, x_shape) def _compute_attention(