Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha committed Jul 17, 2024
1 parent caf1e61 commit 51c1105
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion diffusion/evaluation/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
else:
self.model = model
Expand Down
18 changes: 8 additions & 10 deletions diffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm
from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm
from composer.core import Precision
from composer.utils import reproducibility, dist, get_device
from composer.utils import dist, get_device, reproducibility
from datasets import load_dataset
from omegaconf import DictConfig
from torch.utils.data import Dataset
Expand All @@ -26,7 +26,7 @@ def generate(config: DictConfig) -> None:
config (DictConfig): Configuration composed by Hydra
"""
reproducibility.seed_all(config.seed)
device = get_device()
device = get_device() # type: ignore
dist.initialize_dist(device, config.dist_timeout)

# The model to evaluate
Expand All @@ -43,7 +43,7 @@ def generate(config: DictConfig) -> None:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split = config.dataset.split)
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)
Expand Down Expand Up @@ -79,13 +79,11 @@ def generate(config: DictConfig) -> None:
optimizers=None,
)

image_generator: ImageGenerator = hydra.utils.instantiate(
config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset
)
image_generator: ImageGenerator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)

def generate_from_model():
image_generator.generate()
Expand Down

0 comments on commit 51c1105

Please sign in to comment.