From 4c768e638c2b81501421c1f7970e20dac17657b4 Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Sun, 24 Aug 2025 08:21:56 +0000 Subject: [PATCH 1/3] best --- src/__pycache__/args_config.cpython-312.pyc | Bin 556 -> 546 bytes src/agents/__pycache__/Base.cpython-312.pyc | Bin 4986 -> 4976 bytes .../__pycache__/Reflexion.cpython-312.pyc | Bin 2883 -> 2873 bytes .../reflexion_oneshot.cpython-312.pyc | Bin 9006 -> 8996 bytes .../TB_eval/__pycache__/utils.cpython-312.pyc | Bin 13229 -> 13219 bytes ...TritonBench_G_comp_alpac_v1_hackathon.json | 73 +++++- .../__pycache__/ProblemState.cpython-312.pyc | Bin 1187 -> 1177 bytes .../__pycache__/TritonBench.cpython-312.pyc | Bin 16573 -> 16563 bytes .../__pycache__/Memory.cpython-312.pyc | Bin 1368 -> 1358 bytes src/models/__pycache__/Base.cpython-312.pyc | Bin 631 -> 621 bytes src/models/__pycache__/KimiK2.cpython-312.pyc | Bin 2183 -> 2173 bytes .../prompt_for_generation.cpython-312.pyc | Bin 9969 -> 10094 bytes .../prompt_for_reflection.cpython-312.pyc | Bin 14265 -> 14679 bytes .../__pycache__/retriever.cpython-312.pyc | Bin 3354 -> 3344 bytes src/temp/embedding_triton_kernel.py | 131 +++++++++++ src/temp/flash_decode2_phi.py | 147 ++++++++++++ src/temp/int4_matmul.py | 194 ++++++++++++++++ src/temp/l2_norm_bwd.py | 94 ++++++++ src/temp/l2_norm_triton1.py | 83 +++++++ src/temp/matrix_transpose.py | 71 ++++++ src/temp/matrix_vector_multip.py | 83 +++++++ src/temp/rotary_transform.py | 212 ++++++++++++++++++ src/temp/sin_kernel.py | 69 ++++++ src/temp/tmp.py | 119 ++++++++++ src/temp/triton_matmul.py | 118 ++++++++++ src/utils/__pycache__/utils.cpython-312.pyc | Bin 2442 -> 2432 bytes 26 files changed, 1393 insertions(+), 1 deletion(-) create mode 100644 src/temp/embedding_triton_kernel.py create mode 100644 src/temp/flash_decode2_phi.py create mode 100644 src/temp/int4_matmul.py create mode 100644 src/temp/l2_norm_bwd.py create mode 100644 src/temp/l2_norm_triton1.py create mode 100644 src/temp/matrix_transpose.py create mode 100644 src/temp/matrix_vector_multip.py create mode 100644 src/temp/rotary_transform.py create mode 100644 src/temp/sin_kernel.py create mode 100644 src/temp/tmp.py create mode 100644 src/temp/triton_matmul.py diff --git a/src/__pycache__/args_config.cpython-312.pyc b/src/__pycache__/args_config.cpython-312.pyc index ed62ea94b9178a5147963fc8bfe6e3c473ed5043..d9f21e52b3a3ff2d938cf36eab4c625fe3f9d5b3 100644 GIT binary patch delta 26 gcmZ3(vWSKIG%qg~0}woDT)C0khLKTuauDM!09=3vH~;_u delta 36 qcmZ3)vWA8GG%qg~0}vF%&DqFp!^o|xU!Gr-U0jfuoI2Tq@fHB9a|)mU diff --git a/src/agents/__pycache__/Base.cpython-312.pyc b/src/agents/__pycache__/Base.cpython-312.pyc index f0a272db34755c9ad5ac90fc0aa6f7c474ac611d..b737dc97e5f9d466eb4d2512e3e05947d0342b4d 100644 GIT binary patch delta 27 hcmeyR_CbyNG%qg~0}woDT)C0kgNadja~zY5AOLz02c-Z2 delta 37 rcmeyM_DhZXG%qg~0}vF%&DqH9!NjerU!Gr-U0jfuoVq!LNk$L=)vXH+ diff --git a/src/agents/__pycache__/Reflexion.cpython-312.pyc b/src/agents/__pycache__/Reflexion.cpython-312.pyc index 54cab4bd97f26a4a5622127942cb5f6a38e1bf70..f3cd5b1661e699515c8acbe0e354a883e92ea2d8 100644 GIT binary patch delta 27 hcmX>swo{DzG%qg~0}woDT)C0^4I`t*W+o;_P5^Ll2Y>(o delta 37 rcmdlfc36!2G%qg~0}vF%&DqHPhLPJyzdXMvySN}RId$_7Mn_Hn(GCme diff --git a/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc index db3379e04a5f6150b194bba0a6cbd55a6310e256..2ce4e493cd2da34bd46f6fae9f8d8f0c3bfee1f0 100644 GIT binary patch delta 101 zcmZ4Iw#1F^G%qg~0}woDT$$0Vx{9Rxw_IFM6B5iG?sS7Ekt;kZ1B@ z+MFZ7&C0lTbA_w}Gvme0I~9Z(*`gU)jX&3IzNF~MDD20~7|b|>^D6_GQKSqs7Xb5r B8|?r9 delta 111 zcmZ4Dw$6?3G%qg~0}vF%&B^$xxRI}imD@_cJijQrxF9h(b@OaiA3;`TUV$$ho8O3q zFf#g1c9f83N@3ibB*D$fIBRpBtOGOSzRl|ugc;d97+H-!M{GW!=*cKt&BPeYID_*m L1DH{y4733NO|&06 diff --git a/src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc b/src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc index ef7453502bf33942c59ea2b7da0a5190dcfc03a2..4bbd6279c7465fe3169da5fc9dc6fdb425a483e2 100644 GIT binary patch delta 26 gcmZ3RzBrxxG%qg~0}woDT)B{&iP3Pg7}In^0B&IimH+?% delta 36 qcmZ3SzBZlvG%qg~0}vF%%~{CJ#BHr#o?nz*T#%TYx|xS*x*-6@$_oAf diff --git a/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json b/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json index 8171858..370aae5 100644 --- a/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json +++ b/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json @@ -1 +1,72 @@ -[{"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", "input": "", "output": "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": n}, num_stages=s, num_warps=w)\n for m in [32, 64, 128]\n for n in [1, 2, 4, 8]\n for s in [3, 4]\n for w in [4, 8]\n ],\n key=[\"M\", \"N\"],\n)\n@triton.jit\ndef mv_kernel(\n A,\n B,\n C,\n N,\n M,\n stride_an,\n stride_am,\n stride_bm,\n stride_cn,\n BLOCK_N: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]\n offset_m = tl.arange(0, BLOCK_M)[None, :]\n n_mask = offset_n < N\n A_ptrs = A + offset_n * stride_an + offset_m * stride_am\n B_ptrs = B + offset_m * stride_bm\n acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n for m in range(0, M, BLOCK_M):\n m_mask = m + offset_m < M\n a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)\n b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)\n acc += a * b\n A_ptrs += BLOCK_M * stride_am\n B_ptrs += BLOCK_M * stride_bm\n\n acc = tl.sum(acc, axis=1)\n C_ptrs = C + offset_n * stride_cn\n tl.store(C_ptrs, acc[:, None], mask=n_mask)\n\n\ndef mv(inp, vec):\n logging.debug(\"GEMS MV\")\n assert inp.shape[1] == vec.shape[0], \"incompatible dimensions\"\n N, M = inp.shape\n out = torch.empty((N,), device=inp.device, dtype=inp.dtype)\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_N\"]),)\n with torch.cuda.device(inp.device):\n mv_kernel[grid](\n inp,\n vec,\n out,\n N,\n M,\n inp.stride(0),\n inp.stride(1),\n vec.stride(0),\n out.stride(0),\n )\n return out\n\n\n\n\n", "file": "matrix_vector_multip.py", "difficulty": "4"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef _matmul_launch_metadata(grid, kernel, args):\n ret = {}\n M, N, K = args[\"M\"], args[\"N\"], args[\"K\"]\n ret[\"name\"] = f\"{kernel.name} [M={M}, N={N}, K={K}]\"\n if \"c_ptr\" in args:\n bytes_per_elem = args[\"c_ptr\"].element_size()\n else:\n bytes_per_elem = 1 if args[\"FP8_OUTPUT\"] else 2\n ret[f\"flops{bytes_per_elem * 8}\"] = 2. * M * N * K\n ret[\"bytes\"] = bytes_per_elem * (M * K + N * K + M * N)\n return ret\n\n\n@triton.jit(launch_metadata=_matmul_launch_metadata)\ndef matmul_kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, #\n BLOCK_SIZE_N: tl.constexpr, #\n BLOCK_SIZE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, #\n ):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n start_m = pid_m * BLOCK_SIZE_M\n start_n = pid_n * BLOCK_SIZE_N\n\n offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n offs_am = tl.where(offs_am < M, offs_am, 0)\n offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n\n offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if (c_ptr.dtype.element_ty == tl.float8e4nv):\n c = accumulator.to(tl.float8e4nv)\n else:\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(a, b):\n configs = {\n torch.float8_e4m3fn: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 128, \"GROUP_SIZE_M\": 8, \"num_stages\": 4,\n \"num_warps\": 8\n }, torch.float16: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64, \"GROUP_SIZE_M\": 8, \"num_stages\": 3,\n \"num_warps\": 8\n }\n }\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.dtype == b.dtype, \"Incompatible dtypes\"\n M, K = a.shape\n K, N = b.shape\n dtype = a.dtype\n\n c = torch.empty((M, N), device=a.device, dtype=dtype)\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]), )\n matmul_kernel[grid](\n a, b, c, #\n M, N, K, #\n a.stride(0), a.stride(1), #\n b.stride(0), b.stride(1), #\n c.stride(0), c.stride(1), #\n BLOCK_SIZE_M=configs[dtype][\"BLOCK_SIZE_M\"], #\n BLOCK_SIZE_N=configs[dtype][\"BLOCK_SIZE_N\"], #\n BLOCK_SIZE_K=configs[dtype][\"BLOCK_SIZE_K\"], #\n GROUP_SIZE_M=configs[dtype][\"GROUP_SIZE_M\"], #\n num_stages=configs[dtype][\"num_stages\"], #\n num_warps=configs[dtype][\"num_warps\"], #\n )\n return c\n\n\n\n\n", "file": "triton_matmul.py", "difficulty": "3"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef embedding_kernel(\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n stride_weight_seq,\n stride_out_seq,\n n_ctx,\n hiden_size: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n):\n start_n = tl.program_id(0) * BLOCK_N\n\n offs_nn = start_n + tl.arange(0, BLOCK_NN)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n for start_nn in range(0, BLOCK_N, BLOCK_NN):\n start_nn = tl.multiple_of(start_nn, BLOCK_NN)\n offs_seq = start_nn + offs_nn\n n_ctx_mask = offs_seq < n_ctx\n token_ids = tl.load(input_ids + offs_seq, mask=n_ctx_mask, other=vob_end_id)\n id_mask = (token_ids >= vob_start_id) & (token_ids < vob_end_id)\n token_ids = token_ids - vob_start_id\n dim_mask = offs_d < hiden_size\n load_mask = id_mask[:, None] & dim_mask[None, :]\n store_mask = n_ctx_mask[:, None] & dim_mask[None, :]\n vecs = tl.load(weight + token_ids[:, None] * stride_weight_seq + offs_d[None, :], mask=load_mask, other=0.0)\n tl.store(out + offs_seq[:, None] * stride_out_seq + offs_d[None, :], vecs, mask=store_mask)\n\n@torch.no_grad()\ndef embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: torch.Tensor):\n BLOCK_N = 64\n BLOCK_NN = 1\n BLOCK_DMODEL = triton.next_power_of_2(weight.shape[1])\n n_ctx = input_ids.shape[0]\n\n grid = (triton.cdiv(n_ctx, BLOCK_N), 1, 1)\n\n embedding_kernel[grid](\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n weight.stride(0),\n out.stride(0),\n n_ctx=n_ctx,\n hiden_size=weight.shape[1],\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n num_warps=1,\n num_stages=1,\n )\n\n\n\n\n", "file": "embedding_triton_kernel.py", "difficulty": "3"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", "input": "", "output": "import time\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n\tconfigs=[\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t \n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\t\n ],\n\tkey=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n bs_ptr, bzp_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n group_size,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr\n ):\n \"\"\"\n assert K % (BLOCK_SIZE_K * SPLIT_K) == 0\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n # [BLOCK_M, BLOCK_K]\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n # [BLOCK_K, BLOCK_N] but repeated 8 times in N\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn\n # tl.static_print(\"shape\", a_ptrs, b_ptrs, bs_ptrs, bzp_ptrs)\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n # Load the next block of A and B.\n # [BLOCK_K, BLOCK_N] but repeated group_size times in K \n bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \\\n + offs_bn[None, :] * stride_bsn\n # [BLOCK_K, BLOCK_N] but repeated in K and N\n bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \\\n + (offs_bn[None, :] // 8) * stride_bzpn\n b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0\n bzp_shift_bits = (offs_bn[None, :] % 8) * 4\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n bs = tl.load(bs_ptrs)\n bzp = tl.load(bzp_ptrs)\n # We accumulate along the K dimension.\n int_b = (b >> b_shift_bits) & 0xF\n int_bzp = (bzp >> bzp_shift_bits) & 0xF\n b = ((int_b - int_bzp) * bs).to(a.dtype)\n accumulator += tl.dot(a, b.to(a.dtype))\n # Advance the ptrs to the next K block.\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0\n # You can fuse arbitrary activation functions here\n # while the accumulator is still in FP32!\n c = accumulator.to(c_ptr.dtype.element_ty)\n # -----------------------------------------------------------\n # Write back the block of the output matrix C with masks.\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor:\n \"\"\"\n \"\"\"\n assert x.is_contiguous(), \"A must be contiguous\"\n assert qweight.is_contiguous(), \"B must be contiguous\" \n M, K = x.shape\n N = scales.shape[1]\n if output is None:\n output = torch.zeros((M, N), device=x.device, dtype=x.dtype) \n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n )\n return output\n\ndef quantize_int4(weight, group_size=128, tp_rank=0):\n # Weight shape: [H1 // 8, H2]\n # Scale shape: [H1 // group_size, H2]\n # zero_pint shape: [H1 // group_size, H2 // 8]\n\n weight = weight.transpose(1, 0)\n h1, h2 = weight.shape\n assert h1 % 8 == 0 and h2 % 8 == 0, \"H1 {} H2 {}\".format(h1, h2)\n assert h2 % group_size == 0, \"H1 {} H2 {}\".format(h1, h2)\n weight = weight.contiguous().view(-1, group_size).cuda(tp_rank)\n weight_max = weight.amax(-1, keepdim=True)\n weight_max = torch.where(weight_max < 0, 0, weight_max)\n weight_min = weight.amin(-1, keepdim=True)\n weight_min = torch.where(weight_min > 0, 0, weight_min)\n weight_range = weight_max - weight_min \n scale = weight_range / (2 ** 4 - 1)\n zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32)\n weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2)\n int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device)\n int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device)\n zero_point = zero_point.view(h1, -1)\n scale = scale.view(h1, -1)\n # pack 8 int4 in an int32 number.\n # Weight pack in row.\n for pack in range(0, h2, 8):\n for i in range(8):\n int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4)\n # zero point pack in col.\n for pack in range(0, h1, 8):\n for i in range(8):\n int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4)\n '''\n fp_weight = torch.zeros(h1, h2).half().to(weight.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_weight[pack * 8 + i, :] = \\\n ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16\n print((fp_weight - weight).abs().sum())\n\n fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zp[pack * 8 + i, :] = \\\n (int_zero_point[pack, :] >> (i * 4)) & 15\n\n print((fp_zp - zero_point).abs().sum())\n '''\n weight = None\n return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size\n\n\ndef unpack_int4(weight, scale, zp):\n \"\"\"\n Test function to verify quantize int4 is correct.\n Will not be used in model inference.\n \"\"\"\n weight = weight.transpose(1, 0)\n scale = scale.transpose(1, 0)\n zp = zp.transpose(1, 0)\n h1, h2 = weight.shape\n group_size = h2 * 8 // scale.shape[1]\n group_num = scale.shape[1]\n fp_weight = torch.zeros(h1, h2 * 8).half().to(weight.device)\n fp_zero_point = torch.zeros(h1, group_num).to(weight.device)\n for pack in range(0, h2):\n for i in range(8):\n fp_weight[:, pack * 8 + i] = (weight[:, pack] >> (i * 4)) & 0xF\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF\n for g in range(group_num):\n fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \\\n fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1)\n return fp_weight.transpose(1, 0)\n\n\n\n", "file": "int4_matmul.py", "difficulty": "5"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n Out, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_obs,\n stride_oh,\n stride_od,\n head_dim,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n\n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n\n tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim)\n return\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq):\n Lk = mid_out.shape[-1]\n head_dim = Lk\n batch, head_num = mid_out.shape[0], mid_out.shape[1]\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n mid_out,\n mid_out_logexpsum,\n Out,\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logexpsum.stride(0),\n mid_out_logexpsum.stride(1),\n mid_out_logexpsum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n head_dim,\n BLOCK_SEQ=block_seq,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=2,\n )\n return\n\n\n\n\n", "file": "flash_decode2_phi.py", "difficulty": "2"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", "input": "", "output": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n # transpose\n matrix_ptr = M + d_head_arange[None, :] * matrix_stridey + size_m_arange[:, None] * matrix_stridex\n out_ptr = Out + d_head_arange[None, :] * out_stridex + size_m_arange[:, None] * out_stridey\n matrix = tl.load(matrix_ptr)\n tl.store(out_ptr, matrix)\n\ndef wrapper(size_m, d_head):\n matrix = torch.randn((size_m, d_head), dtype=torch.float16, device=\"cuda\")\n out = torch.zeros((d_head, size_m), dtype=torch.float16, device=\"cuda\")\n\n grid = (1,)\n kernel[grid](\n matrix,\n out,\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n )\n return out\n\n\n\n", "file": "matrix_transpose.py", "difficulty": "2"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n ", "input": "", "output": "from typing import Optional, Union\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro,\n CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads,\n stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads,\n stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x1 = tl.load(X + rotary_dim_half * stride_x_headdim, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n tl.store(OUT + rotary_dim_half * stride_out_headdim, o1, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n else:\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)\n x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\ndef apply_rotary(\n x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None,\n interleaved=False, inplace=False, conjugate=False\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n total_seqlen, nheads, headdim = x.shape\n batch = cu_seqlens.shape[0] - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n rotary_dim *= 2\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n seqlen_offsets += seqlen\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads)\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim,\n seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3),\n output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0,\n x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M\n )\n return output\n\n\n\n\n", "file": "rotary_transform.py", "difficulty": "4"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Kernel function using Triton\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # x_ptr: pointer to input data\n # output_ptr: pointer to output data\n # n_elements: number of elements to process\n # BLOCK_SIZE: block size for Triton kernel\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = tl.math.sin(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef call_kernel(x):\n # x: input tensor\n n_elements = x.numel()\n output = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n kernel_function[grid](x, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n\n\n\n", "file": "sin_kernel.py", "difficulty": "1"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\ndef _l2_norm_bwd(\n x, dy, eps=1e-5,\n):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n\n\n\n\n", "file": "l2_norm_bwd.py", "difficulty": "3"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", "input": "", "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n y = x * rstd\n tl.store(Y + cols, y, mask=mask)\n\ndef _l2_norm_fwd(\n x, eps=1e-6\n):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n y = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\n\n\n\n", "file": "l2_norm_triton1.py", "difficulty": "2"}] \ No newline at end of file +[ + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", + "input": "", + "output": "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": n}, num_stages=s, num_warps=w)\n for m in [32, 64, 128]\n for n in [1, 2, 4, 8]\n for s in [3, 4]\n for w in [4, 8]\n ],\n key=[\"M\", \"N\"],\n)\n@triton.jit\ndef mv_kernel(\n A,\n B,\n C,\n N,\n M,\n stride_an,\n stride_am,\n stride_bm,\n stride_cn,\n BLOCK_N: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]\n offset_m = tl.arange(0, BLOCK_M)[None, :]\n n_mask = offset_n < N\n A_ptrs = A + offset_n * stride_an + offset_m * stride_am\n B_ptrs = B + offset_m * stride_bm\n acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n for m in range(0, M, BLOCK_M):\n m_mask = m + offset_m < M\n a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)\n b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)\n acc += a * b\n A_ptrs += BLOCK_M * stride_am\n B_ptrs += BLOCK_M * stride_bm\n\n acc = tl.sum(acc, axis=1)\n C_ptrs = C + offset_n * stride_cn\n tl.store(C_ptrs, acc[:, None], mask=n_mask)\n\n\ndef mv(inp, vec):\n logging.debug(\"GEMS MV\")\n assert inp.shape[1] == vec.shape[0], \"incompatible dimensions\"\n N, M = inp.shape\n out = torch.empty((N,), device=inp.device, dtype=inp.dtype)\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_N\"]),)\n with torch.cuda.device(inp.device):\n mv_kernel[grid](\n inp,\n vec,\n out,\n N,\n M,\n inp.stride(0),\n inp.stride(1),\n vec.stride(0),\n out.stride(0),\n )\n return out\n\n\n\n\n", + "file": "matrix_vector_multip.py", + "difficulty": "4" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef _matmul_launch_metadata(grid, kernel, args):\n ret = {}\n M, N, K = args[\"M\"], args[\"N\"], args[\"K\"]\n ret[\"name\"] = f\"{kernel.name} [M={M}, N={N}, K={K}]\"\n if \"c_ptr\" in args:\n bytes_per_elem = args[\"c_ptr\"].element_size()\n else:\n bytes_per_elem = 1 if args[\"FP8_OUTPUT\"] else 2\n ret[f\"flops{bytes_per_elem * 8}\"] = 2. * M * N * K\n ret[\"bytes\"] = bytes_per_elem * (M * K + N * K + M * N)\n return ret\n\n\n@triton.jit(launch_metadata=_matmul_launch_metadata)\ndef matmul_kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, #\n BLOCK_SIZE_N: tl.constexpr, #\n BLOCK_SIZE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, #\n ):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n start_m = pid_m * BLOCK_SIZE_M\n start_n = pid_n * BLOCK_SIZE_N\n\n offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n offs_am = tl.where(offs_am < M, offs_am, 0)\n offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n\n offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if (c_ptr.dtype.element_ty == tl.float8e4nv):\n c = accumulator.to(tl.float8e4nv)\n else:\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(a, b):\n configs = {\n torch.float8_e4m3fn: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 128, \"GROUP_SIZE_M\": 8, \"num_stages\": 4,\n \"num_warps\": 8\n }, torch.float16: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64, \"GROUP_SIZE_M\": 8, \"num_stages\": 3,\n \"num_warps\": 8\n }\n }\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.dtype == b.dtype, \"Incompatible dtypes\"\n M, K = a.shape\n K, N = b.shape\n dtype = a.dtype\n\n c = torch.empty((M, N), device=a.device, dtype=dtype)\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]), )\n matmul_kernel[grid](\n a, b, c, #\n M, N, K, #\n a.stride(0), a.stride(1), #\n b.stride(0), b.stride(1), #\n c.stride(0), c.stride(1), #\n BLOCK_SIZE_M=configs[dtype][\"BLOCK_SIZE_M\"], #\n BLOCK_SIZE_N=configs[dtype][\"BLOCK_SIZE_N\"], #\n BLOCK_SIZE_K=configs[dtype][\"BLOCK_SIZE_K\"], #\n GROUP_SIZE_M=configs[dtype][\"GROUP_SIZE_M\"], #\n num_stages=configs[dtype][\"num_stages\"], #\n num_warps=configs[dtype][\"num_warps\"], #\n )\n return c\n\n\n\n\n", + "file": "triton_matmul.py", + "difficulty": "3" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef embedding_kernel(\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n stride_weight_seq,\n stride_out_seq,\n n_ctx,\n hiden_size: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n):\n start_n = tl.program_id(0) * BLOCK_N\n\n offs_nn = start_n + tl.arange(0, BLOCK_NN)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n for start_nn in range(0, BLOCK_N, BLOCK_NN):\n start_nn = tl.multiple_of(start_nn, BLOCK_NN)\n offs_seq = start_nn + offs_nn\n n_ctx_mask = offs_seq < n_ctx\n token_ids = tl.load(input_ids + offs_seq, mask=n_ctx_mask, other=vob_end_id)\n id_mask = (token_ids >= vob_start_id) & (token_ids < vob_end_id)\n token_ids = token_ids - vob_start_id\n dim_mask = offs_d < hiden_size\n load_mask = id_mask[:, None] & dim_mask[None, :]\n store_mask = n_ctx_mask[:, None] & dim_mask[None, :]\n vecs = tl.load(weight + token_ids[:, None] * stride_weight_seq + offs_d[None, :], mask=load_mask, other=0.0)\n tl.store(out + offs_seq[:, None] * stride_out_seq + offs_d[None, :], vecs, mask=store_mask)\n\n@torch.no_grad()\ndef embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: torch.Tensor):\n BLOCK_N = 64\n BLOCK_NN = 1\n BLOCK_DMODEL = triton.next_power_of_2(weight.shape[1])\n n_ctx = input_ids.shape[0]\n\n grid = (triton.cdiv(n_ctx, BLOCK_N), 1, 1)\n\n embedding_kernel[grid](\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n weight.stride(0),\n out.stride(0),\n n_ctx=n_ctx,\n hiden_size=weight.shape[1],\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n num_warps=1,\n num_stages=1,\n )\n\n\n\n\n", + "file": "embedding_triton_kernel.py", + "difficulty": "3" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", + "input": "", + "output": "import time\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n\tconfigs=[\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t \n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\t\n ],\n\tkey=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n bs_ptr, bzp_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n group_size,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr\n ):\n \"\"\"\n assert K % (BLOCK_SIZE_K * SPLIT_K) == 0\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n # [BLOCK_M, BLOCK_K]\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n # [BLOCK_K, BLOCK_N] but repeated 8 times in N\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn\n # tl.static_print(\"shape\", a_ptrs, b_ptrs, bs_ptrs, bzp_ptrs)\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n # Load the next block of A and B.\n # [BLOCK_K, BLOCK_N] but repeated group_size times in K \n bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \\\n + offs_bn[None, :] * stride_bsn\n # [BLOCK_K, BLOCK_N] but repeated in K and N\n bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \\\n + (offs_bn[None, :] // 8) * stride_bzpn\n b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0\n bzp_shift_bits = (offs_bn[None, :] % 8) * 4\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n bs = tl.load(bs_ptrs)\n bzp = tl.load(bzp_ptrs)\n # We accumulate along the K dimension.\n int_b = (b >> b_shift_bits) & 0xF\n int_bzp = (bzp >> bzp_shift_bits) & 0xF\n b = ((int_b - int_bzp) * bs).to(a.dtype)\n accumulator += tl.dot(a, b.to(a.dtype))\n # Advance the ptrs to the next K block.\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0\n # You can fuse arbitrary activation functions here\n # while the accumulator is still in FP32!\n c = accumulator.to(c_ptr.dtype.element_ty)\n # -----------------------------------------------------------\n # Write back the block of the output matrix C with masks.\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor:\n \"\"\"\n \"\"\"\n assert x.is_contiguous(), \"A must be contiguous\"\n assert qweight.is_contiguous(), \"B must be contiguous\" \n M, K = x.shape\n N = scales.shape[1]\n if output is None:\n output = torch.zeros((M, N), device=x.device, dtype=x.dtype) \n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n )\n return output\n\ndef quantize_int4(weight, group_size=128, tp_rank=0):\n # Weight shape: [H1 // 8, H2]\n # Scale shape: [H1 // group_size, H2]\n # zero_pint shape: [H1 // group_size, H2 // 8]\n\n weight = weight.transpose(1, 0)\n h1, h2 = weight.shape\n assert h1 % 8 == 0 and h2 % 8 == 0, \"H1 {} H2 {}\".format(h1, h2)\n assert h2 % group_size == 0, \"H1 {} H2 {}\".format(h1, h2)\n weight = weight.contiguous().view(-1, group_size).cuda(tp_rank)\n weight_max = weight.amax(-1, keepdim=True)\n weight_max = torch.where(weight_max < 0, 0, weight_max)\n weight_min = weight.amin(-1, keepdim=True)\n weight_min = torch.where(weight_min > 0, 0, weight_min)\n weight_range = weight_max - weight_min \n scale = weight_range / (2 ** 4 - 1)\n zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32)\n weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2)\n int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device)\n int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device)\n zero_point = zero_point.view(h1, -1)\n scale = scale.view(h1, -1)\n # pack 8 int4 in an int32 number.\n # Weight pack in row.\n for pack in range(0, h2, 8):\n for i in range(8):\n int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4)\n # zero point pack in col.\n for pack in range(0, h1, 8):\n for i in range(8):\n int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4)\n '''\n fp_weight = torch.zeros(h1, h2).half().to(weight.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_weight[pack * 8 + i, :] = \\\n ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16\n print((fp_weight - weight).abs().sum())\n\n fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zp[pack * 8 + i, :] = \\\n (int_zero_point[pack, :] >> (i * 4)) & 15\n\n print((fp_zp - zero_point).abs().sum())\n '''\n weight = None\n return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size\n\n\ndef unpack_int4(weight, scale, zp):\n \"\"\"\n Test function to verify quantize int4 is correct.\n Will not be used in model inference.\n \"\"\"\n weight = weight.transpose(1, 0)\n scale = scale.transpose(1, 0)\n zp = zp.transpose(1, 0)\n h1, h2 = weight.shape\n group_size = h2 * 8 // scale.shape[1]\n group_num = scale.shape[1]\n fp_weight = torch.zeros(h1, h2 * 8).half().to(weight.device)\n fp_zero_point = torch.zeros(h1, group_num).to(weight.device)\n for pack in range(0, h2):\n for i in range(8):\n fp_weight[:, pack * 8 + i] = (weight[:, pack] >> (i * 4)) & 0xF\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF\n for g in range(group_num):\n fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \\\n fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1)\n return fp_weight.transpose(1, 0)\n\n\n\n", + "file": "int4_matmul.py", + "difficulty": "5" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n Out, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_obs,\n stride_oh,\n stride_od,\n head_dim,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n\n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n\n tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim)\n return\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq):\n Lk = mid_out.shape[-1]\n head_dim = Lk\n batch, head_num = mid_out.shape[0], mid_out.shape[1]\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n mid_out,\n mid_out_logexpsum,\n Out,\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logexpsum.stride(0),\n mid_out_logexpsum.stride(1),\n mid_out_logexpsum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n head_dim,\n BLOCK_SEQ=block_seq,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=2,\n )\n return\n\n\n\n\n", + "file": "flash_decode2_phi.py", + "difficulty": "2" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", + "input": "", + "output": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n # transpose\n matrix_ptr = M + d_head_arange[None, :] * matrix_stridey + size_m_arange[:, None] * matrix_stridex\n out_ptr = Out + d_head_arange[None, :] * out_stridex + size_m_arange[:, None] * out_stridey\n matrix = tl.load(matrix_ptr)\n tl.store(out_ptr, matrix)\n\ndef wrapper(size_m, d_head):\n matrix = torch.randn((size_m, d_head), dtype=torch.float16, device=\"cuda\")\n out = torch.zeros((d_head, size_m), dtype=torch.float16, device=\"cuda\")\n\n grid = (1,)\n kernel[grid](\n matrix,\n out,\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n )\n return out\n\n\n\n", + "file": "matrix_transpose.py", + "difficulty": "2" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n Key changes in this optimized version: 1.Added BLOCK_Kfor better control over memory access patterns \n 2.More aggressive vectorization of cos/sin loads \n 3.Better mask reuse \n 4.More efficient loop structure. \n ", + "input": "", + "output": "from typing import Optional, Union\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro,\n CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads,\n stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads,\n stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x1 = tl.load(X + rotary_dim_half * stride_x_headdim, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n tl.store(OUT + rotary_dim_half * stride_out_headdim, o1, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n else:\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)\n x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\ndef apply_rotary(\n x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None,\n interleaved=False, inplace=False, conjugate=False\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n total_seqlen, nheads, headdim = x.shape\n batch = cu_seqlens.shape[0] - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n rotary_dim *= 2\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n seqlen_offsets += seqlen\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads)\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim,\n seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3),\n output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0,\n x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M\n )\n return output\n\n\n\n\n", + "file": "rotary_transform.py", + "difficulty": "4" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Kernel function using Triton\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # x_ptr: pointer to input data\n # output_ptr: pointer to output data\n # n_elements: number of elements to process\n # BLOCK_SIZE: block size for Triton kernel\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = tl.math.sin(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef call_kernel(x):\n # x: input tensor\n n_elements = x.numel()\n output = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n kernel_function[grid](x, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n\n\n\n", + "file": "sin_kernel.py", + "difficulty": "1" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\ndef _l2_norm_bwd(\n x, dy, eps=1e-5,\n):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n\n\n\n\n", + "file": "l2_norm_bwd.py", + "difficulty": "3" + }, + { + "instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", + "input": "", + "output": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n y = x * rstd\n tl.store(Y + cols, y, mask=mask)\n\ndef _l2_norm_fwd(\n x, eps=1e-6\n):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n y = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\n\n\n\n", + "file": "l2_norm_triton1.py", + "difficulty": "2" + } +] \ No newline at end of file diff --git a/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc b/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc index 041258da8695d655df08fcd5067365004d082324..e8b72b964d457ebbd8283924d94ea64df5bb0d9d 100644 GIT binary patch delta 27 hcmZ3?Ig^w7G%qg~0}woDT)C0^1|y@v=68%+nE-2W2yy@b delta 37 rcmbQqxtNptG%qg~0}vF%&DqF(gOS@xzdXMvySN}RId$_B#;r^M#1#wL diff --git a/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc b/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc index ad4c954d6ce071bdc2f2335b6c38ed1a4e2e0521..65f49893fed011f63e962d74da116333b1f76fd2 100644 GIT binary patch delta 29 jcmdnn$hf(Yk^3|+FBbz4JZN0Gk=v1(QGatdbDAvxgbxT= delta 39 tcmdno$hfzWk^3|+FBbz46vWNh$nD6?ZK+?LUzA;3keHmh*_S!Z769T(3&8*Y diff --git a/src/memories/__pycache__/Memory.cpython-312.pyc b/src/memories/__pycache__/Memory.cpython-312.pyc index 09f82b91d9ce14030c7823c89dc54caace80ecb4..6a1ee1c55806baa5b9076b09c538b25ce6390c94 100644 GIT binary patch delta 27 hcmcb?b&iYsG%qg~0}woDT)C0kpNUa@b1KtkCID~32m=5B delta 37 rcmX@db%TrhG%qg~0}vF%&DqH9&%|x0U!Gr-U0jfuoVq!h=`#}m#`Ozp diff --git a/src/models/__pycache__/Base.cpython-312.pyc b/src/models/__pycache__/Base.cpython-312.pyc index 0ad2c7349939c7e2b124bb94e94bd8b3f3172b23..f0365b4c4166e0d0deef80d2d3adbb783754161f 100644 GIT binary patch delta 27 hcmey)@|K1BG%qg~0}woDT)B~3pOH~{vn!(!BLH!n2Q2^q delta 37 rcmaFM@|}hIG%qg~0}vF%&DqGU&&aK-U!Gr-U0jfuoVwYX(TEWM$x#ZG diff --git a/src/models/__pycache__/KimiK2.cpython-312.pyc b/src/models/__pycache__/KimiK2.cpython-312.pyc index 34c5a6765825d2537bf37a5e5d6d5131d020ec19..c7c5107d290233f52f2361243502d72a42b709c2 100644 GIT binary patch delta 26 gcmZn{{42nHnwOW00SF#6uH4A|iiuHm6LUK&0BGh1xc~qF delta 37 rcmew>&@RY*nwOW00SF4>=4|AC#l)?zU!Gr-U0jfuoVxiNQ#&gF(Q6DL diff --git a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc index 29e23cc138d0482512bb0b41a3380383a96f3f73..49ed1398e7da8cb5124d96b0924a8b008d7a51fc 100644 GIT binary patch delta 175 zcmez9`_7N|G%qg~0}$vOUzKrcBCihPp^aAadFn0o6g<-uN-|OvGV=;bOB6DT71DC@ z6H62l^HLO2@=KF)QgsxP^YhA5i%Jwq@)eTO49ye@ic*s^i!<}{6w>mG6p|Bjl1pz0qnu&*lW)Ws=nQ{zRw@)1q$X!3<|rhVl%(dBWaj57p^NB+rkxdJxNP;X|YS4aab-pM`rqFDr^>0}f05JPb;mIk&DMzVY?4O}1W mq$b!TyN+v}I-7iAY0Bqpa$ zzGM+%D8|Lo!1lpFhL5Fz>w}FnFG~aG2P->%mIm$*4oohL!YmE^AAFdc8HHFH`HDD! G?g0Q&wlQA- diff --git a/src/retrievers/__pycache__/retriever.cpython-312.pyc b/src/retrievers/__pycache__/retriever.cpython-312.pyc index dc463e336802936d3187aecc27bdae9c5fdf2a5e..b455f7510cc1b6430a6bd3fa844210a4e9c7ccde 100644 GIT binary patch delta 27 hcmbOwH9?B|G%qg~0}woDT)B~ZIU}Ra=Dm!u+yH5^2h9Kg delta 37 rcmbOrHA{;7G%qg~0}vF%&DqGkoRQm1zdXMvySN}RId$`9##n9uz~u`y diff --git a/src/temp/embedding_triton_kernel.py b/src/temp/embedding_triton_kernel.py new file mode 100644 index 0000000..d0c993b --- /dev/null +++ b/src/temp/embedding_triton_kernel.py @@ -0,0 +1,131 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + tokens_ptr, # int32* + out_ptr, # weight.dtype* + weight_ptr, # weight.dtype* + seq_len, + vocab_size, + n_dim, + stride_tokens, + stride_out_n, + stride_out_d, + stride_weight_vocab, + stride_weight_dim, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, +): + pid_seq = tl.program_id(0) # batch dimension + offs_d = tl.arange(0, n_dim) + + for block_start in range(0, seq_len, BLOCK_NN): + cur_block_size = tl.minimum(BLOCK_NN, seq_len - block_start) + block_token_offs = pid_seq * stride_tokens + block_start + tl.arange(0, BLOCK_N) + mask_n = tl.arange(0, BLOCK_N) < cur_block_size + block_tokens = tl.load(tokens_ptr + block_token_offs, mask=mask_n, other=0) + + offs_n = block_start + tl.arange(0, BLOCK_N)[:, None] # [BLOCK_N, 1] + offs_w = block_tokens[:, None] * stride_weight_vocab + offs_d[None, :] * stride_weight_dim # [BLOCK_N, n_dim] + + w_vec = tl.load(weight_ptr + offs_w, + mask=(offs_n < seq_len)[:, None] & (offs_d[None, :] < n_dim)) + + offs_out = pid_seq * stride_out_n + offs_n * stride_out_d + offs_d[None, :] + tl.store(out_ptr + offs_out, + w_vec, + mask=(offs_n < seq_len)[:, None] & (offs_d[None, :] < n_dim)) + + +def embedding(tokens: torch.Tensor, + weight: torch.Tensor) -> torch.Tensor: + assert tokens.dim() == 2, "Expected tokens shape (batch, seq)" + bsz, seq_len = tokens.shape + vocab_size, n_dim = weight.shape + assert tokens.dtype in [torch.int32, torch.int64], "tokens must be int32 or int64" + output = torch.empty((bsz, seq_len, n_dim), dtype=weight.dtype, device=weight.device) + + BLOCK_N = 64 + BLOCK_NN = BLOCK_N + grid = (bsz,) + embedding_kernel[grid]( + tokens, + output, + weight, + seq_len, + vocab_size, + n_dim, + tokens.stride(0), + output.stride(0), + output.stride(1), + weight.stride(0), + weight.stride(1), + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + ) + return output +################################################################################################################################################## + + + +import torch + +def test_embedding(): + # 参数定义 + vocab_size = 1000 # 词汇表大小 + embedding_dim = 512 # 嵌入维度 + sequence_length = 128 # 输入序列长度 + vob_start_id = 10 # 词汇表起始 ID + vob_end_id = 1000 # 词汇表结束 ID + + # 创建测试输入张量 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + + # 调用嵌入函数 + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + # 保存结果 + results = {} + results['test_case_1'] = out.clone() + + # 测试不同的输入 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_2'] = out.clone() + + # 测试不同的词汇表范围 + vob_start_id = 0 + vob_end_id = 500 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_3'] = out.clone() + + # 测试不同的嵌入维度 + embedding_dim = 256 + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_4'] = out.clone() + + return results + +result_gold = test_embedding() diff --git a/src/temp/flash_decode2_phi.py b/src/temp/flash_decode2_phi.py new file mode 100644 index 0000000..2f9421c --- /dev/null +++ b/src/temp/flash_decode2_phi.py @@ -0,0 +1,147 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_o_b, + stride_mid_o_h, + stride_mid_o_s, + stride_mid_o_d, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = tl.full([], 0.0, dtype=tl.float32) + max_logic = tl.full([], -float("inf"), dtype=tl.float32) + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_d = tl.arange(0, BLOCK_DMODEL) + ptr_mid = Mid_O + cur_batch * stride_mid_o_b + cur_head * stride_mid_o_h + block_id * stride_mid_o_s + offs_d * stride_mid_o_d + tv = tl.load(ptr_mid).to(tl.float32) + + ptr_lse = Mid_O_LogExpSum + cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_s + tlogic = tl.load(ptr_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + exp_di = tl.exp(tlogic - new_max) + sum_exp += exp_di + acc += tv * exp_di + max_logic = new_max + + acc_norm = acc / sum_exp + + offs_out_d = tl.arange(0, BLOCK_DMODEL) + ptr_out = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offs_out_d * stride_out_d + tl.store(ptr_out, acc_norm.to(Out.type.element_ty)) + + +@torch.no_grad() +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int +): + BLOCK_DMODEL = Mid_O.shape[-1] + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=1, + ) + +################################################################################################################################################## + + + +import torch + +# Define the test function +def test_flash_decode_stage2(): + # Define the parameters for different test cases + batch_size = 2 + head_num = 4 + seq_block_num = 3 + head_dim = 64 + block_seq = 16 + + test_cases = { + "test_case_1": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + }, + "test_case_2": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + 1 # Different block size + }, + "test_case_3": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq // 2 # Different block size + }, + "test_case_4": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq * 2 # Different block size + } + } + + # Execute the function for all test cases + results = {} + for key, test_case in test_cases.items(): + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + results[key] = test_case["Out"] + + return results + +# Run the test +result_gold = test_flash_decode_stage2() diff --git a/src/temp/int4_matmul.py b/src/temp/int4_matmul.py new file mode 100644 index 0000000..8741b86 --- /dev/null +++ b/src/temp/int4_matmul.py @@ -0,0 +1,194 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 2}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr = 1, +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = first_pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + scales_ptrs = scales_ptr + (offs_n * stride_scales_n) + zeros_ptrs = zeros_ptr + ((offs_n // 8) * stride_zeros_n) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + mask_k = offs_k < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + b_idx = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_raw = tl.load(b_ptr + b_idx, mask=mask_k[:, None], other=0) + + b_i4 = (b_raw >> (4 * (offs_k[:, None] % 8))) & 0xF + + group_idx = (k * BLOCK_SIZE_K * SPLIT_K + offs_k[:, None]) // groupsize + scales = tl.load(scales_ptrs + group_idx * stride_scales_g) + zeros = tl.load(zeros_ptrs + group_idx * stride_zeros_g) + + b_fp = (b_i4 - ((zeros >> (4 * (offs_n[None, :] % 8))) & 0xF)) * scales + + accumulator += tl.dot(a, b_fp.to(a.dtype)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + offs_k += BLOCK_SIZE_K * SPLIT_K + + offs_m_real = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n_real = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_m_real < M + mask_n = offs_n_real < N + + c_ptrs = c_ptr + stride_cm * offs_m_real[:, None] + stride_cn * offs_n_real[None, :] + c_mask = mask_m[:, None] & mask_n[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptrs, accumulator, mask=c_mask) + + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + groupsize: int = 128, + output=None +) -> torch.FloatTensor: + assert x.is_contiguous(), "x must be contiguous" + assert qweight.is_contiguous(), "qweight must be contiguous" + + M, K = x.shape + assert K == qweight.shape[0] * 8, "K must align with packed INT4 weight size" + N = qweight.shape[1] + + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META.get('SPLIT_K', 1), + ) + + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + groupsize, + GROUP_SIZE_M=8, + ) + return output + + +def quantize_int4(w: torch.Tensor, groupsize: int = 128): + assert w.dtype in (torch.float16, torch.float32, torch.bfloat16) + w = w.float() + + N, K = w.shape + assert K % groupsize == 0, "groupsize must evenly divide K" + + w_grouped = w.view(N, K // groupsize, groupsize) + mn, _ = w_grouped.min(dim=-1, keepdim=True) + mx, _ = w_grouped.max(dim=-1, keepdim=True) + scales = (mx - mn) / 15.0 + zeros = -mn / scales + + quantized = torch.round((w_grouped - mn) / scales).clamp(0, 15).to(torch.int8) + + quantized = quantized.view(N, K) + scales = scales.view(N, K // groupsize) + zeros = zeros.to(torch.int8).view(N, K // groupsize) + + packed = torch.zeros((N, K // 2), dtype=torch.int32, device=w.device) + for k in range(0, K, 2): + even = quantized[:, k] + odd = quantized[:, k + 1] if k + 1 < K else 0 + packed[:, k // 2] = (odd.int() << 4) | (even.int() & 0xF) + + return packed, scales.float(), zeros + + +def unpack_int4(b_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, groupsize: int = 128): + assert b_packed.dtype == torch.int32 + assert b_packed.dim() == 2 + N, K_half = b_packed.shape + K = K_half * 2 + + unpacked = torch.empty((N, K), dtype=torch.int8, device=b_packed.device) + for k in range(0, K // 2): + val = b_packed[:, k] + unpacked[:, 2 * k] = val.int() & 0xF + unpacked[:, 2 * k + 1] = (val.int() >> 4) & 0xF + + group_idx = torch.arange(K, device=b_packed.device) // groupsize + zeros_exp = zeros.gather(1, group_idx.view(1, -1).expand(N, K)) + scales_exp = scales.gather(1, group_idx.view(1, -1).expand(N, K)) + + return (unpacked.float() - zeros_exp) * scales_exp + +################################################################################################################################################## + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + group_size = 128 + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + # Test case + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + results = { + "test_case_1": triton_output + } + + return results + +result_gold = test_correct_int4_s2() diff --git a/src/temp/l2_norm_bwd.py b/src/temp/l2_norm_bwd.py new file mode 100644 index 0000000..56886ce --- /dev/null +++ b/src/temp/l2_norm_bwd.py @@ -0,0 +1,94 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, + M, N, + eps, + stride_x_row, + BLOCK_N: tl.constexpr, +): + row_id = tl.program_id(0) + if row_id >= M: + return + + offsets_n = tl.arange(0, BLOCK_N) + mask_n = offsets_n < N + + offsets_x = row_id * stride_x_row + offsets_n + x = tl.load(X + offsets_x, mask=mask_n, other=0.0).to(tl.float32) + dy = tl.load(DY + offsets_x, mask=mask_n, other=0.0).to(tl.float32) + + squares = x * x + var = tl.sum(tl.where(mask_n, squares, 0.0), axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dot = tl.sum(tl.where(mask_n, dy * x, 0.0), axis=0) + + dx = dy * rstd - dot * (1.0 / (var + eps)) * rstd * x + + tl.store(DX + offsets_x, dx, mask=mask_n) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + original_shape = x.shape + x = x.reshape(-1, x.size(-1)) + dy = dy.reshape(-1, dy.size(-1)).contiguous() + x = x.contiguous() + M, N = x.shape + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError(f"N ({N}) exceeds maximum BLOCK_N ({BLOCK_N})") + + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, dy, dx, + M, N, + eps, + x.stride(0), + BLOCK_N=BLOCK_N, + ) + return dx.view(*original_shape) + +################################################################################################################################################## + + + +import torch + +# Test the backward L2 normalization +def test_l2_norm_bwd(): + results = {} + + # Test case 1: Default case + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_1'] = dx + + # Test case 2: Different shape + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_2'] = dx + + # Test case 3: Larger tensor + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_3'] = dx + + # Test case 4: Edge case with small tensor + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_4'] = dx + + return results + +# Run the tests +result_gold = test_l2_norm_bwd() diff --git a/src/temp/l2_norm_triton1.py b/src/temp/l2_norm_triton1.py new file mode 100644 index 0000000..c25201e --- /dev/null +++ b/src/temp/l2_norm_triton1.py @@ -0,0 +1,83 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + mask = cols < N + y = x * rstd + tl.store(Y + cols, y, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + original_shape = x.shape + x = x.reshape(-1, original_shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + y = torch.empty_like(x) + + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("L2 norm Triton kernel requires N <= BLOCK_N.") + + _l2_norm_fwd_1pass_kernel[(M,)]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N + ) + + return y.reshape(original_shape) + +################################################################################################################################################## + + + +import torch + +# Test the forward L2 normalization +def test_l2_norm_fwd(): + results = {} + + # Test case 1 + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + y1 = _l2_norm_fwd(x1) + results['test_case_1'] = y1 + + # Test case 2: Different batch size + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + y2 = _l2_norm_fwd(x2) + results['test_case_2'] = y2 + + # Test case 3: Different feature size + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + y3 = _l2_norm_fwd(x3) + results['test_case_3'] = y3 + + # Test case 4: Larger tensor + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + y4 = _l2_norm_fwd(x4) + results['test_case_4'] = y4 + + return results + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/matrix_transpose.py b/src/temp/matrix_transpose.py new file mode 100644 index 0000000..b6fdd5e --- /dev/null +++ b/src/temp/matrix_transpose.py @@ -0,0 +1,71 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid = tl.program_id(0) + num_blocks_h = (D_HEAD + 127) // 128 + block_m = pid // num_blocks_h + block_d = pid % num_blocks_h + + BLOCK_M = 64 + BLOCK_D = 128 + + rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M) + rd = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = rm < SIZE_M + mask_d = rd < D_HEAD + mask = mask_m[:, None] & mask_d[None, :] + + ptrs_in = M + rm[:, None] * matrix_stridex + rd[None, :] * matrix_stridey + vals = tl.load(ptrs_in, mask=mask) + ptrs_out = Out + rd[None, :] * out_stridex + rm[:, None] * out_stridey + tl.store(ptrs_out, vals, mask=mask) + +def wrapper(size_m: int, d_head: int): + torch.manual_seed(0) + M = torch.randn(size_m, d_head, dtype=torch.float16, device='cuda') + Out = torch.zeros(d_head, size_m, dtype=torch.float16, device='cuda') + + grid = lambda meta: [((size_m + 63) // 64) * ((d_head + 127) // 128)] + + kernel[grid]( + M, + Out, + M.stride(0), + M.stride(1), + Out.stride(0), + Out.stride(1), + size_m, + d_head, + ) + return Out + +################################################################################################################################################## + + + +import torch + +def test_triton_vs_torch(): + results = {} + + # 测试用例 1: 基本矩阵转置 (小矩阵) + size_m, d_head = 16, 16 + out = wrapper(size_m, d_head) + results["test_case_1"] = out.clone() + + # 测试用例 2: 非方形矩阵 + size_m, d_head = 32, 64 + out = wrapper(size_m, d_head) + results["test_case_2"] = out.clone() + + return results + + +# 运行测试 +result_gold = test_triton_vs_torch() +# print(result_gold) \ No newline at end of file diff --git a/src/temp/matrix_vector_multip.py b/src/temp/matrix_vector_multip.py new file mode 100644 index 0000000..b55b68c --- /dev/null +++ b/src/temp/matrix_vector_multip.py @@ -0,0 +1,83 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel( + A, B, C, + N, M, + stride_an, stride_am, + stride_b, + stride_c, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_n = tl.program_id(0) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for k in range(0, M, BLOCK_M): + offs_am = k + offs_m + mask_a = (offs_n[:, None] < N) & (offs_am[None, :] < M) + mask_b = offs_am < M + a_ptrs = A + (offs_n[:, None] * stride_an + offs_am[None, :] * stride_am) + b_ptrs = B + offs_am * stride_b + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + acc += tl.sum(a * b[None, :], axis=1) + + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_c = offs_cn < N + c_ptrs = C + offs_cn * stride_c + tl.store(c_ptrs, acc, mask=mask_c) + + +def mv(A: torch.Tensor, B: torch.Tensor): + assert A.dim() == 2, "A must be 2-D" + assert B.dim() == 1, "B must be 1-D" + N, M = A.shape + assert B.shape[0] == M, "Incompatible shapes for matrix-vector multiplication" + C = torch.empty(N, dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 64 + + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, + N, M, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, + ) + return C + +################################################################################################################################################## + + + +def test_mv(): + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + triton_result_2 = mv(A, B) + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + A = torch.randn(32, 16, device='cuda') + B = torch.randn(16, device='cuda') + triton_result_3 = mv(A, B) + + return { + "test_case_2": triton_result_2, + "test_case_3": triton_result_3, + } + +result_gold = test_mv() diff --git a/src/temp/rotary_transform.py b/src/temp/rotary_transform.py new file mode 100644 index 0000000..e146a6c --- /dev/null +++ b/src/temp/rotary_transform.py @@ -0,0 +1,212 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, + CU_SEQLENS, + stride_x_batch, stride_x_head, stride_x_m, stride_x_k, + stride_out_batch, stride_out_head, stride_out_m, stride_out_k, + stride_cos_m, stride_cos_k, + stride_sin_m, stride_sin_k, + SEQLEN_OFFSETS, + TOTAL_HEADS, + INITIAL_HEAD_INDEX, + BLOCK_K: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + IS_VARLEN: tl.constexpr, + BLOCK_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_head = tl.program_id(1) + pid_batch = tl.program_id(2) + + initial_head = tl.load(INITIAL_HEAD_INDEX) + total_heads = tl.load(TOTAL_HEADS) + cur_head = initial_head + pid_head + if cur_head >= total_heads: + return + + head_dim = BLOCK_K * 2 + curr_dtype = X.type.element_ty + + offs_k = tl.arange(0, BLOCK_K) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + if IS_VARLEN: + seq_beg = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_beg + mask_m = offs_m < seqlen_i + else: + seq_beg = 0 + seqlen_i = stride_x_m + mask_m = offs_m < seqlen_i + + base_x = pid_batch * stride_x_batch + cur_head * stride_x_head + base_out = pid_batch * stride_out_batch + cur_head * stride_out_head + base_cos = offs_m + seqlen_offset = tl.load(SEQLEN_OFFSETS + pid_batch) if SEQLEN_OFFSETS else 0 + base_cos = base_cos + seqlen_offset + + mask_k = offs_k < (head_dim // 2) + cos = tl.load( + COS + base_cos[:, None] * stride_cos_m + offs_k[None, :] * stride_cos_k, + mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ).to(tl.float32) + sin = tl.load( + SIN + base_cos[:, None] * stride_sin_m + offs_k[None, :] * stride_sin_k, + mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ).to(tl.float32) + + for m_step in range(BLOCK_M): + if not mask_m[m_step]: + continue + curr_m = pid_m * BLOCK_M + m_step + c = cos[m_step, :] + s = sin[m_step, :] + + if INTERLEAVED: + offs_x0 = base_x + curr_m * stride_x_m + offs_k * 2 * stride_x_k + offs_x1 = offs_x0 + stride_x_k + x0 = tl.load(X + offs_x0, mask=mask_k, other=0.0).to(tl.float32) + x1 = tl.load(X + offs_x1, mask=mask_k, other=0.0).to(tl.float32) + else: + offs_x0 = base_x + curr_m * stride_x_m + offs_k * stride_x_k + offs_x1 = base_x + curr_m * stride_x_m + (BLOCK_K + offs_k) * stride_x_k + x0 = tl.load(X + offs_x0, mask=mask_k, other=0.0).to(tl.float32) + x1 = tl.load(X + offs_x1, mask=mask_k, other=0.0).to(tl.float32) + + if CONJUGATE: + x1 = -x1 + + out0 = x0 * c - x1 * s + out1 = x0 * s + x1 * c + + tl.store(OUT + offs_x0, out0.to(curr_dtype), mask=mask_k) + tl.store(OUT + offs_x1, out1.to(curr_dtype), mask=mask_k) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + assert x.dtype in {torch.float16, torch.bfloat16, torch.float32} + + if cu_seqlens is None: + batch, seqlen, nheads, headdim = x.shape + is_varlen = False + max_seqlen = seqlen + else: + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.size(0) - 1 + seqlen = max_seqlen + is_varlen = True + + seqlen_ro, rotary_dim_over2 = cos.shape + assert sin.shape == cos.shape and headdim % 2 == 0 + rotary_dim = rotary_dim_over2 * 2 + assert rotary_dim <= headdim + + BLOCK_K = headdim // 2 + BLOCK_M = 64 + + if isinstance(seqlen_offsets, int): + seqlen_offsets_tensor = torch.tensor([seqlen_offsets], dtype=torch.int32, device=x.device) + else: + seqlen_offsets_tensor = seqlen_offsets.to(torch.int32) + + if inplace: + out = x + else: + out = torch.empty_like(x) + + if rotary_dim < headdim and not inplace: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + total_heads = torch.tensor([nheads], dtype=torch.int32, device=x.device) + initial_head = torch.tensor([0], dtype=torch.int32, device=x.device) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), nheads, batch) + + rotary_kernel[grid]( + out, x, cos, sin, + cu_seqlens, + x.stride(0), x.stride(2), x.stride(1), x.stride(3), + out.stride(0), out.stride(2), out.stride(1), out.stride(3), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + seqlen_offsets_tensor, + total_heads, initial_head, + BLOCK_K=BLOCK_K, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + IS_VARLEN=is_varlen, + BLOCK_M=BLOCK_M, + EVEN_K=True, + ) + return out + +################################################################################################################################################## + + + +import torch + +def test_apply_rotary(): + results = {} + + # Test case 1: Basic test with fixed sequence length and no interleaving + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin) + results['test_case_1'] = output.shape + + # Test case 2: Variable length sequences with interleaving + total_seqlen, nheads, headdim = 256, 4, 64 + batch = 3 + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + max_seqlen = 128 + rotary_dim = 32 + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + results['test_case_2'] = output.shape + + # Test case 3: Conjugate flag enabled + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, conjugate=True) + results['test_case_3'] = output.shape + + # Test case 4: Inplace operation + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, inplace=True) + results['test_case_4'] = output.shape + + return results + +result_gold = test_apply_rotary() diff --git a/src/temp/sin_kernel.py b/src/temp/sin_kernel.py new file mode 100644 index 0000000..c8f8c64 --- /dev/null +++ b/src/temp/sin_kernel.py @@ -0,0 +1,69 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 1024) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x, dtype=x.dtype, device=x.device) + grid = lambda meta: (triton.cdiv(n_elements, BLOCK_SIZE),) + kernel_function[grid]( + x_ptr=x, + output_ptr=output, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + +################################################################################################################################################## + + + +import torch + +# Function to test the Triton kernel +def test_call_kernel(): + results = {} + + # Test case 1: Small input tensor + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + output1 = call_kernel(x1) + results['test_case_1'] = output1 + + # Test case 2: Larger input tensor + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + output2 = call_kernel(x2) + results['test_case_2'] = output2 + + # Test case 3: Edge case with zero elements + x3 = torch.tensor([], dtype=torch.float32).cuda() + output3 = call_kernel(x3) + results['test_case_3'] = output3 + + # Test case 4: Input tensor with negative values + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + output4 = call_kernel(x4) + results['test_case_4'] = output4 + + return results + +# Run the test function +result_gold = test_call_kernel() diff --git a/src/temp/tmp.py b/src/temp/tmp.py new file mode 100644 index 0000000..9d41e77 --- /dev/null +++ b/src/temp/tmp.py @@ -0,0 +1,119 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + ids, weight, out, + stride_ids_n, stride_ids_nn, + stride_weight_t, stride_weight_d, + stride_out_n, stride_out_d, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + ids_ptr = ids + pid * stride_ids_n + out_ptr = out + pid * stride_out_n + + for start_n in range(0, BLOCK_N, BLOCK_NN): + offs_n = start_n + tl.arange(0, BLOCK_NN) + mask_n = offs_n < BLOCK_N + token_ids = tl.load(ids_ptr + offs_n * stride_ids_nn, mask=mask_n, other=0) + + for start_d in range(0, BLOCK_DMODEL, BLOCK_DMODEL): + offs_d = start_d + tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < BLOCK_DMODEL + weight_ptrs = weight + token_ids[:, None] * stride_weight_t + offs_d[None, :] * stride_weight_d + out_ptrs = out_ptr + offs_n[:, None] * stride_out_n + offs_d[None, :] * stride_out_d + embed = tl.load(weight_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + tl.store(out_ptrs, embed, mask=mask_n[:, None] & mask_d[None, :]) + + +def embedding(ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + assert ids.dim() == 2, "ids must be 2-D tensor" + assert weight.dim() == 2, "weight must be 2-D tensor" + batch, seq_len = ids.shape + vocab_size, d_model = weight.shape + out = torch.empty((batch, seq_len, d_model), dtype=weight.dtype, device=weight.device) + + BLOCK_N = seq_len + BLOCK_NN = min(64, seq_len) + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = (batch,) + embedding_kernel[grid]( + ids, weight, out, + ids.stride(0), ids.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(2), + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + +################################################################################################################################################## + + + +import torch + +def test_embedding(): + # 参数定义 + vocab_size = 1000 # 词汇表大小 + embedding_dim = 512 # 嵌入维度 + sequence_length = 128 # 输入序列长度 + vob_start_id = 10 # 词汇表起始 ID + vob_end_id = 1000 # 词汇表结束 ID + + # 创建测试输入张量 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + + # 调用嵌入函数 + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + # 保存结果 + results = {} + results['test_case_1'] = out.clone() + + # 测试不同的输入 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_2'] = out.clone() + + # 测试不同的词汇表范围 + vob_start_id = 0 + vob_end_id = 500 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_3'] = out.clone() + + # 测试不同的嵌入维度 + embedding_dim = 256 + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_4'] = out.clone() + + return results + +result_gold = test_embedding() diff --git a/src/temp/triton_matmul.py b/src/temp/triton_matmul.py new file mode 100644 index 0000000..93f6708 --- /dev/null +++ b/src/temp/triton_matmul.py @@ -0,0 +1,118 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + acc_dtype = tl.float32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + mask_k = offs_k + for k in range(0, K, BLOCK_SIZE_K): + if IS_EVEN_K: + a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0) + b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0) + else: + a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (mask_k[None, :] < (K - k)), other=0.0) + b = tl.load(b_ptrs, mask=(mask_k[:, None] < (K - k)) & (offs_n[None, :] < N), other=0.0) + accumulator += tl.dot(a, b).to(acc_dtype) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator.to(c_ptr.type.element_ty), mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "a must be contiguous" + assert b.is_contiguous(), "b must be contiguous" + M, K = a.shape + _K, N = b.shape + assert K == _K, "K dimensions must match" + + def decide_dtype(dt): + if dt in (torch.float8_e4m3fn, torch.float8_e5m2): + return torch.float16 + elif dt in (torch.float16, torch.bfloat16, torch.float32): + return torch.float16 + elif dt is torch.int8: + return torch.int32 + else: + return torch.float32 + + c_dtype = decide_dtype(a.dtype) + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + if a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + num_stages, num_warps = 3, 8 + else: + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32 + num_stages, num_warps = 2, 4 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_M, + BLOCK_SIZE_N=BLOCK_N, + BLOCK_SIZE_K=BLOCK_K, + IS_EVEN_K=K % BLOCK_K == 0, + num_stages=num_stages, + num_warps=num_warps, + ) + return c + +################################################################################################################################################## + + + +import torch + +# Test for matmul +def test_matmul(): + results = {} + M, K, N = 256, 128, 256 + + # Test case 1: torch.float16 + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + c = matmul(a, b) + results['test_case_1'] = c + + return results + +# Run all tests +result_gold = test_matmul() \ No newline at end of file diff --git a/src/utils/__pycache__/utils.cpython-312.pyc b/src/utils/__pycache__/utils.cpython-312.pyc index 5240a44343db32ccd23713863f8a830c080e8631..37aabeaa3d42d901632cf458067b16c233b2ce63 100644 GIT binary patch delta 27 hcmeAYZV=`^&CAQh00a*jS8n9K$H=I>`3vJTb^vG>2sr=% delta 37 rcmZn=?h@uc&CAQh00aebb2f6{W8~J=FV8Q^E-pw+PTl;9aT+@Sze)?C From 62e92c389950d289e99a1d40c53da425b741b598 Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Sun, 24 Aug 2025 08:33:38 +0000 Subject: [PATCH 2/3] best json --- reflexion_oneshot_tritonbench_2.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 reflexion_oneshot_tritonbench_2.json diff --git a/reflexion_oneshot_tritonbench_2.json b/reflexion_oneshot_tritonbench_2.json new file mode 100644 index 0000000..ea3afbf --- /dev/null +++ b/reflexion_oneshot_tritonbench_2.json @@ -0,0 +1 @@ +[{"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", "label": null, "file": "matrix_vector_multip.py", "test_code": "\n\n\ndef test_mv():\n # \u6d4b\u8bd5\u7528\u4f8b 2: 4x3 \u77e9\u9635\u4e0e 3x1 \u5411\u91cf\u76f8\u4e58\n A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda')\n B = torch.tensor([1.0, 2.0, 3.0], device='cuda')\n triton_result_2 = mv(A, B)\n\n # \u6d4b\u8bd5\u7528\u4f8b 3: 32x16 \u77e9\u9635\u4e0e 16x1 \u5411\u91cf\u76f8\u4e58\n A = torch.randn(32, 16, device='cuda')\n B = torch.randn(16, device='cuda')\n triton_result_3 = mv(A, B)\n\n return {\n \"test_case_2\": triton_result_2,\n \"test_case_3\": triton_result_3,\n }\n\nresult_gold = test_mv()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mv_kernel(A, B, C, M, N,\n stride_am: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr):\n pid_m = tl.program_id(0)\n\n offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n mask_m = offs_m < N\n acc = tl.full((BLOCK_M,), 0.0, dtype=tl.float32)\n\n for n_start in range(0, M, BLOCK_N):\n offs_n_block = n_start + offs_n\n mask_n = offs_n_block < M\n\n a_offs = offs_m[:, None] * stride_am + offs_n_block[None, :]\n a = tl.load(A + a_offs, mask=mask_m[:, None] & mask_n[None, :], other=0.0)\n b = tl.load(B + offs_n_block, mask=mask_n, other=0.0)\n\n acc += tl.sum(a * b[None, :], axis=1)\n\n tl.store(C + offs_m, acc.to(C.dtype.element_ty), mask=mask_m)\n\ndef mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n assert A.dim() == 2, \"Matrix A should be 2D (NxM)\"\n assert B.dim() == 1, \"Vector B should be 1D (M)\"\n N, M = A.shape\n assert B.size(0) == M, f\"Cannot multiply A ({N}x{M}) and B ({B.size(0)}), inner dimensions must match\"\n\n C = torch.empty((N,), device=A.device, dtype=A.dtype)\n\n def grid(meta):\n return (triton.cdiv(N, meta['BLOCK_M']),)\n\n mv_kernel[grid](\n A, B, C, M, N,\n stride_am=A.stride(0),\n BLOCK_M=64,\n BLOCK_N=64\n )\n return C\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", "label": null, "file": "triton_matmul.py", "test_code": "\n\n\nimport torch\n\n# Test for matmul\ndef test_matmul():\n results = {}\n M, K, N = 256, 128, 256\n\n # Test case 1: torch.float16\n a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n b = torch.randn((K, N), dtype=torch.float16, device='cuda')\n c = matmul(a, b)\n results['test_case_1'] = c\n\n return results\n\n# Run all tests\nresult_gold = test_matmul()", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n offs_am = tl.max_contiguous(tl.multiple_of(offs_am % M, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % N, BLOCK_SIZE_N), BLOCK_SIZE_N)\n offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_SIZE_K), BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n k_offset = k * BLOCK_SIZE_K\n a_mask = (offs_k[None, :] + k_offset < K)\n b_mask = (offs_k[:, None] + k_offset < K)\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b = tl.load(b_ptrs, mask=b_mask, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cm = tl.max_contiguous(tl.multiple_of(offs_cm % M, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_cn = tl.max_contiguous(tl.multiple_of(offs_cn % N, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n c = accumulator.to(tl.float16)\n ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(ptrs, c, mask=mask)\n\ndef matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n assert a.dim() == 2 and b.dim() == 2, \"Both tensors must be 2D\"\n M, K = a.shape\n K_b, N = b.shape\n assert K == K_b, \"Inner dimensions must match\"\n assert a.is_contiguous(), \"A must be contiguous\"\n assert b.is_contiguous(), \"B must be contiguous\"\n\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)\n\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n BLOCK_SIZE_M=128,\n BLOCK_SIZE_N=128,\n BLOCK_SIZE_K=32,\n GROUP_SIZE_M=8,\n )\n return c\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", "label": null, "file": "embedding_triton_kernel.py", "test_code": "\n\n\nimport torch\n\ndef test_embedding():\n # \u53c2\u6570\u5b9a\u4e49\n vocab_size = 1000 # \u8bcd\u6c47\u8868\u5927\u5c0f\n embedding_dim = 512 # \u5d4c\u5165\u7ef4\u5ea6\n sequence_length = 128 # \u8f93\u5165\u5e8f\u5217\u957f\u5ea6\n vob_start_id = 10 # \u8bcd\u6c47\u8868\u8d77\u59cb ID\n vob_end_id = 1000 # \u8bcd\u6c47\u8868\u7ed3\u675f ID\n\n # \u521b\u5efa\u6d4b\u8bd5\u8f93\u5165\u5f20\u91cf\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n\n # \u8c03\u7528\u5d4c\u5165\u51fd\u6570\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n\n # \u4fdd\u5b58\u7ed3\u679c\n results = {}\n results['test_case_1'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8f93\u5165\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_2'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8bcd\u6c47\u8868\u8303\u56f4\n vob_start_id = 0\n vob_end_id = 500\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_3'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u5d4c\u5165\u7ef4\u5ea6\n embedding_dim = 256\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_4'] = out.clone()\n\n return results\n\nresult_gold = test_embedding()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef embedding_kernel(\n ids,\n weight,\n out,\n stride_idsb,\n stride_idss,\n stride_idsn,\n stride_wb,\n stride_wn,\n stride_wd,\n stride_ob,\n stride_os,\n stride_od,\n N: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_b = pid\n pid_s = 0\n offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n mask = offset_n < N\n\n offset_ids = pid_b * stride_idsb + pid_s * stride_idss\n ids_ptr = ids + offset_ids + offset_n * stride_idsn\n curr_ids = tl.load(ids_ptr, mask=mask).to(tl.int32)\n\n offset_d = tl.arange(0, BLOCK_DMODEL)\n for i in range(BLOCK_NN):\n offset_n_i = offset_n + i * BLOCK_N\n mask_i = offset_n_i < N\n\n ids_i = tl.load(ids + offset_ids + offset_n_i * stride_idsn, mask=mask_i).to(tl.int32)\n weight_ptr = weight + ids_i * stride_wn + offset_d[None, :] * stride_wd\n out_ptr = out + pid_b * stride_ob + offset_n_i * stride_os + offset_d[None, :] * stride_od\n tl.store(out_ptr, tl.load(weight_ptr), mask=mask_i[:, None])\n\n\ndef embedding(ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:\n assert ids.dim() == 2\n B, S = ids.shape\n D = weight.size(-1)\n out = torch.empty((B, S, D), dtype=weight.dtype, device=weight.device)\n\n BLOCK_N = 16\n BLOCK_NN = 4\n BLOCK_DMODEL = triton.next_power_of_2(D)\n\n grid = lambda META: (triton.cdiv(B * S, META[\"BLOCK_N\"]),)\n embedding_kernel[grid](\n ids,\n weight,\n out,\n ids.stride(0),\n ids.stride(1),\n 1,\n weight.stride(0),\n weight.stride(1),\n 1,\n out.stride(0),\n out.stride(1),\n out.stride(2),\n B * S,\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n BLOCK_DMODEL=BLOCK_DMODEL,\n )\n return out\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", "label": null, "file": "int4_matmul.py", "test_code": "\n\n\ndef test_correct_int4_s2(M=32, K=4096, N=4096):\n group_size = 128\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size)\n \n # Test case\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n \n results = {\n \"test_case_1\": triton_output\n }\n \n return results\n\nresult_gold = test_correct_int4_s2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scale_g, stride_scale_n,\n stride_zero_g, stride_zero_n,\n groupsize,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n IS_EVEN_K: tl.constexpr,\n):\n pid = tl.program_id(0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = 8 * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * 8\n group_size_m = min(num_pid_m - first_pid_m, 8)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n\n scales_offs_m = offs_n\n zeros_offs_m = (offs_n // 8)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k_start in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n k_idx = k_start * BLOCK_SIZE_K * SPLIT_K + offs_k\n mask_k = k_idx < K if not IS_EVEN_K else None\n\n a = tl.load(a_ptrs, mask=mask_k, other=0.0).to(tl.float32)\n\n packed = tl.load(b_ptrs, mask=mask_k, other=0)\n shift = (k_idx % 8) * 4\n b_int4 = (packed >> shift)[:, None] & 0xF\n b_int4 = b_int4.squeeze(-1)\n\n group_id_k = tl.min(k_idx, axis=0) // groupsize\n scale = tl.load(scales_ptr + group_id_k * stride_scale_g + scales_offs_m * stride_scale_n)\n zero = tl.load(zeros_ptr + group_id_k * stride_zero_g + zeros_offs_m * stride_zero_n)\n zero = (zero >> ((offs_n % 8) * 4)) & 0xF\n zero = zero * scale\n\n b = (b_int4.to(tl.float32) - zero[None, :]) * scale[None, :]\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n\n if SPLIT_K > 1:\n tl.atomic_add(c_ptrs, accumulator, mask=mask_c)\n else:\n tl.store(c_ptrs, accumulator.to(c_ptr.dtype.element_ty), mask=mask_c)\n\n\ndef matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor,\n zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor:\n assert x.dtype in (torch.float16, torch.float32)\n assert qweight.dtype == torch.int32\n assert scales.dtype in (torch.float16, torch.float32)\n assert zeros.dtype in (torch.float16, torch.float32)\n\n M, K = x.shape\n K_packed, N = qweight.shape\n assert K_packed * 8 == K\n\n output = torch.empty((M, N), dtype=x.dtype, device=x.device)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n\n matmul_kernel[grid](\n x, qweight, output,\n scales, zeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0) if scales.dim() > 1 else 0,\n scales.stride(1) if scales.dim() > 1 else 1,\n zeros.stride(0) if zeros.dim() > 1 else 0,\n zeros.stride(1) if zeros.dim() > 1 else 1,\n group_size,\n )\n return output\n\n\ndef quantize_int4(w: torch.Tensor, group_size: int = 128):\n assert w.dtype in [torch.float16, torch.float32]\n K, N = w.shape\n assert K % group_size == 0, \"K must be divisible by groupsize\"\n\n w_reshaped = w.view(-1, group_size, N)\n wmin = w_reshaped.min(dim=1, keepdim=True)[0]\n wmax = w_reshaped.max(dim=1, keepdim=True)[0]\n scale = (wmax - wmin) / 15.0\n zero = torch.round(-wmin / scale).clamp(0, 15)\n q = torch.round(w_reshaped / scale + zero).clamp(0, 15).to(torch.uint8)\n q = q.view(K, N)\n\n packed = torch.zeros(K // 8, N, dtype=torch.int32, device=w.device)\n for k in range(8):\n packed |= (q[k::8, :].to(torch.int32) << (k * 4))\n scale = scale.squeeze(1).to(torch.float16)\n zero = zero.squeeze(1).to(torch.int32)\n\n return packed, scale, zero\n\n\ndef unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128):\n assert packed.dtype == torch.int32\n K_packed, N = packed.shape\n K = K_packed * 8\n\n unpacked = torch.zeros(K, N, dtype=torch.float32, device=packed.device)\n for k in range(8):\n unpacked[k::8, :] = ((packed >> (k * 4)) & 0xF).float()\n\n unpacked = unpacked.view(-1, group_size, N)\n scale_exp = scale.view(-1, 1, N)\n zero_exp = zero.view(-1, 1, N)\n unpacked = (unpacked - zero_exp.float()) * scale_exp\n return unpacked.view(K, N).to(torch.float16)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", "label": null, "file": "flash_decode2_phi.py", "test_code": "\n\n\nimport torch\n\n# Define the test function\ndef test_flash_decode_stage2():\n # Define the parameters for different test cases\n batch_size = 2\n head_num = 4\n seq_block_num = 3\n head_dim = 64\n block_seq = 16\n\n test_cases = {\n \"test_case_1\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq\n },\n \"test_case_2\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq + 1 # Different block size\n },\n \"test_case_3\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq // 2 # Different block size\n },\n \"test_case_4\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq * 2 # Different block size\n }\n }\n\n # Execute the function for all test cases\n results = {}\n for key, test_case in test_cases.items():\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n results[key] = test_case[\"Out\"]\n\n return results\n\n# Run the test\nresult_gold = test_flash_decode_stage2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n Out,\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_lseb,\n stride_mid_lseh,\n stride_mid_lses,\n stride_ob,\n stride_oh,\n stride_od,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_seq_len = tl.load(B_Seqlen + cur_batch)\n block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = float(\"-inf\")\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for block_i in range(0, block_n_size):\n tv = tl.load(\n Mid_O\n + cur_batch * stride_mid_ob\n + cur_head * stride_mid_oh\n + block_i * stride_mid_os\n + offs_d * stride_mid_od\n ).to(tl.float32)\n\n tlogic = tl.load(\n Mid_O_LogExpSum\n + cur_batch * stride_mid_lseb\n + cur_head * stride_mid_lseh\n + block_i * stride_mid_lses\n ).to(tl.float32)\n\n old_max = max_logic\n new_max = tl.maximum(max_logic, tlogic)\n max_logic = new_max\n\n sum_exp = sum_exp * tl.exp(old_max - new_max) + tl.exp(tlogic - new_max)\n acc = acc * tl.exp(old_max - new_max) + tv * tl.exp(tlogic - new_max)\n\n acc = acc / sum_exp\n\n offs_out = (\n cur_batch * stride_ob\n + cur_head * stride_oh\n + offs_d * stride_od\n )\n tl.store(Out + offs_out, acc.to(Out.dtype.element_ty))\n\n\n@torch.no_grad()\ndef flash_decode_stage2(\n Mid_O: torch.Tensor,\n Mid_O_LogExpSum: torch.Tensor,\n B_Seqlen: torch.Tensor,\n Out: torch.Tensor,\n block_seq: int,\n):\n batch, head_num = Mid_O.shape[0], Mid_O.shape[1]\n dim = Mid_O.shape[-1]\n\n assert Mid_O_LogExpSum.shape[0] == batch and Mid_O_LogExpSum.shape[1] == head_num\n assert Out.shape[0] == batch and Out.shape[1] == head_num and Out.shape[2] == dim\n\n BLOCK_SEQ = block_seq\n BLOCK_DMODEL = dim\n\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n Out,\n Mid_O.stride(0),\n Mid_O.stride(1),\n Mid_O.stride(2),\n Mid_O.stride(3),\n Mid_O_LogExpSum.stride(0),\n Mid_O_LogExpSum.stride(1),\n Mid_O_LogExpSum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=1,\n )\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", "label": null, "file": "matrix_transpose.py", "test_code": "\n\n\nimport torch\n\ndef test_triton_vs_torch():\n results = {}\n\n # \u6d4b\u8bd5\u7528\u4f8b 1: \u57fa\u672c\u77e9\u9635\u8f6c\u7f6e (\u5c0f\u77e9\u9635)\n size_m, d_head = 16, 16\n out = wrapper(size_m, d_head)\n results[\"test_case_1\"] = out.clone()\n\n # \u6d4b\u8bd5\u7528\u4f8b 2: \u975e\u65b9\u5f62\u77e9\u9635\n size_m, d_head = 32, 64\n out = wrapper(size_m, d_head)\n results[\"test_case_2\"] = out.clone()\n\n return results\n\n\n# \u8fd0\u884c\u6d4b\u8bd5\nresult_gold = test_triton_vs_torch()\n# print(result_gold)", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M,\n D_HEAD,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n\n offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n mask_read = (offs_m[:, None] < SIZE_M) & (offs_n[None, :] < D_HEAD)\n m_ptrs = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey\n x = tl.load(m_ptrs, mask=mask_read, other=0.0)\n\n mask_write = (offs_n[:, None] < D_HEAD) & (offs_m[None, :] < SIZE_M)\n o_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey\n tl.store(o_ptrs, x.trans(), mask=mask_write)\n\ndef wrapper(\n SIZE_M: int,\n D_HEAD: int,\n) -> torch.Tensor:\n matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda')\n out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda')\n\n BLOCK_M = 32\n BLOCK_N = 32\n grid = ((SIZE_M + BLOCK_M - 1) // BLOCK_M, (D_HEAD + BLOCK_N - 1) // BLOCK_N)\n\n kernel[grid](\n matrix, out,\n matrix.stride(0),\n matrix.stride(1),\n out.stride(0),\n out.stride(1),\n SIZE_M,\n D_HEAD,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n )\n\n return out\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n ", "label": null, "file": "rotary_transform.py", "test_code": "\n\n\nimport torch\n\ndef test_apply_rotary():\n results = {}\n \n # Test case 1: Basic test with fixed sequence length and no interleaving\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin)\n results['test_case_1'] = output.shape\n\n # Test case 2: Variable length sequences with interleaving\n total_seqlen, nheads, headdim = 256, 4, 64\n batch = 3\n cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda')\n max_seqlen = 128\n rotary_dim = 32\n x = torch.randn(total_seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True)\n results['test_case_2'] = output.shape\n\n # Test case 3: Conjugate flag enabled\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, conjugate=True)\n results['test_case_3'] = output.shape\n\n # Test case 4: Inplace operation\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, inplace=True)\n results['test_case_4'] = output.shape\n\n return results\n\nresult_gold = test_apply_rotary()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Union, Optional\n\n\n@triton.jit\ndef rotary_kernel(\n OUT,\n X,\n COS,\n SIN,\n CU_SEQLENS,\n SEQLEN_OFFSETS,\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n CACHE_KEY_SEQLEN,\n stride_out_batch,\n stride_out_seqlen,\n stride_out_nheads,\n stride_out_headdim,\n stride_x_batch,\n stride_x_seqlen,\n stride_x_nheads,\n stride_x_headdim,\n BLOCK_K: tl.constexpr,\n IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr,\n INTERLEAVED: tl.constexpr,\n CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_batch = tl.program_id(1)\n pid_head = tl.program_id(2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n else_seqlen = seqlen if not IS_VARLEN else cur_seqlen\n\n if pid_m * BLOCK_M >= else_seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rm_cs = tl.where(rm < else_seqlen, rm_cs, 0)\n\n if not INTERLEAVED:\n rk_half = tl.arange(0, BLOCK_K // 2)\n mask_half = (rm[:, None] < else_seqlen) & (rk_half[None, :] < rotary_dim_half)\n cos_mask = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half)\n\n cos = tl.load(\n COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :],\n mask=cos_mask,\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :],\n mask=cos_mask,\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(\n X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim,\n mask=mask_half,\n other=0.0,\n ).to(tl.float32)\n x1 = tl.load(\n X + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim,\n mask=mask_half,\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n tl.store(\n OUT + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim,\n o0,\n mask=mask_half,\n )\n tl.store(\n OUT + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim,\n o1,\n mask=mask_half,\n )\n else:\n rk = tl.arange(0, BLOCK_K)\n mask = (rm[:, None] < else_seqlen) & (rk[None, :] < rotary_dim)\n rk_half = rk // 2\n cos_mask = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half)\n\n cos = tl.load(\n COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :],\n mask=cos_mask,\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :],\n mask=cos_mask,\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(\n X + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim,\n mask=mask,\n other=0.0,\n ).to(tl.float32)\n x1 = tl.load(\n X + rm[:, None] * stride_x_seqlen + (rk[None, :] ^ 1) * stride_x_headdim,\n mask=mask,\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(\n rk[None, :] % 2 == 0,\n x0_cos - x1_sin,\n x0_cos + x1_sin,\n )\n tl.store(\n OUT + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim,\n out,\n mask=mask,\n )\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved: bool = False,\n inplace: bool = False,\n conjugate: bool = False,\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n if x.ndim == 3:\n x = x.unsqueeze(1)\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n if x.ndim == 2:\n x = x.unsqueeze(1)\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim_half = cos.shape\n rotary_dim = rotary_dim_half * 2\n assert sin.shape == cos.shape\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert cos.dtype == sin.dtype, f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert x.dtype == cos.dtype, f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.numel() == batch\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert isinstance(seqlen_offsets, int)\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n grid = lambda META: (\n triton.cdiv(seqlen, META[\"BLOCK_M\"]),\n batch,\n nheads,\n )\n\n rotary_kernel[grid](\n output,\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128,\n output.stride(0) if not is_varlen else 1,\n output.stride(-2) if is_varlen else output.stride(-3),\n output.stride(-2),\n output.stride(-1),\n x.stride(0) if not is_varlen else 1,\n x.stride(-2) if is_varlen else x.stride(-3),\n x.stride(-2),\n x.stride(-1),\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", "label": null, "file": "sin_kernel.py", "test_code": "\n\n\nimport torch\n\n# Function to test the Triton kernel\ndef test_call_kernel():\n results = {}\n \n # Test case 1: Small input tensor\n x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda()\n output1 = call_kernel(x1)\n results['test_case_1'] = output1\n \n # Test case 2: Larger input tensor\n x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda()\n output2 = call_kernel(x2)\n results['test_case_2'] = output2\n\n # Test case 3: Edge case with zero elements\n x3 = torch.tensor([], dtype=torch.float32).cuda()\n output3 = call_kernel(x3)\n results['test_case_3'] = output3\n\n # Test case 4: Input tensor with negative values\n x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda()\n output4 = call_kernel(x4)\n results['test_case_4'] = output4\n \n return results\n\n# Run the test function\nresult_gold = test_call_kernel()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef kernel_function(\n x_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.math.sin(x.to(tl.float32))\n tl.store(output_ptr + offsets, y, mask=mask)\n\n\ndef call_kernel(x: torch.Tensor, output: torch.Tensor = None, BLOCK_SIZE: int = 128):\n assert x.is_cuda, \"Input tensor must reside on CUDA (ROCm) device\"\n assert x.is_contiguous(), \"Input tensor must be contiguous\"\n assert (BLOCK_SIZE & (BLOCK_SIZE - 1)) == 0, \"BLOCK_SIZE must be a power of two\"\n n_elements = x.numel()\n if n_elements == 0:\n if output is None:\n output = torch.empty_like(x)\n return output\n\n x = x.view(-1)\n if output is None:\n output = torch.empty_like(x)\n else:\n output = output.reshape(-1)\n assert output.numel() >= n_elements\n output = output.to(x.device)\n\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n kernel_function[grid](\n x_ptr=x,\n output_ptr=output,\n n_elements=n_elements,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", "label": null, "file": "l2_norm_bwd.py", "test_code": "\n\n\nimport torch\n\n# Test the backward L2 normalization\ndef test_l2_norm_bwd():\n results = {}\n \n # Test case 1: Default case\n x = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_1'] = dx\n\n # Test case 2: Different shape\n x = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dy = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_2'] = dx\n\n # Test case 3: Larger tensor\n x = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_3'] = dx\n\n # Test case 4: Edge case with small tensor\n x = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_4'] = dx\n\n return results\n\n# Run the tests\nresult_gold = test_l2_norm_bwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input tensor\n DY, # pointer to the output gradient tensor\n DX, # pointer to the input gradient tensor\n stride_x_row, # stride between rows in X\n N, # number of columns (feature dim)\n eps, # small constant for numerical stability\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0) # row index\n cols = tl.arange(0, BLOCK_N) # column indices\n mask = cols < N # valid mask for this row\n\n # Offset to row start\n X_row = X + row * stride_x_row\n DY_row = DY + row * stride_x_row\n DX_row = DX + row * stride_x_row\n\n # Load data\n x = tl.load(X_row + cols, mask=mask, other=0.0).to(tl.float32)\n dy = tl.load(DY_row + cols, mask=mask, other=0.0).to(tl.float32)\n\n var = tl.sum(x * x, axis=0)\n rstd = tl.math.rsqrt(var + eps)\n\n # Compute gradient\n dx = dy * rstd - tl.sum(dy * x, axis=0) * rstd / (var + eps) * x\n tl.store(DX_row + cols, dx.to(DX.dtype.element_ty), mask=mask)\n\n\ndef _l2_norm_bwd(\n x: torch.Tensor,\n dy: torch.Tensor,\n eps: float = 1e-5\n) -> torch.Tensor:\n # Flatten and ensure contiguous\n shape_og = x.shape\n x = x.reshape(-1, x.size(-1)).contiguous()\n dy = dy.reshape(-1, dy.size(-1)).contiguous()\n\n M, N = x.shape\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"Feature dimension too large for the fused kernel.\")\n\n dx = torch.empty_like(x)\n\n grid = (M,)\n _l2_norm_bwd_kernel[grid](\n x, dy, dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n\n return dx.view(shape_og)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", "label": null, "file": "l2_norm_triton1.py", "test_code": "\n\n\nimport torch\n\n# Test the forward L2 normalization\ndef test_l2_norm_fwd():\n results = {}\n \n # Test case 1\n x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n y1 = _l2_norm_fwd(x1)\n results['test_case_1'] = y1\n\n # Test case 2: Different batch size\n x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32)\n y2 = _l2_norm_fwd(x2)\n results['test_case_2'] = y2\n\n # Test case 3: Different feature size\n x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32)\n y3 = _l2_norm_fwd(x3)\n results['test_case_3'] = y3\n\n # Test case 4: Larger tensor\n x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n y4 = _l2_norm_fwd(x4)\n results['test_case_4'] = y4\n\n return results\n\nresult_gold = test_l2_norm_fwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X,\n Y,\n stride_x_row,\n stride_y_row,\n N,\n eps,\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n x_ptrs = X + row * stride_x_row\n y_ptrs = Y + row * stride_y_row\n\n acc = tl.zeros([], dtype=tl.float32)\n\n for col_offset in tl.range(0, N, BLOCK_N):\n cols = col_offset + tl.arange(0, BLOCK_N)\n mask = cols < N\n x = tl.load(x_ptrs + cols, mask=mask, other=0.0).to(tl.float32)\n acc += tl.sum(x * x)\n\n rstd = tl.math.rsqrt(acc + eps)\n\n for col_offset in tl.range(0, N, BLOCK_N):\n cols = col_offset + tl.arange(0, BLOCK_N)\n mask = cols < N\n x = tl.load(x_ptrs + cols, mask=mask, other=0.0).to(tl.float32)\n x_norm = x * rstd\n tl.store(y_ptrs + cols, x_norm, mask=mask)\n\n\ndef _l2_norm_fwd(\n x: torch.Tensor,\n eps: float = 1e-6,\n) -> torch.Tensor:\n x_orig_shape = x.shape\n x = x.view(-1, x.shape[-1]).contiguous()\n M, N = x.shape\n y = torch.empty_like(x)\n\n element_size = x.element_size()\n max_block_bytes = 65536\n BLOCK_N = min(triton.next_power_of_2(N), max_block_bytes // element_size)\n BLOCK_N = max(BLOCK_N, 16)\n\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n y.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n\n return y.view(x_orig_shape)\n"}] \ No newline at end of file From 05c7501cd5f557edb4b06983cb25328e0d4ee078 Mon Sep 17 00:00:00 2001 From: bingoo <1575938147@qq.com> Date: Sun, 24 Aug 2025 08:56:52 +0000 Subject: [PATCH 3/3] update prompts --- src/configs/tritonbench_oneshot_config.yaml | 2 +- src/prompts/prompt_for_generation.py | 2 +- src/prompts/prompt_for_reflection.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/configs/tritonbench_oneshot_config.yaml b/src/configs/tritonbench_oneshot_config.yaml index b54bc4c..336e279 100644 --- a/src/configs/tritonbench_oneshot_config.yaml +++ b/src/configs/tritonbench_oneshot_config.yaml @@ -1,7 +1,7 @@ # LLM model api_key: "" model_id: "Kimi-K2-Instruct" -temperature: 1.0 +temperature: 0.1 # TritonBench statis_path: "/hackathon-agent/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json" diff --git a/src/prompts/prompt_for_generation.py b/src/prompts/prompt_for_generation.py index 6d07461..2b96f60 100644 --- a/src/prompts/prompt_for_generation.py +++ b/src/prompts/prompt_for_generation.py @@ -38,7 +38,7 @@ * **`tl.arange`:** Arguments `start` and `end` **must be `tl.constexpr`**. * **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`. 8. **Triton Version:** Assume Triton version 3.1.0 or later. - +9. If the input is float and double, convert to bf16 precision for calculation; if the input is int, use int8 precision for calculation **FINAL VERIFICATION:** Before completing, verify: 1. ALL functions defined in the code have EXACT signatures matching the required function signatures above. diff --git a/src/prompts/prompt_for_reflection.py b/src/prompts/prompt_for_reflection.py index fe3f936..a8f341d 100644 --- a/src/prompts/prompt_for_reflection.py +++ b/src/prompts/prompt_for_reflection.py @@ -17,6 +17,7 @@ **Important Instructions:** - Think before writing the reflection and no more explanation is required after the reflection. - You should not suggest changes to the name of the function. +- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" @@ -47,6 +48,7 @@ **Important Instructions:** - Think before writing the reflection and no more explanation is required after the reflection. - You should not suggest changes to the name of the function. +- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" @@ -100,6 +102,7 @@ **Important Instructions:** - Think before writing the reflection and no more explanation is required after the reflection. - You should not suggest changes to the name of the function. +- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" @@ -239,6 +242,7 @@ def grid(args: dict[str, Any]) -> tuple[int]: **Important Instructions:** - Think before writing the optimization and no more explanation is required after the reflection. - You should not suggest changes to the name of the function and parameter names, counts, or order. +- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```"