Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JinaBert model #407

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

joelpaulkoch
Copy link
Contributor

I want to share my work on the JinaBert model.
Not sure if you want to include it at all, since it's not officially part of transformers, you must specify trust_remote_code=True when running it with transformers, and there is still an open issue.

This PR would enable bumblebee users to run the jina embeddings v2 models.

The implementation of JinaBert is here.

Another issue with this being a custom implementation is that there is another variant that I started to work on: jinaai/jina-embeddings-v2-base-code.

Both, jinaai/jina-embeddings-v2-base-en and jinaai/jina-embeddings-v2-base-code, specify JinaBertForMaskedLM as architecture but point to different implementations.

 "_name_or_path": "jinaai/jina-bert-implementation",
  "architectures": [
    "JinaBertForMaskedLM"
  ],
  "auto_map": {
    "AutoConfig": "jinaai/jina-bert-implementation--configuration_bert.JinaBertConfig",
    "AutoModelForMaskedLM": "jinaai/jina-bert-implementation--modeling_bert.JinaBertForMaskedLM",
    "AutoModel": "jinaai/jina-bert-implementation--modeling_bert.JinaBertModel",
    "AutoModelForSequenceClassification": "jinaai/jina-bert-implementation--modeling_bert.JinaBertForSequenceClassification"
  },

vs.

  "_name_or_path": "jinaai/jina-bert-v2-qk-post-norm",
  "architectures": [
    "JinaBertForMaskedLM"
  ],
  "auto_map": {
    "AutoConfig": "jinaai/jina-bert-v2-qk-post-norm--configuration_bert.JinaBertConfig",
    "AutoModel": "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel",
    "AutoModelForMaskedLM": "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertForMaskedLM",
    "AutoModelForSequenceClassification": "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertForSequenceClassification"
  },

Is there a mechanism in bumblebee to distinguish these?

There are still some issues in this PR, I will add comments and can work on them over the next days/weeks.

@@ -150,6 +150,8 @@ defmodule Bumblebee do
"GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling},
"GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification},
"GPTNeoXForTokenClassification" => {Bumblebee.Text.GptNeoX, :for_token_classification},
"JinaBertForMaskedLM" => {Bumblebee.Text.JinaBert, :for_masked_language_modeling},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config says it's JinaBertForMaskedLM. However, with this mapping there are missing and unused parameters:

11:51:58.408 [debug] the following parameters were missing:

  * language_modeling_head.dense.kernel
  * language_modeling_head.dense.bias
  * language_modeling_head.output.kernel
  * language_modeling_head.bias.bias
  * language_modeling_head.norm.gamma
  * language_modeling_head.norm.beta


11:51:58.408 [debug] the following PyTorch parameters were unused:

  * pooler.dense.bias
  * pooler.dense.weight

Looks to me like this is not in line with the previous :for_masked_language_modeling implementation of BERT.
So, we could map here to the :base architecture instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is JinaBertForMaskedLM implementation and it has the expected layers. I think the issue is that the model on the hub is actually JinaBertmodel and the config is wrong.

So the correct way to workaround this would be specifying architecture when loading:

Bumblebee.load_model({:hf, "..."}, architecture: :base)

It may be worth opening a PR on the HF repo, changing it to JinaBertmodel. Unfortunately, the same is the case for the other checkpoints of this model (small, etc).

Comment on lines +293 to +314
defp get_slopes_power_of_2(n) do
start = 2 ** -(2 ** -(:math.log2(n) - 3))
ratio = start
for i <- 0..(n - 1), do: start * ratio ** i
end

defp integer?(number) do
round(number) == number
end

defp get_alibi_head_slopes(n_heads) do
if integer?(:math.log2(n_heads)) do
get_slopes_power_of_2(n_heads)
else
closest_power_of_2 = 2 ** round(:math.floor(:math.log2(n_heads)))

get_slopes_power_of_2(closest_power_of_2) ++
(get_alibi_head_slopes(2 * closest_power_of_2)
|> Enum.take_every(2)
|> Enum.take(n_heads - closest_power_of_2))
end
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I could rewrite all of this using Nx functions. I'm assuming that would theoretically speed things up. Not sure if it's worth the effort.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of performance this is actually better than writing a defn, assuming n_heads is small (which is the case). The reason is that all of this code runs when building the defn expression (defn-compile time), and not as part of the inference. In other words, when compiling the model, we are building this small tensor and it gets embedded as a constant into the computation.

To make the distinction more clear, I would make alibi_matrix a defnp, and make alibi_head_slopes a deftransformp that returns a tensor.

Comment on lines +347 to +354
alibi_relative_bias_matrix =
Axon.nx(hidden_state, fn hidden_state ->
{_, seqlen, _} = Nx.shape(hidden_state)

matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions)

matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]]
end)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, we're recalculating the matrix on each run, right?

In the original implementation, they are storing the matrix and only recalculating when it's too small for the current seqlen, otherwise they cut it down to match the dimensions.
I guess we could do the same, maybe using Axon.StatefulOutput?

I'm also wondering, shouldn't we know seqlen anyways when we're compiling the model?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering, shouldn't we know seqlen anyways when we're compiling the model?

Yes! PyTorch models generally accept a dynamically sized inputs and adjust the computation accordingly. In our case, we always know the lengths at model compile time.

Consequently, instead of computing the whole matrix and slicing it, ideally we would compute it only for the known length, to avoid unnecessary work and memory consumption.

As for reusing the matrix, in cases like this we always create the tensor at runtime. We prefer the model to be stateless and also let the XLA compiler optimise across operations. There are certain cases where we need some level of statefulness, such as autoregressive text generation, and we do it with an explicit "cache" output/input (though a whole generation request is still stateless).

Comment on lines +559 to +561
"encoder.blocks.{n}.ffn.wo" => "encoder.layer.{n}.mlp.wo",
"encoder.blocks.{n}.ffn.layernorm" => "encoder.layer.{n}.mlp.layernorm",
"encoder.blocks.{n}.ffn.gated_layers" => "encoder.layer.{n}.mlp.gated_layers"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took over the param names, so we might want to change them


@tag :skip
test ":base" do
repo = {:hf, "doesnotexist/tiny-random-JinaBert"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to make a tiny-random-JinaBert with no success, I might try again..

@jonatanklosko
Copy link
Member

Hey @joelpaulkoch, thanks for the PR and a great article!

To be honest, I am hesitant to support implementations from the Hub because (a) theoretically they are less stable, because they may be still subject to tweaks; (b) model proliferation is more likely, the jina-embeddings-v2-base-en vs jina-embeddings-v2-base-code is a good example.

We generally wait until models make it to hf/transformers, though from huggingface/transformers#27035 it's not clear if that's ever going to happen.

At the moment, I would defer the decision and see how the status quo evolves. People can still use the model by installing bumblebee as {:bumblebee, github: "joelpaulkoch/jina-embeddings-v2"}.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants