Skip to content

Commit 09d2fdd

Browse files
authored
Keep rope at float32 precision (#1497)
* Keep rope at float32 precision * Carry out all of RoPE in float32 * Formatting * Cleanup * Do not cast x
1 parent a8da424 commit 09d2fdd

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

keras_nlp/models/gemma/gemma_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ def _apply_rope(self, x, positions):
9494
# TODO: refactor to use RotaryEmbedding layer?
9595
max_wavelength = 10000
9696
x_shape = ops.shape(x)
97-
freq_exponents = (2.0 / x_shape[-1]) * ops.cast(
98-
ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype
97+
freq_exponents = (2.0 / x_shape[-1]) * ops.arange(
98+
x_shape[-1] // 2, dtype="float32"
9999
)
100100
timescale = max_wavelength**freq_exponents
101101
radians = positions[..., None] / timescale[None, None, :]
102102
radians = radians[..., None, :]
103-
sin, cos = ops.sin(radians), ops.cos(radians)
103+
sin = ops.cast(ops.sin(radians), self.compute_dtype)
104+
cos = ops.cast(ops.cos(radians), self.compute_dtype)
104105
x1, x2 = ops.split(x, 2, axis=-1)
105106
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
106107
# compilation on jax. We should be able to remove this once the
@@ -156,10 +157,9 @@ def call(
156157
):
157158
seq_len = ops.shape(x)[1]
158159
start_index = cache_update_index
159-
positions = ops.cast(
160-
ops.arange(seq_len, dtype="float32"), self.compute_dtype
161-
)
162-
positions = positions + ops.cast(start_index, self.compute_dtype)
160+
positions = ops.arange(seq_len, dtype="float32")
161+
162+
positions = positions + ops.cast(start_index, "float32")
163163
query = self.query_dense(x)
164164
query = self._apply_rope(query, positions)
165165

0 commit comments

Comments
 (0)