Skip to content

Commit

Permalink
Changes so no external downloads. (#781)
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri authored Jul 19, 2024
1 parent 1f19c7f commit 39c0c00
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def __init__(
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -304,6 +306,7 @@ def __init__(
self.cpu_scheduling = cpu_scheduling
self.scheduler_id = scheduler_id
self.num_inference_steps = num_inference_steps
self.punet_quant_paths = punet_quant_paths

self.text_encoder = None
self.unet = None
Expand Down Expand Up @@ -340,6 +343,7 @@ def __init__(
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
self.scheduler_target = self.map["unet"]["target"]
self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path
elif not self.is_sd3:
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -364,6 +368,7 @@ def setup_punet(self):
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
)
self.map["unet"]["export_args"]["quant_paths"] = self.punet_quant_paths
for idx, word in enumerate(self.map["unet"]["keywords"]):
if word in ["fp32", "fp16"]:
self.map["unet"]["keywords"][idx] = "i8"
Expand Down
26 changes: 19 additions & 7 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(
return noise_pred


def get_punet_model(hf_model_name, external_weight_path, precision="i8"):
def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision="i8"):
from sharktank.models.punet.model import (
Unet2DConditionModel as sharktank_unet2d,
ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel,
Expand All @@ -103,14 +103,23 @@ def download(filename):
repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision
)

results = {
"config.json": download("config.json"),
"params.safetensors": download("params.safetensors"),
}
if not quant_paths:
results = {
"config.json": download("config.json"),
"params.safetensors": download("params.safetensors"),
}
else:
results = {
"config.json": quant_paths["config"],
"params.safetensors": quant_paths["params"],
}
output_dir = os.path.dirname(external_weight_path)

if precision == "i8":
results["quant_params.json"] = download("quant_params.json")
if quant_paths:
results["quant_params.json"] = quant_paths["quant_params"]
else:
results["quant_params.json"] = download("quant_params.json")
ds_filename = os.path.basename(external_weight_path)
output_path = os.path.join(output_dir, ds_filename)
ds = get_punet_dataset(
Expand Down Expand Up @@ -177,6 +186,7 @@ def export_unet_model(
input_mlir=None,
weights_only=False,
use_punet=False,
quant_paths=None,
):
if use_punet:
submodel_name = "punet"
Expand Down Expand Up @@ -213,7 +223,9 @@ def export_unet_model(
)
return vmfb_path
elif use_punet:
unet_model = get_punet_model(hf_model_name, external_weight_path, precision)
unet_model = get_punet_model(
hf_model_name, external_weight_path, quant_paths, precision
)
else:
unet_model = UnetModel(hf_model_name, hf_auth_token, precision)

Expand Down
7 changes: 4 additions & 3 deletions models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ def export_vae_model(

mapper = {}

utils.save_external_weights(
mapper, vae_model, external_weights, external_weight_path
)
if not os.path.exists(external_weight_path):
utils.save_external_weights(
mapper, vae_model, external_weights, external_weight_path
)
if weights_only:
return external_weight_path

Expand Down

0 comments on commit 39c0c00

Please sign in to comment.