Skip to content

Commit

Permalink
Adjust scaling
Browse files Browse the repository at this point in the history
Adjust scaling strategies
  • Loading branch information
jonatanklosko committed May 30, 2024
1 parent f89f96a commit 08a3e47
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 55 deletions.
91 changes: 59 additions & 32 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1230,39 +1230,66 @@ defmodule Bumblebee.Layers do
) do
position = Nx.iota({sequence_length})

{base, position} =
case scaling_strategy do
%{type: :linear, factor: factor} ->
{base, Nx.divide(position, factor)}

%{type: :dynamic, factor: factor} when sequence_length > max_positions ->
base =
base
|> Nx.multiply(factor * sequence_length / max_positions - (factor - 1))
|> Nx.pow(size / (size - 2))

{base, position}

%{type: :su, short_factor: sf, long_factor: lf, original_max_positions: omp} ->
scaling_factor =
if sequence_length > omp do
Nx.tensor(lf, type: :f32)
else
Nx.tensor(sf, type: :f32)
end

# Define how you want to use scaling_factor for base and position
scaled_base = Nx.multiply(base, scaling_factor)
# scaled_position = Nx.divide(position, scaling_factor)

{scaled_base, position}

_other ->
{base, position}
end

range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size)
inv_frequency = Nx.divide(1.0, Nx.pow(base, range))

case scaling_strategy do
%{type: :linear, factor: factor} ->
frequency = Nx.pow(base, range)
position = Nx.divide(position, factor)
positions_cos_sin(position, frequency)

%{type: :dynamic, factor: factor} when sequence_length > max_positions ->
base =
base
|> Nx.multiply(factor * sequence_length / max_positions - (factor - 1))
|> Nx.pow(size / (size - 2))

frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)

%{
type: type,
short_factor: short_factor,
long_factor: long_factor,
original_max_positions: original_max_positions
}
when type in [:su, :yarn] ->
factor =
if sequence_length > original_max_positions do
Nx.tensor(long_factor, type: :f32)
else
Nx.tensor(short_factor, type: :f32)
end

scale = max_positions / original_max_positions

cos_sin_factor =
cond do
scale <= 1.0 ->
1.0

type == :su ->
Nx.divide(Nx.log(scale), Nx.log(original_max_positions))
|> Nx.add(1)
|> Nx.sqrt()

type == :yarn ->
Nx.multiply(0.1, Nx.log(scale))
|> Nx.add(1.0)
end

frequency = Nx.multiply(factor, Nx.pow(base, range))
{cos, sin} = positions_cos_sin(position, frequency)
{Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)}

_other ->
frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)
end
end

defnp positions_cos_sin(position, frequency) do
inv_frequency = 1.0 / frequency
angle = Nx.outer(position, inv_frequency)

angle = Nx.concatenate([angle, angle], axis: -1)
Expand Down
34 changes: 11 additions & 23 deletions lib/bumblebee/text/phi3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ defmodule Bumblebee.Text.Phi3 do
such as 512, 1024 or 2048
"""
],
original_max_positions: [
default: 4096,
doc: """
the original vocabulary size of the position embedding.
"""
],
hidden_size: [
default: 2048,
doc: "the dimensionality of hidden layers"
Expand Down Expand Up @@ -66,13 +60,10 @@ defmodule Bumblebee.Text.Phi3 do
doc: """
scaling configuration for rotary embedding. Currently the supported values are:
* `%{type: :linear, factor: number()}`
* `%{type: :dynamic, factor: number()}`
* `%{type: :su, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}`
* `%{type: :su, short_factor: [number], long_factor: [number]}`
* `%{type: :yarn, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}`
For more details see https://www.reddit.com/r/LocalLlama/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
layer_norm_epsilon: [
Expand Down Expand Up @@ -433,20 +424,18 @@ defmodule Bumblebee.Text.Phi3 do
import Shared.Converters

scaling_strategy_converter = fn name, value ->
case value do
%{"type" => "linear", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :linear, factor: factor}}

%{"type" => "dynamic", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :dynamic, factor: factor}}
original_max_positions = data["original_max_position_embeddings"]

%{"type" => "su", "long_factor" => lf, "short_factor" => sf} ->
case value do
%{"type" => type, "long_factor" => long_factor, "short_factor" => short_factor}
when type in ["su", "yarn"] and is_list(long_factor) and is_list(short_factor) and
is_number(original_max_positions) ->
{:ok,
%{
type: :su,
long_factor: lf,
short_factor: sf,
original_max_positions: spec.original_max_positions
type: String.to_atom(type),
long_factor: long_factor,
short_factor: short_factor,
original_max_positions: original_max_positions
}}

_other ->
Expand All @@ -458,7 +447,6 @@ defmodule Bumblebee.Text.Phi3 do
convert!(data,
vocab_size: {"vocab_size", number()},
max_positions: {"max_position_embeddings", number()},
original_max_positions: {"original_max_position_embeddings", number()},
hidden_size: {"hidden_size", number()},
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
Expand Down
52 changes: 52 additions & 0 deletions test/bumblebee/text/phi3_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,58 @@ defmodule Bumblebee.Text.Phi3Test do
)
end

test ":base rotary embedding scaling strategy :su" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-Phi3Model-rope_scaling-su-original_max_position_embeddings-256"}
)

assert %Bumblebee.Text.Phi3{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[-1.4528, 0.5995, 0.1573], [-0.2664, 1.9339, 0.5336], [1.1053, -0.1643, 0.5989]]
])
)
end

test ":base rotary embedding scaling strategy :yarn" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-Phi3Model-rope_scaling-yarn-original_max_position_embeddings-256"}
)

assert %Bumblebee.Text.Phi3{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[-1.4530, 0.5995, 0.1574], [-0.2663, 1.9339, 0.5336], [1.1052, -0.1642, 0.5989]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down

0 comments on commit 08a3e47

Please sign in to comment.