Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pre-commit-run and enabled it on GitHub actions #293

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ on:
push:
branches:
- main
permissions:
contents: write
pull_request:
branches:
- main

jobs:
deploy:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: 3.x
- run: pip install -r docs/requirements.txt
- run: mkdocs gh-deploy --force
python-version: '3.10'
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
20 changes: 20 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: docs
on:
push:
branches:
- main

permissions:
contents: write

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: '3.10'
- run: pip install -r docs/requirements.txt
- run: mkdocs gh-deploy --force
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Install the pre-commit hooks below with
# 'pre-commit install'

# Auto-update the version of the hooks with
# 'pre-commit autoupdate'

# Run the hooks on all files with
# 'pre-commit run --all'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0
hooks:
- id: check-ast
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
# only include python files
files: \.py$
- id: debug-statements
# only include python files
files: \.py$
- id: trailing-whitespace
# only include python files
files: \.py$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
hooks:
- id: ruff
5 changes: 2 additions & 3 deletions examples/block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import jax
import jax.numpy as jnp
from jax import lax
from jax import random
import jax_triton as jt
from jax_triton import pallas as pl
Expand Down Expand Up @@ -178,8 +177,8 @@ def main(unused_argv):
k = random.normal(k_key, shape, dtype=dtype)
v = random.normal(v_key, shape, dtype=dtype)

o = mha(q, k, v).block_until_ready()
o_ref = mha_reference(q, k, v).block_until_ready()
mha(q, k, v).block_until_ready()
mha_reference(q, k, v).block_until_ready()

if __name__ == "__main__":
from absl import app
Expand Down
5 changes: 2 additions & 3 deletions examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jax import random
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

import jax_triton as jt
Expand Down Expand Up @@ -87,7 +86,7 @@ def tree_unflatten(cls, data, xs):
return BlockELL(blocks, blocks_per_row, indices, shape=shape)

def _validate(self):
nblocks, n, m = self.blocks.shape
_nblocks, n, m = self.blocks.shape
nrows = self.blocks_per_row.shape[0]
assert self.indices.shape[0] == nrows
assert len(self.shape) == 2
Expand Down Expand Up @@ -168,7 +167,7 @@ def sdd_matmul(x_ell, y, num_warps: int = 8, num_stages: int = 3, bn: int = 64,
grid = (jt.cdiv(m, bm), jt.cdiv(n, bn))

kernel = functools.partial(sdd_kernel, bm=bm, bn=bn)
out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x.dtype)
out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x_ell.dtype)
return pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=grid, out_shape=out_shape,
debug=debug)(x_ell.blocks, x_ell.indices,
Expand Down
19 changes: 9 additions & 10 deletions examples/pallas/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import functools
import timeit

from typing import Optional, Tuple

import jax.numpy as jnp
from jax import random
import jax
from jax import lax
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
import numpy as np

import jax_triton as jt
Expand Down Expand Up @@ -188,13 +185,15 @@ def main(unused_argv):
x = random.normal(x_key, (batch_size, feature_size), dtype)
h = random.normal(h_key, (batch_size, hidden_size), dtype)
c = random.normal(c_key, (batch_size, hidden_size), dtype)
lstm_cell = jax.jit(functools.partial(lstm_cell,
block_batch=block_batch,
block_hidden=block_hidden,
block_features=block_features,
num_warps=num_warps,
num_stages=num_stages))
y, c_next = jax.block_until_ready(lstm_cell(weights, x, h, c))
lstm_cell_fn = jax.jit(functools.partial(
lstm_cell,
block_batch=block_batch,
block_hidden=block_hidden,
block_features=block_features,
num_warps=num_warps,
num_stages=num_stages,
))
y, c_next = jax.block_until_ready(lstm_cell_fn(weights, x, h, c))
y_ref, c_next_ref = lstm_cell_reference(weights, x, h, c)
np.testing.assert_allclose(y, y_ref, atol=0.05, rtol=0.05)
np.testing.assert_allclose(c_next, c_next_ref, atol=0.05, rtol=0.05)
Expand Down
11 changes: 11 additions & 0 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
# limitations under the License.

"""Library for JAX-Triton integrations."""

__all__ = [
"utils",
"triton_call",
"cdiv",
"next_power_of_2",
"strides_from_shape",
"__version__",
"__version_info__",
]

import jaxlib
from jax._src.lib import gpu_triton
from jax_triton import utils
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
jax.nn.sigmoid = sigmoid
del sigmoid, oryx, jax

from jax_triton.experimental.fusion.lowering import jit
from jax_triton.experimental.fusion.lowering import jit as jit
3 changes: 1 addition & 2 deletions jax_triton/experimental/fusion/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax import lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import core
from jax._src import util
from jax._src.lax.control_flow import for_loop
Expand Down Expand Up @@ -251,6 +250,7 @@ def _matmul_elementwise_lowering_rule(x, y, *args, left_ops, right_ops, out_ops,
bias, = args
else:
bias = None
del bias # TODO(sharadmv): Please fix or remove `bias` above.
lhs_dim, rhs_dim = contract_dims
M, N, K = x.shape[1 - lhs_dim], y.shape[1 - rhs_dim], x.shape[lhs_dim]
assert x.shape[lhs_dim] == y.shape[rhs_dim]
Expand Down Expand Up @@ -340,4 +340,3 @@ def _dot_general_lowering_rule(x, y, dimension_numbers, **_):
out_ops=[], contract_dims=(lhs_dim,
rhs_dim))
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule

2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import itertools as it

from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, List, Tuple, Union

from jax._src import core as jax_core
import jax.numpy as jnp
Expand Down
2 changes: 0 additions & 2 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,6 @@ def compile_ttir_to_hsaco_inplace(
amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options)
hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options)

if hip_options.debug:
print(x)
name = metadata["name"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
Expand Down
7 changes: 6 additions & 1 deletion jax_triton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.

"""Contains utilities for writing and calling Triton functions."""


__all__ = ["cdiv", "strides_from_shape", "next_power_of_2"]


from jax.experimental.pallas import cdiv
from jax.experimental.pallas import strides_from_shape
from jax.experimental.pallas import next_power_of_2
from jax.experimental.pallas import next_power_of_2
22 changes: 22 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,25 @@ packages = ["jax_triton"]

[tool.setuptools.dynamic]
version = {attr = "jax_triton.version.__version__"}

[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
"*.ipynb",
]
line-length = 88
indent-width = 2
target-version = "py310"

[tool.ruff.lint]
ignore = [
# Do not assign a `lambda` expression, use a `def`
"E731",
# Module level import not at top of file
"E402",
# Ambiguous variable name
"E741",
]