-
Notifications
You must be signed in to change notification settings - Fork 14.1k
metal: SSM kernel improvements #17876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Mind adding some representative cases to |
| } else { | ||
| auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); | ||
|
|
||
| ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
|
|
||
| ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the old kernel faster for ne1 == 1? If not, we can remove it and always use the batched kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I'll test that today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I'll test that today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like non-batched is significantly faster for ne1 == 1, so I think we should keep both paths.
| // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead | ||
| constexpr int BATCH_SIZE = 256; | ||
| const bool use_batched = (ne1 > 1); | ||
|
|
||
| ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
| if (use_batched) { | ||
| auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op); | ||
|
|
||
| ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
|
|
||
| ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); | ||
| // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences | ||
| // Each threadgroup has BATCH_SIZE threads, each handling one token | ||
| const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; | ||
| ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); | ||
| } else { | ||
| auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); | ||
|
|
||
| ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
|
|
||
| ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - I am quite surprised that this change makes such a big difference. I have to try this approach for all other kernels that launch threadgroups with just 1 thread: unary, binary, scale, clamp, fill, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I was very surprised too. All credit to Claude Code with Opus 4.5 for the insight.
ggml/src/ggml-metal/ggml-metal.metal
Outdated
|
|
||
| // Batched version: each threadgroup processes multiple tokens for better efficiency | ||
| // Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens | ||
| template<int BATCH_SIZE> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parameter does not have to be a template. The better pattern is to make it a function constant. For example, see how this works:
llama.cpp/ggml/src/ggml-metal/ggml-metal.metal
Lines 3108 to 3122 in 86a3f0f
| constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; | |
| constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; | |
| template<typename block_q_type, short NR0, typename args_t> | |
| void mul_vec_q_n_f32_impl( | |
| args_t args, | |
| device const char * src0, | |
| device const char * src1, | |
| device char * dst, | |
| threadgroup char * shmem, | |
| uint3 tgpig, | |
| ushort tiisg, | |
| ushort sgitg) { | |
| const short NSG = FC_mul_mv_nsg; | |
Then observe how we pass the function constants during the construction of the pipeline:
llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp
Lines 472 to 494 in 86a3f0f
| ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { | |
| char base[256]; | |
| char name[256]; | |
| snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); | |
| snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); | |
| ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | |
| if (!res.pipeline) { | |
| ggml_metal_cv_t cv = ggml_metal_cv_init(); | |
| ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); | |
| ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); | |
| res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | |
| ggml_metal_cv_free(cv); | |
| } | |
| return res; | |
| } | |
This way, we can construct pipelines at runtime with different FC_ssm_conv_bs, instead of defining many template instantiations that would increase the compile time.
When you do that, add logic in ggml_metal_op_ssm_conv to determine the smallest power of 2 that is larger or equal to ne1 and less or equal than 256. Use that power to construct a pipeline with the respective batch size.
For example, if ne1 == 100, we want a pipeline with FC_ssm_conv_bs == 128. And so on.
Test results for
|
|
You do want lower us/run, higher GB/s. The two values are the same data, just GB/s is computed by summing tensor sizes and dividing by the runtime. |
|
@gabe-l-hart These perf numbers don't add up to the observed |
|
Ok, I'm glad I'm not crazy! This does seem very fishy given the improved results with pp. I'll try to make these a closer representative test. |
This was done using Claude Code. It found a number of optimizations around how the threads were organized, resulting in a huge performance boost! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This used Claude Code and resulted in a modest performance improvement while maintaining correctness. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
|
Oh boy, that makes much more sense! // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
// d_inner == 3072
// d_conv == 4
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}));BatchedUn-Batched |
|
With a single-token (generate) example: Batch for prefill onlyBatch for bothNon-Batch for bothGiven this, I think we should keep both versions dispatched on |
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
427ae08 to
da044cd
Compare
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
|
Some similar numbers for representative tests on Without OptimizationsWith Optimizations |
ggml/src/ggml-metal/ggml-metal.metal
Outdated
| x[0] = sumf; | ||
| } | ||
|
|
||
| // typedef decltype(kernel_ssm_conv_f32_f32_batched<1>) kernel_ssm_conv_batched_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: SSMKernelImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
|
Another small speedup with a |
|
Hm, the failing test looks suspiciously related to this PR, but it's failing on |
ggerganov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failing tests seems like a fluke - should be safe to ignore
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This brings in significant improvements to prefill performance for all models using the SSM_CONV and SSM_SCAN ops (granite4, jamba, falcon-h, nemotron-h, Qwen3 Next) on Apple Metal. See ggml-org/llama.cpp#17876 Branch: LlamaCPPMetalSSMImprovements Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Description
This branch adds performance improvements on
metalfor bothSSM_CONVandSSM_SCAN. The kernels were heavily edited by Claude Code, but I've reviewed all changes.Changes
SSM_CONV: Implement abatchedversion that uses batches of256threads for multi-token prefill.ydim of the outer grid intoy / BATCH_SIZEand useBATCH_SIZEasxfor threadgroup (inner grid)tgpig.yandtpitg.xSSM_SCAN: Reduce redundantx_dt/dAcomputationsPerformance
./bin/llama-batched-bench -m ~/models/ibm-granite/granite-4.0-h-1b/ggml-model-Q8_0.gguf -c 131072 -b 2048 -ub 512 -npp 1024,4096,8192 -ntg 128 -npl 1,4,8 -ngl 99Baseline (c8554b6)
SSM_CONV improvements
SSM_CONV + SSM_SCAN improvements