Skip to content

Conversation

@gabe-l-hart
Copy link
Collaborator

Description

This branch adds performance improvements on metal for both SSM_CONV and SSM_SCAN. The kernels were heavily edited by Claude Code, but I've reviewed all changes.

Changes

  • SSM_CONV: Implement a batched version that uses batches of 256 threads for multi-token prefill.
    • Split what was the y dim of the outer grid into y / BATCH_SIZE and use BATCH_SIZE as x for threadgroup (inner grid)
    • Recompute offsets from tgpig.y and tpitg.x
  • SSM_SCAN: Reduce redundant x_dt / dA computations

Performance

./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/ggml-model-Q8_0.gguf -c 131072 -b 2048 -ub 512 -npp 1024,4096,8192 -ntg 128 -npl 1,4,8 -ngl 99

Baseline (c8554b6)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.719 1423.52 1.407 91.00 2.126 541.87
1024 128 4 4608 2.877 1423.62 2.947 173.72 5.824 791.15
1024 128 8 9216 5.748 1425.13 5.269 194.33 11.018 836.48
4096 128 1 4224 2.887 1418.90 1.415 90.49 4.301 982.03
4096 128 4 16896 11.537 1420.09 2.990 171.24 14.527 1163.05
4096 128 8 33792 23.169 1414.28 6.213 164.80 29.383 1150.06
8192 128 1 8320 6.458 1268.54 1.437 89.11 7.894 1053.92
8192 128 4 33280 23.539 1392.09 3.236 158.24 26.774 1242.98
8192 128 8 66560 47.488 1380.06 5.973 171.43 53.461 1245.02

SSM_CONV improvements

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.457 2240.26 1.407 91.00 1.864 618.12
1024 128 4 4608 1.821 2249.29 2.948 173.68 4.769 966.24
1024 128 8 9216 3.641 2250.21 5.265 194.47 8.906 1034.80
4096 128 1 4224 1.834 2233.20 1.410 90.76 3.244 1301.91
4096 128 4 16896 7.332 2234.56 2.997 170.86 10.329 1635.82
4096 128 8 33792 14.683 2231.72 5.347 191.51 20.030 1687.09
8192 128 1 8320 3.723 2200.12 1.425 89.81 5.149 1615.96
8192 128 4 33280 15.000 2184.60 3.074 166.54 18.074 1841.32
8192 128 8 66560 33.971 1929.15 6.071 168.67 40.043 1662.23

SSM_CONV + SSM_SCAN improvements

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.437 2345.00 1.411 90.75 1.847 623.65
1024 128 4 4608 1.732 2364.53 2.977 171.98 4.709 978.47
1024 128 8 9216 3.487 2349.39 5.374 190.53 8.861 1040.03
4096 128 1 4224 1.753 2336.90 1.425 89.80 3.178 1329.05
4096 128 4 16896 7.007 2338.09 3.020 169.55 10.027 1685.03
4096 128 8 33792 14.042 2333.49 5.412 189.22 19.454 1737.01
8192 128 1 8320 3.572 2293.39 1.434 89.25 5.006 1661.94
8192 128 4 33280 14.758 2220.40 3.208 159.62 17.965 1852.45
8192 128 8 66560 33.036 1983.79 6.053 169.17 39.089 1702.79

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Dec 9, 2025
@jeffbolznv
Copy link
Collaborator

Mind adding some representative cases to test-backend-ops perf?

Comment on lines +1385 to +1404
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old kernel faster for ne1 == 1? If not, we can remove it and always use the batched kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'll test that today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like non-batched is significantly faster for ne1 == 1, so I think we should keep both paths.

Comment on lines 1368 to 1404
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
constexpr int BATCH_SIZE = 256;
const bool use_batched = (ne1 > 1);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
if (use_batched) {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
// Each threadgroup has BATCH_SIZE threads, each handling one token
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - I am quite surprised that this change makes such a big difference. I have to try this approach for all other kernels that launch threadgroups with just 1 thread: unary, binary, scale, clamp, fill, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I was very surprised too. All credit to Claude Code with Opus 4.5 for the insight.


// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
template<int BATCH_SIZE>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter does not have to be a template. The better pattern is to make it a function constant. For example, see how this works:

constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
template<typename block_q_type, short NR0, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;

Then observe how we pass the function constants during the construction of the pipeline:

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;
}

This way, we can construct pipelines at runtime with different FC_ssm_conv_bs, instead of defining many template instantiations that would increase the compile time.

When you do that, add logic in ggml_metal_op_ssm_conv to determine the smallest power of 2 that is larger or equal to ne1 and less or equal than 256. Use that power to construct a pipeline with the respective batch size.

For example, if ne1 == 100, we want a pipeline with FC_ssm_conv_bs == 128. And so on.

@gabe-l-hart
Copy link
Collaborator Author

Test results for test-backend-ops perf -o SSM_CONV -b Metal

@jeffbolznv I've honestly never quite known how to interpret the results of test-backend-ops perf. Intuitively, I would think a higher us/run would mean slower and a higher GB/s would mean more efficient throughput, but these results seem to show I have that exactly backwards.

@ggerganov Assuming my intuition is backwards, then it looks like the batched implementation is indeed always faster. Would it be worth also adding a float4 variant of the batch implementation?

Test results for test-backend-ops perf:

Baseline (non-batch)

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32', name = 'kernel_ssm_conv_f32_f32'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32                       0x12010a0a0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):             139264 runs -     7.19 us/run -       36 kB/run -    4.77 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              65536 runs -    16.12 us/run -       68 kB/run -    4.02 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    20.51 us/run -      108 kB/run -    5.02 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):             114688 runs -     9.31 us/run -       54 kB/run -    5.53 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              49152 runs -    22.91 us/run -      102 kB/run -    4.25 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              40960 runs -    29.60 us/run -      162 kB/run -    5.22 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              90112 runs -    11.61 us/run -       72 kB/run -    5.92 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              40960 runs -    29.68 us/run -      136 kB/run -    4.37 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    38.45 us/run -      216 kB/run -    5.36 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x11df051c0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):             204800 runs -     5.04 us/run -       36 kB/run -    6.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              73728 runs -    14.26 us/run -       68 kB/run -    4.55 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              90112 runs -    11.88 us/run -       96 kB/run -    7.71 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):             163840 runs -     6.13 us/run -       54 kB/run -    8.40 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              57344 runs -    19.96 us/run -      102 kB/run -    4.87 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              65536 runs -    16.36 us/run -      144 kB/run -    8.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):             139264 runs -     7.38 us/run -       72 kB/run -    9.30 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              40960 runs -    25.68 us/run -      136 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              49152 runs -    20.84 us/run -      192 kB/run -    8.78 GB/s

Batched for n_t > 1

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x107505090 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.91 us/run -       36 kB/run -    1.72 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    20.04 us/run -       68 kB/run -    3.24 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    21.20 us/run -      108 kB/run -    4.86 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.42 us/run -       54 kB/run -    1.81 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.68 us/run -      102 kB/run -    3.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              32768 runs -    30.60 us/run -      162 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.49 us/run -       72 kB/run -    1.83 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.75 us/run -      136 kB/run -    3.44 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    40.05 us/run -      216 kB/run -    5.14 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x301505560 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):             204800 runs -     5.03 us/run -       36 kB/run -    6.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.56 us/run -       68 kB/run -    3.31 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              90112 runs -    11.93 us/run -       96 kB/run -    7.68 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):             163840 runs -     6.17 us/run -       54 kB/run -    8.35 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.45 us/run -      102 kB/run -    3.42 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              65536 runs -    16.32 us/run -      144 kB/run -    8.41 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):             139264 runs -     7.36 us/run -       72 kB/run -    9.33 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    37.16 us/run -      136 kB/run -    3.49 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              49152 runs -    20.85 us/run -      192 kB/run -    8.78 GB/s

Batched always

  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.91 us/run -       36 kB/run -    1.72 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]):              57344 runs -    19.81 us/run -       68 kB/run -    3.27 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]):              49152 runs -    21.16 us/run -      108 kB/run -    4.87 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.41 us/run -       54 kB/run -    1.81 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]):              40960 runs -    28.69 us/run -      102 kB/run -    3.39 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]):              32768 runs -    30.57 us/run -      162 kB/run -    5.05 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.20 us/run -       72 kB/run -    1.85 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]):              32768 runs -    37.75 us/run -      136 kB/run -    3.44 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]):              32768 runs -    39.89 us/run -      216 kB/run -    5.16 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.49 us/run -       36 kB/run -    1.76 GB/s
  SSM_CONV(type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]):              57344 runs -    19.60 us/run -       68 kB/run -    3.31 GB/s
  SSM_CONV(type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]):              49152 runs -    20.92 us/run -       96 kB/run -    4.38 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.35 us/run -       54 kB/run -    1.82 GB/s
  SSM_CONV(type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]):              40960 runs -    28.43 us/run -      102 kB/run -    3.42 GB/s
  SSM_CONV(type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]):              40960 runs -    29.87 us/run -      144 kB/run -    4.60 GB/s
  # NOTE: n_t == 1
  SSM_CONV(type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    36.92 us/run -       72 kB/run -    1.86 GB/s
  SSM_CONV(type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]):              32768 runs -    37.20 us/run -      136 kB/run -    3.49 GB/s
  SSM_CONV(type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]):              32768 runs -    38.95 us/run -      192 kB/run -    4.70 GB/s

@jeffbolznv
Copy link
Collaborator

You do want lower us/run, higher GB/s. The two values are the same data, just GB/s is computed by summing tensor sizes and dividing by the runtime.

@ggerganov
Copy link
Member

@gabe-l-hart These perf numbers don't add up to the observed llama-batched-bench results. Ideally, note down the actual shapes that are used during prompt processing and text generation and add tests that correspond to those shapes.

@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Dec 9, 2025

Ok, I'm glad I'm not crazy! This does seem very fishy given the improved results with pp. I'll try to make these a closer representative test.

This was done using Claude Code. It found a number of optimizations around
how the threads were organized, resulting in a huge performance boost!

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This used Claude Code and resulted in a modest performance improvement
while maintaining correctness.

Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Oh boy, that makes much more sense!

    // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
    // d_inner == 3072
    // d_conv == 4
    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}));

Batched

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x136b083d0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.64 us/run -    13403 kB/run -  180.96 GB/s

Un-Batched

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12320be70 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                     2504 runs -  3801.15 us/run -    13403 kB/run -    3.36 GB/s

@gabe-l-hart
Copy link
Collaborator Author

With a single-token (generate) example:

Batch for prefill only

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x1292076d0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.69 us/run -    13403 kB/run -  180.82 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12910b550 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.22 us/run -      117 kB/run -   10.91 GB/s
  Backend Metal: OK

Batch for both

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_b256', name = 'kernel_ssm_conv_f32_f32_b256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_b256                  0x157b070e0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    70.58 us/run -    13403 kB/run -  181.11 GB/s
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              24576 runs -    59.17 us/run -      117 kB/run -    1.89 GB/s
  Backend Metal: OK

Non-Batch for both

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x135e075e0 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                     2504 runs -  3801.51 us/run -    13403 kB/run -    3.36 GB/s
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.26 us/run -      117 kB/run -   10.88 GB/s
  Backend Metal: OK

Given this, I think we should keep both versions dispatched on n_t like it is currently.

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart gabe-l-hart force-pushed the SSMKernelImprovements branch from 427ae08 to da044cd Compare December 9, 2025 18:04
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Some similar numbers for representative tests on SSM_CONV:

Without Optimizations

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_scan_f32', name = 'kernel_ssm_scan_f32_nsg=4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_scan_f32_nsg=4                     0x149906a50 | th_max = 1024 | th_width =   32
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=512,n_seqs=1):                     2102 runs -  2178.75 us/run -    15968 kB/run -    6.99 GB/s
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=1,n_seqs=1):              40960 runs -    24.85 us/run -     3097 kB/run -  118.88 GB/s
  Backend Metal: OK

With Optimizations

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_scan_f32', name = 'kernel_ssm_scan_f32_nsg=4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_scan_f32_nsg=4                     0x146408690 | th_max = 1024 | th_width =   32
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=512,n_seqs=1):                     2102 runs -  1909.71 us/run -    15968 kB/run -    7.97 GB/s
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=48,n_group=1,n_seq_tokens=1,n_seqs=1):              40960 runs -    26.50 us/run -     3097 kB/run -  111.48 GB/s
  Backend Metal: OK

x[0] = sumf;
}

// typedef decltype(kernel_ssm_conv_f32_f32_batched<1>) kernel_ssm_conv_batched_t;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

Another small speedup with a float4 version of the SSM_CONV batched impl:

ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_batched_4', name = 'kernel_ssm_conv_f32_f32_batched_4_ssm_conv_bs=256'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_batched_4_ssm_conv_bs=256      0x126304280 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[515,3328,1,1],ne_b=[4,3328,1,1]):                    15024 runs -    68.79 us/run -    13403 kB/run -  185.80 GB/s
ggml_metal_library_compile_pipeline: compiling pipeline: base = 'kernel_ssm_conv_f32_f32_4', name = 'kernel_ssm_conv_f32_f32_4'
ggml_metal_library_compile_pipeline: loaded kernel_ssm_conv_f32_f32_4                     0x12610d700 | th_max = 1024 | th_width =   32
  SSM_CONV(type=f32,ne_a=[4,3328,1,1],ne_b=[4,3328,1,1]):              98304 runs -    10.33 us/run -      117 kB/run -   10.80 GB/s
  Backend Metal: OK

@gabe-l-hart
Copy link
Collaborator Author

Hm, the failing test looks suspiciously related to this PR, but it's failing on CONV_2D which seems to pass just fine on my machine.

2025-12-09T18:44:28.7929820Z Failing tests:
2025-12-09T18:44:28.7930170Z   CONV_2D(ne_input=[1,1,25,2],ne_kernel=[2,1,25,1],type_kernel=f16,stride0=3,stride1=5,padding0=5,padding1=5,dilation0=2,dilation1=4,cwhn=0)
2025-12-09T18:44:28.7930420Z   Backend Metal: �[1;31mFAIL�[0m

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing tests seems like a fluke - should be safe to ignore

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@ggerganov ggerganov merged commit 086a63e into ggml-org:master Dec 9, 2025
62 of 69 checks passed
@gabe-l-hart gabe-l-hart deleted the SSMKernelImprovements branch December 9, 2025 19:30
@github-actions github-actions bot added the testing Everything test related label Dec 9, 2025
gabe-l-hart added a commit to gabe-l-hart/ollama that referenced this pull request Dec 10, 2025
This brings in significant improvements to prefill performance for all
models using the SSM_CONV and SSM_SCAN ops (granite4, jamba, falcon-h,
nemotron-h, Qwen3 Next) on Apple Metal.

See ggml-org/llama.cpp#17876

Branch: LlamaCPPMetalSSMImprovements

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants