Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add :spec_overrides to Bumblebee.load_model/2 #340

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ defmodule Bumblebee do
* `:spec` - the model specification to use when building the model.
By default the specification is loaded using `load_spec/2`

* `:spec_overrides` - additional options to configure the model
specification with. This is a shorthand for using `load_spec/2`,
`configure/2` and passing as `:spec`

* `:module` - the model specification module. By default it is
inferred from the configuration file, if that is not possible,
it must be specified explicitly
Expand Down Expand Up @@ -534,6 +538,11 @@ defmodule Bumblebee do
spec = Bumblebee.configure(spec, num_labels: 10)
{:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)

Or as a shorthand, you can pass just the options to override:

{:ok, resnet} =
Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec_overrides: [num_labels: 10])

"""
@doc type: :model
@spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()}
Expand All @@ -543,6 +552,7 @@ defmodule Bumblebee do
opts =
Keyword.validate!(opts, [
:spec,
:spec_overrides,
:module,
:architecture,
:params_variant,
Expand All @@ -561,10 +571,19 @@ defmodule Bumblebee do
end

defp maybe_load_model_spec(opts, repository, repo_files) do
if spec = opts[:spec] do
{:ok, spec}
else
do_load_spec(repository, repo_files, opts[:module], opts[:architecture])
spec_result =
if spec = opts[:spec] do
{:ok, spec}
else
do_load_spec(repository, repo_files, opts[:module], opts[:architecture])
end

with {:ok, spec} <- spec_result do
if options = opts[:spec_overrides] do
{:ok, configure(spec, options)}
else
{:ok, spec}
end
end
end

Expand Down
7 changes: 1 addition & 6 deletions test/bumblebee/vision/dino_v2_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,9 @@ defmodule Bumblebee.Vision.DinoV2Test do
end

test ":backbone with different feature map subset" do
assert {:ok, spec} =
Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"})

spec = Bumblebee.configure(spec, backbone_output_indices: [0, 2])

assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"},
spec: spec
spec_overrides: [backbone_output_indices: [0, 2]]
)

assert %Bumblebee.Vision.DinoV2{architecture: :backbone} = spec
Expand Down
Loading