Skip to content

Commit

Permalink
Make sure the initial decoding cache has the proper types
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 23, 2024
1 parent 601551a commit 53cec79
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
29 changes: 26 additions & 3 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ defmodule Bumblebee.Text.Generation do
})

max_length = max_length_fun.(1)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, max_length)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
{inputs, inputs["decoder_input_ids"], max_length}
end

Expand All @@ -260,7 +260,7 @@ defmodule Bumblebee.Text.Generation do
prepare_inputs_fun = fn inputs, _params ->
sequence_length = Nx.axis_size(inputs["input_ids"], 1)
max_length = max_length_fun.(sequence_length)
inputs = prepare_decoder_inputs(inputs, "", spec, max_length)
inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length)
{inputs, inputs["input_ids"], max_length}
end

Expand All @@ -279,7 +279,7 @@ defmodule Bumblebee.Text.Generation do
inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
end

defp prepare_decoder_inputs(inputs, prefix, spec, max_length) do
defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do
input_ids = inputs[prefix <> "input_ids"]
attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids)

Expand All @@ -295,9 +295,32 @@ defmodule Bumblebee.Text.Generation do

batch_size = Nx.axis_size(input_ids, 0)
cache = init_cache(spec, batch_size, max_length, inputs)

output_policy = model_output_policy(model)

# TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
# match Axon compiler

# Cast all float cache tensors to match the model output. This way
# we make sure the cache we pass as input has the same types as
# the updated cache returned from the model
cache =
Bumblebee.Utils.Nx.map(cache, fn tensor ->
if Nx.Type.integer?(Nx.type(tensor)) do
tensor
else
Axon.MixedPrecision.cast(output_policy, tensor, :output)
end
end)

Map.put(inputs, "cache", cache)
end

defp model_output_policy(model) do
{node, _} = Axon.pop_node(model)
node.policy
end

defp update_decoder_inputs(prefix, inputs, cache, token_ids) do
inputs
|> Map.replace!(prefix <> "input_ids", token_ids)
Expand Down
28 changes: 28 additions & 0 deletions test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,32 @@ defmodule Bumblebee.Text.BartTest do

assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end

test "generation with :for_conditional_generation and lower precision" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"},
type: :f16
)

{:ok, generation_config} =
Bumblebee.load_generation_config(
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
)

assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
"seed" => Nx.tensor([0])
}

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)

generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
%{token_ids: token_ids} = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end
end

0 comments on commit 53cec79

Please sign in to comment.