diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 2fb4bfbbc5..29e422255b 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -875,8 +875,11 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None from jsonargparse import CLI - CLI(benchmark_main) - - # ref: https://github.com/pytorch/pytorch/blob/3af12447/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1110-L1116 - if world_size > 1: - torch_dist.destroy_process_group() + try: + CLI(benchmark_main) + except Exception: + raise + finally: + # ref: https://github.com/pytorch/pytorch/blob/3af12447/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1110-L1116 + if world_size > 1: + torch_dist.destroy_process_group()