From cd687540319343f27b8f31b4dfde0bf9a66c8635 Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Tue, 16 Jul 2024 19:05:10 -0700 Subject: [PATCH] fix? --- diffusion/evaluation/generate_images.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py index 21c2fabf..cff8d674 100644 --- a/diffusion/evaluation/generate_images.py +++ b/diffusion/evaluation/generate_images.py @@ -63,8 +63,13 @@ def __init__(self, self.hf_model = hf_model if hf_model or isinstance(model, str): print(f'LOCALRANK{dist.get_local_rank()}') + if dist.get_local_rank() == 0: + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() self.model = AutoPipelineForText2Image.from_pretrained( - model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() else: self.model = model self.dataset = dataset