diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..28f4e071fc 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -297,6 +297,9 @@ from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder as Gemma3VisionEncoder, ) +from keras_hub.src.models.gemma3n.gemma3n_backbone import ( + Gemma3nBackbone as Gemma3nBackbone, +) from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import ( GPT2CausalLM as GPT2CausalLM, diff --git a/keras_hub/src/models/gemma3n/__init__.py b/keras_hub/src/models/gemma3n/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/gemma3n/gemma3n_attention.py b/keras_hub/src/models/gemma3n/gemma3n_attention.py new file mode 100644 index 0000000000..dc1adaadff --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_attention.py @@ -0,0 +1,605 @@ +import math + +import keras +import numpy as np + +from keras_hub.src.models.gemma3n.gemma3n_utils import apply_rotary_pos_emb +from keras_hub.src.models.gemma3n.gemma3n_utils import eager_attention_forward +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioRelativePositionEmbedding(keras.layers.Layer): + """A layer for learning relative position embeddings for audio sequences. + + This layer implements the relative position embedding mechanism used in the + audio tower of the Gemma3n model. It computes position-aware attention + scores by generating a timing signal based on relative positions between + queries and keys, which is then projected and added to the content-based + attention logits. + + Args: + hidden_size: int. The size of the hidden state. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_context_left: int. The number of steps to attend to in + the past, including the current step. + conf_attention_context_right: int. The number of steps to attend to in + the future. + """ + + def __init__( + self, + hidden_size, + conf_num_attention_heads, + conf_attention_context_left, + conf_attention_context_right, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_context_right = conf_attention_context_right + self.num_heads = conf_num_attention_heads + self.channels = hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, conf_attention_context_left - 1) + self.max_forward = conf_attention_context_right + self.pos_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="pos_proj", + dtype=self.dtype_policy, + ) + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * np.exp( + np.arange(num_timescales, dtype="float32") + * -log_timescale_increment + ) + self.inv_timescales = keras.ops.expand_dims( + keras.ops.expand_dims( + keras.ops.convert_to_tensor(inv_timescales), 0 + ), + 0, + ) + + def build(self, input_shape): + self.pos_proj.build((None, self.channels)) + super().build(input_shape) + + def _get_timing_signal_1d_pos(self, position, dtype): + position = keras.ops.cast( + keras.ops.expand_dims(position, axis=-1), "float32" + ) + scaled_time = position * keras.ops.cast(self.inv_timescales, "float32") + timing_signal = keras.ops.concatenate( + [keras.ops.sin(scaled_time), keras.ops.cos(scaled_time)], axis=-1 + ) + return keras.ops.cast(timing_signal, dtype) + + def _relative_shift( + self, + term_bd_before_shift, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ): + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + padding_tuple = [[0, 0]] * (len(term_bd_before_shift.shape) - 1) + [ + [0, pad_amount_last_dim] + ] + term_bd_padded = keras.ops.pad(term_bd_before_shift, padding_tuple) + term_bd_reshaped = keras.ops.reshape( + term_bd_padded, + ( + batch_size, + num_heads, + -1, + ), + )[:, :, : query_block_size * key_context_size] + term_bd_shifted = keras.ops.reshape( + term_bd_reshaped, + ( + batch_size, + num_heads, + -1, + query_block_size, + key_context_size, + ), + ) + return term_bd_shifted + + def _int8_call(self, queries, keys): + original_dtype = queries.dtype + queries_calc = keras.ops.cast(queries, "float32") + keys_calc = keras.ops.cast(keys, "float32") + result_calc = self.call(queries_calc, keys_calc) + return keras.ops.cast(result_calc, original_dtype) + + def call(self, queries, keys): + batch_size = keras.ops.shape(queries)[0] + ( + _, + num_query_blocks, + query_block_size, + num_heads, + head_dim, + ) = queries.shape + _, _, key_context_size, _, _ = keys.shape + pos_indices = keras.ops.expand_dims( + keras.ops.arange( + self.max_backward, -self.max_forward - 1, -1, dtype="float32" + ), + 0, + ) + max_span_plus_1 = pos_indices.shape[1] + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + projected_sin_emb = self.pos_proj(sin_emb_timing_signal) + sin_emb = keras.ops.squeeze( + keras.ops.reshape( + projected_sin_emb, + (1, max_span_plus_1, self.num_heads, self.head_dim), + ), + axis=0, + ) + queries_p = keras.ops.transpose(queries, (0, 3, 1, 2, 4)) + keys_p_t = keras.ops.transpose(keys, (0, 3, 1, 4, 2)) + term_ac = keras.ops.matmul(queries_p, keys_p_t) + q_permuted = keras.ops.transpose(queries, (0, 3, 1, 2, 4)) + s_permuted = keras.ops.transpose(sin_emb, (1, 2, 0)) + + q_reshaped_dim = -1 + if num_query_blocks is not None: + q_reshaped_dim = num_query_blocks * query_block_size + + q_reshaped = keras.ops.reshape( + q_permuted, + ( + batch_size * num_heads, + q_reshaped_dim, + head_dim, + ), + ) + term_bd_unshifed_matmul = keras.ops.matmul(q_reshaped, s_permuted) + term_bd_unshifed = keras.ops.reshape( + term_bd_unshifed_matmul, + ( + batch_size, + num_heads, + -1, + query_block_size, + max_span_plus_1, + ), + ) + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + return term_ac + term_bd_shifted + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + } + ) + return config + + +class Gemma3nTextAttention(keras.layers.Layer): + """A multi-head attention layer for text sequences. + + This layer implements the text attention mechanism for the Gemma3n model, + which is a standard multi-head attention architecture. It includes features + such as Grouped-Query Attention (GQA), RMS Normalization for query and key + states, and Rotary Position Embeddings (RoPE) to incorporate positional + information. + + Args: + hidden_size: int. The size of the hidden state. + num_attention_heads: int. The number of query attention heads. + num_key_value_heads: int. The number of key and value attention heads. + If `num_key_value_heads` is not equal to `num_attention_heads`, this + layer implements Grouped-Query Attention. + head_dim: int. The dimension of each attention head. + attention_dropout: float. Dropout probability for the attention scores. + attention_bias: bool. If `True`, dense layers for query, key, value, + and output projections will use a bias term. + rms_norm_eps: float. The epsilon value for RMS Normalization layers. + sliding_window: int, optional. The size of the sliding window for + local attention. If `None`, global attention is used. Defaults to + `None`. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_dropout, + attention_bias, + rms_norm_eps, + sliding_window=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.rms_norm_eps = rms_norm_eps + self.sliding_window = sliding_window + self.num_key_value_groups = ( + self.num_attention_heads // self.num_key_value_heads + ) + self.q_proj = keras.layers.Dense( + self.num_attention_heads * self.head_dim, + use_bias=self.attention_bias, + name="q_proj", + dtype=self.dtype_policy, + ) + self.k_proj = keras.layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="k_proj", + dtype=self.dtype_policy, + ) + self.v_proj = keras.layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="v_proj", + dtype=self.dtype_policy, + ) + self.o_proj = keras.layers.Dense( + self.hidden_size, + use_bias=self.attention_bias, + name="o_proj", + dtype=self.dtype_policy, + ) + self.q_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + name="q_norm", + dtype=self.dtype_policy, + ) + self.k_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + name="k_norm", + dtype=self.dtype_policy, + ) + self.v_norm = Gemma3nRMSNorm( + dim=self.head_dim, + eps=self.rms_norm_eps, + with_scale=False, + name="v_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.q_proj.build(input_shape) + self.k_proj.build(input_shape) + self.v_proj.build(input_shape) + self.o_proj.build( + input_shape[:-1] + (self.num_attention_heads * self.head_dim,) + ) + norm_shape = input_shape[:-1] + ( + self.num_attention_heads, + self.head_dim, + ) + self.q_norm.build(norm_shape) + self.k_norm.build(norm_shape) + self.v_norm.build(norm_shape) + super().build(input_shape) + + def call( + self, hidden_states, position_embeddings, attention_mask, training=False + ): + input_shape = keras.ops.shape(hidden_states)[:-1] + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states) + query_states = keras.ops.reshape( + query_states, + input_shape + (self.num_attention_heads, self.head_dim), + ) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb( + query_states, cos, sin, unsqueeze_dim=2 + ) + query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) + key_states = self.k_proj(hidden_states) + key_states = keras.ops.reshape( + key_states, input_shape + (self.num_key_value_heads, self.head_dim) + ) + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) + value_states = self.v_proj(hidden_states) + value_states = keras.ops.reshape( + value_states, + input_shape + (self.num_key_value_heads, self.head_dim), + ) + value_states = self.v_norm(value_states) + value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) + attn_output, attn_weights = eager_attention_forward( + query_states, + key_states, + value_states, + self.num_key_value_groups, + self.head_dim, + attention_mask, + dropout=self.attention_dropout if training else 0.0, + training=training, + ) + attn_output = keras.ops.reshape(attn_output, input_shape + (-1,)) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_dropout": self.attention_dropout, + "attention_bias": self.attention_bias, + "rms_norm_eps": self.rms_norm_eps, + "sliding_window": self.sliding_window, + } + ) + return config + + +class Gemma3nAudioAttention(keras.layers.Layer): + """An attention layer specialized for audio sequences. + + This layer implements the attention mechanism for the audio tower of the + Gemma3n model. It is designed to handle long audio sequences by processing + the input in fixed-size chunks. For each chunk of queries, it attends to a + larger context of keys and values, defined by a left (past) and right + (future) context window. This allows the model to capture local and more + distant dependencies efficiently. + + Args: + hidden_size: int. The size of the hidden state. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The size of each processing chunk. + conf_attention_context_right: int. The number of steps to attend to in + the future. + conf_attention_context_left: int. The number of steps to attend to in + the past, including the current step. + conf_attention_logit_cap: float. The soft cap value to apply to the + attention logits. + """ + + def __init__( + self, + hidden_size, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.num_heads = conf_num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = conf_attention_chunk_size + self.max_future_horizon = conf_attention_context_right + self.max_past_horizon = max(0, conf_attention_context_left - 1) + self.attention_logits_soft_cap = conf_attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + self.relative_position_embedding = ( + Gemma3nAudioRelativePositionEmbedding( + hidden_size, + conf_num_attention_heads, + conf_attention_context_left, + conf_attention_context_right, + name="relative_position_embedding", + dtype=self.dtype_policy, + ) + ) + self.q_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="q_proj", + dtype=self.dtype_policy, + ) + self.k_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="k_proj", + dtype=self.dtype_policy, + ) + self.v_proj = keras.layers.Dense( + self.num_heads * self.head_dim, + use_bias=False, + name="v_proj", + dtype=self.dtype_policy, + ) + q_scale = self.head_dim**-0.5 + r_softplus_0 = 1.0 / np.log(1 + np.exp(0.0)) # softplus(0) for numpy + self.q_scale = q_scale * r_softplus_0 + + lower_causal_mask = np.tril( + np.ones((self.context_size, self.chunk_size), dtype=bool), k=0 + ).T + upper_causal_mask = np.tril( + np.ones((self.chunk_size, self.context_size), dtype=bool), + k=self.max_past_horizon + self.max_future_horizon, + ) + local_causal_valid_mask = np.ones( + (self.chunk_size, self.context_size), dtype=bool + ) + local_causal_valid_mask = ( + local_causal_valid_mask * lower_causal_mask * upper_causal_mask + ) + self.local_causal_valid_mask = keras.ops.convert_to_tensor( + local_causal_valid_mask + ) + self.softcap = keras.ops.convert_to_tensor( + self.attention_logits_soft_cap, dtype="float32" + ) + + def build(self, input_shape): + self.per_dim_scale = self.add_weight( + shape=(self.head_dim,), + initializer="zeros", + trainable=True, + name="per_dim_scale", + dtype=self.dtype_policy.variable_dtype, + ) + self.q_proj.build(input_shape) + self.k_proj.build(input_shape) + self.v_proj.build(input_shape) + self.relative_position_embedding.build(input_shape) + super().build(input_shape) + + def _pad_dim1(self, x, pad_left, pad_right): + paddings = [[0, 0], [pad_left, pad_right]] + [ + [0, 0] for _ in range(len(x.shape) - 2) + ] + return keras.ops.pad(x, paddings) + + def _convert_to_block(self, hidden_states): + b, t = keras.ops.shape(hidden_states)[:2] + tail_shape_list = list(hidden_states.shape[2:]) + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + padding_len = num_blocks * self.chunk_size - t + hidden_states = self._pad_dim1(hidden_states, 0, padding_len) + permute_dims = [b, num_blocks, self.chunk_size] + tail_shape_list + return keras.ops.reshape(hidden_states, permute_dims) + + def _extract_block_context(self, hidden_states): + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right) + _, t = keras.ops.shape(hidden_states)[:2] + frame_len = self.context_size + frame_step = self.chunk_size + num_frames = (t - frame_len) // frame_step + 1 + + start_indices = keras.ops.arange(0, num_frames) * frame_step + frame_offsets = keras.ops.arange(0, frame_len) + indices = keras.ops.expand_dims( + start_indices, axis=1 + ) + keras.ops.expand_dims(frame_offsets, axis=0) + return keras.ops.take(hidden_states, indices, axis=1) + + def call(self, hidden_states, mask): + qkv_shape = keras.ops.shape(hidden_states)[:-1] + ( + self.num_heads, + self.head_dim, + ) + query_states = keras.ops.reshape(self.q_proj(hidden_states), qkv_shape) + key_states = keras.ops.reshape(self.k_proj(hidden_states), qkv_shape) + value_states = keras.ops.reshape(self.v_proj(hidden_states), qkv_shape) + per_dim_scale_sp = keras.ops.softplus(self.per_dim_scale) + query_states = query_states * self.q_scale * per_dim_scale_sp + batch_size, q_time = keras.ops.shape(query_states)[:2] + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = keras.ops.shape(query_blocks)[1] + original_valid_mask = keras.ops.logical_not(mask) + extracted_valid_mask_blocks = self._extract_block_context( + original_valid_mask + ) + if ( + len(extracted_valid_mask_blocks.shape) == 4 + and extracted_valid_mask_blocks.shape[2] + * extracted_valid_mask_blocks.shape[3] + == self.context_size + ): + extracted_valid_mask_blocks = keras.ops.reshape( + extracted_valid_mask_blocks, + (batch_size, num_query_blocks, self.context_size), + ) + condition_from_input_validity = keras.ops.expand_dims( + keras.ops.expand_dims(extracted_valid_mask_blocks, 1), -2 + ) + condition_from_causality = keras.ops.expand_dims( + keras.ops.expand_dims( + keras.ops.expand_dims(self.local_causal_valid_mask, 0), 0 + ), + 0, + ) + final_condition_for_where = keras.ops.logical_and( + condition_from_input_validity, + keras.ops.cast(condition_from_causality, "bool"), + ) + logits = self.relative_position_embedding(query_blocks, key_blocks) + softcap = keras.ops.cast(self.softcap, dtype=logits.dtype) + logits = logits / softcap + logits = keras.ops.tanh(logits) + logits = logits * softcap + min_val = np.finfo(keras.backend.floatx()).min + logits = keras.ops.where(final_condition_for_where, logits, min_val) + probabilities = keras.ops.softmax( + keras.ops.cast(logits, "float32"), axis=-1 + ) + probabilities = keras.ops.cast(probabilities, value_blocks.dtype) + context_vectors = keras.ops.einsum( + "bnuwc,bucnh->buwnh", probabilities, value_blocks + ) + context_vectors = keras.ops.reshape( + context_vectors, + ( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ), + ) + context_vectors = context_vectors[:, :q_time] + return context_vectors + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py b/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py new file mode 100644 index 0000000000..0a4cdc05e6 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py @@ -0,0 +1,511 @@ +import keras + +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerAttention, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerFeedForward, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioConformerLightConv1d, +) +from keras_hub.src.models.gemma3n.gemma3n_audio_layers import ( + Gemma3nAudioSSCPConvBlock, +) +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioSubSampleConvProjection(keras.layers.Layer): + """A convolutional projection layer that subsamples audio features. + + This layer applies two blocks of 2D convolutions to the input audio + spectrogram. Each block subsamples the input along the time and frequency + dimensions. The output is then flattened and projected to the model's + hidden size. + + Args: + input_feat_size: int. The number of frequency bins in the input + spectrogram. + hidden_size: int. The dimensionality of the output embeddings. + sscp_conv_channel_size: list of int. The number of output channels for + each of the two convolutional blocks. + sscp_conv_kernel_size: list of tuple of int. The kernel sizes for each + of the two convolutional blocks. + sscp_conv_stride_size: list of tuple of int. The stride sizes for each + of the two convolutional blocks. + sscp_conv_group_norm_eps: float. Epsilon value for the Group + Normalization layers within the convolutional blocks. + """ + + def __init__( + self, + input_feat_size, + hidden_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.input_feat_size = input_feat_size + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + current_f_for_block_input = input_feat_size + self.calculated_block_padding = [] + self.calculated_f_out_dims = [] + for i in range(2): + kernel_h, kernel_w = sscp_conv_kernel_size[i] + _, stride_w = sscp_conv_stride_size[i] + pad_t_top, pad_t_bottom, pad_f_left, pad_f_right = ( + 0, + kernel_h - 1, + 1, + 1, + ) + manual_padding_tuple = ( + pad_f_left, + pad_f_right, + pad_t_top, + pad_t_bottom, + ) + self.calculated_block_padding.append(manual_padding_tuple) + f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right + f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 + self.calculated_f_out_dims.append(f_out_after_conv) + current_f_for_block_input = f_out_after_conv + self.conv_0 = Gemma3nAudioSSCPConvBlock( + idx=0, + input_freq_dim=input_feat_size, + sscp_conv_channel_size=sscp_conv_channel_size, + sscp_conv_kernel_size=sscp_conv_kernel_size, + sscp_conv_stride_size=sscp_conv_stride_size, + sscp_conv_group_norm_eps=sscp_conv_group_norm_eps, + manual_padding=self.calculated_block_padding[0], + name="conv_0", + dtype=self.dtype_policy, + ) + self.conv_1 = Gemma3nAudioSSCPConvBlock( + idx=1, + name="conv_1", + input_freq_dim=self.calculated_f_out_dims[0], + sscp_conv_channel_size=sscp_conv_channel_size, + sscp_conv_kernel_size=sscp_conv_kernel_size, + sscp_conv_stride_size=sscp_conv_stride_size, + sscp_conv_group_norm_eps=sscp_conv_group_norm_eps, + manual_padding=self.calculated_block_padding[1], + dtype=self.dtype_policy, + ) + self.input_proj_linear = keras.layers.Dense( + hidden_size, + use_bias=False, + name="input_proj_linear", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + _, t_in, f_in = input_shape + conv0_input_shape = (None, 1, t_in, f_in) + self.conv_0.build(conv0_input_shape) + if t_in is not None: + pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0][2:4] + kernel_h_0, _ = self.sscp_conv_kernel_size[0] + stride_h_0, _ = self.sscp_conv_stride_size[0] + t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 + t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 + else: + t_out_0 = None + c_out_0 = self.sscp_conv_channel_size[0] + f_out_0 = self.calculated_f_out_dims[0] + conv1_input_shape = (None, c_out_0, t_out_0, f_out_0) + self.conv_1.build(conv1_input_shape) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + b, t_in, f_in = input_shape + if t_in is not None: + _, _, pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0] + kernel_h_0, _ = self.sscp_conv_kernel_size[0] + stride_h_0, _ = self.sscp_conv_stride_size[0] + t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0 + t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1 + _, _, pad_t_top_1, pad_t_bottom_1 = self.calculated_block_padding[1] + kernel_h_1, _ = self.sscp_conv_kernel_size[1] + stride_h_1, _ = self.sscp_conv_stride_size[1] + t_padded_1 = t_out_0 + pad_t_top_1 + pad_t_bottom_1 + t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1 + else: + t_out_1 = None + return (b, t_out_1, self.hidden_size) + + def call(self, audio_encodings): + audio_encodings_reshaped = keras.ops.expand_dims(audio_encodings, 1) + x = self.conv_0(audio_encodings_reshaped) + x = self.conv_1(x) + b, c_out, t_out, f_out = keras.ops.shape(x) + x_permuted = keras.ops.transpose(x, (0, 2, 3, 1)) + output_flattened = keras.ops.reshape( + x_permuted, (b, t_out, f_out * c_out) + ) + return self.input_proj_linear(output_flattened) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_feat_size": self.input_feat_size, + "hidden_size": self.hidden_size, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + } + ) + return config + + +class Gemma3nAudioConformerBlock(keras.layers.Layer): + """A single conformer block for processing audio sequences. + + This layer implements the conformer architecture, which consists of a + sequence of four modules: a feed-forward module, a multi-head + self-attention module, a convolution module, and a final feed-forward + module. The output of each module is added to its input through a residual + connection. + + Args: + hidden_size: int. The dimensionality of the input and output embeddings. + rms_norm_eps: float. Epsilon value for the Gemma 3n RMS normalization + layers. + gradient_clipping: float. The maximum absolute value for the gradient. + conf_residual_weight: float. The weight for the residual connection in + the feed-forward layers. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The size of chunks for local attention. + conf_attention_context_right: int. The right context size for local + attention. + conf_attention_context_left: int. The left context size for local + attention. + conf_attention_logit_cap: float. The maximum value for the attention + logits. + conf_conv_kernel_size: int. The kernel size for the 1D convolution + layer. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_conv_kernel_size = conf_conv_kernel_size + self.ffw_layer_start = Gemma3nAudioConformerFeedForward( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_residual_weight=conf_residual_weight, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="ffw_layer_start", + ) + self.attention = Gemma3nAudioConformerAttention( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_num_attention_heads=conf_num_attention_heads, + conf_attention_chunk_size=conf_attention_chunk_size, + conf_attention_context_right=conf_attention_context_right, + conf_attention_context_left=conf_attention_context_left, + conf_attention_logit_cap=conf_attention_logit_cap, + dtype=self.dtype_policy, + name="attention", + ) + self.lconv1d = Gemma3nAudioConformerLightConv1d( + hidden_size=hidden_size, + rms_norm_eps=rms_norm_eps, + conf_conv_kernel_size=conf_conv_kernel_size, + gradient_clipping=gradient_clipping, + dtype=self.dtype_policy, + name="lconv1d", + ) + self.ffw_layer_end = Gemma3nAudioConformerFeedForward( + hidden_size=hidden_size, + gradient_clipping=gradient_clipping, + conf_residual_weight=conf_residual_weight, + rms_norm_eps=rms_norm_eps, + dtype=self.dtype_policy, + name="ffw_layer_end", + ) + self.norm = Gemma3nRMSNorm( + hidden_size, eps=rms_norm_eps, name="norm", dtype=self.dtype_policy + ) + + def build(self, input_shape): + audio_encodings_shape, _ = input_shape + self.ffw_layer_start.build(audio_encodings_shape) + self.attention.build(audio_encodings_shape) + self.lconv1d.build(audio_encodings_shape) + self.ffw_layer_end.build(audio_encodings_shape) + self.norm.build(audio_encodings_shape) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + audio_encodings_shape, _ = input_shape + return audio_encodings_shape + + def call(self, inputs): + audio_encodings, audio_mel_mask = inputs + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention(audio_encodings, audio_mel_mask) + validity_mask_for_lconv = keras.ops.logical_not(audio_mel_mask) + audio_encodings_for_lconv_input = audio_encodings * keras.ops.cast( + keras.ops.expand_dims(validity_mask_for_lconv, -1), + audio_encodings.dtype, + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + output = self.norm(audio_encodings) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + } + ) + return config + + +class Gemma3nAudioEncoder(keras.layers.Layer): + """The main audio encoder for the Gemma3n model. + + This layer combines a subsampling convolutional projection with a stack of + conformer blocks to encode audio spectrograms into a sequence of hidden + states. + + Args: + hidden_size: int. The dimensionality of the embeddings. + input_feat_size: int. The number of frequency bins in the input + spectrogram. + sscp_conv_channel_size: list of int. The number of output channels for + each of the two convolutional blocks in the subsampler. + sscp_conv_kernel_size: list of tuple of int. The kernel sizes for each + of the two convolutional blocks in the subsampler. + sscp_conv_stride_size: list of tuple of int. The stride sizes for each + of the two convolutional blocks in the subsampler. + sscp_conv_group_norm_eps: float. Epsilon value for the Group + Normalization layers in the subsampler. + conf_num_hidden_layers: int. The number of conformer blocks. + rms_norm_eps: float. Epsilon value for the Gemma 3n RMS normalization + layers. + gradient_clipping: float. The maximum absolute value for the gradient. + conf_residual_weight: float. The weight for the residual connection in + the feed-forward layers of the conformer blocks. + conf_num_attention_heads: int. The number of attention heads in the + conformer blocks. + conf_attention_chunk_size: int. The size of chunks for local attention + in the conformer blocks. + conf_attention_context_right: int. The right context size for local + attention in the conformer blocks. + conf_attention_context_left: int. The left context size for local + attention in the conformer blocks. + conf_attention_logit_cap: float. The maximum value for the attention + logits in the conformer blocks. + conf_conv_kernel_size: int. The kernel size for the 1D convolution + layer in the conformer blocks. + conf_reduction_factor: int. The factor by which to reduce the sequence + length of the final output. + """ + + def __init__( + self, + hidden_size, + input_feat_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + conf_num_hidden_layers, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + conf_reduction_factor, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.input_feat_size = input_feat_size + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.conf_num_hidden_layers = conf_num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_conv_kernel_size = conf_conv_kernel_size + self.conf_reduction_factor = conf_reduction_factor + self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection( + input_feat_size, + hidden_size, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + dtype=self.dtype_policy, + name="subsample_conv_projection", + ) + self.conformer = [ + Gemma3nAudioConformerBlock( + hidden_size, + rms_norm_eps, + gradient_clipping, + conf_residual_weight, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + conf_conv_kernel_size, + dtype=self.dtype_policy, + name=f"conformer_block_{i}", + ) + for i in range(conf_num_hidden_layers) + ] + + def build(self, input_shape): + audio_mel_shape, _ = input_shape + self.subsample_conv_projection.build(audio_mel_shape) + encodings_shape = self.subsample_conv_projection.compute_output_shape( + audio_mel_shape + ) + t_sub = encodings_shape[1] + time_stride_product = 1 + for stride_pair in self.sscp_conv_stride_size: + time_stride_product *= stride_pair[0] + batch_size = ( + audio_mel_shape[0] if audio_mel_shape[0] is not None else -1 + ) + current_mask_shape = (batch_size, t_sub) + current_encodings_shape = encodings_shape + for block in self.conformer: + block.build((current_encodings_shape, current_mask_shape)) + current_encodings_shape = block.compute_output_shape( + (current_encodings_shape, current_mask_shape) + ) + super().build(input_shape) + + def compute_output_shape(self, input_shape): + audio_mel_shape, _ = input_shape + encodings_shape = self.subsample_conv_projection.compute_output_shape( + audio_mel_shape + ) + current_encodings_shape = encodings_shape + for block in self.conformer: + current_encodings_shape = block.compute_output_shape( + (current_encodings_shape, None) + ) + if self.conf_reduction_factor > 1: + t_sub = current_encodings_shape[1] + if t_sub is not None: + new_t = t_sub // self.conf_reduction_factor + current_encodings_shape = ( + current_encodings_shape[0], + new_t, + current_encodings_shape[2], + ) + return current_encodings_shape, None + + def call(self, inputs): + audio_mel, audio_mel_mask = inputs + audio_encodings = self.subsample_conv_projection(audio_mel) + t_sub = keras.ops.shape(audio_encodings)[1] + time_stride_product = 1 + for stride_pair in self.sscp_conv_stride_size: + time_stride_product *= stride_pair[0] + indices = keras.ops.arange(0, t_sub) * time_stride_product + indices = keras.ops.clip( + indices, 0, keras.ops.shape(audio_mel_mask)[1] - 1 + ) + current_mask = keras.ops.take(audio_mel_mask, indices, axis=1) + for block in self.conformer: + audio_encodings = block((audio_encodings, current_mask)) + + if self.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.conf_reduction_factor] + current_mask = current_mask[:, :: self.conf_reduction_factor] + return audio_encodings * keras.ops.cast( + keras.ops.logical_not(keras.ops.expand_dims(current_mask, axis=-1)), + audio_encodings.dtype, + ), current_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "input_feat_size": self.input_feat_size, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + "conf_num_hidden_layers": self.conf_num_hidden_layers, + "rms_norm_eps": self.rms_norm_eps, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + "conf_reduction_factor": self.conf_reduction_factor, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py b/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py new file mode 100644 index 0000000000..11d15813b9 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_audio_layers.py @@ -0,0 +1,526 @@ +import keras + +from keras_hub.src.models.gemma3n.gemma3n_attention import Gemma3nAudioAttention +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nAudioCumulativeGroupNorm(keras.layers.Layer): + """A cumulative group normalization layer for audio features. + + This layer normalizes the input hidden states based on cumulative statistics + calculated over the time dimension. It is designed to process audio + spectrograms or similar sequential data. + + Args: + num_channels: int. The number of channels for normalization. + feature_dims: tuple. The dimensions of the features to be normalized. + eps: float. A small epsilon value to add to the variance to avoid + division by zero. + """ + + def __init__( + self, + num_channels, + feature_dims, + eps=1e-3, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_channels = num_channels + self.feature_dims = tuple(feature_dims) + self.eps = eps + self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) + + def build(self, input_shape): + self.scale = self.add_weight( + shape=(self.num_channels,), + initializer="ones", + trainable=True, + name="scale", + dtype=self.dtype_policy.variable_dtype, + ) + super().build(input_shape) + + def _int8_call(self, hidden_states): + original_dtype = hidden_states.dtype + x_calc = keras.ops.cast(hidden_states, "float32") + result_calc = self.call(x_calc) + return keras.ops.cast(result_calc, original_dtype) + + def call(self, hidden_states): + input_dtype = hidden_states.dtype + x_calc = keras.ops.cast(hidden_states, "float32") + mask_calc = keras.ops.ones_like(x_calc, dtype="float32") + sum_values_at_t = keras.ops.sum( + x_calc, axis=self.reduction_axes, keepdims=True + ) + cum_sum_values = keras.ops.cumsum(sum_values_at_t, axis=1) + elements_in_group_at_t = keras.ops.sum( + mask_calc, axis=self.reduction_axes, keepdims=True + ) + cum_count_elements = keras.ops.cumsum(elements_in_group_at_t, axis=1) + safe_cum_count_elements = keras.ops.maximum(cum_count_elements, 1.0) + cum_mean = cum_sum_values / safe_cum_count_elements + squared_diff_from_mean = keras.ops.square(x_calc - cum_mean) + sum_sq_diff_at_t = keras.ops.sum( + squared_diff_from_mean, axis=self.reduction_axes, keepdims=True + ) + cum_sum_sq_diff = keras.ops.cumsum(sum_sq_diff_at_t, axis=1) + cum_variance = cum_sum_sq_diff / safe_cum_count_elements + normalized_x = (x_calc - cum_mean) * keras.ops.rsqrt( + cum_variance + self.eps + ) + scale_view_shape = [1] * (len(hidden_states.shape) - 1) + [ + self.num_channels + ] + reshaped_scale = keras.ops.reshape(self.scale, scale_view_shape) + normalized_x = normalized_x * keras.ops.cast(reshaped_scale, "float32") + final_output = normalized_x * mask_calc + return keras.ops.cast(final_output, input_dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_channels": self.num_channels, + "feature_dims": self.feature_dims, + "eps": self.eps, + } + ) + return config + + +class Gemma3nAudioSSCPConvBlock(keras.layers.Layer): + """A single SSCP (Spectrogram Sub-sampling Convolutional Preprocessor) + block. + + This block consists of a 2D convolution, a cumulative group normalization + layer, and a ReLU activation. It is used to process and downsample audio + spectrograms. + + Args: + idx: int. The index of the convolutional block. + input_freq_dim: int. The frequency dimension of the input spectrogram. + sscp_conv_channel_size: list or tuple. A sequence containing the number + of output channels for each convolutional block in the SSCP stack. + sscp_conv_kernel_size: list or tuple. A sequence of kernel sizes for + each convolutional block. + sscp_conv_stride_size: list or tuple. A sequence of stride sizes for + each convolutional block. + sscp_conv_group_norm_eps: float. The epsilon value for the cumulative + group normalization layer. + manual_padding: tuple. A tuple of 4 integers specifying the manual + padding to be applied as (pad_w_left, pad_w_right, pad_h_top, + pad_h_bottom). + """ + + def __init__( + self, + idx, + input_freq_dim, + sscp_conv_channel_size, + sscp_conv_kernel_size, + sscp_conv_stride_size, + sscp_conv_group_norm_eps, + manual_padding=(0, 0, 0, 0), + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.idx = idx + self.input_freq_dim = input_freq_dim + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.manual_padding = manual_padding + out_channels = sscp_conv_channel_size[idx] + kernel_h, kernel_w = sscp_conv_kernel_size[idx] + stride_h, stride_w = sscp_conv_stride_size[idx] + self.conv = keras.layers.Conv2D( + filters=out_channels, + kernel_size=(kernel_h, kernel_w), + strides=(stride_h, stride_w), + padding="valid", + use_bias=False, + data_format="channels_last", + name="conv", + dtype=self.dtype_policy, + ) + f_in_padded = ( + input_freq_dim + self.manual_padding[0] + self.manual_padding[1] + ) + f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 + self.norm = Gemma3nAudioCumulativeGroupNorm( + num_channels=out_channels, + feature_dims=(f_out_conv,), + eps=sscp_conv_group_norm_eps, + name="norm", + dtype=self.dtype_policy, + ) + self.activation = keras.layers.ReLU( + name="activation", dtype=self.dtype_policy + ) + + def build(self, input_shape): + _, c_in, h, w = input_shape + if h is not None: + padded_h = h + self.manual_padding[2] + self.manual_padding[3] + else: + padded_h = None + padded_w = w + self.manual_padding[0] + self.manual_padding[1] + conv_input_shape = (None, padded_h, padded_w, c_in) + if not self.conv.built: + self.conv.build(conv_input_shape) + if h is not None: + h_out = (padded_h - self.conv.kernel_size[0]) // self.conv.strides[ + 0 + ] + 1 + else: + h_out = None + w_out = (padded_w - self.conv.kernel_size[1]) // self.conv.strides[ + 1 + ] + 1 + norm_input_shape = (None, h_out, w_out, self.conv.filters) + if not self.norm.built: + self.norm.build(norm_input_shape) + super().build(input_shape) + + def call(self, audio_encodings): + audio_encodings_nhwc = keras.ops.transpose( + audio_encodings, (0, 2, 3, 1) + ) + keras_padding = [ + [0, 0], + [self.manual_padding[2], self.manual_padding[3]], + [self.manual_padding[0], self.manual_padding[1]], + [0, 0], + ] + audio_encodings_padded = keras.ops.pad( + audio_encodings_nhwc, + keras_padding, + mode="constant", + constant_values=0.0, + ) + audio_encodings_conv = self.conv(audio_encodings_padded) + x_normed = self.norm(audio_encodings_conv) + audio_encodings_normed = keras.ops.transpose(x_normed, (0, 3, 1, 2)) + return self.activation(audio_encodings_normed) + + def get_config(self): + config = super().get_config() + config.update( + { + "idx": self.idx, + "input_freq_dim": self.input_freq_dim, + "sscp_conv_channel_size": self.sscp_conv_channel_size, + "sscp_conv_kernel_size": self.sscp_conv_kernel_size, + "sscp_conv_stride_size": self.sscp_conv_stride_size, + "sscp_conv_group_norm_eps": self.sscp_conv_group_norm_eps, + "manual_padding": self.manual_padding, + } + ) + return config + + +class Gemma3nAudioConformerFeedForward(keras.layers.Layer): + """The feed-forward module for the Conformer block. + + This module implements the feed-forward sub-layer of a Conformer block, + which consists of pre-layer normalization, two dense layers with a SiLU + activation function in between, post-layer normalization, and a residual + connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + conf_residual_weight: float. The weight applied to the output of the + sub-layer before adding the residual connection. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + """ + + def __init__( + self, + hidden_size, + gradient_clipping, + conf_residual_weight, + rms_norm_eps, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.gradient_clipping = gradient_clipping + self.conf_residual_weight = conf_residual_weight + self.rms_norm_eps = rms_norm_eps + self.pre_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_layer_norm", + dtype=self.dtype_policy, + ) + self.ffw_layer_1 = keras.layers.Dense( + hidden_size * 4, + use_bias=False, + name="ffw_layer_1", + dtype=self.dtype_policy, + ) + self.ffw_layer_2 = keras.layers.Dense( + hidden_size, + use_bias=False, + name="ffw_layer_2", + dtype=self.dtype_policy, + ) + self.post_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_layer_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.pre_layer_norm.build(input_shape) + self.ffw_layer_1.build(input_shape) + ffw1_output_shape = input_shape[:-1] + (self.hidden_size * 4,) + self.ffw_layer_2.build(ffw1_output_shape) + self.post_layer_norm.build(input_shape) + super().build(input_shape) + + def call(self, audio_encodings): + residual = audio_encodings + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.ffw_layer_1(audio_encodings) + audio_encodings = keras.activations.silu(audio_encodings) + audio_encodings = self.ffw_layer_2(audio_encodings) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.conf_residual_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "gradient_clipping": self.gradient_clipping, + "conf_residual_weight": self.conf_residual_weight, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class Gemma3nAudioConformerLightConv1d(keras.layers.Layer): + """The lightweight 1D convolution module for the Conformer block. + + This module implements the convolution sub-layer of a Conformer block, + which consists of pre-layer normalization, a gated linear unit (GLU), a + lightweight depthwise 1D convolution, and a final projection, followed by a + residual connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + conf_conv_kernel_size: int. The kernel size for the depthwise 1D + convolution. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + conf_conv_kernel_size, + gradient_clipping, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.conf_conv_kernel_size = conf_conv_kernel_size + self.gradient_clipping = gradient_clipping + self.pre_layer_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_layer_norm", + dtype=self.dtype_policy, + ) + self.linear_start = keras.layers.Dense( + hidden_size * 2, + use_bias=False, + name="linear_start", + dtype=self.dtype_policy, + ) + self.depthwise_conv1d = keras.layers.DepthwiseConv1D( + kernel_size=conf_conv_kernel_size, + strides=1, + padding="valid", + use_bias=False, + data_format="channels_last", + name="depthwise_conv1d", + dtype=self.dtype_policy, + ) + self.conv_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="conv_norm", + dtype=self.dtype_policy, + ) + self.linear_end = keras.layers.Dense( + hidden_size, + use_bias=False, + name="linear_end", + dtype=self.dtype_policy, + ) + self.causal_padding = conf_conv_kernel_size - 1 + + def build(self, input_shape): + self.pre_layer_norm.build(input_shape) + self.linear_start.build(input_shape) + glu_output_shape = input_shape[:-1] + (self.hidden_size,) + self.depthwise_conv1d.build(glu_output_shape) + self.conv_norm.build(glu_output_shape) + self.linear_end.build(glu_output_shape) + super().build(input_shape) + + def call(self, audio_encodings): + residual = audio_encodings + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + gated, activated = keras.ops.split(audio_encodings, 2, axis=-1) + audio_encodings = gated * keras.activations.sigmoid(activated) + + padded = keras.ops.pad( + audio_encodings, + [[0, 0], [self.causal_padding, 0], [0, 0]], + ) + audio_encodings = self.depthwise_conv1d(padded) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = keras.activations.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + return audio_encodings + residual + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "conf_conv_kernel_size": self.conf_conv_kernel_size, + "gradient_clipping": self.gradient_clipping, + } + ) + return config + + +class Gemma3nAudioConformerAttention(keras.layers.Layer): + """The attention module for the Conformer block. + + This module implements the multi-head self-attention sub-layer of a + Conformer block. It wraps the core attention mechanism with pre and post + layer normalization, a final dense projection, and a residual connection. + + Args: + hidden_size: int. The hidden size of the input and output tensors. + gradient_clipping: float. The maximum absolute value for gradient + clipping. + conf_num_attention_heads: int. The number of attention heads. + conf_attention_chunk_size: int. The chunk size for attention + computation, used for memory efficiency. + conf_attention_context_right: int. The right context size for attention. + conf_attention_context_left: int. The left context size for attention. + conf_attention_logit_cap: float. The value to which attention logits + are capped. + """ + + def __init__( + self, + hidden_size, + gradient_clipping, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.gradient_clipping = gradient_clipping + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_logit_cap = conf_attention_logit_cap + self.pre_attn_norm = Gemma3nRMSNorm( + hidden_size, name="pre_attn_norm", dtype=self.dtype_policy + ) + self.attn = Gemma3nAudioAttention( + hidden_size, + conf_num_attention_heads, + conf_attention_chunk_size, + conf_attention_context_right, + conf_attention_context_left, + conf_attention_logit_cap, + dtype=self.dtype_policy, + name="attn", + ) + self.post = keras.layers.Dense( + hidden_size, use_bias=False, name="post", dtype=self.dtype_policy + ) + self.post_norm = Gemma3nRMSNorm( + hidden_size, name="post_norm", dtype=self.dtype_policy + ) + + def build(self, input_shape): + self.pre_attn_norm.build(input_shape) + self.attn.build(input_shape) + self.post.build(input_shape) + self.post_norm.build(input_shape) + super().build(input_shape) + + def call(self, audio_encodings, audio_mel_mask): + residual = audio_encodings + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn( + audio_encodings_norm, audio_mel_mask + ) + b, t, num_heads, head_dim = keras.ops.shape(audio_encodings_attn_out) + audio_encodings_reshaped = keras.ops.reshape( + audio_encodings_attn_out, (b, t, num_heads * head_dim) + ) + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = keras.ops.clip( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return residual + self.post_norm(audio_encodings) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "gradient_clipping": self.gradient_clipping, + "conf_num_attention_heads": self.conf_num_attention_heads, + "conf_attention_chunk_size": self.conf_attention_chunk_size, + "conf_attention_context_right": self.conf_attention_context_right, # noqa: E501 + "conf_attention_context_left": self.conf_attention_context_left, + "conf_attention_logit_cap": self.conf_attention_logit_cap, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_backbone.py b/keras_hub/src/models/gemma3n/gemma3n_backbone.py new file mode 100644 index 0000000000..e939297432 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_backbone.py @@ -0,0 +1,865 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, +) +from keras_hub.src.models.gemma3n.gemma3n_text_model import Gemma3nTextModel +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nMultimodalEmbedder(keras.layers.Layer): + """A layer for handling multimodal embeddings. + + This layer manages embeddings for different modalities (here, vision, text, + and audio). It can take either token IDs or pre-computed embedding vectors + as input. The embeddings are normalized and projected to match the text + model's hidden size. + + Args: + multimodal_hidden_size: int. The hidden size of the multimodal + embeddings. + text_hidden_size: int. The hidden size of the text model. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + vocab_offset: int. The vocabulary offset for the specific modality. + vocab_size: int. The vocabulary size for the specific modality. + """ + + def __init__( + self, + multimodal_hidden_size, + text_hidden_size, + rms_norm_eps, + vocab_offset, + vocab_size, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.multimodal_hidden_size = multimodal_hidden_size + self.text_hidden_size = text_hidden_size + self.rms_norm_eps = rms_norm_eps + self.vocab_offset = vocab_offset + self.vocab_size = vocab_size + self.embedding = keras.layers.Embedding( + vocab_size, + multimodal_hidden_size, + name="embedding", + dtype=self.dtype_policy, + ) + self.hard_embedding_norm = Gemma3nRMSNorm( + multimodal_hidden_size, + eps=rms_norm_eps, + name="hard_embedding_norm", + dtype=self.dtype_policy, + ) + self.soft_embedding_norm = Gemma3nRMSNorm( + multimodal_hidden_size, + eps=rms_norm_eps, + name="soft_embedding_norm", + dtype=self.dtype_policy, + ) + self.embedding_projection = keras.layers.Dense( + text_hidden_size, + use_bias=False, + name="embedding_projection", + dtype=self.dtype_policy, + ) + self.embedding_post_projection_norm = Gemma3nRMSNorm( + text_hidden_size, + eps=rms_norm_eps, + with_scale=False, + name="embedding_post_projection_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + embeds_shape = (None, None, self.multimodal_hidden_size) + self.hard_embedding_norm.build(embeds_shape) + self.soft_embedding_norm.build(embeds_shape) + self.embedding_projection.build(embeds_shape) + proj_shape = (None, None, self.text_hidden_size) + self.embedding_post_projection_norm.build(proj_shape) + self.embedding.build((None, None)) + super().build(input_shape) + + def call(self, inputs): + input_ids, inputs_embeds = None, None + if isinstance(inputs, list): + input_ids, inputs_embeds = inputs + elif "int" in str(inputs.dtype): + input_ids = inputs + else: + inputs_embeds = inputs + if (input_ids is None) and (inputs_embeds is None): + raise ValueError( + "You must specify either input_ids or inputs_embeds" + ) + if (input_ids is not None) and (inputs_embeds is not None): + raise ValueError( + "You can only specify one of input_ids or inputs_embeds" + ) + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + index_to_lookup = input_ids - self.vocab_offset + hard_emb = self.embedding(index_to_lookup) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + def get_config(self): + config = super().get_config() + config.update( + { + "multimodal_hidden_size": self.multimodal_hidden_size, + "text_hidden_size": self.text_hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "vocab_offset": self.vocab_offset, + "vocab_size": self.vocab_size, + } + ) + return config + + +class Gemma3nMultimodalEmbeddingProcessor(keras.layers.Layer): + """Processes and interleaves text, vision, and audio embeddings. + + This layer takes raw token IDs and multimodal inputs (pixel values, audio + features) and produces a final sequence of embeddings ready for the + decoder. It handles the embedding lookup for text and special tokens, + and replaces the special tokens with the processed features from the + vision and audio encoders. + + Args: + language_model: `keras_hub.models.gemma3n.Gemma3nTextModel`. The + underlying text model containing embedding layers. + vision_encoder: `keras.Model`. The vision encoder model. + embed_vision: `keras_hub.models.gemma3n.Gemma3nMultimodalEmbedder`. The + embedder for vision. + audio_encoder: `keras_hub.models.gemma3n.Gemma3nAudioEncoder`. The audio + encoder model. + embed_audio: `keras_hub.models.gemma3n.Gemma3nMultimodalEmbedder`. The + embedder for audio. + vision_soft_tokens_per_image: int. Number of tokens to represent an + image. + audio_soft_tokens_per_image: int. Number of tokens to represent an + audio clip. + image_token_id: int. The special token ID for images. + audio_token_id: int. The special token ID for audio. + vocab_size_per_layer_input: int. The vocabulary size for per-layer + inputs. + """ + + def __init__( + self, + language_model, + vision_encoder, + embed_vision, + audio_encoder, + embed_audio, + vision_soft_tokens_per_image, + audio_soft_tokens_per_image, + image_token_id, + audio_token_id, + vocab_size_per_layer_input, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.language_model = language_model + self.vision_encoder = vision_encoder + self.embed_vision = embed_vision + self.audio_encoder = audio_encoder + self.embed_audio = embed_audio + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.image_token_id = image_token_id + self.audio_token_id = audio_token_id + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.text_hidden_size = language_model.embed_tokens.embedding_dim + + def call(self, inputs): + input_ids = inputs["token_ids"] + pixel_values = inputs.get("pixel_values") + input_features = inputs.get("input_features") + input_features_mask = inputs.get("input_features_mask") + inputs_embeds = self.language_model.embed_tokens(input_ids) + per_layer_inputs_mask = keras.ops.logical_and( + input_ids >= 0, input_ids < self.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = keras.ops.where( + per_layer_inputs_mask, input_ids, keras.ops.zeros_like(input_ids) + ) + per_layer_inputs = self.language_model.get_per_layer_inputs( + per_layer_inputs_tokens + ) + if self.vision_encoder: + vision_mask = keras.ops.logical_and( + input_ids >= self.embed_vision.vocab_offset, + input_ids < self.embed_audio.vocab_offset, + ) + dummy_vision_token_id = ( + self.embed_vision.vocab_offset + + self.embed_vision.embedding.input_dim + - 1 + ) + vision_input_ids = keras.ops.where( + vision_mask, input_ids, dummy_vision_token_id + ) + vision_embeds_from_vocab = self.embed_vision(vision_input_ids) + expanded_vision_mask = keras.ops.expand_dims(vision_mask, axis=-1) + inputs_embeds = keras.ops.where( + expanded_vision_mask, + vision_embeds_from_vocab, + inputs_embeds, + ) + if self.audio_encoder: + audio_mask = input_ids >= self.embed_audio.vocab_offset + dummy_audio_token_id = ( + self.embed_audio.vocab_offset + + self.embed_audio.embedding.input_dim + - 1 + ) + audio_input_ids = keras.ops.where( + audio_mask, input_ids, dummy_audio_token_id + ) + audio_embeds_from_vocab = self.embed_audio(audio_input_ids) + expanded_audio_mask = keras.ops.expand_dims(audio_mask, axis=-1) + inputs_embeds = keras.ops.where( + expanded_audio_mask, audio_embeds_from_vocab, inputs_embeds + ) + + if pixel_values is not None and self.vision_encoder: + reshape_target = (-1,) + tuple(self.vision_encoder.image_shape) + pixel_values = keras.ops.reshape(pixel_values, reshape_target) + vision_features = self.vision_encoder(pixel_values) + if self.vision_encoder.data_format == "channels_first": + vision_features = keras.ops.transpose( + vision_features, (0, 2, 3, 1) + ) + shape = keras.ops.shape(vision_features) + vision_features = keras.ops.reshape( + vision_features, (shape[0], shape[1] * shape[2], shape[3]) + ) + vision_features *= keras.ops.sqrt( + keras.ops.cast( + self.vision_encoder.num_features, dtype=inputs_embeds.dtype + ) + ) + vision_embeds = self.embed_vision(vision_features) + image_token_mask = keras.ops.equal(input_ids, self.image_token_id) + + def scatter_vision_features(): + batch_size, seq_len, hidden_size = keras.ops.shape( + inputs_embeds + ) + num_soft_tokens = self.vision_soft_tokens_per_image + start_mask_f32 = keras.ops.cast( + image_token_mask, dtype="float32" + ) + start_mask_f32 = keras.ops.expand_dims(start_mask_f32, axis=-1) + kernel = keras.ops.ones( + (num_soft_tokens, 1, 1), dtype="float32" + ) + padded_mask = keras.ops.pad( + start_mask_f32, + [[0, 0], [num_soft_tokens - 1, 0], [0, 0]], + ) + full_mask_f32 = keras.ops.conv( + padded_mask, kernel, strides=1, padding="valid" + ) + full_mask = keras.ops.cast( + keras.ops.squeeze(full_mask_f32, axis=-1) > 0.5, "bool" + ) + flat_vision_embeds = keras.ops.reshape( + vision_embeds, [-1, hidden_size] + ) + flat_full_mask = keras.ops.reshape(full_mask, [-1]) + gather_indices = ( + keras.ops.cumsum(keras.ops.cast(flat_full_mask, "int32")) + - 1 + ) + gather_indices = keras.ops.where( + flat_full_mask, gather_indices, 0 + ) + replacement_values = keras.ops.take( + flat_vision_embeds, gather_indices, axis=0 + ) + replacement_tensor = keras.ops.reshape( + replacement_values, (batch_size, seq_len, hidden_size) + ) + expanded_full_mask = keras.ops.expand_dims(full_mask, axis=-1) + return keras.ops.where( + expanded_full_mask, replacement_tensor, inputs_embeds + ) + + inputs_embeds = keras.ops.cond( + keras.ops.any(image_token_mask), + scatter_vision_features, + lambda: inputs_embeds, + ) + + if ( + input_features is not None + and input_features_mask is not None + and self.audio_encoder + ): + audio_features, _ = self.audio_encoder( + (input_features, input_features_mask) + ) + audio_embeds = self.embed_audio(audio_features) + shape = keras.ops.shape(audio_embeds) + audio_batch_size, audio_seq_len, hidden_size = ( + shape[0], + shape[1], + shape[2], + ) + target_len = self.audio_soft_tokens_per_image + last_audio_token_id = ( + self.embed_audio.vocab_offset + + self.embed_audio.embedding.input_dim + - 1 + ) + padding_toks = keras.ops.convert_to_tensor( + [[last_audio_token_id]], dtype="int64" + ) + padding_embs = self.embed_audio(padding_toks) + padding_token = keras.ops.squeeze(padding_embs, axis=[0]) + flat_audio_embeds = keras.ops.reshape( + audio_embeds, [-1, hidden_size] + ) + vocab = keras.ops.concatenate( + [flat_audio_embeds, padding_token], axis=0 + ) + pad_token_index = keras.ops.shape(flat_audio_embeds)[0] + indices = keras.ops.arange(target_len) + is_real_token = indices < audio_seq_len + batch_offsets = keras.ops.arange(audio_batch_size) * audio_seq_len + real_indices = keras.ops.expand_dims( + indices, 0 + ) + keras.ops.expand_dims(batch_offsets, 1) + final_indices = keras.ops.where( + keras.ops.expand_dims(is_real_token, 0), + real_indices, + pad_token_index, + ) + audio_embeds = keras.ops.take(vocab, final_indices, axis=0) + audio_token_mask = keras.ops.equal(input_ids, self.audio_token_id) + + def scatter_audio_features(): + batch_size, seq_len, hidden_size = keras.ops.shape( + inputs_embeds + ) + num_soft_tokens = self.audio_soft_tokens_per_image + start_mask_f32 = keras.ops.cast( + audio_token_mask, dtype="float32" + ) + start_mask_f32 = keras.ops.expand_dims(start_mask_f32, axis=-1) + kernel = keras.ops.ones( + (num_soft_tokens, 1, 1), dtype="float32" + ) + padded_mask = keras.ops.pad( + start_mask_f32, + [[0, 0], [num_soft_tokens - 1, 0], [0, 0]], + ) + full_mask_f32 = keras.ops.conv( + padded_mask, kernel, strides=1, padding="valid" + ) + full_mask = keras.ops.cast( + keras.ops.squeeze(full_mask_f32, axis=-1) > 0.5, "bool" + ) + flat_audio_embeds = keras.ops.reshape( + audio_embeds, [-1, hidden_size] + ) + flat_full_mask = keras.ops.reshape(full_mask, [-1]) + gather_indices = ( + keras.ops.cumsum(keras.ops.cast(flat_full_mask, "int32")) + - 1 + ) + gather_indices = keras.ops.where( + flat_full_mask, gather_indices, 0 + ) + replacement_values = keras.ops.take( + flat_audio_embeds, gather_indices, axis=0 + ) + replacement_tensor = keras.ops.reshape( + replacement_values, (batch_size, seq_len, hidden_size) + ) + expanded_full_mask = keras.ops.expand_dims(full_mask, axis=-1) + return keras.ops.where( + expanded_full_mask, replacement_tensor, inputs_embeds + ) + + inputs_embeds = keras.ops.cond( + keras.ops.any(audio_token_mask), + scatter_audio_features, + lambda: inputs_embeds, + ) + projected_per_layer_inputs = ( + self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + ) + return inputs_embeds, projected_per_layer_inputs + + def get_config(self): + config = super().get_config() + config.update( + { + "language_model": keras.layers.serialize(self.language_model), + "vision_encoder": keras.layers.serialize(self.vision_encoder), + "embed_vision": keras.layers.serialize(self.embed_vision), + "audio_encoder": keras.layers.serialize(self.audio_encoder), + "embed_audio": keras.layers.serialize(self.embed_audio), + "vision_soft_tokens_per_image": self.vision_soft_tokens_per_image, # noqa: E501 + "audio_soft_tokens_per_image": self.audio_soft_tokens_per_image, + "image_token_id": self.image_token_id, + "audio_token_id": self.audio_token_id, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + } + ) + return config + + @classmethod + def from_config(cls, config): + config = config.copy() + language_model = keras.layers.deserialize(config.pop("language_model")) + vision_encoder = keras.layers.deserialize(config.pop("vision_encoder")) + embed_vision = keras.layers.deserialize(config.pop("embed_vision")) + audio_encoder = keras.layers.deserialize(config.pop("audio_encoder")) + embed_audio = keras.layers.deserialize(config.pop("embed_audio")) + return cls( + language_model=language_model, + vision_encoder=vision_encoder, + embed_vision=embed_vision, + audio_encoder=audio_encoder, + embed_audio=embed_audio, + **config, + ) + + +@keras_hub_export("keras_hub.models.Gemma3nBackbone") +class Gemma3nBackbone(Backbone): + """The Gemma3n model backbone. + + This model is a multimodal transformer that can process text, image, and + audio inputs. It consists of a text decoder and optional vision and audio + encoders. + + Args: + text_vocab_size: int. The size of the text vocabulary. + text_hidden_size: int. The hidden size of the text model. + num_hidden_layers: int. The number of hidden layers in the text model. + pad_token_id: int. The ID of the padding token. + num_attention_heads: int. The number of attention heads in the text + model. + num_key_value_heads: int. The number of key-value heads for GQA. + head_dim: int. The dimension of each attention head. + intermediate_size: list[int]. A list of intermediate sizes for the MLP + layers. + hidden_activation: str. The activation function for the MLP layers. + layer_types: list[str]. A list of layer types ('full_attention' or + 'sliding_attention'). + sliding_window: int. The sliding window size for sliding window + attention. + rope_theta: float. The theta value for RoPE. + max_position_embeddings: int. The maximum sequence length. + vocab_size_per_layer_input: int. The vocab size for per-layer inputs. + hidden_size_per_layer_input: int. The hidden size for per-layer inputs. + altup_num_inputs: int. The number of inputs for the AltUp mechanism. + laurel_rank: int. The rank for the Laurel block. + attention_bias: bool. Whether to use a bias in the attention + projections. + attention_dropout: float. The dropout rate for attention weights. + rope_scaling: float. The scaling factor for RoPE. + rope_local_base_freq: float. The base frequency for local RoPE. + activation_sparsity_pattern: list[float]. The sparsity pattern for MLP + activations. + altup_coef_clip: float. The coefficient clipping value for AltUp. + altup_active_idx: int. The active index for AltUp. + altup_correct_scale: bool. Whether to correct the scale in AltUp. + num_kv_shared_layers: int. The number of shared KV layers. + vision_encoder_config: dict. The config for the vision encoder. + vision_hidden_size: int. The hidden size of the vision embeddings. + vision_vocab_size: int. The vocabulary size for vision tokens. + vision_vocab_offset: int. The vocabulary offset for vision tokens. + vision_soft_tokens_per_image: int. The number of tokens per image. + image_token_id: int. The special token ID for images. + audio_encoder_config: dict. The config for the audio encoder. + audio_hidden_size: int. The hidden size of the audio embeddings. + audio_vocab_size: int. The vocabulary size for audio tokens. + audio_vocab_offset: int. The vocabulary offset for audio tokens. + audio_soft_tokens_per_image: int. The number of tokens per audio clip. + audio_token_id: int. The special token ID for audio. + rms_norm_eps: float. The epsilon value for RMS normalization. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. Defaults to `None`. + + Example: + ```python + import numpy as np + from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, + ) + from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone + from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, + ) + from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + decode_arch_def, + ) + + # Vision encoder config. + vision_arch_def = [["er_r1_k3_s1_e1_c16"]] + vision_block_args = decode_arch_def(vision_arch_def) + vision_encoder = MobileNetV5Backbone( + block_args=vision_block_args, + num_features=4, + image_shape=(224, 224, 3), + use_msfa=False, + ) + + # Audio encoder config. + audio_encoder = Gemma3nAudioEncoder( + hidden_size=8, + input_feat_size=32, + sscp_conv_channel_size=[4, 8], + sscp_conv_kernel_size=[(3, 3), (3, 3)], + sscp_conv_stride_size=[(2, 2), (2, 2)], + sscp_conv_group_norm_eps=1e-5, + conf_num_hidden_layers=1, + rms_norm_eps=1e-6, + gradient_clipping=1.0, + conf_residual_weight=0.5, + conf_num_attention_heads=1, + conf_attention_chunk_size=4, + conf_attention_context_right=5, + conf_attention_context_left=5, + conf_attention_logit_cap=50.0, + conf_conv_kernel_size=5, + conf_reduction_factor=1, + ) + + # Backbone config. + backbone = Gemma3nBackbone( + text_vocab_size=50, + text_hidden_size=8, + num_hidden_layers=1, + pad_token_id=0, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=8, + intermediate_size=[16], + hidden_activation="gelu_approximate", + layer_types=["full_attention"], + sliding_window=4, + rope_theta=10000.0, + max_position_embeddings=16, + vocab_size_per_layer_input=50, + hidden_size_per_layer_input=2, + altup_num_inputs=2, + laurel_rank=1, + vision_encoder_config=vision_encoder.get_config(), + vision_hidden_size=16, + audio_encoder_config=audio_encoder.get_config(), + audio_hidden_size=8, + ) + + # Create dummy inputs. + input_data = { + "token_ids": np.random.randint(0, 50, size=(1, 16), dtype="int32"), + "attention_mask": np.ones((1, 1, 16, 16), dtype=bool), + "pixel_values": np.random.rand(1, 1, 224, 224, 3).astype("float32"), + "input_features": np.random.rand(1, 16, 32).astype("float32"), + "input_features_mask": np.zeros((1, 16), dtype=bool), + } + + # Forward pass. + outputs = backbone(input_data) + ``` + """ + + def __init__( + self, + text_vocab_size, + text_hidden_size, + num_hidden_layers, + pad_token_id, + num_attention_heads, + num_key_value_heads, + head_dim, + intermediate_size, + hidden_activation, + layer_types, + sliding_window, + rope_theta, + max_position_embeddings, + vocab_size_per_layer_input, + hidden_size_per_layer_input, + altup_num_inputs, + laurel_rank, + attention_bias=False, + attention_dropout=0.0, + rope_scaling=None, + rope_local_base_freq=10000.0, + activation_sparsity_pattern=None, + altup_coef_clip=None, + altup_active_idx=0, + altup_correct_scale=True, + num_kv_shared_layers=0, + vision_encoder_config=None, + vision_hidden_size=2048, + vision_vocab_size=128, + vision_vocab_offset=100, + vision_soft_tokens_per_image=256, + image_token_id=98, + audio_encoder_config=None, + audio_hidden_size=32, + audio_vocab_size=128, + audio_vocab_offset=228, + audio_soft_tokens_per_image=188, + audio_token_id=99, + rms_norm_eps=1e-6, + dtype=None, + **kwargs, + ): + # === Layers === + self.vision_encoder = None + if vision_encoder_config: + from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, + ) + + vision_encoder_config["dtype"] = dtype + self.vision_encoder = MobileNetV5Backbone.from_config( + vision_encoder_config + ) + self.audio_encoder = None + if audio_encoder_config: + audio_config = audio_encoder_config.copy() + audio_config.pop("dtype", None) + self.audio_encoder = Gemma3nAudioEncoder( + dtype=dtype, **audio_config + ) + self.language_model = Gemma3nTextModel( + pad_token_id=pad_token_id, + vocab_size=text_vocab_size, + hidden_size=text_hidden_size, + num_hidden_layers=num_hidden_layers, + rms_norm_eps=rms_norm_eps, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + layer_types=layer_types, + sliding_window=sliding_window, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_local_base_freq=rope_local_base_freq, + max_position_embeddings=max_position_embeddings, + intermediate_size=intermediate_size, + hidden_activation=hidden_activation, + activation_sparsity_pattern=activation_sparsity_pattern, + altup_num_inputs=altup_num_inputs, + altup_coef_clip=altup_coef_clip, + altup_active_idx=altup_active_idx, + altup_correct_scale=altup_correct_scale, + laurel_rank=laurel_rank, + hidden_size_per_layer_input=hidden_size_per_layer_input, + vocab_size_per_layer_input=vocab_size_per_layer_input, + num_kv_shared_layers=num_kv_shared_layers, + dtype=dtype, + name="text_model", + ) + self.embed_vision = None + if self.vision_encoder: + self.embed_vision = Gemma3nMultimodalEmbedder( + multimodal_hidden_size=vision_hidden_size, + text_hidden_size=text_hidden_size, + rms_norm_eps=rms_norm_eps, + vocab_offset=vision_vocab_offset, + vocab_size=vision_vocab_size, + dtype=dtype, + name="vision_embedder", + ) + self.embed_audio = None + if self.audio_encoder: + self.embed_audio = Gemma3nMultimodalEmbedder( + multimodal_hidden_size=audio_hidden_size, + text_hidden_size=text_hidden_size, + rms_norm_eps=rms_norm_eps, + vocab_offset=audio_vocab_offset, + vocab_size=audio_vocab_size, + dtype=dtype, + name="audio_embedder", + ) + self.embedding_processor = Gemma3nMultimodalEmbeddingProcessor( + language_model=self.language_model, + vision_encoder=self.vision_encoder, + embed_vision=self.embed_vision, + audio_encoder=self.audio_encoder, + embed_audio=self.embed_audio, + vision_soft_tokens_per_image=vision_soft_tokens_per_image, + audio_soft_tokens_per_image=audio_soft_tokens_per_image, + image_token_id=image_token_id, + audio_token_id=audio_token_id, + vocab_size_per_layer_input=vocab_size_per_layer_input, + dtype=dtype, + name="multimodal_embedding_processor", + ) + + # === Functional Model === + # === Model Inputs === + token_ids_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + attention_mask_input = keras.Input( + shape=(None, None, None), dtype="bool", name="attention_mask" + ) + processor_inputs = { + "token_ids": token_ids_input, + } + model_inputs = { + "token_ids": token_ids_input, + "attention_mask": attention_mask_input, + } + + # === Modality Feature Extraction and Interleaving === + if self.vision_encoder: + input_shape = (None,) + tuple(self.vision_encoder.image_shape) + pixel_values_input = keras.Input( + shape=input_shape, + dtype="float32", + name="pixel_values", + ) + processor_inputs["pixel_values"] = pixel_values_input + model_inputs["pixel_values"] = pixel_values_input + if self.audio_encoder: + input_features_input = keras.Input( + shape=(None, self.audio_encoder.input_feat_size), + dtype="float32", + name="input_features", + ) + input_features_mask_input = keras.Input( + shape=(None,), dtype="bool", name="input_features_mask" + ) + processor_inputs["input_features"] = input_features_input + processor_inputs["input_features_mask"] = input_features_mask_input + model_inputs["input_features"] = input_features_input + model_inputs["input_features_mask"] = input_features_mask_input + final_embeds, per_layer_inputs = self.embedding_processor( + processor_inputs + ) + + # === Decoder layers === + # The Gemma3nTextModel encapsulates the decoder loop and final norm. + # It requires `input_ids` for its internal per-layer logic. + sequence_output = self.language_model( + token_ids_input, + attention_mask_input, + final_embeds, + per_layer_inputs, + ) + super().__init__( + inputs=model_inputs, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.text_vocab_size = text_vocab_size + self.text_hidden_size = text_hidden_size + self.num_hidden_layers = num_hidden_layers + self.pad_token_id = pad_token_id + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.layer_types = layer_types + self.sliding_window = sliding_window + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.altup_num_inputs = altup_num_inputs + self.laurel_rank = laurel_rank + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.activation_sparsity_pattern = activation_sparsity_pattern + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.num_kv_shared_layers = num_kv_shared_layers + self.vision_encoder_config = vision_encoder_config + self.vision_hidden_size = vision_hidden_size + self.vision_vocab_size = vision_vocab_size + self.vision_vocab_offset = vision_vocab_offset + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.image_token_id = image_token_id + self.audio_encoder_config = audio_encoder_config + self.audio_hidden_size = audio_hidden_size + self.audio_vocab_size = audio_vocab_size + self.audio_vocab_offset = audio_vocab_offset + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.audio_token_id = audio_token_id + self.rms_norm_eps = rms_norm_eps + + def get_config(self): + config = super().get_config() + config.update( + { + "text_vocab_size": self.text_vocab_size, + "text_hidden_size": self.text_hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "pad_token_id": self.pad_token_id, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "layer_types": self.layer_types, + "sliding_window": self.sliding_window, + "rope_theta": self.rope_theta, + "max_position_embeddings": self.max_position_embeddings, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + "altup_num_inputs": self.altup_num_inputs, + "laurel_rank": self.laurel_rank, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "rope_scaling": self.rope_scaling, + "rope_local_base_freq": self.rope_local_base_freq, + "activation_sparsity_pattern": self.activation_sparsity_pattern, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "num_kv_shared_layers": self.num_kv_shared_layers, + "vision_encoder_config": self.vision_encoder_config, + "vision_hidden_size": self.vision_hidden_size, + "vision_vocab_size": self.vision_vocab_size, + "vision_vocab_offset": self.vision_vocab_offset, + "vision_soft_tokens_per_image": self.vision_soft_tokens_per_image, # noqa: E501 + "image_token_id": self.image_token_id, + "audio_encoder_config": self.audio_encoder_config, + "audio_hidden_size": self.audio_hidden_size, + "audio_vocab_size": self.audio_vocab_size, + "audio_vocab_offset": self.audio_vocab_offset, + "audio_soft_tokens_per_image": self.audio_soft_tokens_per_image, + "audio_token_id": self.audio_token_id, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py b/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py new file mode 100644 index 0000000000..23e261f98c --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_backbone_test.py @@ -0,0 +1,185 @@ +from copy import deepcopy + +import numpy as np +import pytest +from absl.testing import parameterized + +try: + from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, + ) + from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + decode_arch_def, + ) + + mobilenetv5 = True +except ImportError: + mobilenetv5 = False + +from keras_hub.src.models.gemma3n.gemma3n_audio_encoder import ( + Gemma3nAudioEncoder, +) +from keras_hub.src.models.gemma3n.gemma3n_backbone import Gemma3nBackbone +from keras_hub.src.tests.test_case import TestCase + + +@pytest.mark.skipif( + not mobilenetv5, + reason="The pull request for MobileNetV5 is still open.", +) +class Gemma3nBackboneTest(TestCase): + def setUp(self): + self.batch_size = 1 + self.text_vocab_size = 50 + self.text_sequence_length = 16 + self.image_height = 224 + self.image_width = 224 + self.audio_sequence_length = 16 + self.audio_feature_size = 32 + # === Vision Encoder === + if mobilenetv5: + vision_arch_def = [["er_r1_k3_s1_e1_c16"]] + vision_block_args = decode_arch_def(vision_arch_def) + vision_encoder = MobileNetV5Backbone( + block_args=vision_block_args, + num_features=4, + image_shape=(self.image_height, self.image_width, 3), + use_msfa=False, + ) + vision_encoder_config = vision_encoder.get_config() + else: + vision_encoder_config = None + # === Audio Encoder === + audio_encoder = Gemma3nAudioEncoder( + hidden_size=8, + input_feat_size=self.audio_feature_size, + sscp_conv_channel_size=[4, 8], + sscp_conv_kernel_size=[(3, 3), (3, 3)], + sscp_conv_stride_size=[(2, 2), (2, 2)], + sscp_conv_group_norm_eps=1e-5, + conf_num_hidden_layers=1, + rms_norm_eps=1e-6, + gradient_clipping=1.0, + conf_residual_weight=0.5, + conf_num_attention_heads=1, + conf_attention_chunk_size=4, + conf_attention_context_right=5, + conf_attention_context_left=5, + conf_attention_logit_cap=50.0, + conf_conv_kernel_size=5, + conf_reduction_factor=1, + ) + # === Multimodal === + self.multimodal_init_kwargs = { + "text_vocab_size": self.text_vocab_size, + "text_hidden_size": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "num_attention_heads": 1, + "num_key_value_heads": 1, + "head_dim": 8, # hidden_size / num_attention_heads + "intermediate_size": [16], + "hidden_activation": "gelu_approximate", + "layer_types": ["full_attention"], + "sliding_window": 4, + "rope_theta": 10000.0, + "max_position_embeddings": self.text_sequence_length, + "vocab_size_per_layer_input": 50, + "hidden_size_per_layer_input": 2, + "altup_num_inputs": 2, + "laurel_rank": 1, + "vision_encoder_config": vision_encoder_config, + "vision_hidden_size": 16, + "audio_encoder_config": audio_encoder.get_config(), + "audio_hidden_size": 8, + } + self.multimodal_input_data = { + "token_ids": np.random.randint( + 0, + self.text_vocab_size, + size=(self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "attention_mask": np.ones( + ( + self.batch_size, + 1, + self.text_sequence_length, + self.text_sequence_length, + ), + dtype=bool, + ), + "pixel_values": np.random.rand( + self.batch_size, 1, self.image_height, self.image_width, 3 + ).astype("float32"), + "input_features": np.random.rand( + self.batch_size, + self.audio_sequence_length, + self.audio_feature_size, + ).astype("float32"), + "input_features_mask": np.zeros( + (self.batch_size, self.audio_sequence_length), dtype=bool + ), + } + # === Text-Only === + self.text_init_kwargs = deepcopy(self.multimodal_init_kwargs) + del self.text_init_kwargs["vision_encoder_config"] + del self.text_init_kwargs["audio_encoder_config"] + del self.text_init_kwargs["vision_hidden_size"] + del self.text_init_kwargs["audio_hidden_size"] + self.text_input_data = deepcopy(self.multimodal_input_data) + del self.text_input_data["pixel_values"] + del self.text_input_data["input_features"] + del self.text_input_data["input_features_mask"] + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + def test_backbone_basics(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + input_data = self.multimodal_input_data + else: + init_kwargs = self.text_init_kwargs + input_data = self.text_input_data + self.run_backbone_test( + cls=Gemma3nBackbone, + init_kwargs=init_kwargs, + input_data=input_data, + expected_output_shape=( + self.batch_size, + self.text_sequence_length, + init_kwargs["text_hidden_size"], + ), + ) + + @parameterized.named_parameters( + ("multimodal", "multimodal"), ("text_only", "text_only") + ) + def test_saved_model(self, backbone_type): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + input_data = self.multimodal_input_data + else: + init_kwargs = self.text_init_kwargs + input_data = self.text_input_data + self.run_model_saving_test( + cls=Gemma3nBackbone, + init_kwargs=init_kwargs, + input_data=input_data, + ) + + @parameterized.named_parameters( + ("multimodal", "multimodal", 10354, 7), + ("text_only", "text_only", 1450, 4), + ) + def test_architecture_characteristics( + self, backbone_type, num_params, num_layers + ): + if backbone_type == "multimodal": + init_kwargs = self.multimodal_init_kwargs + else: + init_kwargs = self.text_init_kwargs + model = Gemma3nBackbone(**init_kwargs) + self.assertEqual(model.count_params(), num_params) + self.assertEqual(len(model.layers), num_layers) diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py b/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py new file mode 100644 index 0000000000..a071dd10fe --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_decoder.py @@ -0,0 +1,274 @@ +import math + +import keras + +from keras_hub.src.models.gemma3n.gemma3n_attention import Gemma3nTextAttention +from keras_hub.src.models.gemma3n.gemma3n_text_layers import Gemma3nTextAltUp +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextLaurelBlock, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import Gemma3nTextMLP +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextDecoderBlock(keras.layers.Layer): + """A layer that implements a single Gemma3n decoder block. + + This layer combines self-attention, feed-forward networks, and normalization + to process sequences. It includes specialized components like AltUp and + Laurel blocks for enhanced performance. + + Args: + hidden_size: int. The size of the hidden states. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + num_attention_heads: int. The number of attention heads. + num_key_value_heads: int. The number of key and value heads for + Grouped-Query Attention. + head_dim: int. The dimension of each attention head. + attention_bias: bool. If `True`, attention layers will use a bias. + attention_dropout: float. The dropout rate for the attention mechanism. + is_sliding: bool. If `True`, enables sliding window attention. + sliding_window: int. The size of the sliding window for attention. + intermediate_size: int. The size of the intermediate layer in the MLP. + hidden_activation: str. The activation function for the MLP. + activation_sparsity: float. Sparsity factor for the activation function. + altup_num_inputs: int. The number of inputs for the AltUp layer. + altup_coef_clip: float. Coefficient clipping value for the AltUp layer. + altup_active_idx: int. The index of the active prediction in the + AltUp layer. + altup_correct_scale: bool. Whether to scale the corrected output from + the AltUp layer. + laurel_rank: int. The rank for the Laurel block. + hidden_size_per_layer_input: int. The hidden size for the per-layer + input projection. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + is_sliding, + sliding_window, + intermediate_size, + hidden_activation, + activation_sparsity, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.is_sliding = is_sliding + self.sliding_window = sliding_window + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity = activation_sparsity + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.laurel_rank = laurel_rank + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.attention = Gemma3nTextAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if is_sliding else None, + name="attention", + dtype=self.dtype_policy, + ) + self.mlp = Gemma3nTextMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_activation=hidden_activation, + activation_sparsity=activation_sparsity, + name="mlp", + dtype=self.dtype_policy, + ) + self.input_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="input_layernorm", + dtype=self.dtype_policy, + ) + self.post_attention_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_attention_layernorm", + dtype=self.dtype_policy, + ) + self.pre_feedforward_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="pre_feedforward_layernorm", + dtype=self.dtype_policy, + ) + self.post_feedforward_layernorm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_feedforward_layernorm", + dtype=self.dtype_policy, + ) + self.altup = Gemma3nTextAltUp( + hidden_size=hidden_size, + altup_num_inputs=altup_num_inputs, + altup_coef_clip=altup_coef_clip, + altup_active_idx=altup_active_idx, + rms_norm_eps=rms_norm_eps, + altup_correct_scale=altup_correct_scale, + name="altup", + dtype=self.dtype_policy, + ) + self.laurel = Gemma3nTextLaurelBlock( + hidden_size=hidden_size, + laurel_rank=laurel_rank, + rms_norm_eps=rms_norm_eps, + name="laurel", + dtype=self.dtype_policy, + ) + self.per_layer_input_gate = keras.layers.Dense( + hidden_size_per_layer_input, + use_bias=False, + name="per_layer_input_gate", + dtype=self.dtype_policy, + ) + self.per_layer_projection = keras.layers.Dense( + hidden_size, + use_bias=False, + name="per_layer_projection", + dtype=self.dtype_policy, + ) + self.post_per_layer_input_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_per_layer_input_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + ( + hidden_states_shape, + _, + _, + per_layer_input_shape, + _, + ) = input_shape + active_prediction_shape = hidden_states_shape[1:] + self.input_layernorm.build(active_prediction_shape) + self.laurel.build(active_prediction_shape) + self.attention.build(active_prediction_shape) + self.post_attention_layernorm.build(active_prediction_shape) + self.pre_feedforward_layernorm.build(active_prediction_shape) + self.mlp.build(active_prediction_shape) + self.post_feedforward_layernorm.build(active_prediction_shape) + self.altup.build(hidden_states_shape) + self.per_layer_input_gate.build(active_prediction_shape) + self.per_layer_projection.build(per_layer_input_shape) + self.post_per_layer_input_norm.build(active_prediction_shape) + if self.hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + super().build(input_shape) + + def call(self, inputs): + ( + hidden_states, + position_embeddings_global, + position_embeddings_local, + per_layer_input, + attention_mask, + ) = inputs + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.altup_active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + position_embeddings = ( + position_embeddings_local + if self.is_sliding + else position_embeddings_global + ) + attn, _ = self.attention( + active_prediction_normed, position_embeddings, attention_mask + ) + attn = self.post_attention_layernorm(attn) + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + corrected_predictions = self.altup.correct( + predictions, attn_ffw_laurel_gated + ) + corrected_predictions_list = [ + corrected_predictions[i] + for i in range(corrected_predictions.shape[0]) + ] + first_prediction = corrected_predictions_list[self.altup_active_idx] + if self.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output( + first_prediction + ) + first_prediction_gated = self.per_layer_input_gate(first_prediction) + first_prediction_activated = self.act_fn(first_prediction_gated) + first_prediction_multiplied = ( + first_prediction_activated * per_layer_input + ) + first_prediction_projected = self.per_layer_projection( + first_prediction_multiplied + ) + first_prediction_normed = self.post_per_layer_input_norm( + first_prediction_projected + ) + for i in range(1, len(corrected_predictions_list)): + corrected_predictions_list[i] += first_prediction_normed + return keras.ops.stack(corrected_predictions_list, axis=0) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "is_sliding": self.is_sliding, + "sliding_window": self.sliding_window, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity": self.activation_sparsity, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "laurel_rank": self.laurel_rank, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_layers.py b/keras_hub/src/models/gemma3n/gemma3n_text_layers.py new file mode 100644 index 0000000000..ba8d36eb04 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_layers.py @@ -0,0 +1,426 @@ +import keras +import numpy as np + +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextScaledWordEmbedding(keras.layers.Layer): + """A layer that computes scaled word embeddings for Gemma3n models. + + This layer performs a standard embedding lookup and then scales the + resulting vectors by a specified factor. + + Args: + num_embeddings: int. The size of the vocabulary. + embedding_dim: int. The dimension of the embedding vectors. + embed_scale: float. The scaling factor applied to the embeddings. + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + embed_scale=1.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_scale = embed_scale + self.embedding = keras.layers.Embedding( + self.num_embeddings, + self.embedding_dim, + name="embedding", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.embedding.build(input_shape) + super().build(input_shape) + + def call(self, inputs): + return self.embedding(inputs) * self.embed_scale + + def get_config(self): + config = super().get_config() + config.update( + { + "num_embeddings": self.num_embeddings, + "embedding_dim": self.embedding_dim, + "embed_scale": self.embed_scale, + } + ) + return config + + +class Gemma3nTextRotaryEmbedding(keras.layers.Layer): + """A layer that computes rotary positional embeddings for Gemma3n models. + + This layer calculates the cosine and sine matrices for Rotary Positional + Embedding (RoPE), which are then applied to query and key tensors in the + attention mechanism to inject positional information. + + Args: + head_dim: int. The dimension of each attention head. + rope_theta: float. The base for the rotary frequency. + max_position_embeddings: int. The maximum sequence length that this + model might be used with. + rope_scaling: dict or `None`. Specifies the scaling strategy for RoPE. + base: float. The base value for the inverse frequency calculation. + """ + + def __init__( + self, + head_dim, + rope_theta, + max_position_embeddings, + rope_scaling, + base=10000, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.rope_scaling = rope_scaling + self.base = base + inv_freq = 1.0 / ( + self.base + ** (np.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim) + ) + self.inv_freq = keras.ops.convert_to_tensor(inv_freq) + self.attention_scaling = 1.0 + + def call(self, x, position_ids): + inv_freq_expanded = keras.ops.expand_dims( + keras.ops.expand_dims(self.inv_freq, 0), -1 + ) + inv_freq_expanded = keras.ops.repeat( + inv_freq_expanded, repeats=keras.ops.shape(position_ids)[0], axis=0 + ) + position_ids_expanded = keras.ops.expand_dims( + keras.ops.cast(position_ids, "float32"), 1 + ) + + freqs = keras.ops.transpose( + keras.ops.matmul(inv_freq_expanded, position_ids_expanded), + (0, 2, 1), + ) + emb = keras.ops.concatenate([freqs, freqs], axis=-1) + cos = keras.ops.cos(emb) * self.attention_scaling + sin = keras.ops.sin(emb) * self.attention_scaling + return keras.ops.cast(cos, x.dtype), keras.ops.cast(sin, x.dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "head_dim": self.head_dim, + "rope_theta": self.rope_theta, + "max_position_embeddings": self.max_position_embeddings, + "rope_scaling": self.rope_scaling, + "base": self.base, + } + ) + return config + + +class Gemma3nTextMLP(keras.layers.Layer): + """A Gemma3n-specific feed-forward network (MLP) layer. + + This layer implements the MLP block used in Gemma3n transformer layers, + featuring a gated linear unit (GLU) structure. It can also apply activation + sparsity using a Gaussian top-k mechanism. + + Args: + hidden_size: int. The dimension of the hidden state. + intermediate_size: int. The dimension of the intermediate layer in the + MLP. + hidden_activation: str or callable. The activation function to use. + activation_sparsity: float. The target sparsity for activations, + enabling the Gaussian top-k mechanism if greater than 0. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + activation_sparsity, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity = activation_sparsity + self.gate_proj = keras.layers.Dense( + intermediate_size, + use_bias=False, + name="gate_proj", + dtype=self.dtype_policy, + ) + self.up_proj = keras.layers.Dense( + intermediate_size, + use_bias=False, + name="up_proj", + dtype=self.dtype_policy, + ) + self.down_proj = keras.layers.Dense( + hidden_size, + use_bias=False, + name="down_proj", + dtype=self.dtype_policy, + ) + if hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(hidden_activation) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + self.down_proj.build((None, self.intermediate_size)) + super().build(input_shape) + + def _gaussian_topk(self, inputs): + target_sparsity_tensor = keras.ops.convert_to_tensor( + self.activation_sparsity, dtype="float32" + ) + std_multiplier = keras.ops.erfinv( + 2 * target_sparsity_tensor - 1 + ) * keras.ops.sqrt(keras.ops.convert_to_tensor(2.0, dtype="float32")) + std_multiplier = keras.ops.cast(std_multiplier, dtype=inputs.dtype) + inputs_mean = keras.ops.mean(inputs, axis=-1, keepdims=True) + inputs_std = keras.ops.std(inputs, axis=-1, keepdims=True) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return keras.ops.relu(inputs - cutoff_x) + + def call(self, hidden_states): + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + down_proj = self.down_proj(activations * up_proj) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity": self.activation_sparsity, + } + ) + return config + + +class Gemma3nTextLaurelBlock(keras.layers.Layer): + """A Laurel block layer for the Gemma3n model. + + This layer implements a low-rank residual block which applies a + down-projection to a specified rank, followed by an up-projection. The + result is normalized and added back to the original input, forming a + residual connection. + + Args: + hidden_size: int. The dimension of the hidden state. + laurel_rank: int. The rank of the low-rank adaptation. + rms_norm_eps: float. The epsilon value for the RMS normalization layer. + """ + + def __init__( + self, hidden_size, laurel_rank, rms_norm_eps, dtype=None, **kwargs + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.laurel_rank = laurel_rank + self.rms_norm_eps = rms_norm_eps + self.linear_left = keras.layers.Dense( + laurel_rank, + use_bias=False, + name="linear_left", + dtype=self.dtype_policy, + ) + self.linear_right = keras.layers.Dense( + hidden_size, + use_bias=False, + name="linear_right", + dtype=self.dtype_policy, + ) + self.post_laurel_norm = Gemma3nRMSNorm( + hidden_size, + eps=rms_norm_eps, + name="post_laurel_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.linear_left.build(input_shape) + self.linear_right.build((None, self.laurel_rank)) + self.post_laurel_norm.build(input_shape) + super().build(input_shape) + + def call(self, hidden_states): + laurel_hidden_states = self.linear_left(hidden_states) + laurel_hidden_states = self.linear_right(laurel_hidden_states) + normed_laurel_hidden_states = self.post_laurel_norm( + laurel_hidden_states + ) + return hidden_states + normed_laurel_hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "laurel_rank": self.laurel_rank, + "rms_norm_eps": self.rms_norm_eps, + } + ) + return config + + +class Gemma3nTextAltUp(keras.layers.Layer): + """An Alternating Update (AltUp) layer for the Gemma3n model. + + This layer implements the AltUp mechanism, which combines multiple input + modalities through a predict-and-correct cycle. It uses a router to compute + modality-specific coefficients for predicting and correcting hidden states. + + Args: + hidden_size: int. The dimension of the hidden state. + altup_num_inputs: int. The number of input modalities to the AltUp + block. + altup_coef_clip: float. The clipping value for coefficients. + altup_active_idx: int. The index of the currently active input. + rms_norm_eps: float. The epsilon value for the Gemma 3n RMS + normalization layers. + altup_correct_scale: bool. If `True`, enables a learnable scaling + factor on the corrected output. + """ + + def __init__( + self, + hidden_size, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + rms_norm_eps, + altup_correct_scale, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.rms_norm_eps = rms_norm_eps + + self.altup_correct_scale = altup_correct_scale + self.correct_output_scale = None + self.correction_coefs = keras.layers.Dense( + self.altup_num_inputs, + use_bias=False, + name="correction_coefs", + dtype=self.dtype_policy, + ) + self.prediction_coefs = keras.layers.Dense( + self.altup_num_inputs**2, + use_bias=False, + name="prediction_coefs", + dtype=self.dtype_policy, + ) + self.modality_router = keras.layers.Dense( + self.altup_num_inputs, + use_bias=False, + name="modality_router", + dtype=self.dtype_policy, + ) + self.router_norm = Gemma3nRMSNorm( + self.hidden_size, + eps=self.rms_norm_eps, + name="router_norm", + dtype=self.dtype_policy, + ) + self.router_input_scale = self.hidden_size**-1.0 + + def build(self, input_shape): + if self.altup_correct_scale: + self.correct_output_scale = self.add_weight( + shape=(self.hidden_size,), + initializer="zeros", + trainable=True, + name="correct_output_scale", + dtype=self.dtype_policy.variable_dtype, + ) + router_input_shape = input_shape[1:] + self.router_norm.build(router_input_shape) + self.modality_router.build(router_input_shape) + coefs_input_shape = router_input_shape[:-1] + (self.altup_num_inputs,) + self.correction_coefs.build(coefs_input_shape) + self.prediction_coefs.build(coefs_input_shape) + super().build(input_shape) + + def compute_router_modalities(self, x): + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return keras.ops.cast( + keras.ops.tanh(keras.ops.cast(routed, "float32")), x.dtype + ) + + def predict(self, hidden_states): + modalities = self.compute_router_modalities( + hidden_states[self.altup_active_idx] + ) + all_coefs = keras.ops.reshape( + self.prediction_coefs(modalities), + modalities.shape[:-1] + + (self.altup_num_inputs, self.altup_num_inputs), + ) + all_coefs = keras.ops.transpose(all_coefs, (0, 1, 3, 2)) + predictions = keras.ops.matmul( + keras.ops.transpose(hidden_states, (1, 2, 3, 0)), all_coefs + ) + predictions = keras.ops.transpose(predictions, (3, 0, 1, 2)) + predictions += hidden_states + return predictions + + def correct(self, predictions, activated): + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.altup_active_idx] + innovation = keras.ops.repeat( + keras.ops.expand_dims(innovation, 0), self.altup_num_inputs, axis=0 + ) + all_coefs = self.correction_coefs(modalities) + 1.0 + all_coefs = keras.ops.expand_dims( + keras.ops.transpose(all_coefs, (2, 0, 1)), -1 + ) + corrected = innovation * all_coefs + corrected += predictions + return corrected + + def scale_corrected_output(self, corrected): + return corrected * self.correct_output_scale + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "rms_norm_eps": self.rms_norm_eps, + "altup_correct_scale": self.altup_correct_scale, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_text_model.py b/keras_hub/src/models/gemma3n/gemma3n_text_model.py new file mode 100644 index 0000000000..2a668cdc35 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_text_model.py @@ -0,0 +1,384 @@ +import math + +import keras + +from keras_hub.src.models.gemma3n.gemma3n_text_decoder import ( + Gemma3nTextDecoderBlock, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextRotaryEmbedding, +) +from keras_hub.src.models.gemma3n.gemma3n_text_layers import ( + Gemma3nTextScaledWordEmbedding, +) +from keras_hub.src.models.gemma3n.rms_normalization import Gemma3nRMSNorm + + +class Gemma3nTextModel(keras.layers.Layer): + """The core Gemma3n text model layer. + + This layer implements the transformer architecture of the Gemma3n model. + It includes token embeddings, multiple decoder blocks, and final + normalization. + + Args: + pad_token_id: int. The id for the padding token. + vocab_size: int. The size of the vocabulary. + hidden_size: int. The size of the hidden states. + num_hidden_layers: int. The number of hidden layers in the transformer. + rms_norm_eps: float. The epsilon value for the RMS normalization layers. + num_attention_heads: int. The number of attention heads. + num_key_value_heads: int. The number of key-value heads for GQA. + head_dim: int. The dimension of each attention head. + attention_bias: bool. Whether to use a bias in the attention mechanism. + attention_dropout: float. The dropout rate for the attention scores. + layer_types: list of str. The type of each layer, e.g., + "sliding_attention". + sliding_window: int. The sliding window size for sliding window + attention. + rope_theta: float. The base frequency for Rotary Positional Embeddings. + rope_scaling: float or None. The scaling factor for RoPE. + rope_local_base_freq: float. The base frequency for local RoPE. + max_position_embeddings: int. The maximum sequence length. + intermediate_size: list of int. The size of the intermediate layer in + each of the feed-forward networks. + hidden_activation: str. The activation function for the hidden layers. + activation_sparsity_pattern: list of float or None. The sparsity pattern + for activations. + altup_num_inputs: int. The number of inputs for the AltUp mechanism. + altup_coef_clip: float. The coefficient clipping value for AltUp. + altup_active_idx: int. The active index for AltUp. + altup_correct_scale: bool. Whether to correct scaling in AltUp. + laurel_rank: int. The rank for LAUREL factorization. + hidden_size_per_layer_input: int. The hidden size for per-layer inputs. + vocab_size_per_layer_input: int. The vocabulary size for per-layer + inputs. + num_kv_shared_layers: int. The number of shared key-value layers. + """ + + def __init__( + self, + pad_token_id, + vocab_size, + hidden_size, + num_hidden_layers, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + layer_types, + sliding_window, + rope_theta, + rope_scaling, + rope_local_base_freq, + max_position_embeddings, + intermediate_size, + hidden_activation, + activation_sparsity_pattern, + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + vocab_size_per_layer_input, + num_kv_shared_layers, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_types = layer_types + self.sliding_window = sliding_window + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.rope_local_base_freq = rope_local_base_freq + self.max_position_embeddings = max_position_embeddings + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.activation_sparsity_pattern = activation_sparsity_pattern + self.altup_num_inputs = altup_num_inputs + self.altup_coef_clip = altup_coef_clip + self.altup_active_idx = altup_active_idx + self.altup_correct_scale = altup_correct_scale + self.laurel_rank = laurel_rank + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + self.padding_idx = pad_token_id + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + vocab_size, + hidden_size, + embed_scale=hidden_size**0.5, + name="embed_tokens", + dtype=self.dtype_policy, + ) + if activation_sparsity_pattern is None: + self.activation_sparsity_pattern = [0.0] * num_hidden_layers + self.layers = [ + Gemma3nTextDecoderBlock( + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + head_dim, + attention_bias, + attention_dropout, + layer_types[i] == "sliding_attention", + sliding_window, + intermediate_size[i], + hidden_activation, + self.activation_sparsity_pattern[i], + altup_num_inputs, + altup_coef_clip, + altup_active_idx, + altup_correct_scale, + laurel_rank, + hidden_size_per_layer_input, + name=f"decoder_block_{i}", + dtype=self.dtype_policy, + ) + for i in range(num_hidden_layers) + ] + self.norm = Gemma3nRMSNorm( + hidden_size, eps=rms_norm_eps, name="norm", dtype=self.dtype_policy + ) + self.rotary_emb = Gemma3nTextRotaryEmbedding( + head_dim, + rope_theta, + max_position_embeddings, + rope_scaling, + dtype=self.dtype_policy, + name="rotary_emb", + ) + self.rotary_emb_local = Gemma3nTextRotaryEmbedding( + head_dim, + rope_local_base_freq, + max_position_embeddings, + None, + dtype=self.dtype_policy, + name="rotary_emb_local", + ) + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + vocab_size_per_layer_input, + num_hidden_layers * hidden_size_per_layer_input, + embed_scale=hidden_size_per_layer_input**0.5, + name="embed_tokens_per_layer", + dtype=self.dtype_policy, + ) + self.per_layer_model_projection = keras.layers.Dense( + num_hidden_layers * hidden_size_per_layer_input, + use_bias=False, + name="per_layer_model_projection", + dtype=self.dtype_policy, + ) + self.per_layer_projection_norm = Gemma3nRMSNorm( + hidden_size_per_layer_input, + eps=rms_norm_eps, + name="per_layer_projection_norm", + dtype=self.dtype_policy, + ) + self.altup_projections = [ + keras.layers.Dense( + hidden_size, + use_bias=False, + name=f"altup_projection_{i}", + dtype=self.dtype_policy, + ) + for i in range(1, altup_num_inputs) + ] + self.altup_unembed_projections = [ + keras.layers.Dense( + hidden_size, + use_bias=False, + name=f"altup_unembed_projection_{i}", + dtype=self.dtype_policy, + ) + for i in range(1, altup_num_inputs) + ] + self.per_layer_projection_scale = hidden_size**-0.5 + self.per_layer_input_scale = 1.0 / math.sqrt(2.0) + + def build(self, input_shape): + if isinstance(input_shape, (list, tuple)) and isinstance( + input_shape[0], (list, tuple) + ): + input_ids_shape, _, inputs_embeds_shape, _ = input_shape + else: + input_ids_shape = input_shape + hidden_size = self.embed_tokens.embedding_dim + inputs_embeds_shape = input_ids_shape[:-1] + (hidden_size,) + self.embed_tokens.build(input_ids_shape) + self.embed_tokens_per_layer.build(input_ids_shape) + if not self.per_layer_model_projection.built: + self.per_layer_model_projection.build(inputs_embeds_shape) + per_layer_projection_norm_shape = ( + None, + None, + None, + self.hidden_size_per_layer_input, + ) + if not self.per_layer_projection_norm.built: + self.per_layer_projection_norm.build( + per_layer_projection_norm_shape + ) + for proj in self.altup_projections: + proj.build(inputs_embeds_shape) + for proj in self.altup_unembed_projections: + proj.build(inputs_embeds_shape) + decoder_hidden_states_shape = ( + self.altup_num_inputs, + ) + inputs_embeds_shape + decoder_per_layer_input_shape = input_ids_shape + ( + self.hidden_size_per_layer_input, + ) + decoder_input_shape = ( + decoder_hidden_states_shape, + None, # position_embeddings_global + None, # position_embeddings_local + decoder_per_layer_input_shape, + None, # attention_mask + ) + for layer in self.layers: + layer.build(decoder_input_shape) + self.norm.build(inputs_embeds_shape) + super().build(input_shape) + + def get_per_layer_inputs(self, input_ids): + embeds = self.embed_tokens_per_layer(input_ids) + return keras.ops.reshape( + embeds, + keras.ops.shape(input_ids) + + (self.num_hidden_layers, self.hidden_size_per_layer_input), + ) + + def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None): + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = ( + per_layer_projection * self.per_layer_projection_scale + ) + per_layer_projection = keras.ops.reshape( + per_layer_projection, + keras.ops.shape(inputs_embeds)[:-1] + + (self.num_hidden_layers, self.hidden_size_per_layer_input), + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection + ) + if per_layer_inputs is None: + return per_layer_projection + return ( + per_layer_projection + per_layer_inputs + ) * self.per_layer_input_scale + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, (list, tuple)) and isinstance( + input_shape[0], (list, tuple) + ): + input_ids_shape = input_shape[0] + else: + input_ids_shape = input_shape + hidden_size = self.embed_tokens.embedding_dim + return input_ids_shape + (hidden_size,) + + def call(self, input_ids, attention_mask, inputs_embeds, per_layer_inputs): + position_ids = keras.ops.expand_dims( + keras.ops.arange(0, keras.ops.shape(input_ids)[1]), 0 + ) + hidden_states_0 = inputs_embeds + cos_global, sin_global = self.rotary_emb(hidden_states_0, position_ids) + cos_local, sin_local = self.rotary_emb_local( + hidden_states_0, position_ids + ) + target_magnitude = keras.ops.sqrt( + keras.ops.mean(hidden_states_0**2, axis=-1, keepdims=True) + ) + epsilon = 1e-5 + temp_hidden_states = [hidden_states_0] + for proj in self.altup_projections: + altup_proj = proj(hidden_states_0) + new_magnitude = keras.ops.sqrt( + keras.ops.maximum( + keras.ops.mean(altup_proj**2, axis=-1, keepdims=True), + epsilon, + ) + ) + current_hidden_state = altup_proj * target_magnitude / new_magnitude + temp_hidden_states.append(current_hidden_state) + hidden_states = keras.ops.stack(temp_hidden_states, axis=0) + for i, decoder_layer in enumerate(self.layers): + per_layer_input = per_layer_inputs[:, :, i, :] + hidden_states = decoder_layer( + ( + hidden_states, + (cos_global, sin_global), + (cos_local, sin_local), + per_layer_input, + attention_mask, + ) + ) + target_magnitude = keras.ops.sqrt( + keras.ops.mean(hidden_states[0] ** 2, axis=-1, keepdims=True) + ) + temp_hidden_states = [hidden_states[0]] + for i, proj in enumerate(self.altup_unembed_projections): + altup_unemb_proj = proj(hidden_states[i + 1]) + new_magnitude = keras.ops.sqrt( + keras.ops.maximum( + keras.ops.mean(altup_unemb_proj**2, axis=-1, keepdims=True), + epsilon, + ) + ) + current_hidden_state = ( + altup_unemb_proj * target_magnitude / new_magnitude + ) + temp_hidden_states.append(current_hidden_state) + hidden_states = keras.ops.stack(temp_hidden_states) + hidden_states = keras.ops.mean(hidden_states, axis=0) + return self.norm(hidden_states) + + def get_config(self): + config = super().get_config() + config.update( + { + "pad_token_id": self.pad_token_id, + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "head_dim": self.head_dim, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "layer_types": self.layer_types, + "sliding_window": self.sliding_window, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "rope_local_base_freq": self.rope_local_base_freq, + "max_position_embeddings": self.max_position_embeddings, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "activation_sparsity_pattern": self.activation_sparsity_pattern, + "altup_num_inputs": self.altup_num_inputs, + "altup_coef_clip": self.altup_coef_clip, + "altup_active_idx": self.altup_active_idx, + "altup_correct_scale": self.altup_correct_scale, + "laurel_rank": self.laurel_rank, + "hidden_size_per_layer_input": self.hidden_size_per_layer_input, + "vocab_size_per_layer_input": self.vocab_size_per_layer_input, + "num_kv_shared_layers": self.num_kv_shared_layers, + } + ) + return config diff --git a/keras_hub/src/models/gemma3n/gemma3n_utils.py b/keras_hub/src/models/gemma3n/gemma3n_utils.py new file mode 100644 index 0000000000..0db8706d63 --- /dev/null +++ b/keras_hub/src/models/gemma3n/gemma3n_utils.py @@ -0,0 +1,122 @@ +import keras + + +def rotate_half(x): + """Rotates half of the hidden dimensions of the input tensor. + + This function is used to implement rotary positional embeddings. It splits + the last dimension of the input tensor into two halves, negates the second + half, and then concatenates them back together. + + Args: + x: The input tensor. + + Returns: + A new tensor with the second half of the last dimension rotated. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return keras.ops.concatenate([-x2, x1], axis=-1) + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key and value states for Grouped-Query Attention. + + This function is used in Grouped-Query Attention (GQA) to expand the key + and value states to match the number of query heads. + + Args: + hidden_states: The key or value tensor to be repeated, with a shape of + `[batch, num_key_value_heads, seq_len, head_dim]`. + n_rep: int. The number of times to repeat the key/value heads. + + Returns: + The repeated tensor with a shape of + `[batch, num_key_value_heads * n_rep, seq_len, head_dim]`. + """ + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = keras.ops.expand_dims(hidden_states, 2) + hidden_states = keras.ops.repeat(hidden_states, n_rep, axis=2) + return keras.ops.reshape( + hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim) + ) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + """Applies rotary positional embedding to the input tensor. + + Args: + x: The input tensor. + cos: The cosine part of the rotary embedding. + sin: The sine part of the rotary embedding. + unsqueeze_dim: int. The dimension to unsqueeze `cos` and `sin` before + applying the embedding. Defaults to 1. + + Returns: + The tensor with rotary positional embeddings applied. + """ + cos = keras.ops.expand_dims(cos, axis=unsqueeze_dim) + sin = keras.ops.expand_dims(sin, axis=unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +def eager_attention_forward( + query, + key, + value, + num_key_value_groups, + head_dim, + attention_mask, + dropout=0.0, + scaling=None, + softcap=None, + training=False, +): + """Forward pass for an eager attention implementation. + + Args: + query: The query tensor. + key: The key tensor. + value: The value tensor. + num_key_value_groups: int. The number of key-value groups. + head_dim: int. The dimension of each attention head. + attention_mask: The attention mask to apply. + dropout: float. The dropout rate. Defaults to 0.0. + scaling: float, optional. The scaling factor for attention scores. + If `None`, it defaults to `head_dim**-0.5`. + softcap: float, optional. A softcap value to apply to attention weights. + Defaults to `None`. + training: bool. Whether the model is in training mode. Defaults to + `False`. + """ + if scaling is None: + scaling = head_dim**-0.5 + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = ( + keras.ops.matmul(query, keras.ops.transpose(key_states, (0, 1, 3, 2))) + * scaling + ) + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + keras.ops.cast( + causal_mask, dtype=attn_weights.dtype + ) + attn_weights_dtype = attn_weights.dtype + attn_weights = keras.ops.softmax( + keras.ops.cast(attn_weights, "float32"), axis=-1 + ) + attn_weights = keras.ops.cast(attn_weights, attn_weights_dtype) + if training: + attn_weights = keras.layers.Dropout(dropout)( + attn_weights, training=training + ) + attn_output = keras.ops.matmul(attn_weights, value_states) + attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) + return attn_output, attn_weights diff --git a/keras_hub/src/models/gemma3n/rms_normalization.py b/keras_hub/src/models/gemma3n/rms_normalization.py new file mode 100644 index 0000000000..48955699d0 --- /dev/null +++ b/keras_hub/src/models/gemma3n/rms_normalization.py @@ -0,0 +1,67 @@ +import keras + + +class Gemma3nRMSNorm(keras.layers.Layer): + """The Gemma 3n specific RMS normalization layer. + + Args: + dim: int. The dimension of the input tensor. + eps: float. A small constant added to the denominator for numerical + stability. Defaults to `1e-6`. + with_scale: bool. Whether to include a learnable scaling parameter. + Defaults to `True`. + """ + + def __init__(self, dim, eps=1e-6, with_scale=True, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.eps = eps + self.with_scale = with_scale + + def build(self, input_shape): + if self.with_scale: + self.scale = self.add_weight( + shape=(self.dim,), + initializer="ones", + trainable=True, + name="scale", + dtype=self.dtype_policy.variable_dtype, + ) + else: + self.scale = 1.0 + super().build(input_shape) + + def call(self, x): + norm_x = x * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x), axis=-1, keepdims=True) + + self.eps + ) + return norm_x * self.scale + + def _int8_call(self, x): + x = keras.ops.cast(x, "float32") + norm_x = x * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x), axis=-1, keepdims=True) + + self.eps + ) + norm_x = norm_x * self.scale + return keras.ops.cast(norm_x, x.dtype) + + def _float8_call(self, x): + x_calc = keras.ops.cast(x, "float32") + norm_x = x_calc * keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x_calc), axis=-1, keepdims=True) + + self.eps + ) + return keras.ops.cast(norm_x * self.scale, x.dtype) + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "eps": self.eps, + "with_scale": self.with_scale, + } + ) + return config