Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add M2M100 & NLLB model #392

Merged
merged 6 commits into from
Aug 19, 2024
Merged

Conversation

aymanosman
Copy link
Contributor

The values used in the tests were obtained by running this python script:

from torch import tensor
from transformers import M2M100Model, M2M100ForConditionalGeneration

base = M2M100Model.from_pretrained("hf-internal-testing/tiny-random-M2M100Model")
generation = M2M100ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-M2M100ForConditionalGeneration")

inputs = {
    'input_ids': tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
    'decoder_input_ids':  tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
    'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

base(**inputs).last_hidden_state[:,1:4,1:4]
generation(**inputs).logits[:,1:4,1:4]
from transformers import AutoModelForSeq2SeqLM

generation = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-nllb")

inputs = {
    'input_ids': tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
    'decoder_input_ids':  tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
    'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

generation(**inputs).logits[:,1:4,1:4]

Here is a notebook demonstrating that it works, but there are issues regarding adding the language codes.

@aymanosman
Copy link
Contributor Author

I also saw these warnings that I do not understand:

21:24:16.049 [debug] the following parameters were missing:

  * language_modeling_head.logits_bias.bias


21:24:16.049 [debug] the following PyTorch parameters were unused:

  * lm_head.weight
  * ```

@jonatanklosko
Copy link
Member

@aymanosman thanks for the PR, looks clean, and great job adding the tests :)

lib/bumblebee.ex Outdated Show resolved Hide resolved
lib/bumblebee/text/pre_trained_tokenizer.ex Outdated Show resolved Hide resolved
Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aymanosman thanks! I will have a look at improving the language configuration soon :)

@jonatanklosko jonatanklosko merged commit 17e4397 into elixir-nx:main Aug 19, 2024
2 checks passed
@aymanosman aymanosman deleted the m2m100-and-nllb branch August 19, 2024 18:20
@jonatanklosko
Copy link
Member

jonatanklosko commented Aug 20, 2024

@aymanosman here's how you can specify the source and target language as expected by the model:

source_lang = "eng_Latn"
target_lang = "fra_Latn"

target_token = Bumblebee.Tokenizer.token_to_id(tokenizer, target_lang)
generation_config = Bumblebee.configure(generation_config, forced_bos_token_id: target_token)

serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)

text = "The bank of the river is beautiful in spring"
Nx.Serving.run(serving, "#{source_lang} #{text}")

(From brief testing it seems that specifying the source language improves the translation quality)

I am still considering how to best fit this into Bumblebee APIs, but this should unblock your testing :)

@aymanosman
Copy link
Contributor Author

@jonatanklosko Thanks. Looking forward to the improved API. I was thinking of modifying the generation serving to take structured input that contained more than just the input text but also the source and target languages.

So something like this would be possible:

Nx.Serving.run(serving, {"eng_Latn", "fra_Latn", "The bank of the river is beautiful in spring"})

But I'll wait for you to come up with a better idea :).

@jonatanklosko
Copy link
Member

Yeah, we already accept map inputs in this and other servings. In this case we need a separate translation serving, because languages require special handling. The tricky part is to allow languages to be dynamic (part of serving input, rather than known at model compilation time), but I have an idea how we can do that reasonably.

@jonatanklosko
Copy link
Member

@aymanosman implemented in #395 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants