From 08a3e4734fede4846329a8b8db1296ac70f7a4d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Thu, 30 May 2024 16:16:52 +0700 Subject: [PATCH] Adjust scaling Adjust scaling strategies --- lib/bumblebee/layers.ex | 91 ++++++++++++++++++++----------- lib/bumblebee/text/phi3.ex | 34 ++++-------- test/bumblebee/text/phi3_test.exs | 52 ++++++++++++++++++ 3 files changed, 122 insertions(+), 55 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 782ad6b7..0c87b979 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1230,39 +1230,66 @@ defmodule Bumblebee.Layers do ) do position = Nx.iota({sequence_length}) - {base, position} = - case scaling_strategy do - %{type: :linear, factor: factor} -> - {base, Nx.divide(position, factor)} - - %{type: :dynamic, factor: factor} when sequence_length > max_positions -> - base = - base - |> Nx.multiply(factor * sequence_length / max_positions - (factor - 1)) - |> Nx.pow(size / (size - 2)) - - {base, position} - - %{type: :su, short_factor: sf, long_factor: lf, original_max_positions: omp} -> - scaling_factor = - if sequence_length > omp do - Nx.tensor(lf, type: :f32) - else - Nx.tensor(sf, type: :f32) - end - - # Define how you want to use scaling_factor for base and position - scaled_base = Nx.multiply(base, scaling_factor) - # scaled_position = Nx.divide(position, scaling_factor) - - {scaled_base, position} - - _other -> - {base, position} - end - range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size) - inv_frequency = Nx.divide(1.0, Nx.pow(base, range)) + + case scaling_strategy do + %{type: :linear, factor: factor} -> + frequency = Nx.pow(base, range) + position = Nx.divide(position, factor) + positions_cos_sin(position, frequency) + + %{type: :dynamic, factor: factor} when sequence_length > max_positions -> + base = + base + |> Nx.multiply(factor * sequence_length / max_positions - (factor - 1)) + |> Nx.pow(size / (size - 2)) + + frequency = Nx.pow(base, range) + positions_cos_sin(position, frequency) + + %{ + type: type, + short_factor: short_factor, + long_factor: long_factor, + original_max_positions: original_max_positions + } + when type in [:su, :yarn] -> + factor = + if sequence_length > original_max_positions do + Nx.tensor(long_factor, type: :f32) + else + Nx.tensor(short_factor, type: :f32) + end + + scale = max_positions / original_max_positions + + cos_sin_factor = + cond do + scale <= 1.0 -> + 1.0 + + type == :su -> + Nx.divide(Nx.log(scale), Nx.log(original_max_positions)) + |> Nx.add(1) + |> Nx.sqrt() + + type == :yarn -> + Nx.multiply(0.1, Nx.log(scale)) + |> Nx.add(1.0) + end + + frequency = Nx.multiply(factor, Nx.pow(base, range)) + {cos, sin} = positions_cos_sin(position, frequency) + {Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)} + + _other -> + frequency = Nx.pow(base, range) + positions_cos_sin(position, frequency) + end + end + + defnp positions_cos_sin(position, frequency) do + inv_frequency = 1.0 / frequency angle = Nx.outer(position, inv_frequency) angle = Nx.concatenate([angle, angle], axis: -1) diff --git a/lib/bumblebee/text/phi3.ex b/lib/bumblebee/text/phi3.ex index cac2aceb..025ffd51 100644 --- a/lib/bumblebee/text/phi3.ex +++ b/lib/bumblebee/text/phi3.ex @@ -18,12 +18,6 @@ defmodule Bumblebee.Text.Phi3 do such as 512, 1024 or 2048 """ ], - original_max_positions: [ - default: 4096, - doc: """ - the original vocabulary size of the position embedding. - """ - ], hidden_size: [ default: 2048, doc: "the dimensionality of hidden layers" @@ -66,13 +60,10 @@ defmodule Bumblebee.Text.Phi3 do doc: """ scaling configuration for rotary embedding. Currently the supported values are: - * `%{type: :linear, factor: number()}` - - * `%{type: :dynamic, factor: number()}` + * `%{type: :su, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}` - * `%{type: :su, short_factor: [number], long_factor: [number]}` + * `%{type: :yarn, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}` - For more details see https://www.reddit.com/r/LocalLlama/comments/14mrgpr/dynamically_scaled_rope_further_increases """ ], layer_norm_epsilon: [ @@ -433,20 +424,18 @@ defmodule Bumblebee.Text.Phi3 do import Shared.Converters scaling_strategy_converter = fn name, value -> - case value do - %{"type" => "linear", "factor" => factor} when is_number(factor) -> - {:ok, %{type: :linear, factor: factor}} - - %{"type" => "dynamic", "factor" => factor} when is_number(factor) -> - {:ok, %{type: :dynamic, factor: factor}} + original_max_positions = data["original_max_position_embeddings"] - %{"type" => "su", "long_factor" => lf, "short_factor" => sf} -> + case value do + %{"type" => type, "long_factor" => long_factor, "short_factor" => short_factor} + when type in ["su", "yarn"] and is_list(long_factor) and is_list(short_factor) and + is_number(original_max_positions) -> {:ok, %{ - type: :su, - long_factor: lf, - short_factor: sf, - original_max_positions: spec.original_max_positions + type: String.to_atom(type), + long_factor: long_factor, + short_factor: short_factor, + original_max_positions: original_max_positions }} _other -> @@ -458,7 +447,6 @@ defmodule Bumblebee.Text.Phi3 do convert!(data, vocab_size: {"vocab_size", number()}, max_positions: {"max_position_embeddings", number()}, - original_max_positions: {"original_max_position_embeddings", number()}, hidden_size: {"hidden_size", number()}, num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, diff --git a/test/bumblebee/text/phi3_test.exs b/test/bumblebee/text/phi3_test.exs index 5c79ae1f..6ac45a9e 100644 --- a/test/bumblebee/text/phi3_test.exs +++ b/test/bumblebee/text/phi3_test.exs @@ -28,6 +28,58 @@ defmodule Bumblebee.Text.Phi3Test do ) end + test ":base rotary embedding scaling strategy :su" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, + "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-su-original_max_position_embeddings-256"} + ) + + assert %Bumblebee.Text.Phi3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.4528, 0.5995, 0.1573], [-0.2664, 1.9339, 0.5336], [1.1053, -0.1643, 0.5989]] + ]) + ) + end + + test ":base rotary embedding scaling strategy :yarn" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, + "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-yarn-original_max_position_embeddings-256"} + ) + + assert %Bumblebee.Text.Phi3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.4530, 0.5995, 0.1574], [-0.2663, 1.9339, 0.5336], [1.1052, -0.1642, 0.5989]] + ]) + ) + end + test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model(