Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CI workflow to check JAX distibuted initialize within K8s jobsets #24197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions .github/workflows/k8s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
name: Distributed run using K8s Jobset

on:
push:
branches:
- main
pull_request:
branches:
- main

permissions:
contents: read
pull-requests: read
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -ex -o pipefail {0}

jobs:

distributed-initialize:
runs-on: ubuntu-22.04
outputs:
TAG: ${{ steps.metadata.outputs.tags }}
steps:
- name: Checkout
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4
with:
path: jax

- name: Start Minikube cluster
uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18

- name: Install K8s Jobset
run: |
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml

- name: Build image
run: |
cat > Dockerfile <<EOF
FROM ubuntu:22.04
ADD jax /opt/jax
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
RUN pip install -e /opt/jax[k8s]
EOF

minikube image build -t local/jax:latest .

- name: Create service account for K8s job introspection
run: |
kubectl apply -f jax/examples/k8s/svc-acct.yaml

- name: Prepare test job
run: |
export VERSION=v4.44.3
export BINARY=yq_linux_amd64
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq

cat jax/examples/k8s/example.yaml |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
tee example.yaml

- name: Submit test job
run: |
kubectl apply -f example.yaml

- name: Check job status
shell: bash -e -o pipefail {0}
run: |
while true; do
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$timestamp] Checking job status..."

if [ "$status" == "Completed" ]; then
echo "[$timestamp] Job has completed successfully!"
exit 0
elif [ "$status" == "Failed" ]; then
echo "[$timestamp] Job has failed!"
exit 1
else
echo "[$timestamp] Job is still running. Current pod status:"
kubectl get pods --no-headers
echo "[$timestamp] Waiting for 3 seconds before checking again..."
sleep 3
fi
done

- name: Examine individual pod outputs
if: "!cancelled()"
run: |
set +x
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
echo "========================================"
echo "Pod $pod output:"
echo "----------------------------------------"
kubectl logs $pod
echo "========================================"
done
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repos:
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
exclude: examples/k8s/svc-acct.yaml
- id: end-of-file-fixer
# only include python files
files: \.py$
Expand Down
40 changes: 40 additions & 0 deletions examples/k8s/example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: example
spec:
replicatedJobs:
- name: workers
template:
spec:
parallelism: 2
completions: 2
backoffLimit: 0
template:
spec:
serviceAccountName: training-job-sa
restartPolicy: Never
imagePullSecrets:
- name: null
containers:
- name: main
image: PLACEHOLDER
imagePullPolicy: IfNotPresent
resources:
requests:
cpu: 900m
nvidia.com/gpu: null
limits:
cpu: 1
nvidia.com/gpu: null
command:
- python
args:
- -c
- |
import jax
jax.distributed.initialize()
print(jax.devices())
print(jax.local_devices())
assert jax.process_count() > 1
assert len(jax.devices()) > len(jax.local_devices())
31 changes: 31 additions & 0 deletions examples/k8s/svc-acct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
apiVersion: v1
kind: ServiceAccount
metadata:
name: training-job-sa
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: pod-reader
rules:
- apiGroups: [""]
resources: ["pods"]
verbs: ["get", "list", "watch"]
- apiGroups: ["batch"]
resources: ["jobs"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: pod-reader-binding
namespace: default
subjects:
- kind: ServiceAccount
name: training-job-sa
namespace: default
roleRef:
kind: Role
name: pod-reader
apiGroup: rbac.authorization.k8s.io
20 changes: 11 additions & 9 deletions jax/_src/clusters/k8s_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def is_env_present(cls) -> bool:
try:
import kubernetes as k8s # pytype: disable=import-error
except ImportError as e:
warnings.warn(textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package is "
"not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"
" pip install jax[k8s]"
" OR"
" pip install jax[k8s,<MORE-EXTRAS...>]"
))
warnings.warn(
'\n'.join([
textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package "
"is not installed to enable automatic bootstrapping in this "
"environment. To enable automatic boostrapping, please install "
"jax with the [k8s] extra. For example:"),
" pip install jax[k8s]",
" pip install jax[k8s,<MORE-EXTRAS...>]",
])
)
return False

k8s.config.load_incluster_config()
Expand Down