Skip to content

Commit 27e959e

Browse files
Add stable diffusion 3.5 medium model (#2033)
1 parent 96b8685 commit 27e959e

8 files changed

+348
-65
lines changed

keras_hub/src/models/stable_diffusion_3/mmdit.py

Lines changed: 254 additions & 58 deletions
Large diffs are not rendered by default.

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ class StableDiffusion3Backbone(Backbone):
205205
mmdit_qk_norm: Optional str. Whether to normalize the query and key
206206
tensors for each transformer in MMDiT. Available options are `None`
207207
and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
208-
and to `"rms_norm" for 3.5 version.
208+
and to `"rms_norm"` for 3.5 version.
209+
mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
210+
the blocks that serve as dual attention blocks. Typically, this is
211+
for 3.5 version. Defaults to `None`.
209212
vae: The VAE used for transformations between pixel space and latent
210213
space.
211214
clip_l: The CLIP text encoder for encoding the inputs.
@@ -253,6 +256,7 @@ class StableDiffusion3Backbone(Backbone):
253256
mmdit_depth=4,
254257
mmdit_position_size=192,
255258
mmdit_qk_norm=None,
259+
mmdit_dual_attention_indices=None,
256260
vae=vae,
257261
clip_l=clip_l,
258262
clip_g=clip_g,
@@ -268,6 +272,7 @@ def __init__(
268272
mmdit_num_heads,
269273
mmdit_position_size,
270274
mmdit_qk_norm,
275+
mmdit_dual_attention_indices,
271276
vae,
272277
clip_l,
273278
clip_g,
@@ -319,6 +324,7 @@ def __init__(
319324
context_shape=context_shape,
320325
pooled_projection_shape=pooled_projection_shape,
321326
qk_norm=mmdit_qk_norm,
327+
dual_attention_indices=mmdit_dual_attention_indices,
322328
data_format=data_format,
323329
dtype=dtype,
324330
name="diffuser",
@@ -454,6 +460,7 @@ def __init__(
454460
self.mmdit_num_heads = mmdit_num_heads
455461
self.mmdit_position_size = mmdit_position_size
456462
self.mmdit_qk_norm = mmdit_qk_norm
463+
self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
457464
self.latent_channels = latent_channels
458465
self.output_channels = output_channels
459466
self.num_train_timesteps = num_train_timesteps
@@ -590,6 +597,9 @@ def get_config(self):
590597
"mmdit_num_heads": self.mmdit_num_heads,
591598
"mmdit_position_size": self.mmdit_position_size,
592599
"mmdit_qk_norm": self.mmdit_qk_norm,
600+
"mmdit_dual_attention_indices": (
601+
self.mmdit_dual_attention_indices
602+
),
593603
"vae": layers.serialize(self.vae),
594604
"clip_l": layers.serialize(self.clip_l),
595605
"clip_g": layers.serialize(self.clip_g),
@@ -638,7 +648,10 @@ def from_config(cls, config, custom_objects=None):
638648
)
639649

640650
# To maintain backward compatibility, we need to ensure that
641-
# `mmdit_qk_norm` is included in the config.
651+
# `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
652+
# config.
642653
if "mmdit_qk_norm" not in config:
643654
config["mmdit_qk_norm"] = None
655+
if "mmdit_dual_attention_indices" not in config:
656+
config["mmdit_dual_attention_indices"] = None
644657
return cls(**config)

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setUp(self):
3535
"mmdit_num_heads": 2,
3636
"mmdit_position_size": 192,
3737
"mmdit_qk_norm": None,
38+
"mmdit_dual_attention_indices": None,
3839
"vae": vae,
3940
"clip_l": clip_l,
4041
"clip_g": clip_g,
@@ -67,10 +68,15 @@ def test_backbone_basics(self):
6768
run_quantization_check=False,
6869
)
6970

70-
# Test `mmdit_qk_norm="rms_norm"`.
71+
def test_backbone_basics_mmditx(self):
72+
# MMDiT-X includes `mmdit_qk_norm` and `mmdit_dual_attention_indices`.
7173
self.run_backbone_test(
7274
cls=StableDiffusion3Backbone,
73-
init_kwargs={**self.init_kwargs, "mmdit_qk_norm": "rms_norm"},
75+
init_kwargs={
76+
**self.init_kwargs,
77+
"mmdit_qk_norm": "rms_norm",
78+
"mmdit_dual_attention_indices": (0,),
79+
},
7480
input_data=self.input_data,
7581
expected_output_shape={
7682
"images": (2, 64, 64, 3),

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setUp(self):
4141
mmdit_num_heads=2,
4242
mmdit_position_size=192,
4343
mmdit_qk_norm=None,
44+
mmdit_dual_attention_indices=None,
4445
vae=VAEBackbone(
4546
[32, 32, 32, 32],
4647
[1, 1, 1, 1],

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setUp(self):
4141
mmdit_num_heads=2,
4242
mmdit_position_size=192,
4343
mmdit_qk_norm=None,
44+
mmdit_dual_attention_indices=None,
4445
vae=VAEBackbone(
4546
[32, 32, 32, 32],
4647
[1, 1, 1, 1],

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@
1313
},
1414
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
1515
},
16+
"stable_diffusion_3.5_medium": {
17+
"metadata": {
18+
"description": (
19+
"3 billion parameter, including CLIP L and CLIP G text "
20+
"encoders, MMDiT-X generative model, and VAE autoencoder. "
21+
"Developed by Stability AI."
22+
),
23+
"params": 3371793763,
24+
"path": "stable_diffusion_3",
25+
},
26+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
27+
},
1628
"stable_diffusion_3.5_large": {
1729
"metadata": {
1830
"description": (

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setUp(self):
4141
mmdit_num_heads=2,
4242
mmdit_position_size=192,
4343
mmdit_qk_norm=None,
44+
mmdit_dual_attention_indices=None,
4445
vae=VAEBackbone(
4546
[32, 32, 32, 32],
4647
[1, 1, 1, 1],

tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \
77
--preset stable_diffusion_3_medium \
88
--upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium
9+
python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \
10+
--preset stable_diffusion_3.5_medium \
11+
--upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_medium \
12+
--dtype bfloat16
913
python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \
1014
--preset stable_diffusion_3.5_large \
1115
--upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_large \
@@ -56,6 +60,17 @@
5660
# Tokenizer
5761
"clip_tokenizer": "hf://openai/clip-vit-large-patch14",
5862
},
63+
"stable_diffusion_3.5_medium": {
64+
# HF root
65+
"root": "hf://stabilityai/stable-diffusion-3.5-medium",
66+
# Model <-> Path
67+
"clip_l": "text_encoder/model.safetensors",
68+
"clip_g": "text_encoder_2/model.safetensors",
69+
"diffuser": "sd3.5_medium.safetensors",
70+
"vae": "sd3.5_medium.safetensors",
71+
# Tokenizer
72+
"clip_tokenizer": "hf://openai/clip-vit-large-patch14",
73+
},
5974
"stable_diffusion_3.5_large": {
6075
# HF root
6176
"root": "hf://stabilityai/stable-diffusion-3.5-large",
@@ -148,11 +163,27 @@ def convert_model(preset, height, width):
148163
24,
149164
192,
150165
None, # qk_norm
166+
None, # dual_attention_indices
167+
vae,
168+
clip_l,
169+
clip_g,
170+
image_shape=(height, width, 3),
171+
name="stable_diffusion_3_medium_backbone",
172+
)
173+
elif preset == "stable_diffusion_3.5_medium":
174+
backbone = StableDiffusion3Backbone(
175+
2,
176+
64 * 24,
177+
24,
178+
24,
179+
384, # position_size is larger than SD3
180+
"rms_norm", # qk_norm
181+
(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), # dual_attn_indices
151182
vae,
152183
clip_l,
153184
clip_g,
154185
image_shape=(height, width, 3),
155-
name="stable_diffusion_3_backbone",
186+
name="stable_diffusion_3.5_medium_backbone",
156187
)
157188
elif preset in (
158189
"stable_diffusion_3.5_large",
@@ -165,11 +196,12 @@ def convert_model(preset, height, width):
165196
38,
166197
192,
167198
"rms_norm", # qk_norm
199+
None, # dual_attention_indices
168200
vae,
169201
clip_l,
170202
clip_g,
171203
image_shape=(height, width, 3),
172-
name="stable_diffusion_3.5_backbone",
204+
name="stable_diffusion_3.5_large_backbone",
173205
)
174206
else:
175207
raise ValueError(f"Unknown preset={preset}.")
@@ -418,6 +450,24 @@ def port_diffuser(preset, filename, model):
418450
port_dense(loader, block.mlp.dense1, f"{prefix}.mlp.fc1")
419451
port_dense(loader, block.mlp.dense2, f"{prefix}.mlp.fc2")
420452

453+
# Dual attention
454+
if block.use_dual_attention:
455+
port_dense(
456+
loader, block.attention_qkv2, f"{prefix}.attn2.qkv"
457+
)
458+
if block.qk_norm is not None:
459+
port_ln_or_gn(
460+
loader, block.q_norm2, f"{prefix}.attn2.ln_q"
461+
)
462+
port_ln_or_gn(
463+
loader, block.k_norm2, f"{prefix}.attn2.ln_k"
464+
)
465+
port_dense(
466+
loader,
467+
block.attention_proj2,
468+
f"{prefix}.attn2.proj",
469+
)
470+
421471
# Output layer
422472
port_dense(
423473
loader,
@@ -562,7 +612,10 @@ def validate_output(preset, keras_model, keras_preprocessor, output_dir):
562612
if preset == "stable_diffusion_3_medium":
563613
num_steps = 28
564614
guidance_scale = 7.0
565-
elif preset == "stable_diffusion_3.5_large":
615+
elif preset in (
616+
"stable_diffusion_3.5_medium",
617+
"stable_diffusion_3.5_large",
618+
):
566619
num_steps = 40
567620
guidance_scale = 4.5
568621
elif preset == "stable_diffusion_3.5_large_turbo":

0 commit comments

Comments
 (0)