diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba3c3427517..680904d132d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + const char * suffix = ""; + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); + snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); @@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res.smem = 32*sizeof(float)*nsg; + // Shared memory layout: + // - sgptg * NW floats for partial sums (nsg * 32) + // - sgptg floats for shared_x_dt (nsg) + // - sgptg floats for shared_dA (nsg) + // Total: nsg * (32 + 2) floats + res.smem = (32 + 2)*sizeof(float)*nsg; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 77f2e98cfe8..0a8b9211a76 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 30109f83e10..8944b07e907 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -77,6 +77,7 @@ #define FC_MUL_MV 600 #define FC_MUL_MM 700 #define FC_ROPE 800 +#define FC_SSM_CONV 900 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 392addb8d1f..e99c1763f63 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + const bool use_batched = (ne1 > 1); + + if (use_batched) { + // Determine the smallest power of 2 that's >= ne1, but <= 256 + int BATCH_SIZE; + if (ne1 > 128) BATCH_SIZE = 256; + else if (ne1 > 64 ) BATCH_SIZE = 128; + else if (ne1 > 32 ) BATCH_SIZE = 64; + else if (ne1 > 16 ) BATCH_SIZE = 32; + else if (ne1 > 8 ) BATCH_SIZE = 16; + else if (ne1 > 4 ) BATCH_SIZE = 8; + else BATCH_SIZE = 2; + + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b78d5a2bad..51bcbae309f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]]; + +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_batched_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); s = (s0 * dA) + (B[i0] * x_dt); const float sumf = simd_sum(s * C[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } // recurse s0 = s; - x += args.ns12; - dt += args.ns21; B += args.ns42; C += args.ns52; } + // Advance pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e94a53da25..376b6ce4ee3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8164,6 +8164,13 @@ static std::vector> make_test_cases_perf() { } } + // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate + + return test_cases; }