From 087c0904eadcb0268adbc2a8e7002704409d902a Mon Sep 17 00:00:00 2001 From: saienduri Date: Thu, 18 Jul 2024 19:28:31 -0700 Subject: [PATCH 1/4] changes so no external downloads --- .../custom_models/sd_inference/sd_pipeline.py | 8 ++++--- .../custom_models/sdxl_inference/unet.py | 24 +++++++++++++------ .../custom_models/sdxl_inference/vae.py | 7 +++--- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 256bcd8ee..1e8151ad3 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -234,6 +234,8 @@ def __init__( benchmark: bool | dict[bool] = False, verbose: bool = False, batch_prompts: bool = False, + punet_quant_paths: dict[str] = None, + vae_weight_path: str = None, ): common_export_args = { "hf_model_name": None, @@ -304,6 +306,7 @@ def __init__( self.cpu_scheduling = cpu_scheduling self.scheduler_id = scheduler_id self.num_inference_steps = num_inference_steps + self.punet_quant_paths = punet_quant_paths self.text_encoder = None self.unet = None @@ -340,6 +343,7 @@ def __init__( self.scheduler_device = self.map["unet"]["device"] self.scheduler_driver = self.map["unet"]["driver"] self.scheduler_target = self.map["unet"]["target"] + self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -361,9 +365,7 @@ def __init__( def setup_punet(self): if self.use_i8_punet: self.map["unet"]["export_args"]["precision"] = "i8" - self.map["unet"]["export_args"]["external_weight_path"] = ( - utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" - ) + self.map["unet"]["export_args"]["external_weight_path"] = self.punet_weight_path for idx, word in enumerate(self.map["unet"]["keywords"]): if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 38f743066..6cc6effac 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -82,7 +82,7 @@ def forward( return noise_pred -def get_punet_model(hf_model_name, external_weight_path, precision="i8"): +def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision="i8"): from sharktank.models.punet.model import ( Unet2DConditionModel as sharktank_unet2d, ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel, @@ -103,14 +103,23 @@ def download(filename): repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision ) - results = { - "config.json": download("config.json"), - "params.safetensors": download("params.safetensors"), - } + if not quant_paths: + results = { + "config.json": download("config.json"), + "params.safetensors": download("params.safetensors"), + } + else: + results = { + "config.json": quant_paths["config"], + "params.safetensors": quant_paths["params"], + } output_dir = os.path.dirname(external_weight_path) if precision == "i8": - results["quant_params.json"] = download("quant_params.json") + if quant_paths: + results["quant_params.json"] = quant_paths["quant_params"] + else: + results["quant_params.json"] = download("quant_params.json") ds_filename = os.path.basename(external_weight_path) output_path = os.path.join(output_dir, ds_filename) ds = get_punet_dataset( @@ -177,6 +186,7 @@ def export_unet_model( input_mlir=None, weights_only=False, use_punet=False, + quant_paths=None, ): if use_punet: submodel_name = "punet" @@ -213,7 +223,7 @@ def export_unet_model( ) return vmfb_path elif use_punet: - unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + unet_model = get_punet_model(hf_model_name, external_weight_path, quant_paths, precision) else: unet_model = UnetModel(hf_model_name, hf_auth_token, precision) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 753cbb9e7..8a02dc192 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -119,9 +119,10 @@ def export_vae_model( mapper = {} - utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path - ) + if not os.path.exists(external_weight_path): + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_path + ) if weights_only: return external_weight_path From f7c2875ea067025de98d1425aa9a4b11071355c3 Mon Sep 17 00:00:00 2001 From: saienduri Date: Thu, 18 Jul 2024 21:07:28 -0700 Subject: [PATCH 2/4] revert old attempt --- .../turbine_models/custom_models/sd_inference/sd_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 1e8151ad3..7541e8cc0 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -365,7 +365,9 @@ def __init__( def setup_punet(self): if self.use_i8_punet: self.map["unet"]["export_args"]["precision"] = "i8" - self.map["unet"]["export_args"]["external_weight_path"] = self.punet_weight_path + self.map["unet"]["export_args"]["external_weight_path"] = ( + utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" + ) for idx, word in enumerate(self.map["unet"]["keywords"]): if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" From 496e4bfbac99fdb0082bc1193be00a89eb1403d0 Mon Sep 17 00:00:00 2001 From: saienduri Date: Thu, 18 Jul 2024 21:41:32 -0700 Subject: [PATCH 3/4] add unet export quant paths param --- models/turbine_models/custom_models/sd_inference/sd_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 7541e8cc0..71f1b0db3 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -368,6 +368,7 @@ def setup_punet(self): self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" ) + self.map["unet"]["export_args"]["quant_paths"] = self.punet_quant_paths for idx, word in enumerate(self.map["unet"]["keywords"]): if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" From c90a7bd31262563935f4ccb2f0c739fad81b7d56 Mon Sep 17 00:00:00 2001 From: saienduri Date: Fri, 19 Jul 2024 13:46:34 -0700 Subject: [PATCH 4/4] formatting --- models/turbine_models/custom_models/sdxl_inference/unet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 6cc6effac..a62678ac7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -223,7 +223,9 @@ def export_unet_model( ) return vmfb_path elif use_punet: - unet_model = get_punet_model(hf_model_name, external_weight_path, quant_paths, precision) + unet_model = get_punet_model( + hf_model_name, external_weight_path, quant_paths, precision + ) else: unet_model = UnetModel(hf_model_name, hf_auth_token, precision)