@@ -162,22 +162,14 @@ def __call__(
162
162
* (self .max_timescale / self .min_timescale ) ** fraction
163
163
)
164
164
position = position [:, :, jnp .newaxis , jnp .newaxis ]
165
- timescale = timescale [jnp .newaxis , jnp .newaxis , jnp .newaxis , :]
166
165
sinusoid_inp = position / timescale
167
166
sin = jnp .sin (sinusoid_inp )
168
167
cos = jnp .cos (sinusoid_inp )
169
- reshape_tensor = inputs .astype (jnp .float32 ).reshape (
170
- * inputs .shape [:- 1 ], 2 , - 1
171
- )
172
- reshape_tensor = jax .numpy .swapaxes (reshape_tensor , - 1 , - 2 )
173
- first_half = reshape_tensor [..., 0 ]
174
- second_half = reshape_tensor [..., 1 ]
168
+ first_half , second_half = jnp .split (inputs , 2 , axis = - 1 )
175
169
first_part = first_half * cos - second_half * sin
176
170
second_part = second_half * cos + first_half * sin
177
171
if self .cast_as_fprop_dtype :
178
172
first_part = first_part .astype (self .fprop_dtype )
179
173
second_part = second_part .astype (self .fprop_dtype )
180
- x_out = jnp .stack ((first_part , second_part ), axis = - 1 ).reshape (
181
- * first_part .shape [:- 1 ], - 1
182
- )
174
+ x_out = jnp .concatenate ((first_part , second_part ), axis = - 1 )
183
175
return x_out
0 commit comments