Skip to content

Commit

Permalink
Fix vae script CLI and revert precision changes to sd3 text encoders …
Browse files Browse the repository at this point in the history
…export
  • Loading branch information
eagarvey-amd committed Aug 17, 2024
1 parent 7ecfece commit 18bffdb
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
1 change: 1 addition & 0 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def load_map(self):
if not self.map[submodel]["load"]:
self.printer.print(f"Skipping load for {submodel}")
continue
breakpoint()
self.load_submodel(submodel)

def load_submodel(self, submodel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.ops.iree import trace_tensor
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
Expand Down Expand Up @@ -54,10 +55,9 @@ class TextEncoderModule(torch.nn.Module):
@torch.no_grad()
def __init__(
self,
precision,
):
super().__init__()
self.dtype = torch.float16 if precision == "fp16" else torch.float32
self.dtype = torch.float16
self.clip_l = SDClipModel(
layer="hidden",
layer_idx=-2,
Expand All @@ -66,25 +66,21 @@ def __init__(
layer_norm_hidden_state=False,
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
)
if precision == "fp16":
self.clip_l = self.clip_l.half()
).half()
clip_l_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/clip_l.safetensors",
)
with safe_open(clip_l_weights, framework="pt", device="cpu") as f:
load_into(f, self.clip_l.transformer, "", "cpu", self.dtype)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype)
if precision == "fp16":
self.clip_l = self.clip_g.half()
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half()
clip_g_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/clip_g.safetensors",
)
with safe_open(clip_g_weights, framework="pt", device="cpu") as f:
load_into(f, self.clip_g.transformer, "", "cpu", self.dtype)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float16)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half()
t5_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/t5xxl_fp16.safetensors",
Expand Down Expand Up @@ -155,8 +151,7 @@ def export_text_encoders(
attn_spec=attn_spec,
)
return vmfb_path
model = TextEncoderModule(precision)
mapper = {}
model = TextEncoderModule()

assert (
".safetensors" not in external_weight_path
Expand Down
10 changes: 0 additions & 10 deletions models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,7 @@ class CompiledVae(CompiledModule):

if __name__ == "__main__":
from turbine_models.custom_models.sd_inference.sd_cmd_opts import args

if args.input_mlir:
vae_model = None
else:
vae_model = VaeModel(
args.hf_model_name,
custom_vae=None,
)
mod_str = export_vae_model(
vae_model,
args.hf_model_name,
args.batch_size,
height=args.height,
Expand All @@ -286,7 +277,6 @@ class CompiledVae(CompiledModule):
device=args.device,
target=args.iree_target_triple,
ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags,
variant=args.vae_variant,
decomp_attn=args.decomp_attn,
attn_spec=args.attn_spec,
input_mlir=args.input_mlir,
Expand Down

0 comments on commit 18bffdb

Please sign in to comment.