Skip to content

Commit

Permalink
Do not cast x
Browse files Browse the repository at this point in the history
  • Loading branch information
grasskin committed Mar 13, 2024
1 parent 854d7b4 commit 3d40750
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions keras_nlp/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def build(self, inputs_shape):
def _apply_rope(self, x, positions):
"""Rope rotate q or k."""
# TODO: refactor to use RotaryEmbedding layer?
x = ops.cast(
x, dtype="float32"
) # Carry out rope in float32, then downcast
max_wavelength = 10000
x_shape = ops.shape(x)
freq_exponents = (2.0 / x_shape[-1]) * ops.arange(
Expand All @@ -103,14 +100,14 @@ def _apply_rope(self, x, positions):
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
sin, cos = ops.sin(radians), ops.cos(radians)
sin = ops.cast(ops.sin(radians), self.compute_dtype)
cos = ops.cast(ops.cos(radians), self.compute_dtype)
x1, x2 = ops.split(x, 2, axis=-1)
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
# compilation on jax. We should be able to remove this once the
# following PR is in all jax releases we care about:
# https://github.com/openxla/xla/pull/7875
output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
output = ops.cast(output, dtype=self.compute_dtype)
return ops.reshape(output, x_shape)

def _compute_attention(
Expand Down

0 comments on commit 3d40750

Please sign in to comment.