Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local_coordinator bool option #20260

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')
Expand Down Expand Up @@ -68,7 +67,7 @@ def get_tpu_env_value_from_metadata(key):
def is_gce_env():
worker_number_string = get_metadata('agent-worker-number')
try:
worker_number = int(worker_number_string)
int(worker_number_string)
return True
except:
return False
Expand All @@ -82,7 +81,9 @@ def is_gke_env():
def get_gce_worker_endpoints() -> str:
return get_metadata('worker-network-endpoints').split(',')


class SingleSliceGceTpuCluster(clusters.ClusterEnv):

@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env()
Expand All @@ -103,7 +104,9 @@ def get_process_id(cls) -> int:
def get_local_process_id(cls) -> int | None:
return None


class MultisliceGceTpuCluster(clusters.ClusterEnv):

@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_multislice_gce_env()
Expand Down Expand Up @@ -159,7 +162,9 @@ def _get_process_count_per_slice() -> int:
def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number'))


class GkeTpuCluster(MultisliceGceTpuCluster):

# This class handles both single and multislice GKE as the environment
# variables are set the same in both cases.
@classmethod
Expand Down
28 changes: 12 additions & 16 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,13 @@ def __init_subclass__(cls, **kwargs):
cls._cluster_types.append(cls)

@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:
if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
def auto_detect_distributed_params(
cls, coordinator_address: str | None, num_processes: int | None,
process_id: int | None, local_device_ids: Sequence[int] | None
) -> tuple[str | None, int | None, int | None, Sequence[int] | None]:
all_args = (coordinator_address, num_processes, process_id, local_device_ids)
if all(p is not None for p in all_args):
return coordinator_address, num_processes, process_id, local_device_ids
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
Expand All @@ -67,10 +62,11 @@ def auto_detect_unset_distributed_params(cls,
env.get_local_process_id() is not None):
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
else:
logger.debug('Could not find a known environment for initializing distributed JAX. '
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
return (coordinator_address, num_processes, process_id, local_device_ids)
# pytype: enable=bad-return-type
logger.debug(
"Could not find a known environment for initializing distributed JAX."
" Known environments: %s",
", ".join(e.__name__ for e in cls._cluster_types))
return coordinator_address, num_processes, process_id, local_device_ids

@classmethod
def is_env_present(cls) -> bool:
Expand Down
23 changes: 15 additions & 8 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,22 @@ def initialize(self,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
local_coordinator: bool = False):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))

if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]

(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids
)
)
if local_coordinator:
coordinator_address = '127.0.0.1:8080'
num_processes = 1
process_id = 0
else:
(coordinator_address, num_processes, process_id,
local_device_ids) = clusters.ClusterEnv.auto_detect_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids)

if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
Expand Down Expand Up @@ -114,7 +119,8 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
local_coordinator: bool = False):
"""Initializes the JAX distributed system.

Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
Expand Down Expand Up @@ -174,7 +180,8 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout)
local_device_ids, initialization_timeout,
local_coordinator)
atexit.register(shutdown)


Expand Down