Skip to content

Commit

Permalink
Add CI workflow for JAX distibuted initialize in K8s jobsets
Browse files Browse the repository at this point in the history
  • Loading branch information
yhtang committed Oct 8, 2024
1 parent e5fa965 commit e71ae50
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 9 deletions.
105 changes: 105 additions & 0 deletions .github/workflows/k8s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
name: Distributed run using K8s Jobset

on:
push:
branches:
- main
pull_request:
branches:
- main

permissions:
contents: read
pull-requests: read
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -ex -o pipefail {0}

jobs:

distributed-initialize:
runs-on: ubuntu-22.04
outputs:
TAG: ${{ steps.metadata.outputs.tags }}
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4
with:
path: jax

- name: Start Minikube cluster
uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18

- name: Install K8s Jobset
run: |
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml
- name: Build image
run: |
cat > Dockerfile <<EOF
FROM ubuntu:22.04
ADD jax /opt/jax
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
RUN pip install -e /opt/jax[k8s]
EOF
minikube image build -t local/jax:latest .
- name: Create service account for K8s job introspection
run: |
kubectl apply -f jax/examples/k8s/svc-acct.yaml
- name: Prepare test job
run: |
export VERSION=v4.44.3
export BINARY=yq_linux_amd64
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq
cat jax/examples/k8s/example.yaml |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
tee example.yaml
- name: Submit test job
run: |
kubectl apply -f example.yaml
- name: Check job status
shell: bash -e -o pipefail {0}
run: |
while true; do
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$timestamp] Checking job status..."
if [ "$status" == "Completed" ]; then
echo "[$timestamp] Job has completed successfully!"
exit 0
elif [ "$status" == "Failed" ]; then
echo "[$timestamp] Job has failed!"
exit 1
else
echo "[$timestamp] Job is still running. Current pod status:"
kubectl get pods --no-headers
echo "[$timestamp] Waiting for 3 seconds before checking again..."
sleep 3
fi
done
- name: Examine individual pod outputs
if: "!cancelled()"
run: |
set +x
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
echo "========================================"
echo "Pod $pod output:"
echo "----------------------------------------"
kubectl logs $pod
echo "========================================"
done
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repos:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
exclude: examples/k8s/svc-acct.yaml
- id: end-of-file-fixer
# only include python files
files: \.py$
Expand Down
40 changes: 40 additions & 0 deletions examples/k8s/example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: example
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
spec:
serviceAccountName: training-job-sa
restartPolicy: Never
imagePullSecrets:
- name: null
containers:
- name: main
image: PLACEHOLDER
imagePullPolicy: IfNotPresent
resources:
requests:
cpu: 900m
nvidia.com/gpu: null
limits:
cpu: 1
nvidia.com/gpu: null
command:
- python
args:
- -c
- |
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())
31 changes: 31 additions & 0 deletions examples/k8s/svc-acct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
apiVersion: v1
kind: ServiceAccount
metadata:
name: training-job-sa
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: pod-reader
rules:
- apiGroups: [""]
resources: ["pods"]
verbs: ["get", "list", "watch"]
- apiGroups: ["batch"]
resources: ["jobs"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: pod-reader-binding
namespace: default
subjects:
- kind: ServiceAccount
name: training-job-sa
namespace: default
roleRef:
kind: Role
name: pod-reader
apiGroup: rbac.authorization.k8s.io
20 changes: 11 additions & 9 deletions jax/_src/clusters/k8s_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def is_env_present(cls) -> bool:
try:
import kubernetes as k8s # pytype: disable=import-error
except ImportError as e:
warnings.warn(textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package is "
"not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"
" pip install jax[k8s]"
" OR"
" pip install jax[k8s,<MORE-EXTRAS...>]"
))
warnings.warn(
'\n'.join([
textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package "
"is not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"),
" pip install jax[k8s]",
" pip install jax[k8s,<MORE-EXTRAS...>]",
])
)
return False

k8s.config.load_incluster_config()
Expand Down

0 comments on commit e71ae50

Please sign in to comment.