Bumped jax-triton version along with JAX/Triton requirements #382
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: ci | |
on: | |
push: | |
branches: | |
- main | |
pull_request: | |
branches: | |
- main | |
permissions: | |
contents: read # to fetch code | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
lint: | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Set up Python 3.10 | |
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1 | |
test: | |
runs-on: linux-x86-g2-48-l4-4gpu | |
container: | |
image: index.docker.io/library/ubuntu@sha256:0e5e4a57c2499249aafc3b40fcd541e9a456aab7296681a3994d631587203f97 # ratchet:ubuntu:22.04 | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Set up Python 3.10 | |
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: Setup Released JAX and Torch | |
run: | | |
# triton needs a compiler | |
apt-get update && apt-get install -y build-essential | |
pip install torch | |
pip install -U "jax[cuda12]" | |
pip install pytest | |
- name: Test JAX Triton | |
run: | | |
echo "Running JAX Triton GPU Tests" | |
nvidia-smi | |
pip install . | |
# Need newer ml-dtypes because we install newer numpy | |
pip install --upgrade ml-dtypes | |
pytest -v --tb=short tests/ | |