From 958e8a3fed8c55e36636b2039c6115836e9753ea Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 25 Oct 2024 20:20:04 +0200 Subject: [PATCH 1/5] Add JinaBert with copy of Bert implementation --- lib/bumblebee.ex | 2 + lib/bumblebee/text/jina_bert.ex | 658 ++++++++++++++++++++++++++++++++ 2 files changed, 660 insertions(+) create mode 100644 lib/bumblebee/text/jina_bert.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..0e5a38f7 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -150,6 +150,8 @@ defmodule Bumblebee do "GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling}, "GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification}, "GPTNeoXForTokenClassification" => {Bumblebee.Text.GptNeoX, :for_token_classification}, + "JinaBertForMaskedLM" => {Bumblebee.Text.JinaBert, :for_masked_language_modeling}, + "JinaBertModel" => {Bumblebee.Text.JinaBert, :base}, "LayoutLMForMaskedLM" => {Bumblebee.Multimodal.LayoutLm, :for_masked_language_modeling}, "LayoutLMForQuestionAnswering" => {Bumblebee.Multimodal.LayoutLm, :for_question_answering}, "LayoutLMForSequenceClassification" => diff --git a/lib/bumblebee/text/jina_bert.ex b/lib/bumblebee/text/jina_bert.ex new file mode 100644 index 00000000..eb5980d5 --- /dev/null +++ b/lib/bumblebee/text/jina_bert.ex @@ -0,0 +1,658 @@ +defmodule Bumblebee.Text.JinaBert do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 30522, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 512, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + max_token_types: [ + default: 2, + doc: """ + the vocabulary size of the token type embedding (also referred to as segment embedding). + This corresponds to how many different token groups can be distinguished in the input + """ + ], + hidden_size: [ + default: 768, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size: [ + default: 3072, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :gelu, + doc: "the activation function" + ], + dropout_rate: [ + default: 0.1, + doc: "the dropout rate for embedding and encoder" + ], + attention_dropout_rate: [ + default: 0.1, + doc: "the dropout rate for attention weights" + ], + classifier_dropout_rate: [ + default: nil, + doc: + "the dropout rate for the classification head. If not specified, the value of `:dropout_rate` is used instead" + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by the layer normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ] + ] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label]) + + @moduledoc """ + BERT model family. + + ## Architectures + + * `:base` - plain BERT without any head on top + + * `:for_masked_language_modeling` - BERT with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - BERT with a sequence + classification head. The head returns logits corresponding to + possible classes + + * `:for_token_classification` - BERT with a token classification + head. The head returns logits for each token in the original + sequence + + * `:for_question_answering` - BERT with a span classification head. + The head returns logits for the span start and end positions + + * `:for_multiple_choice` - BERT with a multiple choice prediction + head. Each input in the batch consists of several sequences to + choose from and the model returns logits corresponding to those + choices + + * `:for_next_sentence_prediction` - BERT with a next sentence + prediction head. The head returns logits predicting whether the + second sentence is random or in context + + * `:for_pre_training` - BERT with both MLM and NSP heads as done + during the pre-training + + * `:for_causal_language_modeling` - BERT working as a decoder with + a language modeling head. The head returns logits for each token + in the original sequence + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"token_type_ids"` - `{batch_size, sequence_length}` + + Mask distinguishing groups in the input sequence. This is used + in when the input sequence is a semantically a pair of sequences. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{num_blocks, num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + ### Exceptions + + The `:for_multiple_choice` model accepts groups of sequences, so the + expected sequence shape is `{batch_size, num_choices, sequence_length}`. + + The `:for_causal_language_modeling` model is a decoder and accepts + the following additional inputs: `"encoder_hidden_state"`, + `"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_masked_language_modeling, + :for_sequence_classification, + :for_token_classification, + :for_question_answering, + :for_multiple_choice, + :for_next_sentence_prediction, + :for_pre_training, + :for_causal_language_modeling + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(%{architecture: :for_multiple_choice}) do + %{"input_ids" => Nx.template({1, 1, 1}, :u32)} + end + + def input_template(_spec) do + %{"input_ids" => Nx.template({1, 1}, :u32)} + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_masked_language_modeling} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + outputs.pooled_state + |> Axon.dropout( + rate: classifier_dropout_rate(spec), + name: "sequence_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_token_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Axon.dropout( + rate: classifier_dropout_rate(spec), + name: "token_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "token_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_question_answering} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, 2, + kernel_initializer: kernel_initializer(spec), + name: "question_answering_head.output" + ) + + {start_logits, end_logits} = Layers.split_pair(logits) + + Layers.output(%{ + start_logits: start_logits, + end_logits: end_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_multiple_choice} = spec) do + inputs = inputs(spec, shape: {nil, nil, nil}) + + group_inputs = ["input_ids", "attention_mask", "token_type_ids", "position_ids"] + + flat_inputs = + Enum.reduce(group_inputs, inputs, fn name, inputs -> + Map.update!(inputs, name, &Layers.flatten_leading/1) + end) + + outputs = core(flat_inputs, spec) + + logits = + outputs.pooled_state + |> Axon.dropout(rate: classifier_dropout_rate(spec), name: "multiple_choice_head.dropout") + |> Axon.dense(1, + kernel_initializer: kernel_initializer(spec), + name: "multiple_choice_head.output" + ) + + # The final shape depends on the dynamic batch size and number + # of choices, so we do a reshape based on the input shape + logits = + Axon.layer( + fn logits, input_ids, _opts -> + num_choices = Nx.axis_size(input_ids, 1) + Nx.reshape(logits, {:auto, num_choices}) + end, + [logits, inputs["input_ids"]] + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_next_sentence_prediction} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.pooled_state, 2, + kernel_initializer: kernel_initializer(spec), + name: "next_sentence_prediction_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_pre_training} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + lm_logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + nsp_logits = + Axon.dense(outputs.pooled_state, 2, + kernel_initializer: kernel_initializer(spec), + name: "next_sentence_prediction_head.output" + ) + + Layers.output(%{ + language_modeling_logits: lm_logits, + next_sentence_prediction_logits: nsp_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec, decoder?: true) + outputs = core(inputs, spec, decoder?: true) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cross_attentions: outputs.cross_attentions, + cache: outputs.cache + }) + end + + @impl true + def init_cache(spec, batch_size, max_length, inputs) do + encoder_sequence_length = + if encoder_hidden_state = inputs["encoder_hidden_state"] do + Nx.axis_size(encoder_hidden_state, 1) + end + + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + encoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks, + encoder_sequence_length: encoder_sequence_length + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + defp inputs(spec, opts \\ []) do + shape = Keyword.get(opts, :shape, {nil, nil}) + decoder? = Keyword.get(opts, :decoder?, false) + + hidden_shape = Tuple.append(shape, spec.hidden_size) + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("token_type_ids", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape) + ]) + + extra_decoder_inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("encoder_hidden_state", optional: true, shape: hidden_shape), + Axon.input("encoder_attention_mask", optional: true, shape: shape), + Axon.input("cross_attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("cache", optional: true) + ]) + + extra_decoder_inputs = + if decoder? do + extra_decoder_inputs + else + Map.new(extra_decoder_inputs, fn {name, _input} -> {name, Layers.none()} end) + end + + Map.merge(inputs, extra_decoder_inputs) + end + + defp core(inputs, spec, opts \\ []) do + decoder? = Keyword.get(opts, :decoder?, false) + + embeddings = + embedder(inputs["input_ids"], inputs["position_ids"], inputs["token_type_ids"], spec, + name: "embedder" + ) + + encoder_outputs = + encoder( + embeddings, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["encoder_hidden_state"], + inputs["encoder_attention_mask"], + inputs["cross_attention_head_mask"], + inputs["cache"], + spec, + decoder?: decoder?, + name: "encoder" + ) + + pooled_state = pooler(encoder_outputs.hidden_state, spec, name: "pooler") + + %{ + hidden_state: encoder_outputs.hidden_state, + pooled_state: pooled_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions, + cross_attentions: encoder_outputs.cross_attentions, + cache: encoder_outputs.cache + } + end + + defp embedder(input_ids, position_ids, token_type_ids, spec, opts) do + name = opts[:name] + + position_ids = + Layers.default position_ids do + Layers.default_position_ids(input_ids) + end + + token_type_ids = + Layers.default token_type_ids do + Layers.default_token_type_ids(input_ids) + end + + inputs_embeddings = + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + + position_embeddings = + Axon.embedding(position_ids, spec.max_positions, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "position_embedding") + ) + + token_type_embeddings = + Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_type_embedding") + ) + + Axon.add([inputs_embeddings, position_embeddings, token_type_embeddings]) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) + end + + defp encoder( + hidden_state, + attention_mask, + attention_head_mask, + encoder_hidden_state, + encoder_attention_mask, + cross_attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + decoder? = opts[:decoder?] + + cross_attention? = decoder? and spec.use_cross_attention + + Layers.Transformer.blocks( + hidden_state, + [ + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + causal: decoder?, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: spec.intermediate_size, + activation: spec.activation + ], + name: join(name, "blocks") + ] ++ + if(cross_attention?, + do: [ + cross_hidden_state: encoder_hidden_state, + cross_attention_mask: encoder_attention_mask, + cross_attention_head_mask: cross_attention_head_mask + ], + else: [] + ) + ) + end + + defp pooler(hidden_state, spec, opts) do + name = opts[:name] + + hidden_state + |> Layers.take_token(index: 0, axis: 1) + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + |> Axon.tanh() + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: use a shared parameter with embeddings.word_embeddings.kernel + # if spec.tie_word_embeddings is true (relevant for training) + + hidden_state + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "dense") + ) + |> Layers.activation(spec.activation, name: join(name, "activation")) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + # We reuse the kernel of input embeddings and add bias for each token + |> Layers.dense_transposed(spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + |> Axon.bias(name: join(name, "bias")) + end + + defp classifier_dropout_rate(spec) do + spec.classifier_dropout_rate || spec.dropout_rate + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + max_token_types: {"type_vocab_size", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + dropout_rate: {"hidden_dropout_prob", number()}, + attention_dropout_rate: {"attention_probs_dropout_prob", number()}, + classifier_dropout_rate: {"classifier_dropout", optional(number())}, + layer_norm_epsilon: {"layer_norm_eps", number()}, + initializer_scale: {"initializer_range", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "bert.embeddings.word_embeddings", + "embedder.position_embedding" => "bert.embeddings.position_embeddings", + "embedder.token_type_embedding" => "bert.embeddings.token_type_embeddings", + "embedder.norm" => "bert.embeddings.LayerNorm", + "encoder.blocks.{n}.self_attention.query" => + "bert.encoder.layer.{n}.attention.self.query", + "encoder.blocks.{n}.self_attention.key" => "bert.encoder.layer.{n}.attention.self.key", + "encoder.blocks.{n}.self_attention.value" => + "bert.encoder.layer.{n}.attention.self.value", + "encoder.blocks.{n}.self_attention.output" => + "bert.encoder.layer.{n}.attention.output.dense", + "encoder.blocks.{n}.self_attention_norm" => + "bert.encoder.layer.{n}.attention.output.LayerNorm", + "encoder.blocks.{n}.cross_attention.query" => + "bert.encoder.layer.{n}.crossattention.self.query", + "encoder.blocks.{n}.cross_attention.key" => + "bert.encoder.layer.{n}.crossattention.self.key", + "encoder.blocks.{n}.cross_attention.value" => + "bert.encoder.layer.{n}.crossattention.self.value", + "encoder.blocks.{n}.cross_attention.output" => + "bert.encoder.layer.{n}.crossattention.output.dense", + "encoder.blocks.{n}.cross_attention_norm" => + "bert.encoder.layer.{n}.crossattention.output.LayerNorm", + "encoder.blocks.{n}.ffn.intermediate" => "bert.encoder.layer.{n}.intermediate.dense", + "encoder.blocks.{n}.ffn.output" => "bert.encoder.layer.{n}.output.dense", + "encoder.blocks.{n}.output_norm" => "bert.encoder.layer.{n}.output.LayerNorm", + "pooler.output" => "bert.pooler.dense", + "language_modeling_head.dense" => "cls.predictions.transform.dense", + "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm", + "language_modeling_head.output" => "cls.predictions.decoder", + "language_modeling_head.bias" => "cls.predictions", + "next_sentence_prediction_head.output" => "cls.seq_relationship", + "sequence_classification_head.output" => "classifier", + "token_classification_head.output" => "classifier", + "multiple_choice_head.output" => "classifier", + "question_answering_head.output" => "qa_outputs" + } + end + end +end From b1b6ec3c9149b69639045707bb3b5d411ce9e988 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Sat, 26 Oct 2024 19:50:18 +0200 Subject: [PATCH 2/5] Implement JinaBert --- lib/bumblebee/text/jina_bert.ex | 151 +++++++++++++++++++++++++++----- 1 file changed, 129 insertions(+), 22 deletions(-) diff --git a/lib/bumblebee/text/jina_bert.ex b/lib/bumblebee/text/jina_bert.ex index eb5980d5..b4c48acf 100644 --- a/lib/bumblebee/text/jina_bert.ex +++ b/lib/bumblebee/text/jina_bert.ex @@ -71,7 +71,7 @@ defmodule Bumblebee.Text.JinaBert do ] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label]) @moduledoc """ - BERT model family. + Jina adaption of BERT model family. ## Architectures @@ -432,9 +432,7 @@ defmodule Bumblebee.Text.JinaBert do decoder? = Keyword.get(opts, :decoder?, false) embeddings = - embedder(inputs["input_ids"], inputs["position_ids"], inputs["token_type_ids"], spec, - name: "embedder" - ) + embedder(inputs["input_ids"], inputs["token_type_ids"], spec, name: "embedder") encoder_outputs = encoder( @@ -462,14 +460,9 @@ defmodule Bumblebee.Text.JinaBert do } end - defp embedder(input_ids, position_ids, token_type_ids, spec, opts) do + defp embedder(input_ids, token_type_ids, spec, opts) do name = opts[:name] - position_ids = - Layers.default position_ids do - Layers.default_position_ids(input_ids) - end - token_type_ids = Layers.default token_type_ids do Layers.default_token_type_ids(input_ids) @@ -481,23 +474,55 @@ defmodule Bumblebee.Text.JinaBert do name: join(name, "token_embedding") ) - position_embeddings = - Axon.embedding(position_ids, spec.max_positions, spec.hidden_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "position_embedding") - ) - token_type_embeddings = Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size, kernel_initializer: kernel_initializer(spec), name: join(name, "token_type_embedding") ) - Axon.add([inputs_embeddings, position_embeddings, token_type_embeddings]) + Axon.add([inputs_embeddings, token_type_embeddings]) |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) end + defp get_slopes_power_of_2(n) do + start = 2 ** -(2 ** -(:math.log2(n) - 3)) + ratio = start + for i <- 0..(n - 1), do: start * ratio ** i + end + + defp integer?(number) do + round(number) == number + end + + defp get_alibi_head_slopes(n_heads) do + if integer?(:math.log2(n_heads)) do + get_slopes_power_of_2(n_heads) + else + closest_power_of_2 = 2 ** round(:math.floor(:math.log2(n_heads))) + + get_slopes_power_of_2(closest_power_of_2) ++ + (get_alibi_head_slopes(2 * closest_power_of_2) + |> Enum.take_every(2) + |> Enum.take(n_heads - closest_power_of_2)) + end + end + + defp alibi_matrix(num_attention_heads, size) do + context_position = Nx.iota({1, size, 1}, axis: 1) + memory_position = Nx.iota({1, size}, axis: 1) + relative_position = Nx.abs(Nx.subtract(context_position, memory_position)) + + relative_position = Nx.tile(relative_position, [num_attention_heads, 1, 1]) + slopes = Nx.tensor(get_alibi_head_slopes(num_attention_heads)) |> Nx.multiply(-1) + + slopes + |> Nx.new_axis(-1) + |> Nx.new_axis(-1) + |> Nx.multiply(relative_position) + |> Nx.new_axis(0) + end + defp encoder( hidden_state, attention_mask, @@ -514,11 +539,22 @@ defmodule Bumblebee.Text.JinaBert do cross_attention? = decoder? and spec.use_cross_attention + # we build the alibi matrix only once instead of rebuilding + # for this we must use the maximum seqlen + alibi_relative_bias_matrix = + Axon.nx(hidden_state, fn hidden_state -> + {_, seqlen, _} = Nx.shape(hidden_state) + matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions) + + matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]] + end) + Layers.Transformer.blocks( hidden_state, [ attention_mask: attention_mask, attention_head_mask: attention_head_mask, + attention_relative_bias: alibi_relative_bias_matrix, cache: cache, causal: decoder?, num_blocks: spec.num_blocks, @@ -530,10 +566,8 @@ defmodule Bumblebee.Text.JinaBert do layer_norm: [ epsilon: spec.layer_norm_epsilon ], - ffn: [ - intermediate_size: spec.intermediate_size, - activation: spec.activation - ], + ffn: &glumlp(&1, spec, name: &2), + block_type: &jina_block_impl/3, name: join(name, "blocks") ] ++ if(cross_attention?, @@ -547,6 +581,45 @@ defmodule Bumblebee.Text.JinaBert do ) end + def glumlp( + hidden_states, + spec, + opts + ) do + name = opts[:name] + intermediate_size = spec.intermediate_size + activation = spec.activation + hidden_dropout_prob = spec.dropout_rate + hidden_size = spec.hidden_size + layer_norm_eps = spec.layer_norm_epsilon + + residual_connection = hidden_states + + hidden_states = + hidden_states + |> Axon.dense(intermediate_size * 2, use_bias: false, name: join(name, "gated_layers")) + + gated = + Axon.nx(hidden_states, fn hidden_states -> + hidden_states[[.., .., 0..(intermediate_size - 1)]] + end) + |> Axon.activation(activation) + + non_gated = + Axon.nx(hidden_states, fn hidden_states -> + hidden_states[[.., .., intermediate_size..-1//1]] + end) + + hidden_states = + Axon.multiply(gated, non_gated) + |> Axon.dropout(rate: hidden_dropout_prob) + |> Axon.dense(hidden_size, name: join(name, "wo")) + + hidden_states + |> Axon.add(residual_connection) + |> Axon.layer_norm(epsilon: layer_norm_eps, name: join(name, "layernorm")) + end + defp pooler(hidden_state, spec, opts) do name = opts[:name] @@ -588,6 +661,37 @@ defmodule Bumblebee.Text.JinaBert do Axon.Initializers.normal(scale: spec.initializer_scale) end + defp jina_block_impl(hidden_state, steps, _name) do + shortcut = hidden_state + + {hidden_state, attention_info} = steps.self_attention.(hidden_state) + + hidden_state = + hidden_state + |> Axon.add(shortcut) + |> steps.self_attention_norm.() + + {hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn hidden_state -> + shortcut = hidden_state + + {hidden_state, cross_attention_info} = steps.cross_attention.(hidden_state) + + hidden_state = + hidden_state + |> Axon.add(shortcut) + |> steps.cross_attention_norm.() + + {hidden_state, cross_attention_info} + end) + + hidden_state = + hidden_state + |> steps.ffn.() + + {hidden_state, attention_info, cross_attention_info} + end + defimpl Bumblebee.HuggingFace.Transformers.Config do def load(spec, data) do import Shared.Converters @@ -651,7 +755,10 @@ defmodule Bumblebee.Text.JinaBert do "sequence_classification_head.output" => "classifier", "token_classification_head.output" => "classifier", "multiple_choice_head.output" => "classifier", - "question_answering_head.output" => "qa_outputs" + "question_answering_head.output" => "qa_outputs", + "encoder.blocks.{n}.ffn.wo" => "encoder.layer.{n}.mlp.wo", + "encoder.blocks.{n}.ffn.layernorm" => "encoder.layer.{n}.mlp.layernorm", + "encoder.blocks.{n}.ffn.gated_layers" => "encoder.layer.{n}.mlp.gated_layers" } end end From ae2baf7930881df64d27739f9bfc2e66340e8b5d Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 8 Nov 2024 11:23:39 +0100 Subject: [PATCH 3/5] Clean up --- lib/bumblebee/text/jina_bert.ex | 210 +------------------------------- 1 file changed, 5 insertions(+), 205 deletions(-) diff --git a/lib/bumblebee/text/jina_bert.ex b/lib/bumblebee/text/jina_bert.ex index b4c48acf..4675e217 100644 --- a/lib/bumblebee/text/jina_bert.ex +++ b/lib/bumblebee/text/jina_bert.ex @@ -75,39 +75,12 @@ defmodule Bumblebee.Text.JinaBert do ## Architectures - * `:base` - plain BERT without any head on top + * `:base` - plain Jina BERT without any head on top - * `:for_masked_language_modeling` - BERT with a language modeling + * `:for_masked_language_modeling` - Jina BERT with a language modeling head. The head returns logits for each token in the original sequence - * `:for_sequence_classification` - BERT with a sequence - classification head. The head returns logits corresponding to - possible classes - - * `:for_token_classification` - BERT with a token classification - head. The head returns logits for each token in the original - sequence - - * `:for_question_answering` - BERT with a span classification head. - The head returns logits for the span start and end positions - - * `:for_multiple_choice` - BERT with a multiple choice prediction - head. Each input in the batch consists of several sequences to - choose from and the model returns logits corresponding to those - choices - - * `:for_next_sentence_prediction` - BERT with a next sentence - prediction head. The head returns logits predicting whether the - second sentence is random or in context - - * `:for_pre_training` - BERT with both MLM and NSP heads as done - during the pre-training - - * `:for_causal_language_modeling` - BERT working as a decoder with - a language modeling head. The head returns logits for each token - in the original sequence - ## Inputs * `"input_ids"` - `{batch_size, sequence_length}` @@ -135,15 +108,6 @@ defmodule Bumblebee.Text.JinaBert do Mask to nullify selected heads of the self-attention blocks in the encoder. - ### Exceptions - - The `:for_multiple_choice` model accepts groups of sequences, so the - expected sequence shape is `{batch_size, num_choices, sequence_length}`. - - The `:for_causal_language_modeling` model is a decoder and accepts - the following additional inputs: `"encoder_hidden_state"`, - `"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`. - ## Global layer options #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} @@ -155,6 +119,7 @@ defmodule Bumblebee.Text.JinaBert do ## References * [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) + * [Jina Embeddings 2: 8192-Token General-Purpose Text Embeddings for Long Documents](https://arxiv.org/abs/2310.19923) """ @@ -172,14 +137,7 @@ defmodule Bumblebee.Text.JinaBert do def architectures(), do: [ :base, - :for_masked_language_modeling, - :for_sequence_classification, - :for_token_classification, - :for_question_answering, - :for_multiple_choice, - :for_next_sentence_prediction, - :for_pre_training, - :for_causal_language_modeling + :for_masked_language_modeling ] @impl true @@ -220,159 +178,6 @@ defmodule Bumblebee.Text.JinaBert do }) end - def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do - inputs = inputs(spec) - outputs = core(inputs, spec) - - logits = - outputs.pooled_state - |> Axon.dropout( - rate: classifier_dropout_rate(spec), - name: "sequence_classification_head.dropout" - ) - |> Axon.dense(spec.num_labels, - kernel_initializer: kernel_initializer(spec), - name: "sequence_classification_head.output" - ) - - Layers.output(%{ - logits: logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_token_classification} = spec) do - inputs = inputs(spec) - outputs = core(inputs, spec) - - logits = - outputs.hidden_state - |> Axon.dropout( - rate: classifier_dropout_rate(spec), - name: "token_classification_head.dropout" - ) - |> Axon.dense(spec.num_labels, - kernel_initializer: kernel_initializer(spec), - name: "token_classification_head.output" - ) - - Layers.output(%{ - logits: logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_question_answering} = spec) do - inputs = inputs(spec) - outputs = core(inputs, spec) - - logits = - Axon.dense(outputs.hidden_state, 2, - kernel_initializer: kernel_initializer(spec), - name: "question_answering_head.output" - ) - - {start_logits, end_logits} = Layers.split_pair(logits) - - Layers.output(%{ - start_logits: start_logits, - end_logits: end_logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_multiple_choice} = spec) do - inputs = inputs(spec, shape: {nil, nil, nil}) - - group_inputs = ["input_ids", "attention_mask", "token_type_ids", "position_ids"] - - flat_inputs = - Enum.reduce(group_inputs, inputs, fn name, inputs -> - Map.update!(inputs, name, &Layers.flatten_leading/1) - end) - - outputs = core(flat_inputs, spec) - - logits = - outputs.pooled_state - |> Axon.dropout(rate: classifier_dropout_rate(spec), name: "multiple_choice_head.dropout") - |> Axon.dense(1, - kernel_initializer: kernel_initializer(spec), - name: "multiple_choice_head.output" - ) - - # The final shape depends on the dynamic batch size and number - # of choices, so we do a reshape based on the input shape - logits = - Axon.layer( - fn logits, input_ids, _opts -> - num_choices = Nx.axis_size(input_ids, 1) - Nx.reshape(logits, {:auto, num_choices}) - end, - [logits, inputs["input_ids"]] - ) - - Layers.output(%{ - logits: logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_next_sentence_prediction} = spec) do - inputs = inputs(spec) - outputs = core(inputs, spec) - - logits = - Axon.dense(outputs.pooled_state, 2, - kernel_initializer: kernel_initializer(spec), - name: "next_sentence_prediction_head.output" - ) - - Layers.output(%{ - logits: logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_pre_training} = spec) do - inputs = inputs(spec) - outputs = core(inputs, spec) - - lm_logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") - - nsp_logits = - Axon.dense(outputs.pooled_state, 2, - kernel_initializer: kernel_initializer(spec), - name: "next_sentence_prediction_head.output" - ) - - Layers.output(%{ - language_modeling_logits: lm_logits, - next_sentence_prediction_logits: nsp_logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions - }) - end - - def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do - inputs = inputs(spec, decoder?: true) - outputs = core(inputs, spec, decoder?: true) - logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") - - Layers.output(%{ - logits: logits, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions, - cross_attentions: outputs.cross_attentions, - cache: outputs.cache - }) - end - @impl true def init_cache(spec, batch_size, max_length, inputs) do encoder_sequence_length = @@ -539,11 +344,10 @@ defmodule Bumblebee.Text.JinaBert do cross_attention? = decoder? and spec.use_cross_attention - # we build the alibi matrix only once instead of rebuilding - # for this we must use the maximum seqlen alibi_relative_bias_matrix = Axon.nx(hidden_state, fn hidden_state -> {_, seqlen, _} = Nx.shape(hidden_state) + matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions) matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]] @@ -653,10 +457,6 @@ defmodule Bumblebee.Text.JinaBert do |> Axon.bias(name: join(name, "bias")) end - defp classifier_dropout_rate(spec) do - spec.classifier_dropout_rate || spec.dropout_rate - end - defp kernel_initializer(spec) do Axon.Initializers.normal(scale: spec.initializer_scale) end From ae6024cdbecfaed73f3edd8fc2fc34c7e2d46b95 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 8 Nov 2024 11:24:23 +0100 Subject: [PATCH 4/5] Add jina_bert_test.exs --- test/bumblebee/text/jina_bert_test.exs | 84 ++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 test/bumblebee/text/jina_bert_test.exs diff --git a/test/bumblebee/text/jina_bert_test.exs b/test/bumblebee/text/jina_bert_test.exs new file mode 100644 index 00000000..f50fba2e --- /dev/null +++ b/test/bumblebee/text/jina_bert_test.exs @@ -0,0 +1,84 @@ +defmodule Bumblebee.Text.JinaBertTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + @tag slow: true + test "jina-embeddings-v2-small-en" do + repo = {:hf, "jinaai/jina-embeddings-v2-small-en"} + + {:ok, %{model: model, params: params, spec: _spec}} = + Bumblebee.load_model(repo, + params_filename: "model.safetensors", + spec_overrides: [architecture: :base] + ) + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [-0.1346, 0.1457, 0.5572], + [-0.1383, 0.1412, 0.5643], + [-0.1125, 0.1354, 0.5599] + ]) + ) + end + + @tag :skip + test ":base" do + repo = {:hf, "doesnotexist/tiny-random-JinaBert"} + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model(repo) + + assert %Bumblebee.Text.JinaBert{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] + ]) + ) + end + + @tag :skip + test ":for_masked_language_modeling" do + repo = {:hf, "doesnotexist/tiny-random-JinaBert"} + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model(repo) + + assert %Bumblebee.Text.Bert{architecture: :for_masked_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1124} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([[[-0.0127, 0.0508, 0.0904], [0.1151, 0.1189, 0.0922], [0.0089, 0.1132, -0.2470]]]) + ) + end +end From cbe59c65479a8261247de1c653cb3070bec9427f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 8 Nov 2024 12:17:22 +0100 Subject: [PATCH 5/5] Some more clean up --- lib/bumblebee/text/jina_bert.ex | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/text/jina_bert.ex b/lib/bumblebee/text/jina_bert.ex index 4675e217..180da891 100644 --- a/lib/bumblebee/text/jina_bert.ex +++ b/lib/bumblebee/text/jina_bert.ex @@ -391,37 +391,34 @@ defmodule Bumblebee.Text.JinaBert do opts ) do name = opts[:name] - intermediate_size = spec.intermediate_size - activation = spec.activation - hidden_dropout_prob = spec.dropout_rate - hidden_size = spec.hidden_size - layer_norm_eps = spec.layer_norm_epsilon residual_connection = hidden_states hidden_states = - hidden_states - |> Axon.dense(intermediate_size * 2, use_bias: false, name: join(name, "gated_layers")) + Axon.dense(hidden_states, spec.intermediate_size * 2, + use_bias: false, + name: join(name, "gated_layers") + ) gated = Axon.nx(hidden_states, fn hidden_states -> - hidden_states[[.., .., 0..(intermediate_size - 1)]] + hidden_states[[.., .., 0..(spec.intermediate_size - 1)]] end) - |> Axon.activation(activation) + |> Axon.activation(spec.activation) non_gated = Axon.nx(hidden_states, fn hidden_states -> - hidden_states[[.., .., intermediate_size..-1//1]] + hidden_states[[.., .., spec.intermediate_size..-1//1]] end) hidden_states = Axon.multiply(gated, non_gated) - |> Axon.dropout(rate: hidden_dropout_prob) - |> Axon.dense(hidden_size, name: join(name, "wo")) + |> Axon.dropout(rate: spec.dropout_rate) + |> Axon.dense(spec.hidden_size, name: join(name, "wo")) hidden_states |> Axon.add(residual_connection) - |> Axon.layer_norm(epsilon: layer_norm_eps, name: join(name, "layernorm")) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm")) end defp pooler(hidden_state, spec, opts) do @@ -485,9 +482,7 @@ defmodule Bumblebee.Text.JinaBert do {hidden_state, cross_attention_info} end) - hidden_state = - hidden_state - |> steps.ffn.() + hidden_state = steps.ffn.(hidden_state) {hidden_state, attention_info, cross_attention_info} end