From 6c8ff719f36f4f8d909ae2b2707129650698fe0b Mon Sep 17 00:00:00 2001 From: um3 Date: Wed, 24 Jan 2024 18:36:14 +0800 Subject: [PATCH] fix DDP test --- test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test.py b/test.py index cb3c815..49f6b69 100644 --- a/test.py +++ b/test.py @@ -157,7 +157,16 @@ #--------------------------- def load_network(network): save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) - network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path)) + except: + if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 + print("Compiling model...") + # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0 + torch.set_float32_matmul_precision('high') + network = torch.compile(network, mode="default", dynamic=True) # pytorch 2.0 + network.load_state_dict(torch.load(save_path)) + return network @@ -281,11 +290,6 @@ def get_id(img_path): #if opt.fp16: # model_structure = network_to_half(model_structure) -if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 - print("Compiling model...") - # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0 - torch.set_float32_matmul_precision('high') - model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0 model = load_network(model_structure)