Skip to content

Commit

Permalink
Reduce the output of generation loop when streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 15, 2024
1 parent ffd1c3f commit a74c203
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
49 changes: 40 additions & 9 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,46 @@ defmodule Bumblebee.Text.Generation do
argument. Note that the inputs map should additionally include a
`"seed"` tensor, with one value per input in the batch.
## Streaming
This function sets up a hook that is invoked after every generated
token. The hook receives a map with the following attributes:
* `:token_id` - the newly generated token
* `:finished?` - a boolean indicating if the sequence is finished
* `:length` - the current length of the generated sequence. Once
the sequence is finished, the length does not increase
Each of the attributes is a tensor with a leading batch dimension.
When streaming you may not care about the output result, in which
case you can enable `:ignore_output` to reduce the output size.
## Options
* `:logits_processors` - a list of numerical functions to modify
predicted scores at each generation step. The functions are
applied in order, after all default processors
* `:ignore_output` - if true, returns a dummy tensor that should
be ignored. This is useful when you consume the generated tokens
in a stream fashion via the hook, so that the full output does
not need to be transferred unnecessarily after the computation.
Defaults to `false`
"""
@spec build_generate(
Axon.t(),
Bumblebee.ModelSpec.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) ::
(params :: map(), inputs :: map() -> %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()})
(params :: map(), inputs :: map() ->
%{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()} | (ignored :: Nx.Tensor.t()))
def build_generate(model, spec, config, opts \\ []) do
opts = Keyword.validate!(opts, logits_processors: [])
opts = Keyword.validate!(opts, logits_processors: [], ignore_output: false)

decoder_start_token_id = config.decoder_start_token_id || config.bos_token_id
eos_token_id = config.eos_token_id
Expand Down Expand Up @@ -148,7 +172,8 @@ defmodule Bumblebee.Text.Generation do
traverse_cache_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id,
strategy: config.strategy
strategy: config.strategy,
ignore_output: opts[:ignore_output]
)
end

Expand Down Expand Up @@ -400,11 +425,15 @@ defmodule Bumblebee.Text.Generation do
)
end

%{
# Output only the newly generated tokens
token_ids: state.sequences[[.., length..-1//1]],
length: state.finished_length - length
}
if opts[:ignore_output] do
state.ignored
else
%{
# Output only the newly generated tokens
token_ids: state.sequences[[.., length..-1//1]],
length: state.finished_length - length
}
end
end

deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed")
Expand Down Expand Up @@ -480,7 +509,9 @@ defmodule Bumblebee.Text.Generation do
sequences: sequences,
input_length: length,
length: length,
finished_length: finished_length
finished_length: finished_length,
# The ignored return value that we attach all hooks to
ignored: Nx.broadcast(0, {batch_size})
}
end

Expand Down
5 changes: 4 additions & 1 deletion lib/bumblebee/text/text_generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ defmodule Bumblebee.Text.TextGeneration do
return_length: true
)

generate_fun = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
generate_fun =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
ignore_output: opts[:stream]
)

batch_keys = Shared.sequence_batch_keys(sequence_length)

Expand Down

0 comments on commit a74c203

Please sign in to comment.