diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e79b89c0f0d..72fce10872f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -62,6 +62,13 @@ def __call__(self, parser, namespace, values, option_string=None): fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +fpte_group = parser.add_mutually_exclusive_group() +fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") +fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") +fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") +fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") + + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") diff --git a/comfy/model_management.py b/comfy/model_management.py index be4301aa4e3..d4acd8950ca 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -482,6 +482,21 @@ def text_encoder_device(): else: return torch.device("cpu") +def text_encoder_dtype(device=None): + if args.fp8_e4m3fn_text_enc: + return torch.float8_e4m3fn + elif args.fp8_e5m2_text_enc: + return torch.float8_e5m2 + elif args.fp16_text_enc: + return torch.float16 + elif args.fp32_text_enc: + return torch.float32 + + if should_use_fp16(device, prioritize_performance=False): + return torch.float16 + else: + return torch.float32 + def vae_device(): return get_torch_device() diff --git a/comfy/sd.py b/comfy/sd.py index 65d94f46ecc..c3cc8e72080 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -95,10 +95,7 @@ def __init__(self, target=None, embedding_directory=None, no_init=False): load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - if model_management.should_use_fp16(load_device, prioritize_performance=False): - params['dtype'] = torch.float16 - else: - params['dtype'] = torch.float32 + params['dtype'] = model_management.text_encoder_dtype(load_device) self.cond_stage_model = clip(**(params))