From a72a159c46e8b7937c1b0c7458ff595e7fb491a0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 16 Sep 2024 16:11:21 +0100 Subject: [PATCH] Upgraded the sources to 3.10 The diff was generated via pyupgrade --py310-plus --keep-percent-format **/*.py I also removed redundant `setuptools` dependency from pyproject.toml as a drive by change. --- examples/block_map.py | 3 +- examples/pallas/blocksparse_matmul.py | 3 +- examples/pallas/lstm.py | 4 +-- jax_triton/experimental/fusion/fusion.py | 4 +-- .../experimental/fusion/jaxpr_rewriter.py | 25 ++++++------- jax_triton/experimental/fusion/lowering.py | 4 +-- jax_triton/triton_lib.py | 35 ++++++++++--------- pyproject.toml | 1 - 8 files changed, 39 insertions(+), 40 deletions(-) diff --git a/examples/block_map.py b/examples/block_map.py index 8877a1d1..8e25ae47 100644 --- a/examples/block_map.py +++ b/examples/block_map.py @@ -14,7 +14,6 @@ import functools -from typing import Optional import jax import jax.numpy as jnp @@ -113,7 +112,7 @@ def mha(q, k, v, *, sm_scale: float = 1.0, block_q: int = 128, block_k: int = 128, - num_warps: Optional[int] = None, + num_warps: int | None = None, num_stages: int = 1, grid=None, ): diff --git a/examples/pallas/blocksparse_matmul.py b/examples/pallas/blocksparse_matmul.py index a7f08239..540c7c4c 100644 --- a/examples/pallas/blocksparse_matmul.py +++ b/examples/pallas/blocksparse_matmul.py @@ -15,7 +15,6 @@ import functools import timeit -from typing import Tuple import jax.numpy as jnp from jax import random @@ -63,7 +62,7 @@ class BlockELL: blocks: jnp.ndarray # float32[n_rows, n_blocks, *block_size] blocks_per_row: jnp.ndarray # int32[n_rows, n_blocks] indices: jnp.ndarray # int32[n_rows, max_num_blocks_per_row, 2] - shape: Tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1]) + shape: tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1]) ndim: int = property(lambda self: len(self.shape)) num_blocks = property(lambda self: self.blocks.shape[0]) diff --git a/examples/pallas/lstm.py b/examples/pallas/lstm.py index 7eae81a5..10495510 100644 --- a/examples/pallas/lstm.py +++ b/examples/pallas/lstm.py @@ -67,7 +67,7 @@ def body(k, acc_refs): accs = for_loop.for_loop(num_k_blocks, body, [acc_i, acc_f, acc_o, acc_g]) bs = [pl.load(b_ref, (idx_n,)) for b_ref in [b_hi_ref, b_hf_ref, b_hg_ref, b_ho_ref]] - acc_i, acc_f, acc_g, acc_o = [acc + b for acc, b in zip(accs, bs)] + acc_i, acc_f, acc_g, acc_o = (acc + b for acc, b in zip(accs, bs)) i_gate, f_gate, o_gate = ( jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o)) cell = jnp.tanh(acc_g) @@ -124,7 +124,7 @@ def lstm_cell_reference(weights, x, h, c): xs = [jnp.dot(x, w) for w in ws] hs = [jnp.dot(h, u) for u in us] accs = [x + h for x, h in zip(xs, hs)] - acc_i, acc_f, acc_g, acc_o = [acc + b[None] for acc, b in zip(accs, bs)] + acc_i, acc_f, acc_g, acc_o = (acc + b[None] for acc, b in zip(accs, bs)) i_gate, f_gate, o_gate = ( jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o)) cell = jnp.tanh(acc_g) diff --git a/jax_triton/experimental/fusion/fusion.py b/jax_triton/experimental/fusion/fusion.py index c6e73e1b..b7baf621 100644 --- a/jax_triton/experimental/fusion/fusion.py +++ b/jax_triton/experimental/fusion/fusion.py @@ -17,7 +17,7 @@ import functools import os -from typing import Any, Tuple +from typing import Any import jax from jax import lax @@ -204,7 +204,7 @@ def make_elementwise(shape, dtype, *args): class MatmulElementwise(jax_rewrite.JaxExpression): x: jax_rewrite.JaxExpression y: jax_rewrite.JaxExpression - elem_ops: Tuple[core.Primitive] + elem_ops: tuple[core.Primitive] def match(self, expr, bindings, succeed): if not isinstance(expr, MatmulElementwise): diff --git a/jax_triton/experimental/fusion/jaxpr_rewriter.py b/jax_triton/experimental/fusion/jaxpr_rewriter.py index aa0277b4..9ed0ef0e 100644 --- a/jax_triton/experimental/fusion/jaxpr_rewriter.py +++ b/jax_triton/experimental/fusion/jaxpr_rewriter.py @@ -19,7 +19,8 @@ import dataclasses import itertools as it -from typing import Any, Callable, List, Tuple, Union +from typing import Any +from collections.abc import Callable from jax._src import core as jax_core import jax.numpy as jnp @@ -35,7 +36,7 @@ class Node(matcher.Pattern, metaclass=abc.ABCMeta): @abc.abstractproperty - def parents(self) -> List[Node]: + def parents(self) -> list[Node]: ... @@ -51,9 +52,9 @@ def map_parents(self, fn: Callable[[Node], Node]) -> Node: class Eqn(Node): primitive: jax_core.Primitive params: jr.Params - invars: List[Node] - shape: Union[Tuple[int, ...], List[Tuple[int, ...]]] - dtype: Union[jnp.dtype, List[jnp.dtype]] + invars: list[Node] + shape: tuple[int, ...] | list[tuple[int, ...]] + dtype: jnp.dtype | list[jnp.dtype] @property def parents(self): @@ -77,7 +78,7 @@ def match(self, expr, bindings, succeed): @dataclasses.dataclass(frozen=True, eq=False) class JaxprVar(Node): - shape: Tuple[int, ...] + shape: tuple[int, ...] dtype: jnp.dtype def match(self, expr, bindings, succeed): @@ -131,7 +132,7 @@ def from_literal(cls, var: jax_core.Literal) -> Literal: @dataclasses.dataclass(eq=False) class Part(Node): index: int - shape: Tuple[int, ...] + shape: tuple[int, ...] dtype: jnp.dtype parent: Node @@ -153,9 +154,9 @@ def map_parents(self, fn): @dataclasses.dataclass(eq=True) class JaxprGraph(matcher.Pattern): - constvars: List[Node] - invars: List[Node] - outvars: List[Node] + constvars: list[Node] + invars: list[Node] + outvars: list[Node] def get_nodes(self): nodes = set(self.outvars) @@ -167,7 +168,7 @@ def get_nodes(self): queue.append(p) return nodes - def get_children(self, node) -> List[Node]: + def get_children(self, node) -> list[Node]: nodes = self.get_nodes() return [n for n in nodes if node in n.parents] @@ -274,7 +275,7 @@ def to_jaxpr(self) -> jax_core.Jaxpr: outvars = [env[n] for n in self.outvars] return jax_core.Jaxpr(constvars, invars, outvars, eqns, jax_core.no_effects) - def toposort(self) -> List[Node]: + def toposort(self) -> list[Node]: node_stack = list(self.outvars) child_counts = {} while node_stack: diff --git a/jax_triton/experimental/fusion/lowering.py b/jax_triton/experimental/fusion/lowering.py index d3631c5d..c05edcaa 100644 --- a/jax_triton/experimental/fusion/lowering.py +++ b/jax_triton/experimental/fusion/lowering.py @@ -15,7 +15,7 @@ """Contains lowering passes for jaxprs to pallas.""" import functools -from typing import Any, Dict +from typing import Any import jax from jax import api_util @@ -317,7 +317,7 @@ def read(v: core.Atom) -> Any: def write(v: Var, val: Any) -> None: env[v] = val - env: Dict[Var, Any] = {} + env: dict[Var, Any] = {} map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index dc542e73..dc6c9b9e 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -25,7 +25,8 @@ import pprint import tempfile import types -from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union +from typing import Any, Protocol, Union +from collections.abc import Callable, Sequence import zlib from functools import partial @@ -102,11 +103,11 @@ jnp.dtype("bool"): "B", } -Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] -GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] +Grid = Union[int, tuple[int], tuple[int, int], tuple[int, int, int]] +GridOrLambda = Union[Grid, Callable[[dict[str, Any]], Grid]] -def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]: +def normalize_grid(grid: GridOrLambda, metaparams) -> tuple[int, int, int]: if callable(grid): grid = grid(metaparams) if isinstance(grid, int): @@ -186,8 +187,8 @@ class CompilationResult: name: str shared_mem_bytes: int cluster_dims: tuple - ttgir: Optional[str] - llir: Optional[str] + ttgir: str | None + llir: str | None def compile_ttir_inplace( ttir, @@ -375,7 +376,7 @@ def get_or_create_triton_kernel( enable_fp_fusion, metaparams, dump: bool, -) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]: +) -> tuple[triton_kernel_call_lib.TritonKernel, Any]: if num_warps is None: num_warps = 4 if num_stages is None: @@ -730,7 +731,7 @@ def prune_configs(configs, named_args, **kwargs): class ShapeDtype(Protocol): @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: ... @property @@ -739,21 +740,21 @@ def dtype(self) -> np.dtype: def triton_call( - *args: Union[jax.Array, bool, int, float, np.float32], + *args: jax.Array | bool | int | float | np.float32, kernel: triton.JITFunction, - out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], + out_shape: ShapeDtype | Sequence[ShapeDtype], grid: GridOrLambda, name: str = "", custom_call_target_name: str = "triton_kernel_call", - num_warps: Optional[int] = None, - num_stages: Optional[int] = None, + num_warps: int | None = None, + num_stages: int | None = None, num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple. - compute_capability: Optional[int] = None, + compute_capability: int | None = None, enable_fp_fusion: bool = True, - input_output_aliases: Optional[Dict[int, int]] = None, - zeroed_outputs: Union[ - Sequence[int], Callable[[Dict[str, Any]], Sequence[int]] - ] = (), + input_output_aliases: dict[int, int] | None = None, + zeroed_outputs: ( + Sequence[int] | Callable[[dict[str, Any]], Sequence[int]] + ) = (), debug: bool = False, serialized_metadata: bytes = b"", **metaparams: Any, diff --git a/pyproject.toml b/pyproject.toml index 59c280d9..a24b8651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ dependencies = [ "absl-py>=1.4.0", "jax>=0.4.31", "triton>=3.0", - "setuptools", # triton seems to need this when installing itself. ] [project.optional-dependencies]