diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py index 21c2fabf..cff8d674 100644 --- a/diffusion/evaluation/generate_images.py +++ b/diffusion/evaluation/generate_images.py @@ -63,8 +63,13 @@ def __init__(self, self.hf_model = hf_model if hf_model or isinstance(model, str): print(f'LOCALRANK{dist.get_local_rank()}') + if dist.get_local_rank() == 0: + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() self.model = AutoPipelineForText2Image.from_pretrained( - model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() else: self.model = model self.dataset = dataset