diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 24af34a6..1caba490 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -252,7 +252,7 @@ def compile_to_vmfb( flags.extend(MI_flags["pad_attention"]) elif "punet" in flagset_keywords: flags.extend(MI_flags["punet"]) - elif "vae_preprocess" in flagset_keywords: + elif "vae" in safe_name: flags.extend(MI_flags["vae_preprocess"]) else: flags.extend(MI_flags["preprocess_default"])