From 715aac5f9977a9caf03e4e4f264e7d2018d48340 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 28 Feb 2024 18:46:58 -0500 Subject: [PATCH] WIP --- lib/bumblebee/text/generation.ex | 3 + .../text/generation/grammar_constraint.ex | 94 +++++++++++++++++++ .../text/generation/logits_processing.ex | 12 +++ lib/bumblebee/text/generation/stack.ex | 64 +++++++++++++ lib/bumblebee/text/generation/token_trie.ex | 61 ++++++++++++ mix.exs | 3 +- mix.lock | 1 + 7 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 lib/bumblebee/text/generation/grammar_constraint.ex create mode 100644 lib/bumblebee/text/generation/stack.ex create mode 100644 lib/bumblebee/text/generation/token_trie.ex diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 17bfcf8b..15c3571d 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -361,6 +361,9 @@ defmodule Bumblebee.Text.Generation do end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) + end, + if config.grammar do + &grammar_constrained_processor(&1, &2, grammar: config.grammar) end ] ++ if config.strategy.type == :multinomial_sampling do diff --git a/lib/bumblebee/text/generation/grammar_constraint.ex b/lib/bumblebee/text/generation/grammar_constraint.ex new file mode 100644 index 00000000..87361c24 --- /dev/null +++ b/lib/bumblebee/text/generation/grammar_constraint.ex @@ -0,0 +1,94 @@ +defmodule Bumblebee.Text.Generation.GrammarConstraint do + @moduledoc false + + alias Bumblebee.Text.Generation.TokenTrie + alias Bumblebee.Text.Generation.Stack + alias EBNF.ParseState + + alias __MODULE__ + + # Models a constraint + + defstruct [ + :token_trie, + :grammar_encoding, + :tokenizer, + :start_rule_id, + :start_rule_position, + :rule_positions + ] + + def create(grammar, root, tokenizer) do + %ParseState{symbol_ids: symbols, grammar_encoding: encoding} = EBNF.encode(grammar) + trie = TokenTrie.create(tokenizer) + start_rule_id = Map.fetch!(symbols, root) + rule_positions = get_rule_positions(encoding) + + %GrammarConstraint{ + token_trie: trie, + grammar_encoding: encoding, + tokenizer: tokenizer, + start_rule_id: start_rule_id, + start_rule_position: Map.fetch!(rule_positions, start_rule_id), + rule_positions: rule_positions + } + end + + def init_stacks(constraint) do + # stack will never exceed the grammar encoding size + stack = + Stack.new(length(constraint.grammar_encoding)) + |> Stack.push(constraint.start_rule_pos + 2) + |> advance_stack() + end + + defn advance_stack(stack) do + if Nx.equal(Stack.length(stack), 0) do + stack + else + top = Stack.peek(stack) + + if Nx.equal(top, 2) do + stack + else + + + end + end + end + + defp get_rule_positions(grammar_encoding) do + recur_get_rule_positions(grammar_encoding, 0, %{}) + end + + defp recur_get_rule_positions([0xFFFF], _pos, rule_positions), do: rule_positions + + defp recur_get_rule_positions([rule_id | rest], pos, rule_positions) do + rule_positions = Map.put(rule_positions, rule_id, pos) + + case find_next_rule(rest, pos + 1) do + {[_ | leftover], pos} -> + recur_get_rule_positions(leftover, pos + 1, rule_positions) + + {[], _} -> + rule_positions + end + end + + defp find_next_rule([0 | rest], pos) do + {rest, pos + 1} + end + + defp find_next_rule([rule_size | _] = leftover, pos) do + leftover = Enum.drop(leftover, rule_size + 1) + pos = pos + rule_size + 1 + + case leftover do + [0 | _] -> + {leftover, pos} + + leftover -> + find_next_rule(leftover, pos) + end + end +end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index bf8a924c..a97f6598 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,6 +3,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn + alias Bumblebee.Text.Generation.GrammarConstraint + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) @@ -255,4 +257,14 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do {idx, _token_id} -> idx + 1 end end + + deftransform grammar_constrained_processor(logits, input_ids, opts \\ []) do + opts = Keyword.validate!(opts, [:grammar, :tokenizer]) + + grammar = opts[:grammar] + tokenizer = opts[:tokenizer] + + constraint = GrammarConstraint.create(grammar, "root", tokenizer) + batch_stacks = GrammarConstraint.init_stacks(constraint) + end end diff --git a/lib/bumblebee/text/generation/stack.ex b/lib/bumblebee/text/generation/stack.ex new file mode 100644 index 00000000..c7669f15 --- /dev/null +++ b/lib/bumblebee/text/generation/stack.ex @@ -0,0 +1,64 @@ +defmodule Bumblebee.Text.Generation.Stack do + @moduledoc false + + # A "stack" like data structure represented as an Nx container + # to make constrained sampling possible/easier. The HF implementation + # uses a "dynamic" stack, but we need all shapes up front and + # can't manipulate so we use a "stack" and then a pointer in + # the stack + + alias __MODULE__ + + @derive {Nx.Container, containers: [:data, :pointer]} + defstruct [:data, :pointer] + + import Nx.Defn + + @empty_value -1 + + @doc """ + Initializes a new stack. + """ + def new(size, opts \\ []) do + opts = Keyword.validate!(opts, type: :s64) + + %Stack{ + data: Nx.broadcast(Nx.tensor(@empty_value, type: opts[:type]), {size}), + pointer: Nx.tensor(0) + } + end + + @doc """ + Push a value to the top of the stack. + """ + deftransform push(%Stack{data: data, pointer: pointer} = stack, value) do + unless Nx.rank(value) == 0, do: raise("can only push scalar values to stack") + + %{ + data: Nx.put_slice(data, [pointer], value), + pointer: Nx.add(pointer, 1) + } + end + + @doc """ + Pops a value from the stack. + """ + defn pop(%Stack{data: data, pointer: pointer} = stack) do + value = data[[pointer]] + {value, %{stack | pointer: Nx.subtract(pointer, 1)}} + end + + @doc """ + Peeks at the head of the stack. + """ + defn peek(%Stack{data: data, pointer: pointer}) do + data[[pointer]] + end + + @doc """ + Returns the length of the stack. + """ + defn length(%Stack{pointer: pointer}) do + pointer + end +end diff --git a/lib/bumblebee/text/generation/token_trie.ex b/lib/bumblebee/text/generation/token_trie.ex new file mode 100644 index 00000000..f49e69b2 --- /dev/null +++ b/lib/bumblebee/text/generation/token_trie.ex @@ -0,0 +1,61 @@ +defmodule Bumblebee.Text.Generation.TokenTrie do + @moduledoc false + + # Internal data structure used in constrained sampling + + alias Bumblebee.Text.PreTrainedTokenizer + alias __MODULE__ + + defstruct [:tokens, :trie, :eos_token_id] + + @leaf -1 + + @doc """ + Returns the token encoded by the given ID. + """ + def id_to_token(%TokenTrie{tokens: tokens}, id) do + Map.fetch!(tokens, id) + end + + @doc """ + Returns the number of tokens in the trie. + """ + def n_tokens(%TokenTrie{tokens: tokens}) do + length(tokens) + end + + @doc """ + Creates a trie from the vocabulary in the given tokenizer. + """ + def create(%PreTrainedTokenizer{native_tokenizer: tokenizer, special_tokens: %{eos: eos_token}}) do + vocab = Tokenizers.Tokenizer.get_vocab(tokenizer) + eos_token_id = Map.fetch!(vocab, eos_token) + + tokens = + Map.new(vocab, fn {token, id} -> + # TODO: Special cases for GPT2 and Llama + {id, String.to_charlist(token)} + end) + + trie = + Enum.reduce(tokens, %{}, fn {token_id, token_bytes}, acc -> + insert_into_trie(acc, token_bytes, token_id) + end) + + %TokenTrie{tokens: tokens, trie: trie, eos_token_id: eos_token_id} + end + + ## Helpers + + defp insert_into_trie(trie, token_bytes, token_id) do + do_insert_into_trie(trie, token_bytes, token_id) + end + + defp do_insert_into_trie(trie, [], token_id), do: Map.put(trie, @leaf, token_id) + + defp do_insert_into_trie(trie, [byte | rest_bytes], token_id) do + current = Map.get(trie, byte, %{}) + updated = do_insert_into_trie(current, rest_bytes, token_id) + Map.put(trie, byte, updated) + end +end diff --git a/mix.exs b/mix.exs index 07ab0ba5..b07d0628 100644 --- a/mix.exs +++ b/mix.exs @@ -49,7 +49,8 @@ defmodule Bumblebee.MixProject do {:stb_image, "~> 0.6.0", only: :test}, {:bypass, "~> 2.1", only: :test}, {:ex_doc, "~> 0.28", only: :dev, runtime: false}, - {:nx_signal, "~> 0.2.0"} + {:nx_signal, "~> 0.2.0"}, + {:ebnf, github: "seanmor5/ebnf"} ] end diff --git a/mix.lock b/mix.lock index 314aded5..f08a24c4 100644 --- a/mix.lock +++ b/mix.lock @@ -9,6 +9,7 @@ "cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"}, "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, + "ebnf": {:git, "https://github.com/seanmor5/ebnf.git", "a69e84619881b27fa8eceff29713ff9b496814cd", []}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.31.0", "06eb1dfd787445d9cab9a45088405593dd3bb7fe99e097eaa71f37ba80c7a676", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5350cafa6b7f77bdd107aa2199fe277acf29d739aba5aee7e865fc680c62a110"}, "exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"},