Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure the initial decoding cache has the proper types #346

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ defmodule Bumblebee.Text.Generation do
})

max_length = max_length_fun.(1)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, max_length)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
{inputs, inputs["decoder_input_ids"], max_length}
end

Expand All @@ -260,7 +260,7 @@ defmodule Bumblebee.Text.Generation do
prepare_inputs_fun = fn inputs, _params ->
sequence_length = Nx.axis_size(inputs["input_ids"], 1)
max_length = max_length_fun.(sequence_length)
inputs = prepare_decoder_inputs(inputs, "", spec, max_length)
inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length)
{inputs, inputs["input_ids"], max_length}
end

Expand All @@ -279,7 +279,7 @@ defmodule Bumblebee.Text.Generation do
inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
end

defp prepare_decoder_inputs(inputs, prefix, spec, max_length) do
defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do
input_ids = inputs[prefix <> "input_ids"]
attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids)

Expand All @@ -295,9 +295,32 @@ defmodule Bumblebee.Text.Generation do

batch_size = Nx.axis_size(input_ids, 0)
cache = init_cache(spec, batch_size, max_length, inputs)

output_policy = model_output_policy(model)

# TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
# match Axon compiler

# Cast all float cache tensors to match the model output. This way
# we make sure the cache we pass as input has the same types as
# the updated cache returned from the model
cache =
Bumblebee.Utils.Nx.map(cache, fn tensor ->
if Nx.Type.integer?(Nx.type(tensor)) do
tensor
else
Axon.MixedPrecision.cast(output_policy, tensor, :output)
end
end)

Map.put(inputs, "cache", cache)
end

defp model_output_policy(model) do
{node, _} = Axon.pop_node(model)
node.policy
end

defp update_decoder_inputs(prefix, inputs, cache, token_ids) do
inputs
|> Map.replace!(prefix <> "input_ids", token_ids)
Expand Down
28 changes: 28 additions & 0 deletions test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,32 @@ defmodule Bumblebee.Text.BartTest do

assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end

test "generation with :for_conditional_generation and lower precision" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"},
type: :f16
)

{:ok, generation_config} =
Bumblebee.load_generation_config(
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
)

assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
"seed" => Nx.tensor([0])
}

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)

generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
%{token_ids: token_ids} = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end
end
Loading