Skip to content

Commit

Permalink
Update embeddings.py
Browse files Browse the repository at this point in the history
  • Loading branch information
prrathi authored Apr 4, 2024
1 parent de49d83 commit 1120463
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def __call__(
)
position = position[:, :, jnp.newaxis, jnp.newaxis]
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
sin = jnp.sin(sinusoid_inp).astype(inputs.dtype)
cos = jnp.cos(sinusoid_inp).astype(inputs.dtype)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
Expand Down Expand Up @@ -198,4 +198,4 @@ def __call__(
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = -1)
# signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]])
position_embedding = signal.astype(jnp.float32)
return input_embedding + position_embedding
return input_embedding + position_embedding

0 comments on commit 1120463

Please sign in to comment.