Skip to content

Commit

Permalink
fix DDP test
Browse files Browse the repository at this point in the history
  • Loading branch information
um3 committed Jan 24, 2024
1 parent c2ef853 commit 6c8ff71
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,16 @@
#---------------------------
def load_network(network):
save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path))
except:
if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7
print("Compiling model...")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
torch.set_float32_matmul_precision('high')
network = torch.compile(network, mode="default", dynamic=True) # pytorch 2.0
network.load_state_dict(torch.load(save_path))

return network


Expand Down Expand Up @@ -281,11 +290,6 @@ def get_id(img_path):
#if opt.fp16:
# model_structure = network_to_half(model_structure)

if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7
print("Compiling model...")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
torch.set_float32_matmul_precision('high')
model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0

model = load_network(model_structure)

Expand Down

0 comments on commit 6c8ff71

Please sign in to comment.