diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 14ea7301..9499e518 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -437,6 +437,7 @@ def prepare_all( vmfbs: dict = {}, weights: dict = {}, interactive: bool = False, + num_steps: int = 20, ): ready = self.is_prepared(vmfbs, weights) match ready: @@ -463,7 +464,7 @@ def prepare_all( if not self.map[submodel].get("weights") and self.map[submodel][ "export_args" ].get("external_weights"): - self.export_submodel(submodel, weights_only=True) + self.export_submodel(submodel, weights_only=True, num_steps=num_steps) return self.prepare_all(mlirs, vmfbs, weights, interactive) def is_prepared(self, vmfbs, weights): @@ -581,6 +582,7 @@ def export_submodel( submodel: str, input_mlir: str = None, weights_only: bool = False, + num_steps: int = 20, ): if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) @@ -671,7 +673,8 @@ def export_submodel( self.map[submodel]["export_args"]["batch_size"], self.map[submodel]["export_args"]["max_length"], "produce_img_split", - unet_module_name = self.map["unet"]["module_name"], + unet_module_name=self.map["unet"]["module_name"], + num_steps=num_steps, ) dims = [ self.map[submodel]["export_args"]["width"], @@ -722,15 +725,24 @@ def export_submodel( # LOAD def load_map(self): - for submodel in self.map.keys(): + # Make sure fullpipeline is imported last + submodels = list(self.map.keys() - {"fullpipeline"}) + submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else [] + for submodel in submodels: if not self.map[submodel]["load"]: self.printer.print(f"Skipping load for {submodel}") continue elif self.map[submodel].get("wraps"): + vmfbs = [] + weights = [] for wrapped in self.map[submodel]["wraps"]: - self.map[submodel]["vmfb"].append(self.map[wrapped]["vmfb"]) - self.map[submodel]["weights"].append(self.map[wrapped]["weights"]) + vmfbs.append(self.map[wrapped]["vmfb"]) + if "weights" in self.map[wrapped]: + weights.append(self.map[wrapped]["weights"]) + self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"] + self.map[submodel]["weights"] = weights + self.map[submodel]["weights"] + print(f"Loading {submodel}") self.load_submodel(submodel) def load_submodel(self, submodel): diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index fa8009cb..bd8369bf 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -465,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt): text_input_ids_list += text_inputs.input_ids.unsqueeze(0) uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0) - if self.compiled_pipeline: - return text_input_ids_list, uncond_input_ids_list - else: - prompt_embeds, add_text_embeds = self.text_encoder( - "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] - ) - return prompt_embeds, add_text_embeds + prompt_embeds, add_text_embeds = self.text_encoder( + "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] + ) + return prompt_embeds, add_text_embeds def prepare_latents( self, @@ -615,6 +612,7 @@ def _produce_latents_sdxl( return latents def produce_images_compiled( + self, sample, prompt_embeds, text_embeds, @@ -624,9 +622,11 @@ def produce_images_compiled( sample, prompt_embeds, text_embeds, - guidance_scale, + torch.as_tensor([guidance_scale], dtype=sample.dtype), ] - image = self.compiled_pipeline("produce_img_latents", pipe_inputs) + #image = self.compiled_pipeline("produce_img_latents", pipe_inputs) + image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs) + return image def prepare_sampling_inputs( self, @@ -726,12 +726,7 @@ def generate_images( for i in range(batch_count): if self.compiled_pipeline: - image = produce_images_compiled( - samples[i], - prompt_embeds, - negative_embeds, - guidance_scale - ) + image = self.produce_images_compiled(samples[i], prompt_embeds, negative_embeds, guidance_scale).to_host() else: produce_latents_input = [ samples[i], @@ -833,7 +828,7 @@ def numpy_to_pil_image(images): False, args.compiled_pipeline, ) - sd_pipe.prepare_all() + sd_pipe.prepare_all(num_steps=args.num_inference_steps) sd_pipe.load_map() sd_pipe.generate_images( args.prompt, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 448b6791..78f20a12 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -36,13 +36,13 @@ ], "unet": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "clip": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", ], "vae": [ @@ -61,7 +61,7 @@ "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", diff --git a/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py index 2a8cc4ff..f0cec20b 100644 --- a/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py +++ b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py @@ -38,7 +38,7 @@ %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ %step_64 = arith.index_cast %arg0 : index to i64 %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> }} return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}> @@ -48,28 +48,27 @@ produce_img_split = r""" module @sdxl_compiled_pipeline {{ - func.func private @{scheduler_module}.run_initialize(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],f16>, !torch.vtensor<[{num_steps}],f32>) attributes {{torch.assume_strict_symbolic_shapes}} - func.func private @{scheduler_module}.run_scale(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>) attributes {{torch.assume_strict_symbolic_shapes}} - func.func private @{scheduler_module}.run_step(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}} - func.func private @{unet_module}.{unet_function}(%arg0: !torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %arg3: !torch.vtensor<[{bd},1280],{precision}>, %arg4: !torch.vtensor<[{bd},6],{precision}>, %arg5: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}} - func.func private @{vae_module}.decode(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}} - - func.func @produce_image_latents(%sample: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %p_embeds: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %t_embeds: !torch.vtensor<[{bd},1280],{precision}>, %guidance_scale: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> {{ - %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{num_steps}],f32>) + func.func private @{scheduler_module}.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1xf16>, tensor<{num_steps}xf32>) attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{scheduler_module}.run_step(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{vae_module}.decode(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + + func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> {{ + %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<{num_steps}xf32>) %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %n_steps = arith.constant {num_steps} : index - %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) {{ + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) {{ %step_64 = arith.index_cast %arg0 : index to i64 %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %step_torch = torch_c.from_builtin_tensor %this_step : tensor<1xi64> -> !torch.vtensor<[1],si64> - %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %step_torch, %timesteps) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],si64>, !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>) - %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{bd},{max_length},2048],{precision}>, !torch.vtensor<[{bd},1280],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> - %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> - scf.yield %pred : !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> + %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) + %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> + %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> + scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}> }} - %image = func.call @{vae_module}.decode(%res): (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> - return %image : !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> + %image = func.call @{vae_module}.decode(%res): (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> + return %image : tensor<{batch_size}x3x{height}x{width}x{precision}> }} }} """ @@ -128,4 +127,4 @@ def get_pipeline_ir( scheduler_module=scheduler_module_name, vae_module=vae_module_name, num_steps=num_steps, - ) \ No newline at end of file + ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index ec88c525..3f0aaf7e 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -480,6 +480,7 @@ def export_submodel( self.hf_model_name, None, self.max_length, + self.batch_size, self.precision, "vmfb", self.external_weights, @@ -494,7 +495,6 @@ def export_submodel( input_mlir=input_mlir["prompt_encoder"], attn_spec=self.attn_spec, weights_only=weights_only, - batchsize=self.batch_size, batch_input=self.batch_prompt_input, ) return prompt_encoder_vmfb, prompt_encoder_external_weight_path diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 40ce6c2e..8f7668d5 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -231,7 +231,7 @@ def export_prompt_encoder( ) if weights_only: - return None, external_weight_path + return external_weight_path class CompiledClip(CompiledModule): if external_weights: @@ -277,7 +277,7 @@ def encode_prompts_turbo( module_str = str(module) if compile_to != "vmfb": - return module_str + return module_str, None else: vmfb_path = utils.compile_to_vmfb( module_str,