diff --git a/submitit/helpers.py b/submitit/helpers.py index 23cfeab..374b551 100644 --- a/submitit/helpers.py +++ b/submitit/helpers.py @@ -293,10 +293,16 @@ def monitor_jobs( @contextlib.contextmanager -def clean_env() -> tp.Iterator[None]: +def clean_env(extra_names: tp.Sequence[str] = ()) -> tp.Iterator[None]: """Removes slurm and submitit related environment variables so as to avoid interferences when submiting a new job from a job. + Parameters + ---------- + extra_names: Sequence[str] + Additional environment variables to hide inside the context, + e.g. TRITON_CACHE_DIR and TORCHINDUCTOR_CACHE_DIR when using torch.compile. + Note ---- A slurm job submitted from within a slurm job inherits some of its attributes, which may @@ -312,7 +318,11 @@ def clean_env() -> tp.Iterator[None]: cluster_env = { x: os.environ.pop(x) for x in os.environ - if x.startswith(("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_")) or x in distrib_names + if ( + x.startswith(("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_")) + or x in distrib_names + or x in extra_names + ) } try: yield diff --git a/submitit/test_helpers.py b/submitit/test_helpers.py index c23722e..c7bab72 100644 --- a/submitit/test_helpers.py +++ b/submitit/test_helpers.py @@ -144,3 +144,8 @@ def test_clean_env() -> None: assert not _get_env() assert len(_get_env()) == len(base) + 2 assert _get_env() == base + + with utils.environment_variables(MASTER_PORT=42, BLABLA=314): + with helpers.clean_env(extra_names=("BLABLA",)): + assert "MASTER_PORT" not in os.environ + assert "BLABLA" not in os.environ