diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 620ff84c..63150243 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -123,6 +123,9 @@ defmodule Bumblebee do {Bumblebee.Vision.Deit, :for_image_classification_with_teacher}, "DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling}, "DeiTModel" => {Bumblebee.Vision.Deit, :base}, + "Dinov2Model" => {Bumblebee.Vision.DinoV2, :base}, + "Dinov2Backbone" => {Bumblebee.Vision.DinoV2, :backbone}, + "Dinov2ForImageClassification" => {Bumblebee.Vision.DinoV2, :for_image_classification}, "DistilBertModel" => {Bumblebee.Text.Distilbert, :base}, "DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling}, "DistilBertForSequenceClassification" => @@ -203,7 +206,8 @@ defmodule Bumblebee do } @transformers_image_processor_type_to_featurizer %{ - "BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer + "BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer, + "BitImageProcessor" => Bumblebee.Vision.BitFeaturizer } @model_type_to_featurizer %{ diff --git a/lib/bumblebee/diffusion/layers/unet.ex b/lib/bumblebee/diffusion/layers/unet.ex index dc520ed5..0cd72fcb 100644 --- a/lib/bumblebee/diffusion/layers/unet.ex +++ b/lib/bumblebee/diffusion/layers/unet.ex @@ -323,7 +323,7 @@ defmodule Bumblebee.Diffusion.Layers.UNet do epsilon: 1.0e-5 ], dropout_rate: dropout, - ffn: &ffn_geglu(&1, hidden_size, dropout: dropout, name: &2), + ffn: &ffn_geglu(&1, 4 * hidden_size, hidden_size, dropout: dropout, name: &2), block_type: :norm_first, name: join(name, "blocks") ) @@ -347,12 +347,10 @@ defmodule Bumblebee.Diffusion.Layers.UNet do end # A feed-forward network with GEGLU nonlinearity as in https://arxiv.org/abs/2002.05202 - defp ffn_geglu(x, size, opts) do + defp ffn_geglu(x, intermediate_size, output_size, opts) do name = opts[:name] dropout = opts[:dropout] || 0.0 - intermediate_size = 4 * size - {x, gate} = x |> Axon.dense(intermediate_size * 2, name: join(name, "intermediate")) @@ -362,6 +360,6 @@ defmodule Bumblebee.Diffusion.Layers.UNet do x |> Axon.dropout(rate: dropout, name: join(name, "dropout")) - |> Axon.dense(size, name: join(name, "output")) + |> Axon.dense(output_size, name: join(name, "output")) end end diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index b5b4a6dc..78a6a63d 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -265,6 +265,10 @@ defmodule Bumblebee.Layers.Transformer do * `:parallel` - block with attention and FFN independently (in parallel). This type doesn't support cross-attention + Alternatively a custom 3-arity function may be given. The function + receives the input hidden state, a map with block steps and a + name to prefix any additional layers. + * `:scale_attention_weights` - whether to scale query in the traditional style of multi-headed attention. Defaults to `true` @@ -469,17 +473,25 @@ defmodule Bumblebee.Layers.Transformer do ffn = &ffn_fun.(&1, join(name, "ffn")) + block_impl = + case block_type do + type when is_atom(type) -> &block_impl(type, &1, &2, &3) + fun when is_function(fun) -> fun + end + {hidden_state, attention_info, cross_attention_info} = - block_impl( - block_type, + block_impl.( hidden_state, - self_attention_norm, - self_attention, - cross_attention_maybe, - cross_attention_norm, - cross_attention, - output_norm, - ffn + %{ + self_attention_norm: self_attention_norm, + self_attention: self_attention, + cross_attention_maybe: cross_attention_maybe, + cross_attention_norm: cross_attention_norm, + cross_attention: cross_attention, + output_norm: output_norm, + ffn: ffn + }, + name ) {attention, self_attention_cache, attention_relative_bias} = attention_info @@ -495,36 +507,26 @@ defmodule Bumblebee.Layers.Transformer do {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} end - defp block_impl( - :standard, - hidden_state, - self_attention_norm, - self_attention, - cross_attention_maybe, - cross_attention_norm, - cross_attention, - output_norm, - ffn - ) do + defp block_impl(:standard, hidden_state, steps, _name) do shortcut = hidden_state - {hidden_state, attention_info} = self_attention.(hidden_state) + {hidden_state, attention_info} = steps.self_attention.(hidden_state) hidden_state = hidden_state |> Axon.add(shortcut) - |> self_attention_norm.() + |> steps.self_attention_norm.() {hidden_state, cross_attention_info} = - cross_attention_maybe.(hidden_state, fn hidden_state -> + steps.cross_attention_maybe.(hidden_state, fn hidden_state -> shortcut = hidden_state - {hidden_state, cross_attention_info} = cross_attention.(hidden_state) + {hidden_state, cross_attention_info} = steps.cross_attention.(hidden_state) hidden_state = hidden_state |> Axon.add(shortcut) - |> cross_attention_norm.() + |> steps.cross_attention_norm.() {hidden_state, cross_attention_info} end) @@ -533,41 +535,31 @@ defmodule Bumblebee.Layers.Transformer do hidden_state = hidden_state - |> ffn.() + |> steps.ffn.() |> Axon.add(shortcut) - |> output_norm.() + |> steps.output_norm.() {hidden_state, attention_info, cross_attention_info} end - defp block_impl( - :norm_first, - hidden_state, - self_attention_norm, - self_attention, - cross_attention_maybe, - cross_attention_norm, - cross_attention, - output_norm, - ffn - ) do + defp block_impl(:norm_first, hidden_state, steps, _name) do shortcut = hidden_state {hidden_state, attention_info} = hidden_state - |> self_attention_norm.() - |> self_attention.() + |> steps.self_attention_norm.() + |> steps.self_attention.() hidden_state = Axon.add(hidden_state, shortcut) {hidden_state, cross_attention_info} = - cross_attention_maybe.(hidden_state, fn hidden_state -> + steps.cross_attention_maybe.(hidden_state, fn hidden_state -> shortcut = hidden_state {hidden_state, cross_attention_info} = hidden_state - |> cross_attention_norm.() - |> cross_attention.() + |> steps.cross_attention_norm.() + |> steps.cross_attention.() hidden_state = Axon.add(hidden_state, shortcut) @@ -578,40 +570,30 @@ defmodule Bumblebee.Layers.Transformer do hidden_state = hidden_state - |> output_norm.() - |> ffn.() + |> steps.output_norm.() + |> steps.ffn.() |> Axon.add(shortcut) {hidden_state, attention_info, cross_attention_info} end - defp block_impl( - :parallel, - hidden_state, - self_attention_norm, - self_attention, - cross_attention_maybe, - _cross_attention_norm, - _cross_attention, - output_norm, - ffn - ) do + defp block_impl(:parallel, hidden_state, steps, _name) do shortcut = hidden_state {attention_hidden_state, attention_info} = hidden_state - |> self_attention_norm.() - |> self_attention.() + |> steps.self_attention_norm.() + |> steps.self_attention.() {_hidden_state, cross_attention_info} = - cross_attention_maybe.(hidden_state, fn _hidden_state -> + steps.cross_attention_maybe.(hidden_state, fn _hidden_state -> raise "cross attention not supported" end) ffn_hidden_state = hidden_state - |> output_norm.() - |> ffn.() + |> steps.output_norm.() + |> steps.ffn.() hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state]) diff --git a/lib/bumblebee/vision/bit_featurizer.ex b/lib/bumblebee/vision/bit_featurizer.ex new file mode 100644 index 00000000..71ee5545 --- /dev/null +++ b/lib/bumblebee/vision/bit_featurizer.ex @@ -0,0 +1,179 @@ +defmodule Bumblebee.Vision.BitFeaturizer do + alias Bumblebee.Shared + + options = [ + resize: [ + default: true, + doc: "whether to resize the input to the given `:size`" + ], + size: [ + default: %{shortest_edge: 448}, + doc: """ + the size to resize the input to, either `%{height: ..., width: ...}` or `%{shortest_edge: ...}`. + Only has an effect if `:resize` is `true` + """ + ], + resize_method: [ + default: :bicubic, + doc: + "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`" + ], + center_crop: [ + default: true, + doc: """ + whether to crop the input at the center. If the input size is smaller than `:crop_size` along + any edge, the image is padded with zeros and then center cropped + """ + ], + crop_size: [ + default: %{height: 448, width: 448}, + doc: """ + the size to center crop the image to, given as `%{height: ..., width: ...}`. Only has an effect + if `:center_crop` is `true` + """ + ], + rescale: [ + default: true, + doc: "whether to rescale the input by the given `:rescale_factor`" + ], + rescale_factor: [ + default: 0.00392156862745098, + doc: """ + the factor by which to rescale the input. A single number + Only has an effect if `:rescale` is `true` + """ + ], + normalize: [ + default: true, + doc: "whether or not to normalize the input with mean and standard deviation" + ], + image_mean: [ + default: [0.5, 0.5, 0.5], + doc: "the sequence of mean values for each channel, to be used when normalizing images" + ], + image_std: [ + default: [0.5, 0.5, 0.5], + doc: + "the sequence of standard deviations for each channel, to be used when normalizing images" + ] + ] + + @moduledoc """ + BiT featurizer for image data. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct Shared.option_defaults(options) + + @behaviour Bumblebee.Featurizer + @behaviour Bumblebee.Configurable + + alias Bumblebee.Utils.Image + + @impl true + def config(featurizer, opts) do + featurizer = Shared.put_config_attrs(featurizer, opts) + + if featurizer.resize and Shared.featurizer_size_fixed?(featurizer.size) and + not featurizer.center_crop do + raise ArgumentError, + "the resize shape depends on the input shape and cropping is disabled." <> + "You must either configure a fixed size or enable cropping" + end + + featurizer + end + + @impl true + def process_input(featurizer, images) do + images = List.wrap(images) + + for image <- images do + images = + image + |> Image.to_batched_tensor() + |> Nx.as_type(:f32) + |> Image.normalize_channels(length(featurizer.image_mean)) + + images = + if featurizer.resize do + size = Shared.featurizer_resize_size(images, featurizer.size) + NxImage.resize(images, size, method: featurizer.resize_method) + else + images + end + + if featurizer.center_crop do + %{height: height, width: width} = featurizer.crop_size + NxImage.center_crop(images, {height, width}) + else + images + end + end + |> Nx.concatenate() + end + + @impl true + def batch_template(featurizer, batch_size) do + num_channels = length(featurizer.image_mean) + + {height, width} = + case featurizer do + %{center_crop: true, crop_size: %{height: height, width: width}} -> + {height, width} + + %{resize: true, size: %{height: height, width: width}} -> + {height, width} + end + + Nx.template({batch_size, height, width, num_channels}, :f32) + end + + @impl true + def process_batch(featurizer, images) do + images = + if featurizer.rescale do + Nx.multiply(images, featurizer.rescale_factor) + else + images + end + + images = + if featurizer.normalize do + NxImage.normalize( + images, + Nx.tensor(featurizer.image_mean), + Nx.tensor(featurizer.image_std) + ) + else + images + end + + %{"pixel_values" => images} + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(featurizer, data) do + import Shared.Converters + + opts = + convert!(data, + resize: {"do_resize", boolean()}, + size: {"size", image_size(single_as: :shortest_edge)}, + resize_method: {"resample", resize_method()}, + center_crop: {"do_center_crop", boolean()}, + crop_size: {"crop_size", image_size()}, + rescale: {"do_rescale", boolean()}, + rescale_factor: {"rescale_factor", number()}, + normalize: {"do_normalize", boolean()}, + image_mean: {"image_mean", list(number())}, + image_std: {"image_std", list(number())} + ) + + @for.config(featurizer, opts) + end + end +end diff --git a/lib/bumblebee/vision/deit.ex b/lib/bumblebee/vision/deit.ex index 6d437c55..61fd7c27 100644 --- a/lib/bumblebee/vision/deit.ex +++ b/lib/bumblebee/vision/deit.ex @@ -249,11 +249,11 @@ defmodule Bumblebee.Vision.Deit do name: join(name, "norm") ) - pooled = pooler(hidden_state, spec, name: join(name, "pooler")) + pooled_state = pooler(hidden_state, spec, name: join(name, "pooler")) %{ hidden_state: hidden_state, - pooled_state: pooled, + pooled_state: pooled_state, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions } diff --git a/lib/bumblebee/vision/dino_v2.ex b/lib/bumblebee/vision/dino_v2.ex new file mode 100644 index 00000000..8af774ee --- /dev/null +++ b/lib/bumblebee/vision/dino_v2.ex @@ -0,0 +1,520 @@ +defmodule Bumblebee.Vision.DinoV2 do + alias Bumblebee.Shared + + options = + [ + image_size: [ + default: 518, + doc: """ + the size of the input spatial dimensions. The model is trained for this size, however + the model supports any other input size by interpolating position embeddings + """ + ], + num_channels: [ + default: 3, + doc: "the number of channels in the input" + ], + patch_size: [ + default: 14, + doc: "the size of the patch spatial dimensions" + ], + hidden_size: [ + default: 384, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size_ratio: [ + default: 4, + doc: """ + the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder, + expressed as a multiplier of `:hidden_size` + """ + ], + use_qkv_bias: [ + default: true, + doc: "whether to use bias in query, key, and value projections" + ], + activation: [ + default: :gelu, + doc: "the activation function" + ], + ffn_swiglu_activation: [ + default: false, + doc: + "whether to use the gated SwiGLU activation function in the feed-forward network (FFN)" + ], + scale_initial_value: [ + default: 1.0, + doc: "the initial value for scaling layers" + ], + dropout_rate: [ + default: 0.0, + doc: "the dropout rate for encoder and decoder" + ], + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by the layer normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + backbone_output_indices: [ + default: nil, + doc: """ + list of indices indicating which feature maps to include in the output. If not specified, only + the last feature map is included + """ + ], + backbone_use_norm: [ + default: true, + doc: + "whether to add layer normalization layer to each of the feature maps returned by the backbone" + ] + ] ++ + Shared.common_options([ + :output_hidden_states, + :output_attentions, + :num_labels, + :id_to_label + ]) + + @moduledoc """ + DINOv2 model family. + + ## Architectures + + * `:base` - plain DINOv2 without any head on top + + * `:for_image_classification` - DINOv2 with head for image classification + + * `:backbone` - DINOv2 with feature maps output + + ## Inputs + + * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` + + Featurized image pixel values. + + * `"patch_mask"` - `{batch_size, num_patches}` + + Mask to nullify selected embedded patches. + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base, :backbone, :for_image_classification] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(spec) do + %{ + "pixel_values" => + Nx.template({1, spec.image_size, spec.image_size, spec.num_channels}, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + hidden_state = + Axon.layer_norm(outputs.hidden_state, epsilon: spec.layer_norm_epsilon, name: "norm") + + pooled_state = Layers.take_token(hidden_state, index: 0, axis: 1) + + Layers.output(%{ + hidden_state: hidden_state, + pooled_state: pooled_state, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :for_image_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + hidden_state = + Axon.layer_norm(outputs.hidden_state, epsilon: spec.layer_norm_epsilon, name: "norm") + + class_token = Layers.take_token(hidden_state, index: 0, axis: 1) + + patch_embeddings_mean = + Axon.nx(hidden_state, fn hidden_state -> + patch_embeddings = hidden_state[[.., 1..-1//1, ..]] + Nx.mean(patch_embeddings, axes: [1]) + end) + + logits = + Axon.concatenate(class_token, patch_embeddings_mean) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "image_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + def model(%__MODULE__{architecture: :backbone} = spec) do + inputs = inputs(spec) + outputs = core(inputs, %{spec | output_hidden_states: true}) + feature_maps = feature_maps(outputs.hidden_states, inputs["pixel_values"], spec) + + Layers.output(%{ + feature_maps: feature_maps, + hidden_states: + if(spec.output_hidden_states, do: outputs.hidden_states, else: Layers.none()), + attentions: outputs.attentions + }) + end + + defp inputs(spec) do + shape = {nil, nil, nil, spec.num_channels} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("pixel_values", shape: shape), + Axon.input("patch_mask", shape: {nil, nil}, optional: true) + ]) + end + + defp core(inputs, spec, opts \\ []) do + name = opts[:name] + + embeddings = + embedder(inputs["pixel_values"], inputs["patch_mask"], spec, name: join(name, "embedder")) + + encoder_outputs = encoder(embeddings, spec, name: join(name, "encoder")) + + %{ + hidden_state: encoder_outputs.hidden_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions + } + end + + defp feature_maps(hidden_states, pixel_values, spec, opts \\ []) do + name = opts[:name] + + num_feature_maps = spec.num_blocks + 1 + output_indices = spec.backbone_output_indices || [num_feature_maps - 1] + + for index <- output_indices do + hidden_state = Axon.nx(hidden_states, &elem(&1, index)) + + hidden_state = + if spec.backbone_use_norm do + Axon.layer_norm(hidden_state, + epsilon: spec.layer_norm_epsilon, + name: join(name, "norm") + ) + else + hidden_state + end + + Axon.layer( + fn hidden_state, pixel_values, _opts -> + {batch_size, height, width, _channels} = Nx.shape(pixel_values) + + hidden_state = hidden_state[[.., 1..-1//1, ..]] + + Nx.reshape( + hidden_state, + {batch_size, div(height, spec.patch_size), div(width, spec.patch_size), :auto} + ) + end, + [hidden_state, pixel_values] + ) + end + |> List.to_tuple() + |> Axon.container() + end + + defp embedder(pixel_values, patch_mask, spec, opts) do + name = opts[:name] + + patch_embeddings = + pixel_values + |> patch_embedding(spec, name: join(name, "patch_embedding")) + |> Layers.apply_vision_patch_mask(patch_mask, name: join(name, "mask_tokens")) + + class_embedding = + Layers.learned_embeddings(1, spec.hidden_size, name: join(name, "class_embedding")) + + input_embeddings = Layers.concatenate_embeddings([class_embedding, patch_embeddings]) + + num_patches = div(spec.image_size, spec.patch_size) ** 2 + + position_embeddings = + Layers.learned_embeddings(num_patches + 1, spec.hidden_size, + initializer: :zeros, + name: join(name, "position_embedding") + ) + |> interpolate_position_embeddings(pixel_values, spec) + + Axon.add(input_embeddings, position_embeddings) + |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) + end + + defp patch_embedding(pixel_values, spec, opts) do + name = opts[:name] + + pixel_values + |> Axon.conv(spec.hidden_size, + kernel_size: spec.patch_size, + strides: spec.patch_size, + padding: :valid, + kernel_initializer: kernel_initializer(spec), + name: join(name, "projection") + ) + |> Axon.reshape({:batch, :auto, spec.hidden_size}, name: join(name, "reshape")) + end + + defp interpolate_position_embeddings(position_embeddings, pixel_values, spec) do + Axon.layer( + fn position_embeddings, pixel_values, _opts -> + original_positions = div(spec.image_size, spec.patch_size) + {batch_size, height, width, _channels} = Nx.shape(pixel_values) + resized_height = div(height, spec.patch_size) + resized_width = div(width, spec.patch_size) + + class_position_embedding = position_embeddings[[.., 0..0//1, ..]] + input_position_embeddings = position_embeddings[[.., 1..-1//1, ..]] + + interpolated_position_embeddings = + input_position_embeddings + |> Nx.reshape({batch_size, original_positions, original_positions, spec.hidden_size}) + |> Axon.Layers.resize( + size: {resized_height, resized_width}, + method: :bicubic, + antialias: false + ) + |> Nx.reshape({batch_size, :auto, spec.hidden_size}) + + Nx.concatenate([class_position_embedding, interpolated_position_embeddings], axis: 1) + end, + [position_embeddings, pixel_values], + op_name: :interpolate_position_embeddings + ) + end + + defp encoder(hidden_state, spec, opts) do + name = opts[:name] + + ffn = + if spec.ffn_swiglu_activation do + intermediate_size = + div(floor(floor(spec.hidden_size * spec.intermediate_size_ratio) * 2 / 3 + 7), 8) * 8 + + &ffn_swiglu(&1, intermediate_size, spec.hidden_size, name: &2) + else + intermediate_size = floor(spec.hidden_size * spec.intermediate_size_ratio) + + [ + intermediate_size: intermediate_size, + activation: spec.activation + ] + end + + Layers.Transformer.blocks(hidden_state, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + attention_dropout_rate: spec.attention_dropout_rate, + query_use_bias: spec.use_qkv_bias, + key_use_bias: spec.use_qkv_bias, + value_use_bias: spec.use_qkv_bias, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: ffn, + block_type: &block_impl(&1, &2, &3, spec), + output_hidden_states: spec.output_hidden_states, + output_attentions: spec.output_attentions, + name: join(name, "blocks") + ) + end + + # A feed-forward network with SwiGLU nonlinearity as in https://arxiv.org/abs/2002.05202 + defp ffn_swiglu(x, intermediate_size, output_size, opts) do + name = opts[:name] + dropout = opts[:dropout] || 0.0 + + {gate, x} = + x + |> Axon.dense(intermediate_size * 2, name: join(name, "intermediate")) + |> Axon.split(2, axis: -1) + + x = Axon.multiply(x, Axon.silu(gate)) + + x + |> Axon.dropout(rate: dropout, name: join(name, "dropout")) + |> Axon.dense(output_size, name: join(name, "output")) + end + + # :norm_first block with additional scaling layers + defp block_impl(hidden_state, steps, name, spec) do + shortcut = hidden_state + + {hidden_state, attention_info} = + hidden_state + |> steps.self_attention_norm.() + |> steps.self_attention.() + + hidden_state = + hidden_state + |> Bumblebee.Layers.scale( + scale_initializer: Axon.Initializers.full(spec.scale_initial_value), + name: join(name, "self_attention_scale") + ) + |> Axon.add(shortcut) + + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _hidden_state -> + raise "cross attention not supported" + end) + + shortcut = hidden_state + + hidden_state = + hidden_state + |> steps.output_norm.() + |> steps.ffn.() + |> Bumblebee.Layers.scale( + scale_initializer: Axon.Initializers.full(spec.scale_initial_value), + name: join(name, "output_scale") + ) + |> Axon.add(shortcut) + + {hidden_state, attention_info, cross_attention_info} + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + image_size: {"image_size", number()}, + num_channels: {"num_channels", number()}, + patch_size: {"patch_size", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size_ratio: {"mlp_ratio", number()}, + activation: {"hidden_act", activation()}, + use_qkv_bias: {"qkv_bias", boolean()}, + dropout_rate: {"hidden_dropout_prob", number()}, + attention_dropout_rate: {"attention_probs_dropout_prob", number()}, + layer_norm_epsilon: {"layer_norm_eps", number()}, + initializer_scale: {"initializer_range", number()}, + scale_initial_value: {"layerscale_value", number()}, + ffn_swiglu_activation: {"use_swiglu_ffn", boolean()}, + backbone_output_indices: {"out_indices", list(number())}, + backbone_use_norm: {"use_backbone_norm", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + %{ + "embedder.patch_embedding.projection" => "dinov2.embeddings.patch_embeddings.projection", + "embedder.class_embedding" => %{ + "embeddings" => { + [{"dinov2.embeddings", "cls_token"}], + fn [value] -> Nx.squeeze(value, axes: [0]) end + } + }, + "embedder.position_embedding" => %{ + "embeddings" => { + [{"dinov2.embeddings", "position_embeddings"}], + fn [value] -> Nx.squeeze(value, axes: [0]) end + } + }, + "encoder.blocks.{n}.self_attention_norm" => "dinov2.encoder.layer.{n}.norm1", + "encoder.blocks.{n}.self_attention.key" => + "dinov2.encoder.layer.{n}.attention.attention.key", + "encoder.blocks.{n}.self_attention.query" => + "dinov2.encoder.layer.{n}.attention.attention.query", + "encoder.blocks.{n}.self_attention.value" => + "dinov2.encoder.layer.{n}.attention.attention.value", + "encoder.blocks.{n}.self_attention.output" => + "dinov2.encoder.layer.{n}.attention.output.dense", + "encoder.blocks.{n}.self_attention_scale" => %{ + "scale" => { + [{"dinov2.encoder.layer.{n}.layer_scale1", "lambda1"}], + fn [lambda1] -> lambda1 end + } + }, + "encoder.blocks.{n}.ffn.intermediate" => + if(spec.ffn_swiglu_activation, + do: "dinov2.encoder.layer.{n}.mlp.weights_in", + else: "dinov2.encoder.layer.{n}.mlp.fc1" + ), + "encoder.blocks.{n}.ffn.output" => + if(spec.ffn_swiglu_activation, + do: "dinov2.encoder.layer.{n}.mlp.weights_out", + else: "dinov2.encoder.layer.{n}.mlp.fc2" + ), + "encoder.blocks.{n}.output_norm" => "dinov2.encoder.layer.{n}.norm2", + "encoder.blocks.{n}.output_scale" => %{ + "scale" => { + [{"dinov2.encoder.layer.{n}.layer_scale2", "lambda1"}], + fn [lambda1] -> lambda1 end + } + }, + "norm" => "dinov2.layernorm", + "image_classification_head.output" => "classifier" + } + end + end +end diff --git a/lib/bumblebee/vision/vit.ex b/lib/bumblebee/vision/vit.ex index de75644c..a2792fd9 100644 --- a/lib/bumblebee/vision/vit.ex +++ b/lib/bumblebee/vision/vit.ex @@ -203,11 +203,11 @@ defmodule Bumblebee.Vision.Vit do name: join(name, "norm") ) - pooled = pooler(hidden_state, spec, name: join(name, "pooler")) + pooled_state = pooler(hidden_state, spec, name: join(name, "pooler")) %{ hidden_state: hidden_state, - pooled_state: pooled, + pooled_state: pooled_state, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions } diff --git a/mix.exs b/mix.exs index 811477cb..5bf8c197 100644 --- a/mix.exs +++ b/mix.exs @@ -99,12 +99,14 @@ defmodule Bumblebee.MixProject do Bumblebee.Vision.ClipVision, Bumblebee.Vision.ConvNext, Bumblebee.Vision.Deit, + Bumblebee.Vision.DinoV2, Bumblebee.Vision.ResNet, Bumblebee.Vision.Vit ], Preprocessors: [ Bumblebee.Audio.WhisperFeaturizer, Bumblebee.Text.PreTrainedTokenizer, + Bumblebee.Vision.BitFeaturizer, Bumblebee.Vision.BlipFeaturizer, Bumblebee.Vision.ClipFeaturizer, Bumblebee.Vision.ConvNextFeaturizer, diff --git a/mix.lock b/mix.lock index 3962cd64..844dba7c 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:git, "https://github.com/elixir-nx/axon.git", "67b48c7a43438f5eec2a35311572565cafe889d7", []}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "885d2d2a5d85970e9430de4414cab91a2b77a75a", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"}, @@ -20,7 +20,7 @@ "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "nx": {:git, "https://github.com/elixir-nx/nx.git", "b40321c2b75d34f8e781f5805ef20c7be853fa57", [sparse: "nx"]}, - "nx_image": {:hex, :nx_image, "0.1.1", "69cf0d2fd873d12b028583aa49b5e0a25f6aca307afc337a5d871851a20fba1d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "55c8206a822237f6027168f11214e3887263c5b8a1f8e0634eea82c96e5093e3"}, + "nx_image": {:hex, :nx_image, "0.1.2", "0c6e3453c1dc30fc80c723a54861204304cebc8a89ed3b806b972c73ee5d119d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9161863c42405ddccb6dbbbeae078ad23e30201509cc804b3b3a7c9e98764b81"}, "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, "plug": {:hex, :plug, "1.14.2", "cff7d4ec45b4ae176a227acd94a7ab536d9b37b942c8e8fa6dfc0fff98ff4d80", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "842fc50187e13cf4ac3b253d47d9474ed6c296a8732752835ce4a86acdf68d13"}, "plug_cowboy": {:hex, :plug_cowboy, "2.6.1", "9a3bbfceeb65eff5f39dab529e5cd79137ac36e913c02067dba3963a26efe9b2", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "de36e1a21f451a18b790f37765db198075c25875c64834bcc82d90b309eb6613"}, diff --git a/test/bumblebee/vision/bit_featurizer_test.ex b/test/bumblebee/vision/bit_featurizer_test.ex new file mode 100644 index 00000000..ab2125f3 --- /dev/null +++ b/test/bumblebee/vision/bit_featurizer_test.ex @@ -0,0 +1,15 @@ +defmodule Bumblebee.Vision.BitFeaturizerTest do + use ExUnit.Case, async: true + + test "encodes image" do + assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/bit-50"}) + + assert %Bumblebee.Vision.BitFeaturizer{} = featurizer + + image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) + + inputs = Bumblebee.apply_featurizer(featurizer, image) + + assert Nx.shape(inputs["pixel_values"]) == {1, 448, 448, 3} + end +end diff --git a/test/bumblebee/vision/dino_v2_test.exs b/test/bumblebee/vision/dino_v2_test.exs new file mode 100644 index 00000000..c9140cc1 --- /dev/null +++ b/test/bumblebee/vision/dino_v2_test.exs @@ -0,0 +1,179 @@ +defmodule Bumblebee.Vision.DinoV2Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Model"}) + + assert %Bumblebee.Vision.DinoV2{architecture: :base} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 226, 32} + assert Nx.shape(outputs.pooled_state) == {1, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.1210, -0.3567, -0.4570], [-1.0003, -0.8821, -0.5325], [-0.6919, -0.5334, -0.4568]] + ]) + ) + + assert_all_close( + outputs.pooled_state[[.., 1..3]], + Nx.tensor([[-0.7099, -0.6118, 0.7679]]) + ) + end + + test ":base with position embedding interpolation (different input size)" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Model"}) + + assert %Bumblebee.Vision.DinoV2{architecture: :base} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 64, 64, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 1025, 32} + assert Nx.shape(outputs.pooled_state) == {1, 32} + + # Note: the interpolation has a slightly different implementation + # in PyTorch, so the numbers don't match exactly, but this is used + # at inference and should be fine in practice. We do a low-precision + # comparison for the reference. + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.2287, -0.2291, -0.4323], [-1.1548, -0.4430, -0.4710], [-1.0547, -0.7580, -0.4654]] + ]), + atol: 1.0e-1 + ) + + assert_all_close( + outputs.pooled_state[[.., 1..3]], + Nx.tensor([[-0.7270, -0.5913, 0.7701]]), + atol: 1.0e-2 + ) + end + + test ":base with swiglu ffn" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-Dinov2Model-use_swiglu_ffn-True"} + ) + + assert %Bumblebee.Vision.DinoV2{architecture: :base} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 226, 32} + assert Nx.shape(outputs.pooled_state) == {1, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.4022, 0.2361, 0.6539], [-1.0799, -0.3041, 0.3125], [-0.7367, -0.0650, 0.6671]] + ]) + ) + + assert_all_close( + outputs.pooled_state[[.., 1..3]], + Nx.tensor([[-0.5637, 0.7523, 1.0458]]) + ) + end + + test ":for_image_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "hf-internal-testing/tiny-random-Dinov2ForImageClassification"} + ) + + assert %Bumblebee.Vision.DinoV2{architecture: :for_image_classification} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.1091, 0.0126]]) + ) + end + + test ":backbone" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"}) + + assert %Bumblebee.Vision.DinoV2{architecture: :backbone} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert {feature_map} = outputs.feature_maps + + assert Nx.shape(feature_map) == {1, 15, 15, 32} + + assert_all_close( + feature_map[[.., 1..2, 1..2, 1..2]], + Nx.tensor([[[[1.3373, 0.6393], [0.5469, 1.4045]], [[1.1879, 0.7435], [1.1777, 0.6638]]]]) + ) + end + + test ":backbone with different feature map subset" do + assert {:ok, spec} = + Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"}) + + spec = Bumblebee.configure(spec, backbone_output_indices: [0, 2]) + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"}, + spec: spec + ) + + assert %Bumblebee.Vision.DinoV2{architecture: :backbone} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert {feature_map0, feature_map2} = outputs.feature_maps + + assert Nx.shape(feature_map0) == {1, 15, 15, 32} + assert Nx.shape(feature_map2) == {1, 15, 15, 32} + + assert_all_close( + feature_map0[[.., 1..2, 1..2, 1..2]], + Nx.tensor([[[[0.8425, 0.5487], [0.2003, 1.2553]], [[0.6486, 0.4550], [1.3376, 0.7091]]]]) + ) + + assert_all_close( + feature_map2[[.., 1..2, 1..2, 1..2]], + Nx.tensor([[[[1.3373, 0.6393], [0.5469, 1.4045]], [[1.1879, 0.7435], [1.1777, 0.6638]]]]) + ) + end +end