From 0c02652c7337f7ef686c45e7a31e072d5966c6eb Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:18:23 -0500 Subject: [PATCH] Add functionality to SD pipeline and abstracted components for saving output .npys (#792) --- .../custom_models/pipeline_base.py | 26 +++++++++++++++++-- .../custom_models/sd_inference/sd_cmd_opts.py | 6 +++++ .../custom_models/sd_inference/sd_pipeline.py | 13 +++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index b4755283..3102ac3e 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -84,7 +84,12 @@ class PipelineComponent: """ def __init__( - self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False + self, + printer, + dest_type="devicearray", + dest_dtype="float16", + benchmark=False, + save_outputs=False, ): self.runner = None self.module_name = None @@ -92,6 +97,8 @@ def __init__( self.metadata = None self.printer = printer self.benchmark = benchmark + self.save_outputs = save_outputs + self.output_counter = 0 self.dest_type = dest_type self.dest_dtype = dest_dtype @@ -218,6 +225,16 @@ def _output_cast(self, output): case _: return output + def save_output(self, function_name, output): + if isinstance(output, tuple) or isinstance(output, list): + for i in output: + self.save_output(function_name, i) + else: + np.save( + f"{function_name}_output_{self.output_counter}.npy", output.to_host() + ) + self.output_counter += 1 + def _run(self, function_name, inputs: list): return self.module[function_name](*inputs) @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list): output = self._run_and_benchmark(function_name, inputs) else: output = self._run(function_name, inputs) + if self.save_outputs: + self.save_output(function_name, output) output = self._output_cast(output) return output @@ -340,6 +359,7 @@ def __init__( hf_model_name: str | dict[str] = None, benchmark: bool | dict[bool] = False, verbose: bool = False, + save_outputs: bool | dict[bool] = False, common_export_args: dict = {}, ): self.map = model_map @@ -374,6 +394,7 @@ def __init__( "external_weights": external_weights, "hf_model_name": hf_model_name, "benchmark": benchmark, + "save_outputs": save_outputs, } for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) @@ -391,7 +412,7 @@ def __init__( ) for submodel in self.map.keys(): for key, value in map_arguments.items(): - if key != "benchmark": + if key not in ["benchmark", "save_outputs"]: self.map = merge_export_arg(self.map, value, key) for key, value in self.map[submodel].get("export_args", {}).items(): if key == "hf_model_name": @@ -744,6 +765,7 @@ def load_submodel(self, submodel): printer=self.printer, dest_type=dest_type, benchmark=self.map[submodel].get("benchmark", False), + save_outputs=self.map[submodel].get("save_outputs", False), ) self.map[submodel]["runner"].load( self.map[submodel]["driver"], diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 5e025a4d..a852bf46 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -151,6 +151,12 @@ def is_valid_file(arg): help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.", ) +p.add_argument( + "--save_outputs", + type=str, + default=None, + help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.", +) ############################################################################## # SDXL Modelling Options # These options are used to control model defining parameters for SDXL. diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 2456ae4b..97369c53 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -236,8 +236,9 @@ def __init__( batch_prompts: bool = False, punet_quant_paths: dict[str] = None, vae_weight_path: str = None, - vae_harness: bool = False, + vae_harness: bool = True, add_tk_kernels: bool = False, + save_outputs: bool | dict[bool] = False, ): common_export_args = { "hf_model_name": None, @@ -286,6 +287,7 @@ def __init__( hf_model_name, benchmark, verbose, + save_outputs, common_export_args, ) for submodel in sd_model_map: @@ -742,6 +744,14 @@ def numpy_to_pil_image(images): benchmark[i] = True else: benchmark = False + if args.save_outputs: + if args.save_outputs.lower() == "all": + save_outputs = True + else: + for i in args.save_outputs.split(","): + save_outputs[i] = True + else: + save_outputs = False if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): args.decomp_attn = { "text_encoder": args.decomp_attn, @@ -772,6 +782,7 @@ def numpy_to_pil_image(images): args.use_i8_punet, benchmark, args.verbose, + save_outputs=save_outputs, ) sd_pipe.prepare_all() sd_pipe.load_map()