Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
# 1. Headless Service: Required for distributed pods to discover each other
apiVersion: v1
kind: Service
metadata:
name: ${JOB_NAME}
namespace: default
spec:
clusterIP: None
selector:
job-name: ${JOB_NAME}
---
# 2. Indexed Job: Manages the distributed workload and queues via Kueue
apiVersion: batch/v1
kind: Job
metadata:
name: ${JOB_NAME}
namespace: default
labels:
kueue.x-k8s.io/queue-name: multislice-queue
spec:
completions: ${TOTAL_PODS}
parallelism: ${TOTAL_PODS}
completionMode: Indexed
template:
metadata:
labels:
job-name: ${JOB_NAME}
spec:
subdomain: ${JOB_NAME}
restartPolicy: Never
containers:
- name: benchmark
image: ${IMAGE}

# ---> IMPORTANT: UPDATE THIS COMMAND <---
command:
- "python3"
- "/path/to/your/benchmark_script.py"
- "--config_file=${FULL_CONFIG_PATH}"
- "--output_directory=${OUTPUT_DIR}"

# 3. Distributed Setup: Injecting JAX environment variables natively
env:
- name: JAX_COORDINATOR_ADDRESS
value: "${JOB_NAME}-0.${JOB_NAME}.default.svc.cluster.local"
- name: JAX_COORDINATOR_PORT
value: "1234"
- name: JAX_PROCESS_COUNT
value: "${TOTAL_PODS}"
- name: JAX_PROCESS_INDEX
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']

# 4. Resource constraint tailored to your cluster
resources:
requests:
cpu: "1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
apiVersion: kueue.x-k8s.io/v1beta1
kind: ResourceFlavor
metadata:
name: "spot-flavor"
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: ClusterQueue
metadata:
name: "xpk-cluster-queue"
spec:
namespaceSelector: {} # Allows jobs from any namespace
resourceGroups:
- coveredResources: ["cpu", "memory"]
flavors:
- name: "spot-flavor"
resources:
- name: "cpu"
nominalQuota: 1000 # Set artificially high to allow scaling
- name: "memory"
nominalQuota: 4000Gi
---
apiVersion: kueue.x-k8s.io/v1beta1
kind: LocalQueue
metadata:
name: "multislice-queue" # XPK strictly looks for this name by default
namespace: "default"
spec:
clusterQueue: "xpk-cluster-queue"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# The name for the entire test suite run.
# Assumes n2-standard-32-32 (32 machines) X 16 replicas
suite_name: "llama-70b_replicas_16"
num_repeats: 1


mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 32}
dcn_parallelism: {"replica": 16}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
# associated with the `RestoreAndBroadcastBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: false
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# The name for the entire test suite run.
# Assumes n2-standard-2-64 (64 machines) X 2 replicas
suite_name: "llama-70b_replicas_2"
num_repeats: 20


mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 2}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
# associated with the `RestoreAndBroadcastBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: false
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true
# enable_trace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# The name for the entire test suite run.
# Assumes n2-standard-2-64 (64 machines) X 2 replicas
suite_name: "llama-70b_replicas_2_no_broadcast"
num_repeats: 20


mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 2}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
# associated with the `RestoreAndBroadcastBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: false
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: False
# enable_trace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# The name for the entire test suite run.
# Assumes v5p-128 (64 devices) X 4 replicas
suite_name: "llama-70b_replicas_4"
num_repeats: 20


mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 4}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
# associated with the `RestoreAndBroadcastBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: false
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: true
enable_trace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# The name for the entire test suite run.
# Assumes v5p-128 (64 devices) X 4 replicas
suite_name: "llama-70b_replicas_4_no_broadcast"
num_repeats: 20


mesh_config:
mesh_axes: ["replica", "model"]
# Should match reference_sharding_path.
ici_parallelism: {"replica": 1, "model": 64}
dcn_parallelism: {"replica": 4}

# Note: checkpoint_config field not specified.

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark"
options:
# --- Generator Options ---
# These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class
# associated with the `RestoreAndBroadcastBenchmark` generator.
async_enabled: true
use_ocdbt: true
use_zarr3: true
use_replica_parallel: false
use_compression: true
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
use_load_and_broadcast: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
apiVersion: batch/v1
kind: Job
metadata:
name: gke-test-job
namespace: default
labels:
# This label tells Kueue to intercept and manage this job
kueue.x-k8s.io/queue-name: multislice-queue
spec:
template:
spec:
containers:
- name: test-container
image: ubuntu
command: ["/bin/sh", "-c", "echo 'Hello from GKE and Kueue!'; sleep 30"]
resources:
requests:
cpu: "1"
memory: "1Gi"
restartPolicy: Never
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh:
logging.info('Creating hybrid mesh.')
dcn_shape = [dcn_parallelism.get(axis, 1) for axis in config.mesh_axes]

if jax.default_backend() == 'cpu':
devices = jax.devices()
# Sort devices by process index to ensure a predictable global grid
devices = sorted(devices, key=lambda d: d.process_index)
global_shape = tuple(d * i for d, i in zip(dcn_shape, ici_shape))
devices_array = np.array(devices).reshape(global_shape)
logging.info(
'Creating CPU-only hybrid mesh with axes: %s',
{axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)},
)
return jax.sharding.Mesh(devices_array, config.mesh_axes)

# --- Validation ---
if config.process_is_granule:
process_count = jax.process_count()
Expand Down
Loading
Loading