diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 45f77ce494..b494d559bd 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -85,30 +85,42 @@ def __init__( self.built = True def call(self, inputs, start_index=0): + inputs = ops.moveaxis( + inputs, (self.feature_axis, self.sequence_axis), (-1, 1) + ) cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) - return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + return ops.moveaxis( + output, (-1, 1), (self.feature_axis, self.sequence_axis) + ) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - x1, x2 = ops.split(tensor, 2, axis=self.feature_axis) - half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) + x1, x2 = ops.split(tensor, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + half_rot_tensor = ops.stack((-x2, x1), axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, inputs, start_index=0): - def get_axis(axis): - return axis if axis > 0 else len(inputs.shape) + axis + start_index = ops.cast(start_index, dtype="float32") - feature_axis = get_axis(self.feature_axis) - sequence_axis = get_axis(self.sequence_axis) + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 rotary_dim = ops.shape(inputs)[feature_axis] inverse_freq = self._get_inverse_freq(rotary_dim) - seq_len = ops.shape(inputs)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + seq_len = ops.shape(inputs)[sequence_axis] + tensor = ops.arange(seq_len, dtype="float32") + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i,j->ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=-1) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) # Reshape the embedding to be broadcastable with input shape. if feature_axis < sequence_axis: @@ -117,17 +129,16 @@ def get_axis(axis): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) - return ops.cos(embedding), ops.sin(embedding) + cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) + sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) + return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) + freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = freq_range / ops.cast(self.scaling_factor, "float32") inverse_freq = 1.0 / ( self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ** (freq_range / ops.cast(rotary_dim, "float32")) ) return inverse_freq diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index e01c1f8ce4..4b391264a2 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding from keras_nlp.utils.keras_utils import clone_initializer @@ -87,28 +88,23 @@ def build(self, inputs_shape): (None, None, self.num_query_heads, self.head_dim) ) self.softmax = keras.layers.Softmax(dtype="float32") + + self.rope_layer = RotaryEmbedding( + max_wavelength=10_000.0, dtype=self.dtype_policy + ) + self.built = True - def _apply_rope(self, x, positions): + def _apply_rope(self, x, start_index): """Rope rotate q or k.""" - # TODO: refactor to use RotaryEmbedding layer? - max_wavelength = 10000 - x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.arange( - x_shape[-1] // 2, dtype="float32" + x = self.rope_layer(x, start_index=start_index) + # Gemma uses a different layout for positional embeddings. + # The transformation below ensures the embeddings are numerically + # equivalent to the original gemma implementation. + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) ) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - sin = ops.cast(ops.sin(radians), self.compute_dtype) - cos = ops.cast(ops.cos(radians), self.compute_dtype) - x1, x2 = ops.split(x, 2, axis=-1) - # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA - # compilation on jax. We should be able to remove this once the - # following PR is in all jax releases we care about: - # https://github.com/openxla/xla/pull/7875 - output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - return ops.reshape(output, x_shape) + return x def _compute_attention( self, @@ -155,19 +151,14 @@ def call( cache_update_index=0, training=False, ): - seq_len = ops.shape(x)[1] - start_index = cache_update_index - positions = ops.arange(seq_len, dtype="float32") - - positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) - query = self._apply_rope(query, positions) + query = self._apply_rope(query, cache_update_index) if cache is not None: key_cache = cache[:, 0, ...] value_cache = cache[:, 1, ...] key_update = self.key_dense(x) - key_update = self._apply_rope(key_update, positions) + key_update = self._apply_rope(key_update, cache_update_index) value_update = self.value_dense(x) start = [0, cache_update_index, 0, 0] key = ops.slice_update(key_cache, start, key_update) @@ -175,7 +166,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, positions) + key = self._apply_rope(key, cache_update_index) value = self.value_dense(x) attention_vec = self._compute_attention(