diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index 55a3e3cc51..658d76fdb4 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -2,7 +2,6 @@ from functools import partial from functools import wraps from typing import ClassVar, TYPE_CHECKING -import inspect import math import os import sys @@ -112,6 +111,11 @@ def world_size(self) -> int: def init_method(self): return f"{common_utils.FILE_SCHEMA}{self.file_name}" + @property + def destroy_pg_upon_exit(self) -> bool: + # Overriding base test class: do not auto destroy PG upon exit. + return False + @classmethod def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): assert not fake_pg, "Not yet supported here..." @@ -130,14 +134,10 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) - if "destroy_process_group" in inspect.signature(self.run_test).parameters: - run_test_kwargs = {"destroy_process_group": False} - else: - run_test_kwargs = {} torch.distributed.barrier() try: - self.run_test(test_name, pipe, **run_test_kwargs) + self.run_test(test_name, pipe) except Exception: raise finally: