Skip to content

Commit

Permalink
Fix env check order by splitting TPU envs into separate files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687394503
  • Loading branch information
Google-ML-Automation committed Oct 18, 2024
1 parent bbcc3ee commit babfebf
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 107 deletions.
8 changes: 6 additions & 2 deletions jax/_src/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from .ompi_cluster import OmpiCluster as OmpiCluster
from .slurm_cluster import SlurmCluster as SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
from .k8s_cluster import K8sCluster as K8sCluster
from .cloud_gke_cluster import GkeTpuCluster as GkeTpuCluster
# This environment check will query the GCE metadata server, so we put it near
# the end of the list to avoid unnecessary queries.
from .cloud_gce_cluster import GceTpuCluster as GceTpuCluster
# This is an abstract environment, so it should be at the end of the list.
from .cloud_tpu_cluster import BaseTpuCluster as BaseTpuCluster
99 changes: 99 additions & 0 deletions jax/_src/clusters/cloud_gce_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import os
import re
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

logger = logging.getLogger(__name__)

# We use an arbitrarily chosen port for the coordinator since we cannot
# rely on communication to choose one in real time.
coordinator_port = '8476'

metadata_response_code_success = 200

def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')

retry_count = 0
retrySeconds = 0.500
api_resp = None

while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'}, timeout=60)
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)

if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text, api_resp.status_code


class GceTpuCluster(clusters.BaseTpuCluster):

name: str = "gcetpu"

@classmethod
def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
return False
metadata_response, metadata_code = get_metadata('agent-worker-number')
if metadata_code == metadata_response_code_success:
logger.debug("Gce Tpu Cluster detected for Jax Distributed System")
return True
else:
logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata")
logger.debug("Metadata code: %s", metadata_code)
logger.debug("Metadata response: %s", metadata_response)
return False

@staticmethod
def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number')[0])

@staticmethod
def _get_worker_list_in_slice() -> list[str]:
workers = get_metadata('worker-network-endpoints')[0].split(',')
return [worker.split(':')[2] for worker in workers]

@staticmethod
def _get_tpu_env_value(key):
def get_tpu_env_value_from_metadata(key):
tpu_env_data = get_metadata('tpu-env')[0]
key_value_pairs = tpu_env_data.split('\n')
for key_value_pair in key_value_pairs:
# Typical line is MEGASCALE_NUM_SLICES: '2'
if ':' in key_value_pair:
row_key, value = re.split(':', key_value_pair, 1)
row_key = row_key.strip()
if row_key == key:
return value.strip().strip("'")
return None

value = os.environ.get(key, None)
return value if value is not None else get_tpu_env_value_from_metadata(key)
51 changes: 51 additions & 0 deletions jax/_src/clusters/cloud_gke_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import os
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

logger = logging.getLogger(__name__)


class GkeTpuCluster(clusters.BaseTpuCluster):

name: str = "gketpu"

@classmethod
def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
logger.debug("Gke Tpu Cluster detected for Jax Distributed System")
return True
else:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
else:
logger.debug("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set")
return False

@staticmethod
def _get_process_id_in_slice() -> int:
return int(str(os.environ.get('TPU_WORKER_ID')))

@staticmethod
def _get_worker_list_in_slice() -> list[str]:
return str(os.environ.get('TPU_WORKER_HOSTNAMES', None)).split(',')

@staticmethod
def _get_tpu_env_value(key):
return os.environ.get(key, None)
122 changes: 17 additions & 105 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
from __future__ import annotations

import logging
import os
import re
import socket
import time
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm

logger = logging.getLogger(__name__)

Expand All @@ -30,49 +27,6 @@

metadata_response_code_success = 200

def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')

retry_count = 0
retrySeconds = 0.500
api_resp = None

while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'}, timeout=60)
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)

if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text, api_resp.status_code

def get_tpu_env_value(key):
def get_tpu_env_value_from_metadata(key):
tpu_env_data = get_metadata('tpu-env')[0]
key_value_pairs = tpu_env_data.split('\n')
for key_value_pair in key_value_pairs:
# Typical line is MEGASCALE_NUM_SLICES: '2'
if ':' in key_value_pair:
row_key, value = re.split(':', key_value_pair, 1)
row_key = row_key.strip()
if row_key == key:
return value.strip().strip("'")
return None

value = os.environ.get(key, None)
return value if value is not None else get_tpu_env_value_from_metadata(key)

def has_megascale_address():
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None

class BaseTpuCluster(clusters.ClusterEnv):

name: str = "tpu"
Expand All @@ -94,11 +48,11 @@ def is_env_present(cls) -> bool:

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
if has_megascale_address():
if cls._has_megascale_address():
# For both GCE via QueuedResources and GKE via JobSet, the
# Megascale coordinator address is set as the host with process id = 0,
# so can be used as the jax distributed system coordinator.
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
coordinator_address = cls._get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
else:
# For both GCE (QueuedResources and TPUVM create) and GKE via Job API,
# the workers lists are sorted by process ID so the first one can
Expand Down Expand Up @@ -147,20 +101,24 @@ def get_process_id(cls) -> int:
logger.debug("Process ID of %s generated by within-slice id %s and slice id %s", process_id, process_id_in_slice, slice_id)
return process_id

@staticmethod
def _get_num_slices() -> int:
if has_megascale_address():
return int(get_tpu_env_value('MEGASCALE_NUM_SLICES'))
@classmethod
def _get_num_slices(cls) -> int:
if cls._has_megascale_address():
return int(cls._get_tpu_env_value('MEGASCALE_NUM_SLICES'))
else:
return 1

@staticmethod
def _get_slice_id() -> int:
if has_megascale_address():
return int(get_tpu_env_value('MEGASCALE_SLICE_ID'))
@classmethod
def _get_slice_id(cls) -> int:
if cls._has_megascale_address():
return int(cls._get_tpu_env_value('MEGASCALE_SLICE_ID'))
else:
return 0

@classmethod
def _has_megascale_address(cls):
return cls._get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None

@staticmethod
def _get_process_id_in_slice() -> int:
"""Returns a process ID that is unique within slice."""
Expand All @@ -171,54 +129,8 @@ def _get_worker_list_in_slice() -> list[str]:
"""Returns a list of worker endpoints/hostnames within slice."""
raise NotImplementedError()

class GceTpuCluster(BaseTpuCluster):

name: str = "gcetpu"

@classmethod
def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
return False
metadata_response, metadata_code = get_metadata('agent-worker-number')
if metadata_code == metadata_response_code_success:
logger.debug("Gce Tpu Cluster detected for Jax Distributed System")
return True
else:
logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata")
logger.debug("Metadata code: %s", metadata_code)
logger.debug("Metadata response: %s", metadata_response)
return False

@staticmethod
def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number')[0])

@staticmethod
def _get_worker_list_in_slice() -> list[str]:
workers = get_metadata('worker-network-endpoints')[0].split(',')
return [worker.split(':')[2] for worker in workers]

class GkeTpuCluster(BaseTpuCluster):

name: str = "gketpu"

@classmethod
def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
logger.debug("Gke Tpu Cluster detected for Jax Distributed System")
return True
else:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
else:
logger.debug("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set")
return False

@staticmethod
def _get_process_id_in_slice() -> int:
return int(str(os.environ.get('TPU_WORKER_ID')))
def _get_tpu_env_value(key):
"""Returns the value of a TPU environment variable."""
raise NotImplementedError()

@staticmethod
def _get_worker_list_in_slice() -> list[str]:
return str(os.environ.get('TPU_WORKER_HOSTNAMES', None)).split(',')

0 comments on commit babfebf

Please sign in to comment.