diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 18209a94..52e32443 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -608,7 +608,7 @@ defmodule Bumblebee do loader_fun: loader_fun ] ++ Keyword.take(opts, [:backend, :log_params_diff]) - params = Bumblebee.Conversion.PyTorch.load_params!(model, input_template, paths, opts) + params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts) {:ok, params} end end @@ -709,7 +709,7 @@ defmodule Bumblebee do end defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true) - defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorch.Loader.load!/1 + defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 @doc """ Featurizes `input` with the given featurizer. diff --git a/lib/bumblebee/conversion/pytorch/loader.ex b/lib/bumblebee/conversion/pytorch_loader.ex similarity index 98% rename from lib/bumblebee/conversion/pytorch/loader.ex rename to lib/bumblebee/conversion/pytorch_loader.ex index 6dc08c2e..97ebb2d9 100644 --- a/lib/bumblebee/conversion/pytorch/loader.ex +++ b/lib/bumblebee/conversion/pytorch_loader.ex @@ -1,4 +1,4 @@ -defmodule Bumblebee.Conversion.PyTorch.Loader do +defmodule Bumblebee.Conversion.PyTorchLoader do @moduledoc false @doc """ @@ -84,7 +84,7 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do {:storage, storage_type, storage} = storage type = storage_type_to_nx(storage_type) - lazy_tensor = %Bumblebee.Conversion.PyTorch.FileTensor{ + lazy_tensor = %Bumblebee.Conversion.PyTorchLoader.FileTensor{ shape: shape, type: type, offset: offset, diff --git a/lib/bumblebee/conversion/pytorch/file_tensor.ex b/lib/bumblebee/conversion/pytorch_loader/file_tensor.ex similarity index 87% rename from lib/bumblebee/conversion/pytorch/file_tensor.ex rename to lib/bumblebee/conversion/pytorch_loader/file_tensor.ex index 30f8a8d1..1ae84008 100644 --- a/lib/bumblebee/conversion/pytorch/file_tensor.ex +++ b/lib/bumblebee/conversion/pytorch_loader/file_tensor.ex @@ -1,11 +1,11 @@ -defmodule Bumblebee.Conversion.PyTorch.FileTensor do +defmodule Bumblebee.Conversion.PyTorchLoader.FileTensor do @moduledoc false defstruct [:shape, :type, :offset, :strides, :storage] end -defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorch.FileTensor do - alias Bumblebee.Conversion.PyTorch.Loader +defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorchLoader.FileTensor do + alias Bumblebee.Conversion.PyTorchLoader def traverse(lazy_tensor, acc, fun) do template = Nx.template(lazy_tensor.shape, lazy_tensor.type) @@ -14,8 +14,8 @@ defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorch.FileTensor do binary = case lazy_tensor.storage do {:zip, path, file_name} -> - Loader.open_zip!(path, fn unzip -> - Loader.read_zip_file(unzip, file_name) + PyTorchLoader.open_zip!(path, fn unzip -> + PyTorchLoader.read_zip_file(unzip, file_name) end) {:file, path, offset, size} -> diff --git a/lib/bumblebee/conversion/pytorch.ex b/lib/bumblebee/conversion/pytorch_params.ex similarity index 99% rename from lib/bumblebee/conversion/pytorch.ex rename to lib/bumblebee/conversion/pytorch_params.ex index 7a0bd492..ff5ef4f7 100644 --- a/lib/bumblebee/conversion/pytorch.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -1,4 +1,4 @@ -defmodule Bumblebee.Conversion.PyTorch do +defmodule Bumblebee.Conversion.PyTorchParams do @moduledoc false require Logger @@ -26,7 +26,7 @@ defmodule Bumblebee.Conversion.PyTorch do * `:loader_fun` - a 1-arity function that takes a path argument and loads the params file. Defaults to - `Bumblebee.Conversion.PyTorch.Loader.load!/1` + `Bumblebee.Conversion.PyTorchLoader.load!/1` """ @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map() @@ -37,7 +37,7 @@ defmodule Bumblebee.Conversion.PyTorch do :log_params_diff, :backend, params_mapping: %{}, - loader_fun: &Bumblebee.Conversion.PyTorch.Loader.load!/1 + loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1 ]) with_default_backend(opts[:backend], fn -> diff --git a/test/bumblebee/conversion/pytorch/loader_test.exs b/test/bumblebee/conversion/pytorch_loader_test.exs similarity index 85% rename from test/bumblebee/conversion/pytorch/loader_test.exs rename to test/bumblebee/conversion/pytorch_loader_test.exs index 1f63352e..3a5e672d 100644 --- a/test/bumblebee/conversion/pytorch/loader_test.exs +++ b/test/bumblebee/conversion/pytorch_loader_test.exs @@ -1,14 +1,14 @@ -defmodule Bumblebee.Conversion.PyTorch.LoaderTest do +defmodule Bumblebee.Conversion.PyTorchLoaderTest do use ExUnit.Case, async: true - alias Bumblebee.Conversion.PyTorch.Loader + alias Bumblebee.Conversion.PyTorchLoader setup do Nx.default_backend(Nx.BinaryBackend) :ok end - @dir Path.expand("../../../fixtures/pytorch", __DIR__) + @dir Path.expand("../../fixtures/pytorch", __DIR__) for format <- ["zip", "legacy"] do @format format @@ -17,7 +17,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do test "tensors" do path = Path.join(@dir, "tensors.#{@format}.pt") - assert path |> Loader.load!() |> Enum.map(&Nx.to_tensor/1) == [ + assert path |> PyTorchLoader.load!() |> Enum.map(&Nx.to_tensor/1) == [ Nx.tensor([-1.0, 1.0], type: :f64), Nx.tensor([-1.0, 1.0], type: :f32), Nx.tensor([-1.0, 1.0], type: :f16), @@ -36,7 +36,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do test "numpy arrays" do path = Path.join(@dir, "numpy_arrays.#{@format}.pt") - assert Loader.load!(path) == [ + assert PyTorchLoader.load!(path) == [ Nx.tensor([-1.0, 1.0], type: :f64), Nx.tensor([-1.0, 1.0], type: :f32), Nx.tensor([-1.0, 1.0], type: :f16), @@ -57,20 +57,20 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do test "ordered dict" do path = Path.join(@dir, "ordered_dict.#{@format}.pt") - assert Loader.load!(path) == %{"x" => 1, "y" => 2} + assert PyTorchLoader.load!(path) == %{"x" => 1, "y" => 2} end test "noncontiguous tensor" do path = Path.join(@dir, "noncontiguous_tensor.#{@format}.pt") - assert path |> Loader.load!() |> Nx.to_tensor() == + assert path |> PyTorchLoader.load!() |> Nx.to_tensor() == Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], type: :s64) end test "numpy array in Fortran order" do path = Path.join(@dir, "noncontiguous_numpy_array.#{@format}.pt") - assert Loader.load!(path) == + assert PyTorchLoader.load!(path) == Nx.tensor([[1, 4], [2, 5], [3, 6]], type: :s64) end end @@ -84,7 +84,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do assert { {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage1}, {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage2} - } = Loader.load!(path) + } = PyTorchLoader.load!(path) assert {:file, path, offset, size} = storage1 assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>> @@ -95,7 +95,7 @@ defmodule Bumblebee.Conversion.PyTorch.LoaderTest do test "raises if the files does not exist" do assert_raise File.Error, ~r/no such file or directory/, fn -> - Loader.load!("nonexistent") + PyTorchLoader.load!("nonexistent") end end end diff --git a/test/bumblebee/conversion/pytorch_test.exs b/test/bumblebee/conversion/pytorch_params_test.exs similarity index 82% rename from test/bumblebee/conversion/pytorch_test.exs rename to test/bumblebee/conversion/pytorch_params_test.exs index ba55d33b..c73dbef7 100644 --- a/test/bumblebee/conversion/pytorch_test.exs +++ b/test/bumblebee/conversion/pytorch_params_test.exs @@ -1,9 +1,9 @@ -defmodule Bumblebee.Conversion.PyTorchTest do +defmodule Bumblebee.Conversion.PyTorchParamsTest do use ExUnit.Case, async: true import Bumblebee.TestHelpers - alias Bumblebee.Conversion.PyTorch + alias Bumblebee.Conversion.PyTorchParams @dir Path.expand("../../fixtures/pytorch", __DIR__) @@ -40,7 +40,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do log = ExUnit.CaptureLog.capture_log(fn -> params = - PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping()) + PyTorchParams.load_params!(model, input_template(), path, + params_mapping: params_mapping() + ) assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2})) @@ -55,7 +57,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do log = ExUnit.CaptureLog.capture_log(fn -> - PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping()) + PyTorchParams.load_params!(model, input_template(), path, + params_mapping: params_mapping() + ) end) assert log =~ """ @@ -86,7 +90,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do log = ExUnit.CaptureLog.capture_log(fn -> params = - PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping()) + PyTorchParams.load_params!(model, input_template(), path, + params_mapping: params_mapping() + ) assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2})) @@ -102,7 +108,9 @@ defmodule Bumblebee.Conversion.PyTorchTest do log = ExUnit.CaptureLog.capture_log(fn -> params = - PyTorch.load_params!(model, input_template(), path, params_mapping: params_mapping()) + PyTorchParams.load_params!(model, input_template(), path, + params_mapping: params_mapping() + ) assert_equal(params["base.conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) assert_equal(params["base.conv"]["bias"], Nx.broadcast(0.0, {2}))