diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 1e016e1a1b..d36ce63d8e 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -40,25 +40,11 @@ def warmup( ) -> None: indices, offsets, weights = request.unpack_3() if warmup_ms: - if torch.cuda.is_available(): - elapsed_time_ms = 0 - torch.cuda.synchronize() - start_events = torch.cuda.Event(enable_timing=True) - end_events = torch.cuda.Event(enable_timing=True) - while elapsed_time_ms < warmup_ms: - start_events.record() - out = func(indices, offsets, weights) - if bwd_only: - out.backward(grad) - end_events.record() - torch.cuda.synchronize() - elapsed_time_ms += start_events.elapsed_time(end_events) - else: - start_time_ms = time.time() * 1000 - while time.time() * 1000 - start_time_ms < warmup_ms: - out = func(indices, offsets, weights) - if bwd_only: - out.backward(grad) + start_time_ms = time.time() * 1000 + while time.time() * 1000 - start_time_ms < warmup_ms: + out = func(indices, offsets, weights) + if bwd_only: + out.backward(grad) else: for _ in range(warmup_runs): out = func(indices, offsets, weights)