Skip to content

Commit

Permalink
⚡ fixing outputting of results
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Jul 7, 2023
1 parent 8848d45 commit 2d1c5ed
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 33 deletions.
73 changes: 41 additions & 32 deletions supercat/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,23 @@ def output_results(
self,
results,
return_data:bool=False,
output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
output_dir: Path = ta.Param(None, help="The location of the output directory. If not given then it uses the directory of the item."),
suffix:str = ta.Param("", help="The file extension for the output file."),
**kwargs
):
list_to_return = []
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

for item, result in zip(self.items, results[0]):
extension = item.name[item.name.rfind(".")+1:].lower()
stub = item.name[:-len(extension)]
new_name = f"{stub}upscaled.{extension}"
new_path = item.parent/new_name
my_suffix = suffix or item.suffix
if my_suffix[0] != ".":
my_suffix = "." + my_suffix

new_name = item.with_suffix("").name + f".upscaled{my_suffix}"
my_output_dir = output_dir or item.parent
new_path = my_output_dir/new_name

dim = len(result.shape) - 1

Expand Down Expand Up @@ -261,35 +269,36 @@ def extra_callbacks(self):
def inference_callbacks(self):
return [DDPMSamplerCallback()]

def output_results(
self,
results,
output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
diffusion_gif:bool=False,
diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."),
**kwargs,
):
final_results = [[result[-1] for result in results[0][0]]]
to_return = super().output_results(final_results, output_dir=output_dir, **kwargs)

if diffusion_gif:
assert self.dim == 2

output_dir = Path(output_dir)
print(f"Saving {len(results[0])} generated images:")

transform = T.ToPILImage()
output_dir.mkdir(exist_ok=True, parents=True)
images = []
for index, image in enumerate(results[0][0]):
path = output_dir/f"image.{index}.png"
# def output_results(
# self,
# results,
# output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
# diffusion_gif:bool=False,
# diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."),
# **kwargs,
# ):
# breakpoint()
# # final_results = [[result[-1] for result in results[0][0]]]
# to_return = super().output_results(results, output_dir=output_dir, **kwargs)

# if diffusion_gif:
# assert self.dim == 2

# output_dir = Path(output_dir)
# print(f"Saving {len(results[0])} generated images:")

# transform = T.ToPILImage()
# output_dir.mkdir(exist_ok=True, parents=True)
# images = []
# for index, image in enumerate(results[0][0]):
# path = output_dir/f"image.{index}.png"

image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0))
images.append(image)
print(f"\t{path}")
images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps)
# image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0))
# images.append(image)
# print(f"\t{path}")
# images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps)

return to_return
# return to_return

if __name__ == "__main__":
SupercatDiffusion.main()
3 changes: 2 additions & 1 deletion supercat/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def before_batch(self):
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t*z
outputs.append(xt)

self.learn.pred = (torch.stack(outputs, dim=1),)
# self.learn.pred = (torch.stack(outputs, dim=1),)
self.learn.pred = (xt,)

raise CancelBatchException

0 comments on commit 2d1c5ed

Please sign in to comment.