diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index 324533c618c5..7fceb07cda5a 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -232,6 +232,7 @@ def __init__( self.backend = backend self._spawn_params = None + self._pg_init_kwargs = None self.init_method = init_method if self.backend is not None: @@ -239,6 +240,9 @@ def __init__( self._spawn_params = self._setup_spawn_params( nproc_per_node, nnodes, node_rank, master_addr, master_port, init_method, **spawn_kwargs ) + else: + self._pg_init_kwargs = spawn_kwargs + # The logger will be setup after the idist.initialize() call self._logger = None @@ -319,7 +323,7 @@ def training(local_rank, config, **kwargs): def __enter__(self) -> "Parallel": if self.backend is not None and self._spawn_params is None: - idist.initialize(self.backend, init_method=self.init_method) + idist.initialize(self.backend, init_method=self.init_method, **(self._pg_init_kwargs or dict())) # The logger can be setup from now since idist.initialize() has been called (if needed) self._logger = setup_logger(__name__ + "." + self.__class__.__name__)