diff --git a/test.py b/test.py index cb3c815..49f6b69 100644 --- a/test.py +++ b/test.py @@ -157,7 +157,16 @@ #--------------------------- def load_network(network): save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) - network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path)) + except: + if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 + print("Compiling model...") + # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0 + torch.set_float32_matmul_precision('high') + network = torch.compile(network, mode="default", dynamic=True) # pytorch 2.0 + network.load_state_dict(torch.load(save_path)) + return network @@ -281,11 +290,6 @@ def get_id(img_path): #if opt.fp16: # model_structure = network_to_half(model_structure) -if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 - print("Compiling model...") - # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0 - torch.set_float32_matmul_precision('high') - model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0 model = load_network(model_structure)