From 6bc8cb48515819f70abdcea3f3adc5d4fb17b628 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 15:14:52 -0500 Subject: [PATCH] Don't redundantly use tqdm progress if we're printing benchmarks --- models/turbine_models/custom_models/pipeline_base.py | 3 ++- .../turbine_models/custom_models/sd_inference/sd_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 7dc5cc5d1..cb336e021 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -343,7 +343,8 @@ def __init__( common_export_args: dict = {}, ): self.map = model_map - self.printer = Printer(verbose, time.time(), True) + self.verbose = verbose + self.printer = Printer(self.verbose, time.time(), True) if isinstance(device, dict): assert isinstance( target, dict diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index d482be00f..d2a65cf3a 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -565,7 +565,7 @@ def _produce_latents_sdxl( [guidance_scale], dtype=self.map["unet"]["np_dtype"], ) - for i, t in tqdm(enumerate(timesteps)): + for i, t in tqdm(enumerate(timesteps), disable=(self.benchmark and self.verbose)): if self.cpu_scheduling: latent_model_input, t = self.scheduler.scale_model_input( latents,