@@ -94,13 +94,14 @@ def _apply_rope(self, x, positions):
94
94
# TODO: refactor to use RotaryEmbedding layer?
95
95
max_wavelength = 10000
96
96
x_shape = ops .shape (x )
97
- freq_exponents = (2.0 / x_shape [- 1 ]) * ops .cast (
98
- ops . arange ( x_shape [- 1 ] // 2 , dtype = "float32" ), self . compute_dtype
97
+ freq_exponents = (2.0 / x_shape [- 1 ]) * ops .arange (
98
+ x_shape [- 1 ] // 2 , dtype = "float32"
99
99
)
100
100
timescale = max_wavelength ** freq_exponents
101
101
radians = positions [..., None ] / timescale [None , None , :]
102
102
radians = radians [..., None , :]
103
- sin , cos = ops .sin (radians ), ops .cos (radians )
103
+ sin = ops .cast (ops .sin (radians ), self .compute_dtype )
104
+ cos = ops .cast (ops .cos (radians ), self .compute_dtype )
104
105
x1 , x2 = ops .split (x , 2 , axis = - 1 )
105
106
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
106
107
# compilation on jax. We should be able to remove this once the
@@ -156,10 +157,9 @@ def call(
156
157
):
157
158
seq_len = ops .shape (x )[1 ]
158
159
start_index = cache_update_index
159
- positions = ops .cast (
160
- ops .arange (seq_len , dtype = "float32" ), self .compute_dtype
161
- )
162
- positions = positions + ops .cast (start_index , self .compute_dtype )
160
+ positions = ops .arange (seq_len , dtype = "float32" )
161
+
162
+ positions = positions + ops .cast (start_index , "float32" )
163
163
query = self .query_dense (x )
164
164
query = self ._apply_rope (query , positions )
165
165
0 commit comments