Skip to content

Commit

Permalink
Add some command line arguments to store text encoder weights in fp8.
Browse files Browse the repository at this point in the history
Pytorch supports two variants of fp8:
--fp8_e4m3fn-text-enc (the one that seems to give better results)
--fp8_e5m2-text-enc
  • Loading branch information
comfyanonymous committed Nov 17, 2023
1 parent 107e78b commit 0cf4e86
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
7 changes: 7 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def __call__(self, parser, namespace, values, option_string=None):
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")

fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")


parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
Expand Down
15 changes: 15 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")

def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32

if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32

def vae_device():
return get_torch_device()

Expand Down
5 changes: 1 addition & 4 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ def __init__(self, target=None, embedding_directory=None, no_init=False):
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)

self.cond_stage_model = clip(**(params))

Expand Down

0 comments on commit 0cf4e86

Please sign in to comment.