From 86155561fb5c9d79116d38a91f2f21b59f5e0ad2 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Fri, 18 Oct 2024 05:33:05 -0700 Subject: [PATCH] nit: Use `frozen` dataclasses rather than `unsafe_hash`. PiperOrigin-RevId: 687267707 --- .../ops/tpu/splash_attention/splash_attention_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 536c32e574b2..7124d64f085a 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -468,7 +468,7 @@ def from_head_minor(vals: tuple[Any, ...], layout: QKVLayout): return (*vals[:-2], vals[-1], vals[-2]) -@dataclasses.dataclass(unsafe_hash=True) +@dataclasses.dataclass(frozen=True, slots=True) class BlockSizes: """Tile sizes parameterizing SplashAttention kernels. @@ -498,9 +498,9 @@ class BlockSizes: def __post_init__(self): if self.block_kv_compute is None: - self.block_kv_compute = self.block_kv + object.__setattr__(self, "block_kv_compute", self.block_kv) if self.block_kv_dkv_compute is None: - self.block_kv_dkv_compute = self.block_kv_dkv + object.__setattr__(self, "block_kv_dkv_compute", self.block_kv_dkv) if self.use_fused_bwd_kernel: if self.block_q_dq is not None or self.block_kv_dq is not None: raise ValueError(