From 1b62825fdca9675ca75ec7f3720fd3f318e34cbe Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 12 Feb 2024 18:07:06 +0100 Subject: [PATCH 01/42] ControlNet setup --- .../diffusion/stable_diffusion/control_net.ex | 489 ++++++++++++++++++ .../stable_diffusion/control_net_test.ex | 50 ++ 2 files changed, 539 insertions(+) create mode 100644 lib/bumblebee/diffusion/stable_diffusion/control_net.ex create mode 100644 test/bumblebee/diffusion/stable_diffusion/control_net_test.ex diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex new file mode 100644 index 00000000..a50a374d --- /dev/null +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -0,0 +1,489 @@ +defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do + alias Bumblebee.Shared + + options = [ + sample_size: [ + default: 512, + doc: "the size of the input spatial dimensions" + ], + in_channels: [ + default: 4, + doc: "the number of channels in the input" + ], + out_channels: [ + default: 4, + doc: "the number of channels in the output" + ], + embedding_flip_sin_to_cos: [ + default: true, + doc: "whether to flip the sin to cos in the sinusoidal timestep embedding" + ], + embedding_frequency_correction_term: [ + default: 0, + doc: ~S""" + controls the frequency formula in the timestep sinusoidal embedding. The frequency is computed + as $\\omega_i = \\frac{1}{10000^{\\frac{i}{n - s}}}$, for $i \\in \\{0, ..., n-1\\}$, where $n$ + is half of the embedding size and $s$ is the shift. Historically, certain implementations of + sinusoidal embedding used $s=0$, while others used $s=1$ + """ + ], + hidden_sizes: [ + default: [320, 640, 1280, 1280], + doc: "the dimensionality of hidden layers in each upsample/downsample block" + ], + depth: [ + default: 2, + doc: "the number of residual blocks in each upsample/downsample block" + ], + down_block_types: [ + default: [ + :cross_attention_down_block, + :cross_attention_down_block, + :cross_attention_down_block, + :down_block + ], + doc: + "a list of downsample block types. The supported blocks are: `:down_block`, `:cross_attention_down_block`" + ], + up_block_types: [ + default: [ + :up_block, + :cross_attention_up_block, + :cross_attention_up_block, + :cross_attention_up_block + ], + doc: + "a list of upsample block types. The supported blocks are: `:up_block`, `:cross_attention_up_block`" + ], + downsample_padding: [ + default: [{1, 1}, {1, 1}], + doc: "the padding to use in the downsample convolution" + ], + mid_block_scale_factor: [ + default: 1, + doc: "the scale factor to use for the mid block" + ], + num_attention_heads: [ + default: 8, + doc: + "the number of attention heads for each attention layer. Optionally can be a list with one number per block" + ], + cross_attention_size: [ + default: 1280, + doc: "the dimensionality of the cross attention features" + ], + use_linear_projection: [ + default: false, + doc: + "whether the input/output projection of the transformer block should be linear or convolutional" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + group_norm_num_groups: [ + default: 32, + doc: "the number of groups used by the group normalization layers" + ], + group_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the group normalization layers" + ], + conditioning_embedding_out_channels: [ + default: [16, 32, 96, 256], + doc: "the dimensionality of conditioning embedding" + ] + ] + + @moduledoc """ + ControlNet model with two spatial dimensions and conditional state. + + ## Architectures + + * `:base` - the ControlNet model + + ## Inputs + + * `"sample"` - `{batch_size, sample_size, sample_size, in_channels}` + + Sample input with two spatial dimensions. + + * `"timestep"` - `{}` + + The timestep used to parameterize model behaviour in a multi-step + process, such as diffusion. + + * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` + + The conditional state (context) to use with cross-attention. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + alias Bumblebee.Diffusion + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(spec) do + sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} + timestep_shape = {} + encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} + + %{ + "sample" => Nx.template(sample_shape, :f32), + "timestep" => Nx.template(timestep_shape, :u32), + "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs(spec) + |> core(spec) + |> Layers.output() + end + + defp inputs(spec) do + sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("sample", shape: sample_shape), + Axon.input("timestep", shape: {}), + Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) + ]) + end + + defp core(inputs, spec) do + sample = inputs["sample"] + timestep = inputs["timestep"] + encoder_hidden_state = inputs["encoder_hidden_state"] + + timestep = + Axon.layer( + fn sample, timestep, _opts -> + Nx.broadcast(timestep, {Nx.axis_size(sample, 0)}) + end, + [sample, timestep], + op_name: :broadcast + ) + + timestep_embedding = + timestep + |> Diffusion.Layers.timestep_sinusoidal_embedding(hd(spec.hidden_sizes), + flip_sin_to_cos: spec.embedding_flip_sin_to_cos, + frequency_correction_term: spec.embedding_frequency_correction_term + ) + |> Diffusion.Layers.UNet.timestep_embedding_mlp(hd(spec.hidden_sizes) * 4, + name: "time_embedding" + ) + + sample = + Axon.conv(dbg(sample), 4, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "input_conv" + ) + + control_net_cond_embeddings = control_net_embeddings(sample, spec) + + sample = Axon.add(sample, control_net_cond_embeddings) + + {sample, down_block_residuals} = + down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") + + sample = + mid_block(sample, timestep_embedding, encoder_hidden_state, spec, name: "mid_block") + + conditioning_scale = Axon.constant(1) + + down_block_residuals = + control_net_down_blocks(down_block_residuals, spec, name: "control_net.down_blocks") + + down_block_residuals = + for residual <- Tuple.to_list(down_block_residuals) do + Axon.multiply(residual, conditioning_scale, name: "control_net.down_blocks") + end + |> List.to_tuple() + + mid_block_residual = + control_net_mid_block(sample, spec, name: "control_net.mid_block") + |> Axon.multiply(conditioning_scale) + + %{ + down_block_residuals: down_block_residuals, + mid_block_residual: mid_block_residual + } + end + + defp control_net_down_blocks(down_block_residuals, spec, opts) do + name = opts[:name] + # blocks = Enum.zip(spec.hidden_sizes, Tuple.to_list(down_block_residuals)) + + residuals = + for {{residual, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_residuals)) do + Axon.conv(residual, out_channels, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: name |> join(i) |> join("zero_conv"), + kernel_initializer: :zeros + ) + end + + List.to_tuple(residuals) + # # last block one less + # for _ <- spec.depth, reduce: sample do + # input -> + # Axon.conv(input, spec.hidden_sizes[-1], + # kernel_size: 3, + # padding: [{1, 1}, {1, 1}], + # name: "first", + # initializer: :zero + # ) + # end + end + + defp control_net_mid_block(input, spec, opts) do + name = opts[:name] + + Axon.conv(input, List.last(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: join(name, "zero_conv"), + kernel_initializer: :zeros + ) + end + + defp control_net_embeddings(sample, spec) do + input = + Axon.conv(sample, hd(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "input_conv", + activation: :silu + ) + + block_in_channels = Enum.drop(spec.conditioning_embedding_out_channels, -1) + block_out_channels = Enum.drop(spec.conditioning_embedding_out_channels, 1) + + state = input + + sample = + for {in_channels, out_channels} <- Enum.zip(block_in_channels, block_out_channels), + reduce: state do + input -> + input + |> Axon.conv(in_channels, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "first", + activation: :silu + ) + |> Axon.conv(out_channels, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + strides: 2, + name: "second", + activation: :silu + ) + end + + Axon.conv(sample, hd(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "out_conv", + kernel_initializer: :zeros + ) + end + + defp down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, opts) do + name = opts[:name] + + blocks = + Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)]) + + in_channels = hd(spec.hidden_sizes) + down_block_residuals = [{sample, in_channels}] + + state = {sample, down_block_residuals, in_channels} + + {sample, down_block_residuals, _} = + for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks), + reduce: state do + {sample, down_block_residuals, in_channels} -> + last_block? = idx == length(spec.hidden_sizes) - 1 + + {sample, residuals} = + Diffusion.Layers.UNet.down_block_2d( + block_type, + sample, + timestep_embedding, + encoder_hidden_state, + depth: spec.depth, + in_channels: in_channels, + out_channels: out_channels, + add_downsample: not last_block?, + downsample_padding: spec.downsample_padding, + activation: spec.activation, + norm_epsilon: spec.group_norm_epsilon, + norm_num_groups: spec.group_norm_num_groups, + num_attention_heads: num_attention_heads, + use_linear_projection: spec.use_linear_projection, + name: join(name, idx) + ) + + {sample, down_block_residuals ++ Tuple.to_list(residuals), out_channels} + end + + {sample, List.to_tuple(down_block_residuals)} + end + + defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do + Diffusion.Layers.UNet.mid_cross_attention_block_2d( + hidden_state, + timesteps_embedding, + encoder_hidden_state, + channels: List.last(spec.hidden_sizes), + activation: spec.activation, + norm_epsilon: spec.group_norm_epsilon, + norm_num_groups: spec.group_norm_num_groups, + output_scale_factor: spec.mid_block_scale_factor, + num_attention_heads: spec |> num_attention_heads_per_block() |> List.last(), + use_linear_projection: spec.use_linear_projection, + name: opts[:name] + ) + end + + defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do + spec.num_attention_heads + end + + defp num_attention_heads_per_block(spec) when is_integer(spec.num_attention_heads) do + num_blocks = length(spec.down_block_types) + List.duplicate(spec.num_attention_heads, num_blocks) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + in_channels: {"in_channels", number()}, + out_channels: {"out_channels", number()}, + sample_size: {"sample_size", number()}, + center_input_sample: {"center_input_sample", boolean()}, + embedding_flip_sin_to_cos: {"flip_sin_to_cos", boolean()}, + embedding_frequency_correction_term: {"freq_shift", number()}, + hidden_sizes: {"block_out_channels", list(number())}, + depth: {"layers_per_block", number()}, + down_block_types: { + "down_block_types", + list( + mapping(%{ + "DownBlock2D" => :down_block, + "CrossAttnDownBlock2D" => :cross_attention_down_block + }) + ) + }, + up_block_types: { + "up_block_types", + list( + mapping(%{ + "UpBlock2D" => :up_block, + "CrossAttnUpBlock2D" => :cross_attention_up_block + }) + ) + }, + downsample_padding: {"downsample_padding", padding(2)}, + mid_block_scale_factor: {"mid_block_scale_factor", number()}, + num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])}, + cross_attention_size: {"cross_attention_dim", number()}, + use_linear_projection: {"use_linear_projection", boolean()}, + activation: {"act_fn", activation()}, + group_norm_num_groups: {"norm_num_groups", number()}, + group_norm_epsilon: {"norm_eps", number()}, + conditioning_embedding_out_channels: + {"conditioning_embedding_out_channels", list(number())} + ) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + alias Bumblebee.HuggingFace.Transformers + + def params_mapping(_spec) do + block_mapping = %{ + "transformers.{m}.norm" => "attentions.{m}.norm", + "transformers.{m}.input_projection" => "attentions.{m}.proj_in", + "transformers.{m}.output_projection" => "attentions.{m}.proj_out", + "transformers.{m}.blocks.{l}.self_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_q", + "transformers.{m}.blocks.{l}.self_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_k", + "transformers.{m}.blocks.{l}.self_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_v", + "transformers.{m}.blocks.{l}.self_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_out.0", + "transformers.{m}.blocks.{l}.cross_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_q", + "transformers.{m}.blocks.{l}.cross_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_k", + "transformers.{m}.blocks.{l}.cross_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_v", + "transformers.{m}.blocks.{l}.cross_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_out.0", + "transformers.{m}.blocks.{l}.ffn.intermediate" => + "attentions.{m}.transformer_blocks.{l}.ff.net.0.proj", + "transformers.{m}.blocks.{l}.ffn.output" => + "attentions.{m}.transformer_blocks.{l}.ff.net.2", + "transformers.{m}.blocks.{l}.self_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm1", + "transformers.{m}.blocks.{l}.cross_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm2", + "transformers.{m}.blocks.{l}.output_norm" => + "attentions.{m}.transformer_blocks.{l}.norm3", + "residual_blocks.{m}.timestep_projection" => "resnets.{m}.time_emb_proj", + "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1", + "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1", + "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2", + "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2", + "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut", + "downsamples.{m}.conv" => "downsamplers.{m}.conv", + "upsamples.{m}.conv" => "upsamplers.{m}.conv" + } + + blocks_mapping = + ["down_blocks.{n}", "mid_block", "up_blocks.{n}"] + |> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1)) + |> Enum.reduce(&Map.merge/2) + + %{ + "time_embedding.intermediate" => "time_embedding.linear_1", + "time_embedding.output" => "time_embedding.linear_2", + "input_conv" => "conv_in", + "output_norm" => "conv_norm_out", + "output_conv" => "conv_out" + } + |> Map.merge(blocks_mapping) + end + end +end diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex new file mode 100644 index 00000000..0ffe2a53 --- /dev/null +++ b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex @@ -0,0 +1,50 @@ +defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + @tag timeout: :infinity + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}, + module: Bumblebee.Diffusion.StableDiffusion.ControlNet, + architecture: :base + ) + + assert %Bumblebee.Diffusion.StableDiffusion.ControlNet{ + architecture: :base + } = spec + + inputs = %{ + "sample" => Nx.broadcast(0.5, {1, 512, 512, 4}), + "timestep" => Nx.tensor(1), + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.sample) == {1, 32, 32, 4} + + assert_all_close( + to_channels_first(outputs.sample)[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [ + [-1.0813, -0.5109, -0.1545], + [-0.8094, -1.2588, -0.8355], + [-0.9218, -1.2142, -0.6982] + ], + [ + [-0.2179, -0.2799, -1.0922], + [-0.9485, -0.8376, 0.0843], + [-0.9650, -0.7105, -0.3920] + ], + [[1.3359, 0.8373, -0.2392], [0.9448, -0.0478, 0.6881], [-0.0154, -0.5304, 0.2081]] + ] + ]), + atol: 1.0e-4 + ) + end +end From 4dd11d80c0080d79027d2915bcc1dfc170f09d19 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 14 Feb 2024 17:30:43 +0100 Subject: [PATCH 02/42] Param mapping --- .../diffusion/stable_diffusion/control_net.ex | 83 ++++++++++++------- .../stable_diffusion/control_net_test.ex | 5 +- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index a50a374d..daa5481e 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -3,7 +3,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do options = [ sample_size: [ - default: 512, + default: 64, doc: "the size of the input spatial dimensions" ], in_channels: [ @@ -144,11 +144,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do def input_template(spec) do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} + controlnet_conditioning_shape = {1, 512, 512, spec.in_channels} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), + "controlnet_conditioning" => Nx.template(controlnet_conditioning_shape, :f32), "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) } end @@ -162,10 +164,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + controlnet_conditioning_shape = {nil, 512, 512, spec.in_channels} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), + Axon.input("controlnet_conditioning", shape: controlnet_conditioning_shape), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) ]) end @@ -173,6 +177,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp core(inputs, spec) do sample = inputs["sample"] timestep = inputs["timestep"] + controlnet_conditioning = inputs["controlnet_conditioning"] encoder_hidden_state = inputs["encoder_hidden_state"] timestep = @@ -195,15 +200,17 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do ) sample = - Axon.conv(dbg(sample), 4, + Axon.conv(sample, hd(spec.hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], name: "input_conv" ) - control_net_cond_embeddings = control_net_embeddings(sample, spec) + control_net_cond_embeddings = + control_net_embeddings(controlnet_conditioning, spec, name: "controlnet_cond_embedding") - sample = Axon.add(sample, control_net_cond_embeddings) + sample = + Axon.add(sample, control_net_cond_embeddings, name: "add_sample_control_net_embeddings") {sample, down_block_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") @@ -214,16 +221,16 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do conditioning_scale = Axon.constant(1) down_block_residuals = - control_net_down_blocks(down_block_residuals, spec, name: "control_net.down_blocks") + control_net_down_blocks(down_block_residuals, spec, name: "controlnet_down_blocks") down_block_residuals = for residual <- Tuple.to_list(down_block_residuals) do - Axon.multiply(residual, conditioning_scale, name: "control_net.down_blocks") + Axon.multiply(residual, conditioning_scale, name: "conditioning_scale") end |> List.to_tuple() mid_block_residual = - control_net_mid_block(sample, spec, name: "control_net.mid_block") + control_net_mid_block(sample, spec, name: "controlnet_mid_block") |> Axon.multiply(conditioning_scale) %{ @@ -234,29 +241,18 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp control_net_down_blocks(down_block_residuals, spec, opts) do name = opts[:name] - # blocks = Enum.zip(spec.hidden_sizes, Tuple.to_list(down_block_residuals)) residuals = for {{residual, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_residuals)) do Axon.conv(residual, out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name |> join(i) |> join("zero_conv"), + name: name |> join(i) |> join("test"), kernel_initializer: :zeros ) end List.to_tuple(residuals) - # # last block one less - # for _ <- spec.depth, reduce: sample do - # input -> - # Axon.conv(input, spec.hidden_sizes[-1], - # kernel_size: 3, - # padding: [{1, 1}, {1, 1}], - # name: "first", - # initializer: :zero - # ) - # end end defp control_net_mid_block(input, spec, opts) do @@ -265,41 +261,43 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.conv(input, List.last(spec.hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: join(name, "zero_conv"), + name: name, kernel_initializer: :zeros ) end - defp control_net_embeddings(sample, spec) do - input = + defp control_net_embeddings(sample, spec, opts) do + name = opts[:name] + + state = Axon.conv(sample, hd(spec.hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: "input_conv", + name: join(name, "input_conv"), activation: :silu ) block_in_channels = Enum.drop(spec.conditioning_embedding_out_channels, -1) block_out_channels = Enum.drop(spec.conditioning_embedding_out_channels, 1) - state = input + channels = Enum.zip(block_in_channels, block_out_channels) sample = - for {in_channels, out_channels} <- Enum.zip(block_in_channels, block_out_channels), + for {{in_channels, out_channels}, i} <- Enum.with_index(channels), reduce: state do input -> input |> Axon.conv(in_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: "first", + name: name |> join("blocks") |> join(2 * i) |> join("t"), activation: :silu ) |> Axon.conv(out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], strides: 2, - name: "second", + name: name |> join("blocks") |> join(2 * i + 1) |> join("t"), activation: :silu ) end @@ -307,7 +305,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.conv(sample, hd(spec.hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: "out_conv", + name: join(name, "output_conv"), kernel_initializer: :zeros ) end @@ -472,10 +470,36 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do } blocks_mapping = - ["down_blocks.{n}", "mid_block", "up_blocks.{n}"] + ["down_blocks.{n}", "mid_block"] |> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1)) |> Enum.reduce(&Map.merge/2) + controlnet_mapping = %{ + "blocks.{n}.t" => %{ + "bias" => { + [{"controlnet_cond_embedding.blocks.{n}", "bias"}], + fn value -> value end + }, + "kernel" => { + [{"controlnet_cond_embedding.blocks.{n}", "weight"}], + fn value -> value end + } + }, + # "blocks.{n}.t" => "controlnet_cond_embedding.blocks.{n}", + # "controlnet_down_blocks.{m}.test" => "controlnet_down_blocks.{m}", + "controlnet_down_blocks.{m}.test" => %{ + "bias" => { + [{"controlnet_down_blocks.{m}", "bias"}], + fn value -> value end + }, + "kernel" => { + [{"controlnet_down_blocks.{m}", "weight"}], + fn value -> value end + } + }, + "controlnet_mid_block" => "controlnet_mid_block" + } + %{ "time_embedding.intermediate" => "time_embedding.linear_1", "time_embedding.output" => "time_embedding.linear_2", @@ -484,6 +508,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do "output_conv" => "conv_out" } |> Map.merge(blocks_mapping) + |> Map.merge(controlnet_mapping) end end end diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex index 0ffe2a53..09ba9b31 100644 --- a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex +++ b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex @@ -18,12 +18,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do } = spec inputs = %{ - "sample" => Nx.broadcast(0.5, {1, 512, 512, 4}), + "sample" => Nx.broadcast(0.5, {1, 64, 64, 4}), + "controlnet_conditioning" => Nx.broadcast(0.8, {1, 512, 512, 4}), "timestep" => Nx.tensor(1), "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) } - outputs = Axon.predict(model, params, inputs) + outputs = Axon.predict(model, params, inputs, debug: true) assert Nx.shape(outputs.sample) == {1, 32, 32, 4} From abdea6bc943b4cae1992704a2b2bb0a610c56107 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 15 Feb 2024 16:46:52 +0100 Subject: [PATCH 03/42] AnyDoor controlnet params --- .../diffusion/stable_diffusion/control_net.ex | 490 +++++++++++++++--- 1 file changed, 406 insertions(+), 84 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index daa5481e..a4f66c7f 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -247,7 +247,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.conv(residual, out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name |> join(i) |> join("test"), + name: name |> join(i) |> join("zero_conv"), kernel_initializer: :zeros ) end @@ -261,7 +261,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.conv(input, List.last(spec.hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name, + name: name |> join("zero_conv"), kernel_initializer: :zeros ) end @@ -270,7 +270,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do name = opts[:name] state = - Axon.conv(sample, hd(spec.hidden_sizes), + Axon.conv(sample, hd(spec.conditioning_embedding_out_channels), kernel_size: 3, padding: [{1, 1}, {1, 1}], name: join(name, "input_conv"), @@ -290,14 +290,14 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do |> Axon.conv(in_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name |> join("blocks") |> join(2 * i) |> join("t"), + name: name |> join(4 * i + 2) |> join("conv"), activation: :silu ) |> Axon.conv(out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], strides: 2, - name: name |> join("blocks") |> join(2 * i + 1) |> join("t"), + name: name |> join(4 * (i + 1)) |> join("conv"), activation: :silu ) end @@ -426,89 +426,411 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do end defimpl Bumblebee.HuggingFace.Transformers.Model do - alias Bumblebee.HuggingFace.Transformers def params_mapping(_spec) do - block_mapping = %{ - "transformers.{m}.norm" => "attentions.{m}.norm", - "transformers.{m}.input_projection" => "attentions.{m}.proj_in", - "transformers.{m}.output_projection" => "attentions.{m}.proj_out", - "transformers.{m}.blocks.{l}.self_attention.query" => - "attentions.{m}.transformer_blocks.{l}.attn1.to_q", - "transformers.{m}.blocks.{l}.self_attention.key" => - "attentions.{m}.transformer_blocks.{l}.attn1.to_k", - "transformers.{m}.blocks.{l}.self_attention.value" => - "attentions.{m}.transformer_blocks.{l}.attn1.to_v", - "transformers.{m}.blocks.{l}.self_attention.output" => - "attentions.{m}.transformer_blocks.{l}.attn1.to_out.0", - "transformers.{m}.blocks.{l}.cross_attention.query" => - "attentions.{m}.transformer_blocks.{l}.attn2.to_q", - "transformers.{m}.blocks.{l}.cross_attention.key" => - "attentions.{m}.transformer_blocks.{l}.attn2.to_k", - "transformers.{m}.blocks.{l}.cross_attention.value" => - "attentions.{m}.transformer_blocks.{l}.attn2.to_v", - "transformers.{m}.blocks.{l}.cross_attention.output" => - "attentions.{m}.transformer_blocks.{l}.attn2.to_out.0", - "transformers.{m}.blocks.{l}.ffn.intermediate" => - "attentions.{m}.transformer_blocks.{l}.ff.net.0.proj", - "transformers.{m}.blocks.{l}.ffn.output" => - "attentions.{m}.transformer_blocks.{l}.ff.net.2", - "transformers.{m}.blocks.{l}.self_attention_norm" => - "attentions.{m}.transformer_blocks.{l}.norm1", - "transformers.{m}.blocks.{l}.cross_attention_norm" => - "attentions.{m}.transformer_blocks.{l}.norm2", - "transformers.{m}.blocks.{l}.output_norm" => - "attentions.{m}.transformer_blocks.{l}.norm3", - "residual_blocks.{m}.timestep_projection" => "resnets.{m}.time_emb_proj", - "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1", - "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1", - "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2", - "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2", - "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut", - "downsamples.{m}.conv" => "downsamplers.{m}.conv", - "upsamples.{m}.conv" => "upsamplers.{m}.conv" - } - - blocks_mapping = - ["down_blocks.{n}", "mid_block"] - |> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1)) - |> Enum.reduce(&Map.merge/2) - - controlnet_mapping = %{ - "blocks.{n}.t" => %{ - "bias" => { - [{"controlnet_cond_embedding.blocks.{n}", "bias"}], - fn value -> value end - }, - "kernel" => { - [{"controlnet_cond_embedding.blocks.{n}", "weight"}], - fn value -> value end - } - }, - # "blocks.{n}.t" => "controlnet_cond_embedding.blocks.{n}", - # "controlnet_down_blocks.{m}.test" => "controlnet_down_blocks.{m}", - "controlnet_down_blocks.{m}.test" => %{ - "bias" => { - [{"controlnet_down_blocks.{m}", "bias"}], - fn value -> value end - }, - "kernel" => { - [{"controlnet_down_blocks.{m}", "weight"}], - fn value -> value end - } - }, - "controlnet_mid_block" => "controlnet_mid_block" - } - + # controlnet_cond_embedding_mapping = %{ - "time_embedding.intermediate" => "time_embedding.linear_1", - "time_embedding.output" => "time_embedding.linear_2", - "input_conv" => "conv_in", - "output_norm" => "conv_norm_out", - "output_conv" => "conv_out" + "controlnet_cond_embedding.input_conv" => "control_model.input_hint_block.0", + "controlnet_cond_embedding.output_conv" => "control_model.input_hint_block.14", + "controlnet_cond_embedding.{l}.conv" => "control_model.input_hint_block.{l}", + + # controlnet_down_blocks_mapping = %{ + "controlnet_down_blocks.{m}.zero_conv" => "control_model.zero_convs.{m}.0", + + # controlnet_mid_block_mapping = %{ + "controlnet_mid_block.zero_conv" => "control_model.middle_block_out.0", + + # controlnet_mapping = %{ + "input_conv" => "control_model.input_blocks.0.0", + + # down_blocks_mapping = %{ + # down_blocks + "down_blocks.0.transformers.0.norm" => "control_model.input_blocks.1.1.norm", + "down_blocks.0.transformers.1.norm" => "control_model.input_blocks.2.1.norm", + "down_blocks.1.transformers.0.norm" => "control_model.input_blocks.4.1.norm", + "down_blocks.1.transformers.1.norm" => "control_model.input_blocks.5.1.norm", + "down_blocks.2.transformers.0.norm" => "control_model.input_blocks.7.1.norm", + "down_blocks.2.transformers.1.norm" => "control_model.input_blocks.8.1.norm", + + # self attention 0 0 + "down_blocks.0.transformers.0.blocks.0.self_attention_norm" => + "control_model.input_blocks.1.1.transformer_blocks.0.norm1", + "down_blocks.0.transformers.0.blocks.0.self_attention.key" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k", + "down_blocks.0.transformers.0.blocks.0.self_attention.value" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v", + "down_blocks.0.transformers.0.blocks.0.self_attention.query" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q", + "down_blocks.0.transformers.0.blocks.0.self_attention.output" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0", + + # self attention 0 1 + "down_blocks.0.transformers.1.blocks.0.self_attention_norm" => + "control_model.input_blocks.2.1.transformer_blocks.0.norm1", + "down_blocks.0.transformers.1.blocks.0.self_attention.key" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k", + "down_blocks.0.transformers.1.blocks.0.self_attention.value" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v", + "down_blocks.0.transformers.1.blocks.0.self_attention.query" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q", + "down_blocks.0.transformers.1.blocks.0.self_attention.output" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0", + + # self attention 1 0 + "down_blocks.1.transformers.0.blocks.0.self_attention_norm" => + "control_model.input_blocks.4.1.transformer_blocks.0.norm1", + "down_blocks.1.transformers.0.blocks.0.self_attention.key" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k", + "down_blocks.1.transformers.0.blocks.0.self_attention.value" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v", + "down_blocks.1.transformers.0.blocks.0.self_attention.query" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q", + "down_blocks.1.transformers.0.blocks.0.self_attention.output" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0", + + # self attention 1 1 + "down_blocks.1.transformers.1.blocks.0.self_attention_norm" => + "control_model.input_blocks.5.1.transformer_blocks.0.norm1", + "down_blocks.1.transformers.1.blocks.0.self_attention.key" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k", + "down_blocks.1.transformers.1.blocks.0.self_attention.value" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v", + "down_blocks.1.transformers.1.blocks.0.self_attention.query" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q", + "down_blocks.1.transformers.1.blocks.0.self_attention.output" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0", + + # self attention 2 0 + "down_blocks.2.transformers.0.blocks.0.self_attention_norm" => + "control_model.input_blocks.7.1.transformer_blocks.0.norm1", + "down_blocks.2.transformers.0.blocks.0.self_attention.key" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k", + "down_blocks.2.transformers.0.blocks.0.self_attention.value" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v", + "down_blocks.2.transformers.0.blocks.0.self_attention.query" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q", + "down_blocks.2.transformers.0.blocks.0.self_attention.output" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0", + + # self attention 2 1 + "down_blocks.2.transformers.1.blocks.0.self_attention_norm" => + "control_model.input_blocks.8.1.transformer_blocks.0.norm1", + "down_blocks.2.transformers.1.blocks.0.self_attention.key" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k", + "down_blocks.2.transformers.1.blocks.0.self_attention.value" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v", + "down_blocks.2.transformers.1.blocks.0.self_attention.query" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q", + "down_blocks.2.transformers.1.blocks.0.self_attention.output" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0", + + # cross attention 0 0 + "down_blocks.0.transformers.0.blocks.0.cross_attention_norm" => + "control_model.input_blocks.1.1.transformer_blocks.0.norm2", + "down_blocks.0.transformers.0.blocks.0.cross_attention.key" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k", + "down_blocks.0.transformers.0.blocks.0.cross_attention.value" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v", + "down_blocks.0.transformers.0.blocks.0.cross_attention.query" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q", + "down_blocks.0.transformers.0.blocks.0.cross_attention.output" => + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0", + + # cross attention 0 1 + "down_blocks.0.transformers.1.blocks.0.cross_attention_norm" => + "control_model.input_blocks.2.1.transformer_blocks.0.norm2", + "down_blocks.0.transformers.1.blocks.0.cross_attention.key" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k", + "down_blocks.0.transformers.1.blocks.0.cross_attention.value" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v", + "down_blocks.0.transformers.1.blocks.0.cross_attention.query" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q", + "down_blocks.0.transformers.1.blocks.0.cross_attention.output" => + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0", + + # cross attention 1 0 + "down_blocks.1.transformers.0.blocks.0.cross_attention_norm" => + "control_model.input_blocks.4.1.transformer_blocks.0.norm2", + "down_blocks.1.transformers.0.blocks.0.cross_attention.key" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k", + "down_blocks.1.transformers.0.blocks.0.cross_attention.value" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v", + "down_blocks.1.transformers.0.blocks.0.cross_attention.query" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q", + "down_blocks.1.transformers.0.blocks.0.cross_attention.output" => + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0", + + # cross attention 1 1 + "down_blocks.1.transformers.1.blocks.0.cross_attention_norm" => + "control_model.input_blocks.5.1.transformer_blocks.0.norm2", + "down_blocks.1.transformers.1.blocks.0.cross_attention.key" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k", + "down_blocks.1.transformers.1.blocks.0.cross_attention.value" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v", + "down_blocks.1.transformers.1.blocks.0.cross_attention.query" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q", + "down_blocks.1.transformers.1.blocks.0.cross_attention.output" => + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0", + + # cross attention 2 0 + "down_blocks.2.transformers.0.blocks.0.cross_attention_norm" => + "control_model.input_blocks.7.1.transformer_blocks.0.norm2", + "down_blocks.2.transformers.0.blocks.0.cross_attention.key" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k", + "down_blocks.2.transformers.0.blocks.0.cross_attention.value" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v", + "down_blocks.2.transformers.0.blocks.0.cross_attention.query" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q", + "down_blocks.2.transformers.0.blocks.0.cross_attention.output" => + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0", + + # cross attention 2 1 + "down_blocks.2.transformers.1.blocks.0.cross_attention_norm" => + "control_model.input_blocks.8.1.transformer_blocks.0.norm2", + "down_blocks.2.transformers.1.blocks.0.cross_attention.key" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k", + "down_blocks.2.transformers.1.blocks.0.cross_attention.value" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v", + "down_blocks.2.transformers.1.blocks.0.cross_attention.query" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q", + "down_blocks.2.transformers.1.blocks.0.cross_attention.output" => + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0", + + # ffn 0 0 + "down_blocks.0.transformers.0.blocks.0.ffn.intermediate" => + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.0.transformers.0.blocks.0.ffn.output" => + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2", + + # ffn 0 1 + "down_blocks.0.transformers.1.blocks.0.ffn.intermediate" => + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.0.transformers.1.blocks.0.ffn.output" => + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2", + + # ffn 1 0 + "down_blocks.1.transformers.0.blocks.0.ffn.intermediate" => + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.1.transformers.0.blocks.0.ffn.output" => + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2", + + # ffn 1 1 + "down_blocks.1.transformers.1.blocks.0.ffn.intermediate" => + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.1.transformers.1.blocks.0.ffn.output" => + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2", + + # ffn 2 0 + "down_blocks.2.transformers.0.blocks.0.ffn.intermediate" => + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.2.transformers.0.blocks.0.ffn.output" => + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2", + + # ffn 2 1 + "down_blocks.2.transformers.1.blocks.0.ffn.intermediate" => + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj", + "down_blocks.2.transformers.1.blocks.0.ffn.output" => + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2", + + # residuals 0 0 + "down_blocks.0.residual_blocks.0.norm_1" => "control_model.input_blocks.1.0.in_layers.0", + "down_blocks.0.residual_blocks.0.conv_1" => "control_model.input_blocks.1.0.in_layers.2", + "down_blocks.0.residual_blocks.0.timestep_projection" => + "control_model.input_blocks.1.0.emb_layers.1", + "down_blocks.0.residual_blocks.0.norm_2" => "control_model.input_blocks.1.0.out_layers.0", + "down_blocks.0.residual_blocks.0.conv_2" => "control_model.input_blocks.1.0.out_layers.3", + + # residuals 0 1 + "down_blocks.0.residual_blocks.1.norm_1" => "control_model.input_blocks.2.0.in_layers.0", + "down_blocks.0.residual_blocks.1.conv_1" => "control_model.input_blocks.2.0.in_layers.2", + "down_blocks.0.residual_blocks.1.timestep_projection" => + "control_model.input_blocks.2.0.emb_layers.1", + "down_blocks.0.residual_blocks.1.norm_2" => "control_model.input_blocks.2.0.out_layers.0", + "down_blocks.0.residual_blocks.1.conv_2" => "control_model.input_blocks.2.0.out_layers.3", + + # residuals 1 0 + "down_blocks.1.residual_blocks.0.norm_1" => "control_model.input_blocks.4.0.in_layers.0", + "down_blocks.1.residual_blocks.0.conv_1" => "control_model.input_blocks.4.0.in_layers.2", + "down_blocks.1.residual_blocks.0.timestep_projection" => + "control_model.input_blocks.4.0.emb_layers.1", + "down_blocks.1.residual_blocks.0.norm_2" => "control_model.input_blocks.4.0.out_layers.0", + "down_blocks.1.residual_blocks.0.conv_2" => "control_model.input_blocks.4.0.out_layers.3", + + # residuals 1 1 + "down_blocks.1.residual_blocks.1.norm_1" => "control_model.input_blocks.5.0.in_layers.0", + "down_blocks.1.residual_blocks.1.conv_1" => "control_model.input_blocks.5.0.in_layers.2", + "down_blocks.1.residual_blocks.1.timestep_projection" => + "control_model.input_blocks.5.0.emb_layers.1", + "down_blocks.1.residual_blocks.1.norm_2" => "control_model.input_blocks.5.0.out_layers.0", + "down_blocks.1.residual_blocks.1.conv_2" => "control_model.input_blocks.5.0.out_layers.3", + + # residuals 2 0 + "down_blocks.2.residual_blocks.0.norm_1" => "control_model.input_blocks.7.0.in_layers.0", + "down_blocks.2.residual_blocks.0.conv_1" => "control_model.input_blocks.7.0.in_layers.2", + "down_blocks.2.residual_blocks.0.timestep_projection" => + "control_model.input_blocks.7.0.emb_layers.1", + "down_blocks.2.residual_blocks.0.norm_2" => "control_model.input_blocks.7.0.out_layers.0", + "down_blocks.2.residual_blocks.0.conv_2" => "control_model.input_blocks.7.0.out_layers.3", + + # residuals 2 1 + "down_blocks.2.residual_blocks.1.norm_1" => "control_model.input_blocks.8.0.in_layers.0", + "down_blocks.2.residual_blocks.1.conv_1" => "control_model.input_blocks.8.0.in_layers.2", + "down_blocks.2.residual_blocks.1.timestep_projection" => + "control_model.input_blocks.8.0.emb_layers.1", + "down_blocks.2.residual_blocks.1.norm_2" => "control_model.input_blocks.8.0.out_layers.0", + "down_blocks.2.residual_blocks.1.conv_2" => "control_model.input_blocks.8.0.out_layers.3", + + # residuals 3 0 + "down_blocks.3.residual_blocks.0.norm_1" => "control_model.input_blocks.10.0.in_layers.0", + "down_blocks.3.residual_blocks.0.conv_1" => "control_model.input_blocks.10.0.in_layers.2", + "down_blocks.3.residual_blocks.0.timestep_projection" => + "control_model.input_blocks.10.0.emb_layers.1", + "down_blocks.3.residual_blocks.0.norm_2" => + "control_model.input_blocks.10.0.out_layers.0", + "down_blocks.3.residual_blocks.0.conv_2" => + "control_model.input_blocks.10.0.out_layers.3", + + # residuals 3 1 + "down_blocks.3.residual_blocks.1.norm_1" => "control_model.input_blocks.11.0.in_layers.0", + "down_blocks.3.residual_blocks.1.conv_1" => "control_model.input_blocks.11.0.in_layers.2", + "down_blocks.3.residual_blocks.1.timestep_projection" => + "control_model.input_blocks.11.0.emb_layers.1", + "down_blocks.3.residual_blocks.1.norm_2" => + "control_model.input_blocks.11.0.out_layers.0", + "down_blocks.3.residual_blocks.1.conv_2" => + "control_model.input_blocks.11.0.out_layers.3", + + # projection 0 0 + "down_blocks.0.transformers.0.input_projection" => + "control_model.input_blocks.1.1.proj_in", + "down_blocks.0.transformers.0.output_projection" => + "control_model.input_blocks.1.1.proj_out", + + # projection 0 1 + "down_blocks.0.transformers.1.input_projection" => + "control_model.input_blocks.2.1.proj_in", + "down_blocks.0.transformers.1.output_projection" => + "control_model.input_blocks.2.1.proj_out", + + # projection 1 0 + "down_blocks.1.transformers.0.input_projection" => + "control_model.input_blocks.4.1.proj_in", + "down_blocks.1.transformers.0.output_projection" => + "control_model.input_blocks.4.1.proj_out", + + # projection 1 1 + "down_blocks.1.transformers.1.input_projection" => + "control_model.input_blocks.5.1.proj_in", + "down_blocks.1.transformers.1.output_projection" => + "control_model.input_blocks.5.1.proj_out", + + # projection 2 0 + "down_blocks.2.transformers.0.input_projection" => + "control_model.input_blocks.7.1.proj_in", + "down_blocks.2.transformers.0.output_projection" => + "control_model.input_blocks.7.1.proj_out", + + # projection 2 1 + "down_blocks.2.transformers.1.input_projection" => + "control_model.input_blocks.8.1.proj_in", + "down_blocks.2.transformers.1.output_projection" => + "control_model.input_blocks.8.1.proj_out", + + # shortcut + "down_blocks.1.residual_blocks.0.shortcut.projection" => + "control_model.input_blocks.4.0.skip_connection", + "down_blocks.2.residual_blocks.0.shortcut.projection" => + "control_model.input_blocks.7.0.skip_connection", + + # downsamples + "down_blocks.0.downsamples.0.conv" => "control_model.input_blocks.3.0.op", + "down_blocks.1.downsamples.0.conv" => "control_model.input_blocks.6.0.op", + "down_blocks.2.downsamples.0.conv" => "control_model.input_blocks.9.0.op", + + # out 0 0 + "down_blocks.0.transformers.0.blocks.0.output_norm" => + "control_model.input_blocks.1.1.transformer_blocks.0.norm3", + + # out 0 1 + "down_blocks.0.transformers.1.blocks.0.output_norm" => + "control_model.input_blocks.2.1.transformer_blocks.0.norm3", + + # out 1 0 + "down_blocks.1.transformers.0.blocks.0.output_norm" => + "control_model.input_blocks.4.1.transformer_blocks.0.norm3", + + # out 1 1 + "down_blocks.1.transformers.1.blocks.0.output_norm" => + "control_model.input_blocks.5.1.transformer_blocks.0.norm3", + + # out 2 0 + "down_blocks.2.transformers.0.blocks.0.output_norm" => + "control_model.input_blocks.7.1.transformer_blocks.0.norm3", + + # out 2 1 + "down_blocks.2.transformers.1.blocks.0.output_norm" => + "control_model.input_blocks.8.1.transformer_blocks.0.norm3", + + # mid_block_mapping = %{ + # mid_block + "mid_block.transformers.0.norm" => "control_model.middle_block.1.norm", + # self attention + "mid_block.transformers.0.blocks.0.self_attention_norm" => + "control_model.middle_block.1.transformer_blocks.0.norm1", + "mid_block.transformers.0.blocks.0.self_attention.key" => + "control_model.middle_block.1.transformer_blocks.0.attn1.to_k", + "mid_block.transformers.0.blocks.0.self_attention.value" => + "control_model.middle_block.1.transformer_blocks.0.attn1.to_v", + "mid_block.transformers.0.blocks.0.self_attention.query" => + "control_model.middle_block.1.transformer_blocks.0.attn1.to_q", + "mid_block.transformers.0.blocks.0.self_attention.output" => + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0", + + # cross attention + "mid_block.transformers.0.blocks.0.cross_attention_norm" => + "control_model.middle_block.1.transformer_blocks.0.norm2", + "mid_block.transformers.0.blocks.0.cross_attention.key" => + "control_model.middle_block.1.transformer_blocks.0.attn2.to_k", + "mid_block.transformers.0.blocks.0.cross_attention.value" => + "control_model.middle_block.1.transformer_blocks.0.attn2.to_v", + "mid_block.transformers.0.blocks.0.cross_attention.query" => + "control_model.middle_block.1.transformer_blocks.0.attn2.to_q", + "mid_block.transformers.0.blocks.0.cross_attention.output" => + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0", + + # ffn + "mid_block.transformers.0.blocks.0.ffn.intermediate" => + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj", + "mid_block.transformers.0.blocks.0.ffn.output" => + "control_model.middle_block.1.transformer_blocks.0.ff.net.2", + + # residuals 0 + "mid_block.residual_blocks.0.norm_1" => "control_model.middle_block.0.in_layers.0", + "mid_block.residual_blocks.0.conv_1" => "control_model.middle_block.0.in_layers.2", + "mid_block.residual_blocks.0.timestep_projection" => + "control_model.middle_block.0.emb_layers.1", + "mid_block.residual_blocks.0.norm_2" => "control_model.middle_block.0.out_layers.0", + "mid_block.residual_blocks.0.conv_2" => "control_model.middle_block.0.out_layers.3", + # residuals 1 + "mid_block.residual_blocks.1.norm_1" => "control_model.middle_block.2.in_layers.0", + "mid_block.residual_blocks.1.conv_1" => "control_model.middle_block.2.in_layers.2", + "mid_block.residual_blocks.1.timestep_projection" => + "control_model.middle_block.2.emb_layers.1", + "mid_block.residual_blocks.1.norm_2" => "control_model.middle_block.2.out_layers.0", + "mid_block.residual_blocks.1.conv_2" => "control_model.middle_block.2.out_layers.3", + + # projection + "mid_block.transformers.0.input_projection" => "control_model.middle_block.1.proj_in", + "mid_block.transformers.0.output_projection" => "control_model.middle_block.1.proj_out", + + # out + "mid_block.transformers.0.blocks.0.output_norm" => + "control_model.middle_block.1.transformer_blocks.0.norm3", + + # others + "time_embedding.intermediate" => "control_model.time_embed.0", + "time_embedding.output" => "control_model.time_embed.2" } - |> Map.merge(blocks_mapping) - |> Map.merge(controlnet_mapping) end end end From c9ce1d2fa740b3adadd79b02091a457cf6b49f7a Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 15 Feb 2024 17:16:54 +0100 Subject: [PATCH 04/42] No complaints from anydoor controlnet params --- .../diffusion/stable_diffusion/control_net.ex | 129 ++++++++++++++---- 1 file changed, 99 insertions(+), 30 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index a4f66c7f..93bac062 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -69,7 +69,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do "the number of attention heads for each attention layer. Optionally can be a list with one number per block" ], cross_attention_size: [ - default: 1280, + default: 1024, doc: "the dimensionality of the cross attention features" ], use_linear_projection: [ @@ -245,7 +245,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do residuals = for {{residual, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_residuals)) do Axon.conv(residual, out_channels, - kernel_size: 3, + kernel_size: 1, padding: [{1, 1}, {1, 1}], name: name |> join(i) |> join("zero_conv"), kernel_initializer: :zeros @@ -259,7 +259,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do name = opts[:name] Axon.conv(input, List.last(spec.hidden_sizes), - kernel_size: 3, + kernel_size: 1, padding: [{1, 1}, {1, 1}], name: name |> join("zero_conv"), kernel_initializer: :zeros @@ -426,7 +426,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do # controlnet_cond_embedding_mapping = %{ @@ -701,40 +700,100 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do "control_model.input_blocks.11.0.out_layers.3", # projection 0 0 - "down_blocks.0.transformers.0.input_projection" => - "control_model.input_blocks.1.1.proj_in", - "down_blocks.0.transformers.0.output_projection" => - "control_model.input_blocks.1.1.proj_out", + "down_blocks.0.transformers.0.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.1.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.1.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.0.transformers.0.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.1.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.1.1.proj_out", "bias"}], fn [value] -> value end} + }, # projection 0 1 - "down_blocks.0.transformers.1.input_projection" => - "control_model.input_blocks.2.1.proj_in", - "down_blocks.0.transformers.1.output_projection" => - "control_model.input_blocks.2.1.proj_out", + "down_blocks.0.transformers.1.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.2.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.2.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.0.transformers.1.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.2.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.2.1.proj_out", "bias"}], fn [value] -> value end} + }, # projection 1 0 - "down_blocks.1.transformers.0.input_projection" => - "control_model.input_blocks.4.1.proj_in", - "down_blocks.1.transformers.0.output_projection" => - "control_model.input_blocks.4.1.proj_out", + "down_blocks.1.transformers.0.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.4.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.4.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.1.transformers.0.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.4.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.4.1.proj_out", "bias"}], fn [value] -> value end} + }, # projection 1 1 - "down_blocks.1.transformers.1.input_projection" => - "control_model.input_blocks.5.1.proj_in", - "down_blocks.1.transformers.1.output_projection" => - "control_model.input_blocks.5.1.proj_out", + "down_blocks.1.transformers.1.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.5.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.5.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.1.transformers.1.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.5.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.5.1.proj_out", "bias"}], fn [value] -> value end} + }, # projection 2 0 - "down_blocks.2.transformers.0.input_projection" => - "control_model.input_blocks.7.1.proj_in", - "down_blocks.2.transformers.0.output_projection" => - "control_model.input_blocks.7.1.proj_out", + "down_blocks.2.transformers.0.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.7.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.7.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.2.transformers.0.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.7.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.7.1.proj_out", "bias"}], fn [value] -> value end} + }, # projection 2 1 - "down_blocks.2.transformers.1.input_projection" => - "control_model.input_blocks.8.1.proj_in", - "down_blocks.2.transformers.1.output_projection" => - "control_model.input_blocks.8.1.proj_out", + "down_blocks.2.transformers.1.input_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.8.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.8.1.proj_in", "bias"}], fn [value] -> value end} + }, + "down_blocks.2.transformers.1.output_projection" => %{ + "kernel" => + {[{"control_model.input_blocks.8.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => + {[{"control_model.input_blocks.8.1.proj_out", "bias"}], fn [value] -> value end} + }, # shortcut "down_blocks.1.residual_blocks.0.shortcut.projection" => @@ -820,8 +879,18 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do "mid_block.residual_blocks.1.conv_2" => "control_model.middle_block.2.out_layers.3", # projection - "mid_block.transformers.0.input_projection" => "control_model.middle_block.1.proj_in", - "mid_block.transformers.0.output_projection" => "control_model.middle_block.1.proj_out", + "mid_block.transformers.0.input_projection" => %{ + "kernel" => + {[{"control_model.middle_block.1.proj_in", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => {[{"control_model.middle_block.1.proj_in", "bias"}], fn [value] -> value end} + }, + "mid_block.transformers.0.output_projection" => %{ + "kernel" => + {[{"control_model.middle_block.1.proj_out", "weight"}], + fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, + "bias" => {[{"control_model.middle_block.1.proj_out", "bias"}], fn [value] -> value end} + }, # out "mid_block.transformers.0.blocks.0.output_norm" => From fa49ae9268ff1c905f049df7882bdb36ee395f6b Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 16 Feb 2024 17:49:35 +0100 Subject: [PATCH 05/42] Huggingface params --- .../diffusion/stable_diffusion/control_net.ex | 554 +++--------------- .../stable_diffusion/control_net_test.ex | 36 +- 2 files changed, 84 insertions(+), 506 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index 93bac062..c93f2e09 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -144,7 +144,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do def input_template(spec) do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} - controlnet_conditioning_shape = {1, 512, 512, spec.in_channels} + controlnet_conditioning_shape = {1, 512, 512, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} %{ @@ -164,7 +164,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - controlnet_conditioning_shape = {nil, 512, 512, spec.in_channels} + controlnet_conditioning_shape = {nil, 512, 512, 3} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), @@ -212,7 +212,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do sample = Axon.add(sample, control_net_cond_embeddings, name: "add_sample_control_net_embeddings") - {sample, down_block_residuals} = + {sample, down_blocks_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") sample = @@ -220,11 +220,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do conditioning_scale = Axon.constant(1) - down_block_residuals = - control_net_down_blocks(down_block_residuals, spec, name: "controlnet_down_blocks") + down_blocks_residuals = + control_net_down_blocks(down_blocks_residuals, name: "controlnet_down_blocks") - down_block_residuals = - for residual <- Tuple.to_list(down_block_residuals) do + down_blocks_residuals = + for residual <- Tuple.to_list(down_blocks_residuals) do Axon.multiply(residual, conditioning_scale, name: "conditioning_scale") end |> List.to_tuple() @@ -234,19 +234,18 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do |> Axon.multiply(conditioning_scale) %{ - down_block_residuals: down_block_residuals, + down_blocks_residuals: down_blocks_residuals, mid_block_residual: mid_block_residual } end - defp control_net_down_blocks(down_block_residuals, spec, opts) do + defp control_net_down_blocks(down_block_residuals, opts) do name = opts[:name] residuals = for {{residual, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_residuals)) do Axon.conv(residual, out_channels, kernel_size: 1, - padding: [{1, 1}, {1, 1}], name: name |> join(i) |> join("zero_conv"), kernel_initializer: :zeros ) @@ -260,7 +259,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.conv(input, List.last(spec.hidden_sizes), kernel_size: 1, - padding: [{1, 1}, {1, 1}], name: name |> join("zero_conv"), kernel_initializer: :zeros ) @@ -290,14 +288,14 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do |> Axon.conv(in_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name |> join(4 * i + 2) |> join("conv"), + name: name |> join(2 * i) |> join("conv"), activation: :silu ) |> Axon.conv(out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], strides: 2, - name: name |> join(4 * (i + 1)) |> join("conv"), + name: name |> join(2 * i + 1) |> join("conv"), activation: :silu ) end @@ -426,480 +424,68 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do end defimpl Bumblebee.HuggingFace.Transformers.Model do + alias Bumblebee.HuggingFace.Transformers + def params_mapping(_spec) do - # controlnet_cond_embedding_mapping = + block_mapping = %{ + "transformers.{m}.norm" => "attentions.{m}.norm", + "transformers.{m}.input_projection" => "attentions.{m}.proj_in", + "transformers.{m}.output_projection" => "attentions.{m}.proj_out", + "transformers.{m}.blocks.{l}.self_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_q", + "transformers.{m}.blocks.{l}.self_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_k", + "transformers.{m}.blocks.{l}.self_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_v", + "transformers.{m}.blocks.{l}.self_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn1.to_out.0", + "transformers.{m}.blocks.{l}.cross_attention.query" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_q", + "transformers.{m}.blocks.{l}.cross_attention.key" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_k", + "transformers.{m}.blocks.{l}.cross_attention.value" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_v", + "transformers.{m}.blocks.{l}.cross_attention.output" => + "attentions.{m}.transformer_blocks.{l}.attn2.to_out.0", + "transformers.{m}.blocks.{l}.ffn.intermediate" => + "attentions.{m}.transformer_blocks.{l}.ff.net.0.proj", + "transformers.{m}.blocks.{l}.ffn.output" => + "attentions.{m}.transformer_blocks.{l}.ff.net.2", + "transformers.{m}.blocks.{l}.self_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm1", + "transformers.{m}.blocks.{l}.cross_attention_norm" => + "attentions.{m}.transformer_blocks.{l}.norm2", + "transformers.{m}.blocks.{l}.output_norm" => + "attentions.{m}.transformer_blocks.{l}.norm3", + "residual_blocks.{m}.timestep_projection" => "resnets.{m}.time_emb_proj", + "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1", + "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1", + "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2", + "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2", + "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut", + "downsamples.{m}.conv" => "downsamplers.{m}.conv" + } + + blocks_mapping = + ["down_blocks.{n}", "mid_block"] + |> Enum.map(&Transformers.Utils.prefix_params_mapping(block_mapping, &1, &1)) + |> Enum.reduce(&Map.merge/2) + + controlnet = %{ + "controlnet_down_blocks.{m}.zero_conv" => "controlnet_down_blocks.{m}", + "controlnet_cond_embedding.input_conv" => "controlnet_cond_embedding.conv_in", + "controlnet_cond_embedding.{m}.conv" => "controlnet_cond_embedding.blocks.{m}", + "controlnet_cond_embedding.output_conv" => "controlnet_cond_embedding.conv_out", + "controlnet_mid_block.zero_conv" => "controlnet_mid_block" + } + %{ - "controlnet_cond_embedding.input_conv" => "control_model.input_hint_block.0", - "controlnet_cond_embedding.output_conv" => "control_model.input_hint_block.14", - "controlnet_cond_embedding.{l}.conv" => "control_model.input_hint_block.{l}", - - # controlnet_down_blocks_mapping = %{ - "controlnet_down_blocks.{m}.zero_conv" => "control_model.zero_convs.{m}.0", - - # controlnet_mid_block_mapping = %{ - "controlnet_mid_block.zero_conv" => "control_model.middle_block_out.0", - - # controlnet_mapping = %{ - "input_conv" => "control_model.input_blocks.0.0", - - # down_blocks_mapping = %{ - # down_blocks - "down_blocks.0.transformers.0.norm" => "control_model.input_blocks.1.1.norm", - "down_blocks.0.transformers.1.norm" => "control_model.input_blocks.2.1.norm", - "down_blocks.1.transformers.0.norm" => "control_model.input_blocks.4.1.norm", - "down_blocks.1.transformers.1.norm" => "control_model.input_blocks.5.1.norm", - "down_blocks.2.transformers.0.norm" => "control_model.input_blocks.7.1.norm", - "down_blocks.2.transformers.1.norm" => "control_model.input_blocks.8.1.norm", - - # self attention 0 0 - "down_blocks.0.transformers.0.blocks.0.self_attention_norm" => - "control_model.input_blocks.1.1.transformer_blocks.0.norm1", - "down_blocks.0.transformers.0.blocks.0.self_attention.key" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k", - "down_blocks.0.transformers.0.blocks.0.self_attention.value" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v", - "down_blocks.0.transformers.0.blocks.0.self_attention.query" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q", - "down_blocks.0.transformers.0.blocks.0.self_attention.output" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0", - - # self attention 0 1 - "down_blocks.0.transformers.1.blocks.0.self_attention_norm" => - "control_model.input_blocks.2.1.transformer_blocks.0.norm1", - "down_blocks.0.transformers.1.blocks.0.self_attention.key" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k", - "down_blocks.0.transformers.1.blocks.0.self_attention.value" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v", - "down_blocks.0.transformers.1.blocks.0.self_attention.query" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q", - "down_blocks.0.transformers.1.blocks.0.self_attention.output" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0", - - # self attention 1 0 - "down_blocks.1.transformers.0.blocks.0.self_attention_norm" => - "control_model.input_blocks.4.1.transformer_blocks.0.norm1", - "down_blocks.1.transformers.0.blocks.0.self_attention.key" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k", - "down_blocks.1.transformers.0.blocks.0.self_attention.value" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v", - "down_blocks.1.transformers.0.blocks.0.self_attention.query" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q", - "down_blocks.1.transformers.0.blocks.0.self_attention.output" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0", - - # self attention 1 1 - "down_blocks.1.transformers.1.blocks.0.self_attention_norm" => - "control_model.input_blocks.5.1.transformer_blocks.0.norm1", - "down_blocks.1.transformers.1.blocks.0.self_attention.key" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k", - "down_blocks.1.transformers.1.blocks.0.self_attention.value" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v", - "down_blocks.1.transformers.1.blocks.0.self_attention.query" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q", - "down_blocks.1.transformers.1.blocks.0.self_attention.output" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0", - - # self attention 2 0 - "down_blocks.2.transformers.0.blocks.0.self_attention_norm" => - "control_model.input_blocks.7.1.transformer_blocks.0.norm1", - "down_blocks.2.transformers.0.blocks.0.self_attention.key" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k", - "down_blocks.2.transformers.0.blocks.0.self_attention.value" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v", - "down_blocks.2.transformers.0.blocks.0.self_attention.query" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q", - "down_blocks.2.transformers.0.blocks.0.self_attention.output" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0", - - # self attention 2 1 - "down_blocks.2.transformers.1.blocks.0.self_attention_norm" => - "control_model.input_blocks.8.1.transformer_blocks.0.norm1", - "down_blocks.2.transformers.1.blocks.0.self_attention.key" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k", - "down_blocks.2.transformers.1.blocks.0.self_attention.value" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v", - "down_blocks.2.transformers.1.blocks.0.self_attention.query" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q", - "down_blocks.2.transformers.1.blocks.0.self_attention.output" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0", - - # cross attention 0 0 - "down_blocks.0.transformers.0.blocks.0.cross_attention_norm" => - "control_model.input_blocks.1.1.transformer_blocks.0.norm2", - "down_blocks.0.transformers.0.blocks.0.cross_attention.key" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k", - "down_blocks.0.transformers.0.blocks.0.cross_attention.value" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v", - "down_blocks.0.transformers.0.blocks.0.cross_attention.query" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q", - "down_blocks.0.transformers.0.blocks.0.cross_attention.output" => - "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0", - - # cross attention 0 1 - "down_blocks.0.transformers.1.blocks.0.cross_attention_norm" => - "control_model.input_blocks.2.1.transformer_blocks.0.norm2", - "down_blocks.0.transformers.1.blocks.0.cross_attention.key" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k", - "down_blocks.0.transformers.1.blocks.0.cross_attention.value" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v", - "down_blocks.0.transformers.1.blocks.0.cross_attention.query" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q", - "down_blocks.0.transformers.1.blocks.0.cross_attention.output" => - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0", - - # cross attention 1 0 - "down_blocks.1.transformers.0.blocks.0.cross_attention_norm" => - "control_model.input_blocks.4.1.transformer_blocks.0.norm2", - "down_blocks.1.transformers.0.blocks.0.cross_attention.key" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k", - "down_blocks.1.transformers.0.blocks.0.cross_attention.value" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v", - "down_blocks.1.transformers.0.blocks.0.cross_attention.query" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q", - "down_blocks.1.transformers.0.blocks.0.cross_attention.output" => - "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0", - - # cross attention 1 1 - "down_blocks.1.transformers.1.blocks.0.cross_attention_norm" => - "control_model.input_blocks.5.1.transformer_blocks.0.norm2", - "down_blocks.1.transformers.1.blocks.0.cross_attention.key" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k", - "down_blocks.1.transformers.1.blocks.0.cross_attention.value" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v", - "down_blocks.1.transformers.1.blocks.0.cross_attention.query" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q", - "down_blocks.1.transformers.1.blocks.0.cross_attention.output" => - "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0", - - # cross attention 2 0 - "down_blocks.2.transformers.0.blocks.0.cross_attention_norm" => - "control_model.input_blocks.7.1.transformer_blocks.0.norm2", - "down_blocks.2.transformers.0.blocks.0.cross_attention.key" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k", - "down_blocks.2.transformers.0.blocks.0.cross_attention.value" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v", - "down_blocks.2.transformers.0.blocks.0.cross_attention.query" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q", - "down_blocks.2.transformers.0.blocks.0.cross_attention.output" => - "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0", - - # cross attention 2 1 - "down_blocks.2.transformers.1.blocks.0.cross_attention_norm" => - "control_model.input_blocks.8.1.transformer_blocks.0.norm2", - "down_blocks.2.transformers.1.blocks.0.cross_attention.key" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k", - "down_blocks.2.transformers.1.blocks.0.cross_attention.value" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v", - "down_blocks.2.transformers.1.blocks.0.cross_attention.query" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q", - "down_blocks.2.transformers.1.blocks.0.cross_attention.output" => - "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0", - - # ffn 0 0 - "down_blocks.0.transformers.0.blocks.0.ffn.intermediate" => - "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.0.transformers.0.blocks.0.ffn.output" => - "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2", - - # ffn 0 1 - "down_blocks.0.transformers.1.blocks.0.ffn.intermediate" => - "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.0.transformers.1.blocks.0.ffn.output" => - "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2", - - # ffn 1 0 - "down_blocks.1.transformers.0.blocks.0.ffn.intermediate" => - "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.1.transformers.0.blocks.0.ffn.output" => - "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2", - - # ffn 1 1 - "down_blocks.1.transformers.1.blocks.0.ffn.intermediate" => - "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.1.transformers.1.blocks.0.ffn.output" => - "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2", - - # ffn 2 0 - "down_blocks.2.transformers.0.blocks.0.ffn.intermediate" => - "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.2.transformers.0.blocks.0.ffn.output" => - "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2", - - # ffn 2 1 - "down_blocks.2.transformers.1.blocks.0.ffn.intermediate" => - "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj", - "down_blocks.2.transformers.1.blocks.0.ffn.output" => - "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2", - - # residuals 0 0 - "down_blocks.0.residual_blocks.0.norm_1" => "control_model.input_blocks.1.0.in_layers.0", - "down_blocks.0.residual_blocks.0.conv_1" => "control_model.input_blocks.1.0.in_layers.2", - "down_blocks.0.residual_blocks.0.timestep_projection" => - "control_model.input_blocks.1.0.emb_layers.1", - "down_blocks.0.residual_blocks.0.norm_2" => "control_model.input_blocks.1.0.out_layers.0", - "down_blocks.0.residual_blocks.0.conv_2" => "control_model.input_blocks.1.0.out_layers.3", - - # residuals 0 1 - "down_blocks.0.residual_blocks.1.norm_1" => "control_model.input_blocks.2.0.in_layers.0", - "down_blocks.0.residual_blocks.1.conv_1" => "control_model.input_blocks.2.0.in_layers.2", - "down_blocks.0.residual_blocks.1.timestep_projection" => - "control_model.input_blocks.2.0.emb_layers.1", - "down_blocks.0.residual_blocks.1.norm_2" => "control_model.input_blocks.2.0.out_layers.0", - "down_blocks.0.residual_blocks.1.conv_2" => "control_model.input_blocks.2.0.out_layers.3", - - # residuals 1 0 - "down_blocks.1.residual_blocks.0.norm_1" => "control_model.input_blocks.4.0.in_layers.0", - "down_blocks.1.residual_blocks.0.conv_1" => "control_model.input_blocks.4.0.in_layers.2", - "down_blocks.1.residual_blocks.0.timestep_projection" => - "control_model.input_blocks.4.0.emb_layers.1", - "down_blocks.1.residual_blocks.0.norm_2" => "control_model.input_blocks.4.0.out_layers.0", - "down_blocks.1.residual_blocks.0.conv_2" => "control_model.input_blocks.4.0.out_layers.3", - - # residuals 1 1 - "down_blocks.1.residual_blocks.1.norm_1" => "control_model.input_blocks.5.0.in_layers.0", - "down_blocks.1.residual_blocks.1.conv_1" => "control_model.input_blocks.5.0.in_layers.2", - "down_blocks.1.residual_blocks.1.timestep_projection" => - "control_model.input_blocks.5.0.emb_layers.1", - "down_blocks.1.residual_blocks.1.norm_2" => "control_model.input_blocks.5.0.out_layers.0", - "down_blocks.1.residual_blocks.1.conv_2" => "control_model.input_blocks.5.0.out_layers.3", - - # residuals 2 0 - "down_blocks.2.residual_blocks.0.norm_1" => "control_model.input_blocks.7.0.in_layers.0", - "down_blocks.2.residual_blocks.0.conv_1" => "control_model.input_blocks.7.0.in_layers.2", - "down_blocks.2.residual_blocks.0.timestep_projection" => - "control_model.input_blocks.7.0.emb_layers.1", - "down_blocks.2.residual_blocks.0.norm_2" => "control_model.input_blocks.7.0.out_layers.0", - "down_blocks.2.residual_blocks.0.conv_2" => "control_model.input_blocks.7.0.out_layers.3", - - # residuals 2 1 - "down_blocks.2.residual_blocks.1.norm_1" => "control_model.input_blocks.8.0.in_layers.0", - "down_blocks.2.residual_blocks.1.conv_1" => "control_model.input_blocks.8.0.in_layers.2", - "down_blocks.2.residual_blocks.1.timestep_projection" => - "control_model.input_blocks.8.0.emb_layers.1", - "down_blocks.2.residual_blocks.1.norm_2" => "control_model.input_blocks.8.0.out_layers.0", - "down_blocks.2.residual_blocks.1.conv_2" => "control_model.input_blocks.8.0.out_layers.3", - - # residuals 3 0 - "down_blocks.3.residual_blocks.0.norm_1" => "control_model.input_blocks.10.0.in_layers.0", - "down_blocks.3.residual_blocks.0.conv_1" => "control_model.input_blocks.10.0.in_layers.2", - "down_blocks.3.residual_blocks.0.timestep_projection" => - "control_model.input_blocks.10.0.emb_layers.1", - "down_blocks.3.residual_blocks.0.norm_2" => - "control_model.input_blocks.10.0.out_layers.0", - "down_blocks.3.residual_blocks.0.conv_2" => - "control_model.input_blocks.10.0.out_layers.3", - - # residuals 3 1 - "down_blocks.3.residual_blocks.1.norm_1" => "control_model.input_blocks.11.0.in_layers.0", - "down_blocks.3.residual_blocks.1.conv_1" => "control_model.input_blocks.11.0.in_layers.2", - "down_blocks.3.residual_blocks.1.timestep_projection" => - "control_model.input_blocks.11.0.emb_layers.1", - "down_blocks.3.residual_blocks.1.norm_2" => - "control_model.input_blocks.11.0.out_layers.0", - "down_blocks.3.residual_blocks.1.conv_2" => - "control_model.input_blocks.11.0.out_layers.3", - - # projection 0 0 - "down_blocks.0.transformers.0.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.1.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.1.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.0.transformers.0.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.1.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.1.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # projection 0 1 - "down_blocks.0.transformers.1.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.2.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.2.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.0.transformers.1.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.2.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.2.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # projection 1 0 - "down_blocks.1.transformers.0.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.4.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.4.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.1.transformers.0.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.4.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.4.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # projection 1 1 - "down_blocks.1.transformers.1.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.5.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.5.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.1.transformers.1.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.5.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.5.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # projection 2 0 - "down_blocks.2.transformers.0.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.7.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.7.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.2.transformers.0.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.7.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.7.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # projection 2 1 - "down_blocks.2.transformers.1.input_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.8.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.8.1.proj_in", "bias"}], fn [value] -> value end} - }, - "down_blocks.2.transformers.1.output_projection" => %{ - "kernel" => - {[{"control_model.input_blocks.8.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => - {[{"control_model.input_blocks.8.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # shortcut - "down_blocks.1.residual_blocks.0.shortcut.projection" => - "control_model.input_blocks.4.0.skip_connection", - "down_blocks.2.residual_blocks.0.shortcut.projection" => - "control_model.input_blocks.7.0.skip_connection", - - # downsamples - "down_blocks.0.downsamples.0.conv" => "control_model.input_blocks.3.0.op", - "down_blocks.1.downsamples.0.conv" => "control_model.input_blocks.6.0.op", - "down_blocks.2.downsamples.0.conv" => "control_model.input_blocks.9.0.op", - - # out 0 0 - "down_blocks.0.transformers.0.blocks.0.output_norm" => - "control_model.input_blocks.1.1.transformer_blocks.0.norm3", - - # out 0 1 - "down_blocks.0.transformers.1.blocks.0.output_norm" => - "control_model.input_blocks.2.1.transformer_blocks.0.norm3", - - # out 1 0 - "down_blocks.1.transformers.0.blocks.0.output_norm" => - "control_model.input_blocks.4.1.transformer_blocks.0.norm3", - - # out 1 1 - "down_blocks.1.transformers.1.blocks.0.output_norm" => - "control_model.input_blocks.5.1.transformer_blocks.0.norm3", - - # out 2 0 - "down_blocks.2.transformers.0.blocks.0.output_norm" => - "control_model.input_blocks.7.1.transformer_blocks.0.norm3", - - # out 2 1 - "down_blocks.2.transformers.1.blocks.0.output_norm" => - "control_model.input_blocks.8.1.transformer_blocks.0.norm3", - - # mid_block_mapping = %{ - # mid_block - "mid_block.transformers.0.norm" => "control_model.middle_block.1.norm", - # self attention - "mid_block.transformers.0.blocks.0.self_attention_norm" => - "control_model.middle_block.1.transformer_blocks.0.norm1", - "mid_block.transformers.0.blocks.0.self_attention.key" => - "control_model.middle_block.1.transformer_blocks.0.attn1.to_k", - "mid_block.transformers.0.blocks.0.self_attention.value" => - "control_model.middle_block.1.transformer_blocks.0.attn1.to_v", - "mid_block.transformers.0.blocks.0.self_attention.query" => - "control_model.middle_block.1.transformer_blocks.0.attn1.to_q", - "mid_block.transformers.0.blocks.0.self_attention.output" => - "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0", - - # cross attention - "mid_block.transformers.0.blocks.0.cross_attention_norm" => - "control_model.middle_block.1.transformer_blocks.0.norm2", - "mid_block.transformers.0.blocks.0.cross_attention.key" => - "control_model.middle_block.1.transformer_blocks.0.attn2.to_k", - "mid_block.transformers.0.blocks.0.cross_attention.value" => - "control_model.middle_block.1.transformer_blocks.0.attn2.to_v", - "mid_block.transformers.0.blocks.0.cross_attention.query" => - "control_model.middle_block.1.transformer_blocks.0.attn2.to_q", - "mid_block.transformers.0.blocks.0.cross_attention.output" => - "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0", - - # ffn - "mid_block.transformers.0.blocks.0.ffn.intermediate" => - "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj", - "mid_block.transformers.0.blocks.0.ffn.output" => - "control_model.middle_block.1.transformer_blocks.0.ff.net.2", - - # residuals 0 - "mid_block.residual_blocks.0.norm_1" => "control_model.middle_block.0.in_layers.0", - "mid_block.residual_blocks.0.conv_1" => "control_model.middle_block.0.in_layers.2", - "mid_block.residual_blocks.0.timestep_projection" => - "control_model.middle_block.0.emb_layers.1", - "mid_block.residual_blocks.0.norm_2" => "control_model.middle_block.0.out_layers.0", - "mid_block.residual_blocks.0.conv_2" => "control_model.middle_block.0.out_layers.3", - # residuals 1 - "mid_block.residual_blocks.1.norm_1" => "control_model.middle_block.2.in_layers.0", - "mid_block.residual_blocks.1.conv_1" => "control_model.middle_block.2.in_layers.2", - "mid_block.residual_blocks.1.timestep_projection" => - "control_model.middle_block.2.emb_layers.1", - "mid_block.residual_blocks.1.norm_2" => "control_model.middle_block.2.out_layers.0", - "mid_block.residual_blocks.1.conv_2" => "control_model.middle_block.2.out_layers.3", - - # projection - "mid_block.transformers.0.input_projection" => %{ - "kernel" => - {[{"control_model.middle_block.1.proj_in", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => {[{"control_model.middle_block.1.proj_in", "bias"}], fn [value] -> value end} - }, - "mid_block.transformers.0.output_projection" => %{ - "kernel" => - {[{"control_model.middle_block.1.proj_out", "weight"}], - fn [value] -> value |> Nx.new_axis(0) |> Nx.new_axis(0) end}, - "bias" => {[{"control_model.middle_block.1.proj_out", "bias"}], fn [value] -> value end} - }, - - # out - "mid_block.transformers.0.blocks.0.output_norm" => - "control_model.middle_block.1.transformer_blocks.0.norm3", - - # others - "time_embedding.intermediate" => "control_model.time_embed.0", - "time_embedding.output" => "control_model.time_embed.2" + "time_embedding.intermediate" => "time_embedding.linear_1", + "time_embedding.output" => "time_embedding.linear_2", + "input_conv" => "conv_in" } + |> Map.merge(blocks_mapping) + |> Map.merge(controlnet) end end end diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex index 09ba9b31..6c32077f 100644 --- a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex +++ b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex @@ -5,7 +5,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do @moduletag model_test_tags() - @tag timeout: :infinity test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}, @@ -19,33 +18,26 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do inputs = %{ "sample" => Nx.broadcast(0.5, {1, 64, 64, 4}), - "controlnet_conditioning" => Nx.broadcast(0.8, {1, 512, 512, 4}), - "timestep" => Nx.tensor(1), - "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) + "controlnet_conditioning" => Nx.broadcast(0.8, {1, 512, 512, 3}), + "timestep" => Nx.tensor(0), + "encoder_hidden_state" => Nx.broadcast(0.8, {1, 1, 768}) } outputs = Axon.predict(model, params, inputs, debug: true) - assert Nx.shape(outputs.sample) == {1, 32, 32, 4} + assert Nx.shape(outputs.mid_block_residual) == {1, 8, 8, 1280} assert_all_close( - to_channels_first(outputs.sample)[[.., 1..3, 1..3, 1..3]], - Nx.tensor([ - [ - [ - [-1.0813, -0.5109, -0.1545], - [-0.8094, -1.2588, -0.8355], - [-0.9218, -1.2142, -0.6982] - ], - [ - [-0.2179, -0.2799, -1.0922], - [-0.9485, -0.8376, 0.0843], - [-0.9650, -0.7105, -0.3920] - ], - [[1.3359, 0.8373, -0.2392], [0.9448, -0.0478, 0.6881], [-0.0154, -0.5304, 0.2081]] - ] - ]), - atol: 1.0e-4 + outputs.mid_block_residual[[0, 0, 0, 1..3]], + Nx.tensor([-1.2827045917510986, -0.6995724439620972, -0.610561192035675]) + ) + + first_down_residual = elem(outputs.down_blocks_residuals, 0) + assert Nx.shape(first_down_residual) == {1, 64, 64, 320} + + assert_all_close( + first_down_residual[[0, 0, 0, 1..3]], + Nx.tensor([-0.029463158920407295, 0.04885300621390343, -0.12834328413009644]) ) end end From 74fdcfbef1b0ca35e109fea6e1e1c5e01722c464 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Sat, 17 Feb 2024 23:28:44 +0100 Subject: [PATCH 06/42] unet with controlnet --- lib/bumblebee.ex | 1 + .../diffusion/unet_2d_conditional.ex | 130 +++++++++++++++++- .../diffusion/unet_2d_conditional_test.exs | 56 ++++++++ 3 files changed, 185 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 9ba92127..b0f956d2 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -116,6 +116,7 @@ defmodule Bumblebee do "CLIPModel" => {Bumblebee.Multimodal.Clip, :base}, "CLIPTextModel" => {Bumblebee.Text.ClipText, :base}, "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base}, + "ControlNetModel" => {Bumblebee.Diffusion.StableDiffusion.ControlNet, :base}, "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification}, "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base}, "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification}, diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 03d39169..5854f53d 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -133,7 +133,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do alias Bumblebee.Diffusion @impl true - def architectures(), do: [:base] + def architectures(), do: [:base, :with_controlnet] @impl true def config(spec, opts) do @@ -146,11 +146,28 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do timestep_shape = {} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} + mid_dim = List.last(spec.hidden_sizes) + + mid_residual_shape = {1, 16, 16, mid_dim} + + out_channels = [32, 32, 32, 32, 64, 64] + out_spatials = [32, 32, 32, 16, 16, 16] + + down_zip = Enum.zip(out_channels, out_spatials) + + down_residuals = + for {{out_channel, out_spatial}, i} <- Enum.with_index(down_zip), into: %{} do + shape = {1, out_spatial, out_spatial, out_channel} + {"controlnet_down_residual_#{i}", Nx.template(shape, :f32)} + end + %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), - "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) + "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32), + "controlnet_mid_residual" => Nx.template(mid_residual_shape, :f32) } + |> Map.merge(down_residuals) end @impl true @@ -160,6 +177,13 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Layers.output(%{sample: sample}) end + @impl true + def model(%__MODULE__{architecture: :with_controlnet} = spec) do + inputs = inputs_with_controlnet(spec) + sample = core_with_controlnet(inputs, spec) + Layers.output(%{sample: sample}) + end + defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} @@ -170,6 +194,27 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ]) end + defp inputs_with_controlnet(spec) do + sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + mid_dim = List.last(spec.hidden_sizes) + + num_down_residuals = length(spec.hidden_sizes) * (spec.depth + 1) + + down_residuals = + for i <- 0..(num_down_residuals - 1) do + Axon.input("controlnet_down_residual_#{i}") + end + + Bumblebee.Utils.Model.inputs_to_map( + [ + Axon.input("sample", shape: sample_shape), + Axon.input("timestep", shape: {}), + Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), + Axon.input("controlnet_mid_residual", shape: {1, 16, 16, mid_dim}) + ] ++ down_residuals + ) + end + defp core(inputs, spec) do sample = inputs["sample"] timestep = inputs["timestep"] @@ -228,6 +273,87 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ) end + defp core_with_controlnet(inputs, spec) do + sample = inputs["sample"] + timestep = inputs["timestep"] + encoder_hidden_state = inputs["encoder_hidden_state"] + controlnet_mid_residual = inputs["controlnet_mid_residual"] + + num_down_residuals = length(spec.hidden_sizes) * (spec.depth + 1) + + controlnet_down_residuals = + for i <- 0..(num_down_residuals - 1) do + inputs["controlnet_down_residual_#{i}"] + end + |> dbg() + + sample = + if spec.center_input_sample do + Axon.nx(sample, fn sample -> 2 * sample - 1.0 end, op_name: :center) + else + sample + end + + timestep = + Axon.layer( + fn sample, timestep, _opts -> + Nx.broadcast(timestep, {Nx.axis_size(sample, 0)}) + end, + [sample, timestep], + op_name: :broadcast + ) + + timestep_embedding = + timestep + |> Diffusion.Layers.timestep_sinusoidal_embedding(hd(spec.hidden_sizes), + flip_sin_to_cos: spec.embedding_flip_sin_to_cos, + frequency_correction_term: spec.embedding_frequency_correction_term + ) + |> Diffusion.Layers.UNet.timestep_embedding_mlp(hd(spec.hidden_sizes) * 4, + name: "time_embedding" + ) + + sample = + Axon.conv(sample, hd(spec.hidden_sizes), + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "input_conv" + ) + + {sample, down_block_residuals} = + down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") + + down_residual_zip = Enum.zip(Tuple.to_list(down_block_residuals), controlnet_down_residuals) + + down_block_residuals = + for {{{down_residual, out_channel}, controlnet_down_residual}, i} <- + Enum.with_index(down_residual_zip) do + {Axon.add(down_residual, controlnet_down_residual, name: "add_controlnet_down_#{i}"), + out_channel} + end + |> List.to_tuple() + + mid_block_residual = + sample + |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") + |> Axon.add(controlnet_mid_residual, name: "add_controlnet_mid") + + mid_block_residual + |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, + name: "up_blocks" + ) + |> Axon.group_norm(spec.group_norm_num_groups, + epsilon: spec.group_norm_epsilon, + name: "output_norm" + ) + |> Axon.activation(:silu) + |> Axon.conv(spec.out_channels, + kernel_size: 3, + padding: [{1, 1}, {1, 1}], + name: "output_conv" + ) + end + defp down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, opts) do name = opts[:name] diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index 00663a28..095551d9 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -43,4 +43,60 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do atol: 1.0e-4 ) end + + test ":with_controlnet" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"}, + architecture: :with_controlnet + ) + + assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :with_controlnet} = spec + + num_down_residuals = (length(spec.hidden_sizes) * (spec.depth + 1)) |> dbg() + + out_channels = [32, 32, 32, 32, 64, 64] + out_spatials = [32, 32, 32, 16, 16, 16] + + down_zip = Enum.zip(out_channels, out_spatials) + + down_residuals = + for {{out_channel, out_spatial}, i} <- Enum.with_index(down_zip), into: %{} do + shape = {1, out_spatial, out_spatial, out_channel} + {"controlnet_down_residual_#{i}", Nx.broadcast(0.5, shape)} + end + + inputs = + %{ + "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), + "timestep" => Nx.tensor(1), + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), + "controlnet_mid_residual" => Nx.broadcast(0.5, {1, 16, 16, 64}) + } + |> Map.merge(down_residuals) + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.sample) == {1, 32, 32, 4} + + assert_all_close( + to_channels_first(outputs.sample)[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [ + [-1.0813, -0.5109, -0.1545], + [-0.8094, -1.2588, -0.8355], + [-0.9218, -1.2142, -0.6982] + ], + [ + [-0.2179, -0.2799, -1.0922], + [-0.9485, -0.8376, 0.0843], + [-0.9650, -0.7105, -0.3920] + ], + [[1.3359, 0.8373, -0.2392], [0.9448, -0.0478, 0.6881], [-0.0154, -0.5304, 0.2081]] + ] + ]), + atol: 1.0e-4 + ) + end end From 8f55875a61f28b26fde203e92fbcec1ef203cf9d Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 19 Feb 2024 10:35:51 +0100 Subject: [PATCH 07/42] Define inputs for unet with controlnet --- .../diffusion/unet_2d_conditional.ex | 57 ++++++++++++++----- .../diffusion/unet_2d_conditional_test.exs | 39 +++++++++---- 2 files changed, 73 insertions(+), 23 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 5854f53d..46d10375 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -146,21 +146,34 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do timestep_shape = {} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} - mid_dim = List.last(spec.hidden_sizes) + first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} + + state = {spec.sample_size, [first]} - mid_residual_shape = {1, 16, 16, mid_dim} + {mid_spatial, out_shapes} = + for block_out_channel <- spec.hidden_sizes, reduce: state do + {spatial_size, acc} -> + residuals = + for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channel} - out_channels = [32, 32, 32, 32, 64, 64] - out_spatials = [32, 32, 32, 16, 16, 16] + downsampled_spatial = div(spatial_size, 2) + downsample = {1, downsampled_spatial, downsampled_spatial, block_out_channel} - down_zip = Enum.zip(out_channels, out_spatials) + {div(spatial_size, 2), acc ++ residuals ++ [downsample]} + end + + mid_spatial = 2 * mid_spatial + out_shapes = Enum.drop(out_shapes, -1) down_residuals = - for {{out_channel, out_spatial}, i} <- Enum.with_index(down_zip), into: %{} do - shape = {1, out_spatial, out_spatial, out_channel} + for {shape, i} <- Enum.with_index(out_shapes), into: %{} do {"controlnet_down_residual_#{i}", Nx.template(shape, :f32)} end + mid_dim = List.last(spec.hidden_sizes) + + mid_residual_shape = {1, mid_spatial, mid_spatial, mid_dim} + %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), @@ -196,21 +209,40 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do defp inputs_with_controlnet(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - mid_dim = List.last(spec.hidden_sizes) + first = {nil, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} - num_down_residuals = length(spec.hidden_sizes) * (spec.depth + 1) + state = {spec.sample_size, [first]} + + {mid_spatial, out_shapes} = + for block_out_channel <- spec.hidden_sizes, reduce: state do + {spatial_size, acc} -> + residuals = + for _ <- 1..spec.depth, do: {nil, spatial_size, spatial_size, block_out_channel} + + downsampled_spatial = div(spatial_size, 2) + downsample = {nil, downsampled_spatial, downsampled_spatial, block_out_channel} + + {div(spatial_size, 2), acc ++ residuals ++ [downsample]} + end + + mid_spatial = 2 * mid_spatial + out_shapes = Enum.drop(out_shapes, -1) down_residuals = - for i <- 0..(num_down_residuals - 1) do - Axon.input("controlnet_down_residual_#{i}") + for {shape, i} <- Enum.with_index(out_shapes) do + Axon.input("controlnet_down_residual_#{i}", shape: shape) end + mid_dim = List.last(spec.hidden_sizes) + + mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_dim} + Bumblebee.Utils.Model.inputs_to_map( [ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("controlnet_mid_residual", shape: {1, 16, 16, mid_dim}) + Axon.input("controlnet_mid_residual", shape: mid_residual_shape) ] ++ down_residuals ) end @@ -285,7 +317,6 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do for i <- 0..(num_down_residuals - 1) do inputs["controlnet_down_residual_#{i}"] end - |> dbg() sample = if spec.center_input_sample do diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index 095551d9..f2fa9de6 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -44,34 +44,53 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do ) end + @tag timeout: :infinity test ":with_controlnet" do + compvis = "CompVis/stable-diffusion-v1-4" + tiny = "bumblebee-testing/tiny-stable-diffusion" + assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( - {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"}, + {:hf, tiny, subdir: "unet"}, architecture: :with_controlnet ) assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :with_controlnet} = spec - num_down_residuals = (length(spec.hidden_sizes) * (spec.depth + 1)) |> dbg() + first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} + + state = {spec.sample_size, [first]} - out_channels = [32, 32, 32, 32, 64, 64] - out_spatials = [32, 32, 32, 16, 16, 16] + {mid_spatial, out_shapes} = + for block_out_channel <- spec.hidden_sizes, reduce: state do + {spatial_size, acc} -> + residuals = + for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channel} - down_zip = Enum.zip(out_channels, out_spatials) + downsampled_spatial = div(spatial_size, 2) + downsample = {1, downsampled_spatial, downsampled_spatial, block_out_channel} + + {div(spatial_size, 2), acc ++ residuals ++ [downsample]} + end + + mid_spatial = 2 * mid_spatial + out_shapes = Enum.drop(out_shapes, -1) |> dbg() down_residuals = - for {{out_channel, out_spatial}, i} <- Enum.with_index(down_zip), into: %{} do - shape = {1, out_spatial, out_spatial, out_channel} + for {shape, i} <- Enum.with_index(out_shapes), into: %{} do {"controlnet_down_residual_#{i}", Nx.broadcast(0.5, shape)} end + mid_dim = List.last(spec.hidden_sizes) + + mid_residual_shape = {1, mid_spatial, mid_spatial, mid_dim} + inputs = %{ - "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), + "sample" => Nx.broadcast(0.5, {1, spec.sample_size, spec.sample_size, 4}), "timestep" => Nx.tensor(1), - "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), - "controlnet_mid_residual" => Nx.broadcast(0.5, {1, 16, 16, 64}) + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, spec.cross_attention_size}), + "controlnet_mid_residual" => Nx.broadcast(0.5, mid_residual_shape) } |> Map.merge(down_residuals) From 854ecd86dcd5def0cea397d9a53a0023adfb3191 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 19 Feb 2024 11:28:08 +0100 Subject: [PATCH 08/42] Unet with controlnet is same as in transformers --- .../diffusion/unet_2d_conditional.ex | 33 ++++++++----------- .../diffusion/unet_2d_conditional_test.exs | 28 ++++++++-------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 46d10375..d6c5daf5 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -146,24 +146,12 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do timestep_shape = {} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} - first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} + {mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec) - state = {spec.sample_size, [first]} - - {mid_spatial, out_shapes} = - for block_out_channel <- spec.hidden_sizes, reduce: state do - {spatial_size, acc} -> - residuals = - for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channel} - - downsampled_spatial = div(spatial_size, 2) - downsample = {1, downsampled_spatial, downsampled_spatial, block_out_channel} - - {div(spatial_size, 2), acc ++ residuals ++ [downsample]} - end - - mid_spatial = 2 * mid_spatial - out_shapes = Enum.drop(out_shapes, -1) + out_shapes = + Enum.map(out_shapes, fn {_, spatial, spatial, channels} -> + {1, spatial, spatial, channels} + end) down_residuals = for {shape, i} <- Enum.with_index(out_shapes), into: %{} do @@ -207,8 +195,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ]) end - defp inputs_with_controlnet(spec) do - sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + defp mid_spatial_and_residual_shapes(spec) do first = {nil, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} state = {spec.sample_size, [first]} @@ -228,6 +215,14 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do mid_spatial = 2 * mid_spatial out_shapes = Enum.drop(out_shapes, -1) + {mid_spatial, out_shapes} + end + + defp inputs_with_controlnet(spec) do + sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} + + {mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec) + down_residuals = for {shape, i} <- Enum.with_index(out_shapes) do Axon.input("controlnet_down_residual_#{i}", shape: shape) diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index f2fa9de6..0a7c7637 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -74,7 +74,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do end mid_spatial = 2 * mid_spatial - out_shapes = Enum.drop(out_shapes, -1) |> dbg() + out_shapes = Enum.drop(out_shapes, -1) down_residuals = for {shape, i} <- Enum.with_index(out_shapes), into: %{} do @@ -96,23 +96,25 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.sample) == {1, 32, 32, 4} + assert Nx.shape(outputs.sample) == {1, spec.sample_size, spec.sample_size, spec.in_channels} assert_all_close( to_channels_first(outputs.sample)[[.., 1..3, 1..3, 1..3]], Nx.tensor([ [ - [ - [-1.0813, -0.5109, -0.1545], - [-0.8094, -1.2588, -0.8355], - [-0.9218, -1.2142, -0.6982] - ], - [ - [-0.2179, -0.2799, -1.0922], - [-0.9485, -0.8376, 0.0843], - [-0.9650, -0.7105, -0.3920] - ], - [[1.3359, 0.8373, -0.2392], [0.9448, -0.0478, 0.6881], [-0.0154, -0.5304, 0.2081]] + [-2.1599538326263428, -0.6707635521888733, 0.16482116281986237], + [-1.13632333278656, -1.7593439817428589, -0.9655789136886597], + [-1.3559075593948364, -2.5425026416778564, -1.8208105564117432] + ], + [ + [-0.06256292015314102, -0.6823181509971619, -1.2743796110153198], + [-1.3499518632888794, -1.599103569984436, 0.8080697655677795], + [-1.177065134048462, -1.2682275772094727, 0.9214832186698914] + ], + [ + [1.7675844430923462, 1.0919926166534424, -0.03096655011177063], + [0.597271203994751, -0.1870473176240921, 1.1974149942398071], + [-0.5016229152679443, -0.6805112957954407, 0.5924324989318848] ] ]), atol: 1.0e-4 From 35b417fd111ac4aba60c8b9c0043a00e4bfa845b Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 19 Feb 2024 18:18:40 +0100 Subject: [PATCH 09/42] StableDiffusionControlNet serving --- .../diffusion/stable_diffusion/control_net.ex | 8 +- .../diffusion/stable_diffusion_controlnet.ex | 521 ++++++++++++++++++ .../stable_diffusion_controlnet_test.exs | 98 ++++ 3 files changed, 625 insertions(+), 2 deletions(-) create mode 100644 lib/bumblebee/diffusion/stable_diffusion_controlnet.ex create mode 100644 test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index c93f2e09..765640e8 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -144,7 +144,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do def input_template(spec) do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} - controlnet_conditioning_shape = {1, 512, 512, 3} + + cond_size = spec.sample_size * 2 ** 3 + controlnet_conditioning_shape = {1, cond_size, cond_size, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} %{ @@ -164,7 +166,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - controlnet_conditioning_shape = {nil, 512, 512, 3} + + cond_size = spec.sample_size * 2 ** 2 + controlnet_conditioning_shape = {nil, cond_size, cond_size, 3} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex new file mode 100644 index 00000000..4f497e17 --- /dev/null +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -0,0 +1,521 @@ +defmodule Bumblebee.Diffusion.StableDiffusionControlNet do + @moduledoc """ + High-level tasks based on Stable Diffusion. + """ + + import Nx.Defn + + alias Bumblebee.Utils + alias Bumblebee.Shared + + @type text_to_image_input :: + String.t() + | %{ + :prompt => String.t(), + :controlnet_conditioning => Nx.Tensor, + optional(:negative_prompt) => String.t(), + optional(:seed) => integer() + } + @type text_to_image_output :: %{results: list(text_to_image_result())} + @type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()} + + @doc ~S""" + Build serving for prompt-driven image generation. + + The serving accepts `t:text_to_image_input/0` and returns `t:text_to_image_output/0`. + A list of inputs is also supported. + + You can specify `:safety_checker` model to automatically detect + when a generated image is offensive or harmful and filter it out. + + ## Options + + * `:safety_checker` - the safety checker model info map. When a + safety checker is used, each output entry has an additional + `:is_safe` property and unsafe images are automatically zeroed. + Make sure to also set `:safety_checker_featurizer` + + * `:safety_checker_featurizer` - the featurizer to use to preprocess + the safety checker input images + + * `:num_steps` - the number of denoising steps. More denoising + steps usually lead to higher image quality at the expense of + slower inference. Defaults to `50` + + * `:num_images_per_prompt` - the number of images to generate for + each prompt. Defaults to `1` + + * `:guidance_scale` - the scale used for classifier-free diffusion + guidance. Higher guidance scale makes the generated images more + closely reflect the text prompt. This parameter corresponds to + $\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf). + Defaults to `7.5` + + * `:compile` - compiles all computations for predefined input shapes + during serving initialization. Should be a keyword list with the + following keys: + + * `:batch_size` - the maximum batch size of the input. Inputs + are optionally padded to always match this batch size + + * `:sequence_length` - the maximum input sequence length. Input + sequences are always padded/truncated to match that length + + It is advised to set this option in production and also configure + a defn compiler using `:defn_options` to maximally reduce inference + time. + + * `:defn_options` - the options for JIT compilation. Defaults to `[]` + + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. When using this option, you should first + load the parameters into the host. This can be done by passing + `backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends. + Defaults to `false` + + ## Examples + + repository_id = "CompVis/stable-diffusion-v1-4" + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) + {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) + {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) + {:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) + {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) + {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) + {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) + + serving = + Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, + num_steps: 20, + num_images_per_prompt: 2, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer, + compile: [batch_size: 1, sequence_length: 60], + defn_options: [compiler: EXLA] + ) + + prompt = "numbat in forest, detailed, digital art" + Nx.Serving.run(serving, prompt) + #=> %{ + #=> results: [ + #=> %{ + #=> image: #Nx.Tensor< + #=> u8[512][512][3] + #=> ... + #=> >, + #=> is_safe: true + #=> }, + #=> %{ + #=> image: #Nx.Tensor< + #=> u8[512][512][3] + #=> ... + #=> >, + #=> is_safe: true + #=> } + #=> ] + #=> } + + """ + @spec text_to_image( + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + Bumblebee.Scheduler.t(), + keyword() + ) :: Nx.Serving.t() + def text_to_image(encoder, unet, vae, controlnet, tokenizer, scheduler, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :safety_checker, + :safety_checker_featurizer, + :compile, + num_steps: 50, + num_images_per_prompt: 1, + guidance_scale: 7.5, + defn_options: [], + preallocate_params: false + ]) + + safety_checker = opts[:safety_checker] + safety_checker_featurizer = opts[:safety_checker_featurizer] + num_steps = opts[:num_steps] + num_images_per_prompt = opts[:num_images_per_prompt] + preallocate_params = opts[:preallocate_params] + defn_options = opts[:defn_options] + + if safety_checker != nil and safety_checker_featurizer == nil do + raise ArgumentError, "got :safety_checker but no :safety_checker_featurizer was specified" + end + + safety_checker? = safety_checker != nil + + compile = + if compile = opts[:compile] do + compile + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) + end + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + tokenizer = + Bumblebee.configure(tokenizer, + length: sequence_length, + return_token_type_ids: false, + return_attention_mask: false + ) + + {_, encoder_predict} = Axon.build(encoder.model) + {_, vae_predict} = Axon.build(vae.model) + {_, unet_predict} = Axon.build(unet.model) + {_, controlnet_predict} = Axon.build(controlnet.model) + + scheduler_init = &Bumblebee.scheduler_init(scheduler, num_steps, &1, &2) + scheduler_step = &Bumblebee.scheduler_step(scheduler, &1, &2, &3) + + image_fun = + &text_to_image_impl( + encoder_predict, + &1, + unet_predict, + &2, + vae_predict, + &3, + controlnet_predict, + &4, + scheduler_init, + scheduler_step, + &5, + num_images_per_prompt: opts[:num_images_per_prompt], + latents_sample_size: unet.spec.sample_size, + latents_channels: unet.spec.in_channels, + guidance_scale: opts[:guidance_scale] + ) + + safety_checker_fun = + if safety_checker do + {_, predict_fun} = Axon.build(safety_checker.model) + predict_fun + end + + # Note that all of these are copied when using serving as a process + init_args = [ + {image_fun, safety_checker_fun}, + encoder.params, + unet.params, + vae.params, + controlnet.params, + {safety_checker?, safety_checker[:spec], safety_checker[:params]}, + safety_checker_featurizer, + {compile != nil, batch_size, sequence_length}, + num_images_per_prompt, + preallocate_params + ] + + Nx.Serving.new( + fn defn_options -> apply(&init/11, init_args ++ [defn_options]) end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer)) + |> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker)) + end + + defp init( + {image_fun, safety_checker_fun}, + encoder_params, + unet_params, + vae_params, + controlnet_params, + {safety_checker?, safety_checker_spec, safety_checker_params}, + safety_checker_featurizer, + {compile?, batch_size, sequence_length}, + num_images_per_prompt, + preallocate_params, + defn_options + ) do + encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options) + unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options) + vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options) + + controlnet_params = + Shared.maybe_preallocate(controlnet_params, preallocate_params, defn_options) + + image_fun = + Shared.compile_or_jit(image_fun, defn_options, compile?, fn -> + inputs = %{ + "conditional_and_unconditional" => %{ + "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) + }, + "seed" => Nx.template({batch_size}, :s64), + "controlnet_conditioning" => Nx.template({batch_size, 512, 512, 3}, :f32) + } + + [encoder_params, unet_params, vae_params, controlnet_params, inputs] + end) + + safety_checker_fun = + safety_checker_fun && + Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn -> + inputs = %{ + "pixel_values" => + Shared.input_template(safety_checker_spec, "pixel_values", [ + batch_size * num_images_per_prompt + ]) + } + + [safety_checker_params, inputs] + end) + + safety_checker_params = + safety_checker_params && + Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + + image = image_fun.(encoder_params, unet_params, vae_params, controlnet_params, inputs) + + output = + if safety_checker? do + inputs = Bumblebee.apply_featurizer(safety_checker_featurizer, image) + outputs = safety_checker_fun.(safety_checker_params, inputs) + %{image: image, is_unsafe: outputs.is_unsafe} + else + %{image: image} + end + + output + |> Utils.Nx.composite_unflatten_batch(Utils.Nx.batch_size(inputs)) + |> Shared.serving_post_computation() + end + end + + defp preprocess_image(image) do + image + end + + defp client_preprocessing(input, tokenizer) do + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) + + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend) + + # Note: we need to tokenize all sequences together, so that + # they are padded to the same length (if not specified) + prompts = Enum.flat_map(inputs, &[&1.prompt, &1.negative_prompt]) + + prompt_pairs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + inputs = Bumblebee.apply_tokenizer(tokenizer, prompts) + Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0)) + end) + + controlnet_conditioning = + Enum.map(inputs, & &1.controlnet_conditioning) + |> Nx.stack() + |> preprocess_image() + + inputs = %{ + "conditional_and_unconditional" => prompt_pairs, + "seed" => seed, + "controlnet_conditioning" => controlnet_conditioning + } + + {Nx.Batch.concatenate([inputs]), multi?} + end + + defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do + for outputs <- Utils.Nx.batch_to_list(outputs) do + results = + for outputs = %{image: image} <- Utils.Nx.batch_to_list(outputs) do + if safety_checker? do + if Nx.to_number(outputs.is_unsafe) == 1 do + %{image: zeroed(image), is_safe: false} + else + %{image: image, is_safe: true} + end + else + %{image: image} + end + end + + %{results: results} + end + |> Shared.normalize_output(multi?) + end + + defp zeroed(tensor) do + 0 + |> Nx.tensor(type: Nx.type(tensor), backend: Nx.BinaryBackend) + |> Nx.broadcast(Nx.shape(tensor)) + end + + defnp text_to_image_impl( + encoder_predict, + encoder_params, + unet_predict, + unet_params, + vae_predict, + vae_params, + controlnet_predict, + controlnet_params, + scheduler_init, + scheduler_step, + inputs, + opts \\ [] + ) do + num_images_per_prompt = opts[:num_images_per_prompt] + latents_sample_size = opts[:latents_sample_size] + latents_in_channels = opts[:latents_channels] + guidance_scale = opts[:guidance_scale] + + seed = inputs["seed"] + controlnet_conditioning = inputs["controlnet_conditioning"] + + inputs = + inputs["conditional_and_unconditional"] + # Transpose conditional and unconditional to separate blocks + |> composite_transpose_leading() + |> Utils.Nx.composite_flatten_batch() + + %{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs) + + {_twice_batch_size, sequence_length, hidden_size} = Nx.shape(text_embeddings) + + text_embeddings = + text_embeddings + |> Nx.new_axis(1) + |> Nx.tile([1, num_images_per_prompt, 1, 1]) + |> Nx.reshape({:auto, sequence_length, hidden_size}) + + prng_key = + seed + |> Nx.vectorize(:batch) + |> Nx.Random.key() + |> Nx.Random.split(parts: num_images_per_prompt) + |> Nx.devectorize() + |> Nx.flatten(axes: [0, 1]) + |> Nx.vectorize(:batch) + + {latents, prng_key} = + Nx.Random.normal(prng_key, + shape: {latents_sample_size, latents_sample_size, latents_in_channels} + ) + + {scheduler_state, timesteps} = scheduler_init.(Nx.to_template(latents), prng_key) + + latents = Nx.devectorize(latents) + + {latents, _} = + while {latents, + {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, + controlnet_params}}, + timestep <- timesteps do + controlnet_inputs = %{ + "controlnet_conditioning" => controlnet_conditioning, + "sample" => Nx.concatenate([latents, latents]), + "timestep" => timestep, + "encoder_hidden_state" => text_embeddings + } + + %{down_blocks_residuals: down_blocks_residuals, mid_block_residual: mid_block_residual} = + controlnet_predict.(controlnet_params, controlnet_inputs) + + # {down_residual_map, _} = + # while {down_residual_map, down_blocks_residuals}, i <- 0..10 do + # key = "down_block_residual_#{i}" + # value = elem(down_blocks_residuals, i) + + # {Map.put(down_residual_map, key, value), down_blocks_residuals} + # end + + unet_inputs = + %{ + "sample" => Nx.concatenate([latents, latents]), + "timestep" => timestep, + "encoder_hidden_state" => text_embeddings, + "controlnet_mid_residual" => mid_block_residual, + "controlnet_down_residual_0" => elem(down_blocks_residuals, 0), + "controlnet_down_residual_1" => elem(down_blocks_residuals, 1), + "controlnet_down_residual_2" => elem(down_blocks_residuals, 2), + "controlnet_down_residual_3" => elem(down_blocks_residuals, 3), + "controlnet_down_residual_4" => elem(down_blocks_residuals, 4), + "controlnet_down_residual_5" => elem(down_blocks_residuals, 5), + "controlnet_down_residual_6" => elem(down_blocks_residuals, 6), + "controlnet_down_residual_7" => elem(down_blocks_residuals, 7), + "controlnet_down_residual_8" => elem(down_blocks_residuals, 8), + "controlnet_down_residual_9" => elem(down_blocks_residuals, 9), + "controlnet_down_residual_10" => elem(down_blocks_residuals, 10), + "controlnet_down_residual_11" => elem(down_blocks_residuals, 11) + } + + %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs) + + {noise_pred_conditional, noise_pred_unconditional} = + split_conditional_and_unconditional(noise_pred) + + noise_pred = + noise_pred_unconditional + + guidance_scale * (noise_pred_conditional - noise_pred_unconditional) + + {scheduler_state, latents} = + scheduler_step.( + scheduler_state, + Nx.vectorize(latents, :batch), + Nx.vectorize(noise_pred, :batch) + ) + + latents = Nx.devectorize(latents) + + {latents, + {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, + controlnet_params}} + end + + latents = latents * (1 / 0.18215) + + %{sample: image} = vae_predict.(vae_params, latents) + + NxImage.from_continuous(image, -1, 1) + end + + deftransformp composite_transpose_leading(container) do + Utils.Nx.map(container, fn tensor -> + [first, second | rest] = Nx.axes(tensor) + Nx.transpose(tensor, axes: [second, first | rest]) + end) + end + + defnp split_conditional_and_unconditional(tensor) do + batch_size = Nx.axis_size(tensor, 0) + half_size = div(batch_size, 2) + {tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]} + end + + defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt}) + + defp validate_input(%{prompt: prompt, controlnet_conditioning: controlnet_conditioning} = input) do + {:ok, + %{ + prompt: prompt, + controlnet_conditioning: controlnet_conditioning, + negative_prompt: input[:negative_prompt] || "", + seed: input[:seed] || :erlang.system_time() + }} + end + + defp validate_input(%{} = input) do + {:error, + "expected the input map to have :prompt and :controlnet_conditioning key, got: #{inspect(input)}"} + end + + defp validate_input(input) do + {:error, "expected either a string or a map, got: #{inspect(input)}"} + end +end diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs new file mode 100644 index 00000000..3c7992ed --- /dev/null +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -0,0 +1,98 @@ +defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + # @moduletag serving_test_tags() + + @tag timeout: :infinity + describe "text_to_image/6" do + test "generates image for a text prompt with controlnet" do + # Since we don't assert on the result in this case, we use + # a tiny random checkpoint. This test is basically to verify + # the whole generation computation end-to-end + + repository_id = "CompVis/stable-diffusion-v1-4" + # repository_id = "bumblebee-testing/tiny-stable-diffusion" + + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) + {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) + + {:ok, unet} = + Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, architecture: :with_controlnet) + + {:ok, controlnet} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}) + # {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) + + {:ok, vae} = + Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) + + {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) + + {:ok, featurizer} = + Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) + + {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) + + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer, + compile: [batch_size: 1, sequence_length: 60] + ) + + prompt = "numbat in forest, detailed, digital art" + + cond_size = (unet.spec.sample_size * 2 ** 3) + + controlnet_conditioning = Nx.broadcast(0.5, {cond_size, cond_size, 3}) + + assert %{ + results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] + } = + Nx.Serving.run(serving, %{ + prompt: prompt, + controlnet_conditioning: controlnet_conditioning + }) + + # Without safety checker + + # serving = + # Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + # clip, + # unet, + # vae, + # tokenizer, + # scheduler, + # num_steps: 3 + # ) + + # prompt = "numbat in forest, detailed, digital art" + + # assert %{results: [%{image: %Nx.Tensor{}}]} = Nx.Serving.run(serving, prompt) + + # With compilation + + # serving = + # Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, + # num_steps: 3, + # safety_checker: safety_checker, + # safety_checker_featurizer: featurizer, + # defn_options: [compiler: EXLA] + # ) + + # prompt = "numbat in forest, detailed, digital art" + + # assert %{ + # results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] + # } = Nx.Serving.run(serving, prompt) + end + end +end From c109b1a131f6e23d6130092df887a991068c7d3c Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 21 Feb 2024 14:17:08 +0100 Subject: [PATCH 10/42] ControlNet input size --- lib/bumblebee/diffusion/stable_diffusion/control_net.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index 765640e8..e1b68892 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -167,7 +167,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - cond_size = spec.sample_size * 2 ** 2 + cond_size = spec.sample_size * 2 ** 3 controlnet_conditioning_shape = {nil, cond_size, cond_size, 3} Bumblebee.Utils.Model.inputs_to_map([ From 68957d166737bc5877342213c8e47fa3610ffd17 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 23 Feb 2024 13:54:53 +0100 Subject: [PATCH 11/42] Tuple as input for unet --- .../diffusion/unet_2d_conditional.ex | 33 ++++++++----------- .../diffusion/unet_2d_conditional_test.exs | 29 ++++++++-------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index d6c5daf5..e0b4f800 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -154,9 +154,10 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do end) down_residuals = - for {shape, i} <- Enum.with_index(out_shapes), into: %{} do - {"controlnet_down_residual_#{i}", Nx.template(shape, :f32)} + for {shape, i} <- Enum.with_index(out_shapes) do + Nx.template(shape, :f32) end + |> List.to_tuple() mid_dim = List.last(spec.hidden_sizes) @@ -166,9 +167,9 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32), - "controlnet_mid_residual" => Nx.template(mid_residual_shape, :f32) + "controlnet_mid_residual" => Nx.template(mid_residual_shape, :f32), + "controlnet_down_residuals" => down_residuals } - |> Map.merge(down_residuals) end @impl true @@ -221,25 +222,19 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do defp inputs_with_controlnet(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - {mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec) - - down_residuals = - for {shape, i} <- Enum.with_index(out_shapes) do - Axon.input("controlnet_down_residual_#{i}", shape: shape) - end + {mid_spatial, _} = mid_spatial_and_residual_shapes(spec) mid_dim = List.last(spec.hidden_sizes) mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_dim} - Bumblebee.Utils.Model.inputs_to_map( - [ - Axon.input("sample", shape: sample_shape), - Axon.input("timestep", shape: {}), - Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("controlnet_mid_residual", shape: mid_residual_shape) - ] ++ down_residuals - ) + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("sample", shape: sample_shape), + Axon.input("timestep", shape: {}), + Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), + Axon.input("controlnet_mid_residual", shape: mid_residual_shape), + Axon.input("controlnet_down_residuals") + ]) end defp core(inputs, spec) do @@ -310,7 +305,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do controlnet_down_residuals = for i <- 0..(num_down_residuals - 1) do - inputs["controlnet_down_residual_#{i}"] + Axon.nx(inputs["controlnet_down_residuals"], &elem(&1, i)) end sample = diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index 751bae46..b8b1e413 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -68,9 +68,10 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do out_shapes = Enum.drop(out_shapes, -1) down_residuals = - for {shape, i} <- Enum.with_index(out_shapes), into: %{} do - {"controlnet_down_residual_#{i}", Nx.broadcast(0.5, shape)} + for {shape, i} <- Enum.with_index(out_shapes) do + Nx.broadcast(0.5, shape) end + |> List.to_tuple() mid_dim = List.last(spec.hidden_sizes) @@ -81,31 +82,31 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do "sample" => Nx.broadcast(0.5, {1, spec.sample_size, spec.sample_size, 4}), "timestep" => Nx.tensor(1), "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, spec.cross_attention_size}), - "controlnet_mid_residual" => Nx.broadcast(0.5, mid_residual_shape) + "controlnet_mid_residual" => Nx.broadcast(0.5, mid_residual_shape), + "controlnet_down_residuals" => down_residuals } - |> Map.merge(down_residuals) outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.sample) == {1, spec.sample_size, spec.sample_size, spec.in_channels} assert_all_close( - to_channels_first(outputs.sample)[[.., 1..3, 1..3, 1..3]], + outputs.sample[[.., 1..3, 1..3, 1..3]], Nx.tensor([ [ - [-2.1599538326263428, -0.6707635521888733, 0.16482116281986237], - [-1.13632333278656, -1.7593439817428589, -0.9655789136886597], - [-1.3559075593948364, -2.5425026416778564, -1.8208105564117432] + [-2.1599538326263428, -0.06256292015314102, 1.7675844430923462], + [-0.6707635521888733, -0.6823181509971619, 1.0919926166534424], + [0.16482116281986237, -1.2743796110153198, -0.03096655011177063] ], [ - [-0.06256292015314102, -0.6823181509971619, -1.2743796110153198], - [-1.3499518632888794, -1.599103569984436, 0.8080697655677795], - [-1.177065134048462, -1.2682275772094727, 0.9214832186698914] + [-1.13632333278656, -1.3499518632888794, 0.597271203994751], + [-1.7593439817428589, -1.599103569984436, -0.1870473176240921], + [-0.9655789136886597, 0.8080697655677795, 1.1974149942398071] ], [ - [1.7675844430923462, 1.0919926166534424, -0.03096655011177063], - [0.597271203994751, -0.1870473176240921, 1.1974149942398071], - [-0.5016229152679443, -0.6805112957954407, 0.5924324989318848] + [-1.3559075593948364, -1.177065134048462, -0.5016229152679443], + [-2.5425026416778564, -1.2682275772094727, -0.6805112957954407], + [-1.8208105564117432, 0.9214832186698914, 0.5924324989318848] ] ]), atol: 1.0e-4 From 211c2097915c173e49b38df729d9238844305df7 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 23 Feb 2024 14:55:00 +0100 Subject: [PATCH 12/42] Tuple as input inside stable diffusion --- .../diffusion/stable_diffusion_controlnet.ex | 21 +------------------ .../stable_diffusion_controlnet_test.exs | 7 ++++--- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 4f497e17..dcea3edf 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -427,32 +427,13 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do %{down_blocks_residuals: down_blocks_residuals, mid_block_residual: mid_block_residual} = controlnet_predict.(controlnet_params, controlnet_inputs) - # {down_residual_map, _} = - # while {down_residual_map, down_blocks_residuals}, i <- 0..10 do - # key = "down_block_residual_#{i}" - # value = elem(down_blocks_residuals, i) - - # {Map.put(down_residual_map, key, value), down_blocks_residuals} - # end - unet_inputs = %{ "sample" => Nx.concatenate([latents, latents]), "timestep" => timestep, "encoder_hidden_state" => text_embeddings, "controlnet_mid_residual" => mid_block_residual, - "controlnet_down_residual_0" => elem(down_blocks_residuals, 0), - "controlnet_down_residual_1" => elem(down_blocks_residuals, 1), - "controlnet_down_residual_2" => elem(down_blocks_residuals, 2), - "controlnet_down_residual_3" => elem(down_blocks_residuals, 3), - "controlnet_down_residual_4" => elem(down_blocks_residuals, 4), - "controlnet_down_residual_5" => elem(down_blocks_residuals, 5), - "controlnet_down_residual_6" => elem(down_blocks_residuals, 6), - "controlnet_down_residual_7" => elem(down_blocks_residuals, 7), - "controlnet_down_residual_8" => elem(down_blocks_residuals, 8), - "controlnet_down_residual_9" => elem(down_blocks_residuals, 9), - "controlnet_down_residual_10" => elem(down_blocks_residuals, 10), - "controlnet_down_residual_11" => elem(down_blocks_residuals, 11) + "controlnet_down_residuals" => down_blocks_residuals } %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs) diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 3c7992ed..a61d4e48 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -12,7 +12,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do # a tiny random checkpoint. This test is basically to verify # the whole generation computation end-to-end - repository_id = "CompVis/stable-diffusion-v1-4" + repository_id = "runwayml/stable-diffusion-v1-5" # repository_id = "bumblebee-testing/tiny-stable-diffusion" {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) @@ -45,12 +45,13 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do num_steps: 3, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60] + compile: [batch_size: 1, sequence_length: 60], + defn_options: [compiler: EXLA] ) prompt = "numbat in forest, detailed, digital art" - cond_size = (unet.spec.sample_size * 2 ** 3) + cond_size = unet.spec.sample_size * 2 ** 3 controlnet_conditioning = Nx.broadcast(0.5, {cond_size, cond_size, 3}) From edfbd0657ee7f6e0972224fcf81b090ce80ce448 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 23 Feb 2024 15:06:05 +0100 Subject: [PATCH 13/42] Determine conditioning size from spec --- lib/bumblebee/diffusion/stable_diffusion/control_net.ex | 4 ++-- test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index e1b68892..3b5647dc 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -145,7 +145,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} - cond_size = spec.sample_size * 2 ** 3 + cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) controlnet_conditioning_shape = {1, cond_size, cond_size, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} @@ -167,7 +167,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - cond_size = spec.sample_size * 2 ** 3 + cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) controlnet_conditioning_shape = {nil, cond_size, cond_size, 3} Bumblebee.Utils.Model.inputs_to_map([ diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index a61d4e48..d02dc360 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -51,7 +51,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do prompt = "numbat in forest, detailed, digital art" - cond_size = unet.spec.sample_size * 2 ** 3 + cond_size = unet.spec.sample_size * 2 ** (length(unet.spec.hidden_sizes) - 1) controlnet_conditioning = Nx.broadcast(0.5, {cond_size, cond_size, 3}) From bfdd5a35eaaa01d0cd6644d53757b0642b790be9 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 14:42:32 +0100 Subject: [PATCH 14/42] Add docs --- lib/bumblebee/diffusion/stable_diffusion/control_net.ex | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index 3b5647dc..43763676 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -117,6 +117,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do The conditional state (context) to use with cross-attention. + * `"controlnet_conditioning"` - `{batch_size, conditioning_size, conditioning_size, 3}` + + The conditional input + ## Configuration #{Shared.options_doc(options)} @@ -229,13 +233,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do down_blocks_residuals = for residual <- Tuple.to_list(down_blocks_residuals) do - Axon.multiply(residual, conditioning_scale, name: "conditioning_scale") + Axon.multiply(residual, conditioning_scale, name: "down_conditioning_scale") end |> List.to_tuple() mid_block_residual = control_net_mid_block(sample, spec, name: "controlnet_mid_block") - |> Axon.multiply(conditioning_scale) + |> Axon.multiply(conditioning_scale, name: "mid_conditioning_scale") %{ down_blocks_residuals: down_blocks_residuals, From 95d0dffb44d07c19746d1aeca021107edc94f515 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 14:43:04 +0100 Subject: [PATCH 15/42] Rename with_controlnet to with_additional_residuals --- .../diffusion/unet_2d_conditional.ex | 113 ++++++++---------- .../diffusion/unet_2d_conditional_test.exs | 18 ++- 2 files changed, 60 insertions(+), 71 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index e0b4f800..232b669d 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -101,6 +101,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ## Architectures * `:base` - the U-Net model + * `:with_additional_residuals` - with additional residuals ## Inputs @@ -133,42 +134,57 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do alias Bumblebee.Diffusion @impl true - def architectures(), do: [:base, :with_controlnet] + def architectures(), do: [:base, :with_additional_residuals] @impl true def config(spec, opts) do Shared.put_config_attrs(spec, opts) end + defp down_residuals_templates(spec) do + first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} + + state = {spec.sample_size, [first]} + + {_, down_shapes} = + for block_out_channels <- spec.hidden_sizes, reduce: state do + {spatial_size, acc} -> + residuals = + for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channels} + + downsampled_spatial = div(spatial_size, 2) + downsample_shape = {1, downsampled_spatial, downsampled_spatial, block_out_channels} + + {div(spatial_size, 2), acc ++ residuals ++ [downsample_shape]} + end + + # no downsampling in last block + down_shapes = Enum.drop(down_shapes, -1) + + for shape <- down_shapes do + Nx.template(shape, :f32) + end + |> List.to_tuple() + end + @impl true def input_template(spec) do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} - {mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec) - - out_shapes = - Enum.map(out_shapes, fn {_, spatial, spatial, channels} -> - {1, spatial, spatial, channels} - end) - - down_residuals = - for {shape, i} <- Enum.with_index(out_shapes) do - Nx.template(shape, :f32) - end - |> List.to_tuple() - - mid_dim = List.last(spec.hidden_sizes) + mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) + mid_channels = List.last(spec.hidden_sizes) + mid_residual_shape = {1, mid_spatial, mid_spatial, mid_channels} - mid_residual_shape = {1, mid_spatial, mid_spatial, mid_dim} + down_residuals = down_residuals_templates(spec) %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32), - "controlnet_mid_residual" => Nx.template(mid_residual_shape, :f32), - "controlnet_down_residuals" => down_residuals + "additional_mid_residual" => Nx.template(mid_residual_shape, :f32), + "additional_down_residuals" => down_residuals } end @@ -180,9 +196,9 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do end @impl true - def model(%__MODULE__{architecture: :with_controlnet} = spec) do - inputs = inputs_with_controlnet(spec) - sample = core_with_controlnet(inputs, spec) + def model(%__MODULE__{architecture: :with_additional_residuals} = spec) do + inputs = inputs_with_additional_residuals(spec) + sample = core_with_additional_residuals(inputs, spec) Layers.output(%{sample: sample}) end @@ -196,44 +212,19 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ]) end - defp mid_spatial_and_residual_shapes(spec) do - first = {nil, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} - - state = {spec.sample_size, [first]} - - {mid_spatial, out_shapes} = - for block_out_channel <- spec.hidden_sizes, reduce: state do - {spatial_size, acc} -> - residuals = - for _ <- 1..spec.depth, do: {nil, spatial_size, spatial_size, block_out_channel} - - downsampled_spatial = div(spatial_size, 2) - downsample = {nil, downsampled_spatial, downsampled_spatial, block_out_channel} - - {div(spatial_size, 2), acc ++ residuals ++ [downsample]} - end - - mid_spatial = 2 * mid_spatial - out_shapes = Enum.drop(out_shapes, -1) - - {mid_spatial, out_shapes} - end - - defp inputs_with_controlnet(spec) do + defp inputs_with_additional_residuals(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - {mid_spatial, _} = mid_spatial_and_residual_shapes(spec) - - mid_dim = List.last(spec.hidden_sizes) - - mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_dim} + mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) + mid_channels = List.last(spec.hidden_sizes) + mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_channels} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("controlnet_mid_residual", shape: mid_residual_shape), - Axon.input("controlnet_down_residuals") + Axon.input("additional_mid_residual", shape: mid_residual_shape), + Axon.input("additional_down_residuals") ]) end @@ -295,17 +286,17 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ) end - defp core_with_controlnet(inputs, spec) do + defp core_with_additional_residuals(inputs, spec) do sample = inputs["sample"] timestep = inputs["timestep"] encoder_hidden_state = inputs["encoder_hidden_state"] - controlnet_mid_residual = inputs["controlnet_mid_residual"] + additional_mid_residual = inputs["additional_mid_residual"] num_down_residuals = length(spec.hidden_sizes) * (spec.depth + 1) - controlnet_down_residuals = + additional_down_residuals = for i <- 0..(num_down_residuals - 1) do - Axon.nx(inputs["controlnet_down_residuals"], &elem(&1, i)) + Axon.nx(inputs["additional_down_residuals"], &elem(&1, i)) end sample = @@ -344,20 +335,20 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do {sample, down_block_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") - down_residual_zip = Enum.zip(Tuple.to_list(down_block_residuals), controlnet_down_residuals) + down_residual_zip = Enum.zip(Tuple.to_list(down_block_residuals), additional_down_residuals) down_block_residuals = - for {{{down_residual, out_channel}, controlnet_down_residual}, i} <- + for {{{down_residual, out_channels}, additional_down_residual}, i} <- Enum.with_index(down_residual_zip) do - {Axon.add(down_residual, controlnet_down_residual, name: "add_controlnet_down_#{i}"), - out_channel} + {Axon.add(down_residual, additional_down_residual, name: join("add_additional_down", i)), + out_channels} end |> List.to_tuple() mid_block_residual = sample |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - |> Axon.add(controlnet_mid_residual, name: "add_controlnet_mid") + |> Axon.add(additional_mid_residual, name: "add_additional_mid") mid_block_residual |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index b8b1e413..ee655643 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -36,23 +36,22 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do end @tag timeout: :infinity - test ":with_controlnet" do - compvis = "CompVis/stable-diffusion-v1-4" + test ":with_additional_residuals" do tiny = "bumblebee-testing/tiny-stable-diffusion" assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( {:hf, tiny, subdir: "unet"}, - architecture: :with_controlnet + architecture: :with_additional_residuals ) - assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :with_controlnet} = spec + assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :with_additional_residuals} = spec first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} state = {spec.sample_size, [first]} - {mid_spatial, out_shapes} = + {_, out_shapes} = for block_out_channel <- spec.hidden_sizes, reduce: state do {spatial_size, acc} -> residuals = @@ -64,17 +63,16 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do {div(spatial_size, 2), acc ++ residuals ++ [downsample]} end - mid_spatial = 2 * mid_spatial out_shapes = Enum.drop(out_shapes, -1) down_residuals = - for {shape, i} <- Enum.with_index(out_shapes) do + for shape <- out_shapes do Nx.broadcast(0.5, shape) end |> List.to_tuple() + mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) mid_dim = List.last(spec.hidden_sizes) - mid_residual_shape = {1, mid_spatial, mid_spatial, mid_dim} inputs = @@ -82,8 +80,8 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do "sample" => Nx.broadcast(0.5, {1, spec.sample_size, spec.sample_size, 4}), "timestep" => Nx.tensor(1), "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, spec.cross_attention_size}), - "controlnet_mid_residual" => Nx.broadcast(0.5, mid_residual_shape), - "controlnet_down_residuals" => down_residuals + "additional_mid_residual" => Nx.broadcast(0.5, mid_residual_shape), + "additional_down_residuals" => down_residuals } outputs = Axon.predict(model, params, inputs) From 67b54d78de01fd46c3942b06a875837c360b7d7c Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 16:25:06 +0100 Subject: [PATCH 16/42] Update tests --- test/bumblebee/diffusion/stable_diffusion/control_net_test.ex | 2 +- test/bumblebee/diffusion/unet_2d_conditional_test.exs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex index 6c32077f..fcc39462 100644 --- a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex +++ b/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex @@ -23,7 +23,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do "encoder_hidden_state" => Nx.broadcast(0.8, {1, 1, 768}) } - outputs = Axon.predict(model, params, inputs, debug: true) + outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.mid_block_residual) == {1, 8, 8, 1280} diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index ee655643..6740c5ef 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -35,7 +35,6 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do ) end - @tag timeout: :infinity test ":with_additional_residuals" do tiny = "bumblebee-testing/tiny-stable-diffusion" From 8ab0b28a2ee2c2a0b1cb9c45d16e31e5ef5d149c Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 19:14:06 +0100 Subject: [PATCH 17/42] Update docs, rename with_controlnet to with_additional_residuals --- .../diffusion/stable_diffusion_controlnet.ex | 12 +++++++----- .../diffusion/stable_diffusion_controlnet_test.exs | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index dcea3edf..b4f41aa7 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -12,7 +12,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do String.t() | %{ :prompt => String.t(), - :controlnet_conditioning => Nx.Tensor, + :controlnet_conditioning => Nx.Tensor.t(), optional(:negative_prompt) => String.t(), optional(:seed) => integer() } @@ -82,13 +82,14 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) + {:ok, controlnet} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}) {:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) serving = - Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, + Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, controlnet, tokenizer, scheduler, num_steps: 20, num_images_per_prompt: 2, safety_checker: safety_checker, @@ -98,7 +99,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do ) prompt = "numbat in forest, detailed, digital art" - Nx.Serving.run(serving, prompt) + controlnet_conditioning = Nx.tensor() + Nx.Serving.run(serving, %{prompt: prompt, controlnet_conditioning: controlnet_conditioning}) #=> %{ #=> results: [ #=> %{ @@ -432,8 +434,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do "sample" => Nx.concatenate([latents, latents]), "timestep" => timestep, "encoder_hidden_state" => text_embeddings, - "controlnet_mid_residual" => mid_block_residual, - "controlnet_down_residuals" => down_blocks_residuals + "additional_mid_residual" => mid_block_residual, + "additional_down_residuals" => down_blocks_residuals } %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs) diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index d02dc360..147e8c8e 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -19,7 +19,9 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) {:ok, unet} = - Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, architecture: :with_controlnet) + Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, + architecture: :with_additional_residuals + ) {:ok, controlnet} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}) # {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) From 32c21912ca0a9fdf02db03cce841632bd99850a5 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 19:38:44 +0100 Subject: [PATCH 18/42] Compile with conditioning size --- .../diffusion/stable_diffusion_controlnet.ex | 15 ++++++++++----- .../stable_diffusion_controlnet_test.exs | 6 +++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index b4f41aa7..9f89287e 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -159,12 +159,13 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do compile = if compile = opts[:compile] do compile - |> Keyword.validate!([:batch_size, :sequence_length]) - |> Shared.require_options!([:batch_size, :sequence_length]) + |> Keyword.validate!([:batch_size, :sequence_length, :controlnet_conditioning_size]) + |> Shared.require_options!([:batch_size, :sequence_length, :controlnet_conditioning_size]) end batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] + controlnet_conditioning_size = compile[:controlnet_conditioning_size] tokenizer = Bumblebee.configure(tokenizer, @@ -215,7 +216,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do controlnet.params, {safety_checker?, safety_checker[:spec], safety_checker[:params]}, safety_checker_featurizer, - {compile != nil, batch_size, sequence_length}, + {compile != nil, batch_size, sequence_length, controlnet_conditioning_size}, num_images_per_prompt, preallocate_params ] @@ -237,7 +238,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do controlnet_params, {safety_checker?, safety_checker_spec, safety_checker_params}, safety_checker_featurizer, - {compile?, batch_size, sequence_length}, + {compile?, batch_size, sequence_length, controlnet_conditioning_size}, num_images_per_prompt, preallocate_params, defn_options @@ -256,7 +257,11 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) }, "seed" => Nx.template({batch_size}, :s64), - "controlnet_conditioning" => Nx.template({batch_size, 512, 512, 3}, :f32) + "controlnet_conditioning" => + Nx.template( + {batch_size, controlnet_conditioning_size, controlnet_conditioning_size, 3}, + :f32 + ) } [encoder_params, unet_params, vae_params, controlnet_params, inputs] diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 147e8c8e..f00b2d67 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -36,6 +36,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) + cond_size = unet.spec.sample_size * 2 ** (length(unet.spec.hidden_sizes) - 1) + serving = Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( clip, @@ -47,14 +49,12 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do num_steps: 3, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60], + compile: [batch_size: 1, sequence_length: 60, controlnet_conditioning_size: cond_size], defn_options: [compiler: EXLA] ) prompt = "numbat in forest, detailed, digital art" - cond_size = unet.spec.sample_size * 2 ** (length(unet.spec.hidden_sizes) - 1) - controlnet_conditioning = Nx.broadcast(0.5, {cond_size, cond_size, 3}) assert %{ From 44e2b6379465bda5412d18d043ebdd499d456e6a Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 4 Mar 2024 20:01:29 +0100 Subject: [PATCH 19/42] Preprocess image in stable diffusion task --- lib/bumblebee/diffusion/stable_diffusion_controlnet.ex | 2 +- test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 9f89287e..98fec3ee 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -305,7 +305,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do end defp preprocess_image(image) do - image + NxImage.to_continuous(image, 0, 1) end defp client_preprocessing(input, tokenizer) do diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index f00b2d67..4ab7f98f 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -55,7 +55,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do prompt = "numbat in forest, detailed, digital art" - controlnet_conditioning = Nx.broadcast(0.5, {cond_size, cond_size, 3}) + controlnet_conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {cond_size, cond_size, 3}) assert %{ results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] From 061fcd44a5e922d61d9523a29a49fc34afc517a5 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 13 Mar 2024 19:36:35 +0100 Subject: [PATCH 20/42] Optional conditioning_scale input --- .../diffusion/stable_diffusion/control_net.ex | 10 ++++++++-- .../diffusion/stable_diffusion_controlnet.ex | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index 43763676..b7900324 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -157,6 +157,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), "controlnet_conditioning" => Nx.template(controlnet_conditioning_shape, :f32), + "conditioning_scale" => Nx.template({}, :f32), "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) } end @@ -178,6 +179,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("controlnet_conditioning", shape: controlnet_conditioning_shape), + Axon.input("conditioning_scale", optional: true), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) ]) end @@ -186,6 +188,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do sample = inputs["sample"] timestep = inputs["timestep"] controlnet_conditioning = inputs["controlnet_conditioning"] + + conditioning_scale = + Bumblebee.Layers.default inputs["conditioning_scale"] do + Axon.constant(1) + end + encoder_hidden_state = inputs["encoder_hidden_state"] timestep = @@ -226,8 +234,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do sample = mid_block(sample, timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - conditioning_scale = Axon.constant(1) - down_blocks_residuals = control_net_down_blocks(down_blocks_residuals, name: "controlnet_down_blocks") diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 98fec3ee..2f019f49 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -13,6 +13,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do | %{ :prompt => String.t(), :controlnet_conditioning => Nx.Tensor.t(), + optional(:conditioning_scale) => integer(), optional(:negative_prompt) => String.t(), optional(:seed) => integer() } @@ -261,7 +262,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do Nx.template( {batch_size, controlnet_conditioning_size, controlnet_conditioning_size, 3}, :f32 - ) + ), + "conditioning_scale" => Nx.template({batch_size}, :f32) } [encoder_params, unet_params, vae_params, controlnet_params, inputs] @@ -328,10 +330,15 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do |> Nx.stack() |> preprocess_image() + conditioning_scale = + Enum.map(inputs, & &1.conditioning_scale) + |> Nx.tensor(type: :f32, backend: Nx.BinaryBackend) + inputs = %{ "conditional_and_unconditional" => prompt_pairs, "seed" => seed, - "controlnet_conditioning" => controlnet_conditioning + "controlnet_conditioning" => controlnet_conditioning, + "conditioning_scale" => conditioning_scale } {Nx.Batch.concatenate([inputs]), multi?} @@ -384,6 +391,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do seed = inputs["seed"] controlnet_conditioning = inputs["controlnet_conditioning"] + conditioning_scale = inputs["conditioning_scale"] inputs = inputs["conditional_and_unconditional"] @@ -422,10 +430,11 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {latents, _} = while {latents, {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, - controlnet_params}}, + conditioning_scale, controlnet_params}}, timestep <- timesteps do controlnet_inputs = %{ "controlnet_conditioning" => controlnet_conditioning, + "conditioning_scale" => conditioning_scale, "sample" => Nx.concatenate([latents, latents]), "timestep" => timestep, "encoder_hidden_state" => text_embeddings @@ -463,7 +472,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {latents, {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, - controlnet_params}} + conditioning_scale, controlnet_params}} end latents = latents * (1 / 0.18215) @@ -493,6 +502,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do %{ prompt: prompt, controlnet_conditioning: controlnet_conditioning, + conditioning_scale: input[:conditioning_scale] || 1.0, negative_prompt: input[:negative_prompt] || "", seed: input[:seed] || :erlang.system_time() }} From de5dbcbf561f3f56278c878a1cbe059ce8fd8b65 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 14 Mar 2024 09:35:38 +0100 Subject: [PATCH 21/42] Test with tiny models --- .../diffusion/stable_diffusion_controlnet_test.exs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 4ab7f98f..238df542 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -3,17 +3,15 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do import Bumblebee.TestHelpers - # @moduletag serving_test_tags() + @moduletag serving_test_tags() - @tag timeout: :infinity describe "text_to_image/6" do test "generates image for a text prompt with controlnet" do # Since we don't assert on the result in this case, we use # a tiny random checkpoint. This test is basically to verify # the whole generation computation end-to-end - repository_id = "runwayml/stable-diffusion-v1-5" - # repository_id = "bumblebee-testing/tiny-stable-diffusion" + repository_id = "bumblebee-testing/tiny-stable-diffusion" {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) @@ -23,8 +21,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do architecture: :with_additional_residuals ) - {:ok, controlnet} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}) - # {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) + {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) {:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) From 50071096d5a6b04ffb381370880eed3b26189754 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 14 Mar 2024 10:31:08 +0100 Subject: [PATCH 22/42] Rename test file to .exs --- .../{control_net_test.ex => control_net_test.exs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/bumblebee/diffusion/stable_diffusion/{control_net_test.ex => control_net_test.exs} (100%) diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.ex b/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs similarity index 100% rename from test/bumblebee/diffusion/stable_diffusion/control_net_test.ex rename to test/bumblebee/diffusion/stable_diffusion/control_net_test.exs From 19ab327c7a5c745c4be4247d593355f97f0124b4 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 14 Mar 2024 10:32:08 +0100 Subject: [PATCH 23/42] Add optional additional residuals in :base architecture --- .../diffusion/unet_2d_conditional.ex | 149 ++++++------------ .../stable_diffusion_controlnet_test.exs | 5 +- .../diffusion/unet_2d_conditional_test.exs | 11 +- 3 files changed, 57 insertions(+), 108 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 232b669d..51e12047 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -101,7 +101,6 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ## Architectures * `:base` - the U-Net model - * `:with_additional_residuals` - with additional residuals ## Inputs @@ -134,7 +133,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do alias Bumblebee.Diffusion @impl true - def architectures(), do: [:base, :with_additional_residuals] + def architectures(), do: [:base] @impl true def config(spec, opts) do @@ -189,32 +188,15 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do end @impl true - def model(%__MODULE__{architecture: :base} = spec) do + def model(%__MODULE__{} = spec) do inputs = inputs(spec) sample = core(inputs, spec) Layers.output(%{sample: sample}) end - @impl true - def model(%__MODULE__{architecture: :with_additional_residuals} = spec) do - inputs = inputs_with_additional_residuals(spec) - sample = core_with_additional_residuals(inputs, spec) - Layers.output(%{sample: sample}) - end - defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - Bumblebee.Utils.Model.inputs_to_map([ - Axon.input("sample", shape: sample_shape), - Axon.input("timestep", shape: {}), - Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) - ]) - end - - defp inputs_with_additional_residuals(spec) do - sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) mid_channels = List.last(spec.hidden_sizes) mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_channels} @@ -223,8 +205,8 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("additional_mid_residual", shape: mid_residual_shape), - Axon.input("additional_down_residuals") + Axon.input("additional_mid_residual", shape: mid_residual_shape, optional: true), + Axon.input("additional_down_residuals", optional: true) ]) end @@ -269,86 +251,16 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do {sample, down_block_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") - sample - |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, - name: "up_blocks" - ) - |> Axon.group_norm(spec.group_norm_num_groups, - epsilon: spec.group_norm_epsilon, - name: "output_norm" - ) - |> Axon.activation(:silu) - |> Axon.conv(spec.out_channels, - kernel_size: 3, - padding: [{1, 1}, {1, 1}], - name: "output_conv" - ) - end - - defp core_with_additional_residuals(inputs, spec) do - sample = inputs["sample"] - timestep = inputs["timestep"] - encoder_hidden_state = inputs["encoder_hidden_state"] - additional_mid_residual = inputs["additional_mid_residual"] - - num_down_residuals = length(spec.hidden_sizes) * (spec.depth + 1) - - additional_down_residuals = - for i <- 0..(num_down_residuals - 1) do - Axon.nx(inputs["additional_down_residuals"], &elem(&1, i)) - end - - sample = - if spec.center_input_sample do - Axon.nx(sample, fn sample -> 2 * sample - 1.0 end, op_name: :center) - else - sample - end - - timestep = - Axon.layer( - fn sample, timestep, _opts -> - Nx.broadcast(timestep, {Nx.axis_size(sample, 0)}) - end, - [sample, timestep], - op_name: :broadcast - ) - - timestep_embedding = - timestep - |> Diffusion.Layers.timestep_sinusoidal_embedding(hd(spec.hidden_sizes), - flip_sin_to_cos: spec.embedding_flip_sin_to_cos, - frequency_correction_term: spec.embedding_frequency_correction_term - ) - |> Diffusion.Layers.UNet.timestep_embedding_mlp(hd(spec.hidden_sizes) * 4, - name: "time_embedding" - ) - - sample = - Axon.conv(sample, hd(spec.hidden_sizes), - kernel_size: 3, - padding: [{1, 1}, {1, 1}], - name: "input_conv" - ) - - {sample, down_block_residuals} = - down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") - - down_residual_zip = Enum.zip(Tuple.to_list(down_block_residuals), additional_down_residuals) - down_block_residuals = - for {{{down_residual, out_channels}, additional_down_residual}, i} <- - Enum.with_index(down_residual_zip) do - {Axon.add(down_residual, additional_down_residual, name: join("add_additional_down", i)), - out_channels} - end - |> List.to_tuple() + add_optional_additional_down_residuals( + down_block_residuals, + inputs["additional_down_residuals"] + ) mid_block_residual = sample |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - |> Axon.add(additional_mid_residual, name: "add_additional_mid") + |> add_optional_additional_mid_residual(inputs["additional_mid_residual"]) mid_block_residual |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, @@ -489,6 +401,49 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do sample end + def add_optional_additional_mid_residual( + %Axon{} = mid_block_residual, + %Axon{} = additional_mid_block_residual + ) do + Axon.layer( + fn mid_block_residual, additional_mid_block_residual, _opts -> + case additional_mid_block_residual do + %Axon.None{} -> + mid_block_residual + + additional_mid_block_residual -> + Nx.add(mid_block_residual, additional_mid_block_residual) + end + end, + [mid_block_residual, Axon.optional(additional_mid_block_residual)], + name: "add_additional_mid" + ) + end + + def add_optional_additional_down_residuals( + down_block_residuals, + %Axon{} = additional_down_residuals + ) do + down_residuals = Tuple.to_list(down_block_residuals) + + for {{down_residual, out_channels}, i} <- Enum.with_index(down_residuals) do + {Axon.layer( + fn down_residual, additional_down_residuals, _opts -> + case additional_down_residuals do + %Axon.None{} -> + down_residual + + additional_down_residuals -> + Nx.add(down_residual, elem(additional_down_residuals, i)) + end + end, + [down_residual, Axon.optional(additional_down_residuals)], + name: join("add_additional_down", i) + ), out_channels} + end + |> List.to_tuple() + end + defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do spec.num_attention_heads end diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 238df542..b731e6f7 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -16,10 +16,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) - {:ok, unet} = - Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, - architecture: :with_additional_residuals - ) + {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index 6740c5ef..c55a57bc 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -19,7 +19,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) } - outputs = Axon.predict(model, params, inputs) + outputs = Axon.predict(model, params, inputs, debug: true) assert Nx.shape(outputs.sample) == {1, 32, 32, 4} @@ -35,16 +35,13 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do ) end - test ":with_additional_residuals" do + test ":base with additional residuals" do tiny = "bumblebee-testing/tiny-stable-diffusion" assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model( - {:hf, tiny, subdir: "unet"}, - architecture: :with_additional_residuals - ) + Bumblebee.load_model({:hf, tiny, subdir: "unet"}) - assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :with_additional_residuals} = spec + assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} From 0204cd6cf0d35edf85ace81b52006e56e94f58b6 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 14 Mar 2024 10:35:16 +0100 Subject: [PATCH 24/42] Remove debug from test --- test/bumblebee/diffusion/unet_2d_conditional_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index c55a57bc..cc46385e 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -19,7 +19,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) } - outputs = Axon.predict(model, params, inputs, debug: true) + outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.sample) == {1, 32, 32, 4} From 7bfdfe7d856a411bff84cf1f24ce3ec005800cad Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 14 Mar 2024 10:38:41 +0100 Subject: [PATCH 25/42] Specify architecture --- lib/bumblebee/diffusion/unet_2d_conditional.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 51e12047..fcd4700f 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -188,7 +188,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do end @impl true - def model(%__MODULE__{} = spec) do + def model(%__MODULE__{architecture: :base} = spec) do inputs = inputs(spec) sample = core(inputs, spec) Layers.output(%{sample: sample}) From c7b95302f73d307240f92c56bdda65bedd5a51f3 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 3 Apr 2024 14:17:28 +0200 Subject: [PATCH 26/42] Remove residual templates --- .../diffusion/unet_2d_conditional.ex | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index fcd4700f..ad56aa45 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -140,50 +140,16 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Shared.put_config_attrs(spec, opts) end - defp down_residuals_templates(spec) do - first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} - - state = {spec.sample_size, [first]} - - {_, down_shapes} = - for block_out_channels <- spec.hidden_sizes, reduce: state do - {spatial_size, acc} -> - residuals = - for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channels} - - downsampled_spatial = div(spatial_size, 2) - downsample_shape = {1, downsampled_spatial, downsampled_spatial, block_out_channels} - - {div(spatial_size, 2), acc ++ residuals ++ [downsample_shape]} - end - - # no downsampling in last block - down_shapes = Enum.drop(down_shapes, -1) - - for shape <- down_shapes do - Nx.template(shape, :f32) - end - |> List.to_tuple() - end - @impl true def input_template(spec) do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} - mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) - mid_channels = List.last(spec.hidden_sizes) - mid_residual_shape = {1, mid_spatial, mid_spatial, mid_channels} - - down_residuals = down_residuals_templates(spec) - %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), - "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32), - "additional_mid_residual" => Nx.template(mid_residual_shape, :f32), - "additional_down_residuals" => down_residuals + "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) } end @@ -197,15 +163,11 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) - mid_channels = List.last(spec.hidden_sizes) - mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_channels} - Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("additional_mid_residual", shape: mid_residual_shape, optional: true), + Axon.input("additional_mid_residual", optional: true), Axon.input("additional_down_residuals", optional: true) ]) end From 33ba7f8910b8bbc0d6a45838fce22fc9c8f3fb1f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 3 Apr 2024 14:40:53 +0200 Subject: [PATCH 27/42] Rename mid_block_residual to sample --- lib/bumblebee/diffusion/unet_2d_conditional.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index ad56aa45..ef4b2aa7 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -219,12 +219,12 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do inputs["additional_down_residuals"] ) - mid_block_residual = + sample = sample |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") |> add_optional_additional_mid_residual(inputs["additional_mid_residual"]) - mid_block_residual + sample |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, name: "up_blocks" ) From 1d1492f9c8f25d895ed3d3eb34cfd1d654c35b2f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 3 Apr 2024 15:00:35 +0200 Subject: [PATCH 28/42] Uncomment and update stable diffusion controlnet test --- .../stable_diffusion_controlnet_test.exs | 64 +++++++++++-------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index b731e6f7..280ce454 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -42,9 +42,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do scheduler, num_steps: 3, safety_checker: safety_checker, - safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60, controlnet_conditioning_size: cond_size], - defn_options: [compiler: EXLA] + safety_checker_featurizer: featurizer ) prompt = "numbat in forest, detailed, digital art" @@ -61,35 +59,51 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do # Without safety checker - # serving = - # Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( - # clip, - # unet, - # vae, - # tokenizer, - # scheduler, - # num_steps: 3 - # ) + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3 + ) - # prompt = "numbat in forest, detailed, digital art" + prompt = "numbat in forest, detailed, digital art" - # assert %{results: [%{image: %Nx.Tensor{}}]} = Nx.Serving.run(serving, prompt) + assert %{results: [%{image: %Nx.Tensor{}}]} = + Nx.Serving.run(serving, %{ + prompt: prompt, + controlnet_conditioning: controlnet_conditioning + }) # With compilation - # serving = - # Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, - # num_steps: 3, - # safety_checker: safety_checker, - # safety_checker_featurizer: featurizer, - # defn_options: [compiler: EXLA] - # ) + serving = + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, + num_steps: 3, + safety_checker: safety_checker, + safety_checker_featurizer: featurizer, + compile: [batch_size: 1, sequence_length: 60, controlnet_conditioning_size: cond_size], + defn_options: [compiler: EXLA] + ) - # prompt = "numbat in forest, detailed, digital art" + prompt = "numbat in forest, detailed, digital art" - # assert %{ - # results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] - # } = Nx.Serving.run(serving, prompt) + assert %{ + results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] + } = + Nx.Serving.run(serving, %{ + prompt: prompt, + controlnet_conditioning: controlnet_conditioning + }) end end end From 120b75859c7784de70bd8353b71c18f43a001ea1 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 5 Apr 2024 15:59:10 +0200 Subject: [PATCH 29/42] Update docs --- .../diffusion/stable_diffusion_controlnet.ex | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 2f019f49..4e4d825a 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -1,6 +1,6 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do @moduledoc """ - High-level tasks based on Stable Diffusion. + High-level tasks based on Stable Diffusion with ControlNet. """ import Nx.Defn @@ -90,17 +90,32 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) serving = - Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, controlnet, tokenizer, scheduler, + Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( + clip, + unet, + vae, + controlnet, + tokenizer, + scheduler, num_steps: 20, num_images_per_prompt: 2, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60], + compile: [batch_size: 1, sequence_length: 60, conditioning_size: 512], defn_options: [compiler: EXLA] ) prompt = "numbat in forest, detailed, digital art" - controlnet_conditioning = Nx.tensor() + + controlnet_conditioning = + Nx.tensor( + [for(_ <- 1..8, do: [255]) ++ for(_ <- 1..24, do: [0])], + type: :u8 + ) + |> Nx.tile([256, 8, 3]) + |> Nx.pad(0, [{192, 64, 0}, {192, 64, 0}, {0, 0, 0}]) + |> Nx.transpose(axes: [1, 0, 2]) + Nx.Serving.run(serving, %{prompt: prompt, controlnet_conditioning: controlnet_conditioning}) #=> %{ #=> results: [ From a5ca864da0f48ece2c4336500a8d9fdbc189f68b Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 5 Apr 2024 16:17:12 +0200 Subject: [PATCH 30/42] Rename controlnet_conditioning to conditioning --- .../diffusion/stable_diffusion/control_net.ex | 14 +++--- .../diffusion/stable_diffusion_controlnet.ex | 44 +++++++++---------- .../stable_diffusion/control_net_test.exs | 2 +- .../stable_diffusion_controlnet_test.exs | 10 ++--- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex index b7900324..d5794690 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/stable_diffusion/control_net.ex @@ -117,7 +117,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do The conditional state (context) to use with cross-attention. - * `"controlnet_conditioning"` - `{batch_size, conditioning_size, conditioning_size, 3}` + * `"conditioning"` - `{batch_size, conditioning_size, conditioning_size, 3}` The conditional input @@ -150,13 +150,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do timestep_shape = {} cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) - controlnet_conditioning_shape = {1, cond_size, cond_size, 3} + conditioning_shape = {1, cond_size, cond_size, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} %{ "sample" => Nx.template(sample_shape, :f32), "timestep" => Nx.template(timestep_shape, :u32), - "controlnet_conditioning" => Nx.template(controlnet_conditioning_shape, :f32), + "conditioning" => Nx.template(conditioning_shape, :f32), "conditioning_scale" => Nx.template({}, :f32), "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32) } @@ -173,12 +173,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) - controlnet_conditioning_shape = {nil, cond_size, cond_size, 3} + conditioning_shape = {nil, cond_size, cond_size, 3} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), - Axon.input("controlnet_conditioning", shape: controlnet_conditioning_shape), + Axon.input("conditioning", shape: conditioning_shape), Axon.input("conditioning_scale", optional: true), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}) ]) @@ -187,7 +187,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do defp core(inputs, spec) do sample = inputs["sample"] timestep = inputs["timestep"] - controlnet_conditioning = inputs["controlnet_conditioning"] + conditioning = inputs["conditioning"] conditioning_scale = Bumblebee.Layers.default inputs["conditioning_scale"] do @@ -223,7 +223,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do ) control_net_cond_embeddings = - control_net_embeddings(controlnet_conditioning, spec, name: "controlnet_cond_embedding") + control_net_embeddings(conditioning, spec, name: "controlnet_cond_embedding") sample = Axon.add(sample, control_net_cond_embeddings, name: "add_sample_control_net_embeddings") diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 4e4d825a..153a6e7b 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -12,7 +12,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do String.t() | %{ :prompt => String.t(), - :controlnet_conditioning => Nx.Tensor.t(), + :conditioning => Nx.Tensor.t(), optional(:conditioning_scale) => integer(), optional(:negative_prompt) => String.t(), optional(:seed) => integer() @@ -107,7 +107,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do prompt = "numbat in forest, detailed, digital art" - controlnet_conditioning = + conditioning = Nx.tensor( [for(_ <- 1..8, do: [255]) ++ for(_ <- 1..24, do: [0])], type: :u8 @@ -116,7 +116,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do |> Nx.pad(0, [{192, 64, 0}, {192, 64, 0}, {0, 0, 0}]) |> Nx.transpose(axes: [1, 0, 2]) - Nx.Serving.run(serving, %{prompt: prompt, controlnet_conditioning: controlnet_conditioning}) + Nx.Serving.run(serving, %{prompt: prompt, conditioning: conditioning}) #=> %{ #=> results: [ #=> %{ @@ -175,13 +175,13 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do compile = if compile = opts[:compile] do compile - |> Keyword.validate!([:batch_size, :sequence_length, :controlnet_conditioning_size]) - |> Shared.require_options!([:batch_size, :sequence_length, :controlnet_conditioning_size]) + |> Keyword.validate!([:batch_size, :sequence_length, :conditioning_size]) + |> Shared.require_options!([:batch_size, :sequence_length, :conditioning_size]) end batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] - controlnet_conditioning_size = compile[:controlnet_conditioning_size] + conditioning_size = compile[:conditioning_size] tokenizer = Bumblebee.configure(tokenizer, @@ -232,7 +232,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do controlnet.params, {safety_checker?, safety_checker[:spec], safety_checker[:params]}, safety_checker_featurizer, - {compile != nil, batch_size, sequence_length, controlnet_conditioning_size}, + {compile != nil, batch_size, sequence_length, conditioning_size}, num_images_per_prompt, preallocate_params ] @@ -254,7 +254,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do controlnet_params, {safety_checker?, safety_checker_spec, safety_checker_params}, safety_checker_featurizer, - {compile?, batch_size, sequence_length, controlnet_conditioning_size}, + {compile?, batch_size, sequence_length, conditioning_size}, num_images_per_prompt, preallocate_params, defn_options @@ -273,9 +273,9 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) }, "seed" => Nx.template({batch_size}, :s64), - "controlnet_conditioning" => + "conditioning" => Nx.template( - {batch_size, controlnet_conditioning_size, controlnet_conditioning_size, 3}, + {batch_size, conditioning_size, conditioning_size, 3}, :f32 ), "conditioning_scale" => Nx.template({batch_size}, :f32) @@ -340,8 +340,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0)) end) - controlnet_conditioning = - Enum.map(inputs, & &1.controlnet_conditioning) + conditioning = + Enum.map(inputs, & &1.conditioning) |> Nx.stack() |> preprocess_image() @@ -352,7 +352,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do inputs = %{ "conditional_and_unconditional" => prompt_pairs, "seed" => seed, - "controlnet_conditioning" => controlnet_conditioning, + "conditioning" => conditioning, "conditioning_scale" => conditioning_scale } @@ -405,7 +405,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do guidance_scale = opts[:guidance_scale] seed = inputs["seed"] - controlnet_conditioning = inputs["controlnet_conditioning"] + conditioning = inputs["conditioning"] conditioning_scale = inputs["conditioning_scale"] inputs = @@ -444,11 +444,11 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {latents, _} = while {latents, - {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, - conditioning_scale, controlnet_params}}, + {scheduler_state, text_embeddings, unet_params, conditioning, conditioning_scale, + controlnet_params}}, timestep <- timesteps do controlnet_inputs = %{ - "controlnet_conditioning" => controlnet_conditioning, + "conditioning" => conditioning, "conditioning_scale" => conditioning_scale, "sample" => Nx.concatenate([latents, latents]), "timestep" => timestep, @@ -486,8 +486,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do latents = Nx.devectorize(latents) {latents, - {scheduler_state, text_embeddings, unet_params, controlnet_conditioning, - conditioning_scale, controlnet_params}} + {scheduler_state, text_embeddings, unet_params, conditioning, conditioning_scale, + controlnet_params}} end latents = latents * (1 / 0.18215) @@ -512,11 +512,11 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt}) - defp validate_input(%{prompt: prompt, controlnet_conditioning: controlnet_conditioning} = input) do + defp validate_input(%{prompt: prompt, conditioning: conditioning} = input) do {:ok, %{ prompt: prompt, - controlnet_conditioning: controlnet_conditioning, + conditioning: conditioning, conditioning_scale: input[:conditioning_scale] || 1.0, negative_prompt: input[:negative_prompt] || "", seed: input[:seed] || :erlang.system_time() @@ -525,7 +525,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do defp validate_input(%{} = input) do {:error, - "expected the input map to have :prompt and :controlnet_conditioning key, got: #{inspect(input)}"} + "expected the input map to have :prompt and :conditioning key, got: #{inspect(input)}"} end defp validate_input(input) do diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs b/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs index fcc39462..efb7540f 100644 --- a/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs @@ -18,7 +18,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do inputs = %{ "sample" => Nx.broadcast(0.5, {1, 64, 64, 4}), - "controlnet_conditioning" => Nx.broadcast(0.8, {1, 512, 512, 3}), + "conditioning" => Nx.broadcast(0.8, {1, 512, 512, 3}), "timestep" => Nx.tensor(0), "encoder_hidden_state" => Nx.broadcast(0.8, {1, 1, 768}) } diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 280ce454..96ce2ff5 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -47,14 +47,14 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do prompt = "numbat in forest, detailed, digital art" - controlnet_conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {cond_size, cond_size, 3}) + conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {cond_size, cond_size, 3}) assert %{ results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] } = Nx.Serving.run(serving, %{ prompt: prompt, - controlnet_conditioning: controlnet_conditioning + conditioning: conditioning }) # Without safety checker @@ -75,7 +75,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do assert %{results: [%{image: %Nx.Tensor{}}]} = Nx.Serving.run(serving, %{ prompt: prompt, - controlnet_conditioning: controlnet_conditioning + conditioning: conditioning }) # With compilation @@ -91,7 +91,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do num_steps: 3, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60, controlnet_conditioning_size: cond_size], + compile: [batch_size: 1, sequence_length: 60, conditioning_size: cond_size], defn_options: [compiler: EXLA] ) @@ -102,7 +102,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do } = Nx.Serving.run(serving, %{ prompt: prompt, - controlnet_conditioning: controlnet_conditioning + conditioning: conditioning }) end end From b023fb84f2359879853e7e414305c7fb292e5748 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 5 Apr 2024 16:49:37 +0200 Subject: [PATCH 31/42] Share optional add logic --- .../diffusion/unet_2d_conditional.ex | 51 +++++++------------ 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index ef4b2aa7..af166c1c 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -363,49 +363,32 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do sample end - def add_optional_additional_mid_residual( - %Axon{} = mid_block_residual, - %Axon{} = additional_mid_block_residual - ) do - Axon.layer( - fn mid_block_residual, additional_mid_block_residual, _opts -> - case additional_mid_block_residual do - %Axon.None{} -> - mid_block_residual - - additional_mid_block_residual -> - Nx.add(mid_block_residual, additional_mid_block_residual) - end - end, - [mid_block_residual, Axon.optional(additional_mid_block_residual)], - name: "add_additional_mid" - ) + defp add_optional_additional_mid_residual(mid_block_residual, additional_mid_block_residual) do + maybe_add(mid_block_residual, additional_mid_block_residual) end - def add_optional_additional_down_residuals( - down_block_residuals, - %Axon{} = additional_down_residuals - ) do + defp add_optional_additional_down_residuals(down_block_residuals, additional_down_residuals) do down_residuals = Tuple.to_list(down_block_residuals) for {{down_residual, out_channels}, i} <- Enum.with_index(down_residuals) do - {Axon.layer( - fn down_residual, additional_down_residuals, _opts -> - case additional_down_residuals do - %Axon.None{} -> - down_residual - - additional_down_residuals -> - Nx.add(down_residual, elem(additional_down_residuals, i)) - end - end, - [down_residual, Axon.optional(additional_down_residuals)], - name: join("add_additional_down", i) - ), out_channels} + additional_down_residual = Axon.nx(additional_down_residuals, &elem(&1, i)) + {maybe_add(down_residual, additional_down_residual), out_channels} end |> List.to_tuple() end + defp maybe_add(left, maybe_right) do + Axon.layer( + fn left, maybe_right, _opts -> + case maybe_right do + %Axon.None{} -> left + right -> Nx.add(left, right) + end + end, + [left, Axon.optional(maybe_right)] + ) + end + defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do spec.num_attention_heads end From 83513d702e2fe7a9f537c017b1bd20df98c86670 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 5 Apr 2024 17:00:17 +0200 Subject: [PATCH 32/42] Rename and move to Bumblebee.Diffusion.ControlNet --- lib/bumblebee.ex | 2 +- .../diffusion/{stable_diffusion => }/control_net.ex | 2 +- .../diffusion/{stable_diffusion => }/control_net_test.exs | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) rename lib/bumblebee/diffusion/{stable_diffusion => }/control_net.ex (99%) rename test/bumblebee/diffusion/{stable_diffusion => }/control_net_test.exs (85%) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index d38fa329..f289a480 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -116,7 +116,7 @@ defmodule Bumblebee do "CLIPModel" => {Bumblebee.Multimodal.Clip, :base}, "CLIPTextModel" => {Bumblebee.Text.ClipText, :base}, "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base}, - "ControlNetModel" => {Bumblebee.Diffusion.StableDiffusion.ControlNet, :base}, + "ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base}, "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification}, "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base}, "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification}, diff --git a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex b/lib/bumblebee/diffusion/control_net.ex similarity index 99% rename from lib/bumblebee/diffusion/stable_diffusion/control_net.ex rename to lib/bumblebee/diffusion/control_net.ex index d5794690..ffea4f86 100644 --- a/lib/bumblebee/diffusion/stable_diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/control_net.ex @@ -1,4 +1,4 @@ -defmodule Bumblebee.Diffusion.StableDiffusion.ControlNet do +defmodule Bumblebee.Diffusion.ControlNet do alias Bumblebee.Shared options = [ diff --git a/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs b/test/bumblebee/diffusion/control_net_test.exs similarity index 85% rename from test/bumblebee/diffusion/stable_diffusion/control_net_test.exs rename to test/bumblebee/diffusion/control_net_test.exs index efb7540f..8ea5dbf7 100644 --- a/test/bumblebee/diffusion/stable_diffusion/control_net_test.exs +++ b/test/bumblebee/diffusion/control_net_test.exs @@ -1,4 +1,4 @@ -defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do +defmodule Bumblebee.Diffusion.ControlNetTest do use ExUnit.Case, async: true import Bumblebee.TestHelpers @@ -8,11 +8,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion.ControlNetTest do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}, - module: Bumblebee.Diffusion.StableDiffusion.ControlNet, + module: Bumblebee.Diffusion.ControlNet, architecture: :base ) - assert %Bumblebee.Diffusion.StableDiffusion.ControlNet{ + assert %Bumblebee.Diffusion.ControlNet{ architecture: :base } = spec From 2821a805f8ffa69b45f04c4e06bea636b5ae1f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sun, 7 Apr 2024 17:46:43 +0800 Subject: [PATCH 33/42] Fix mapping for layer names with trailing substitutions --- lib/bumblebee/conversion/pytorch.ex | 14 ++++++++++++-- lib/bumblebee/diffusion/control_net.ex | 6 +++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lib/bumblebee/conversion/pytorch.ex b/lib/bumblebee/conversion/pytorch.ex index 60b49213..5d344810 100644 --- a/lib/bumblebee/conversion/pytorch.ex +++ b/lib/bumblebee/conversion/pytorch.ex @@ -297,8 +297,18 @@ defmodule Bumblebee.Conversion.PyTorch do defp match_template(name, template), do: match_template(name, template, %{}) defp match_template(<<_, _::binary>> = name, <<"{", template::binary>>, substitutes) do - [value, name] = String.split(name, ".", parts: 2) - [key, template] = String.split(template, "}.", parts: 2) + {value, name} = + case String.split(name, ".", parts: 2) do + [value] -> {value, ""} + [value, name] -> {value, name} + end + + {key, template} = + case String.split(template, "}", parts: 2) do + [key, ""] -> {key, ""} + [key, "." <> template] -> {key, template} + end + match_template(name, template, put_in(substitutes[key], value)) end diff --git a/lib/bumblebee/diffusion/control_net.ex b/lib/bumblebee/diffusion/control_net.ex index ffea4f86..471ce982 100644 --- a/lib/bumblebee/diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/control_net.ex @@ -302,14 +302,14 @@ defmodule Bumblebee.Diffusion.ControlNet do |> Axon.conv(in_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], - name: name |> join(2 * i) |> join("conv"), + name: name |> join("inner_convs") |> join(2 * i), activation: :silu ) |> Axon.conv(out_channels, kernel_size: 3, padding: [{1, 1}, {1, 1}], strides: 2, - name: name |> join(2 * i + 1) |> join("conv"), + name: name |> join("inner_convs") |> join(2 * i + 1), activation: :silu ) end @@ -488,7 +488,7 @@ defmodule Bumblebee.Diffusion.ControlNet do controlnet = %{ "controlnet_down_blocks.{m}.zero_conv" => "controlnet_down_blocks.{m}", "controlnet_cond_embedding.input_conv" => "controlnet_cond_embedding.conv_in", - "controlnet_cond_embedding.{m}.conv" => "controlnet_cond_embedding.blocks.{m}", + "controlnet_cond_embedding.inner_convs.{m}" => "controlnet_cond_embedding.blocks.{m}", "controlnet_cond_embedding.output_conv" => "controlnet_cond_embedding.conv_out", "controlnet_mid_block.zero_conv" => "controlnet_mid_block" } From f002258838b39f570aae1d34509a2d4283c58b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sun, 7 Apr 2024 17:55:34 +0800 Subject: [PATCH 34/42] Wrap output in a container --- lib/bumblebee/diffusion/control_net.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/diffusion/control_net.ex b/lib/bumblebee/diffusion/control_net.ex index 471ce982..abcf5353 100644 --- a/lib/bumblebee/diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/control_net.ex @@ -248,7 +248,7 @@ defmodule Bumblebee.Diffusion.ControlNet do |> Axon.multiply(conditioning_scale, name: "mid_conditioning_scale") %{ - down_blocks_residuals: down_blocks_residuals, + down_blocks_residuals: Axon.container(down_blocks_residuals), mid_block_residual: mid_block_residual } end From 29c6f1d721671e4c58bd7a54ba766994db24cc6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 18:13:58 +0800 Subject: [PATCH 35/42] Always infer conditioning size --- lib/bumblebee/diffusion/control_net.ex | 8 ++++---- lib/bumblebee/diffusion/stable_diffusion_controlnet.ex | 10 ++++++---- .../diffusion/stable_diffusion_controlnet_test.exs | 6 +++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/diffusion/control_net.ex b/lib/bumblebee/diffusion/control_net.ex index abcf5353..cbdf2d24 100644 --- a/lib/bumblebee/diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/control_net.ex @@ -149,8 +149,8 @@ defmodule Bumblebee.Diffusion.ControlNet do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} - cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) - conditioning_shape = {1, cond_size, cond_size, 3} + conditioning_size = spec.sample_size * 2 ** (length(spec.conditioning_embedding_out_channels) - 1) + conditioning_shape = {1, conditioning_size, conditioning_size, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} %{ @@ -172,8 +172,8 @@ defmodule Bumblebee.Diffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - cond_size = spec.sample_size * 2 ** (length(spec.hidden_sizes) - 1) - conditioning_shape = {nil, cond_size, cond_size, 3} + conditioning_size = spec.sample_size * 2 ** (length(spec.conditioning_embedding_out_channels) - 1) + conditioning_shape = {nil, conditioning_size, conditioning_size, 3} Bumblebee.Utils.Model.inputs_to_map([ Axon.input("sample", shape: sample_shape), diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 153a6e7b..596325e6 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -101,7 +101,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do num_images_per_prompt: 2, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60, conditioning_size: 512], + compile: [batch_size: 1, sequence_length: 60], defn_options: [compiler: EXLA] ) @@ -175,13 +175,15 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do compile = if compile = opts[:compile] do compile - |> Keyword.validate!([:batch_size, :sequence_length, :conditioning_size]) - |> Shared.require_options!([:batch_size, :sequence_length, :conditioning_size]) + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) end batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] - conditioning_size = compile[:conditioning_size] + + conditioning_size = + controlnet.spec.sample_size * 2 ** (length(controlnet.spec.conditioning_embedding_out_channels) - 1) tokenizer = Bumblebee.configure(tokenizer, diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 96ce2ff5..483e1bf1 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -30,7 +30,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) - cond_size = unet.spec.sample_size * 2 ** (length(unet.spec.hidden_sizes) - 1) + conditioning_size = controlnet.spec.sample_size * 2 ** (length(controlnet.spec.conditioning_embedding_out_channels) - 1) serving = Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( @@ -47,7 +47,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do prompt = "numbat in forest, detailed, digital art" - conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {cond_size, cond_size, 3}) + conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3}) assert %{ results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] @@ -91,7 +91,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do num_steps: 3, safety_checker: safety_checker, safety_checker_featurizer: featurizer, - compile: [batch_size: 1, sequence_length: 60, conditioning_size: cond_size], + compile: [batch_size: 1, sequence_length: 60], defn_options: [compiler: EXLA] ) From 02382dd9df403374c30eb4a7d408623bd81b7598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 18:24:53 +0800 Subject: [PATCH 36/42] Naming --- lib/bumblebee/diffusion/control_net.ex | 20 +++++++++++-------- .../diffusion/stable_diffusion_controlnet.ex | 3 ++- .../stable_diffusion_controlnet_test.exs | 7 +++++-- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/diffusion/control_net.ex b/lib/bumblebee/diffusion/control_net.ex index cbdf2d24..8c893815 100644 --- a/lib/bumblebee/diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/control_net.ex @@ -89,9 +89,9 @@ defmodule Bumblebee.Diffusion.ControlNet do default: 1.0e-5, doc: "the epsilon used by the group normalization layers" ], - conditioning_embedding_out_channels: [ + conditioning_embedding_hidden_sizes: [ default: [16, 32, 96, 256], - doc: "the dimensionality of conditioning embedding" + doc: "the dimensionality of hidden layers in the conditioning input embedding" ] ] @@ -149,7 +149,9 @@ defmodule Bumblebee.Diffusion.ControlNet do sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels} timestep_shape = {} - conditioning_size = spec.sample_size * 2 ** (length(spec.conditioning_embedding_out_channels) - 1) + conditioning_size = + spec.sample_size * 2 ** (length(spec.conditioning_embedding_hidden_sizes) - 1) + conditioning_shape = {1, conditioning_size, conditioning_size, 3} encoder_hidden_state_shape = {1, 1, spec.cross_attention_size} @@ -172,7 +174,9 @@ defmodule Bumblebee.Diffusion.ControlNet do defp inputs(spec) do sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels} - conditioning_size = spec.sample_size * 2 ** (length(spec.conditioning_embedding_out_channels) - 1) + conditioning_size = + spec.sample_size * 2 ** (length(spec.conditioning_embedding_hidden_sizes) - 1) + conditioning_shape = {nil, conditioning_size, conditioning_size, 3} Bumblebee.Utils.Model.inputs_to_map([ @@ -282,15 +286,15 @@ defmodule Bumblebee.Diffusion.ControlNet do name = opts[:name] state = - Axon.conv(sample, hd(spec.conditioning_embedding_out_channels), + Axon.conv(sample, hd(spec.conditioning_embedding_hidden_sizes), kernel_size: 3, padding: [{1, 1}, {1, 1}], name: join(name, "input_conv"), activation: :silu ) - block_in_channels = Enum.drop(spec.conditioning_embedding_out_channels, -1) - block_out_channels = Enum.drop(spec.conditioning_embedding_out_channels, 1) + block_in_channels = Enum.drop(spec.conditioning_embedding_hidden_sizes, -1) + block_out_channels = Enum.drop(spec.conditioning_embedding_hidden_sizes, 1) channels = Enum.zip(block_in_channels, block_out_channels) @@ -429,7 +433,7 @@ defmodule Bumblebee.Diffusion.ControlNet do activation: {"act_fn", activation()}, group_norm_num_groups: {"norm_num_groups", number()}, group_norm_epsilon: {"norm_eps", number()}, - conditioning_embedding_out_channels: + conditioning_embedding_hidden_sizes: {"conditioning_embedding_out_channels", list(number())} ) diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 596325e6..61fa31db 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -183,7 +183,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do sequence_length = compile[:sequence_length] conditioning_size = - controlnet.spec.sample_size * 2 ** (length(controlnet.spec.conditioning_embedding_out_channels) - 1) + controlnet.spec.sample_size * + 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1) tokenizer = Bumblebee.configure(tokenizer, diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index 483e1bf1..bdcb906b 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -30,7 +30,9 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) - conditioning_size = controlnet.spec.sample_size * 2 ** (length(controlnet.spec.conditioning_embedding_out_channels) - 1) + conditioning_size = + controlnet.spec.sample_size * + 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1) serving = Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( @@ -47,7 +49,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do prompt = "numbat in forest, detailed, digital art" - conditioning = Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3}) + conditioning = + Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3}) assert %{ results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] From 812cac31dbef6b68ee11876bc17136ed9e3d587e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 18:27:17 +0800 Subject: [PATCH 37/42] Unify control_net -> controlnet --- .../diffusion/{control_net.ex => controlnet.ex} | 17 ++++++++--------- ...control_net_test.exs => controlnet_test.exs} | 0 2 files changed, 8 insertions(+), 9 deletions(-) rename lib/bumblebee/diffusion/{control_net.ex => controlnet.ex} (96%) rename test/bumblebee/diffusion/{control_net_test.exs => controlnet_test.exs} (100%) diff --git a/lib/bumblebee/diffusion/control_net.ex b/lib/bumblebee/diffusion/controlnet.ex similarity index 96% rename from lib/bumblebee/diffusion/control_net.ex rename to lib/bumblebee/diffusion/controlnet.ex index 8c893815..93c14fc5 100644 --- a/lib/bumblebee/diffusion/control_net.ex +++ b/lib/bumblebee/diffusion/controlnet.ex @@ -226,11 +226,10 @@ defmodule Bumblebee.Diffusion.ControlNet do name: "input_conv" ) - control_net_cond_embeddings = - control_net_embeddings(conditioning, spec, name: "controlnet_cond_embedding") + controlnet_cond_embeddings = + controlnet_embeddings(conditioning, spec, name: "controlnet_cond_embedding") - sample = - Axon.add(sample, control_net_cond_embeddings, name: "add_sample_control_net_embeddings") + sample = Axon.add(sample, controlnet_cond_embeddings) {sample, down_blocks_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") @@ -239,7 +238,7 @@ defmodule Bumblebee.Diffusion.ControlNet do mid_block(sample, timestep_embedding, encoder_hidden_state, spec, name: "mid_block") down_blocks_residuals = - control_net_down_blocks(down_blocks_residuals, name: "controlnet_down_blocks") + controlnet_down_blocks(down_blocks_residuals, name: "controlnet_down_blocks") down_blocks_residuals = for residual <- Tuple.to_list(down_blocks_residuals) do @@ -248,7 +247,7 @@ defmodule Bumblebee.Diffusion.ControlNet do |> List.to_tuple() mid_block_residual = - control_net_mid_block(sample, spec, name: "controlnet_mid_block") + controlnet_mid_block(sample, spec, name: "controlnet_mid_block") |> Axon.multiply(conditioning_scale, name: "mid_conditioning_scale") %{ @@ -257,7 +256,7 @@ defmodule Bumblebee.Diffusion.ControlNet do } end - defp control_net_down_blocks(down_block_residuals, opts) do + defp controlnet_down_blocks(down_block_residuals, opts) do name = opts[:name] residuals = @@ -272,7 +271,7 @@ defmodule Bumblebee.Diffusion.ControlNet do List.to_tuple(residuals) end - defp control_net_mid_block(input, spec, opts) do + defp controlnet_mid_block(input, spec, opts) do name = opts[:name] Axon.conv(input, List.last(spec.hidden_sizes), @@ -282,7 +281,7 @@ defmodule Bumblebee.Diffusion.ControlNet do ) end - defp control_net_embeddings(sample, spec, opts) do + defp controlnet_embeddings(sample, spec, opts) do name = opts[:name] state = diff --git a/test/bumblebee/diffusion/control_net_test.exs b/test/bumblebee/diffusion/controlnet_test.exs similarity index 100% rename from test/bumblebee/diffusion/control_net_test.exs rename to test/bumblebee/diffusion/controlnet_test.exs From 5e6f75553561fdf048dfe7a9e408e4d200683b80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 18:30:42 +0800 Subject: [PATCH 38/42] Naming --- lib/bumblebee/diffusion/controlnet.ex | 26 ++++++++++--------- .../diffusion/unet_2d_conditional.ex | 4 +-- lib/bumblebee/shared.ex | 4 +-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/diffusion/controlnet.ex b/lib/bumblebee/diffusion/controlnet.ex index 93c14fc5..ab06a088 100644 --- a/lib/bumblebee/diffusion/controlnet.ex +++ b/lib/bumblebee/diffusion/controlnet.ex @@ -96,7 +96,7 @@ defmodule Bumblebee.Diffusion.ControlNet do ] @moduledoc """ - ControlNet model with two spatial dimensions and conditional state. + ControlNet model with two spatial dimensions and conditioning state. ## Architectures @@ -115,11 +115,11 @@ defmodule Bumblebee.Diffusion.ControlNet do * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` - The conditional state (context) to use with cross-attention. + The conditioning state (context) to use with cross-attention. * `"conditioning"` - `{batch_size, conditioning_size, conditioning_size, 3}` - The conditional input + The conditioning spatial input. ## Configuration @@ -226,10 +226,10 @@ defmodule Bumblebee.Diffusion.ControlNet do name: "input_conv" ) - controlnet_cond_embeddings = - controlnet_embeddings(conditioning, spec, name: "controlnet_cond_embedding") + controlnet_conditioning_embeddings = + controlnet_embedding(conditioning, spec, name: "controlnet_conditioning_embedding") - sample = Axon.add(sample, controlnet_cond_embeddings) + sample = Axon.add(sample, controlnet_conditioning_embeddings) {sample, down_blocks_residuals} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") @@ -247,8 +247,9 @@ defmodule Bumblebee.Diffusion.ControlNet do |> List.to_tuple() mid_block_residual = - controlnet_mid_block(sample, spec, name: "controlnet_mid_block") - |> Axon.multiply(conditioning_scale, name: "mid_conditioning_scale") + sample + |> controlnet_mid_block(spec, name: "controlnet_mid_block") + |> Axon.multiply(conditioning_scale) %{ down_blocks_residuals: Axon.container(down_blocks_residuals), @@ -281,7 +282,7 @@ defmodule Bumblebee.Diffusion.ControlNet do ) end - defp controlnet_embeddings(sample, spec, opts) do + defp controlnet_embedding(sample, spec, opts) do name = opts[:name] state = @@ -489,10 +490,11 @@ defmodule Bumblebee.Diffusion.ControlNet do |> Enum.reduce(&Map.merge/2) controlnet = %{ + "controlnet_conditioning_embedding.input_conv" => "controlnet_cond_embedding.conv_in", + "controlnet_conditioning_embedding.inner_convs.{m}" => + "controlnet_cond_embedding.blocks.{m}", + "controlnet_conditioning_embedding.output_conv" => "controlnet_cond_embedding.conv_out", "controlnet_down_blocks.{m}.zero_conv" => "controlnet_down_blocks.{m}", - "controlnet_cond_embedding.input_conv" => "controlnet_cond_embedding.conv_in", - "controlnet_cond_embedding.inner_convs.{m}" => "controlnet_cond_embedding.blocks.{m}", - "controlnet_cond_embedding.output_conv" => "controlnet_cond_embedding.conv_out", "controlnet_mid_block.zero_conv" => "controlnet_mid_block" } diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index af166c1c..96a02b8e 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -96,7 +96,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do ] @moduledoc """ - U-Net model with two spatial dimensions and conditional state. + U-Net model with two spatial dimensions and conditioning state. ## Architectures @@ -115,7 +115,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}` - The conditional state (context) to use with cross-attention. + The conditioning state (context) to use with cross-attention. ## Configuration diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index b3a0d6fd..a2e4151c 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -492,7 +492,7 @@ defmodule Bumblebee.Shared do def featurizer_resize_size(_images, %{height: height, width: width}), do: {height, width} def featurizer_resize_size(images, %{shortest_edge: size}) do - {height, width} = images_spacial_sizes(images) + {height, width} = images_spatial_sizes(images) {short, long} = if height < width, do: {height, width}, else: {width, height} @@ -502,7 +502,7 @@ defmodule Bumblebee.Shared do if height < width, do: {out_short, out_long}, else: {out_long, out_short} end - defp images_spacial_sizes(images) do + defp images_spatial_sizes(images) do height = Nx.axis_size(images, -3) width = Nx.axis_size(images, -2) {height, width} From b8e0e54d698e3856e488a4026afa0444f7aeb144 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 19:52:34 +0800 Subject: [PATCH 39/42] Residuals -> states --- lib/bumblebee/diffusion/controlnet.ex | 41 ++++++------ lib/bumblebee/diffusion/layers/unet.ex | 18 ++--- .../diffusion/stable_diffusion_controlnet.ex | 6 +- .../diffusion/unet_2d_conditional.ex | 66 ++++++++++--------- test/bumblebee/diffusion/controlnet_test.exs | 10 +-- .../diffusion/unet_2d_conditional_test.exs | 14 ++-- 6 files changed, 80 insertions(+), 75 deletions(-) diff --git a/lib/bumblebee/diffusion/controlnet.ex b/lib/bumblebee/diffusion/controlnet.ex index ab06a088..ba97ccd8 100644 --- a/lib/bumblebee/diffusion/controlnet.ex +++ b/lib/bumblebee/diffusion/controlnet.ex @@ -231,45 +231,44 @@ defmodule Bumblebee.Diffusion.ControlNet do sample = Axon.add(sample, controlnet_conditioning_embeddings) - {sample, down_blocks_residuals} = + {sample, down_block_states} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") sample = mid_block(sample, timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - down_blocks_residuals = - controlnet_down_blocks(down_blocks_residuals, name: "controlnet_down_blocks") + down_block_states = controlnet_down_blocks(down_block_states, name: "controlnet_down_blocks") - down_blocks_residuals = - for residual <- Tuple.to_list(down_blocks_residuals) do - Axon.multiply(residual, conditioning_scale, name: "down_conditioning_scale") + down_block_states = + for down_block_state <- Tuple.to_list(down_block_states) do + Axon.multiply(down_block_state, conditioning_scale) end |> List.to_tuple() - mid_block_residual = + mid_block_state = sample |> controlnet_mid_block(spec, name: "controlnet_mid_block") |> Axon.multiply(conditioning_scale) %{ - down_blocks_residuals: Axon.container(down_blocks_residuals), - mid_block_residual: mid_block_residual + down_block_states: Axon.container(down_block_states), + mid_block_state: mid_block_state } end - defp controlnet_down_blocks(down_block_residuals, opts) do + defp controlnet_down_blocks(down_block_states, opts) do name = opts[:name] - residuals = - for {{residual, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_residuals)) do - Axon.conv(residual, out_channels, + states = + for {{state, out_channels}, i} <- Enum.with_index(Tuple.to_list(down_block_states)) do + Axon.conv(state, out_channels, kernel_size: 1, name: name |> join(i) |> join("zero_conv"), kernel_initializer: :zeros ) end - List.to_tuple(residuals) + List.to_tuple(states) end defp controlnet_mid_block(input, spec, opts) do @@ -333,17 +332,17 @@ defmodule Bumblebee.Diffusion.ControlNet do Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)]) in_channels = hd(spec.hidden_sizes) - down_block_residuals = [{sample, in_channels}] + down_block_states = [{sample, in_channels}] - state = {sample, down_block_residuals, in_channels} + state = {sample, down_block_states, in_channels} - {sample, down_block_residuals, _} = + {sample, down_block_states, _} = for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks), reduce: state do - {sample, down_block_residuals, in_channels} -> + {sample, down_block_states, in_channels} -> last_block? = idx == length(spec.hidden_sizes) - 1 - {sample, residuals} = + {sample, states} = Diffusion.Layers.UNet.down_block_2d( block_type, sample, @@ -362,10 +361,10 @@ defmodule Bumblebee.Diffusion.ControlNet do name: join(name, idx) ) - {sample, down_block_residuals ++ Tuple.to_list(residuals), out_channels} + {sample, down_block_states ++ Tuple.to_list(states), out_channels} end - {sample, List.to_tuple(down_block_residuals)} + {sample, List.to_tuple(down_block_states)} end defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do diff --git a/lib/bumblebee/diffusion/layers/unet.ex b/lib/bumblebee/diffusion/layers/unet.ex index 0cd72fcb..8f0ad953 100644 --- a/lib/bumblebee/diffusion/layers/unet.ex +++ b/lib/bumblebee/diffusion/layers/unet.ex @@ -51,22 +51,22 @@ defmodule Bumblebee.Diffusion.Layers.UNet do :cross_attention_up_block, sample, timestep_embedding, - residuals, + down_block_states, encoder_hidden_state, opts ) do - up_block_2d(sample, timestep_embedding, residuals, encoder_hidden_state, opts) + up_block_2d(sample, timestep_embedding, down_block_states, encoder_hidden_state, opts) end def up_block_2d( :up_block, sample, timestep_embedding, - residuals, + down_block_states, _encoder_hidden_state, opts ) do - up_block_2d(sample, timestep_embedding, residuals, nil, opts) + up_block_2d(sample, timestep_embedding, down_block_states, nil, opts) end @doc """ @@ -147,7 +147,7 @@ defmodule Bumblebee.Diffusion.Layers.UNet do def up_block_2d( hidden_state, timestep_embedding, - residuals, + down_block_states, encoder_hidden_state, opts ) do @@ -164,18 +164,18 @@ defmodule Bumblebee.Diffusion.Layers.UNet do add_upsample = Keyword.get(opts, :add_upsample, true) name = opts[:name] - ^depth = length(residuals) + ^depth = length(down_block_states) hidden_state = - for {{residual, residual_channels}, idx} <- Enum.with_index(residuals), + for {{down_block_state, down_block_channels}, idx} <- Enum.with_index(down_block_states), reduce: hidden_state do hidden_state -> in_channels = if(idx == 0, do: in_channels, else: out_channels) hidden_state = - Axon.concatenate([hidden_state, residual], axis: -1) + Axon.concatenate([hidden_state, down_block_state], axis: -1) |> Diffusion.Layers.residual_block( - in_channels + residual_channels, + in_channels + down_block_channels, out_channels, timestep_embedding: timestep_embedding, norm_epsilon: norm_epsilon, diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 61fa31db..5e0ef850 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -458,7 +458,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do "encoder_hidden_state" => text_embeddings } - %{down_blocks_residuals: down_blocks_residuals, mid_block_residual: mid_block_residual} = + %{down_block_states: down_block_states, mid_block_state: mid_block_state} = controlnet_predict.(controlnet_params, controlnet_inputs) unet_inputs = @@ -466,8 +466,8 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do "sample" => Nx.concatenate([latents, latents]), "timestep" => timestep, "encoder_hidden_state" => text_embeddings, - "additional_mid_residual" => mid_block_residual, - "additional_down_residuals" => down_blocks_residuals + "additional_down_block_states" => down_block_states, + "additional_mid_block_state" => mid_block_state } %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs) diff --git a/lib/bumblebee/diffusion/unet_2d_conditional.ex b/lib/bumblebee/diffusion/unet_2d_conditional.ex index 96a02b8e..8045eca9 100644 --- a/lib/bumblebee/diffusion/unet_2d_conditional.ex +++ b/lib/bumblebee/diffusion/unet_2d_conditional.ex @@ -117,6 +117,15 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do The conditioning state (context) to use with cross-attention. + * `"additional_down_block_states"` + + Optional outputs matching the structure of down blocks, added as + part of the encoder-decoder skip connections. + + * `"additional_mid_block_state"` + + Optional output added to the mid block result. + ## Configuration #{Shared.options_doc(options)} @@ -167,8 +176,8 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Axon.input("sample", shape: sample_shape), Axon.input("timestep", shape: {}), Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}), - Axon.input("additional_mid_residual", optional: true), - Axon.input("additional_down_residuals", optional: true) + Axon.input("additional_down_block_states", optional: true), + Axon.input("additional_mid_block_state", optional: true) ]) end @@ -210,22 +219,19 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do name: "input_conv" ) - {sample, down_block_residuals} = + {sample, down_block_states} = down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks") - down_block_residuals = - add_optional_additional_down_residuals( - down_block_residuals, - inputs["additional_down_residuals"] - ) + down_block_states = + maybe_add_down_block_states(down_block_states, inputs["additional_down_block_states"]) sample = sample |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block") - |> add_optional_additional_mid_residual(inputs["additional_mid_residual"]) + |> maybe_add_mid_block_state(inputs["additional_mid_block_state"]) sample - |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec, + |> up_blocks(timestep_embedding, down_block_states, encoder_hidden_state, spec, name: "up_blocks" ) |> Axon.group_norm(spec.group_norm_num_groups, @@ -247,17 +253,17 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)]) in_channels = hd(spec.hidden_sizes) - down_block_residuals = [{sample, in_channels}] + down_block_states = [{sample, in_channels}] - state = {sample, down_block_residuals, in_channels} + state = {sample, down_block_states, in_channels} - {sample, down_block_residuals, _} = + {sample, down_block_states, _} = for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks), reduce: state do - {sample, down_block_residuals, in_channels} -> + {sample, down_block_states, in_channels} -> last_block? = idx == length(spec.hidden_sizes) - 1 - {sample, residuals} = + {sample, states} = Diffusion.Layers.UNet.down_block_2d( block_type, sample, @@ -276,10 +282,10 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do name: join(name, idx) ) - {sample, down_block_residuals ++ Tuple.to_list(residuals), out_channels} + {sample, down_block_states ++ Tuple.to_list(states), out_channels} end - {sample, List.to_tuple(down_block_residuals)} + {sample, List.to_tuple(down_block_states)} end defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do @@ -301,15 +307,15 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do defp up_blocks( sample, timestep_embedding, - down_block_residuals, + down_block_states, encoder_hidden_state, spec, opts ) do name = opts[:name] - down_block_residuals = - down_block_residuals + down_block_states = + down_block_states |> Tuple.to_list() |> Enum.reverse() |> Enum.chunk_every(spec.depth + 1) @@ -327,13 +333,13 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do reversed_hidden_sizes, spec.up_block_types, num_attention_heads_per_block, - down_block_residuals + down_block_states ] |> Enum.zip() |> Enum.with_index() {sample, _} = - for {{out_channels, block_type, num_attention_heads, residuals}, idx} <- blocks_and_chunks, + for {{out_channels, block_type, num_attention_heads, states}, idx} <- blocks_and_chunks, reduce: {sample, in_channels} do {sample, in_channels} -> last_block? = idx == length(spec.hidden_sizes) - 1 @@ -343,7 +349,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do block_type, sample, timestep_embedding, - residuals, + states, encoder_hidden_state, depth: spec.depth + 1, in_channels: in_channels, @@ -363,16 +369,16 @@ defmodule Bumblebee.Diffusion.UNet2DConditional do sample end - defp add_optional_additional_mid_residual(mid_block_residual, additional_mid_block_residual) do - maybe_add(mid_block_residual, additional_mid_block_residual) + defp maybe_add_mid_block_state(mid_block_state, additional_mid_block_state) do + maybe_add(mid_block_state, additional_mid_block_state) end - defp add_optional_additional_down_residuals(down_block_residuals, additional_down_residuals) do - down_residuals = Tuple.to_list(down_block_residuals) + defp maybe_add_down_block_states(down_block_states, additional_down_block_states) do + down_block_states = Tuple.to_list(down_block_states) - for {{down_residual, out_channels}, i} <- Enum.with_index(down_residuals) do - additional_down_residual = Axon.nx(additional_down_residuals, &elem(&1, i)) - {maybe_add(down_residual, additional_down_residual), out_channels} + for {{down_block_state, out_channels}, i} <- Enum.with_index(down_block_states) do + additional_down_block_state = Axon.nx(additional_down_block_states, &elem(&1, i)) + {maybe_add(down_block_state, additional_down_block_state), out_channels} end |> List.to_tuple() end diff --git a/test/bumblebee/diffusion/controlnet_test.exs b/test/bumblebee/diffusion/controlnet_test.exs index 8ea5dbf7..855fbfe9 100644 --- a/test/bumblebee/diffusion/controlnet_test.exs +++ b/test/bumblebee/diffusion/controlnet_test.exs @@ -25,18 +25,18 @@ defmodule Bumblebee.Diffusion.ControlNetTest do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.mid_block_residual) == {1, 8, 8, 1280} + assert Nx.shape(outputs.mid_block_state) == {1, 8, 8, 1280} assert_all_close( - outputs.mid_block_residual[[0, 0, 0, 1..3]], + outputs.mid_block_state[[0, 0, 0, 1..3]], Nx.tensor([-1.2827045917510986, -0.6995724439620972, -0.610561192035675]) ) - first_down_residual = elem(outputs.down_blocks_residuals, 0) - assert Nx.shape(first_down_residual) == {1, 64, 64, 320} + first_down_block_state = elem(outputs.down_block_states, 0) + assert Nx.shape(first_down_block_state) == {1, 64, 64, 320} assert_all_close( - first_down_residual[[0, 0, 0, 1..3]], + first_down_block_state[[0, 0, 0, 1..3]], Nx.tensor([-0.029463158920407295, 0.04885300621390343, -0.12834328413009644]) ) end diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index cc46385e..4d0c412a 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -35,7 +35,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do ) end - test ":base with additional residuals" do + test ":base with additional states for skip connection" do tiny = "bumblebee-testing/tiny-stable-diffusion" assert {:ok, %{model: model, params: params, spec: spec}} = @@ -50,18 +50,18 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do {_, out_shapes} = for block_out_channel <- spec.hidden_sizes, reduce: state do {spatial_size, acc} -> - residuals = + states = for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channel} downsampled_spatial = div(spatial_size, 2) downsample = {1, downsampled_spatial, downsampled_spatial, block_out_channel} - {div(spatial_size, 2), acc ++ residuals ++ [downsample]} + {div(spatial_size, 2), acc ++ states ++ [downsample]} end out_shapes = Enum.drop(out_shapes, -1) - down_residuals = + down_block_states = for shape <- out_shapes do Nx.broadcast(0.5, shape) end @@ -69,15 +69,15 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) mid_dim = List.last(spec.hidden_sizes) - mid_residual_shape = {1, mid_spatial, mid_spatial, mid_dim} + mid_block_shape = {1, mid_spatial, mid_spatial, mid_dim} inputs = %{ "sample" => Nx.broadcast(0.5, {1, spec.sample_size, spec.sample_size, 4}), "timestep" => Nx.tensor(1), "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, spec.cross_attention_size}), - "additional_mid_residual" => Nx.broadcast(0.5, mid_residual_shape), - "additional_down_residuals" => down_residuals + "additional_mid_block_state" => Nx.broadcast(0.5, mid_block_shape), + "additional_down_block_states" => down_block_states } outputs = Axon.predict(model, params, inputs) From cd860defb2f82fc0fd05e883546aaedf6b88ae90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 20:48:31 +0800 Subject: [PATCH 40/42] Use tiny checkpoint and reference outputs in tests --- test/bumblebee/diffusion/controlnet_test.exs | 57 ++++++++++++----- .../stable_diffusion_controlnet_test.exs | 2 +- .../diffusion/unet_2d_conditional_test.exs | 63 ++++++------------- 3 files changed, 60 insertions(+), 62 deletions(-) diff --git a/test/bumblebee/diffusion/controlnet_test.exs b/test/bumblebee/diffusion/controlnet_test.exs index 855fbfe9..ef1512d7 100644 --- a/test/bumblebee/diffusion/controlnet_test.exs +++ b/test/bumblebee/diffusion/controlnet_test.exs @@ -7,37 +7,60 @@ defmodule Bumblebee.Diffusion.ControlNetTest do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "lllyasviel/sd-controlnet-scribble"}, - module: Bumblebee.Diffusion.ControlNet, - architecture: :base - ) + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) - assert %Bumblebee.Diffusion.ControlNet{ - architecture: :base - } = spec + assert %Bumblebee.Diffusion.ControlNet{architecture: :base} = spec inputs = %{ - "sample" => Nx.broadcast(0.5, {1, 64, 64, 4}), - "conditioning" => Nx.broadcast(0.8, {1, 512, 512, 3}), - "timestep" => Nx.tensor(0), - "encoder_hidden_state" => Nx.broadcast(0.8, {1, 1, 768}) + "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), + "timestep" => Nx.tensor(1), + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), + "conditioning" => Nx.broadcast(0.5, {1, 64, 64, 3}) } outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.mid_block_state) == {1, 8, 8, 1280} + assert Nx.shape(outputs.mid_block_state) == {1, 16, 16, 64} assert_all_close( - outputs.mid_block_state[[0, 0, 0, 1..3]], - Nx.tensor([-1.2827045917510986, -0.6995724439620972, -0.610561192035675]) + outputs.mid_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-0.2818, 1.6207, -0.7002], [0.2391, 1.1387, 0.9682], [-0.6386, 0.7026, -0.4218]], + [[1.0681, 1.8418, -1.0586], [0.9387, 0.5971, 1.2284], [1.2914, 0.4060, -0.9559]], + [[0.5841, 1.2935, 0.0081], [0.7306, 0.2915, 0.7736], [0.0875, 0.9619, 0.4108]] + ] + ]) ) + assert tuple_size(outputs.down_block_states) == 6 + first_down_block_state = elem(outputs.down_block_states, 0) - assert Nx.shape(first_down_block_state) == {1, 64, 64, 320} + assert Nx.shape(first_down_block_state) == {1, 32, 32, 32} + + assert_all_close( + first_down_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-0.1423, 0.2804, -0.0497], [-0.1425, 0.2798, -0.0485], [-0.1426, 0.2794, -0.0488]], + [[-0.1419, 0.2810, -0.0493], [-0.1427, 0.2803, -0.0479], [-0.1427, 0.2800, -0.0486]], + [[-0.1417, 0.2812, -0.0494], [-0.1427, 0.2807, -0.0480], [-0.1426, 0.2804, -0.0486]] + ] + ]) + ) + + last_down_block_state = elem(outputs.down_block_states, 5) + assert Nx.shape(last_down_block_state) == {1, 16, 16, 64} assert_all_close( - first_down_block_state[[0, 0, 0, 1..3]], - Nx.tensor([-0.029463158920407295, 0.04885300621390343, -0.12834328413009644]) + last_down_block_state[[.., 1..3, 1..3, 1..3]], + Nx.tensor([ + [ + [[-1.1169, 0.8087, 0.1024], [0.4832, 0.0686, 1.0149], [-0.3314, 0.1486, 0.4445]], + [[0.5770, 0.3195, -0.2008], [1.5692, -0.1771, 0.7669], [0.4908, 0.1258, 0.0694]], + [[0.4694, -0.3723, 0.1505], [1.7356, -0.4214, 0.8929], [0.4702, 0.2400, 0.1213]] + ] + ]) ) end end diff --git a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs index bdcb906b..5db38905 100644 --- a/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs +++ b/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs @@ -18,7 +18,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) - {:ok, controlnet} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-controlnet"}) + {:ok, controlnet} = Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) {:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) diff --git a/test/bumblebee/diffusion/unet_2d_conditional_test.exs b/test/bumblebee/diffusion/unet_2d_conditional_test.exs index 4d0c412a..ce6fc5be 100644 --- a/test/bumblebee/diffusion/unet_2d_conditional_test.exs +++ b/test/bumblebee/diffusion/unet_2d_conditional_test.exs @@ -43,67 +43,42 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec - first = {1, spec.sample_size, spec.sample_size, hd(spec.hidden_sizes)} - - state = {spec.sample_size, [first]} - - {_, out_shapes} = - for block_out_channel <- spec.hidden_sizes, reduce: state do - {spatial_size, acc} -> - states = - for _ <- 1..spec.depth, do: {1, spatial_size, spatial_size, block_out_channel} - - downsampled_spatial = div(spatial_size, 2) - downsample = {1, downsampled_spatial, downsampled_spatial, block_out_channel} - - {div(spatial_size, 2), acc ++ states ++ [downsample]} - end - - out_shapes = Enum.drop(out_shapes, -1) - down_block_states = - for shape <- out_shapes do - Nx.broadcast(0.5, shape) - end + [ + {1, 32, 32, 32}, + {1, 32, 32, 32}, + {1, 32, 32, 32}, + {1, 16, 16, 32}, + {1, 16, 16, 64}, + {1, 16, 16, 64} + ] + |> Enum.map(&Nx.broadcast(0.5, &1)) |> List.to_tuple() - mid_spatial = div(spec.sample_size, 2 ** (length(spec.hidden_sizes) - 1)) - mid_dim = List.last(spec.hidden_sizes) - mid_block_shape = {1, mid_spatial, mid_spatial, mid_dim} + mid_block_state = Nx.broadcast(0.5, {1, 16, 16, 64}) inputs = %{ - "sample" => Nx.broadcast(0.5, {1, spec.sample_size, spec.sample_size, 4}), + "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), "timestep" => Nx.tensor(1), - "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, spec.cross_attention_size}), - "additional_mid_block_state" => Nx.broadcast(0.5, mid_block_shape), - "additional_down_block_states" => down_block_states + "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), + "additional_down_block_states" => down_block_states, + "additional_mid_block_state" => mid_block_state } outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.sample) == {1, spec.sample_size, spec.sample_size, spec.in_channels} + assert Nx.shape(outputs.sample) == {1, 32, 32, 4} assert_all_close( outputs.sample[[.., 1..3, 1..3, 1..3]], Nx.tensor([ [ - [-2.1599538326263428, -0.06256292015314102, 1.7675844430923462], - [-0.6707635521888733, -0.6823181509971619, 1.0919926166534424], - [0.16482116281986237, -1.2743796110153198, -0.03096655011177063] - ], - [ - [-1.13632333278656, -1.3499518632888794, 0.597271203994751], - [-1.7593439817428589, -1.599103569984436, -0.1870473176240921], - [-0.9655789136886597, 0.8080697655677795, 1.1974149942398071] - ], - [ - [-1.3559075593948364, -1.177065134048462, -0.5016229152679443], - [-2.5425026416778564, -1.2682275772094727, -0.6805112957954407], - [-1.8208105564117432, 0.9214832186698914, 0.5924324989318848] + [[-0.9457, -0.2378, 1.4223], [-0.5736, -0.2456, 0.7603], [-0.4346, -1.1370, -0.1988]], + [[-0.5274, -1.0902, 0.5937], [-1.2290, -0.7996, 0.0264], [-0.3006, -0.1181, 0.7059]], + [[-0.8336, -1.1615, -0.1906], [-1.0489, -0.3815, -0.5497], [-0.6255, 0.0863, 0.3285]] ] - ]), - atol: 1.0e-4 + ]) ) end end From 361926cf058b034627b7d988d8346ecd518adf58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 21:51:48 +0800 Subject: [PATCH 41/42] Update docs sidebar --- mix.exs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mix.exs b/mix.exs index b36c9748..b0677f3e 100644 --- a/mix.exs +++ b/mix.exs @@ -71,10 +71,12 @@ defmodule Bumblebee.MixProject do Bumblebee.Audio, Bumblebee.Text, Bumblebee.Vision, - Bumblebee.Diffusion.StableDiffusion + Bumblebee.Diffusion.StableDiffusion, + Bumblebee.Diffusion.StableDiffusionControlNet ], Models: [ Bumblebee.Audio.Whisper, + Bumblebee.Diffusion.ControlNet, Bumblebee.Diffusion.StableDiffusion.SafetyChecker, Bumblebee.Diffusion.UNet2DConditional, Bumblebee.Diffusion.VaeKl, From cba15011ea1487d83a1f0841f88c7243762f8cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 8 Apr 2024 21:55:33 +0800 Subject: [PATCH 42/42] Up --- lib/bumblebee/diffusion/controlnet.ex | 12 ++++-------- .../diffusion/stable_diffusion_controlnet.ex | 11 ++++++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/diffusion/controlnet.ex b/lib/bumblebee/diffusion/controlnet.ex index ba97ccd8..0e1ff85f 100644 --- a/lib/bumblebee/diffusion/controlnet.ex +++ b/lib/bumblebee/diffusion/controlnet.ex @@ -292,23 +292,19 @@ defmodule Bumblebee.Diffusion.ControlNet do activation: :silu ) - block_in_channels = Enum.drop(spec.conditioning_embedding_hidden_sizes, -1) - block_out_channels = Enum.drop(spec.conditioning_embedding_hidden_sizes, 1) - - channels = Enum.zip(block_in_channels, block_out_channels) + size_pairs = Enum.chunk_every(spec.conditioning_embedding_hidden_sizes, 2, 1) sample = - for {{in_channels, out_channels}, i} <- Enum.with_index(channels), - reduce: state do + for {[in_size, out_size], i} <- Enum.with_index(size_pairs), reduce: state do input -> input - |> Axon.conv(in_channels, + |> Axon.conv(in_size, kernel_size: 3, padding: [{1, 1}, {1, 1}], name: name |> join("inner_convs") |> join(2 * i), activation: :silu ) - |> Axon.conv(out_channels, + |> Axon.conv(out_size, kernel_size: 3, padding: [{1, 1}, {1, 1}], strides: 2, diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index 5e0ef850..fee08e9f 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -107,9 +107,12 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do prompt = "numbat in forest, detailed, digital art" + # The conditioning image matching the given ControlNet condition, + # such as edges, pose or depth. Here we use a simple handcrafted + # tensor conditioning = Nx.tensor( - [for(_ <- 1..8, do: [255]) ++ for(_ <- 1..24, do: [0])], + [List.duplicate(255, 8) ++ List.duplicate(0, 24)], type: :u8 ) |> Nx.tile([256, 8, 3]) @@ -450,10 +453,12 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do {scheduler_state, text_embeddings, unet_params, conditioning, conditioning_scale, controlnet_params}}, timestep <- timesteps do + sample = Nx.concatenate([latents, latents]) + controlnet_inputs = %{ "conditioning" => conditioning, "conditioning_scale" => conditioning_scale, - "sample" => Nx.concatenate([latents, latents]), + "sample" => sample, "timestep" => timestep, "encoder_hidden_state" => text_embeddings } @@ -463,7 +468,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do unet_inputs = %{ - "sample" => Nx.concatenate([latents, latents]), + "sample" => sample, "timestep" => timestep, "encoder_hidden_state" => text_embeddings, "additional_down_block_states" => down_block_states,