Skip to content

Commit 711403c

Browse files
committed
SD3 updates, CLI arguments for multi-device
1 parent 493f260 commit 711403c

File tree

7 files changed

+226
-34
lines changed

7 files changed

+226
-34
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
self.output_counter = 0
102102
self.dest_type = dest_type
103103
self.dest_dtype = dest_dtype
104+
self.validate = False
104105

105106
def load(
106107
self,
@@ -252,6 +253,10 @@ def __call__(self, function_name, inputs: list):
252253
if not isinstance(inputs, list):
253254
inputs = [inputs]
254255
inputs = self._validate_or_convert_inputs(function_name, inputs)
256+
257+
if self.validate:
258+
self.save_torch_inputs(inputs)
259+
255260
if self.benchmark:
256261
output = self._run_and_benchmark(function_name, inputs)
257262
else:
@@ -261,6 +266,8 @@ def __call__(self, function_name, inputs: list):
261266
output = self._output_cast(output)
262267
return output
263268

269+
# def _run_and_validate(self, iree_fn, torch_fn, inputs: list)
270+
264271

265272
class Printer:
266273
def __init__(self, verbose, start_time, print_time):
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from diffusers import StableDiffusion3Pipeline
2+
import torch
3+
from datetime import datetime as dt
4+
5+
6+
def run_diffusers_cpu(
7+
hf_model_name,
8+
prompt,
9+
negative_prompt,
10+
guidance_scale,
11+
seed,
12+
height,
13+
width,
14+
num_inference_steps,
15+
):
16+
from diffusers import StableDiffusion3Pipeline
17+
18+
pipe = StableDiffusion3Pipeline.from_pretrained(
19+
hf_model_name, torch_dtype=torch.float32
20+
)
21+
pipe = pipe.to("cpu")
22+
generator = torch.Generator().manual_seed(int(seed))
23+
24+
image = pipe(
25+
prompt=prompt,
26+
negative_prompt=negative_prompt,
27+
num_inference_steps=num_inference_steps,
28+
guidance_scale=guidance_scale,
29+
height=height,
30+
width=width,
31+
generator=generator,
32+
).images[0]
33+
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
34+
image.save(f"diffusers_reference_output_{timestamp}.png")
35+
36+
37+
if __name__ == "__main__":
38+
from turbine_models.custom_models.sd_inference.sd_cmd_opts import args
39+
40+
run_diffusers_cpu(
41+
args.hf_model_name,
42+
args.prompt,
43+
args.negative_prompt,
44+
args.guidance_scale,
45+
args.seed,
46+
args.height,
47+
args.width,
48+
args.num_inference_steps,
49+
)

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def prepare_model_input(self, sample, t, timesteps):
8383
latent_model_input = sample
8484
return latent_model_input.type(self.dtype), t.type(self.dtype)
8585

86-
def step(self, noise_pred, t, sample, guidance_scale):
87-
self.model._step_index = self.index_for_timestep(t)
86+
def step(self, noise_pred, t, sample, guidance_scale, i):
87+
self.model._step_index = i
8888

8989
if self.do_classifier_free_guidance:
9090
noise_preds = noise_pred.chunk(2)
@@ -299,6 +299,7 @@ def export_scheduler_model(
299299
torch.empty(1, dtype=dtype),
300300
torch.empty(sample, dtype=dtype),
301301
torch.empty(1, dtype=dtype),
302+
torch.empty([1], dtype=torch.int64),
302303
]
303304

304305
fxb = FxProgramsBuilder(scheduler_module)
@@ -361,8 +362,8 @@ class CompiledScheduler(CompiledModule):
361362
}
362363
model_metadata_run_step = {
363364
"model_name": "sd3_scheduler_FlowEulerDiscrete",
364-
"input_shapes": [noise_pred_shape, (1,), sample, (1,)],
365-
"input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype],
365+
"input_shapes": [noise_pred_shape, (1,), sample, (1,), (1,)],
366+
"input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype, "int64"],
366367
"output_shapes": [sample],
367368
"output_dtypes": [np_dtype],
368369
}

models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class TextEncoderModule(torch.nn.Module):
5454
@torch.no_grad()
5555
def __init__(
5656
self,
57-
batch_size=1,
5857
):
5958
super().__init__()
6059
self.dtype = torch.float16
@@ -89,7 +88,6 @@ def __init__(
8988
load_into(f, self.t5xxl.transformer, "", "cpu", self.dtype)
9089

9190
self.do_classifier_free_guidance = True
92-
self.batch_size = batch_size
9391

9492
def get_cond(self, tokens_l, tokens_g, tokens_t5xxl):
9593
l_out, l_pooled = self.clip_l.forward(tokens_l)
@@ -152,9 +150,7 @@ def export_text_encoders(
152150
attn_spec=attn_spec,
153151
)
154152
return vmfb_path
155-
model = TextEncoderModule(
156-
batch_size=batch_size,
157-
)
153+
model = TextEncoderModule(hf_model_name)
158154
mapper = {}
159155

160156
assert (

models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,43 @@ def is_valid_file(arg):
177177
default="fp16",
178178
help="Precision of Stable Diffusion weights and graph.",
179179
)
180+
181+
p.add_argument(
182+
"--clip_precision",
183+
type=str,
184+
default=None,
185+
help="Precision of CLIP weights and graph.",
186+
)
187+
p.add_argument(
188+
"--unet_precision",
189+
type=str,
190+
default=None,
191+
help="Precision of CLIP weights and graph.",
192+
)
193+
p.add_argument(
194+
"--mmdit_precision",
195+
type=str,
196+
default=None,
197+
help="Precision of CLIP weights and graph.",
198+
)
199+
p.add_argument(
200+
"--vae_precision",
201+
type=str,
202+
default=None,
203+
help="Precision of CLIP weights and graph.",
204+
)
205+
180206
p.add_argument(
181207
"--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion"
182208
)
183209

210+
p.add_argument(
211+
"--decomp_attn",
212+
default=False,
213+
action="store_true",
214+
help="Decompose attention at fx graph level",
215+
)
216+
184217
p.add_argument(
185218
"--clip_decomp_attn",
186219
action="store_true",
@@ -205,12 +238,6 @@ def is_valid_file(arg):
205238
help="Decompose attention for unet only at fx graph level",
206239
)
207240

208-
p.add_argument(
209-
"--decomp_attn",
210-
default=False,
211-
action="store_true",
212-
help="Decompose attention at fx graph level",
213-
)
214241

215242
p.add_argument(
216243
"--use_i8_punet",
@@ -270,21 +297,81 @@ def is_valid_file(arg):
270297
# IREE Compiler Options
271298
##############################################################################
272299

273-
p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
274-
275300
p.add_argument(
276-
"--rt_device",
301+
"--device",
277302
type=str,
278303
default="local-task",
279304
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
280305
)
281306

307+
p.add_argument(
308+
"--clip_device",
309+
type=str,
310+
default=None,
311+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
312+
)
313+
p.add_argument(
314+
"--unet_device",
315+
type=str,
316+
default=None,
317+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
318+
)
319+
p.add_argument(
320+
"--mmdit_device",
321+
type=str,
322+
default=None,
323+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
324+
)
325+
p.add_argument(
326+
"--vae_device",
327+
type=str,
328+
default=None,
329+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
330+
)
331+
p.add_argument(
332+
"--scheduler_device",
333+
type=str,
334+
default=None,
335+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
336+
)
337+
282338
# TODO: Bring in detection for target triple
283339
p.add_argument(
284340
"--iree_target_triple",
285341
type=str,
286342
default="x86_64-linux-gnu",
287-
help="Specify vulkan target triple or rocm/cuda target device.",
343+
help="Specify vulkan target triple or rocm/cuda target chip.",
344+
)
345+
346+
p.add_argument(
347+
"--clip_target",
348+
type=str,
349+
default=None,
350+
help="Specify vulkan target triple or rocm/cuda target chip.",
351+
)
352+
p.add_argument(
353+
"--unet_target",
354+
type=str,
355+
default=None,
356+
help="Specify vulkan target triple or rocm/cuda target chip.",
357+
)
358+
p.add_argument(
359+
"--mmdit_target",
360+
type=str,
361+
default=None,
362+
help="Specify vulkan target triple or rocm/cuda target chip.",
363+
)
364+
p.add_argument(
365+
"--vae_target",
366+
type=str,
367+
default=None,
368+
help="Specify vulkan target triple or rocm/cuda target chip.",
369+
)
370+
p.add_argument(
371+
"--scheduler_target",
372+
type=str,
373+
default=None,
374+
help="Specify vulkan target triple or rocm/cuda target chip.",
288375
)
289376

290377
p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options")
@@ -296,13 +383,6 @@ def is_valid_file(arg):
296383
help="extra iree-compile options for models with iree_linalg_ext.attention ops.",
297384
)
298385

299-
p.add_argument(
300-
"--attn_spec",
301-
type=str,
302-
default=None,
303-
help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.",
304-
)
305-
306386
p.add_argument(
307387
"--clip_flags",
308388
type=str,
@@ -331,4 +411,12 @@ def is_valid_file(arg):
331411
help="extra iree-compile options to send for compiling mmdit. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
332412
)
333413

414+
p.add_argument(
415+
"--attn_spec",
416+
type=str,
417+
default=None,
418+
help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.",
419+
)
420+
421+
334422
args, unknown = p.parse_known_args()

0 commit comments

Comments
 (0)