diff --git a/megakernels/demos/latency/scheduler.py b/megakernels/demos/latency/scheduler.py index 1954abed..7afe196c 100644 --- a/megakernels/demos/latency/scheduler.py +++ b/megakernels/demos/latency/scheduler.py @@ -48,7 +48,10 @@ def make_buffer(shape, buffer_dtype=dtype): stacked_params = model.stacked_params - max_attn_partitions = get_sm_count(device) + def align_sm_to_partition_cacheline(x: int) -> int: + return int(math.ceil(x / 16) * 16) + + max_attn_partitions = align_sm_to_partition_cacheline(get_sm_count(device)) barriers = torch.zeros( [ @@ -60,6 +63,7 @@ def make_buffer(shape, buffer_dtype=dtype): device=device, ) + return Globals( # model params qkv_proj_weights=stacked_params.qkv_proj,