Skip to content

Commit

Permalink
Add functionality to SD pipeline and abstracted components for saving…
Browse files Browse the repository at this point in the history
… output .npys (#792)
  • Loading branch information
monorimet authored Jul 25, 2024
1 parent 15dbd93 commit 0c02652
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
26 changes: 24 additions & 2 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ class PipelineComponent:
"""

def __init__(
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
self,
printer,
dest_type="devicearray",
dest_dtype="float16",
benchmark=False,
save_outputs=False,
):
self.runner = None
self.module_name = None
self.device = None
self.metadata = None
self.printer = printer
self.benchmark = benchmark
self.save_outputs = save_outputs
self.output_counter = 0
self.dest_type = dest_type
self.dest_dtype = dest_dtype

Expand Down Expand Up @@ -218,6 +225,16 @@ def _output_cast(self, output):
case _:
return output

def save_output(self, function_name, output):
if isinstance(output, tuple) or isinstance(output, list):
for i in output:
self.save_output(function_name, i)
else:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1

def _run(self, function_name, inputs: list):
return self.module[function_name](*inputs)

Expand All @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list):
output = self._run_and_benchmark(function_name, inputs)
else:
output = self._run(function_name, inputs)
if self.save_outputs:
self.save_output(function_name, output)
output = self._output_cast(output)
return output

Expand Down Expand Up @@ -340,6 +359,7 @@ def __init__(
hf_model_name: str | dict[str] = None,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
save_outputs: bool | dict[bool] = False,
common_export_args: dict = {},
):
self.map = model_map
Expand Down Expand Up @@ -374,6 +394,7 @@ def __init__(
"external_weights": external_weights,
"hf_model_name": hf_model_name,
"benchmark": benchmark,
"save_outputs": save_outputs,
}
for arg in map_arguments.keys():
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
Expand All @@ -391,7 +412,7 @@ def __init__(
)
for submodel in self.map.keys():
for key, value in map_arguments.items():
if key != "benchmark":
if key not in ["benchmark", "save_outputs"]:
self.map = merge_export_arg(self.map, value, key)
for key, value in self.map[submodel].get("export_args", {}).items():
if key == "hf_model_name":
Expand Down Expand Up @@ -744,6 +765,7 @@ def load_submodel(self, submodel):
printer=self.printer,
dest_type=dest_type,
benchmark=self.map[submodel].get("benchmark", False),
save_outputs=self.map[submodel].get("save_outputs", False),
)
self.map[submodel]["runner"].load(
self.map[submodel]["driver"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def is_valid_file(arg):
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
)

p.add_argument(
"--save_outputs",
type=str,
default=None,
help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.",
)
##############################################################################
# SDXL Modelling Options
# These options are used to control model defining parameters for SDXL.
Expand Down
13 changes: 12 additions & 1 deletion models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def __init__(
batch_prompts: bool = False,
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
vae_harness: bool = False,
vae_harness: bool = True,
add_tk_kernels: bool = False,
save_outputs: bool | dict[bool] = False,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -286,6 +287,7 @@ def __init__(
hf_model_name,
benchmark,
verbose,
save_outputs,
common_export_args,
)
for submodel in sd_model_map:
Expand Down Expand Up @@ -742,6 +744,14 @@ def numpy_to_pil_image(images):
benchmark[i] = True
else:
benchmark = False
if args.save_outputs:
if args.save_outputs.lower() == "all":
save_outputs = True
else:
for i in args.save_outputs.split(","):
save_outputs[i] = True
else:
save_outputs = False
if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]):
args.decomp_attn = {
"text_encoder": args.decomp_attn,
Expand Down Expand Up @@ -772,6 +782,7 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
save_outputs=save_outputs,
)
sd_pipe.prepare_all()
sd_pipe.load_map()
Expand Down

0 comments on commit 0c02652

Please sign in to comment.