From b58bec516599c41155cd3ab095c13a2ca152b19d Mon Sep 17 00:00:00 2001 From: umeshp07 <120558572+umeshp07@users.noreply.github.com> Date: Thu, 5 Jun 2025 08:28:22 +0530 Subject: [PATCH] Update speed.py --- speed.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/speed.py b/speed.py index 20a3482d..27d37ff0 100644 --- a/speed.py +++ b/speed.py @@ -102,13 +102,14 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input Measure inference times for each model """ times = {name: [] for name in compiled_models.keys()} - times['Stitching and Retargeting Modules'] = [] + times =torch.cuda.Stream() overall_times = [] - + with torch.no_grad(): + torch.cuda.synchronize() for _ in range(100): - torch.cuda.synchronize() + overall_start = time.time() start = time.time() @@ -139,7 +140,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input times['Stitching and Retargeting Modules'].append(time.time() - start) overall_times.append(time.time() - overall_start) - + torch.cuda.synchronize() return times, overall_times