From 2d1c5ed2ce4409edb3f87777fedf04bb90ed99e2 Mon Sep 17 00:00:00 2001 From: rturnbull Date: Fri, 7 Jul 2023 12:49:55 +1000 Subject: [PATCH] :zap: fixing outputting of results --- supercat/apps.py | 73 ++++++++++++++++++++++++------------------- supercat/diffusion.py | 3 +- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/supercat/apps.py b/supercat/apps.py index 2f99bcc..58ca8d0 100644 --- a/supercat/apps.py +++ b/supercat/apps.py @@ -221,15 +221,23 @@ def output_results( self, results, return_data:bool=False, - output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), + output_dir: Path = ta.Param(None, help="The location of the output directory. If not given then it uses the directory of the item."), + suffix:str = ta.Param("", help="The file extension for the output file."), **kwargs ): list_to_return = [] + if output_dir: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + for item, result in zip(self.items, results[0]): - extension = item.name[item.name.rfind(".")+1:].lower() - stub = item.name[:-len(extension)] - new_name = f"{stub}upscaled.{extension}" - new_path = item.parent/new_name + my_suffix = suffix or item.suffix + if my_suffix[0] != ".": + my_suffix = "." + my_suffix + + new_name = item.with_suffix("").name + f".upscaled{my_suffix}" + my_output_dir = output_dir or item.parent + new_path = my_output_dir/new_name dim = len(result.shape) - 1 @@ -261,35 +269,36 @@ def extra_callbacks(self): def inference_callbacks(self): return [DDPMSamplerCallback()] - def output_results( - self, - results, - output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), - diffusion_gif:bool=False, - diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."), - **kwargs, - ): - final_results = [[result[-1] for result in results[0][0]]] - to_return = super().output_results(final_results, output_dir=output_dir, **kwargs) - - if diffusion_gif: - assert self.dim == 2 - - output_dir = Path(output_dir) - print(f"Saving {len(results[0])} generated images:") - - transform = T.ToPILImage() - output_dir.mkdir(exist_ok=True, parents=True) - images = [] - for index, image in enumerate(results[0][0]): - path = output_dir/f"image.{index}.png" + # def output_results( + # self, + # results, + # output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), + # diffusion_gif:bool=False, + # diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."), + # **kwargs, + # ): + # breakpoint() + # # final_results = [[result[-1] for result in results[0][0]]] + # to_return = super().output_results(results, output_dir=output_dir, **kwargs) + + # if diffusion_gif: + # assert self.dim == 2 + + # output_dir = Path(output_dir) + # print(f"Saving {len(results[0])} generated images:") + + # transform = T.ToPILImage() + # output_dir.mkdir(exist_ok=True, parents=True) + # images = [] + # for index, image in enumerate(results[0][0]): + # path = output_dir/f"image.{index}.png" - image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0)) - images.append(image) - print(f"\t{path}") - images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps) + # image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0)) + # images.append(image) + # print(f"\t{path}") + # images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps) - return to_return + # return to_return if __name__ == "__main__": SupercatDiffusion.main() diff --git a/supercat/diffusion.py b/supercat/diffusion.py index 7dae402..e77a675 100644 --- a/supercat/diffusion.py +++ b/supercat/diffusion.py @@ -68,7 +68,8 @@ def before_batch(self): xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t*z outputs.append(xt) - self.learn.pred = (torch.stack(outputs, dim=1),) + # self.learn.pred = (torch.stack(outputs, dim=1),) + self.learn.pred = (xt,) raise CancelBatchException