diff --git a/speed.py b/speed.py index 20a3482d..27d37ff0 100644 --- a/speed.py +++ b/speed.py @@ -102,13 +102,14 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input Measure inference times for each model """ times = {name: [] for name in compiled_models.keys()} - times['Stitching and Retargeting Modules'] = [] + times =torch.cuda.Stream() overall_times = [] - + with torch.no_grad(): + torch.cuda.synchronize() for _ in range(100): - torch.cuda.synchronize() + overall_start = time.time() start = time.time() @@ -139,7 +140,7 @@ def measure_inference_times(compiled_models, stitching_retargeting_module, input times['Stitching and Retargeting Modules'].append(time.time() - start) overall_times.append(time.time() - overall_start) - + torch.cuda.synchronize() return times, overall_times