Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Constrained sampling based on EBNF grammars #354

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ defmodule Bumblebee.Text.Generation do
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end,
if config.grammar do
&grammar_constrained_processor(&1, &2, grammar: config.grammar)
end
] ++
if config.strategy.type == :multinomial_sampling do
Expand Down
94 changes: 94 additions & 0 deletions lib/bumblebee/text/generation/grammar_constraint.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
defmodule Bumblebee.Text.Generation.GrammarConstraint do
@moduledoc false

alias Bumblebee.Text.Generation.TokenTrie
alias Bumblebee.Text.Generation.Stack
alias EBNF.ParseState

alias __MODULE__

# Models a constraint

defstruct [
:token_trie,
:grammar_encoding,
:tokenizer,
:start_rule_id,
:start_rule_position,
:rule_positions
]

def create(grammar, root, tokenizer) do
%ParseState{symbol_ids: symbols, grammar_encoding: encoding} = EBNF.encode(grammar)
trie = TokenTrie.create(tokenizer)
start_rule_id = Map.fetch!(symbols, root)
rule_positions = get_rule_positions(encoding)

%GrammarConstraint{
token_trie: trie,
grammar_encoding: encoding,
tokenizer: tokenizer,
start_rule_id: start_rule_id,
start_rule_position: Map.fetch!(rule_positions, start_rule_id),
rule_positions: rule_positions
}
end

def init_stacks(constraint) do
# stack will never exceed the grammar encoding size
stack =
Stack.new(length(constraint.grammar_encoding))
|> Stack.push(constraint.start_rule_pos + 2)
|> advance_stack()
end

defn advance_stack(stack) do

Check failure on line 45 in lib/bumblebee/text/generation/grammar_constraint.ex

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3.2.2)

** (CompileError) lib/bumblebee/text/generation/grammar_constraint.ex:45: undefined function defn/2 (there is no such import)
if Nx.equal(Stack.length(stack), 0) do
stack
else
top = Stack.peek(stack)

if Nx.equal(top, 2) do
stack
else


end
end
end

defp get_rule_positions(grammar_encoding) do
recur_get_rule_positions(grammar_encoding, 0, %{})
end

defp recur_get_rule_positions([0xFFFF], _pos, rule_positions), do: rule_positions

defp recur_get_rule_positions([rule_id | rest], pos, rule_positions) do
rule_positions = Map.put(rule_positions, rule_id, pos)

case find_next_rule(rest, pos + 1) do
{[_ | leftover], pos} ->
recur_get_rule_positions(leftover, pos + 1, rule_positions)

{[], _} ->
rule_positions
end
end

defp find_next_rule([0 | rest], pos) do
{rest, pos + 1}
end

defp find_next_rule([rule_size | _] = leftover, pos) do
leftover = Enum.drop(leftover, rule_size + 1)
pos = pos + rule_size + 1

case leftover do
[0 | _] ->
{leftover, pos}

leftover ->
find_next_rule(leftover, pos)
end
end
end
12 changes: 12 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do

import Nx.Defn

alias Bumblebee.Text.Generation.GrammarConstraint

deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do
opts = Keyword.validate!(opts, [:suppressed_token_ids])

Expand Down Expand Up @@ -255,4 +257,14 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
{idx, _token_id} -> idx + 1
end
end

deftransform grammar_constrained_processor(logits, input_ids, opts \\ []) do
opts = Keyword.validate!(opts, [:grammar, :tokenizer])

grammar = opts[:grammar]
tokenizer = opts[:tokenizer]

constraint = GrammarConstraint.create(grammar, "root", tokenizer)
batch_stacks = GrammarConstraint.init_stacks(constraint)
end
end
64 changes: 64 additions & 0 deletions lib/bumblebee/text/generation/stack.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
defmodule Bumblebee.Text.Generation.Stack do
@moduledoc false

# A "stack" like data structure represented as an Nx container
# to make constrained sampling possible/easier. The HF implementation
# uses a "dynamic" stack, but we need all shapes up front and
# can't manipulate so we use a "stack" and then a pointer in
# the stack

alias __MODULE__

@derive {Nx.Container, containers: [:data, :pointer]}
defstruct [:data, :pointer]

import Nx.Defn

@empty_value -1

@doc """
Initializes a new stack.
"""
def new(size, opts \\ []) do
opts = Keyword.validate!(opts, type: :s64)

%Stack{
data: Nx.broadcast(Nx.tensor(@empty_value, type: opts[:type]), {size}),
pointer: Nx.tensor(0)
}
end

@doc """
Push a value to the top of the stack.
"""
deftransform push(%Stack{data: data, pointer: pointer} = stack, value) do
unless Nx.rank(value) == 0, do: raise("can only push scalar values to stack")

%{
data: Nx.put_slice(data, [pointer], value),
pointer: Nx.add(pointer, 1)
}
end

@doc """
Pops a value from the stack.
"""
defn pop(%Stack{data: data, pointer: pointer} = stack) do
value = data[[pointer]]
{value, %{stack | pointer: Nx.subtract(pointer, 1)}}
end

@doc """
Peeks at the head of the stack.
"""
defn peek(%Stack{data: data, pointer: pointer}) do
data[[pointer]]
end

@doc """
Returns the length of the stack.
"""
defn length(%Stack{pointer: pointer}) do
pointer
end
end
61 changes: 61 additions & 0 deletions lib/bumblebee/text/generation/token_trie.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
defmodule Bumblebee.Text.Generation.TokenTrie do
@moduledoc false

# Internal data structure used in constrained sampling

alias Bumblebee.Text.PreTrainedTokenizer
alias __MODULE__

defstruct [:tokens, :trie, :eos_token_id]

@leaf -1

@doc """
Returns the token encoded by the given ID.
"""
def id_to_token(%TokenTrie{tokens: tokens}, id) do
Map.fetch!(tokens, id)
end

@doc """
Returns the number of tokens in the trie.
"""
def n_tokens(%TokenTrie{tokens: tokens}) do
length(tokens)
end

@doc """
Creates a trie from the vocabulary in the given tokenizer.
"""
def create(%PreTrainedTokenizer{native_tokenizer: tokenizer, special_tokens: %{eos: eos_token}}) do
vocab = Tokenizers.Tokenizer.get_vocab(tokenizer)
eos_token_id = Map.fetch!(vocab, eos_token)

tokens =
Map.new(vocab, fn {token, id} ->
# TODO: Special cases for GPT2 and Llama
{id, String.to_charlist(token)}
end)

trie =
Enum.reduce(tokens, %{}, fn {token_id, token_bytes}, acc ->
insert_into_trie(acc, token_bytes, token_id)
end)

%TokenTrie{tokens: tokens, trie: trie, eos_token_id: eos_token_id}
end

## Helpers

defp insert_into_trie(trie, token_bytes, token_id) do
do_insert_into_trie(trie, token_bytes, token_id)
end

defp do_insert_into_trie(trie, [], token_id), do: Map.put(trie, @leaf, token_id)

defp do_insert_into_trie(trie, [byte | rest_bytes], token_id) do
current = Map.get(trie, byte, %{})
updated = do_insert_into_trie(current, rest_bytes, token_id)
Map.put(trie, byte, updated)
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ defmodule Bumblebee.MixProject do
{:stb_image, "~> 0.6.0", only: :test},
{:bypass, "~> 2.1", only: :test},
{:ex_doc, "~> 0.28", only: :dev, runtime: false},
{:nx_signal, "~> 0.2.0"}
{:nx_signal, "~> 0.2.0"},
{:ebnf, github: "seanmor5/ebnf"}
]
end

Expand Down
1 change: 1 addition & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"},
"decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"ebnf": {:git, "https://github.com/seanmor5/ebnf.git", "a69e84619881b27fa8eceff29713ff9b496814cd", []},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"ex_doc": {:hex, :ex_doc, "0.31.0", "06eb1dfd787445d9cab9a45088405593dd3bb7fe99e097eaa71f37ba80c7a676", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5350cafa6b7f77bdd107aa2199fe277acf29d739aba5aee7e865fc680c62a110"},
"exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"},
Expand Down
Loading