From 6f6585b6da94a3e0eeb6bd0b0a48912a397edd6e Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 19 Aug 2024 11:25:18 -0500 Subject: [PATCH] Fix an issue with hardcoded iterations --- models/turbine_models/custom_models/pipeline_base.py | 5 ++++- .../turbine_models/custom_models/sd_inference/sd_pipeline.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 3d6f9ea16..9499e5184 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -437,6 +437,7 @@ def prepare_all( vmfbs: dict = {}, weights: dict = {}, interactive: bool = False, + num_steps: int = 20, ): ready = self.is_prepared(vmfbs, weights) match ready: @@ -463,7 +464,7 @@ def prepare_all( if not self.map[submodel].get("weights") and self.map[submodel][ "export_args" ].get("external_weights"): - self.export_submodel(submodel, weights_only=True) + self.export_submodel(submodel, weights_only=True, num_steps=num_steps) return self.prepare_all(mlirs, vmfbs, weights, interactive) def is_prepared(self, vmfbs, weights): @@ -581,6 +582,7 @@ def export_submodel( submodel: str, input_mlir: str = None, weights_only: bool = False, + num_steps: int = 20, ): if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) @@ -672,6 +674,7 @@ def export_submodel( self.map[submodel]["export_args"]["max_length"], "produce_img_split", unet_module_name=self.map["unet"]["module_name"], + num_steps=num_steps, ) dims = [ self.map[submodel]["export_args"]["width"], diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 74e871b67..c74b877b2 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -831,7 +831,7 @@ def numpy_to_pil_image(images): False, args.compiled_pipeline, ) - sd_pipe.prepare_all() + sd_pipe.prepare_all(num_steps=args.num_inference_steps) sd_pipe.load_map() sd_pipe.generate_images( args.prompt,