diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index e9d9817e..d9518678 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -117,7 +117,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do {stream, {}} end) - |> maybe_stream(opts[:stream], spec, featurizer, tokenizer, options) + |> add_postprocessing(opts[:stream], spec, featurizer, tokenizer, options) end defp validate_input(%{audio: audio} = input, sampling_rate, chunk_num_seconds) do @@ -351,7 +351,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do end end - defp maybe_stream(serving, false, spec, featurizer, tokenizer, options) do + defp add_postprocessing(serving, false, spec, featurizer, tokenizer, options) do Nx.Serving.client_postprocessing(serving, fn {outputs, _metadata}, {} -> outputs = Nx.to_list(outputs) state = decode_chunk_outputs_init(spec, featurizer, tokenizer) @@ -362,7 +362,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do end) end - defp maybe_stream(serving, true, spec, featurizer, tokenizer, options) do + defp add_postprocessing(serving, true, spec, featurizer, tokenizer, options) do serving |> Nx.Serving.streaming() |> Nx.Serving.client_postprocessing(fn stream, {} -> diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index aa4ca19b..b11f47a7 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -196,7 +196,8 @@ defmodule Bumblebee.Text do #=> %{ #=> results: [ #=> %{ - #=> text: "Elixir is a functional programming language that is designed to be used in a variety of applications. It" + #=> text: " programming language that is designed to be used in a variety of applications. It", + #=> token_summary: %{input: 5, output: 15, padding: 0} #=> } #=> ] #=> } @@ -224,6 +225,69 @@ defmodule Bumblebee.Text do defdelegate generation(model_info, tokenizer, generation_config, opts \\ []), to: Bumblebee.Text.TextGeneration + @type translation_input :: + %{ + :text => String.t(), + :source_language_token => String.t(), + :target_language_token => String.t(), + optional(:seed) => integer() | nil + } + @type translation_output :: generation_output() + + @doc """ + Builds serving for text translation. + + The serving accepts `t:translation_input/0` and returns `t:translation_output/0`. + A list of inputs is also supported. + + This serving is an extension of `generation/4` that handles per-input + language configuration. + + Note that this serving is designed for multilingual models that + require source/target language to be specified. Some text models are + trained for specific language pairs, others expect a command such as + "translate English to Spanish", in such cases you most likely want + to use `generation/4`. + + ## Options + + See `generation/4` for available options. + + ## Examples + + repository_id = "facebook/nllb-200-distilled-600M" + + {:ok, model_info} = Bumblebee.load_model({:hf, repository_id}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, repository_id}) + {:ok, generation_config} = Bumblebee.load_generation_config({:hf, repository_id}) + + serving = Bumblebee.Text.translation(model_info, tokenizer, generation_config) + + text = "The bank of the river is beautiful in spring" + + Nx.Serving.run(serving, %{ + text: text, + source_language_token: "eng_Latn", + target_language_token: "pol_Latn" + }) + #=> %{ + #=> results: [ + #=> %{ + #=> text: "W wiosnę brzeg rzeki jest piękny", + #=> token_summary: %{input: 11, output: 13, padding: 0} + #=> } + #=> ] + #=> } + """ + @spec translation( + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + Bumblebee.Text.GenerationConfig.t(), + keyword() + ) :: Nx.Serving.t() + defdelegate translation(model_info, tokenizer, generation_config, opts \\ []), + to: Bumblebee.Text.Translation + @type text_classification_input :: String.t() @type text_classification_output :: %{predictions: list(text_classification_prediction())} @type text_classification_prediction :: %{score: number(), label: String.t()} @@ -316,7 +380,7 @@ defmodule Bumblebee.Text do it is not already a pooled embedding. Supported values: * `:mean_pooling` - performs a mean across all tokens - + * `cls_token_pooling` - takes the embedding for the special CLS token. Note that we currently assume that the CLS token is the first token in the sequence diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index d95e64be..274b56ad 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -244,13 +244,13 @@ defmodule Bumblebee.Text.Generation do encoder_outputs = encoder_predict_fun.(params, inputs) batch_size = Nx.axis_size(encoder_input(inputs), 0) - decoder_input_ids = Nx.broadcast(decoder_start_token_id, {batch_size, 1}) + + inputs = Map.put(inputs, "encoder_hidden_state", encoder_outputs.hidden_state) inputs = - Map.merge(inputs, %{ - "encoder_hidden_state" => encoder_outputs.hidden_state, - "decoder_input_ids" => decoder_input_ids - }) + Map.put_new_lazy(inputs, "decoder_input_ids", fn -> + Nx.broadcast(decoder_start_token_id, {batch_size, 1}) + end) max_length = max_length_fun.(1) inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length) diff --git a/lib/bumblebee/text/pre_trained_tokenizer.ex b/lib/bumblebee/text/pre_trained_tokenizer.ex index fa07ee97..59ab3468 100644 --- a/lib/bumblebee/text/pre_trained_tokenizer.ex +++ b/lib/bumblebee/text/pre_trained_tokenizer.ex @@ -53,6 +53,16 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do whether to return the sequence length. The length is the effective number of tokens, so it is calculated after truncation, but does not include padding """ + ], + template_options: [ + default: [], + doc: """ + options configuring the tokenization template, specific to the given tokenizer type. + Recognised options are: + + * `:language_token` - for tokenizers: `:nllb` + + """ ] ] @@ -187,7 +197,8 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do pad: "", cls: "", mask: "" - } + }, + default_template_options: [language_token: "eng_Latn"] }, roberta: %{ special_tokens: %{ @@ -246,8 +257,15 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do # special tokens added by a template post-processor. By setting # truncation upfront, the tokenizer will apply it before the # post-processor accounting for the extra special tokens - if Keyword.has_key?(opts, :length) or Keyword.has_key?(opts, :truncation_direction) do - update_truncation(tokenizer) + tokenizer = + if Keyword.has_key?(opts, :length) or Keyword.has_key?(opts, :truncation_direction) do + update_truncation(tokenizer) + else + tokenizer + end + + if Keyword.has_key?(opts, :template_options) do + set_template(tokenizer) else tokenizer end @@ -269,6 +287,42 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do ) end + defp set_template(%{type: :nllb} = tokenizer) do + language_token = Keyword.fetch!(tokenizer.template_options, :language_token) + eos_token = tokenizer.special_tokens.eos + + set_template_postprocessor( + tokenizer, + "#{language_token} $A #{eos_token}", + "#{language_token} $A $B #{eos_token}", + [language_token, eos_token] + ) + end + + defp set_template(%{type: type} = tokenizer) do + if tokenizer.template_options != [] do + raise ArgumentError, + "#{inspect(type)} tokenizer expects no :template_options," <> + " got: #{inspect(tokenizer.template_options)}" + end + + tokenizer + end + + defp set_template_postprocessor(tokenizer, single, pair, special_tokens) do + post_processor = + Tokenizers.PostProcessor.template( + single: single, + pair: pair, + special_tokens: + for token <- special_tokens do + {token, Tokenizer.token_to_id(tokenizer.native_tokenizer, token)} + end + ) + + update_in(tokenizer.native_tokenizer, &Tokenizer.set_post_processor(&1, post_processor)) + end + @impl true def apply(tokenizer, input) do input = List.wrap(input) @@ -480,7 +534,7 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do " but got: #{inspect(tokenizer.type)}" end - %{special_tokens: special_tokens} = tokenizer_types[tokenizer.type] + tokenizer_type = %{special_tokens: special_tokens} = tokenizer_types[tokenizer.type] special_tokens = load_special_tokens(special_tokens, special_tokens_map) @@ -493,12 +547,15 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do [] end + template_options = tokenizer_type[:default_template_options] || [] + %{ tokenizer | native_tokenizer: native_tokenizer, special_tokens: special_tokens, additional_special_tokens: additional_special_tokens } + |> @for.config(template_options: template_options) end defp load_special_tokens(special_tokens, data) do diff --git a/lib/bumblebee/text/text_generation.ex b/lib/bumblebee/text/text_generation.ex index 65329025..c2169071 100644 --- a/lib/bumblebee/text/text_generation.ex +++ b/lib/bumblebee/text/text_generation.ex @@ -100,7 +100,7 @@ defmodule Bumblebee.Text.TextGeneration do {batch, {multi?, input_length, input_padded_length}} end) - |> maybe_stream(opts[:stream], opts[:stream_done], tokenizer) + |> add_postprocessing(opts[:stream], opts[:stream_done], tokenizer) end defp validate_input(text) when is_binary(text), do: validate_input(%{text: text}) @@ -117,7 +117,10 @@ defmodule Bumblebee.Text.TextGeneration do {:error, "expected either a string or a map, got: #{inspect(input)}"} end - defp maybe_stream(serving, false, _stream_done, tokenizer) do + @doc false + def add_postprocessing(serving, stream, stream_done, tokenizer) + + def add_postprocessing(serving, false, _stream_done, tokenizer) do Nx.Serving.client_postprocessing( serving, fn {%{token_ids: token_ids, length: length}, _metadata}, @@ -138,7 +141,7 @@ defmodule Bumblebee.Text.TextGeneration do ) end - defp maybe_stream(serving, true, stream_done, tokenizer) do + def add_postprocessing(serving, true, stream_done, tokenizer) do serving |> Nx.Serving.streaming(hooks: [:token]) |> Nx.Serving.client_postprocessing(fn stream, diff --git a/lib/bumblebee/text/translation.ex b/lib/bumblebee/text/translation.ex new file mode 100644 index 00000000..79082fc4 --- /dev/null +++ b/lib/bumblebee/text/translation.ex @@ -0,0 +1,176 @@ +defmodule Bumblebee.Text.Translation do + @moduledoc false + + alias Bumblebee.Shared + alias Bumblebee.Text + + def translation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :compile, + defn_options: [], + preallocate_params: false, + stream: false, + stream_done: false + ]) + + %{model: model, params: params, spec: spec} = model_info + + Shared.validate_architecture!(spec, [:for_conditional_generation]) + + preallocate_params = opts[:preallocate_params] + defn_options = opts[:defn_options] + + compile = + if compile = opts[:compile] do + compile + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) + end + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + tokenizer = + Bumblebee.configure(tokenizer, + length: sequence_length, + pad_direction: :left, + return_token_type_ids: false, + return_length: true + ) + + generate_fun = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + ignore_output: opts[:stream] + ) + + batch_keys = Shared.sequence_batch_keys(sequence_length) + + Nx.Serving.new( + fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + + generate_fun = + Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "decoder_input_ids" => Nx.template({batch_size, 2}, :u32), + "seed" => Nx.template({batch_size}, :s64) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + generate_fun.(params, inputs) |> Shared.serving_post_computation() + end + end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.process_options(batch_keys: batch_keys) + |> Nx.Serving.client_preprocessing(fn input -> + if opts[:stream] do + Shared.validate_input_for_stream!(input) + end + + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + texts = Enum.map(inputs, & &1.text) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend) + + source_language_token = source_language_token!(inputs) + + validate_language_token!(source_language_token, tokenizer) + + tokenizer = + Bumblebee.configure(tokenizer, template_options: [language_token: source_language_token]) + + # We specify custom decoder_input_ids input to include the dynamic + # language token id after the start token + decoder_input_ids = + inputs + |> Enum.map(fn %{target_language_token: target_language_token} -> + validate_language_token!(target_language_token, tokenizer) + token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, target_language_token) + + decoder_start_token_id = + generation_config.decoder_start_token_id || generation_config.bos_token_id + + [decoder_start_token_id, token_id] + end) + |> Nx.tensor(type: :u32, backend: Nx.BinaryBackend) + + inputs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts) + end) + + {input_length, inputs} = Map.pop!(inputs, "length") + input_padded_length = Nx.axis_size(inputs["input_ids"], 1) + + inputs = Map.put(inputs, "seed", seed) + inputs = Map.put(inputs, "decoder_input_ids", decoder_input_ids) + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {multi?, input_length, input_padded_length}} + end) + |> Text.TextGeneration.add_postprocessing(opts[:stream], opts[:stream_done], tokenizer) + end + + defp validate_input( + %{ + text: text, + source_language_token: source_language_token, + target_language_token: target_language_token + } = input + ) do + {:ok, + %{ + text: text, + source_language_token: source_language_token, + target_language_token: target_language_token, + seed: input[:seed] || :erlang.system_time() + }} + end + + defp validate_input(%{} = input) do + {:error, + "expected the input map to have :text, :source_language_token and :target_language_token keys, got: #{inspect(input)}"} + end + + defp validate_input(input) do + {:error, "expected a map, got: #{inspect(input)}"} + end + + defp source_language_token!(inputs) do + source_language_tokens = for input <- inputs, uniq: true, do: input.source_language_token + + case source_language_tokens do + [token] -> + token + + _tokens -> + raise ArgumentError, + "the translation serving supports a list of inputs only when all" <> + " of them have the same :source_language_token. To process multiple" <> + " inputs with different source language, configure :compile options" <> + " with a desired batch size, start a serving process and use" <> + " Task.async_stream/1 with Nx.Serving.batched_run/1" + end + end + + defp validate_language_token!(language_token, tokenizer) do + unless Bumblebee.Tokenizer.token_to_id(tokenizer, language_token) do + raise ArgumentError, + "the specified language token #{inspect(language_token)} is not" <> + " a valid token for this tokenizer" + end + end +end diff --git a/lib/bumblebee/tokenizer.ex b/lib/bumblebee/tokenizer.ex index a1cf4172..bfcfac17 100644 --- a/lib/bumblebee/tokenizer.ex +++ b/lib/bumblebee/tokenizer.ex @@ -52,12 +52,12 @@ defmodule Bumblebee.Tokenizer do @doc """ Converts the given token into the corresponding numeric id. """ - @callback token_to_id(t(), token()) :: token_id() + @callback token_to_id(t(), token()) :: token_id() | nil @doc """ Converts the given token id the corresponding token. """ - @callback id_to_token(t(), token_id()) :: token() + @callback id_to_token(t(), token_id()) :: token() | nil @doc """ Returns a map with special tokens. @@ -86,7 +86,7 @@ defmodule Bumblebee.Tokenizer do @doc """ Converts the given token into the corresponding numeric id. """ - @spec token_to_id(t(), token()) :: token_id() + @spec token_to_id(t(), token()) :: token_id() | nil def token_to_id(%module{} = tokenizer, token) do module.token_to_id(tokenizer, token) end @@ -94,7 +94,7 @@ defmodule Bumblebee.Tokenizer do @doc """ Converts the given token id to the corresponding token. """ - @spec id_to_token(t(), token_id()) :: token() + @spec id_to_token(t(), token_id()) :: token() | nil def id_to_token(%module{} = tokenizer, id) do module.id_to_token(tokenizer, id) end diff --git a/mix.exs b/mix.exs index 2a7be069..6c301497 100644 --- a/mix.exs +++ b/mix.exs @@ -95,9 +95,11 @@ defmodule Bumblebee.MixProject do Bumblebee.Text.GptBigCode, Bumblebee.Text.GptNeoX, Bumblebee.Text.Llama, + Bumblebee.Text.M2m100, Bumblebee.Text.Mbart, Bumblebee.Text.Mistral, Bumblebee.Text.Phi, + Bumblebee.Text.Phi3, Bumblebee.Text.Roberta, Bumblebee.Text.T5, Bumblebee.Vision.BlipVision, diff --git a/test/bumblebee/text/pre_trained_tokenizer_test.exs b/test/bumblebee/text/pre_trained_tokenizer_test.exs index dceaf134..19537fdc 100644 --- a/test/bumblebee/text/pre_trained_tokenizer_test.exs +++ b/test/bumblebee/text/pre_trained_tokenizer_test.exs @@ -339,6 +339,47 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do ) end + test ":nllb" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"}) + + assert %Bumblebee.Text.PreTrainedTokenizer{type: :nllb} = tokenizer + + inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello, my dog is cute "]) + + assert_equal( + inputs["input_ids"], + Nx.tensor([ + [256_047, 94124, 248_079, 1537, 6658, 248, 95740, 256_203, 2] + ]) + ) + + assert_equal( + inputs["attention_mask"], + Nx.tensor([ + [1, 1, 1, 1, 1, 1, 1, 1, 1] + ]) + ) + + tokenizer = + Bumblebee.configure(tokenizer, template_options: [language_token: "fra_Latn"]) + + inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello, my dog is cute "]) + + assert_equal( + inputs["input_ids"], + Nx.tensor([ + [256_057, 94124, 248_079, 1537, 6658, 248, 95740, 256_203, 2] + ]) + ) + + assert_equal( + inputs["attention_mask"], + Nx.tensor([ + [1, 1, 1, 1, 1, 1, 1, 1, 1] + ]) + ) + end + test ":roberta" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"}) diff --git a/test/bumblebee/text/translation_test.exs b/test/bumblebee/text/translation_test.exs new file mode 100644 index 00000000..a3b8f9f2 --- /dev/null +++ b/test/bumblebee/text/translation_test.exs @@ -0,0 +1,34 @@ +defmodule Bumblebee.Text.TranslationTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag serving_test_tags() + + test "generates text with greedy generation" do + {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/nllb-200-distilled-600M"}) + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "facebook/nllb-200-distilled-600M"}) + + serving = Bumblebee.Text.translation(model_info, tokenizer, generation_config) + + text = "The bank of the river is beautiful in spring" + + assert %{ + results: [ + %{ + text: "W wiosnę brzeg rzeki jest piękny", + token_summary: %{input: 11, output: 13, padding: 0} + } + ] + } = + Nx.Serving.run(serving, %{ + text: text, + source_language_token: "eng_Latn", + target_language_token: "pol_Latn" + }) + end +end