Skip to content

Commit

Permalink
nit: Use frozen dataclasses rather than unsafe_hash.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687267707
  • Loading branch information
tomhennigan authored and Google-ML-Automation committed Oct 18, 2024
1 parent 4094564 commit 8615556
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def from_head_minor(vals: tuple[Any, ...], layout: QKVLayout):
return (*vals[:-2], vals[-1], vals[-2])


@dataclasses.dataclass(unsafe_hash=True)
@dataclasses.dataclass(frozen=True, slots=True)
class BlockSizes:
"""Tile sizes parameterizing SplashAttention kernels.
Expand Down Expand Up @@ -498,9 +498,9 @@ class BlockSizes:

def __post_init__(self):
if self.block_kv_compute is None:
self.block_kv_compute = self.block_kv
object.__setattr__(self, "block_kv_compute", self.block_kv)
if self.block_kv_dkv_compute is None:
self.block_kv_dkv_compute = self.block_kv_dkv
object.__setattr__(self, "block_kv_dkv_compute", self.block_kv_dkv)
if self.use_fused_bwd_kernel:
if self.block_q_dq is not None or self.block_kv_dq is not None:
raise ValueError(
Expand Down

0 comments on commit 8615556

Please sign in to comment.