From 772fdd30a5f65896d3f472d8ed1d46f1d745d3d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Mon, 3 Jun 2024 15:03:48 +0200 Subject: [PATCH 01/15] Added Swin model --- lib/bumblebee.ex | 1 + lib/bumblebee/vision/swin.ex | 433 +++++++++++++++++++++++++++++++++++ mix.lock | 1 + 3 files changed, 435 insertions(+) create mode 100644 lib/bumblebee/vision/swin.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index f44e9069..36474c9d 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -186,6 +186,7 @@ defmodule Bumblebee do "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, + "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation}, "T5EncoderModel" => {Bumblebee.Text.T5, :encoder}, diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex new file mode 100644 index 00000000..da4a8f4b --- /dev/null +++ b/lib/bumblebee/vision/swin.ex @@ -0,0 +1,433 @@ +defmodule Bumblebee.Vision.Swin do + alias Bumblebee.Shared + + options = + [ + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" + ], + depths: [ + default: [2, 2, 18, 2], + doc: "the depth (number of residual blocks) at each stage" + ], + drop_path_rate: [ + default: 0.1, + doc: "the drop path rate used to for stochastic depth" + ], + # Maybe it should be renamed to hidden_size + embed_dim: [ + default: 128, + doc: "" + ], + activation: [ + default: :gelu, + doc: "the activation function" + ], + dropout_rate: [ + default: 0.0, + doc: "the dropout rate for encoder and decoder" + ], + image_size: [ + default: 384, + doc: "the size of the input spatial dimensions" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + layer_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the layer normalization layers" + ], + intermediate_size_ratio: [ + default: 4, + doc: """ + the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder, + expressed as a multiplier of `:hidden_size` + """ + ], + num_channels: [ + default: 3, + doc: "the number of channels in the input" + ], + num_heads: [ + default: [4, 8, 16, 32], + doc: "number of attention heads" + ], + patch_size: [ + default: 4, + doc: "the size of the patch spatial dimensions" + ], + path_norm: [ + default: true, + doc: "" + ], + use_attention_bias: [ + default: true, + doc: "whether to use bias in query, key, and value projections" + ], + use_absolute_embeddings: [ + default: false, + doc: "" + ], + window_size: [ + default: 12, + doc: "" + ] + ] ++ Shared.common_options([:num_labels, :id_to_label]) + + @moduledoc """ + Swin Transformer model. + + ## Architectures + + * `:for_image_classification` - Swin tranformer model for image classification. + + ## Global layer options + + # {Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + # {Shared.options_doc(options)} + + ## References + + * [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:for_image_classification] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(spec) do + %{ + "pixel_values" => + Nx.template({1, spec.image_size, spec.image_size, spec.num_channels}, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + spec + |> inputs() + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_image_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Layers.take_token(index: 0, axis: 1) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "image_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + defp inputs(spec) do + shape = {nil, spec.image_size, spec.image_size, spec.num_channels} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("pixel_values", shape: shape), + Axon.input("patch_mask", shape: {nil, nil}, optional: true) + ]) + end + + # Contrary to Python implementation we do not have here argument + # bool_maked_pos. This parameter is propagated from model through + # core to embedder. + defp core(inputs, spec, opts \\ []) do + name = opts[:name] + + embeddings = + embedder(inputs["pixel_values"], spec, name: join(name, "embedder")) + + {hidden_state, hidden_states, attentions} = + encoder(embeddings, spec, name: join(name, "encoder")) + + hidden_state = + Axon.layer_norm(hidden_state, + epsilon: spec.layer_norm_epsilon, + name: join(name, "norm") + ) + + pooled_state = + Axon.adaptive_avg_pool(hidden_state, output_size: {1, 1}, name: join(name, "pooler")) + + %{ + hidden_state: hidden_state, + pooled_state: pooled_state, + hidden_states: hidden_states, + attentions: attentions + } + end + + defp embedder(pixel_values, spec, opts) do + name = opts[:name] + + embeddings = + pixel_values + |> patch_embedding(spec, name: join(name, "patch_embedding")) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) + + embeddings = + if spec.use_absolute_embeddings do + num_patches = div(spec.image_size, spec.patch_size) ** 2 + + position_embeddings = + Layers.learned_embeddings(num_patches, spec.embed_dim, + initializer: :zeros, + name: join(name, "position_embedding") + ) + + Axon.add(embeddings, position_embeddings) + else + embeddings + end + + embeddings + |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) + end + + # TODO: How to get and return output dimensions + # They are used in Python later but here it is not clear + # how to get them till we have loadable implementation. + defp patch_embedding(pixel_values, spec, opts) do + name = opts[:name] + hidden_size = spec.embed_dim + + pixel_values + |> Axon.conv(hidden_size, + kernel_size: spec.patch_size, + strides: spec.patch_size, + padding: :valid, + kernel_initializer: kernel_initializer(spec), + name: join(name, "projection") + ) + |> Axon.reshape({:batch, :auto, spec.embed_dim}, name: join(name, "reshape")) + end + + defp encoder(hidden_state, spec, opts) do + hidden_states = Axon.container({hidden_state}) + attentions = Axon.container({}) + + 0..(length(spec.depths) - 1) + |> Enum.reduce( + {hidden_state, hidden_states, attentions}, + fn layer_idx, {hidden_state, hidden_states, attentions} -> + {hidden_state, attention, _cross_attention, _block_cache, _position_bias} = + stage(hidden_state, spec, layer_idx, opts) + + { + hidden_state, + Layers.append(hidden_states, hidden_state), + Layers.append(attentions, attention) + } + end + ) + end + + defp stage(hidden_state, spec, layer_idx, opts) do + grid_size = div(spec.image_size, spec.patch_size) + input_resolution = div(grid_size, 2 ** layer_idx) + num_attention_heads = Enum.at(spec.num_heads, layer_idx) + dim = spec.embed_dim * 2 ** layer_idx + + {hidden_state, attention, cross_attention, block_cache, position_bias} = + layer(hidden_state, num_attention_heads, dim, spec, opts) + + hidden_state = + if layer_idx < length(spec.depths) - 1 do + downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon) + else + hidden_state + end + + {hidden_state, attention, cross_attention, block_cache, position_bias} + end + + # Steps in Python implementation: + # Normalization + # if shift_size > 0 -> roll hidden states + # window partition + # attention with attention mask + # window reverse + # if shift_size > 0 -> roll shifted windows + # shortcut + drop_path(attention_windows) + # Normalization + # Intermediate + # add result of intermediate + defp layer(hidden_state, num_attention_heads, dim, spec, opts) do + name = opts[:name] + + # shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) + # depth = Enum.at(spec.depths, layer_idx) + + {hidden_state, attention, cross_attention, block_cache, position_bias} = + Layers.Transformer.block(hidden_state, + block_type: :norm_first, + num_attention_heads: num_attention_heads, + hidden_size: dim, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: round(spec.intermediate_size_ratio * dim), + activation: spec.activation + ], + name: join(name, "block_#{num_attention_heads}") + ) + + {hidden_state, attention, cross_attention, block_cache, position_bias} + end + + defp downsample(hidden_state, input_resolution, dim, norm_epsilon) do + Axon.nx(hidden_state, fn x -> + {batch_size, _dim, num_channels} = Nx.shape(x) + + x = Nx.reshape(x, {batch_size, input_resolution, input_resolution, :auto}) + + input_feature_0 = x[[.., 0..-1//2, 0..-1//2, ..]] + input_feature_1 = x[[.., 1..-1//2, 0..-1//2, ..]] + input_feature_2 = x[[.., 0..-1//2, 1..-1//2, ..]] + input_feature_3 = x[[.., 1..-1//2, 1..-1//2, ..]] + + Nx.concatenate([input_feature_0, input_feature_1, input_feature_2, input_feature_3], + axis: -1 + ) + |> Nx.reshape({batch_size, :auto, 4 * num_channels}) + end) + |> Axon.layer_norm(epsilon: norm_epsilon, name: "downsample_norm") + |> Axon.dense(2 * dim, + kernel_initializer: Axon.Initializers.uniform(), + name: "image_classification_head.output" + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + attention_dropout_rate: {"attention_probs_dropout_prob", number()}, + depths: {"depths", list(number())}, + drop_path_rate: {"drop_path_rate", number()}, + embed_dim: {"embed_dim", number()}, + activation: {"hidden_act", activation()}, + dropout_rate: {"hidden_dropout_prob", number()}, + image_size: {"image_size", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"layer_norm_eps", number()}, + intermediate_size_ratio: {"mlp_ratio", number()}, + num_channels: {"num_channels", number()}, + num_heads: {"num_heads", list(number())}, + patch_size: {"patch_size", number()}, + path_norm: {"path_norm", boolean()}, + use_attention_bias: {"qkv_bias", boolean()}, + use_absolute_embeddings: {"use_absolute_embeddings", boolean()}, + window_size: {"window_size", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.patch_embedding.projection" => "swin.embeddings.patch_embeddings.projection", + "embedder.class_embedding" => %{ + "embeddings" => { + [{"swin.embeddings", "cls_token"}], + fn [value] -> Nx.squeeze(value, axes: [0]) end + } + }, + "embedder.position_embedding" => %{ + "embeddings" => { + [{"swin.embeddings", "position_embeddings"}], + fn [value] -> Nx.squeeze(value, axes: [0]) end + } + }, + "encoder.block_{n}.self_attention_norm" => "swin.encoder.layer.{n}.layernorm_before", + "encoder.block_{n}.self_attention.key" => + "swin.encoder.layer.{n}.attention.attention.key", + "encoder.block_{n}.self_attention.query" => + "swin.encoder.layer.{n}.attention.attention.query", + "encoder.block_{n}.self_attention.value" => + "swin.encoder.layer.{n}.attention.attention.value", + "encoder.block_{n}.self_attention.output" => + "swin.encoder.layer.{n}.attention.output.dense", + "encoder.block_{n}.ffn.intermediate" => "swin.encoder.layer.{n}.intermediate.dense", + "encoder.block_{n}.ffn.output" => "swin.encoder.layer.{n}.output.dense", + "encoder.block_{n}.output_norm" => "swin.encoder.layer.{n}.layernorm_after", + "norm" => "swin.layernorm", + "pooler.output" => "swin.pooler.dense", + "image_classification_head.output" => "classifier", + "masked_image_modeling_head.output" => "decoder.0", + "layer_norm_{n}" => "swin.encoder.layers.{n}.blocks.{n}.layernorm", + "layer_{n}_downsample_norm" => "swin.encoder.layers.{n}.downsample.norm", + "downsample_norm" => "swin.encoder.downsample.norm" + } + end + end + + defp roll(%Nx.Tensor{} = x, opts \\ []) do + opts = Keyword.validate!(opts, shifts: [], axes: []) + shifts = opts[:shifts] + axes = opts[:axes] + + if length(shifts) != length(axes) do + raise ArgumentError, "shifts and axes must align, shifts: #{shifts}, axes: #{axes}" + else + shape = Nx.shape(x) |> Tuple.to_list() + + Enum.zip(shifts, axes) + |> Enum.reduce(x, fn {shift, dim}, acc -> + shift = rem(shift, Enum.at(shape, dim)) |> IO.inspect(label: :shift) + + if 0 < shift do + {base, move} = Nx.split(acc, -1 * shift, axis: dim) |> IO.inspect() + Nx.concatenate([move, base], axis: dim) + else + acc + end + end) + end + end +end diff --git a/mix.lock b/mix.lock index bf50a962..1ac77e5d 100644 --- a/mix.lock +++ b/mix.lock @@ -30,6 +30,7 @@ "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, + "table_rex": {:hex, :table_rex, "4.0.0", "3c613a68ebdc6d4d1e731bc973c233500974ec3993c99fcdabb210407b90959b", [:mix], [], "hexpm", "c35c4d5612ca49ebb0344ea10387da4d2afe278387d4019e4d8111e815df8f55"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, "torchx": {:hex, :torchx, "0.7.0", "c71fd603b0133ed8709450d82aa3434cbcf485a37c9a68e9ebcce86f5e4fb7f0", [:mix], [{:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "a324079c56bb67750b1da16f859d994982bb467020a8c2cba324639552f3adb8"}, From a0dd29907799258051655acf0ecb7a4a0eada2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Thu, 13 Jun 2024 09:58:48 +0200 Subject: [PATCH 02/15] feat: windows partitioning and calculating attention mask --- lib/bumblebee/vision/swin.ex | 64 ++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index da4a8f4b..37fdd77f 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -312,6 +312,70 @@ defmodule Bumblebee.Vision.Swin do {hidden_state, attention, cross_attention, block_cache, position_bias} end + def attention_mask(height, width, window_size, shift_size) do + if shift_size > 0 do + # calculate attention mask for shifted window multi-head self-attention (SW-MSA) + img_mask = Nx.broadcast(0.0, {1, height, width, 1}) + + hslices = [ + 0..(height - window_size - 1), + (height - window_size)..(height - shift_size - 1), + (height - shift_size)..(height - 1) + ] + + wslices = [ + 0..(width - window_size - 1), + (width - window_size)..(width - shift_size - 1), + (width - shift_size)..(width - 1) + ] + + {img_mask, count} = + for hrange <- hslices, wrange <- wslices, reduce: {img_mask, 0.0} do + {mask, count} -> + mask = + for hidx <- hrange, widx <- wrange, reduce: mask do + deepest_mask -> + Nx.indexed_put(deepest_mask, Nx.tensor([0, hidx, widx, 0]), count) + end + + {mask, count + 1.0} + end + + mask_windows = + img_mask + |> window_partition(window_size) + |> Nx.reshape({:auto, window_size * window_size}) + + mask_windows + |> Nx.new_axis(1) + |> Nx.subtract(Nx.new_axis(mask_windows, 2)) + |> Nx.equal(0) + |> Nx.logical_not() + |> Nx.select(-100.0, 0) + else + nil + end + end + + defp window_partition(%Axon{} = input_feature, window_size) do + input_feature + |> Axon.nx(fn x -> window_partition(x, window_size) end) + end + + defp window_partition(%Nx.Tensor{} = tensor, window_size) do + {batch_size, height, width, num_channels} = Nx.shape(tensor) + windowed_height = div(height, window_size) + windowed_width = div(width, window_size) + + # TODO: Check last reshape (in Python - view(-1, window_size, window_size, num_channels)) + Nx.reshape( + tensor, + {batch_size, windowed_height, window_size, windowed_width, window_size, num_channels} + ) + |> Nx.transpose(axes: [0, 1, 3, 2, 4, 5]) + |> Nx.reshape({:auto, window_size, window_size, num_channels}) + end + defp downsample(hidden_state, input_resolution, dim, norm_epsilon) do Axon.nx(hidden_state, fn x -> {batch_size, _dim, num_channels} = Nx.shape(x) From eee7a46fe51dc6c712581f808dfd5655024760d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Thu, 20 Jun 2024 11:14:33 +0200 Subject: [PATCH 03/15] feat: completed all layers but still not functional --- lib/bumblebee/vision/swin.ex | 213 +++++++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 62 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index 37fdd77f..164d000e 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -217,9 +217,6 @@ defmodule Bumblebee.Vision.Swin do |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout")) end - # TODO: How to get and return output dimensions - # They are used in Python later but here it is not clear - # how to get them till we have loadable implementation. defp patch_embedding(pixel_values, spec, opts) do name = opts[:name] hidden_size = spec.embed_dim @@ -239,77 +236,148 @@ defmodule Bumblebee.Vision.Swin do hidden_states = Axon.container({hidden_state}) attentions = Axon.container({}) - 0..(length(spec.depths) - 1) - |> Enum.reduce( - {hidden_state, hidden_states, attentions}, - fn layer_idx, {hidden_state, hidden_states, attentions} -> - {hidden_state, attention, _cross_attention, _block_cache, _position_bias} = - stage(hidden_state, spec, layer_idx, opts) + state = { + hidden_state, + hidden_states, + attentions + } + + for layer_idx <- 0..(length(spec.depths) - 1), reduce: state do + {hidden_state, hidden_states, attentions} -> + grid_size = div(spec.image_size, spec.patch_size) + input_resolution = div(grid_size, 2 ** layer_idx) + dim = spec.embed_dim * 2 ** layer_idx + + {hidden_state, attention} = + layer(hidden_state, dim, layer_idx, spec, opts[:name]) + + hidden_state = + if layer_idx < length(spec.depths) - 1 do + downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon) + else + hidden_state + end { hidden_state, Layers.append(hidden_states, hidden_state), Layers.append(attentions, attention) } - end - ) + end end - defp stage(hidden_state, spec, layer_idx, opts) do - grid_size = div(spec.image_size, spec.patch_size) - input_resolution = div(grid_size, 2 ** layer_idx) - num_attention_heads = Enum.at(spec.num_heads, layer_idx) - dim = spec.embed_dim * 2 ** layer_idx + defp layer(hidden_state, dim, layer_idx, spec, name) do + attention = + hidden_state + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) + |> reshape(layer_idx) + |> hidden_windows(layer_idx, spec) + |> attention_window(layer_idx, dim, spec, name) + |> Axon.dropout(rate: spec.dropout_rate) - {hidden_state, attention, cross_attention, block_cache, position_bias} = - layer(hidden_state, num_attention_heads, dim, spec, opts) + output = + Axon.add(hidden_state, hidden_state) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) + |> Axon.dense(round(spec.intermediate_size_ratio * dim)) + |> Layers.activation(spec.activation) + |> Axon.dense(round(spec.intermediate_size_ratio * dim)) + |> Axon.dropout(rate: spec.dropout_rate) - hidden_state = - if layer_idx < length(spec.depths) - 1 do - downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon) - else - hidden_state - end + hidden_state = Axon.add(hidden_state, output) - {hidden_state, attention, cross_attention, block_cache, position_bias} + {hidden_state, attention} end - # Steps in Python implementation: - # Normalization - # if shift_size > 0 -> roll hidden states - # window partition - # attention with attention mask - # window reverse - # if shift_size > 0 -> roll shifted windows - # shortcut + drop_path(attention_windows) - # Normalization - # Intermediate - # add result of intermediate - defp layer(hidden_state, num_attention_heads, dim, spec, opts) do - name = opts[:name] + defp reshape(input, layer_idx) do + input + |> Axon.nx( + fn x -> + {batch_size, dimension, num_channels} = Nx.shape(x) + height_width = dimension |> :math.sqrt() |> floor() + + x + |> Nx.reshape({batch_size, height_width, height_width, num_channels}) + end, + name: "reshape_#{layer_idx}" + ) + end - # shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) - # depth = Enum.at(spec.depths, layer_idx) + defp hidden_windows(input, layer_idx, spec) do + shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) + + input + |> Axon.nx( + fn x -> + {_batch_size, height, width, num_channels} = Nx.shape(x) + + {shift_size, _window_size} = + if min(height, width) <= spec.window_size, + do: {0, min(height, width)}, + else: {shift_size, spec.window_size} + + shiffted_hidden_state = + if shift_size > 0, + do: roll(x, shifts: {-shift_size, -shift_size}, dims: {1, 2}), + else: x + + shiffted_hidden_state + |> window_partition(spec.window_size) + |> Nx.reshape({:auto, spec.window_size * spec.window_size, num_channels}) + end, + name: "hidden_windows_#{layer_idx}" + ) + end - {hidden_state, attention, cross_attention, block_cache, position_bias} = - Layers.Transformer.block(hidden_state, - block_type: :norm_first, - num_attention_heads: num_attention_heads, - hidden_size: dim, - kernel_initializer: kernel_initializer(spec), - dropout_rate: spec.dropout_rate, - attention_dropout_rate: spec.attention_dropout_rate, - layer_norm: [ - epsilon: spec.layer_norm_epsilon - ], - ffn: [ - intermediate_size: round(spec.intermediate_size_ratio * dim), - activation: spec.activation - ], - name: join(name, "block_#{num_attention_heads}") - ) + defp attention_window(input, layer_idx, dim, spec, name) do + num_attention_heads = Enum.at(spec.num_heads, layer_idx) + shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) + + input + |> Axon.nx( + fn x -> + {batch_size, dimension, num_channels} = Nx.shape(x) + height_width = dimension |> :math.sqrt() |> floor() + + {shift_size, window_size} = + if height_width <= spec.window_size, + do: {0, height_width}, + else: {shift_size, spec.window_size} + + attn_mask = attention_mask(height_width, height_width, window_size, shift_size) + + {_hidden_state, attention, _self_attention_cache, _attention_relative_bias} = + Bumblebee.Layers.Transformer.multi_head_attention( + input, + input, + input, + attention_mask: attn_mask, + num_heads: num_attention_heads, + hidden_size: dim, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + name: join(name, "self_attention_#{layer_idx}") + ) - {hidden_state, attention, cross_attention, block_cache, position_bias} + attention = Axon.dropout(attention, rate: spec.dropout_rate) + + shifted_windows = + attention + |> Axon.reshape({:auto, window_size, window_size, num_channels}) + |> window_reverse(window_size) + + att_window = + if shift_size > 0 do + roll(shifted_windows, shifts: {shift_size, shift_size}, dims: {1, 2}) + |> Nx.reshape({batch_size, height_width * height_width, num_channels}) + else + shifted_windows + |> Nx.reshape({batch_size, height_width * height_width, num_channels}) + end + + att_window + end, + name: "attention_windows_#{layer_idx}" + ) end def attention_mask(height, width, window_size, shift_size) do @@ -329,7 +397,7 @@ defmodule Bumblebee.Vision.Swin do (width - shift_size)..(width - 1) ] - {img_mask, count} = + {img_mask, _count} = for hrange <- hslices, wrange <- wslices, reduce: {img_mask, 0.0} do {mask, count} -> mask = @@ -353,7 +421,7 @@ defmodule Bumblebee.Vision.Swin do |> Nx.logical_not() |> Nx.select(-100.0, 0) else - nil + Layers.none() end end @@ -367,7 +435,6 @@ defmodule Bumblebee.Vision.Swin do windowed_height = div(height, window_size) windowed_width = div(width, window_size) - # TODO: Check last reshape (in Python - view(-1, window_size, window_size, num_channels)) Nx.reshape( tensor, {batch_size, windowed_height, window_size, windowed_width, window_size, num_channels} @@ -376,6 +443,24 @@ defmodule Bumblebee.Vision.Swin do |> Nx.reshape({:auto, window_size, window_size, num_channels}) end + defp window_reverse(%Axon{} = input_feature, window_size) do + input_feature + |> Axon.nx(fn x -> window_reverse(x, window_size) end) + end + + defp window_reverse(%Nx.Tensor{} = tensor, window_size) do + {_batch_size, height, width, num_channels} = Nx.shape(tensor) + windowed_height = div(height, window_size) + windowed_width = div(width, window_size) + + Nx.reshape( + tensor, + {:auto, windowed_height, windowed_width, window_size, window_size, num_channels} + ) + |> Nx.transpose(axes: [0, 1, 3, 2, 4, 5]) + |> Nx.reshape({:auto, height, width, num_channels}) + end + defp downsample(hidden_state, input_resolution, dim, norm_epsilon) do Axon.nx(hidden_state, fn x -> {batch_size, _dim, num_channels} = Nx.shape(x) @@ -471,7 +556,11 @@ defmodule Bumblebee.Vision.Swin do end end - defp roll(%Nx.Tensor{} = x, opts \\ []) do + defp roll(%Axon{} = x, opts) do + Axon.nx(x, fn y -> roll(y, opts) end) + end + + defp roll(%Nx.Tensor{} = x, opts) do opts = Keyword.validate!(opts, shifts: [], axes: []) shifts = opts[:shifts] axes = opts[:axes] From a0326f32ade9b7d4a18503311355d7b0710565c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Thu, 20 Jun 2024 15:51:27 +0200 Subject: [PATCH 04/15] fix: attention mask is expected to be boolean mask --- lib/bumblebee/vision/swin.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index 164d000e..ebed5022 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -419,7 +419,6 @@ defmodule Bumblebee.Vision.Swin do |> Nx.subtract(Nx.new_axis(mask_windows, 2)) |> Nx.equal(0) |> Nx.logical_not() - |> Nx.select(-100.0, 0) else Layers.none() end From 132bc60e71c1a6d75aac0ace9bc47c6ed7fe1a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Fri, 21 Jun 2024 10:42:31 +0200 Subject: [PATCH 05/15] feat: splitting layers as suggested by Jonatan (still not able to load model) --- lib/bumblebee/vision/swin.ex | 97 ++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 37 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index ebed5022..f53424b1 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -267,16 +267,19 @@ defmodule Bumblebee.Vision.Swin do end defp layer(hidden_state, dim, layer_idx, spec, name) do + attn_mask = attention_mask_layer(hidden_state, layer_idx, spec) + attention = hidden_state |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) |> reshape(layer_idx) |> hidden_windows(layer_idx, spec) - |> attention_window(layer_idx, dim, spec, name) + |> attention_window(attn_mask, layer_idx, dim, spec, name) + |> unroll(layer_idx, spec) |> Axon.dropout(rate: spec.dropout_rate) output = - Axon.add(hidden_state, hidden_state) + Axon.add(hidden_state, attention) |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) |> Axon.dense(round(spec.intermediate_size_ratio * dim)) |> Layers.activation(spec.activation) @@ -328,14 +331,33 @@ defmodule Bumblebee.Vision.Swin do ) end - defp attention_window(input, layer_idx, dim, spec, name) do + defp attention_window(input, attention_mask, layer_idx, dim, spec, name) do num_attention_heads = Enum.at(spec.num_heads, layer_idx) + + {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = + Bumblebee.Layers.Transformer.multi_head_attention( + input, + input, + input, + attention_mask: attention_mask, + num_heads: num_attention_heads, + hidden_size: dim, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + name: join(name, "self_attention_#{layer_idx}") + ) + + attention = Axon.dropout(attention, rate: spec.dropout_rate) + + {hidden_state, attention} + end + + def unroll({hidden_state, attention}, layer_idx, spec) do shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) - input - |> Axon.nx( - fn x -> - {batch_size, dimension, num_channels} = Nx.shape(x) + Axon.layer( + fn input, att, _ -> + {batch_size, dimension, num_channels} = Nx.shape(input) height_width = dimension |> :math.sqrt() |> floor() {shift_size, window_size} = @@ -343,40 +365,41 @@ defmodule Bumblebee.Vision.Swin do do: {0, height_width}, else: {shift_size, spec.window_size} - attn_mask = attention_mask(height_width, height_width, window_size, shift_size) - - {_hidden_state, attention, _self_attention_cache, _attention_relative_bias} = - Bumblebee.Layers.Transformer.multi_head_attention( - input, - input, - input, - attention_mask: attn_mask, - num_heads: num_attention_heads, - hidden_size: dim, - kernel_initializer: kernel_initializer(spec), - dropout_rate: spec.dropout_rate, - name: join(name, "self_attention_#{layer_idx}") - ) - - attention = Axon.dropout(attention, rate: spec.dropout_rate) - shifted_windows = - attention - |> Axon.reshape({:auto, window_size, window_size, num_channels}) + att + |> Nx.reshape({:auto, window_size, window_size, num_channels}) |> window_reverse(window_size) - att_window = - if shift_size > 0 do - roll(shifted_windows, shifts: {shift_size, shift_size}, dims: {1, 2}) - |> Nx.reshape({batch_size, height_width * height_width, num_channels}) - else - shifted_windows - |> Nx.reshape({batch_size, height_width * height_width, num_channels}) - end + if shift_size > 0 do + roll(shifted_windows, shifts: {shift_size, shift_size}, dims: {1, 2}) + |> Nx.reshape({batch_size, height_width * height_width, num_channels}) + else + shifted_windows + |> Nx.reshape({batch_size, height_width * height_width, num_channels}) + end + end, + [hidden_state, attention], + name: "unroll_#{layer_idx}" + ) + end + + defp attention_mask_layer(hidden_state, layer_idx, spec) do + shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) + + hidden_state + |> Axon.nx( + fn x -> + {_batch_size, dimension, _num_channels} = Nx.shape(x) + height_width = dimension |> :math.sqrt() |> floor() + + {shift_size, window_size} = + if height_width <= spec.window_size, + do: {0, height_width}, + else: {shift_size, spec.window_size} - att_window + attention_mask(height_width, height_width, window_size, shift_size) end, - name: "attention_windows_#{layer_idx}" + name: "att_mask_#{layer_idx}" ) end @@ -420,7 +443,7 @@ defmodule Bumblebee.Vision.Swin do |> Nx.equal(0) |> Nx.logical_not() else - Layers.none() + %Axon.None{} end end From 6ba22ae7c3f549835201fc52b66e9da7f5df2fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Thu, 27 Jun 2024 14:32:24 +0200 Subject: [PATCH 06/15] fix: attention results handling, roll function call and dense calculation --- lib/bumblebee/vision/swin.ex | 40 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index f53424b1..092e1980 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -267,23 +267,27 @@ defmodule Bumblebee.Vision.Swin do end defp layer(hidden_state, dim, layer_idx, spec, name) do + shortcut = hidden_state attn_mask = attention_mask_layer(hidden_state, layer_idx, spec) - attention = + {hidden_state, attention} = hidden_state |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) |> reshape(layer_idx) |> hidden_windows(layer_idx, spec) - |> attention_window(attn_mask, layer_idx, dim, spec, name) + |> attention(attn_mask, layer_idx, dim, spec, name) + + hidden_state = + {hidden_state, shortcut} |> unroll(layer_idx, spec) |> Axon.dropout(rate: spec.dropout_rate) output = - Axon.add(hidden_state, attention) + Axon.add(shortcut, hidden_state) |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) - |> Axon.dense(round(spec.intermediate_size_ratio * dim)) + |> Axon.dense(dim) |> Layers.activation(spec.activation) - |> Axon.dense(round(spec.intermediate_size_ratio * dim)) + |> Axon.dense(dim) |> Axon.dropout(rate: spec.dropout_rate) hidden_state = Axon.add(hidden_state, output) @@ -320,7 +324,7 @@ defmodule Bumblebee.Vision.Swin do shiffted_hidden_state = if shift_size > 0, - do: roll(x, shifts: {-shift_size, -shift_size}, dims: {1, 2}), + do: roll(x, shifts: [-shift_size, -shift_size], axes: [1, 2]), else: x shiffted_hidden_state @@ -331,7 +335,7 @@ defmodule Bumblebee.Vision.Swin do ) end - defp attention_window(input, attention_mask, layer_idx, dim, spec, name) do + defp attention(input, attention_mask, layer_idx, dim, spec, name) do num_attention_heads = Enum.at(spec.num_heads, layer_idx) {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = @@ -343,20 +347,24 @@ defmodule Bumblebee.Vision.Swin do num_heads: num_attention_heads, hidden_size: dim, kernel_initializer: kernel_initializer(spec), - dropout_rate: spec.dropout_rate, + dropout_rate: spec.attention_dropout_rate, name: join(name, "self_attention_#{layer_idx}") ) - attention = Axon.dropout(attention, rate: spec.dropout_rate) + hidden_state = + Axon.dropout(hidden_state, + rate: spec.dropout_rate, + name: join(name, "self_attention_dropout") + ) {hidden_state, attention} end - def unroll({hidden_state, attention}, layer_idx, spec) do + def unroll({hidden_state, input}, layer_idx, spec) do shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) Axon.layer( - fn input, att, _ -> + fn state, input, _ -> {batch_size, dimension, num_channels} = Nx.shape(input) height_width = dimension |> :math.sqrt() |> floor() @@ -366,19 +374,19 @@ defmodule Bumblebee.Vision.Swin do else: {shift_size, spec.window_size} shifted_windows = - att + state |> Nx.reshape({:auto, window_size, window_size, num_channels}) |> window_reverse(window_size) if shift_size > 0 do - roll(shifted_windows, shifts: {shift_size, shift_size}, dims: {1, 2}) + roll(shifted_windows, shifts: [shift_size, shift_size], axes: [1, 2]) |> Nx.reshape({batch_size, height_width * height_width, num_channels}) else shifted_windows |> Nx.reshape({batch_size, height_width * height_width, num_channels}) end end, - [hidden_state, attention], + [hidden_state, input], name: "unroll_#{layer_idx}" ) end @@ -594,10 +602,10 @@ defmodule Bumblebee.Vision.Swin do Enum.zip(shifts, axes) |> Enum.reduce(x, fn {shift, dim}, acc -> - shift = rem(shift, Enum.at(shape, dim)) |> IO.inspect(label: :shift) + shift = rem(shift, Enum.at(shape, dim)) if 0 < shift do - {base, move} = Nx.split(acc, -1 * shift, axis: dim) |> IO.inspect() + {base, move} = Nx.split(acc, -1 * shift, axis: dim) Nx.concatenate([move, base], axis: dim) else acc From 3c8e8b3e3cc6ddb1bc304b8ceb2321a9b5195f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Wed, 24 Jul 2024 14:45:27 +0200 Subject: [PATCH 07/15] feat: added all layers for each block --- lib/bumblebee/vision/swin.ex | 115 ++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index 092e1980..e075afb9 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -176,7 +176,7 @@ defmodule Bumblebee.Vision.Swin do hidden_state = Axon.layer_norm(hidden_state, epsilon: spec.layer_norm_epsilon, - name: join(name, "norm") + name: join(name, "layernorm") ) pooled_state = @@ -233,6 +233,7 @@ defmodule Bumblebee.Vision.Swin do end defp encoder(hidden_state, spec, opts) do + name = opts[:name] hidden_states = Axon.container({hidden_state}) attentions = Axon.container({}) @@ -242,21 +243,10 @@ defmodule Bumblebee.Vision.Swin do attentions } - for layer_idx <- 0..(length(spec.depths) - 1), reduce: state do + for stage_idx <- 0..(length(spec.depths) - 1), reduce: state do {hidden_state, hidden_states, attentions} -> - grid_size = div(spec.image_size, spec.patch_size) - input_resolution = div(grid_size, 2 ** layer_idx) - dim = spec.embed_dim * 2 ** layer_idx - {hidden_state, attention} = - layer(hidden_state, dim, layer_idx, spec, opts[:name]) - - hidden_state = - if layer_idx < length(spec.depths) - 1 do - downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon) - else - hidden_state - end + stage(hidden_state, stage_idx, spec, join("#{name}.blocks", stage_idx)) { hidden_state, @@ -266,16 +256,42 @@ defmodule Bumblebee.Vision.Swin do end end - defp layer(hidden_state, dim, layer_idx, spec, name) do + defp stage(hidden_state, stage_idx, spec, name) do + grid_size = div(spec.image_size, spec.patch_size) + input_resolution = div(grid_size, 2 ** stage_idx) + dim = spec.embed_dim * 2 ** stage_idx + num_attention_heads = Enum.at(spec.num_heads, stage_idx) + + {hidden_state, attention} = + for layer_idx <- 0..(Enum.at(spec.depths, stage_idx) - 1), reduce: {hidden_state, nil} do + {hidden_state, _} -> + {hidden_state, attention} = + layer(hidden_state, layer_idx, dim, num_attention_heads, spec, name) + + {hidden_state, attention} + end + + hidden_state = + if stage_idx < length(spec.depths) - 1 do + downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon, name) + else + hidden_state + end + + {hidden_state, attention} + end + + defp layer(hidden_state, layer_idx, dim, num_attention_heads, spec, name) do shortcut = hidden_state attn_mask = attention_mask_layer(hidden_state, layer_idx, spec) + name = join(name, "layer.#{layer_idx}") {hidden_state, attention} = hidden_state - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_before")) |> reshape(layer_idx) |> hidden_windows(layer_idx, spec) - |> attention(attn_mask, layer_idx, dim, spec, name) + |> attention(attn_mask, num_attention_heads, dim, spec, name) hidden_state = {hidden_state, shortcut} @@ -284,10 +300,10 @@ defmodule Bumblebee.Vision.Swin do output = Axon.add(shortcut, hidden_state) - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) - |> Axon.dense(dim) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_after")) + |> Axon.dense(dim, name: join(name, "dense")) |> Layers.activation(spec.activation) - |> Axon.dense(dim) + |> Axon.dense(dim, name: join(name, "dense")) |> Axon.dropout(rate: spec.dropout_rate) hidden_state = Axon.add(hidden_state, output) @@ -335,9 +351,7 @@ defmodule Bumblebee.Vision.Swin do ) end - defp attention(input, attention_mask, layer_idx, dim, spec, name) do - num_attention_heads = Enum.at(spec.num_heads, layer_idx) - + defp attention(input, attention_mask, num_attention_heads, dim, spec, name) do {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = Bumblebee.Layers.Transformer.multi_head_attention( input, @@ -348,7 +362,7 @@ defmodule Bumblebee.Vision.Swin do hidden_size: dim, kernel_initializer: kernel_initializer(spec), dropout_rate: spec.attention_dropout_rate, - name: join(name, "self_attention_#{layer_idx}") + name: join(name, "self_attention") ) hidden_state = @@ -491,7 +505,7 @@ defmodule Bumblebee.Vision.Swin do |> Nx.reshape({:auto, height, width, num_channels}) end - defp downsample(hidden_state, input_resolution, dim, norm_epsilon) do + defp downsample(hidden_state, input_resolution, dim, norm_epsilon, name) do Axon.nx(hidden_state, fn x -> {batch_size, _dim, num_channels} = Nx.shape(x) @@ -507,10 +521,10 @@ defmodule Bumblebee.Vision.Swin do ) |> Nx.reshape({batch_size, :auto, 4 * num_channels}) end) - |> Axon.layer_norm(epsilon: norm_epsilon, name: "downsample_norm") + |> Axon.layer_norm(epsilon: norm_epsilon, name: join(name, "downsample_norm")) |> Axon.dense(2 * dim, kernel_initializer: Axon.Initializers.uniform(), - name: "image_classification_head.output" + name: join(name, "downsample_reduction") ) end @@ -550,38 +564,25 @@ defmodule Bumblebee.Vision.Swin do defimpl Bumblebee.HuggingFace.Transformers.Model do def params_mapping(_spec) do %{ + "layernorm" => "swin.layernorm", + "encoder.blocks.{n}.layer.{m}.self_attention.output" => + "swin.encoder.layers.{n}.blocks.{m}.attention.output.dense", + "encoder.blocks.{n}.layer.{m}.self_attention.value" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self.value", + "encoder.blocks.{n}.layer.{m}.self_attention.query" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self.query", + "encoder.blocks.{n}.layer.{m}.self_attention.key" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self.key", + "encoder.blocks.{n}.layer.{m}.layernorm_before" => + "swin.encoder.layers.{n}.blocks.{m}.layernorm_before", + "encoder.blocks.{n}.layer.{m}.layernorm_after" => + "swin.encoder.layers.{n}.blocks.{m}.layernorm_after", "embedder.patch_embedding.projection" => "swin.embeddings.patch_embeddings.projection", - "embedder.class_embedding" => %{ - "embeddings" => { - [{"swin.embeddings", "cls_token"}], - fn [value] -> Nx.squeeze(value, axes: [0]) end - } - }, - "embedder.position_embedding" => %{ - "embeddings" => { - [{"swin.embeddings", "position_embeddings"}], - fn [value] -> Nx.squeeze(value, axes: [0]) end - } - }, - "encoder.block_{n}.self_attention_norm" => "swin.encoder.layer.{n}.layernorm_before", - "encoder.block_{n}.self_attention.key" => - "swin.encoder.layer.{n}.attention.attention.key", - "encoder.block_{n}.self_attention.query" => - "swin.encoder.layer.{n}.attention.attention.query", - "encoder.block_{n}.self_attention.value" => - "swin.encoder.layer.{n}.attention.attention.value", - "encoder.block_{n}.self_attention.output" => - "swin.encoder.layer.{n}.attention.output.dense", - "encoder.block_{n}.ffn.intermediate" => "swin.encoder.layer.{n}.intermediate.dense", - "encoder.block_{n}.ffn.output" => "swin.encoder.layer.{n}.output.dense", - "encoder.block_{n}.output_norm" => "swin.encoder.layer.{n}.layernorm_after", - "norm" => "swin.layernorm", - "pooler.output" => "swin.pooler.dense", + "encoder.blocks.{n}.downsample_norm" => "swin.encoder.layers.{n}.downsample.norm", + "encoder.blocks.{n}.downsample_reduction" => + "swin.encoder.layers.{n}.downsample.reduction", "image_classification_head.output" => "classifier", - "masked_image_modeling_head.output" => "decoder.0", - "layer_norm_{n}" => "swin.encoder.layers.{n}.blocks.{n}.layernorm", - "layer_{n}_downsample_norm" => "swin.encoder.layers.{n}.downsample.norm", - "downsample_norm" => "swin.encoder.downsample.norm" + "encoder.blocks.{n}.layer.{m}.dense" => "swin.encoder.layers.{n}.blocks.{m}.output.dense" } end end From ed36ec0913a343b7c1595db007b4473cda37d0d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Wed, 31 Jul 2024 11:18:14 +0200 Subject: [PATCH 08/15] feat: pad if needed and implemented relative position index function --- lib/bumblebee/vision/swin.ex | 46 +++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index e075afb9..d588f165 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -290,7 +290,7 @@ defmodule Bumblebee.Vision.Swin do hidden_state |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_before")) |> reshape(layer_idx) - |> hidden_windows(layer_idx, spec) + |> hidden_state_windows(layer_idx, spec) |> attention(attn_mask, num_attention_heads, dim, spec, name) hidden_state = @@ -325,7 +325,14 @@ defmodule Bumblebee.Vision.Swin do ) end - defp hidden_windows(input, layer_idx, spec) do + defp maybe_pad(input, window_size, height, width) do + pad1 = (window_size - rem(width, window_size)) |> rem(window_size) + pad2 = (window_size - rem(height, window_size)) |> rem(window_size) + + Nx.pad(input, 0, [{0, 0, 0}, {0, pad1, 0}, {0, pad2, 0}, {0, 0, 0}]) + end + + defp hidden_state_windows(input, layer_idx, spec) do shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) input @@ -333,22 +340,54 @@ defmodule Bumblebee.Vision.Swin do fn x -> {_batch_size, height, width, num_channels} = Nx.shape(x) + x = maybe_pad(x, spec.window_size, height, width) + + {_, height, width, _} = Nx.shape(x) + {shift_size, _window_size} = if min(height, width) <= spec.window_size, do: {0, min(height, width)}, else: {shift_size, spec.window_size} + # cyclic shift shiffted_hidden_state = if shift_size > 0, do: roll(x, shifts: [-shift_size, -shift_size], axes: [1, 2]), else: x + # partition windows shiffted_hidden_state |> window_partition(spec.window_size) |> Nx.reshape({:auto, spec.window_size * spec.window_size, num_channels}) end, - name: "hidden_windows_#{layer_idx}" + name: "hidden_state_windows_#{layer_idx}" + ) + end + + defp relative_position_index(window_size) do + coords_h = Nx.iota({window_size}) |> Nx.tile([window_size, 1]) |> Nx.transpose() + coords_w = Nx.iota({window_size}) |> Nx.tile([window_size, 1]) + + coords_flatten = + Nx.stack([coords_h, coords_w]) + # flatten dimension 1 + |> Nx.reshape({2, window_size * window_size}) + + relative_coords = + Nx.subtract(Nx.new_axis(coords_flatten, 2), Nx.new_axis(coords_flatten, 1)) + |> Nx.transpose(axes: [1, 2, 0]) + + relative_coords = + Nx.add( + relative_coords, + Nx.broadcast(Nx.tensor([window_size - 1, window_size - 1]), relative_coords) + ) + + Nx.multiply( + relative_coords, + Nx.broadcast(Nx.tensor([2 * window_size - 1, 1]), relative_coords) ) + |> Nx.sum(axes: [-1]) end defp attention(input, attention_mask, num_attention_heads, dim, spec, name) do @@ -375,6 +414,7 @@ defmodule Bumblebee.Vision.Swin do end def unroll({hidden_state, input}, layer_idx, spec) do + # reverse cyclic shift shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) Axon.layer( From 2666d6cba5bcbf795f1fb7960dd7a0bae59c2e50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Mon, 12 Aug 2024 08:54:52 +0200 Subject: [PATCH 09/15] feat: we are calculating attention due to Swin specifics --- lib/bumblebee/vision/swin.ex | 228 +++++++++++++++++++++++++++++++---- 1 file changed, 205 insertions(+), 23 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index d588f165..8f73b9c1 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -196,7 +196,7 @@ defmodule Bumblebee.Vision.Swin do embeddings = pixel_values |> patch_embedding(spec, name: join(name, "patch_embedding")) - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: "layernorm_embeddings") embeddings = if spec.use_absolute_embeddings do @@ -291,8 +291,11 @@ defmodule Bumblebee.Vision.Swin do |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_before")) |> reshape(layer_idx) |> hidden_state_windows(layer_idx, spec) - |> attention(attn_mask, num_attention_heads, dim, spec, name) + # |> attention_bumblebee(attn_mask, num_attention_heads, dim, spec, name) + |> attention(attn_mask, Layers.none(), num_attention_heads, dim, spec, name) + # TODO "unpad" if it was padded + # After unroll we have to reverse padding (before dropout) hidden_state = {hidden_state, shortcut} |> unroll(layer_idx, spec) @@ -301,9 +304,11 @@ defmodule Bumblebee.Vision.Swin do output = Axon.add(shortcut, hidden_state) |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_after")) - |> Axon.dense(dim, name: join(name, "dense")) + |> Axon.dense(round(spec.intermediate_size_ratio * dim), + name: join(name, "intermediate.dense") + ) |> Layers.activation(spec.activation) - |> Axon.dense(dim, name: join(name, "dense")) + |> Axon.dense(dim, name: join(name, "output.dense")) |> Axon.dropout(rate: spec.dropout_rate) hidden_state = Axon.add(hidden_state, output) @@ -390,7 +395,182 @@ defmodule Bumblebee.Vision.Swin do |> Nx.sum(axes: [-1]) end - defp attention(input, attention_mask, num_attention_heads, dim, spec, name) do + defp transpose_for_scores(x, num_attention_heads, attention_head_size) do + new_shape = + x + |> Nx.shape() + |> Tuple.to_list() + |> List.replace_at(-1, [num_attention_heads, attention_head_size]) + |> List.flatten() + |> List.to_tuple() + + x + |> Nx.reshape(new_shape) + |> Nx.transpose(axes: [0, 2, 1, 3]) + end + + defp attention(hidden_state, attention_mask, head_mask, num_attention_heads, dim, spec, name) do + attention_head_size = floor(dim / num_attention_heads) + all_head_size = num_attention_heads * attention_head_size + name = join(name, "self_attention") + + query = + hidden_state + |> Axon.dense(all_head_size, name: join(name, "query")) + |> Axon.nx(fn x -> + transpose_for_scores(x, num_attention_heads, attention_head_size) + end) + + key = + hidden_state + |> Axon.dense(all_head_size, name: join(name, "key")) + |> Axon.nx(fn x -> + transpose_for_scores(x, num_attention_heads, attention_head_size) + end) + + value = + hidden_state + |> Axon.dense(all_head_size, name: join(name, "value")) + |> Axon.nx(fn x -> + transpose_for_scores(x, num_attention_heads, attention_head_size) + end) + + relative_position_bias_table = + Axon.param( + "relative_position_bias_table", + {(2 * spec.window_size - 1) * (2 * spec.window_size - 1), num_attention_heads} + ) + + probabilities = + Axon.layer( + &attention_weights_impl/7, + [ + hidden_state, + query, + key, + relative_position_bias_table, + Axon.optional(attention_mask), + Axon.optional(head_mask) + ], + name: join(name, "weights"), + window_size: spec.window_size, + num_heads: num_attention_heads, + head_size: attention_head_size, + dropout_rate: spec.attention_dropout_rate + ) + + output_name = join(name, "output") + + context = + Axon.layer( + &attention_output_impl/3, + [probabilities, value], + name: output_name, + num_heads: num_attention_heads, + head_size: attention_head_size + ) + |> Axon.dense(dim, name: join(output_name, "dense")) + |> Axon.dropout( + rate: spec.attention_dropout_rate, + name: join(output_name, "dropout") + ) + + {context, probabilities} + end + + defp attention_weights_impl( + hidden_state, + query, + key, + relative_position_bias_table, + attention_mask, + head_mask, + opts + ) do + opts = + Keyword.validate!(opts, [:mode, :name, :window_size, :num_heads, :head_size, :dropout_rate]) + + {batch_size, dim, _num_channels} = Nx.shape(hidden_state) + + scores = + query + |> Nx.dot([3], [0, 1], Nx.transpose(key, axes: [0, 1, -1, -2]), [2], [0, 1]) + |> Nx.divide(Nx.sqrt(opts[:head_size])) + + rel_pos_idx = relative_position_index(opts[:window_size]) |> Nx.reshape({:auto}) + + relative_position_bias = + Nx.take(relative_position_bias_table, rel_pos_idx) + |> Nx.reshape( + {opts[:window_size] * opts[:window_size], opts[:window_size] * opts[:window_size], :auto} + ) + |> Nx.transpose(axes: [2, 0, 1]) + |> Nx.new_axis(0) + + scores = Nx.add(scores, relative_position_bias) + + scores = + case attention_mask do + %Axon.None{} -> + scores + + _ -> + {mask_size, _, _} = Nx.shape(attention_mask) + + scores = + Nx.reshape( + scores, + {floor(batch_size / mask_size), mask_size, opts[:num_heads], dim, dim} + ) + + attention_mask = + attention_mask + |> Nx.new_axis(1) + |> Nx.new_axis(0) + + scores + |> Nx.add(attention_mask) + |> Nx.reshape({:auto, opts[:num_heads], dim, dim}) + end + + # Normalize the attention scores to probabilities (softmax). + # + # This is actually dropping out entire tokens to attend to, which + # might seem a bit unusual, but is taken from the original + # Transformer paper (dropout). + seed = :erlang.system_time() + + probabilities = + Axon.Activations.softmax(scores, axis: -1) + |> Axon.Layers.dropout(Nx.Random.key(seed), rate: opts[:dropout_rate]) + + case head_mask do + %Axon.None{} -> + probabilities + + head_mask -> + Nx.multiply(probabilities, head_mask) + end + end + + def attention_output_impl(weights, value, opts) do + context = + weights + |> Nx.dot([3], [0, 1], value, [2], [0, 1]) + |> Nx.transpose(axes: [0, 2, 1, 3]) + + new_context_shape = + context + |> Nx.shape() + |> Tuple.to_list() + |> Enum.slice(0..-3//1) + |> Kernel.++([opts[:num_heads] * opts[:head_size]]) + |> List.to_tuple() + + Nx.reshape(context, new_context_shape) + end + + defp attention_bumblebee(input, attention_mask, num_attention_heads, dim, spec, name) do {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = Bumblebee.Layers.Transformer.multi_head_attention( input, @@ -509,11 +689,6 @@ defmodule Bumblebee.Vision.Swin do end end - defp window_partition(%Axon{} = input_feature, window_size) do - input_feature - |> Axon.nx(fn x -> window_partition(x, window_size) end) - end - defp window_partition(%Nx.Tensor{} = tensor, window_size) do {batch_size, height, width, num_channels} = Nx.shape(tensor) windowed_height = div(height, window_size) @@ -564,7 +739,8 @@ defmodule Bumblebee.Vision.Swin do |> Axon.layer_norm(epsilon: norm_epsilon, name: join(name, "downsample_norm")) |> Axon.dense(2 * dim, kernel_initializer: Axon.Initializers.uniform(), - name: join(name, "downsample_reduction") + name: join(name, "downsample_reduction"), + use_bias: false ) end @@ -604,25 +780,31 @@ defmodule Bumblebee.Vision.Swin do defimpl Bumblebee.HuggingFace.Transformers.Model do def params_mapping(_spec) do %{ - "layernorm" => "swin.layernorm", - "encoder.blocks.{n}.layer.{m}.self_attention.output" => + "encoder.blocks.{n}.layer.{m}.intermediate.dense" => + "swin.encoder.layers.{n}.blocks.{m}.intermediate.dense", + "encoder.blocks.{n}.layer.{m}.layernorm_after" => + "swin.encoder.layers.{n}.blocks.{m}.layernorm_after", + "encoder.blocks.{n}.layer.{m}.layernorm_before" => + "swin.encoder.layers.{n}.blocks.{m}.layernorm_before", + "encoder.blocks.{n}.layer.{m}.output.dense" => + "swin.encoder.layers.{n}.blocks.{m}.output.dense", + "encoder.blocks.{n}.layer.{m}.self_attention.key" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self.key", + "encoder.blocks.{n}.layer.{m}.self_attention.output.dense" => "swin.encoder.layers.{n}.blocks.{m}.attention.output.dense", - "encoder.blocks.{n}.layer.{m}.self_attention.value" => - "swin.encoder.layers.{n}.blocks.{m}.attention.self.value", "encoder.blocks.{n}.layer.{m}.self_attention.query" => "swin.encoder.layers.{n}.blocks.{m}.attention.self.query", - "encoder.blocks.{n}.layer.{m}.self_attention.key" => - "swin.encoder.layers.{n}.blocks.{m}.attention.self.key", - "encoder.blocks.{n}.layer.{m}.layernorm_before" => - "swin.encoder.layers.{n}.blocks.{m}.layernorm_before", - "encoder.blocks.{n}.layer.{m}.layernorm_after" => - "swin.encoder.layers.{n}.blocks.{m}.layernorm_after", + "encoder.blocks.{n}.layer.{m}.self_attention.value" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self.value", + "encoder.blocks.{n}.layer.{m}.self_attention.weights" => + "swin.encoder.layers.{n}.blocks.{m}.attention.self", + "layernorm" => "swin.layernorm", + "layernorm_embeddings" => "swin.embeddings.norm", "embedder.patch_embedding.projection" => "swin.embeddings.patch_embeddings.projection", "encoder.blocks.{n}.downsample_norm" => "swin.encoder.layers.{n}.downsample.norm", "encoder.blocks.{n}.downsample_reduction" => "swin.encoder.layers.{n}.downsample.reduction", - "image_classification_head.output" => "classifier", - "encoder.blocks.{n}.layer.{m}.dense" => "swin.encoder.layers.{n}.blocks.{m}.output.dense" + "image_classification_head.output" => "classifier" } end end From 87e9f92817ef6531b48d25cb65ada3fc392d6d61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Fri, 16 Aug 2024 08:25:24 +0200 Subject: [PATCH 10/15] feat: removed Bumblebee attention usage --- lib/bumblebee/vision/swin.ex | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index 8f73b9c1..cc72e4ec 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -291,7 +291,6 @@ defmodule Bumblebee.Vision.Swin do |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_before")) |> reshape(layer_idx) |> hidden_state_windows(layer_idx, spec) - # |> attention_bumblebee(attn_mask, num_attention_heads, dim, spec, name) |> attention(attn_mask, Layers.none(), num_attention_heads, dim, spec, name) # TODO "unpad" if it was padded @@ -570,29 +569,6 @@ defmodule Bumblebee.Vision.Swin do Nx.reshape(context, new_context_shape) end - defp attention_bumblebee(input, attention_mask, num_attention_heads, dim, spec, name) do - {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = - Bumblebee.Layers.Transformer.multi_head_attention( - input, - input, - input, - attention_mask: attention_mask, - num_heads: num_attention_heads, - hidden_size: dim, - kernel_initializer: kernel_initializer(spec), - dropout_rate: spec.attention_dropout_rate, - name: join(name, "self_attention") - ) - - hidden_state = - Axon.dropout(hidden_state, - rate: spec.dropout_rate, - name: join(name, "self_attention_dropout") - ) - - {hidden_state, attention} - end - def unroll({hidden_state, input}, layer_idx, spec) do # reverse cyclic shift shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) From 42a5a64e3aadcaad389aae3e3cbfd5ceeaf7ba17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Mon, 19 Aug 2024 11:14:50 +0200 Subject: [PATCH 11/15] fix: calculating layer output --- lib/bumblebee/vision/swin.ex | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index cc72e4ec..f51029e5 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -295,13 +295,15 @@ defmodule Bumblebee.Vision.Swin do # TODO "unpad" if it was padded # After unroll we have to reverse padding (before dropout) - hidden_state = + attention_windows = {hidden_state, shortcut} |> unroll(layer_idx, spec) |> Axon.dropout(rate: spec.dropout_rate) + hidden_state = Axon.add(shortcut, attention_windows) + output = - Axon.add(shortcut, hidden_state) + hidden_state |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_after")) |> Axon.dense(round(spec.intermediate_size_ratio * dim), name: join(name, "intermediate.dense") From 598431840048cab0eed65bad9586cef62663d16c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bo=C5=A1ko=20Ivani=C5=A1evi=C4=87?= Date: Mon, 19 Aug 2024 13:08:30 +0200 Subject: [PATCH 12/15] feat: added Swin test Swin returns 1000 probabilities ann test takes into account only first 10 values. --- test/bumblebee/vision/swin_test.exs | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 test/bumblebee/vision/swin_test.exs diff --git a/test/bumblebee/vision/swin_test.exs b/test/bumblebee/vision/swin_test.exs new file mode 100644 index 00000000..88c127f0 --- /dev/null +++ b/test/bumblebee/vision/swin_test.exs @@ -0,0 +1,42 @@ +defmodule Bumblebee.Vision.SwinTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":for_image_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "microsoft/swin-base-patch4-window12-384"}) + + assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 384, 384, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 1000} + + compare = outputs.logits[[0, 0..9]] + + assert_all_close( + compare, + Nx.tensor([ + [ + 6.9526e-02, + 8.5011e-01, + 4.5132e-01, + 5.4306e-01, + 2.4646e-01, + -2.2765e-03, + 6.9874e-02, + 1.3368e-01, + 4.6875e-01, + 8.8567e-01 + ] + ]) + ) + end +end From 12de8adecac8253f8fc7d4e52681b01b9d0f2ec7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 23 Aug 2024 17:32:33 +0900 Subject: [PATCH 13/15] Add tests --- lib/bumblebee.ex | 1 + lib/bumblebee/vision/swin.ex | 7 ++-- test/bumblebee/vision/swin_test.exs | 51 +++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 36474c9d..7935c4ee 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -186,6 +186,7 @@ defmodule Bumblebee do "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, + "SwinModel" => {Bumblebee.Vision.Swin, :base}, "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation}, diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index f51029e5..151be6f2 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -83,7 +83,10 @@ defmodule Bumblebee.Vision.Swin do ## Architectures - * `:for_image_classification` - Swin tranformer model for image classification. + * `:base` - plain Swin without any head on top + + * `:for_image_classification` - Swin tranformer model with a + classification head ## Global layer options @@ -108,7 +111,7 @@ defmodule Bumblebee.Vision.Swin do alias Bumblebee.Layers @impl true - def architectures(), do: [:for_image_classification] + def architectures(), do: [:base, :for_image_classification] @impl true def config(spec, opts) do diff --git a/test/bumblebee/vision/swin_test.exs b/test/bumblebee/vision/swin_test.exs index 88c127f0..732c0b4b 100644 --- a/test/bumblebee/vision/swin_test.exs +++ b/test/bumblebee/vision/swin_test.exs @@ -5,7 +5,58 @@ defmodule Bumblebee.Vision.SwinTest do @moduletag model_test_tags() + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-SwinModel"}) + + assert %Bumblebee.Vision.Swin{architecture: :base} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 16, 64} + assert Nx.shape(outputs.pooled_state) == {1, 64} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.4605, 0.9336, -0.5528], [-0.4449, 0.8927, -0.5424], [-0.5024, 0.2263, 0.2208]] + ]) + ) + + assert_all_close( + outputs.pooled_state[[.., 1..3]], + Nx.tensor([[-0.5004, 0.4605, -0.4949]]) + ) + end + test ":for_image_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "hf-internal-testing/tiny-random-SwinForImageClassification"} + ) + + assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.0834, 0.1265]]) + ) + end + + # TODO remove before merging + test ":for_image_classification actual" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "microsoft/swin-base-patch4-window12-384"}) From 740a34536fc5d4c955d4ec687ac31b8e48f145d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 3 Sep 2024 15:38:38 +0700 Subject: [PATCH 14/15] Fixes and refactor --- lib/bumblebee/layers.ex | 12 +- lib/bumblebee/utils/nx.ex | 76 +++ lib/bumblebee/vision/swin.ex | 872 ++++++++++++---------------- test/bumblebee/vision/swin_test.exs | 46 +- 4 files changed, 446 insertions(+), 560 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index cde12666..e6a61794 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -200,7 +200,7 @@ defmodule Bumblebee.Layers do * `query` - `{batch_size, sequence_length, num_heads, head_size}` * `key` - `{batch_size, kv_sequence_length, num_heads, head_size}` * `value` - `{batch_size, kv_sequence_length, num_heads, head_size}` - * `key_mask` (optional) - `{batch_size, kv_sequence_length}` + * `key_mask` (optional) - `{batch_size, kv_sequence_length} | {batch_size, num_heads, sequence_length, kv_sequence_length}` * `head_mask` (optional) - `{num_heads}` * `bias` (optional) - `{batch_size | 1, num_heads | 1, sequence_length, kv_sequence_length}` * `offset` (optional) - `{}` @@ -273,8 +273,14 @@ defmodule Bumblebee.Layers do key_mask = case key_mask do - %Axon.None{} -> Nx.broadcast(1, {1, 1, 1, 1}) - key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1) + %Axon.None{} -> + Nx.broadcast(1, {1, 1, 1, 1}) + + key_mask -> + case Nx.rank(key_mask) do + 2 -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1) + 4 -> key_mask + end end query_sequence_length = Nx.axis_size(query, 2) diff --git a/lib/bumblebee/utils/nx.ex b/lib/bumblebee/utils/nx.ex index 51c6dadd..f9806af6 100644 --- a/lib/bumblebee/utils/nx.ex +++ b/lib/bumblebee/utils/nx.ex @@ -408,6 +408,82 @@ defmodule Bumblebee.Utils.Nx do Nx.take(tensor, flat_idx, axis: opts[:axis]) end + @doc """ + Shifts elements along the specified axes. + + When an shift is positive, the elements are shifted clockwise. + Negative shifts result in counter-clockwise shift. + + ## Options + + * `:shifts` - the shift size to apply to the corresponding axis + from `:axes` + + * `:axes` - the axes to apply shift to, must have the same length + as `:shifts` + + ## Examples + + iex> x = Nx.iota({3, 3}) + iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0]) + #Nx.Tensor< + s64[3][3] + [ + [6, 7, 8], + [0, 1, 2], + [3, 4, 5] + ] + > + + iex> x = Nx.iota({3, 3}) + iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0]) + #Nx.Tensor< + s64[3][3] + [ + [3, 4, 5], + [6, 7, 8], + [0, 1, 2] + ] + > + + iex> x = Nx.iota({3, 3}) + iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1]) + #Nx.Tensor< + s64[3][3] + [ + [7, 8, 6], + [1, 2, 0], + [4, 5, 3] + ] + > + + """ + deftransform roll(tensor, opts) do + opts = Keyword.validate!(opts, shifts: [], axes: []) + + shifts = opts[:shifts] + axes = opts[:axes] + + if length(shifts) != length(axes) do + raise ArgumentError, + "expected shifts and axes to have the same number of elements," <> + " got shifts: #{inspect(shifts)}, axes: #{inspect(axes)}" + end + + shifts + |> Enum.zip(axes) + |> Enum.reduce(tensor, fn {shift, axis}, tensor -> + shift = rem(shift, Nx.axis_size(tensor, axis)) + + if shift == 0 do + tensor + else + {left, right} = Nx.split(tensor, -shift, axis: axis) + Nx.concatenate([right, left], axis: axis) + end + end) + end + @doc """ Returns size of the given `Nx.Batch`, including padding. """ diff --git a/lib/bumblebee/vision/swin.ex b/lib/bumblebee/vision/swin.ex index 151be6f2..42da43ca 100644 --- a/lib/bumblebee/vision/swin.ex +++ b/lib/bumblebee/vision/swin.ex @@ -3,78 +3,74 @@ defmodule Bumblebee.Vision.Swin do options = [ - attention_dropout_rate: [ - default: 0.0, - doc: "the dropout rate for attention weights" + image_size: [ + default: 224, + doc: "the size of the input spatial dimensions" ], - depths: [ - default: [2, 2, 18, 2], - doc: "the depth (number of residual blocks) at each stage" + num_channels: [ + default: 3, + doc: "the number of channels in the input" ], - drop_path_rate: [ - default: 0.1, - doc: "the drop path rate used to for stochastic depth" + patch_size: [ + default: 4, + doc: "the size of the patch spatial dimensions" ], - # Maybe it should be renamed to hidden_size - embed_dim: [ - default: 128, - doc: "" + embedding_size: [ + default: 96, + doc: "the dimensionality of patch embedding layer" ], - activation: [ - default: :gelu, - doc: "the activation function" + use_absolute_position_embeddings: [ + default: false, + doc: "whether to add absolute position embeddings to the patch embeddings" ], - dropout_rate: [ - default: 0.0, - doc: "the dropout rate for encoder and decoder" + num_blocks: [ + default: [2, 2, 6, 2], + doc: "the number of Transformer blocks in the encoder at each stage" ], - image_size: [ - default: 384, - doc: "the size of the input spatial dimensions" + num_attention_heads: [ + default: [3, 6, 12, 24], + doc: "the number of attention heads for each attention layer in the encoder at each stage" ], - initializer_scale: [ - default: 0.02, + window_size: [ + default: 7, doc: - "the standard deviation of the normal initializer used for initializing kernel parameters" - ], - layer_norm_epsilon: [ - default: 1.0e-5, - doc: "the epsilon used by the layer normalization layers" + "the window size, used to limit self-attention computation to non-overlapping windows" ], intermediate_size_ratio: [ default: 4, doc: """ the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder, - expressed as a multiplier of `:hidden_size` + expressed as a multiplier of hidden size (at the given stage) """ ], - num_channels: [ - default: 3, - doc: "the number of channels in the input" + use_attention_bias: [ + default: true, + doc: "whether to use bias in query, key, and value projections" ], - num_heads: [ - default: [4, 8, 16, 32], - doc: "number of attention heads" + activation: [ + default: :gelu, + doc: "the activation function" ], - patch_size: [ - default: 4, - doc: "the size of the patch spatial dimensions" + dropout_rate: [ + default: 0.0, + doc: "the dropout rate for encoder and decoder" ], - path_norm: [ - default: true, - doc: "" + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" ], - use_attention_bias: [ - default: true, - doc: "whether to use bias in query, key, and value projections" + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" ], - use_absolute_embeddings: [ - default: false, - doc: "" + drop_path_rate: [ + default: 0.1, + doc: "the drop path rate used to for stochastic depth" ], - window_size: [ - default: 12, - doc: "" + layer_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the layer normalization layers" ] ] ++ Shared.common_options([:num_labels, :id_to_label]) @@ -90,15 +86,16 @@ defmodule Bumblebee.Vision.Swin do ## Global layer options - # {Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} ## Configuration - # {Shared.options_doc(options)} + #{Shared.options_doc(options)} ## References * [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) + """ defstruct [architecture: :base] ++ Shared.option_defaults(options) @@ -107,6 +104,7 @@ defmodule Bumblebee.Vision.Swin do @behaviour Bumblebee.Configurable import Bumblebee.Utils.Model, only: [join: 2] + import Nx.Defn alias Bumblebee.Layers @@ -141,9 +139,7 @@ defmodule Bumblebee.Vision.Swin do outputs = core(inputs, spec) logits = - outputs.hidden_state - |> Layers.take_token(index: 0, axis: 1) - |> Axon.dense(spec.num_labels, + Axon.dense(outputs.pooled_state, spec.num_labels, kernel_initializer: kernel_initializer(spec), name: "image_classification_head.output" ) @@ -164,49 +160,49 @@ defmodule Bumblebee.Vision.Swin do ]) end - # Contrary to Python implementation we do not have here argument - # bool_maked_pos. This parameter is propagated from model through - # core to embedder. defp core(inputs, spec, opts \\ []) do name = opts[:name] embeddings = - embedder(inputs["pixel_values"], spec, name: join(name, "embedder")) + embedder(inputs["pixel_values"], inputs["patch_mask"], spec, name: join(name, "embedder")) - {hidden_state, hidden_states, attentions} = + encoder_outputs = encoder(embeddings, spec, name: join(name, "encoder")) hidden_state = - Axon.layer_norm(hidden_state, + Axon.layer_norm(encoder_outputs.hidden_state, epsilon: spec.layer_norm_epsilon, - name: join(name, "layernorm") + name: join(name, "norm") ) pooled_state = - Axon.adaptive_avg_pool(hidden_state, output_size: {1, 1}, name: join(name, "pooler")) + hidden_state + |> Axon.adaptive_avg_pool(output_size: {1}, name: join(name, "pooler")) + |> Axon.flatten() %{ hidden_state: hidden_state, pooled_state: pooled_state, - hidden_states: hidden_states, - attentions: attentions + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions } end - defp embedder(pixel_values, spec, opts) do + defp embedder(pixel_values, patch_mask, spec, opts) do name = opts[:name] embeddings = pixel_values |> patch_embedding(spec, name: join(name, "patch_embedding")) - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: "layernorm_embeddings") + |> Layers.apply_vision_patch_mask(patch_mask, name: join(name, "mask_tokens")) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) embeddings = - if spec.use_absolute_embeddings do + if spec.use_absolute_position_embeddings do num_patches = div(spec.image_size, spec.patch_size) ** 2 position_embeddings = - Layers.learned_embeddings(num_patches, spec.embed_dim, + Layers.learned_embeddings(num_patches, spec.embedding_size, initializer: :zeros, name: join(name, "position_embedding") ) @@ -222,7 +218,7 @@ defmodule Bumblebee.Vision.Swin do defp patch_embedding(pixel_values, spec, opts) do name = opts[:name] - hidden_size = spec.embed_dim + hidden_size = spec.embedding_size pixel_values |> Axon.conv(hidden_size, @@ -232,495 +228,361 @@ defmodule Bumblebee.Vision.Swin do kernel_initializer: kernel_initializer(spec), name: join(name, "projection") ) - |> Axon.reshape({:batch, :auto, spec.embed_dim}, name: join(name, "reshape")) + |> Axon.reshape({:batch, :auto, spec.embedding_size}, name: join(name, "reshape")) end defp encoder(hidden_state, spec, opts) do name = opts[:name] - hidden_states = Axon.container({hidden_state}) - attentions = Axon.container({}) - state = { - hidden_state, - hidden_states, - attentions + state = %{ + hidden_state: hidden_state, + hidden_states: Axon.container({hidden_state}), + attentions: Axon.container({}) } - for stage_idx <- 0..(length(spec.depths) - 1), reduce: state do - {hidden_state, hidden_states, attentions} -> - {hidden_state, attention} = - stage(hidden_state, stage_idx, spec, join("#{name}.blocks", stage_idx)) + for stage_idx <- 0..(length(spec.num_blocks) - 1), reduce: state do + state -> + name = name |> join("stages") |> join(stage_idx) + + grid_size = div(spec.image_size, spec.patch_size) + input_resolution = div(grid_size, 2 ** stage_idx) + + {hidden_state, attention, hidden_state_before_downsample} = + stage(state.hidden_state, spec, + hidden_size: spec.embedding_size * 2 ** stage_idx, + num_blocks: Enum.at(spec.num_blocks, stage_idx), + num_attention_heads: Enum.at(spec.num_attention_heads, stage_idx), + downsample: stage_idx < length(spec.num_blocks) - 1, + input_resolution: input_resolution, + name: name + ) - { - hidden_state, - Layers.append(hidden_states, hidden_state), - Layers.append(attentions, attention) + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(state.hidden_states, hidden_state_before_downsample), + attentions: Layers.append(state.attentions, attention) } end end - defp stage(hidden_state, stage_idx, spec, name) do - grid_size = div(spec.image_size, spec.patch_size) - input_resolution = div(grid_size, 2 ** stage_idx) - dim = spec.embed_dim * 2 ** stage_idx - num_attention_heads = Enum.at(spec.num_heads, stage_idx) + defp stage(hidden_state, spec, opts) do + name = opts[:name] + downsample = opts[:downsample] + hidden_size = opts[:hidden_size] + num_blocks = opts[:num_blocks] + num_attention_heads = opts[:num_attention_heads] + input_resolution = opts[:input_resolution] + + # Note that we include only record hidden_state and attention + # from the last block in each stage {hidden_state, attention} = - for layer_idx <- 0..(Enum.at(spec.depths, stage_idx) - 1), reduce: {hidden_state, nil} do - {hidden_state, _} -> + for block_idx <- 0..(num_blocks - 1), reduce: {hidden_state, nil} do + {hidden_state, _attention} -> + name = name |> join("blocks") |> join(block_idx) + + shift_size = + if rem(block_idx, 2) == 0 do + 0 + else + div(spec.window_size, 2) + end + {hidden_state, attention} = - layer(hidden_state, layer_idx, dim, num_attention_heads, spec, name) + transformer_block(hidden_state, + num_attention_heads: num_attention_heads, + hidden_size: hidden_size, + kernel_initializer: kernel_initializer(spec), + dropout_rate: spec.dropout_rate, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm_epsilon: spec.layer_norm_epsilon, + intermediate_size: floor(spec.intermediate_size_ratio * hidden_size), + activation: spec.activation, + name: name, + window_size: spec.window_size, + shift_size: shift_size, + input_resolution: input_resolution + ) {hidden_state, attention} end + hidden_state_before_downsample = hidden_state + hidden_state = - if stage_idx < length(spec.depths) - 1 do - downsample(hidden_state, input_resolution, dim, spec.layer_norm_epsilon, name) + if downsample do + patch_merging(hidden_state, + input_resolution: input_resolution, + hidden_size: hidden_size, + layer_norm_epsilon: spec.layer_norm_epsilon, + kernel_initializer: kernel_initializer(spec), + name: join(name, "downsample") + ) else hidden_state end - {hidden_state, attention} - end - - defp layer(hidden_state, layer_idx, dim, num_attention_heads, spec, name) do - shortcut = hidden_state - attn_mask = attention_mask_layer(hidden_state, layer_idx, spec) - name = join(name, "layer.#{layer_idx}") - - {hidden_state, attention} = - hidden_state - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_before")) - |> reshape(layer_idx) - |> hidden_state_windows(layer_idx, spec) - |> attention(attn_mask, Layers.none(), num_attention_heads, dim, spec, name) - - # TODO "unpad" if it was padded - # After unroll we have to reverse padding (before dropout) - attention_windows = - {hidden_state, shortcut} - |> unroll(layer_idx, spec) - |> Axon.dropout(rate: spec.dropout_rate) - - hidden_state = Axon.add(shortcut, attention_windows) - - output = - hidden_state - |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "layernorm_after")) - |> Axon.dense(round(spec.intermediate_size_ratio * dim), - name: join(name, "intermediate.dense") - ) - |> Layers.activation(spec.activation) - |> Axon.dense(dim, name: join(name, "output.dense")) - |> Axon.dropout(rate: spec.dropout_rate) - - hidden_state = Axon.add(hidden_state, output) - - {hidden_state, attention} - end - - defp reshape(input, layer_idx) do - input - |> Axon.nx( - fn x -> - {batch_size, dimension, num_channels} = Nx.shape(x) - height_width = dimension |> :math.sqrt() |> floor() - - x - |> Nx.reshape({batch_size, height_width, height_width, num_channels}) - end, - name: "reshape_#{layer_idx}" - ) + {hidden_state, attention, hidden_state_before_downsample} end - defp maybe_pad(input, window_size, height, width) do - pad1 = (window_size - rem(width, window_size)) |> rem(window_size) - pad2 = (window_size - rem(height, window_size)) |> rem(window_size) - - Nx.pad(input, 0, [{0, 0, 0}, {0, pad1, 0}, {0, pad2, 0}, {0, 0, 0}]) - end - - defp hidden_state_windows(input, layer_idx, spec) do - shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) - - input - |> Axon.nx( - fn x -> - {_batch_size, height, width, num_channels} = Nx.shape(x) - - x = maybe_pad(x, spec.window_size, height, width) - - {_, height, width, _} = Nx.shape(x) + defp transformer_block(hidden_state, opts) do + num_attention_heads = opts[:num_attention_heads] + hidden_size = opts[:hidden_size] + kernel_initializer = opts[:kernel_initializer] + dropout_rate = opts[:dropout_rate] + attention_dropout_rate = opts[:attention_dropout_rate] + layer_norm_epsilon = opts[:layer_norm_epsilon] + intermediate_size = opts[:intermediate_size] + activation = opts[:activation] + name = opts[:name] + window_size = opts[:window_size] + shift_size = opts[:shift_size] + input_resolution = opts[:input_resolution] - {shift_size, _window_size} = - if min(height, width) <= spec.window_size, - do: {0, min(height, width)}, - else: {shift_size, spec.window_size} + {shift_size, window_size} = + if input_resolution <= window_size do + {0, input_resolution} + else + {shift_size, window_size} + end - # cyclic shift - shiffted_hidden_state = - if shift_size > 0, - do: roll(x, shifts: [-shift_size, -shift_size], axes: [1, 2]), - else: x + shortcut = hidden_state - # partition windows - shiffted_hidden_state - |> window_partition(spec.window_size) - |> Nx.reshape({:auto, spec.window_size * spec.window_size, num_channels}) - end, - name: "hidden_state_windows_#{layer_idx}" - ) - end + attention_mask = + window_attention_mask(hidden_state, shift_size, window_size, input_resolution) - defp relative_position_index(window_size) do - coords_h = Nx.iota({window_size}) |> Nx.tile([window_size, 1]) |> Nx.transpose() - coords_w = Nx.iota({window_size}) |> Nx.tile([window_size, 1]) - - coords_flatten = - Nx.stack([coords_h, coords_w]) - # flatten dimension 1 - |> Nx.reshape({2, window_size * window_size}) - - relative_coords = - Nx.subtract(Nx.new_axis(coords_flatten, 2), Nx.new_axis(coords_flatten, 1)) - |> Nx.transpose(axes: [1, 2, 0]) - - relative_coords = - Nx.add( - relative_coords, - Nx.broadcast(Nx.tensor([window_size - 1, window_size - 1]), relative_coords) + relative_attention_bias = + relative_attention_bias(window_size, num_attention_heads, + name: join(name, "self_attention.relative_attention_bias") ) - Nx.multiply( - relative_coords, - Nx.broadcast(Nx.tensor([2 * window_size - 1, 1]), relative_coords) - ) - |> Nx.sum(axes: [-1]) - end - - defp transpose_for_scores(x, num_attention_heads, attention_head_size) do - new_shape = - x - |> Nx.shape() - |> Tuple.to_list() - |> List.replace_at(-1, [num_attention_heads, attention_head_size]) - |> List.flatten() - |> List.to_tuple() - - x - |> Nx.reshape(new_shape) - |> Nx.transpose(axes: [0, 2, 1, 3]) - end - - defp attention(hidden_state, attention_mask, head_mask, num_attention_heads, dim, spec, name) do - attention_head_size = floor(dim / num_attention_heads) - all_head_size = num_attention_heads * attention_head_size - name = join(name, "self_attention") - - query = - hidden_state - |> Axon.dense(all_head_size, name: join(name, "query")) - |> Axon.nx(fn x -> - transpose_for_scores(x, num_attention_heads, attention_head_size) - end) - - key = - hidden_state - |> Axon.dense(all_head_size, name: join(name, "key")) - |> Axon.nx(fn x -> - transpose_for_scores(x, num_attention_heads, attention_head_size) - end) - - value = + hidden_state = hidden_state - |> Axon.dense(all_head_size, name: join(name, "value")) - |> Axon.nx(fn x -> - transpose_for_scores(x, num_attention_heads, attention_head_size) - end) - - relative_position_bias_table = - Axon.param( - "relative_position_bias_table", - {(2 * spec.window_size - 1) * (2 * spec.window_size - 1), num_attention_heads} - ) + |> Axon.layer_norm(epsilon: layer_norm_epsilon, name: join(name, "self_attention_norm")) + |> hidden_state_windows(shift_size, window_size, input_resolution) - probabilities = - Axon.layer( - &attention_weights_impl/7, - [ - hidden_state, - query, - key, - relative_position_bias_table, - Axon.optional(attention_mask), - Axon.optional(head_mask) - ], - name: join(name, "weights"), - window_size: spec.window_size, + {hidden_state, attention, _self_attention_cache, _attention_relative_bias} = + Layers.Transformer.multi_head_attention(hidden_state, hidden_state, hidden_state, + attention_mask: attention_mask, + attention_relative_bias: relative_attention_bias, num_heads: num_attention_heads, - head_size: attention_head_size, - dropout_rate: spec.attention_dropout_rate + hidden_size: hidden_size, + kernel_initializer: kernel_initializer, + dropout_rate: attention_dropout_rate, + name: join(name, "self_attention") ) - output_name = join(name, "output") - - context = - Axon.layer( - &attention_output_impl/3, - [probabilities, value], - name: output_name, - num_heads: num_attention_heads, - head_size: attention_head_size - ) - |> Axon.dense(dim, name: join(output_name, "dense")) - |> Axon.dropout( - rate: spec.attention_dropout_rate, - name: join(output_name, "dropout") - ) - - {context, probabilities} - end - - defp attention_weights_impl( - hidden_state, - query, - key, - relative_position_bias_table, - attention_mask, - head_mask, - opts - ) do - opts = - Keyword.validate!(opts, [:mode, :name, :window_size, :num_heads, :head_size, :dropout_rate]) - - {batch_size, dim, _num_channels} = Nx.shape(hidden_state) - - scores = - query - |> Nx.dot([3], [0, 1], Nx.transpose(key, axes: [0, 1, -1, -2]), [2], [0, 1]) - |> Nx.divide(Nx.sqrt(opts[:head_size])) - - rel_pos_idx = relative_position_index(opts[:window_size]) |> Nx.reshape({:auto}) - - relative_position_bias = - Nx.take(relative_position_bias_table, rel_pos_idx) - |> Nx.reshape( - {opts[:window_size] * opts[:window_size], opts[:window_size] * opts[:window_size], :auto} - ) - |> Nx.transpose(axes: [2, 0, 1]) - |> Nx.new_axis(0) - - scores = Nx.add(scores, relative_position_bias) - - scores = - case attention_mask do - %Axon.None{} -> - scores + hidden_state = + Axon.dropout(hidden_state, rate: dropout_rate, name: join(name, "self_attention_dropout")) - _ -> - {mask_size, _, _} = Nx.shape(attention_mask) + hidden_state = + hidden_state + |> reverse_hidden_state_windows(shift_size, window_size, input_resolution) + |> Axon.dropout(rate: dropout_rate) - scores = - Nx.reshape( - scores, - {floor(batch_size / mask_size), mask_size, opts[:num_heads], dim, dim} - ) + hidden_state = Axon.add(hidden_state, shortcut) - attention_mask = - attention_mask - |> Nx.new_axis(1) - |> Nx.new_axis(0) + shortcut = hidden_state - scores - |> Nx.add(attention_mask) - |> Nx.reshape({:auto, opts[:num_heads], dim, dim}) - end + hidden_state = + hidden_state + |> Axon.layer_norm(epsilon: layer_norm_epsilon, name: join(name, "output_norm")) + |> Axon.dense(intermediate_size, name: join(name, "ffn.intermediate")) + |> Layers.activation(activation) + |> Axon.dense(hidden_size, name: join(name, "ffn.output")) + |> Axon.dropout(rate: dropout_rate) - # Normalize the attention scores to probabilities (softmax). - # - # This is actually dropping out entire tokens to attend to, which - # might seem a bit unusual, but is taken from the original - # Transformer paper (dropout). - seed = :erlang.system_time() + hidden_state = Axon.add(hidden_state, shortcut) - probabilities = - Axon.Activations.softmax(scores, axis: -1) - |> Axon.Layers.dropout(Nx.Random.key(seed), rate: opts[:dropout_rate]) + {hidden_state, attention} + end - case head_mask do - %Axon.None{} -> - probabilities + defp window_attention_mask(hidden_state, shift_size, window_size, input_resolution) do + if shift_size > 0 do + # Computes attention mask for shifted window multi-head self-attention (SW-MSA) + + Axon.nx(hidden_state, fn hidden_state -> + {batch_size, _dimension, _hidden_size} = Nx.shape(hidden_state) + height = width = input_resolution + + # See Figure 4. in the paper. We color the 2D patches (tokens) + # into 4 groups. Then, we compute a mask such that each token + # attends only to tokens within the same group. + + grid_0 = Nx.broadcast(0, {height - shift_size, width - shift_size}) + grid_b = Nx.broadcast(1, {height - shift_size, shift_size}) + grid_c = Nx.broadcast(2, {shift_size, width - shift_size}) + grid_a = Nx.broadcast(3, {shift_size, shift_size}) + + grid = + Nx.concatenate([ + Nx.concatenate([grid_0, grid_b], axis: 1), + Nx.concatenate([grid_c, grid_a], axis: 1) + ]) + + windowed_patch_groups = + grid + |> Nx.reshape({1, height, width, 1}) + |> window_partition(window_size) + |> Nx.reshape({:auto, window_size * window_size}) + + windows_attention_mask = + Nx.equal( + Nx.new_axis(windowed_patch_groups, 1), + Nx.new_axis(windowed_patch_groups, 2) + ) + |> Nx.new_axis(1) - head_mask -> - Nx.multiply(probabilities, head_mask) + # Note that we repeat the mask for each batched input, so that + # the batch dimension has size batch_size * num_windows, which + # matches the input. This way we can apply the mask as usual, + # without reshaping back and forth. + Nx.tile(windows_attention_mask, [batch_size, 1, 1, 1]) + end) + else + Layers.none() end end - def attention_output_impl(weights, value, opts) do - context = - weights - |> Nx.dot([3], [0, 1], value, [2], [0, 1]) - |> Nx.transpose(axes: [0, 2, 1, 3]) - - new_context_shape = - context - |> Nx.shape() - |> Tuple.to_list() - |> Enum.slice(0..-3//1) - |> Kernel.++([opts[:num_heads] * opts[:head_size]]) - |> List.to_tuple() - - Nx.reshape(context, new_context_shape) - end + defp relative_attention_bias(window_size, num_attention_heads, opts) do + name = opts[:name] - def unroll({hidden_state, input}, layer_idx, spec) do - # reverse cyclic shift - shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) + kernel = + Axon.param("kernel", {(2 * window_size - 1) * (2 * window_size - 1), num_attention_heads}) Axon.layer( - fn state, input, _ -> - {batch_size, dimension, num_channels} = Nx.shape(input) - height_width = dimension |> :math.sqrt() |> floor() - - {shift_size, window_size} = - if height_width <= spec.window_size, - do: {0, height_width}, - else: {shift_size, spec.window_size} + fn kernel, opts -> + window_size = opts[:window_size] - shifted_windows = - state - |> Nx.reshape({:auto, window_size, window_size, num_channels}) - |> window_reverse(window_size) + idx = relative_position_index(window_size) |> Nx.reshape({:auto}) - if shift_size > 0 do - roll(shifted_windows, shifts: [shift_size, shift_size], axes: [1, 2]) - |> Nx.reshape({batch_size, height_width * height_width, num_channels}) - else - shifted_windows - |> Nx.reshape({batch_size, height_width * height_width, num_channels}) - end + kernel + |> Nx.take(idx) + |> Nx.reshape({window_size * window_size, window_size * window_size, :auto}) + |> Nx.transpose(axes: [2, 0, 1]) + |> Nx.new_axis(0) end, - [hidden_state, input], - name: "unroll_#{layer_idx}" + [kernel], + window_size: window_size, + name: name ) end - defp attention_mask_layer(hidden_state, layer_idx, spec) do - shift_size = if 0 == rem(layer_idx, 2), do: 0, else: div(spec.window_size, 2) - - hidden_state - |> Axon.nx( - fn x -> - {_batch_size, dimension, _num_channels} = Nx.shape(x) - height_width = dimension |> :math.sqrt() |> floor() + defp relative_position_index(window_size) do + coords_h = Nx.iota({window_size, window_size}, axis: 0) |> Nx.flatten() + coords_w = Nx.iota({window_size, window_size}, axis: 1) |> Nx.flatten() + coord_pairs = Nx.stack([coords_h, coords_w]) - {shift_size, window_size} = - if height_width <= spec.window_size, - do: {0, height_width}, - else: {shift_size, spec.window_size} + relative_coords = Nx.subtract(Nx.new_axis(coord_pairs, 2), Nx.new_axis(coord_pairs, 1)) - attention_mask(height_width, height_width, window_size, shift_size) - end, - name: "att_mask_#{layer_idx}" - ) + relative_coords + |> Nx.add(Nx.reshape(Nx.tensor([window_size - 1, window_size - 1]), {2, 1, 1})) + |> Nx.multiply(Nx.reshape(Nx.tensor([2 * window_size - 1, 1]), {2, 1, 1})) + |> Nx.sum(axes: [0]) end - def attention_mask(height, width, window_size, shift_size) do - if shift_size > 0 do - # calculate attention mask for shifted window multi-head self-attention (SW-MSA) - img_mask = Nx.broadcast(0.0, {1, height, width, 1}) + defp hidden_state_windows(hidden_state, shift_size, window_size, input_resolution) do + Axon.nx(hidden_state, fn hidden_state -> + {batch_size, _dimension, hidden_size} = Nx.shape(hidden_state) - hslices = [ - 0..(height - window_size - 1), - (height - window_size)..(height - shift_size - 1), - (height - shift_size)..(height - 1) - ] + height = width = input_resolution + hidden_state = Nx.reshape(hidden_state, {batch_size, height, width, hidden_size}) - wslices = [ - 0..(width - window_size - 1), - (width - window_size)..(width - shift_size - 1), - (width - shift_size)..(width - 1) - ] + # Apply cyclic shift + hidden_state = + if shift_size > 0 do + Bumblebee.Utils.Nx.roll(hidden_state, shifts: [-shift_size, -shift_size], axes: [1, 2]) + else + hidden_state + end + + # Partition windows + hidden_state + |> window_partition(window_size) + |> Nx.reshape({:auto, window_size * window_size, hidden_size}) + end) + end - {img_mask, _count} = - for hrange <- hslices, wrange <- wslices, reduce: {img_mask, 0.0} do - {mask, count} -> - mask = - for hidx <- hrange, widx <- wrange, reduce: mask do - deepest_mask -> - Nx.indexed_put(deepest_mask, Nx.tensor([0, hidx, widx, 0]), count) - end + defp reverse_hidden_state_windows(hidden_state, shift_size, window_size, input_resolution) do + Axon.nx(hidden_state, fn hidden_state -> + {_batch_size, _dimension, hidden_size} = Nx.shape(hidden_state) + height = width = input_resolution - {mask, count + 1.0} - end + # Reverse window partitioning + hidden_state = + hidden_state + |> Nx.reshape({:auto, window_size, window_size, hidden_size}) + |> window_unpartition(window_size, height, width) - mask_windows = - img_mask - |> window_partition(window_size) - |> Nx.reshape({:auto, window_size * window_size}) + # Reverse cyclic shift + hidden_state = + if shift_size > 0 do + Bumblebee.Utils.Nx.roll(hidden_state, shifts: [shift_size, shift_size], axes: [1, 2]) + else + hidden_state + end - mask_windows - |> Nx.new_axis(1) - |> Nx.subtract(Nx.new_axis(mask_windows, 2)) - |> Nx.equal(0) - |> Nx.logical_not() - else - %Axon.None{} - end + Nx.reshape(hidden_state, {:auto, height * width, hidden_size}) + end) end - defp window_partition(%Nx.Tensor{} = tensor, window_size) do - {batch_size, height, width, num_channels} = Nx.shape(tensor) + defnp window_partition(tensor, window_size) do + {batch_size, height, width, hidden_size} = Nx.shape(tensor) windowed_height = div(height, window_size) windowed_width = div(width, window_size) Nx.reshape( tensor, - {batch_size, windowed_height, window_size, windowed_width, window_size, num_channels} + {batch_size, windowed_height, window_size, windowed_width, window_size, hidden_size} ) |> Nx.transpose(axes: [0, 1, 3, 2, 4, 5]) - |> Nx.reshape({:auto, window_size, window_size, num_channels}) - end - - defp window_reverse(%Axon{} = input_feature, window_size) do - input_feature - |> Axon.nx(fn x -> window_reverse(x, window_size) end) + |> Nx.reshape({:auto, window_size, window_size, hidden_size}) end - defp window_reverse(%Nx.Tensor{} = tensor, window_size) do - {_batch_size, height, width, num_channels} = Nx.shape(tensor) + defnp window_unpartition(tensor, window_size, height, width) do + {_batch_size, _height, _width, hidden_size} = Nx.shape(tensor) windowed_height = div(height, window_size) windowed_width = div(width, window_size) Nx.reshape( tensor, - {:auto, windowed_height, windowed_width, window_size, window_size, num_channels} + {:auto, windowed_height, windowed_width, window_size, window_size, hidden_size} ) |> Nx.transpose(axes: [0, 1, 3, 2, 4, 5]) - |> Nx.reshape({:auto, height, width, num_channels}) + |> Nx.reshape({:auto, height, width, hidden_size}) end - defp downsample(hidden_state, input_resolution, dim, norm_epsilon, name) do - Axon.nx(hidden_state, fn x -> - {batch_size, _dim, num_channels} = Nx.shape(x) + defp patch_merging(hidden_state, opts) do + input_resolution = opts[:input_resolution] + hidden_size = opts[:hidden_size] + layer_norm_epsilon = opts[:layer_norm_epsilon] + kernel_initializer = opts[:kernel_initializer] + name = opts[:name] + + # We group patches from each 2x2 square and apply a dense layer + # against each group + + hidden_state + |> Axon.nx(fn hidden_state -> + {batch_size, _sequence_length, _hidden_size} = Nx.shape(hidden_state) - x = Nx.reshape(x, {batch_size, input_resolution, input_resolution, :auto}) + hidden_state = + Nx.reshape(hidden_state, {batch_size, input_resolution, input_resolution, :auto}) - input_feature_0 = x[[.., 0..-1//2, 0..-1//2, ..]] - input_feature_1 = x[[.., 1..-1//2, 0..-1//2, ..]] - input_feature_2 = x[[.., 0..-1//2, 1..-1//2, ..]] - input_feature_3 = x[[.., 1..-1//2, 1..-1//2, ..]] + input_feature_0 = hidden_state[[.., 0..-1//2, 0..-1//2, ..]] + input_feature_1 = hidden_state[[.., 1..-1//2, 0..-1//2, ..]] + input_feature_2 = hidden_state[[.., 0..-1//2, 1..-1//2, ..]] + input_feature_3 = hidden_state[[.., 1..-1//2, 1..-1//2, ..]] Nx.concatenate([input_feature_0, input_feature_1, input_feature_2, input_feature_3], axis: -1 ) - |> Nx.reshape({batch_size, :auto, 4 * num_channels}) + |> Nx.reshape({batch_size, :auto, 4 * hidden_size}) end) - |> Axon.layer_norm(epsilon: norm_epsilon, name: join(name, "downsample_norm")) - |> Axon.dense(2 * dim, - kernel_initializer: Axon.Initializers.uniform(), - name: join(name, "downsample_reduction"), + |> Axon.layer_norm(epsilon: layer_norm_epsilon, name: join(name, "norm")) + |> Axon.dense(2 * hidden_size, + kernel_initializer: kernel_initializer, + name: join(name, "reduction"), use_bias: false ) end @@ -736,9 +598,9 @@ defmodule Bumblebee.Vision.Swin do opts = convert!(data, attention_dropout_rate: {"attention_probs_dropout_prob", number()}, - depths: {"depths", list(number())}, + num_blocks: {"depths", list(number())}, drop_path_rate: {"drop_path_rate", number()}, - embed_dim: {"embed_dim", number()}, + embedding_size: {"embed_dim", number()}, activation: {"hidden_act", activation()}, dropout_rate: {"hidden_dropout_prob", number()}, image_size: {"image_size", number()}, @@ -746,11 +608,10 @@ defmodule Bumblebee.Vision.Swin do layer_norm_epsilon: {"layer_norm_eps", number()}, intermediate_size_ratio: {"mlp_ratio", number()}, num_channels: {"num_channels", number()}, - num_heads: {"num_heads", list(number())}, + num_attention_heads: {"num_heads", list(number())}, patch_size: {"patch_size", number()}, - path_norm: {"path_norm", boolean()}, use_attention_bias: {"qkv_bias", boolean()}, - use_absolute_embeddings: {"use_absolute_embeddings", boolean()}, + use_absolute_position_embeddings: {"use_absolute_embeddings", boolean()}, window_size: {"window_size", number()} ) ++ Shared.common_options_from_transformers(data, spec) @@ -761,60 +622,39 @@ defmodule Bumblebee.Vision.Swin do defimpl Bumblebee.HuggingFace.Transformers.Model do def params_mapping(_spec) do %{ - "encoder.blocks.{n}.layer.{m}.intermediate.dense" => - "swin.encoder.layers.{n}.blocks.{m}.intermediate.dense", - "encoder.blocks.{n}.layer.{m}.layernorm_after" => + "embedder.patch_embedding.projection" => "swin.embeddings.patch_embeddings.projection", + "embedder.norm" => "swin.embeddings.norm", + "encoder.stages.{n}.blocks.{m}.output_norm" => "swin.encoder.layers.{n}.blocks.{m}.layernorm_after", - "encoder.blocks.{n}.layer.{m}.layernorm_before" => + "encoder.stages.{n}.blocks.{m}.self_attention_norm" => "swin.encoder.layers.{n}.blocks.{m}.layernorm_before", - "encoder.blocks.{n}.layer.{m}.output.dense" => - "swin.encoder.layers.{n}.blocks.{m}.output.dense", - "encoder.blocks.{n}.layer.{m}.self_attention.key" => + "encoder.stages.{n}.blocks.{m}.self_attention.key" => "swin.encoder.layers.{n}.blocks.{m}.attention.self.key", - "encoder.blocks.{n}.layer.{m}.self_attention.output.dense" => + "encoder.stages.{n}.blocks.{m}.self_attention.output" => "swin.encoder.layers.{n}.blocks.{m}.attention.output.dense", - "encoder.blocks.{n}.layer.{m}.self_attention.query" => + "encoder.stages.{n}.blocks.{m}.self_attention.query" => "swin.encoder.layers.{n}.blocks.{m}.attention.self.query", - "encoder.blocks.{n}.layer.{m}.self_attention.value" => + "encoder.stages.{n}.blocks.{m}.self_attention.value" => "swin.encoder.layers.{n}.blocks.{m}.attention.self.value", - "encoder.blocks.{n}.layer.{m}.self_attention.weights" => - "swin.encoder.layers.{n}.blocks.{m}.attention.self", - "layernorm" => "swin.layernorm", - "layernorm_embeddings" => "swin.embeddings.norm", - "embedder.patch_embedding.projection" => "swin.embeddings.patch_embeddings.projection", - "encoder.blocks.{n}.downsample_norm" => "swin.encoder.layers.{n}.downsample.norm", - "encoder.blocks.{n}.downsample_reduction" => + "encoder.stages.{n}.blocks.{m}.self_attention.relative_attention_bias" => %{ + "kernel" => { + [ + {"swin.encoder.layers.{n}.blocks.{m}.attention.self", + "relative_position_bias_table"} + ], + fn [kernel] -> kernel end + } + }, + "encoder.stages.{n}.blocks.{m}.ffn.intermediate" => + "swin.encoder.layers.{n}.blocks.{m}.intermediate.dense", + "encoder.stages.{n}.blocks.{m}.ffn.output" => + "swin.encoder.layers.{n}.blocks.{m}.output.dense", + "encoder.stages.{n}.downsample.norm" => "swin.encoder.layers.{n}.downsample.norm", + "encoder.stages.{n}.downsample.reduction" => "swin.encoder.layers.{n}.downsample.reduction", + "norm" => "swin.layernorm", "image_classification_head.output" => "classifier" } end end - - defp roll(%Axon{} = x, opts) do - Axon.nx(x, fn y -> roll(y, opts) end) - end - - defp roll(%Nx.Tensor{} = x, opts) do - opts = Keyword.validate!(opts, shifts: [], axes: []) - shifts = opts[:shifts] - axes = opts[:axes] - - if length(shifts) != length(axes) do - raise ArgumentError, "shifts and axes must align, shifts: #{shifts}, axes: #{axes}" - else - shape = Nx.shape(x) |> Tuple.to_list() - - Enum.zip(shifts, axes) - |> Enum.reduce(x, fn {shift, dim}, acc -> - shift = rem(shift, Enum.at(shape, dim)) - - if 0 < shift do - {base, move} = Nx.split(acc, -1 * shift, axis: dim) - Nx.concatenate([move, base], axis: dim) - else - acc - end - end) - end - end end diff --git a/test/bumblebee/vision/swin_test.exs b/test/bumblebee/vision/swin_test.exs index 732c0b4b..081e7401 100644 --- a/test/bumblebee/vision/swin_test.exs +++ b/test/bumblebee/vision/swin_test.exs @@ -12,7 +12,7 @@ defmodule Bumblebee.Vision.SwinTest do assert %Bumblebee.Vision.Swin{architecture: :base} = spec inputs = %{ - "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3}) } outputs = Axon.predict(model, params, inputs) @@ -23,13 +23,13 @@ defmodule Bumblebee.Vision.SwinTest do assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[-0.4605, 0.9336, -0.5528], [-0.4449, 0.8927, -0.5424], [-0.5024, 0.2263, 0.2208]] + [[-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528]] ]) ) assert_all_close( outputs.pooled_state[[.., 1..3]], - Nx.tensor([[-0.5004, 0.4605, -0.4949]]) + Nx.tensor([[-0.4605, 0.9336, -0.5528]]) ) end @@ -42,7 +42,7 @@ defmodule Bumblebee.Vision.SwinTest do assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec inputs = %{ - "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) + "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3}) } outputs = Axon.predict(model, params, inputs) @@ -51,43 +51,7 @@ defmodule Bumblebee.Vision.SwinTest do assert_all_close( outputs.logits, - Nx.tensor([[0.0834, 0.1265]]) - ) - end - - # TODO remove before merging - test ":for_image_classification actual" do - assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "microsoft/swin-base-patch4-window12-384"}) - - assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec - - inputs = %{ - "pixel_values" => Nx.broadcast(0.5, {1, 384, 384, 3}) - } - - outputs = Axon.predict(model, params, inputs) - - assert Nx.shape(outputs.logits) == {1, 1000} - - compare = outputs.logits[[0, 0..9]] - - assert_all_close( - compare, - Nx.tensor([ - [ - 6.9526e-02, - 8.5011e-01, - 4.5132e-01, - 5.4306e-01, - 2.4646e-01, - -2.2765e-03, - 6.9874e-02, - 1.3368e-01, - 4.6875e-01, - 8.8567e-01 - ] - ]) + Nx.tensor([[0.0361, 0.1352]]) ) end end From 1b3c5a06adbeda52d730fdf7ff2c9ceaa61dd9a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 3 Sep 2024 15:41:47 +0700 Subject: [PATCH 15/15] Up --- mix.lock | 1 - 1 file changed, 1 deletion(-) diff --git a/mix.lock b/mix.lock index 1ac77e5d..bf50a962 100644 --- a/mix.lock +++ b/mix.lock @@ -30,7 +30,6 @@ "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, "safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, - "table_rex": {:hex, :table_rex, "4.0.0", "3c613a68ebdc6d4d1e731bc973c233500974ec3993c99fcdabb210407b90959b", [:mix], [], "hexpm", "c35c4d5612ca49ebb0344ea10387da4d2afe278387d4019e4d8111e815df8f55"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, "torchx": {:hex, :torchx, "0.7.0", "c71fd603b0133ed8709450d82aa3434cbcf485a37c9a68e9ebcce86f5e4fb7f0", [:mix], [{:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "a324079c56bb67750b1da16f859d994982bb467020a8c2cba324639552f3adb8"},