Skip to content

Commit

Permalink
Reorganize conversion modules
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 26, 2024
1 parent 9d84d45 commit 7bf6e93
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 28 deletions.
4 changes: 2 additions & 2 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ defmodule Bumblebee do
loader_fun: loader_fun
] ++ Keyword.take(opts, [:backend, :log_params_diff])

params = Bumblebee.Conversion.PyTorch.load_params!(model, input_template, paths, opts)
params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts)
{:ok, params}
end
end
Expand Down Expand Up @@ -709,7 +709,7 @@ defmodule Bumblebee do
end

defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true)
defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorch.Loader.load!/1
defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorchLoader.load!/1

@doc """
Featurizes `input` with the given featurizer.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Bumblebee.Conversion.PyTorch.Loader do
defmodule Bumblebee.Conversion.PyTorchLoader do
@moduledoc false

@doc """
Expand Down Expand Up @@ -84,7 +84,7 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
{:storage, storage_type, storage} = storage
type = storage_type_to_nx(storage_type)

lazy_tensor = %Bumblebee.Conversion.PyTorch.FileTensor{
lazy_tensor = %Bumblebee.Conversion.PyTorchLoader.FileTensor{
shape: shape,
type: type,
offset: offset,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
defmodule Bumblebee.Conversion.PyTorch.FileTensor do
defmodule Bumblebee.Conversion.PyTorchLoader.FileTensor do
@moduledoc false

defstruct [:shape, :type, :offset, :strides, :storage]
end

defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorch.FileTensor do
alias Bumblebee.Conversion.PyTorch.Loader
defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorchLoader.FileTensor do
alias Bumblebee.Conversion.PyTorchLoader

def traverse(lazy_tensor, acc, fun) do
template = Nx.template(lazy_tensor.shape, lazy_tensor.type)
Expand All @@ -14,8 +14,8 @@ defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorch.FileTensor do
binary =
case lazy_tensor.storage do
{:zip, path, file_name} ->
Loader.open_zip!(path, fn unzip ->
Loader.read_zip_file(unzip, file_name)
PyTorchLoader.open_zip!(path, fn unzip ->
PyTorchLoader.read_zip_file(unzip, file_name)
end)

{:file, path, offset, size} ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Bumblebee.Conversion.PyTorch do
defmodule Bumblebee.Conversion.PyTorchParams do
@moduledoc false

require Logger
Expand Down Expand Up @@ -26,7 +26,7 @@ defmodule Bumblebee.Conversion.PyTorch do
* `:loader_fun` - a 1-arity function that takes a path argument
and loads the params file. Defaults to
`Bumblebee.Conversion.PyTorch.Loader.load!/1`
`Bumblebee.Conversion.PyTorchLoader.load!/1`
"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map()
Expand All @@ -37,7 +37,7 @@ defmodule Bumblebee.Conversion.PyTorch do
:log_params_diff,
:backend,
params_mapping: %{},
loader_fun: &Bumblebee.Conversion.PyTorch.Loader.load!/1
loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1
])

with_default_backend(opts[:backend], fn ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
defmodule Bumblebee.Conversion.PyTorch.LoaderTest do
defmodule Bumblebee.Conversion.PyTorchLoaderTest do
use ExUnit.Case, async: true

alias Bumblebee.Conversion.PyTorch.Loader
alias Bumblebee.Conversion.PyTorchLoader

setup do
Nx.default_backend(Nx.BinaryBackend)
:ok
end

@dir Path.expand("../../../fixtures/pytorch", __DIR__)
@dir Path.expand("../../fixtures/pytorch", __DIR__)

for format <- ["zip", "legacy"] do
@format format
Expand All @@ -17,7 +17,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do
test "tensors" do
path = Path.join(@dir, "tensors.#{@format}.pt")

assert path |> Loader.load!() |> Enum.map(&Nx.to_tensor/1) == [
assert path |> PyTorchLoader.load!() |> Enum.map(&Nx.to_tensor/1) == [
Nx.tensor([-1.0, 1.0], type: :f64),
Nx.tensor([-1.0, 1.0], type: :f32),
Nx.tensor([-1.0, 1.0], type: :f16),
Expand All @@ -36,7 +36,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do
test "numpy arrays" do
path = Path.join(@dir, "numpy_arrays.#{@format}.pt")

assert Loader.load!(path) == [
assert PyTorchLoader.load!(path) == [
Nx.tensor([-1.0, 1.0], type: :f64),
Nx.tensor([-1.0, 1.0], type: :f32),
Nx.tensor([-1.0, 1.0], type: :f16),
Expand All @@ -57,20 +57,20 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do
test "ordered dict" do
path = Path.join(@dir, "ordered_dict.#{@format}.pt")

assert Loader.load!(path) == %{"x" => 1, "y" => 2}
assert PyTorchLoader.load!(path) == %{"x" => 1, "y" => 2}
end

test "noncontiguous tensor" do
path = Path.join(@dir, "noncontiguous_tensor.#{@format}.pt")

assert path |> Loader.load!() |> Nx.to_tensor() ==
assert path |> PyTorchLoader.load!() |> Nx.to_tensor() ==
Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], type: :s64)
end

test "numpy array in Fortran order" do
path = Path.join(@dir, "noncontiguous_numpy_array.#{@format}.pt")

assert Loader.load!(path) ==
assert PyTorchLoader.load!(path) ==
Nx.tensor([[1, 4], [2, 5], [3, 6]], type: :s64)
end
end
Expand All @@ -84,7 +84,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do
assert {
{:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage1},
{:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage2}
} = Loader.load!(path)
} = PyTorchLoader.load!(path)

assert {:file, path, offset, size} = storage1
assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>>
Expand All @@ -95,7 +95,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do

test "raises if the files does not exist" do
assert_raise File.Error, ~r/no such file or directory/, fn ->
Loader.load!("nonexistent")
PyTorchLoader.load!("nonexistent")
end
end
end
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defmodule Bumblebee.Conversion.PyTorchTest do
defmodule Bumblebee.Conversion.PyTorchParamsTest do
use ExUnit.Case, async: true

import Bumblebee.TestHelpers

alias Bumblebee.Conversion.PyTorch
alias Bumblebee.Conversion.PyTorchParams

@dir Path.expand("../../fixtures/pytorch", __DIR__)

Expand Down Expand Up @@ -40,7 +40,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do
log =
ExUnit.CaptureLog.capture_log(fn ->
params =
PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping())
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)

assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2}))
Expand All @@ -55,7 +57,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do

log =
ExUnit.CaptureLog.capture_log(fn ->
PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping())
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)
end)

assert log =~ """
Expand Down Expand Up @@ -86,7 +90,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do
log =
ExUnit.CaptureLog.capture_log(fn ->
params =
PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping())
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)

assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2}))
Expand All @@ -102,7 +108,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do
log =
ExUnit.CaptureLog.capture_log(fn ->
params =
PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping())
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)

assert_equal(params["base.conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
assert_equal(params["base.conv"]["bias"], Nx.broadcast(0.0, {2}))
Expand Down

0 comments on commit 7bf6e93

Please sign in to comment.