diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index c5f550e5..eaf90a56 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -758,6 +758,7 @@ def load_map(self): if not self.map[submodel]["load"]: self.printer.print(f"Skipping load for {submodel}") continue + breakpoint() self.load_submodel(submodel) def load_submodel(self, submodel): diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 3edf6b40..8acf4fe3 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -13,6 +13,7 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.ops.iree import trace_tensor from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch @@ -54,10 +55,9 @@ class TextEncoderModule(torch.nn.Module): @torch.no_grad() def __init__( self, - precision, ): super().__init__() - self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.dtype = torch.float16 self.clip_l = SDClipModel( layer="hidden", layer_idx=-2, @@ -66,25 +66,21 @@ def __init__( layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, - ) - if precision == "fp16": - self.clip_l = self.clip_l.half() + ).half() clip_l_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_l.safetensors", ) with safe_open(clip_l_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype) - if precision == "fp16": - self.clip_l = self.clip_g.half() + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() clip_g_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_g.safetensors", ) with safe_open(clip_g_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float16) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() t5_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/t5xxl_fp16.safetensors", @@ -155,8 +151,7 @@ def export_text_encoders( attn_spec=attn_spec, ) return vmfb_path - model = TextEncoderModule(precision) - mapper = {} + model = TextEncoderModule() assert ( ".safetensors" not in external_weight_path diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 14422ae1..add39b35 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -265,16 +265,7 @@ class CompiledVae(CompiledModule): if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args - - if args.input_mlir: - vae_model = None - else: - vae_model = VaeModel( - args.hf_model_name, - custom_vae=None, - ) mod_str = export_vae_model( - vae_model, args.hf_model_name, args.batch_size, height=args.height, @@ -286,7 +277,6 @@ class CompiledVae(CompiledModule): device=args.device, target=args.iree_target_triple, ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, - variant=args.vae_variant, decomp_attn=args.decomp_attn, attn_spec=args.attn_spec, input_mlir=args.input_mlir,