diff --git a/out.json_perf_0.json b/out.json_perf_0.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/out.json_perf_0.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4.json b/reflexion_oneshot_tritonbench_4.json new file mode 100644 index 0000000..6ae2c2f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4.json @@ -0,0 +1 @@ +[{"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", "label": null, "file": "matrix_vector_multip.py", "test_code": "\n\n\ndef test_mv():\n # \u6d4b\u8bd5\u7528\u4f8b 2: 4x3 \u77e9\u9635\u4e0e 3x1 \u5411\u91cf\u76f8\u4e58\n A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda')\n B = torch.tensor([1.0, 2.0, 3.0], device='cuda')\n triton_result_2 = mv(A, B)\n\n # \u6d4b\u8bd5\u7528\u4f8b 3: 32x16 \u77e9\u9635\u4e0e 16x1 \u5411\u91cf\u76f8\u4e58\n A = torch.randn(32, 16, device='cuda')\n B = torch.randn(16, device='cuda')\n triton_result_3 = mv(A, B)\n\n return {\n \"test_case_2\": triton_result_2,\n \"test_case_3\": triton_result_3,\n }\n\nresult_gold = test_mv()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr,\n stride_bk: tl.constexpr, stride_cm: tl.constexpr,\n N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):\n pid_n = tl.program_id(0) # one program per N-tile\n start_n = pid_n * BLOCK_N\n offsets_n = start_n + tl.arange(0, BLOCK_N)\n mask_n = offsets_n < N\n\n # allocate register tile\n accum = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n for start_m in range(0, M, BLOCK_M):\n offsets_m = start_m + tl.arange(0, BLOCK_M)\n mask_m = offsets_m < M\n # A tile: [BLOCK_N, BLOCK_M]\n offs_a = A + (offsets_n[:, None] * stride_am +\n offsets_m[None, :] * stride_ak)\n tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0)\n # B segment: [BLOCK_M]\n offs_b = B + offsets_m * stride_bk\n seg_b = tl.load(offs_b, mask=mask_m, other=0.0)\n\n accum += tl.sum(tile_a * seg_b[None, :], axis=1)\n\n # write back\n offs_c = C + offsets_n * stride_cm\n tl.store(offs_c, accum, mask=mask_n)\n\ndef mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n assert A.dtype == B.dtype\n assert A.size(1) == B.size(0)\n N, M = A.shape\n C = torch.empty((N,), dtype=A.dtype, device=A.device)\n BLOCK_N = 128\n BLOCK_M = 64\n grid = (triton.cdiv(N, BLOCK_N),)\n mv_kernel[grid](\n A, B, C,\n stride_am=A.stride(0), stride_ak=A.stride(1),\n stride_bk=B.stride(0), stride_cm=C.stride(0),\n N=N, M=M,\n BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M\n )\n return C\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", "label": null, "file": "triton_matmul.py", "test_code": "\n\n\nimport torch\n\n# Test for matmul\ndef test_matmul():\n results = {}\n M, K, N = 256, 128, 256\n\n # Test case 1: torch.float16\n a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n b = torch.randn((K, N), dtype=torch.float16, device='cuda')\n c = matmul(a, b)\n results['test_case_1'] = c\n\n return results\n\n# Run all tests\nresult_gold = test_matmul()", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K)\n b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N)\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b = tl.load(b_ptrs, mask=b_mask, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c = accumulator.to(c_ptr.dtype.element_ty)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n assert a.dim() == 2, \"Input a must be 2-D\"\n assert b.dim() == 2, \"Input b must be 2-D\"\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K2, N = b.shape\n assert K == K2, \"Inner dimensions must match\"\n\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n\n matmul_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n )\n return c\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", "label": null, "file": "embedding_triton_kernel.py", "test_code": "\n\n\nimport torch\n\ndef test_embedding():\n # \u53c2\u6570\u5b9a\u4e49\n vocab_size = 1000 # \u8bcd\u6c47\u8868\u5927\u5c0f\n embedding_dim = 512 # \u5d4c\u5165\u7ef4\u5ea6\n sequence_length = 128 # \u8f93\u5165\u5e8f\u5217\u957f\u5ea6\n vob_start_id = 10 # \u8bcd\u6c47\u8868\u8d77\u59cb ID\n vob_end_id = 1000 # \u8bcd\u6c47\u8868\u7ed3\u675f ID\n\n # \u521b\u5efa\u6d4b\u8bd5\u8f93\u5165\u5f20\u91cf\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n\n # \u8c03\u7528\u5d4c\u5165\u51fd\u6570\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n\n # \u4fdd\u5b58\u7ed3\u679c\n results = {}\n results['test_case_1'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8f93\u5165\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_2'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8bcd\u6c47\u8868\u8303\u56f4\n vob_start_id = 0\n vob_end_id = 500\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_3'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u5d4c\u5165\u7ef4\u5ea6\n embedding_dim = 256\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_4'] = out.clone()\n\n return results\n\nresult_gold = test_embedding()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef embedding_kernel(\n weight,\n out,\n indices,\n vocab_size,\n d_model,\n vob_start_id,\n stride_out_0,\n stride_weight_0,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n for k in range(0, BLOCK_NN, BLOCK_N):\n seq_off = pid * BLOCK_NN + k + tl.arange(0, BLOCK_N)\n mask_seq = seq_off < out.shape[0]\n\n token_idx_raw = tl.load(indices + seq_off, mask=mask_seq, other=0).to(tl.int32)\n token_idx = token_idx_raw - vob_start_id\n token_idx = tl.where(token_idx >= 0, token_idx, 0)\n token_idx = tl.where(token_idx < vocab_size, token_idx, vocab_size - 1)\n\n w_offs = (\n token_idx[:, None] * d_model\n + offs_d[None, :]\n )\n vec = tl.load(weight + w_offs)\n\n o_offs = (\n seq_off[:, None] * stride_out_0\n + offs_d[None, :]\n )\n tl.store(out + o_offs, vec, mask=mask_seq[:, None])\n\n\ndef embedding(\n weight: torch.Tensor,\n indices: torch.Tensor,\n vob_start_id: int,\n vob_end_id: int,\n out: torch.Tensor = None\n) -> torch.Tensor:\n shape_2d = indices.shape\n B, SEQ_LEN = shape_2d\n d_model = weight.size(-1)\n weight = weight.view(-1, d_model) # Ensure [vocab_size, d_model]\n\n if out is None:\n out = torch.empty(\n (B * SEQ_LEN, d_model),\n dtype=weight.dtype,\n device=weight.device\n )\n else:\n out = out.view(B * SEQ_LEN, d_model)\n\n vocab_sz = vob_end_id - vob_start_id if vob_end_id > vob_start_id else weight.size(0)\n\n BLOCK_N = 16\n BLOCK_NN = 64\n BLOCK_DMODEL = triton.next_power_of_2(d_model)\n\n grid = lambda META: (triton.cdiv(out.shape[0], META[\"BLOCK_NN\"]), )\n\n embedding_kernel[grid](\n weight,\n out,\n indices.view(-1),\n vocab_size=vocab_sz,\n d_model=d_model,\n vob_start_id=vob_start_id,\n stride_out_0=out.stride(0),\n stride_weight_0=weight.stride(0),\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n BLOCK_DMODEL=BLOCK_DMODEL,\n )\n\n return out.view(B, SEQ_LEN, d_model)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", "label": null, "file": "int4_matmul.py", "test_code": "\n\n\ndef test_correct_int4_s2(M=32, K=4096, N=4096):\n group_size = 128\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size)\n \n # Test case\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n \n results = {\n \"test_case_1\": triton_output\n }\n \n return results\n\nresult_gold = test_correct_int4_s2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n# ===========================================================\n# Triton Kernel (Direct INT4 matmul)\n# -----------------------------------------------------------\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128,'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn, # stride_bk in ints (K//8, N)\n stride_cm, stride_cn,\n stride_scale_g, stride_scale_n,\n stride_zp_g, stride_zp_n,\n group_size,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_k = tl.program_id(1)\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)\n\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n start_k = pid_k * BLOCK_SIZE_K\n offs_k = start_k + tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n k_curr = k_step * BLOCK_SIZE_K * SPLIT_K + pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n mask_k = k_curr[None, :] < K\n\n a = tl.load(a_ptrs, mask=mask_k, other=0.0)\n\n b_raw = tl.load(b_ptrs, mask=mask_k, other=0)\n group_idx = k_curr[:, None] // group_size\n shift = (k_curr[:, None] % 8) * 4\n bits = (b_raw >> shift) & 0xF\n\n scales = tl.load(\n scales_ptr + group_idx * stride_scale_g + offs_n[None, :] * stride_scale_n,\n mask=mask_k, other=0.0\n )\n zeros = tl.load(\n zeros_ptr + group_idx * stride_zp_g + (offs_n[None, :] // 8) * stride_zp_n,\n mask=mask_k, other=0.0\n )\n zeros_bits = (zeros >> ((offs_n[None, :] % 8) * 4)) & 0xF\n\n b_deq = (bits - zeros_bits) * scales\n accumulator += tl.dot(a, b_deq)\n\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn\n\n out_vals = accumulator.to(c_ptr.dtype.element_ty)\n if SPLIT_K > 1:\n tl.atomic_add(c_ptrs, out_vals, mask=mask_c)\n else:\n tl.store(c_ptrs, out_vals, mask=mask_c)\n\n# ===========================================================\n# Python wrapper\n# -----------------------------------------------------------\ndef matmul_dequantize_int4_s2(\n x: torch.FloatTensor,\n qweight: torch.FloatTensor,\n scales: torch.FloatTensor,\n zeros: torch.FloatTensor,\n split_k: int = 1,\n) -> torch.FloatTensor:\n assert x.dim() == 2 and qweight.dim() == 2 and scales.dim() == 2 and zeros.dim() == 2\n M, K = x.shape\n assert K == qweight.shape[0] * 8\n N = qweight.shape[1]\n assert scales.shape == (K // scales.shape[0], N)\n assert zeros.shape == (K // zeros.shape[0], N)\n\n output = torch.empty((M, N), dtype=x.dtype, device=x.device)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n split_k,\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, zeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n zeros.stride(0), zeros.stride(1),\n scales.shape[0],\n )\n return output\n\n# ===========================================================\n# Quantise / Unpack helpers (correct 8\u00d7 packing)\n# -----------------------------------------------------------\ndef quantize_int4(w: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:\n \"\"\"\n w: [OC, K] (fp16/fp32 weights)\n returns: (qpacked, scales, zerospacked, group_size)\n qpacked : int32, shape [OC, K//8]\n scales : fp16/fp32 [OC, K//group_size]\n zerospacked : int32 [K//group_size, OC//8]\n \"\"\"\n assert w.dim() == 2\n OC, K = w.shape\n assert K % group_size == 0\n\n w = w.view(-1, K)\n OC_total, K_ = w.shape\n groups_per_row = K_ // group_size\n w = w.view(OC_total, groups_per_row, group_size)\n\n w_min = w.min(dim=2, keepdim=True).values\n w_max = w.max(dim=2, keepdim=True).values\n scales = (w_max - w_min) / 15.0\n zeros = torch.round(-w_min / scales).clamp(0, 15).to(torch.int32)\n\n quant = torch.clamp(torch.round((w / scales) + zeros), 0, 15).to(torch.int32)\n\n # === pack 8\u00d7OC along the OC axis (group row dimension of zero) ===\n OC_pack = OC_total // 8\n zeros_pack = zeros.permute(1, 0, 2).contiguous().view(groups_per_row, OC_pack, 8)\n zeros_packed = zeros_pack[..., 0]\n for p in range(1, 8):\n zeros_packed |= zeros_pack[..., p] << (p * 4)\n zeros_packed = zeros_packed.view(groups_per_row, OC_pack).contiguous()\n\n # === pack 8\u00d7K along the K axis (group column dimension of weight) ===\n K_pack = K_ // 8\n quant = quant.view(OC_total, K_)\n qpacked = torch.empty((OC_total, K_pack), dtype=torch.int32, device=w.device)\n\n for p in range(8):\n qpacked |= quant[:, p::8] << (p * 4)\n\n qpacked = qpacked.contiguous()\n scales = scales.view(OC_total, groups_per_row).contiguous()\n return qpacked, scales, zeros_packed, group_size\n\n\n# -----------------------------------------------------------\ndef unpack_int4(\n qweight: torch.Tensor,\n scales: torch.Tensor,\n zeros: torch.Tensor,\n group_size: int = 128\n) -> torch.Tensor:\n \"\"\"\n qweight : int32 [OC, K//8]\n scales : [OC, K//group_size]\n zeros : int32 [K//group_size, OC//8]\n returns : fp16/fp32 tensor [OC, K]\n \"\"\"\n OC, K8 = qweight.shape\n K = K8 * 8\n group_dim = K // group_size\n assert scales.shape == (OC, group_dim)\n assert zeros.shape == (group_dim, OC//8)\n\n # --- weight ---\n quant = torch.empty((OC, K), dtype=torch.int32, device=qweight.device)\n for p in range(8):\n mask = 0xF << (p * 4)\n quant[:, p::8] = (qweight & mask) >> (p * 4)\n\n # --- zero (grouped) ---\n OC8 = OC // 8\n zeros_ext = torch.empty((group_dim, OC), dtype=torch.int32, device=zeros.device)\n for p in range(8):\n mask = 0xF << (p * 4)\n zeros_ext[:, p::8] = (zeros & mask) >> (p * 4)\n zeros_ext = zeros_ext.permute(1, 0).contiguous()\n\n # --- dequant ---\n scales_mat = scales.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K)\n zeros_mat = zeros_ext.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K)\n\n return (quant.float() - zeros_mat) * scales_mat\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", "label": null, "file": "flash_decode2_phi.py", "test_code": "\n\n\nimport torch\n\n# Define the test function\ndef test_flash_decode_stage2():\n # Define the parameters for different test cases\n batch_size = 2\n head_num = 4\n seq_block_num = 3\n head_dim = 64\n block_seq = 16\n\n test_cases = {\n \"test_case_1\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq\n },\n \"test_case_2\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq + 1 # Different block size\n },\n \"test_case_3\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq // 2 # Different block size\n },\n \"test_case_4\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq * 2 # Different block size\n }\n }\n\n # Execute the function for all test cases\n results = {}\n for key, test_case in test_cases.items():\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n results[key] = test_case[\"Out\"]\n\n return results\n\n# Run the test\nresult_gold = test_flash_decode_stage2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum,# [batch, head, seq_block_num]\n Out, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_out_b,\n stride_out_h,\n stride_out_d,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_seq_len = tl.load(B_Seqlen + cur_batch)\n block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for block_seq_id in range(0, block_n_size):\n ptr_v = (Mid_O\n + cur_batch * stride_mid_ob\n + cur_head * stride_mid_oh\n + block_seq_id * stride_mid_os\n + offs_d * stride_mid_od)\n tv = tl.load(ptr_v).to(tl.float32)\n\n ptr_logic = (Mid_O_LogExpSum\n + cur_batch * stride_mid_o_eb\n + cur_head * stride_mid_o_eh\n + block_seq_id * stride_mid_o_es)\n tlogic = tl.load(ptr_logic).to(tl.float32)\n\n new_max = tl.maximum(max_logic, tlogic)\n scale = tl.exp(max_logic - new_max)\n acc = acc * scale\n sum_exp = sum_exp * scale\n\n delta = tl.exp(tlogic - new_max)\n acc += delta * tv\n sum_exp += delta\n max_logic = new_max\n\n acc = acc / sum_exp\n\n ptr_out = (Out\n + cur_batch * stride_out_b\n + cur_head * stride_out_h\n + offs_d * stride_out_d)\n tl.store(ptr_out, acc.to(Out.dtype.element_ty))\n\n\n@torch.no_grad()\ndef flash_decode_stage2(\n B_Seqlen: torch.Tensor,\n Mid_O: torch.Tensor,\n Mid_O_LogExpSum: torch.Tensor,\n Out: torch.Tensor,\n BLOCK_SEQ: int = 64,\n):\n assert Mid_O.dim() == 4\n assert Mid_O_LogExpSum.dim() == 3 # [batch, head, seq_block_num]\n assert B_Seqlen.dim() == 1\n assert Out.dim() == 3 # [batch, head, head_dim]\n assert Mid_O.shape[0] == Out.shape[0]\n assert Mid_O.shape[1] == Out.shape[1]\n assert Mid_O.shape[3] == Out.shape[2]\n assert Mid_O.shape[0] == Mid_O_LogExpSum.shape[0]\n assert Mid_O.shape[1] == Mid_O_LogExpSum.shape[1]\n assert Mid_O.shape[2] == Mid_O_LogExpSum.shape[2]\n\n BLOCK_DMODEL = Out.shape[-1]\n batch, head_num = Out.shape[0], Out.shape[1]\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n Out,\n Mid_O.stride(0),\n Mid_O.stride(1),\n Mid_O.stride(2),\n Mid_O.stride(3),\n Mid_O_LogExpSum.stride(0),\n Mid_O_LogExpSum.stride(1),\n Mid_O_LogExpSum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=2,\n )\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", "label": null, "file": "matrix_transpose.py", "test_code": "\n\n\nimport torch\n\ndef test_triton_vs_torch():\n results = {}\n\n # \u6d4b\u8bd5\u7528\u4f8b 1: \u57fa\u672c\u77e9\u9635\u8f6c\u7f6e (\u5c0f\u77e9\u9635)\n size_m, d_head = 16, 16\n out = wrapper(size_m, d_head)\n results[\"test_case_1\"] = out.clone()\n\n # \u6d4b\u8bd5\u7528\u4f8b 2: \u975e\u65b9\u5f62\u77e9\u9635\n size_m, d_head = 32, 64\n out = wrapper(size_m, d_head)\n results[\"test_case_2\"] = out.clone()\n\n return results\n\n\n# \u8fd0\u884c\u6d4b\u8bd5\nresult_gold = test_triton_vs_torch()\n# print(result_gold)", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n\n rows = size_m_arange[:, None]\n cols = d_head_arange[None, :]\n matrix_ptr = M + rows * matrix_stridex + cols * matrix_stridey\n matrix_value = tl.load(matrix_ptr)\n\n out_rows = d_head_arange[:, None]\n out_cols = size_m_arange[None, :]\n out_ptr = Out + out_rows * out_stridex + out_cols * out_stridey\n\n tl.store(out_ptr, tl.trans(matrix_value))\n\ndef wrapper(SIZE_M: int, D_HEAD: int):\n device = \"cuda\"\n matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device)\n out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device)\n grid = (1,)\n kernel[grid](\n matrix,\n out,\n matrix.stride(1),\n matrix.stride(0),\n out.stride(1),\n out.stride(0),\n SIZE_M=SIZE_M,\n D_HEAD=D_HEAD,\n )\n return out\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n ", "label": null, "file": "rotary_transform.py", "test_code": "\n\n\nimport torch\n\ndef test_apply_rotary():\n results = {}\n \n # Test case 1: Basic test with fixed sequence length and no interleaving\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin)\n results['test_case_1'] = output.shape\n\n # Test case 2: Variable length sequences with interleaving\n total_seqlen, nheads, headdim = 256, 4, 64\n batch = 3\n cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda')\n max_seqlen = 128\n rotary_dim = 32\n x = torch.randn(total_seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True)\n results['test_case_2'] = output.shape\n\n # Test case 3: Conjugate flag enabled\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, conjugate=True)\n results['test_case_3'] = output.shape\n\n # Test case 4: Inplace operation\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, inplace=True)\n results['test_case_4'] = output.shape\n\n return results\n\nresult_gold = test_apply_rotary()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\nimport logging\nfrom typing import Optional, Union\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS,\n seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN,\n stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim,\n stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim,\n BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_batch = tl.program_id(1)\n pid_head = tl.program_id(2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n cur_seqlen = seqlen\n x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n end_idx = tl.load(CU_SEQLENS + pid_batch + 1)\n cur_seqlen = end_idx - start_idx\n x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n mask_m = rm < cur_seqlen\n\n if IS_SEQLEN_OFFSETS_TENSOR:\n offset = tl.load(SEQLEN_OFFSETS + pid_batch)\n else:\n offset = SEQLEN_OFFSETS\n rm_cs = rm + offset\n mask_m_cs = rm_cs < seqlen_ro\n\n if not INTERLEAVED:\n rk_half = tl.arange(0, BLOCK_K)\n mask_k_half = rk_half < rotary_dim_half\n\n cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :]\n sin_offs = cos_offs\n cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32)\n sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n\n x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim\n x1_offs = x0_offs + rotary_dim_half * stride_x_headdim\n\n x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n\n if CONJUGATE:\n sin = -sin\n y0 = x0 * cos - x1 * sin\n y1 = x0 * sin + x1 * cos\n\n out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim\n out1_offs = out0_offs + rotary_dim_half * stride_out_headdim\n\n tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :])\n tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :])\n else:\n rk = tl.arange(0, BLOCK_K)\n mask_k = rk < rotary_dim\n\n rk_repeat = rk // 2\n cs_mask = rk_repeat[None, :] < rotary_dim_half\n\n cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]\n sin_offs = cos_offs\n cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32)\n sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32)\n\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim\n x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim\n\n x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32)\n x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32)\n\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n\n out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim\n tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :])\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved: bool = False,\n inplace: bool = False,\n conjugate: bool = False,\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch = cu_seqlens.shape[0] - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n assert rotary_dim <= headdim\n rotary_dim = rotary_dim * 2\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n assert cos.dtype == sin.dtype\n assert x.dtype == cos.dtype\n\n cos = cos.contiguous()\n sin = sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n if interleaved:\n BLOCK_M = 4\n else:\n BLOCK_M = 8 if rotary_dim <= 64 else 4\n BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256\n\n grid = lambda META: (\n triton.cdiv(seqlen, META[\"BLOCK_M\"]),\n batch,\n nheads,\n )\n\n rotary_kernel[grid](\n output,\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128,\n output.stride(0) if not is_varlen else 0,\n output.stride(-3),\n output.stride(-2),\n output.stride(-1),\n x.stride(0) if not is_varlen else 0,\n x.stride(-3),\n x.stride(-2),\n x.stride(-1),\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", "label": null, "file": "sin_kernel.py", "test_code": "\n\n\nimport torch\n\n# Function to test the Triton kernel\ndef test_call_kernel():\n results = {}\n \n # Test case 1: Small input tensor\n x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda()\n output1 = call_kernel(x1)\n results['test_case_1'] = output1\n \n # Test case 2: Larger input tensor\n x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda()\n output2 = call_kernel(x2)\n results['test_case_2'] = output2\n\n # Test case 3: Edge case with zero elements\n x3 = torch.tensor([], dtype=torch.float32).cuda()\n output3 = call_kernel(x3)\n results['test_case_3'] = output3\n\n # Test case 4: Input tensor with negative values\n x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda()\n output4 = call_kernel(x4)\n results['test_case_4'] = output4\n \n return results\n\n# Run the test function\nresult_gold = test_call_kernel()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.math.sin(x)\n tl.store(output_ptr + offsets, y, mask=mask)\n\ndef call_kernel(x: torch.Tensor) -> torch.Tensor:\n n_elements = x.numel()\n output = torch.empty_like(x)\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n kernel_function[grid](\n x,\n output,\n n_elements,\n BLOCK_SIZE=BLOCK_SIZE\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", "label": null, "file": "l2_norm_bwd.py", "test_code": "\n\n\nimport torch\n\n# Test the backward L2 normalization\ndef test_l2_norm_bwd():\n results = {}\n \n # Test case 1: Default case\n x = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_1'] = dx\n\n # Test case 2: Different shape\n x = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dy = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_2'] = dx\n\n # Test case 3: Larger tensor\n x = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_3'] = dx\n\n # Test case 4: Edge case with small tensor\n x = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_4'] = dx\n\n return results\n\n# Run the tests\nresult_gold = test_l2_norm_bwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, DY, DX,\n stride_x_row,\n N,\n eps,\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DY += row * stride_x_row\n DX += row * stride_x_row\n\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)\n\n var = tl.sum(x * x)\n rstd = 1 / tl.sqrt(var + eps)\n\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\n\ndef _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n\n if x.stride(-1) != 1:\n x = x.contiguous()\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n\n M, N = x.shape\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n dx = torch.empty_like(x)\n\n # ROCm-compatible grid launch\n grid = (M,)\n _l2_norm_bwd_kernel[grid](\n x, dy, dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", "label": null, "file": "l2_norm_triton1.py", "test_code": "\n\n\nimport torch\n\n# Test the forward L2 normalization\ndef test_l2_norm_fwd():\n results = {}\n \n # Test case 1\n x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n y1 = _l2_norm_fwd(x1)\n results['test_case_1'] = y1\n\n # Test case 2: Different batch size\n x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32)\n y2 = _l2_norm_fwd(x2)\n results['test_case_2'] = y2\n\n # Test case 3: Different feature size\n x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32)\n y3 = _l2_norm_fwd(x3)\n results['test_case_3'] = y3\n\n # Test case 4: Larger tensor\n x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n y4 = _l2_norm_fwd(x4)\n results['test_case_4'] = y4\n\n return results\n\nresult_gold = test_l2_norm_fwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n\n cols = tl.arange(0, BLOCK_N)\n\n var = tl.zeros([], dtype=tl.float32)\n for base in range(0, N, BLOCK_N):\n mask = cols < (N - base)\n data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32)\n var += tl.sum(data * data)\n\n rstd = tl.rsqrt(var + eps)\n\n for base in range(0, N, BLOCK_N):\n mask = cols < (N - base)\n data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32)\n y = data * rstd\n tl.store(Y + base + cols, y, mask=mask)\n\ndef _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5):\n x = x if x.is_contiguous() else x.contiguous()\n original_shape = x.shape\n if x.ndim > 2:\n x = x.view(-1, x.shape[-1])\n M, N = x.shape\n y = torch.empty_like(x)\n\n element_size = x.element_size()\n BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N))\n\n grid = (M,)\n _l2_norm_fwd_1pass_kernel[grid](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n return y.view(original_shape)\n"}] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp b/reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp new file mode 100644 index 0000000..43cb1a8 Binary files /dev/null and b/reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp differ diff --git a/reflexion_oneshot_tritonbench_4/exec/flash_decode2_phi.py b/reflexion_oneshot_tritonbench_4/exec/flash_decode2_phi.py new file mode 100644 index 0000000..9ccee70 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/flash_decode2_phi.py @@ -0,0 +1,146 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_obs, stride_oh, stride_od, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d * stride_mid_od + offs_l = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_id in range(0, block_n_size): + tv = tl.load(Mid_O + offs_v + block_id * stride_mid_os) + l_cur = tl.load(Mid_O_LogExpSum + offs_l + block_id * stride_mid_o_es) + new_max = tl.maximum(max_logic, l_cur) + old_scale = tl.exp(max_logic - new_max) + acc *= old_scale + sum_exp *= old_scale + cur_exp = tl.exp(l_cur - new_max) + acc += tv * cur_exp + sum_exp += cur_exp + max_logic = new_max + tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od, (acc / sum_exp).to(Out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(mid_out: torch.Tensor, mid_out_logexpsum: torch.Tensor, B_Seqlen: torch.Tensor, out: torch.Tensor, block_seq: int, BLOCK_DMODEL: int=None) -> None: + if BLOCK_DMODEL is None: + BLOCK_DMODEL = out.shape[-1] + batch, head_num = out.shape[:2] + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, mid_out, mid_out_logexpsum, out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), out.stride(0), out.stride(1), out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json new file mode 100644 index 0000000..5ba3a59 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json @@ -0,0 +1,418 @@ +[ + { + "input_size": [ + [ + 4, + 8, + 16, + 64 + ], + [ + 4, + 8, + 16 + ], + [ + 4 + ], + [ + 4, + 8, + 64 + ] + ], + "ms": 0.00685499981045723, + "GB/s": 20.61677664591699, + "TFLOPS": 0.009560321197970791 + }, + { + "input_size": [ + [ + 8, + 8, + 16, + 64 + ], + [ + 8, + 8, + 16 + ], + [ + 8 + ], + [ + 8, + 8, + 64 + ] + ], + "ms": 0.010343999601900578, + "GB/s": 27.325600432937524, + "TFLOPS": 0.012671307525564598 + }, + { + "input_size": [ + [ + 16, + 8, + 16, + 64 + ], + [ + 16, + 8, + 16 + ], + [ + 16 + ], + [ + 16, + 8, + 64 + ] + ], + "ms": 0.009943000040948391, + "GB/s": 56.85527483373911, + "TFLOPS": 0.026364678559831923 + }, + { + "input_size": [ + [ + 32, + 8, + 16, + 64 + ], + [ + 32, + 8, + 16 + ], + [ + 32 + ], + [ + 32, + 8, + 64 + ] + ], + "ms": 0.010905000381171703, + "GB/s": 103.67940948925656, + "TFLOPS": 0.048077760813765975 + }, + { + "input_size": [ + [ + 64, + 8, + 16, + 64 + ], + [ + 64, + 8, + 16 + ], + [ + 64 + ], + [ + 64, + 8, + 64 + ] + ], + "ms": 0.011105000041425228, + "GB/s": 203.6243126127705, + "TFLOPS": 0.09442377272296025 + }, + { + "input_size": [ + [ + 128, + 8, + 16, + 64 + ], + [ + 128, + 8, + 16 + ], + [ + 128 + ], + [ + 128, + 8, + 64 + ] + ], + "ms": 0.011587000451982021, + "GB/s": 390.3077434700886, + "TFLOPS": 0.18099179409639796 + }, + { + "input_size": [ + [ + 256, + 8, + 16, + 64 + ], + [ + 256, + 8, + 16 + ], + [ + 256 + ], + [ + 256, + 8, + 64 + ] + ], + "ms": 0.012749000452458858, + "GB/s": 709.4667565295695, + "TFLOPS": 0.3289908111338296 + }, + { + "input_size": [ + [ + 512, + 8, + 16, + 64 + ], + [ + 512, + 8, + 16 + ], + [ + 512 + ], + [ + 512, + 8, + 64 + ] + ], + "ms": 0.01828099973499775, + "GB/s": 989.5511329923569, + "TFLOPS": 0.4588703091516692 + }, + { + "input_size": [ + [ + 1024, + 8, + 16, + 64 + ], + [ + 1024, + 8, + 16 + ], + [ + 1024 + ], + [ + 1024, + 8, + 64 + ] + ], + "ms": 0.02782299928367138, + "GB/s": 1300.361892372729, + "TFLOPS": 0.6029981106259139 + }, + { + "input_size": [ + [ + 2048, + 8, + 16, + 64 + ], + [ + 2048, + 8, + 16 + ], + [ + 2048 + ], + [ + 2048, + 8, + 64 + ] + ], + "ms": 0.046505000442266464, + "GB/s": 1555.960333552326, + "TFLOPS": 0.7215230981807231 + }, + { + "input_size": [ + [ + 4096, + 8, + 16, + 64 + ], + [ + 4096, + 8, + 16 + ], + [ + 4096 + ], + [ + 4096, + 8, + 64 + ] + ], + "ms": 0.08667699992656708, + "GB/s": 1669.6456052079209, + "TFLOPS": 0.7742407334916386 + }, + { + "input_size": [ + [ + 8192, + 8, + 16, + 64 + ], + [ + 8192, + 8, + 16 + ], + [ + 8192 + ], + [ + 8192, + 8, + 64 + ] + ], + "ms": 0.16365200281143188, + "GB/s": 1768.6294028036255, + "TFLOPS": 0.8201410657628948 + }, + { + "input_size": [ + [ + 16384, + 8, + 16, + 64 + ], + [ + 16384, + 8, + 16 + ], + [ + 16384 + ], + [ + 16384, + 8, + 64 + ] + ], + "ms": 0.31447499990463257, + "GB/s": 1840.7806285890786, + "TFLOPS": 0.8535987155780443 + }, + { + "input_size": [ + [ + 32768, + 8, + 16, + 64 + ], + [ + 32768, + 8, + 16 + ], + [ + 32768 + ], + [ + 32768, + 8, + 64 + ] + ], + "ms": 0.5913429856300354, + "GB/s": 1957.8468065643615, + "TFLOPS": 0.9078841299318041 + }, + { + "input_size": [ + [ + 65536, + 8, + 16, + 64 + ], + [ + 65536, + 8, + 16 + ], + [ + 65536 + ], + [ + 65536, + 8, + 64 + ] + ], + "ms": 1.1500940322875977, + "GB/s": 2013.3292469958415, + "TFLOPS": 0.9336122037467414 + }, + { + "input_size": [ + [ + 131072, + 8, + 16, + 64 + ], + [ + 131072, + 8, + 16 + ], + [ + 131072 + ], + [ + 131072, + 8, + 64 + ] + ], + "ms": 2.2799609899520874, + "GB/s": 2031.1908512510645, + "TFLOPS": 0.9418949084936443 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json new file mode 100644 index 0000000..ee02aa5 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json @@ -0,0 +1,6 @@ +{ + "flash_decode2_phi.json": { + "ms": 3.6883, + "efficiency": 99.617 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json new file mode 100644 index 0000000..74f2e7b --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json @@ -0,0 +1,145 @@ +[ + { + "input_size": [ + [ + 16 + ], + [ + 16 + ] + ], + "ms": 0.006012999918311834, + "GB/s": 0.03193081699789953, + "TFLOPS": 5.321802832983255e-06 + }, + { + "input_size": [ + [ + 32 + ], + [ + 32 + ] + ], + "ms": 0.0059739998541772366, + "GB/s": 0.06427854191049123, + "TFLOPS": 1.0713090318415206e-05 + }, + { + "input_size": [ + [ + 64 + ], + [ + 64 + ] + ], + "ms": 0.0059739998541772366, + "GB/s": 0.12855708382098247, + "TFLOPS": 2.1426180636830412e-05 + }, + { + "input_size": [ + [ + 128 + ], + [ + 128 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 0.25540405045773634, + "TFLOPS": 4.256734174295606e-05 + }, + { + "input_size": [ + [ + 256 + ], + [ + 256 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 0.5108081009154727, + "TFLOPS": 8.513468348591212e-05 + }, + { + "input_size": [ + [ + 512 + ], + [ + 512 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 1.0216162018309454, + "TFLOPS": 0.00017026936697182423 + }, + { + "input_size": [ + [ + 1024 + ], + [ + 1024 + ] + ], + "ms": 0.006012999918311834, + "GB/s": 2.0435722878655698, + "TFLOPS": 0.0003405953813109283 + }, + { + "input_size": [ + [ + 2048 + ], + [ + 2048 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 4.0864648073237815, + "TFLOPS": 0.0006810774678872969 + }, + { + "input_size": [ + [ + 4096 + ], + [ + 4096 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 8.172929614647563, + "TFLOPS": 0.0013621549357745938 + }, + { + "input_size": [ + [ + 8192 + ], + [ + 8192 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 16.345859229295126, + "TFLOPS": 0.0027243098715491877 + }, + { + "input_size": [ + [ + 16384 + ], + [ + 16384 + ] + ], + "ms": 0.006696000229567289, + "GB/s": 29.362006161804633, + "TFLOPS": 0.004893667693634106 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json new file mode 100644 index 0000000..9a4f2e2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json @@ -0,0 +1,6 @@ +{ + "l2_norm_bwd.json": { + "ms": 2.0181, + "efficiency": 1.44 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json new file mode 100644 index 0000000..8e264de --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json @@ -0,0 +1,6 @@ +{ + "l2_norm_triton1.json": { + "ms": 1.7399, + "efficiency": 0.7672 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log new file mode 100644 index 0000000..798ca3d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log @@ -0,0 +1,16 @@ +{'input_size': [torch.Size([4, 8, 16, 64]), torch.Size([4, 8, 16]), torch.Size([4]), torch.Size([4, 8, 64])], 'ms': 0.00685499981045723, 'GB/s': 20.61677664591699, 'TFLOPS': 0.009560321197970791} +{'input_size': [torch.Size([8, 8, 16, 64]), torch.Size([8, 8, 16]), torch.Size([8]), torch.Size([8, 8, 64])], 'ms': 0.010343999601900578, 'GB/s': 27.325600432937524, 'TFLOPS': 0.012671307525564598} +{'input_size': [torch.Size([16, 8, 16, 64]), torch.Size([16, 8, 16]), torch.Size([16]), torch.Size([16, 8, 64])], 'ms': 0.009943000040948391, 'GB/s': 56.85527483373911, 'TFLOPS': 0.026364678559831923} +{'input_size': [torch.Size([32, 8, 16, 64]), torch.Size([32, 8, 16]), torch.Size([32]), torch.Size([32, 8, 64])], 'ms': 0.010905000381171703, 'GB/s': 103.67940948925656, 'TFLOPS': 0.048077760813765975} +{'input_size': [torch.Size([64, 8, 16, 64]), torch.Size([64, 8, 16]), torch.Size([64]), torch.Size([64, 8, 64])], 'ms': 0.011105000041425228, 'GB/s': 203.6243126127705, 'TFLOPS': 0.09442377272296025} +{'input_size': [torch.Size([128, 8, 16, 64]), torch.Size([128, 8, 16]), torch.Size([128]), torch.Size([128, 8, 64])], 'ms': 0.011587000451982021, 'GB/s': 390.3077434700886, 'TFLOPS': 0.18099179409639796} +{'input_size': [torch.Size([256, 8, 16, 64]), torch.Size([256, 8, 16]), torch.Size([256]), torch.Size([256, 8, 64])], 'ms': 0.012749000452458858, 'GB/s': 709.4667565295695, 'TFLOPS': 0.3289908111338296} +{'input_size': [torch.Size([512, 8, 16, 64]), torch.Size([512, 8, 16]), torch.Size([512]), torch.Size([512, 8, 64])], 'ms': 0.01828099973499775, 'GB/s': 989.5511329923569, 'TFLOPS': 0.4588703091516692} +{'input_size': [torch.Size([1024, 8, 16, 64]), torch.Size([1024, 8, 16]), torch.Size([1024]), torch.Size([1024, 8, 64])], 'ms': 0.02782299928367138, 'GB/s': 1300.361892372729, 'TFLOPS': 0.6029981106259139} +{'input_size': [torch.Size([2048, 8, 16, 64]), torch.Size([2048, 8, 16]), torch.Size([2048]), torch.Size([2048, 8, 64])], 'ms': 0.046505000442266464, 'GB/s': 1555.960333552326, 'TFLOPS': 0.7215230981807231} +{'input_size': [torch.Size([4096, 8, 16, 64]), torch.Size([4096, 8, 16]), torch.Size([4096]), torch.Size([4096, 8, 64])], 'ms': 0.08667699992656708, 'GB/s': 1669.6456052079209, 'TFLOPS': 0.7742407334916386} +{'input_size': [torch.Size([8192, 8, 16, 64]), torch.Size([8192, 8, 16]), torch.Size([8192]), torch.Size([8192, 8, 64])], 'ms': 0.16365200281143188, 'GB/s': 1768.6294028036255, 'TFLOPS': 0.8201410657628948} +{'input_size': [torch.Size([16384, 8, 16, 64]), torch.Size([16384, 8, 16]), torch.Size([16384]), torch.Size([16384, 8, 64])], 'ms': 0.31447499990463257, 'GB/s': 1840.7806285890786, 'TFLOPS': 0.8535987155780443} +{'input_size': [torch.Size([32768, 8, 16, 64]), torch.Size([32768, 8, 16]), torch.Size([32768]), torch.Size([32768, 8, 64])], 'ms': 0.5913429856300354, 'GB/s': 1957.8468065643615, 'TFLOPS': 0.9078841299318041} +{'input_size': [torch.Size([65536, 8, 16, 64]), torch.Size([65536, 8, 16]), torch.Size([65536]), torch.Size([65536, 8, 64])], 'ms': 1.1500940322875977, 'GB/s': 2013.3292469958415, 'TFLOPS': 0.9336122037467414} +{'input_size': [torch.Size([131072, 8, 16, 64]), torch.Size([131072, 8, 16]), torch.Size([131072]), torch.Size([131072, 8, 64])], 'ms': 2.2799609899520874, 'GB/s': 2031.1908512510645, 'TFLOPS': 0.9418949084936443} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log new file mode 100644 index 0000000..3d811ce --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log @@ -0,0 +1,11 @@ +{'input_size': [torch.Size([16]), torch.Size([16])], 'ms': 0.006012999918311834, 'GB/s': 0.03193081699789953, 'TFLOPS': 5.321802832983255e-06} +{'input_size': [torch.Size([32]), torch.Size([32])], 'ms': 0.0059739998541772366, 'GB/s': 0.06427854191049123, 'TFLOPS': 1.0713090318415206e-05} +{'input_size': [torch.Size([64]), torch.Size([64])], 'ms': 0.0059739998541772366, 'GB/s': 0.12855708382098247, 'TFLOPS': 2.1426180636830412e-05} +{'input_size': [torch.Size([128]), torch.Size([128])], 'ms': 0.006014000158756971, 'GB/s': 0.25540405045773634, 'TFLOPS': 4.256734174295606e-05} +{'input_size': [torch.Size([256]), torch.Size([256])], 'ms': 0.006014000158756971, 'GB/s': 0.5108081009154727, 'TFLOPS': 8.513468348591212e-05} +{'input_size': [torch.Size([512]), torch.Size([512])], 'ms': 0.006014000158756971, 'GB/s': 1.0216162018309454, 'TFLOPS': 0.00017026936697182423} +{'input_size': [torch.Size([1024]), torch.Size([1024])], 'ms': 0.006012999918311834, 'GB/s': 2.0435722878655698, 'TFLOPS': 0.0003405953813109283} +{'input_size': [torch.Size([2048]), torch.Size([2048])], 'ms': 0.006014000158756971, 'GB/s': 4.0864648073237815, 'TFLOPS': 0.0006810774678872969} +{'input_size': [torch.Size([4096]), torch.Size([4096])], 'ms': 0.006014000158756971, 'GB/s': 8.172929614647563, 'TFLOPS': 0.0013621549357745938} +{'input_size': [torch.Size([8192]), torch.Size([8192])], 'ms': 0.006014000158756971, 'GB/s': 16.345859229295126, 'TFLOPS': 0.0027243098715491877} +{'input_size': [torch.Size([16384]), torch.Size([16384])], 'ms': 0.006696000229567289, 'GB/s': 29.362006161804633, 'TFLOPS': 0.004893667693634106} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log new file mode 100644 index 0000000..550e319 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log @@ -0,0 +1,11 @@ +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log new file mode 100644 index 0000000..9f23be3 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log @@ -0,0 +1,18 @@ +{'input_size': [torch.Size([128, 256]), torch.Size([256])], 'ms': 0.019283000379800797, 'GB/s': 6.8769381003025165, 'TFLOPS': 0.0033986412233155293} +{'input_size': [torch.Size([256, 384]), torch.Size([384])], 'ms': 0.026660999283194542, 'GB/s': 14.844754909448309, 'TFLOPS': 0.007374367251265395} +{'input_size': [torch.Size([384, 512]), torch.Size([512])], 'ms': 0.0357620008289814, 'GB/s': 22.09093399941353, 'TFLOPS': 0.010995357946564867} +{'input_size': [torch.Size([512, 640]), torch.Size([640])], 'ms': 0.04333899915218353, 'GB/s': 30.349754856619256, 'TFLOPS': 0.015121715148490715} +{'input_size': [torch.Size([640, 768]), torch.Size([768])], 'ms': 0.051676999777555466, 'GB/s': 38.154536998805426, 'TFLOPS': 0.01902277617182717} +{'input_size': [torch.Size([768, 896]), torch.Size([896])], 'ms': 0.060015998780727386, 'GB/s': 45.97387456769338, 'TFLOPS': 0.022931485403221343} +{'input_size': [torch.Size([896, 1024]), torch.Size([1024])], 'ms': 0.07028000056743622, 'GB/s': 52.32919707322878, 'TFLOPS': 0.026109959948552405} +{'input_size': [torch.Size([1024, 1152]), torch.Size([1152])], 'ms': 0.07629299908876419, 'GB/s': 61.9623826099687, 'TFLOPS': 0.030924148063114457} +{'input_size': [torch.Size([1152, 1280]), torch.Size([1280])], 'ms': 0.08503299951553345, 'GB/s': 69.47853225994645, 'TFLOPS': 0.03468206480780756} +{'input_size': [torch.Size([1280, 1408]), torch.Size([1408])], 'ms': 0.0923290029168129, 'GB/s': 78.19549406923474, 'TFLOPS': 0.03903952047708762} +{'input_size': [torch.Size([1408, 1536]), torch.Size([1536])], 'ms': 0.10183099657297134, 'GB/s': 85.06769344825665, 'TFLOPS': 0.0424760254300415} +{'input_size': [torch.Size([1536, 1664]), torch.Size([1664])], 'ms': 0.10796499997377396, 'GB/s': 94.81235587909558, 'TFLOPS': 0.04734689946965889} +{'input_size': [torch.Size([1664, 1792]), torch.Size([1792])], 'ms': 0.1157429963350296, 'GB/s': 103.17147799970981, 'TFLOPS': 0.05152602048366933} +{'input_size': [torch.Size([1792, 1920]), torch.Size([1920])], 'ms': 0.12308000028133392, 'GB/s': 111.93864127809444, 'TFLOPS': 0.05590900214714739} +{'input_size': [torch.Size([1920, 2048]), torch.Size([2048])], 'ms': 0.13025599718093872, 'GB/s': 120.87360536750782, 'TFLOPS': 0.06037587650628989} +{'input_size': [torch.Size([2048, 2176]), torch.Size([2176])], 'ms': 0.13975700736045837, 'GB/s': 127.66936225230204, 'TFLOPS': 0.06377423335212126} +{'input_size': [torch.Size([2176, 2304]), torch.Size([2304])], 'ms': 0.14849700033664703, 'GB/s': 135.16728253430261, 'TFLOPS': 0.06752330334800354} +{'input_size': [torch.Size([2304, 2432]), torch.Size([2432])], 'ms': 0.15607500076293945, 'GB/s': 143.72741240009412, 'TFLOPS': 0.07180301742891973} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.log new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log new file mode 100644 index 0000000..45caff8 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log @@ -0,0 +1,14 @@ +{'input_size': [torch.Size([4, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.009139999747276306, 'GB/s': 116.516414600269, 'TFLOPS': 0.057361927187824735} +{'input_size': [torch.Size([8, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.009500999934971333, 'GB/s': 222.4540589902001, 'TFLOPS': 0.11036480446025432} +{'input_size': [torch.Size([16, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.01006300002336502, 'GB/s': 418.4326731812891, 'TFLOPS': 0.20840226524204278} +{'input_size': [torch.Size([32, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.012788999825716019, 'GB/s': 657.204794318576, 'TFLOPS': 0.32796184667749595} +{'input_size': [torch.Size([64, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.019202999770641327, 'GB/s': 874.5300317961281, 'TFLOPS': 0.4368384158825538} +{'input_size': [torch.Size([128, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.030789999291300774, 'GB/s': 1090.3155820950246, 'TFLOPS': 0.5448917306321646} +{'input_size': [torch.Size([256, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.06005600094795227, 'GB/s': 1117.7109188168276, 'TFLOPS': 0.5587190533895199} +{'input_size': [torch.Size([512, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.11085200309753418, 'GB/s': 1210.930865019127, 'TFLOPS': 0.6053915321760459} +{'input_size': [torch.Size([1024, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.20907600224018097, 'GB/s': 1283.9916447781015, 'TFLOPS': 0.6419566404651943} +{'input_size': [torch.Size([2048, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.3862780034542084, 'GB/s': 1389.8987030040546, 'TFLOPS': 0.6949281439781023} +{'input_size': [torch.Size([4096, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.7548360228538513, 'GB/s': 1422.5052534461477, 'TFLOPS': 0.7112417740348714} +{'input_size': [torch.Size([8192, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 1.4495340585708618, 'GB/s': 1481.510571829739, 'TFLOPS': 0.7407496344436594} +{'input_size': [torch.Size([16384, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 2.838752031326294, 'GB/s': 1512.983040647386, 'TFLOPS': 0.7564886345485674} +{'input_size': [torch.Size([32768, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 5.886918544769287, 'GB/s': 1459.159135747248, 'TFLOPS': 0.7295781763136869} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log new file mode 100644 index 0000000..8ceeb1d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log @@ -0,0 +1,16 @@ +{'input_size': [torch.Size([4096])], 'ms': 0.006053999997675419, 'GB/s': 5.412619757611839, 'TFLOPS': 0.0006765774697014799} +{'input_size': [torch.Size([8192])], 'ms': 0.006093999836593866, 'GB/s': 10.754184732080693, 'TFLOPS': 0.0013442730915100868} +{'input_size': [torch.Size([16384])], 'ms': 0.006093999836593866, 'GB/s': 21.508369464161387, 'TFLOPS': 0.0026885461830201737} +{'input_size': [torch.Size([32768])], 'ms': 0.006093999836593866, 'GB/s': 43.01673892832277, 'TFLOPS': 0.005377092366040347} +{'input_size': [torch.Size([65536])], 'ms': 0.006093999836593866, 'GB/s': 86.03347785664555, 'TFLOPS': 0.010754184732080695} +{'input_size': [torch.Size([131072])], 'ms': 0.006134000141173601, 'GB/s': 170.94489335948708, 'TFLOPS': 0.021368111669935887} +{'input_size': [torch.Size([262144])], 'ms': 0.006134000141173601, 'GB/s': 341.88978671897416, 'TFLOPS': 0.04273622333987177} +{'input_size': [torch.Size([524288])], 'ms': 0.0062939999625086784, 'GB/s': 666.3972076555627, 'TFLOPS': 0.08329965095694535} +{'input_size': [torch.Size([1048576])], 'ms': 0.00685499981045723, 'GB/s': 1223.7211133402611, 'TFLOPS': 0.15296513916753265} +{'input_size': [torch.Size([2097152])], 'ms': 0.008860000409185886, 'GB/s': 1893.5908832019568, 'TFLOPS': 0.2366988604002446} +{'input_size': [torch.Size([4194304])], 'ms': 0.014072000049054623, 'GB/s': 2384.482083785541, 'TFLOPS': 0.2980602604731926} +{'input_size': [torch.Size([8388608])], 'ms': 0.0230919998139143, 'GB/s': 2906.15211072204, 'TFLOPS': 0.363269013840255} +{'input_size': [torch.Size([16777216])], 'ms': 0.05051499977707863, 'GB/s': 2656.987599570411, 'TFLOPS': 0.3321234499463014} +{'input_size': [torch.Size([33554432])], 'ms': 0.08868200331926346, 'GB/s': 3026.943979080033, 'TFLOPS': 0.3783679973850041} +{'input_size': [torch.Size([67108864])], 'ms': 0.1559540033340454, 'GB/s': 3442.4952263011182, 'TFLOPS': 0.4303119032876398} +{'input_size': [torch.Size([134217728])], 'ms': 0.2989189922809601, 'GB/s': 3592.0829780891536, 'TFLOPS': 0.4490103722611442} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log new file mode 100644 index 0000000..a440b79 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log @@ -0,0 +1,31 @@ +{'input_size': [torch.Size([256, 256]), torch.Size([256, 256])], 'ms': 0.009019999764859676, 'GB/s': 43.59379271071602, 'TFLOPS': 3.720003644647768} +{'input_size': [torch.Size([384, 384]), torch.Size([384, 384])], 'ms': 0.010223000310361385, 'GB/s': 86.54367339725967, 'TFLOPS': 11.077590194849238} +{'input_size': [torch.Size([512, 512]), torch.Size([512, 512])], 'ms': 0.012949000112712383, 'GB/s': 121.46605809786634, 'TFLOPS': 20.73020724870252} +{'input_size': [torch.Size([640, 640]), torch.Size([640, 640])], 'ms': 0.014112000353634357, 'GB/s': 174.14965549990777, 'TFLOPS': 37.151926506646994} +{'input_size': [torch.Size([768, 768]), torch.Size([768, 768])], 'ms': 0.01720000058412552, 'GB/s': 205.7525511520165, 'TFLOPS': 52.67265309491623} +{'input_size': [torch.Size([896, 896]), torch.Size([896, 896])], 'ms': 0.018681999295949936, 'GB/s': 257.836215690483, 'TFLOPS': 77.00708308622427} +{'input_size': [torch.Size([1024, 1024]), torch.Size([1024, 1024])], 'ms': 0.02353299967944622, 'GB/s': 267.3461133599119, 'TFLOPS': 91.25414002684992} +{'input_size': [torch.Size([1152, 1152]), torch.Size([1152, 1152])], 'ms': 0.0257789995521307, 'GB/s': 308.8802567336973, 'TFLOPS': 118.61001858573978} +{'input_size': [torch.Size([1280, 1280]), torch.Size([1280, 1280])], 'ms': 0.029627999290823936, 'GB/s': 331.7942566255078, 'TFLOPS': 141.56554949355} +{'input_size': [torch.Size([1408, 1408]), torch.Size([1408, 1408])], 'ms': 0.030910000205039978, 'GB/s': 384.8199262729385, 'TFLOPS': 180.60881873076582} +{'input_size': [torch.Size([1536, 1536]), torch.Size([1536, 1536])], 'ms': 0.03672400116920471, 'GB/s': 385.4638805498806, 'TFLOPS': 197.35750684153888} +{'input_size': [torch.Size([1664, 1664]), torch.Size([1664, 1664])], 'ms': 0.05195799842476845, 'GB/s': 319.7462662857386, 'TFLOPS': 177.352595699823} +{'input_size': [torch.Size([1792, 1792]), torch.Size([1792, 1792])], 'ms': 0.06085899844765663, 'GB/s': 316.59383971906124, 'TFLOPS': 189.1120535921859} +{'input_size': [torch.Size([1920, 1920]), torch.Size([1920, 1920])], 'ms': 0.06326399743556976, 'GB/s': 349.6206515012925, 'TFLOPS': 223.7572169608272} +{'input_size': [torch.Size([2048, 2048]), torch.Size([2048, 2048])], 'ms': 0.07392799854278564, 'GB/s': 340.40991905705846, 'TFLOPS': 232.38650474295193} +{'input_size': [torch.Size([2176, 2176]), torch.Size([2176, 2176])], 'ms': 0.07533200085163116, 'GB/s': 377.1286528809203, 'TFLOPS': 273.54398288962756} +{'input_size': [torch.Size([2304, 2304]), torch.Size([2304, 2304])], 'ms': 0.14252400398254395, 'GB/s': 223.47460855717318, 'TFLOPS': 171.628499371909} +{'input_size': [torch.Size([2432, 2432]), torch.Size([2432, 2432])], 'ms': 0.14476899802684784, 'GB/s': 245.13358856996922, 'TFLOPS': 198.72162913405504} +{'input_size': [torch.Size([2560, 2560]), torch.Size([2560, 2560])], 'ms': 0.16501599550247192, 'GB/s': 238.28962689505434, 'TFLOPS': 203.34048161711303} +{'input_size': [torch.Size([2688, 2688]), torch.Size([2688, 2688])], 'ms': 0.1670999974012375, 'GB/s': 259.4378496362499, 'TFLOPS': 232.45631327407992} +{'input_size': [torch.Size([2816, 2816]), torch.Size([2816, 2816])], 'ms': 0.18959100544452667, 'GB/s': 250.95671542245927, 'TFLOPS': 235.5647035432151} +{'input_size': [torch.Size([2944, 2944]), torch.Size([2944, 2944])], 'ms': 0.18975099921226501, 'GB/s': 274.05819318941786, 'TFLOPS': 268.94244024988205} +{'input_size': [torch.Size([3072, 3072]), torch.Size([3072, 3072])], 'ms': 0.21536900103092194, 'GB/s': 262.9120427218319, 'TFLOPS': 269.22193174715585} +{'input_size': [torch.Size([3200, 3200]), torch.Size([3200, 3200])], 'ms': 0.297435998916626, 'GB/s': 206.56544676430437, 'TFLOPS': 220.33647654859135} +{'input_size': [torch.Size([3328, 3328]), torch.Size([3328, 3328])], 'ms': 0.3247379958629608, 'GB/s': 204.6372917447065, 'TFLOPS': 227.01096897546105} +{'input_size': [torch.Size([3456, 3456]), torch.Size([3456, 3456])], 'ms': 0.32375550270080566, 'GB/s': 221.35103620532738, 'TFLOPS': 254.9963937085371} +{'input_size': [torch.Size([3584, 3584]), torch.Size([3584, 3584])], 'ms': 0.35669100284576416, 'GB/s': 216.0703112360975, 'TFLOPS': 258.1319984900578} +{'input_size': [torch.Size([3712, 3712]), torch.Size([3712, 3712])], 'ms': 0.36759498715400696, 'GB/s': 224.90422037600632, 'TFLOPS': 278.2814886785785} +{'input_size': [torch.Size([3840, 3840]), torch.Size([3840, 3840])], 'ms': 0.4137600064277649, 'GB/s': 213.82830294267677, 'TFLOPS': 273.70022776662626} +{'input_size': [torch.Size([3968, 3968]), torch.Size([3968, 3968])], 'ms': 0.4952859878540039, 'GB/s': 190.73857592726222, 'TFLOPS': 252.2835564264588} +{'input_size': [torch.Size([4096, 4096]), torch.Size([4096, 4096])], 'ms': 0.5376620292663574, 'GB/s': 187.22411202694673, 'TFLOPS': 255.62332095412455} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json new file mode 100644 index 0000000..3c6ddcd --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json @@ -0,0 +1,254 @@ +[ + { + "input_size": [ + [ + 128, + 256 + ], + [ + 256 + ] + ], + "ms": 0.019283000379800797, + "GB/s": 6.8769381003025165, + "TFLOPS": 0.0033986412233155293 + }, + { + "input_size": [ + [ + 256, + 384 + ], + [ + 384 + ] + ], + "ms": 0.026660999283194542, + "GB/s": 14.844754909448309, + "TFLOPS": 0.007374367251265395 + }, + { + "input_size": [ + [ + 384, + 512 + ], + [ + 512 + ] + ], + "ms": 0.0357620008289814, + "GB/s": 22.09093399941353, + "TFLOPS": 0.010995357946564867 + }, + { + "input_size": [ + [ + 512, + 640 + ], + [ + 640 + ] + ], + "ms": 0.04333899915218353, + "GB/s": 30.349754856619256, + "TFLOPS": 0.015121715148490715 + }, + { + "input_size": [ + [ + 640, + 768 + ], + [ + 768 + ] + ], + "ms": 0.051676999777555466, + "GB/s": 38.154536998805426, + "TFLOPS": 0.01902277617182717 + }, + { + "input_size": [ + [ + 768, + 896 + ], + [ + 896 + ] + ], + "ms": 0.060015998780727386, + "GB/s": 45.97387456769338, + "TFLOPS": 0.022931485403221343 + }, + { + "input_size": [ + [ + 896, + 1024 + ], + [ + 1024 + ] + ], + "ms": 0.07028000056743622, + "GB/s": 52.32919707322878, + "TFLOPS": 0.026109959948552405 + }, + { + "input_size": [ + [ + 1024, + 1152 + ], + [ + 1152 + ] + ], + "ms": 0.07629299908876419, + "GB/s": 61.9623826099687, + "TFLOPS": 0.030924148063114457 + }, + { + "input_size": [ + [ + 1152, + 1280 + ], + [ + 1280 + ] + ], + "ms": 0.08503299951553345, + "GB/s": 69.47853225994645, + "TFLOPS": 0.03468206480780756 + }, + { + "input_size": [ + [ + 1280, + 1408 + ], + [ + 1408 + ] + ], + "ms": 0.0923290029168129, + "GB/s": 78.19549406923474, + "TFLOPS": 0.03903952047708762 + }, + { + "input_size": [ + [ + 1408, + 1536 + ], + [ + 1536 + ] + ], + "ms": 0.10183099657297134, + "GB/s": 85.06769344825665, + "TFLOPS": 0.0424760254300415 + }, + { + "input_size": [ + [ + 1536, + 1664 + ], + [ + 1664 + ] + ], + "ms": 0.10796499997377396, + "GB/s": 94.81235587909558, + "TFLOPS": 0.04734689946965889 + }, + { + "input_size": [ + [ + 1664, + 1792 + ], + [ + 1792 + ] + ], + "ms": 0.1157429963350296, + "GB/s": 103.17147799970981, + "TFLOPS": 0.05152602048366933 + }, + { + "input_size": [ + [ + 1792, + 1920 + ], + [ + 1920 + ] + ], + "ms": 0.12308000028133392, + "GB/s": 111.93864127809444, + "TFLOPS": 0.05590900214714739 + }, + { + "input_size": [ + [ + 1920, + 2048 + ], + [ + 2048 + ] + ], + "ms": 0.13025599718093872, + "GB/s": 120.87360536750782, + "TFLOPS": 0.06037587650628989 + }, + { + "input_size": [ + [ + 2048, + 2176 + ], + [ + 2176 + ] + ], + "ms": 0.13975700736045837, + "GB/s": 127.66936225230204, + "TFLOPS": 0.06377423335212126 + }, + { + "input_size": [ + [ + 2176, + 2304 + ], + [ + 2304 + ] + ], + "ms": 0.14849700033664703, + "GB/s": 135.16728253430261, + "TFLOPS": 0.06752330334800354 + }, + { + "input_size": [ + [ + 2304, + 2432 + ], + [ + 2432 + ] + ], + "ms": 0.15607500076293945, + "GB/s": 143.72741240009412, + "TFLOPS": 0.07180301742891973 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json new file mode 100644 index 0000000..b78c381 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json @@ -0,0 +1,6 @@ +{ + "matrix_vector_multip.json": { + "ms": 0.3142, + "efficiency": 7.0489 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json new file mode 100644 index 0000000..55e39be --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json @@ -0,0 +1,296 @@ +[ + { + "input_size": [ + [ + 4, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.009139999747276306, + "GB/s": 116.516414600269, + "TFLOPS": 0.057361927187824735 + }, + { + "input_size": [ + [ + 8, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.009500999934971333, + "GB/s": 222.4540589902001, + "TFLOPS": 0.11036480446025432 + }, + { + "input_size": [ + [ + 16, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.01006300002336502, + "GB/s": 418.4326731812891, + "TFLOPS": 0.20840226524204278 + }, + { + "input_size": [ + [ + 32, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.012788999825716019, + "GB/s": 657.204794318576, + "TFLOPS": 0.32796184667749595 + }, + { + "input_size": [ + [ + 64, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.019202999770641327, + "GB/s": 874.5300317961281, + "TFLOPS": 0.4368384158825538 + }, + { + "input_size": [ + [ + 128, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.030789999291300774, + "GB/s": 1090.3155820950246, + "TFLOPS": 0.5448917306321646 + }, + { + "input_size": [ + [ + 256, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.06005600094795227, + "GB/s": 1117.7109188168276, + "TFLOPS": 0.5587190533895199 + }, + { + "input_size": [ + [ + 512, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.11085200309753418, + "GB/s": 1210.930865019127, + "TFLOPS": 0.6053915321760459 + }, + { + "input_size": [ + [ + 1024, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.20907600224018097, + "GB/s": 1283.9916447781015, + "TFLOPS": 0.6419566404651943 + }, + { + "input_size": [ + [ + 2048, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.3862780034542084, + "GB/s": 1389.8987030040546, + "TFLOPS": 0.6949281439781023 + }, + { + "input_size": [ + [ + 4096, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.7548360228538513, + "GB/s": 1422.5052534461477, + "TFLOPS": 0.7112417740348714 + }, + { + "input_size": [ + [ + 8192, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 1.4495340585708618, + "GB/s": 1481.510571829739, + "TFLOPS": 0.7407496344436594 + }, + { + "input_size": [ + [ + 16384, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 2.838752031326294, + "GB/s": 1512.983040647386, + "TFLOPS": 0.7564886345485674 + }, + { + "input_size": [ + [ + 32768, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 5.886918544769287, + "GB/s": 1459.159135747248, + "TFLOPS": 0.7295781763136869 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json new file mode 100644 index 0000000..f298582 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json @@ -0,0 +1,6 @@ +{ + "rotary_transform.json": { + "ms": 5.4456, + "efficiency": 74.2022 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json new file mode 100644 index 0000000..2e0ab59 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json @@ -0,0 +1,162 @@ +[ + { + "input_size": [ + [ + 4096 + ] + ], + "ms": 0.006053999997675419, + "GB/s": 5.412619757611839, + "TFLOPS": 0.0006765774697014799 + }, + { + "input_size": [ + [ + 8192 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 10.754184732080693, + "TFLOPS": 0.0013442730915100868 + }, + { + "input_size": [ + [ + 16384 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 21.508369464161387, + "TFLOPS": 0.0026885461830201737 + }, + { + "input_size": [ + [ + 32768 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 43.01673892832277, + "TFLOPS": 0.005377092366040347 + }, + { + "input_size": [ + [ + 65536 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 86.03347785664555, + "TFLOPS": 0.010754184732080695 + }, + { + "input_size": [ + [ + 131072 + ] + ], + "ms": 0.006134000141173601, + "GB/s": 170.94489335948708, + "TFLOPS": 0.021368111669935887 + }, + { + "input_size": [ + [ + 262144 + ] + ], + "ms": 0.006134000141173601, + "GB/s": 341.88978671897416, + "TFLOPS": 0.04273622333987177 + }, + { + "input_size": [ + [ + 524288 + ] + ], + "ms": 0.0062939999625086784, + "GB/s": 666.3972076555627, + "TFLOPS": 0.08329965095694535 + }, + { + "input_size": [ + [ + 1048576 + ] + ], + "ms": 0.00685499981045723, + "GB/s": 1223.7211133402611, + "TFLOPS": 0.15296513916753265 + }, + { + "input_size": [ + [ + 2097152 + ] + ], + "ms": 0.008860000409185886, + "GB/s": 1893.5908832019568, + "TFLOPS": 0.2366988604002446 + }, + { + "input_size": [ + [ + 4194304 + ] + ], + "ms": 0.014072000049054623, + "GB/s": 2384.482083785541, + "TFLOPS": 0.2980602604731926 + }, + { + "input_size": [ + [ + 8388608 + ] + ], + "ms": 0.0230919998139143, + "GB/s": 2906.15211072204, + "TFLOPS": 0.363269013840255 + }, + { + "input_size": [ + [ + 16777216 + ] + ], + "ms": 0.05051499977707863, + "GB/s": 2656.987599570411, + "TFLOPS": 0.3321234499463014 + }, + { + "input_size": [ + [ + 33554432 + ] + ], + "ms": 0.08868200331926346, + "GB/s": 3026.943979080033, + "TFLOPS": 0.3783679973850041 + }, + { + "input_size": [ + [ + 67108864 + ] + ], + "ms": 0.1559540033340454, + "GB/s": 3442.4952263011182, + "TFLOPS": 0.4303119032876398 + }, + { + "input_size": [ + [ + 134217728 + ] + ], + "ms": 0.2989189922809601, + "GB/s": 3592.0829780891536, + "TFLOPS": 0.4490103722611442 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json new file mode 100644 index 0000000..507bcf4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json @@ -0,0 +1,6 @@ +{ + "sin_kernel.json": { + "ms": 1.4189, + "efficiency": 90.2883 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py new file mode 100644 index 0000000..3469c2a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py @@ -0,0 +1,59 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from flash_decode2_phi import flash_decode_stage2 +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('flash_decode2_phi', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 18): # Adjust the range as needed for your testing + batch_size = 2 ** i + head_num = 8 # Example head number, adjust as needed + seq_block_num = 16 # Example sequence block number, adjust as needed + head_dim = 64 # Example head dimension, adjust as needed + + mid_out = torch.rand(batch_size, head_num, seq_block_num, head_dim, dtype=torch.float32) + mid_out_logexpsum = torch.rand(batch_size, head_num, seq_block_num, dtype=torch.float32) + B_Seqlen = torch.randint(1, seq_block_num * 32, (batch_size,), dtype=torch.int32) + Out = torch.empty(batch_size, head_num, head_dim, dtype=torch.float32) + + self.input_tensors.append((mid_out, mid_out_logexpsum, B_Seqlen, Out)) + + def to_cuda(self, input_tensor): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + return (mid_out.cuda(), mid_out_logexpsum.cuda(), B_Seqlen.cuda(), Out.cuda()) + + def call_op(self, input_tensor): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + block_seq = 32 # Example block sequence size, adjust as needed + flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq) + return Out + + def get_gbps(self, input_tensor, runtime): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + total_bytes = mid_out.numel() * mid_out.element_size() + mid_out_logexpsum.numel() * mid_out_logexpsum.element_size() + B_Seqlen.numel() * B_Seqlen.element_size() + Out.numel() * Out.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + mid_out, _, _, _ = input_tensor + FLOPS = 2 * mid_out.numel() # Example calculation, adjust based on actual operations + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py new file mode 100644 index 0000000..0eb17f4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py @@ -0,0 +1,51 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from l2_norm_bwd import _l2_norm_bwd +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('l2_norm_bwd', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(4, 15): + size = 2 ** i + x = torch.rand(size, dtype=torch.float32) + dy = torch.rand(size, dtype=torch.float32) + self.input_tensors.append((x, dy)) + + def to_cuda(self, input_tensor): + x, dy = input_tensor + return x.cuda(), dy.cuda() + + def call_op(self, input_tensor): + x, dy = input_tensor + return _l2_norm_bwd(x, dy) + + def get_gbps(self, input_tensor, runtime): + x, dy = input_tensor + total_bytes = (x.numel() + dy.numel() + x.numel()) * x.element_size() # x, dy, and dx + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x, dy = input_tensor + # Assuming each element involves a few FLOPs, e.g., multiplication, addition + FLOPS = 2 * x.numel() # Simplified estimation + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py new file mode 100644 index 0000000..c568861 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py @@ -0,0 +1,75 @@ +import sys +import os +import json + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from l2_norm_triton1 import _l2_norm_fwd +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('l2_norm_triton1', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(4, 15): + size = 2 ** i + input_tensor = torch.rand(size, dtype=torch.float32) + self.input_tensors.append(input_tensor) + + def to_cuda(self, input_tensor): + return input_tensor.cuda() + + def call_op(self, input_tensor): + return _l2_norm_fwd(input_tensor) + + def get_gbps(self, input_tensor, runtime): + x = input_tensor + total_bytes = 2 * x.numel() * x.element_size() # Read and write + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x = input_tensor + FLOPS = 2 * x.numel() # Each element involves a multiply and an add + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [input_tensor.shape], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py new file mode 100644 index 0000000..8901aea --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py @@ -0,0 +1,52 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from matrix_vector_multip import mv +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('matrix_vector_multip', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 20): # Adjust the range for different sizes + M = 128 * i + N = 128 * (i - 1) # Example: N is half of M + matrix = torch.rand((N, M), dtype=torch.float32) + vector = torch.rand((M,), dtype=torch.float32) + self.input_tensors.append((matrix, vector)) + + def to_cuda(self, input_tensor): + matrix, vector = input_tensor + return (matrix.cuda(), vector.cuda()) + + def call_op(self, input_tensor): + matrix, vector = input_tensor + return mv(matrix, vector) + + def get_gbps(self, input_tensor, runtime): + matrix, vector = input_tensor + total_bytes = (matrix.numel() + vector.numel() + matrix.size(0)) * matrix.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + matrix, vector = input_tensor + N, M = matrix.shape + FLOPS = 2 * N * M # Each element in the output involves M multiplications and M-1 additions + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py new file mode 100644 index 0000000..a8ceb06 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py @@ -0,0 +1,145 @@ +# Modifications Copyright(C)[2025] Advanced Micro Devices, Inc. All rights reserved. +# https://github.com/thunlp/TritonBench - Apache License 2.0 +import torch +import triton +import triton.language as tl + +from typing import Callable +import json +import os + +class do_bench_config(): + def __init__( + self, + warm_up=25, + repetition=100, + grad_to_none=None, + quantiles=[0.5, 0.8, 0.2], + return_mode="median" + ): + self.warm_up = warm_up + self.repetition = repetition + self.grad_to_none = grad_to_none + self.quantiles = quantiles + self.return_mode = return_mode + +class Performance_Metrics: + def __init__( + self, + op_name, + dtype=None, + is_backward=False, + **kwargs + ): + self.op_name = op_name + self.dtype = dtype + if is_backward: + self.op_name += 'backward' + self.kwargs = kwargs + + self.input_tensors = [] + self.do_bench_config = do_bench_config() + + def get_input_tensors(self): + raise NotImplementedError("You must implement this method to get input tensors") + + def to_cuda(self, input_tensor): + raise NotImplementedError("You must implement this method to get input tensors") + + def call_op(self, input_tensor): + raise NotImplementedError("You must implement this method to call the op") + + def get_do_bench_config(self, warmup=None, rep=None): + if warmup != None and rep != None: + self.do_bench_config = do_bench_config( + warm_up=warmup, + repetition=rep, + ) + return + + if self.input_tensors == []: + raise NotImplementedError("You must implement this method to get input_tensors") + + previous_ms = None + epsilon = 1e-4 + stable_count = 0 + max_stable_count = 3 + input_tensor = self.to_cuda(self.input_tensors[-1]) + + for t in range(1, 11): + warmup = 100 * t + rep = 1000 * t + + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: self.call_op(input_tensor), + warmup=warmup, + rep=rep, + quantiles=[0.5, 0.8, 0.2], + return_mode="median" + ) + + print("warmup time:", warmup, "rep time:", rep, "runtime:", ms) + + if previous_ms is not None: + relative_change = abs(ms - previous_ms) / abs(previous_ms) if previous_ms != 0 else float('inf') + + if relative_change < epsilon: + stable_count += 1 + else: + stable_count = 0 + + if stable_count >= max_stable_count: + print(f"MS stabilized with warmup={warmup} and rep={rep}") + self.do_bench_config = do_bench_config( + warm_up=warmup, + repetition=rep, + ) + return + + previous_ms = ms + + print("MS did not stabilize. Returning default config.") + raise NotImplementedError("You must implement this method to make the runtime stable") + + def get_runtime(self, op: Callable): + ms, min_ms, max_ms = triton.testing.do_bench( + op, + warmup=self.do_bench_config.warm_up, + rep=self.do_bench_config.repetition, + quantiles=self.do_bench_config.quantiles, + return_mode=self.do_bench_config.return_mode + ) + return ms + + def get_gbps(self, input_tensor, runtime): + raise NotImplementedError("You must implement this method to get the method to calculate GBPS") + + def get_tflops(self, input_tensor, runtime): + raise NotImplementedError("You must implement this method to get the method to calculate TFLOPS") + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [item.shape if type(item)==torch.Tensor else item for item in input_tensor], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py new file mode 100644 index 0000000..97d9cde --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py @@ -0,0 +1,56 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from rotary_transform import apply_rotary +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('rotary_transform', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 16): # Choose a reasonable range for testing + batch_size = 2 ** i + seqlen = 128 # Fixed sequence length + nheads = 8 # Number of attention heads + headdim = 64 # Dimension of each head + rotary_dim = 32 # Rotary dimension + x = torch.rand(batch_size, seqlen, nheads, headdim, dtype=torch.float32) + cos = torch.rand(seqlen, rotary_dim // 2, dtype=torch.float32) + sin = torch.rand(seqlen, rotary_dim // 2, dtype=torch.float32) + self.input_tensors.append((x, cos, sin)) + + def to_cuda(self, input_tensor): + x, cos, sin = input_tensor + return (x.cuda(), cos.cuda(), sin.cuda()) + + def call_op(self, input_tensor): + x, cos, sin = input_tensor + return apply_rotary(x, cos, sin) + + def get_gbps(self, input_tensor, runtime): + x, cos, sin = input_tensor + total_bytes = x.numel() * x.element_size() + cos.numel() * cos.element_size() + sin.numel() * sin.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x, cos, sin = input_tensor + # Assuming each element in x is involved in a few operations (e.g., multiply and add) + FLOPS = 2 * x.numel() # Simplified estimation + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py new file mode 100644 index 0000000..d6eeda2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py @@ -0,0 +1,75 @@ +import sys +import os +import json +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# Correctly import the kernel function +from sin_kernel import call_kernel +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('sin_kernel', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(12, 28): + size = 2 ** i + input_tensor = torch.rand(size, dtype=torch.float32) + self.input_tensors.append(input_tensor) + + def to_cuda(self, input_tensor): + return input_tensor.cuda() + + def call_op(self, input_tensor): + return call_kernel(input_tensor) + + def get_gbps(self, input_tensor, runtime): + x = input_tensor + total_bytes = 2 * x.numel() * x.element_size() # Read and write + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x = input_tensor + FLOPS = x.numel() # One sin operation per element + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [input_tensor.shape], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py new file mode 100644 index 0000000..6044dc4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py @@ -0,0 +1,54 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from triton_matmul import matmul # Correctly import the matmul function +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('triton_matmul', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 33): # Define a range for matrix sizes + M = N = K = 128 * i + a = torch.rand((M, K), dtype=torch.float16) # Use float16 for compatibility + b = torch.rand((K, N), dtype=torch.float16) + self.input_tensors.append((a, b)) + + def to_cuda(self, input_tensor): + a, b = input_tensor + return (a.cuda(), b.cuda()) + + def call_op(self, input_tensor): + a, b = input_tensor + return matmul(a, b) + + def get_gbps(self, input_tensor, runtime): + a, b = input_tensor + M, K = a.shape + K, N = b.shape + total_bytes = (M * K + K * N + M * N) * a.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + a, b = input_tensor + M, K = a.shape + K, N = b.shape + FLOPS = 2 * M * N * K + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json new file mode 100644 index 0000000..12df2d7 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json @@ -0,0 +1,467 @@ +[ + { + "input_size": [ + [ + 256, + 256 + ], + [ + 256, + 256 + ] + ], + "ms": 0.009019999764859676, + "GB/s": 43.59379271071602, + "TFLOPS": 3.720003644647768 + }, + { + "input_size": [ + [ + 384, + 384 + ], + [ + 384, + 384 + ] + ], + "ms": 0.010223000310361385, + "GB/s": 86.54367339725967, + "TFLOPS": 11.077590194849238 + }, + { + "input_size": [ + [ + 512, + 512 + ], + [ + 512, + 512 + ] + ], + "ms": 0.012949000112712383, + "GB/s": 121.46605809786634, + "TFLOPS": 20.73020724870252 + }, + { + "input_size": [ + [ + 640, + 640 + ], + [ + 640, + 640 + ] + ], + "ms": 0.014112000353634357, + "GB/s": 174.14965549990777, + "TFLOPS": 37.151926506646994 + }, + { + "input_size": [ + [ + 768, + 768 + ], + [ + 768, + 768 + ] + ], + "ms": 0.01720000058412552, + "GB/s": 205.7525511520165, + "TFLOPS": 52.67265309491623 + }, + { + "input_size": [ + [ + 896, + 896 + ], + [ + 896, + 896 + ] + ], + "ms": 0.018681999295949936, + "GB/s": 257.836215690483, + "TFLOPS": 77.00708308622427 + }, + { + "input_size": [ + [ + 1024, + 1024 + ], + [ + 1024, + 1024 + ] + ], + "ms": 0.02353299967944622, + "GB/s": 267.3461133599119, + "TFLOPS": 91.25414002684992 + }, + { + "input_size": [ + [ + 1152, + 1152 + ], + [ + 1152, + 1152 + ] + ], + "ms": 0.0257789995521307, + "GB/s": 308.8802567336973, + "TFLOPS": 118.61001858573978 + }, + { + "input_size": [ + [ + 1280, + 1280 + ], + [ + 1280, + 1280 + ] + ], + "ms": 0.029627999290823936, + "GB/s": 331.7942566255078, + "TFLOPS": 141.56554949355 + }, + { + "input_size": [ + [ + 1408, + 1408 + ], + [ + 1408, + 1408 + ] + ], + "ms": 0.030910000205039978, + "GB/s": 384.8199262729385, + "TFLOPS": 180.60881873076582 + }, + { + "input_size": [ + [ + 1536, + 1536 + ], + [ + 1536, + 1536 + ] + ], + "ms": 0.03672400116920471, + "GB/s": 385.4638805498806, + "TFLOPS": 197.35750684153888 + }, + { + "input_size": [ + [ + 1664, + 1664 + ], + [ + 1664, + 1664 + ] + ], + "ms": 0.05195799842476845, + "GB/s": 319.7462662857386, + "TFLOPS": 177.352595699823 + }, + { + "input_size": [ + [ + 1792, + 1792 + ], + [ + 1792, + 1792 + ] + ], + "ms": 0.06085899844765663, + "GB/s": 316.59383971906124, + "TFLOPS": 189.1120535921859 + }, + { + "input_size": [ + [ + 1920, + 1920 + ], + [ + 1920, + 1920 + ] + ], + "ms": 0.06326399743556976, + "GB/s": 349.6206515012925, + "TFLOPS": 223.7572169608272 + }, + { + "input_size": [ + [ + 2048, + 2048 + ], + [ + 2048, + 2048 + ] + ], + "ms": 0.07392799854278564, + "GB/s": 340.40991905705846, + "TFLOPS": 232.38650474295193 + }, + { + "input_size": [ + [ + 2176, + 2176 + ], + [ + 2176, + 2176 + ] + ], + "ms": 0.07533200085163116, + "GB/s": 377.1286528809203, + "TFLOPS": 273.54398288962756 + }, + { + "input_size": [ + [ + 2304, + 2304 + ], + [ + 2304, + 2304 + ] + ], + "ms": 0.14252400398254395, + "GB/s": 223.47460855717318, + "TFLOPS": 171.628499371909 + }, + { + "input_size": [ + [ + 2432, + 2432 + ], + [ + 2432, + 2432 + ] + ], + "ms": 0.14476899802684784, + "GB/s": 245.13358856996922, + "TFLOPS": 198.72162913405504 + }, + { + "input_size": [ + [ + 2560, + 2560 + ], + [ + 2560, + 2560 + ] + ], + "ms": 0.16501599550247192, + "GB/s": 238.28962689505434, + "TFLOPS": 203.34048161711303 + }, + { + "input_size": [ + [ + 2688, + 2688 + ], + [ + 2688, + 2688 + ] + ], + "ms": 0.1670999974012375, + "GB/s": 259.4378496362499, + "TFLOPS": 232.45631327407992 + }, + { + "input_size": [ + [ + 2816, + 2816 + ], + [ + 2816, + 2816 + ] + ], + "ms": 0.18959100544452667, + "GB/s": 250.95671542245927, + "TFLOPS": 235.5647035432151 + }, + { + "input_size": [ + [ + 2944, + 2944 + ], + [ + 2944, + 2944 + ] + ], + "ms": 0.18975099921226501, + "GB/s": 274.05819318941786, + "TFLOPS": 268.94244024988205 + }, + { + "input_size": [ + [ + 3072, + 3072 + ], + [ + 3072, + 3072 + ] + ], + "ms": 0.21536900103092194, + "GB/s": 262.9120427218319, + "TFLOPS": 269.22193174715585 + }, + { + "input_size": [ + [ + 3200, + 3200 + ], + [ + 3200, + 3200 + ] + ], + "ms": 0.297435998916626, + "GB/s": 206.56544676430437, + "TFLOPS": 220.33647654859135 + }, + { + "input_size": [ + [ + 3328, + 3328 + ], + [ + 3328, + 3328 + ] + ], + "ms": 0.3247379958629608, + "GB/s": 204.6372917447065, + "TFLOPS": 227.01096897546105 + }, + { + "input_size": [ + [ + 3456, + 3456 + ], + [ + 3456, + 3456 + ] + ], + "ms": 0.32375550270080566, + "GB/s": 221.35103620532738, + "TFLOPS": 254.9963937085371 + }, + { + "input_size": [ + [ + 3584, + 3584 + ], + [ + 3584, + 3584 + ] + ], + "ms": 0.35669100284576416, + "GB/s": 216.0703112360975, + "TFLOPS": 258.1319984900578 + }, + { + "input_size": [ + [ + 3712, + 3712 + ], + [ + 3712, + 3712 + ] + ], + "ms": 0.36759498715400696, + "GB/s": 224.90422037600632, + "TFLOPS": 278.2814886785785 + }, + { + "input_size": [ + [ + 3840, + 3840 + ], + [ + 3840, + 3840 + ] + ], + "ms": 0.4137600064277649, + "GB/s": 213.82830294267677, + "TFLOPS": 273.70022776662626 + }, + { + "input_size": [ + [ + 3968, + 3968 + ], + [ + 3968, + 3968 + ] + ], + "ms": 0.4952859878540039, + "GB/s": 190.73857592726222, + "TFLOPS": 252.2835564264588 + }, + { + "input_size": [ + [ + 4096, + 4096 + ], + [ + 4096, + 4096 + ] + ], + "ms": 0.5376620292663574, + "GB/s": 187.22411202694673, + "TFLOPS": 255.62332095412455 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json new file mode 100644 index 0000000..953bb99 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json @@ -0,0 +1,6 @@ +{ + "triton_matmul.json": { + "ms": 3.5604, + "efficiency": 89.1928 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py b/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py new file mode 100644 index 0000000..8b39a49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + dx = torch.empty_like(x) + grid = (M,) + _l2_norm_bwd_kernel[grid](x, dy, dx, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py b/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py new file mode 100644 index 0000000..efc3d1a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py @@ -0,0 +1,98 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(data * data) + rstd = tl.rsqrt(var + eps) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + y = data * rstd + tl.store(Y + base + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-05): + x = x if x.is_contiguous() else x.contiguous() + original_shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + element_size = x.element_size() + BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N)) + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.view(original_shape) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py b/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py new file mode 100644 index 0000000..0f93fd4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr, stride_bk: tl.constexpr, stride_cm: tl.constexpr, N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + start_n = pid_n * BLOCK_N + offsets_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offsets_n < N + accum = tl.zeros((BLOCK_N,), dtype=tl.float32) + for start_m in range(0, M, BLOCK_M): + offsets_m = start_m + tl.arange(0, BLOCK_M) + mask_m = offsets_m < M + offs_a = A + (offsets_n[:, None] * stride_am + offsets_m[None, :] * stride_ak) + tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0) + offs_b = B + offsets_m * stride_bk + seg_b = tl.load(offs_b, mask=mask_m, other=0.0) + accum += tl.sum(tile_a * seg_b[None, :], axis=1) + offs_c = C + offsets_n * stride_cm + tl.store(offs_c, accum, mask=mask_n) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dtype == B.dtype + assert A.size(1) == B.size(0) + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + BLOCK_N = 128 + BLOCK_M = 64 + grid = (triton.cdiv(N, BLOCK_N),) + mv_kernel[grid](A, B, C, stride_am=A.stride(0), stride_ak=A.stride(1), stride_bk=B.stride(0), stride_cm=C.stride(0), N=N, M=M, BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt b/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt new file mode 100644 index 0000000..63dd0da --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt @@ -0,0 +1,3 @@ +Performance analysis for /workspace/reflexion_oneshot_tritonbench_4/exec: +Error processing l2_norm_triton1.json, skipping... +Error processing sin_kernel.json, skipping... \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py b/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py new file mode 100644 index 0000000..7212c74 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py @@ -0,0 +1,207 @@ +import torch +import triton +import triton.language as tl +import logging +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + end_idx = tl.load(CU_SEQLENS + pid_batch + 1) + cur_seqlen = end_idx - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < cur_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + mask_m_cs = rm_cs < seqlen_ro + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_offs = x0_offs + rotary_dim_half * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_offs = out0_offs + rotary_dim_half * stride_out_headdim + tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + rk_repeat = rk // 2 + cs_mask = rk_repeat[None, :] < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32) + rk_swap = rk + (rk + 1) % 2 * 2 - 1 + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'If cu_seqlens is passed in, then max_seqlen must be passed' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + assert rotary_dim <= headdim + rotary_dim = rotary_dim * 2 + assert headdim <= 256, 'Only support headdim <= 256' + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + assert cos.dtype == sin.dtype + assert x.dtype == cos.dtype + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + if interleaved: + BLOCK_M = 4 + else: + BLOCK_M = 8 if rotary_dim <= 64 else 4 + BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py b/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py new file mode 100644 index 0000000..786f954 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py @@ -0,0 +1,87 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py b/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py new file mode 100644 index 0000000..a59f8d1 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2, 'Input a must be 2-D' + assert b.dim() == 2, 'Input b must be 2-D' + assert a.shape[1] == b.shape[0], 'Incompatible dimensions' + M, K = a.shape + K2, N = b.shape + assert K == K2, 'Inner dimensions must match' + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json b/reflexion_oneshot_tritonbench_4/out.json new file mode 100644 index 0000000..455a27c --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json @@ -0,0 +1,42 @@ +2025-08-23_16-23-45 => File: matrix_vector_multip.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-24-10 => File: triton_matmul.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-24-19 => File: embedding_triton_kernel.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 158, in + result_gold = test_embedding() + ^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 88, in test_embedding + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 29, in embedding + out = out.view(B * SEQ_LEN, d_model) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: shape '[512000, 128]' is invalid for input of size 65536 +2025-08-23_16-24-28 => File: int4_matmul.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 160, in + result_gold = test_correct_int4_s2() + ^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 144, in test_correct_int4_s2 + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 54, in matmul_dequantize_int4_s2 + assert K == qweight.shape[0] * 8 + ^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError +2025-08-23_16-24-38 => File: flash_decode2_phi.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 157, in + result_gold = test_flash_decode_stage2() + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 145, in test_flash_decode_stage2 + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 34, in flash_decode_stage2 + assert Mid_O.dim() == 4 + ^^^^^^^^^^^^^^^^ +AssertionError +2025-08-23_16-24-57 => File: matrix_transpose.py, Call Status: True, Exec Status: False, difficulty: -1, stderr: Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py +2025-08-23_16-25-17 => File: rotary_transform.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-25-37 => File: sin_kernel.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-25-56 => File: l2_norm_bwd.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-26-15 => File: l2_norm_triton1.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-26-15 => File: /workspace/reflexion_oneshot_tritonbench_4.json, Call Accuracy: 0.7, Exec Accuracy: 0.6 diff --git a/reflexion_oneshot_tritonbench_4/out.json_all_passes.json b/reflexion_oneshot_tritonbench_4/out.json_all_passes.json new file mode 100644 index 0000000..31186e0 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_all_passes.json @@ -0,0 +1,92 @@ +[ + { + "pass_num": 0, + "file_name": "matrix_vector_multip.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "triton_matmul.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "embedding_triton_kernel.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 158, in \n result_gold = test_embedding()\n ^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 88, in test_embedding\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 29, in embedding\n out = out.view(B * SEQ_LEN, d_model)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nRuntimeError: shape '[512000, 128]' is invalid for input of size 65536", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "int4_matmul.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 160, in \n result_gold = test_correct_int4_s2()\n ^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 144, in test_correct_int4_s2\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 54, in matmul_dequantize_int4_s2\n assert K == qweight.shape[0] * 8\n ^^^^^^^^^^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "flash_decode2_phi.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 157, in \n result_gold = test_flash_decode_stage2()\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 145, in test_flash_decode_stage2\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n File \"/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py\", line 116, in decorate_context\n return func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 34, in flash_decode_stage2\n assert Mid_O.dim() == 4\n ^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "matrix_transpose.py", + "call_status": 1, + "exec_status": 0, + "stdout": "None", + "stderr": "Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "rotary_transform.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "sin_kernel.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_bwd.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_triton1.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json_perf_0.json b/reflexion_oneshot_tritonbench_4/out.json_perf_0.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_perf_0.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json_results_0.json b/reflexion_oneshot_tritonbench_4/out.json_results_0.json new file mode 100644 index 0000000..31186e0 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_results_0.json @@ -0,0 +1,92 @@ +[ + { + "pass_num": 0, + "file_name": "matrix_vector_multip.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "triton_matmul.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "embedding_triton_kernel.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 158, in \n result_gold = test_embedding()\n ^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 88, in test_embedding\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 29, in embedding\n out = out.view(B * SEQ_LEN, d_model)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nRuntimeError: shape '[512000, 128]' is invalid for input of size 65536", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "int4_matmul.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 160, in \n result_gold = test_correct_int4_s2()\n ^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 144, in test_correct_int4_s2\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 54, in matmul_dequantize_int4_s2\n assert K == qweight.shape[0] * 8\n ^^^^^^^^^^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "flash_decode2_phi.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 157, in \n result_gold = test_flash_decode_stage2()\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 145, in test_flash_decode_stage2\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n File \"/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py\", line 116, in decorate_context\n return func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 34, in flash_decode_stage2\n assert Mid_O.dim() == 4\n ^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "matrix_transpose.py", + "call_status": 1, + "exec_status": 0, + "stdout": "None", + "stderr": "Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "rotary_transform.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "sin_kernel.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_bwd.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_triton1.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt b/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt new file mode 100644 index 0000000..c0b21ee --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt @@ -0,0 +1,4 @@ +Call Accuracy: 70.0 +Exec Accuracy: 60.0 +Call Accuracy: 70.0 +Exec Accuracy: 60.0 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py new file mode 100644 index 0000000..a2d3d9f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py @@ -0,0 +1,158 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weight, out, indices, vocab_size, d_model, vob_start_id, stride_out_0, stride_weight_0, BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + for k in range(0, BLOCK_NN, BLOCK_N): + seq_off = pid * BLOCK_NN + k + tl.arange(0, BLOCK_N) + mask_seq = seq_off < out.shape[0] + token_idx_raw = tl.load(indices + seq_off, mask=mask_seq, other=0).to(tl.int32) + token_idx = token_idx_raw - vob_start_id + token_idx = tl.where(token_idx >= 0, token_idx, 0) + token_idx = tl.where(token_idx < vocab_size, token_idx, vocab_size - 1) + w_offs = token_idx[:, None] * d_model + offs_d[None, :] + vec = tl.load(weight + w_offs) + o_offs = seq_off[:, None] * stride_out_0 + offs_d[None, :] + tl.store(out + o_offs, vec, mask=mask_seq[:, None]) + +def embedding(weight: torch.Tensor, indices: torch.Tensor, vob_start_id: int, vob_end_id: int, out: torch.Tensor=None) -> torch.Tensor: + shape_2d = indices.shape + B, SEQ_LEN = shape_2d + d_model = weight.size(-1) + weight = weight.view(-1, d_model) + if out is None: + out = torch.empty((B * SEQ_LEN, d_model), dtype=weight.dtype, device=weight.device) + else: + out = out.view(B * SEQ_LEN, d_model) + vocab_sz = vob_end_id - vob_start_id if vob_end_id > vob_start_id else weight.size(0) + BLOCK_N = 16 + BLOCK_NN = 64 + BLOCK_DMODEL = triton.next_power_of_2(d_model) + grid = lambda META: (triton.cdiv(out.shape[0], META['BLOCK_NN']),) + embedding_kernel[grid](weight, out, indices.view(-1), vocab_size=vocab_sz, d_model=d_model, vob_start_id=vob_start_id, stride_out_0=out.stride(0), stride_weight_0=weight.stride(0), BLOCK_N=BLOCK_N, BLOCK_NN=BLOCK_NN, BLOCK_DMODEL=BLOCK_DMODEL) + return out.view(B, SEQ_LEN, d_model) + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py new file mode 100644 index 0000000..96f48fb --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weights_ptr, indices_ptr, out_ptr, v_stride_0, v_stride_1, out_stride_0, out_stride_1, vocab_size: tl.constexpr, hidden_size: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, NUM_SEQ: tl.constexpr, MAX_LEN: tl.constexpr): + seq_id = tl.program_id(0) + if seq_id >= NUM_SEQ: + return + len_offsets = tl.program_id(1) + offset_start = len_offsets * BLOCK_NN + dim_offsets = tl.arange(0, BLOCK_SIZE) + indices_base = indices_ptr + seq_id * out_stride_0 + out_base = out_ptr + seq_id * out_stride_0 + for i in range(0, BLOCK_NN, BLOCK_N): + cur_offset = offset_start + i + cur_mask = cur_offset < MAX_LEN + token_id = tl.load(indices_base + cur_offset * out_stride_1, mask=cur_mask, other=0) + token_mask = token_id < vocab_size + valid_mask = cur_mask & token_mask + selector = token_id[:, None] * v_stride_0 + dim_offsets[None, :] * v_stride_1 + w_ptrs = weights_ptr + selector + vec = tl.load(w_ptrs, mask=valid_mask[:, None], other=0.0) + out_ptrs = out_base + cur_offset[:, None] * out_stride_1 + dim_offsets[None, :] + tl.store(out_ptrs, vec, mask=valid_mask[:, None]) + +def embedding(weights: torch.Tensor, indices: torch.Tensor, vob_start_id=None, vob_end_id=None, out=None) -> torch.Tensor: + assert weights.dim() == 2, 'weights must be 2-D [vocab_size, hidden_size]' + vocab_size, hidden_size = weights.shape + assert indices.dim() == 2, 'indices must be 2-D: [batch, seq]' + batch, seq_len = indices.shape + if out is None: + out = torch.empty((batch, seq_len, hidden_size), dtype=weights.dtype, device=weights.device) + else: + assert out.shape == (batch, seq_len, hidden_size), f'out tensor shape mismatch: expected {(batch, seq_len, hidden_size)}, got {out.shape}' + BLOCK_SIZE = triton.next_power_of_2(hidden_size) + BLOCK_N = 1 + BLOCK_NN = triton.next_power_of_2(seq_len) if seq_len > 0 else 1 + grid = (batch, triton.cdiv(seq_len, BLOCK_NN)) + embedding_kernel[grid](weights, indices, out, weights.stride(0), weights.stride(1), out.stride(0), out.stride(1), vocab_size=vocab_size, hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE, BLOCK_N=BLOCK_N, BLOCK_NN=BLOCK_NN, NUM_SEQ=batch, MAX_LEN=seq_len) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py new file mode 100644 index 0000000..9ccee70 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py @@ -0,0 +1,146 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_obs, stride_oh, stride_od, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d * stride_mid_od + offs_l = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_id in range(0, block_n_size): + tv = tl.load(Mid_O + offs_v + block_id * stride_mid_os) + l_cur = tl.load(Mid_O_LogExpSum + offs_l + block_id * stride_mid_o_es) + new_max = tl.maximum(max_logic, l_cur) + old_scale = tl.exp(max_logic - new_max) + acc *= old_scale + sum_exp *= old_scale + cur_exp = tl.exp(l_cur - new_max) + acc += tv * cur_exp + sum_exp += cur_exp + max_logic = new_max + tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od, (acc / sum_exp).to(Out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(mid_out: torch.Tensor, mid_out_logexpsum: torch.Tensor, B_Seqlen: torch.Tensor, out: torch.Tensor, block_seq: int, BLOCK_DMODEL: int=None) -> None: + if BLOCK_DMODEL is None: + BLOCK_DMODEL = out.shape[-1] + batch, head_num = out.shape[:2] + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, mid_out, mid_out_logexpsum, out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), out.stride(0), out.stride(1), out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py new file mode 100644 index 0000000..721898a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py @@ -0,0 +1,157 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_out_b, stride_out_h, stride_out_d, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for block_seq_id in range(0, block_n_size): + ptr_v = Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_seq_id * stride_mid_os + offs_d * stride_mid_od + tv = tl.load(ptr_v).to(tl.float32) + ptr_logic = Mid_O_LogExpSum + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + block_seq_id * stride_mid_o_es + tlogic = tl.load(ptr_logic).to(tl.float32) + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + delta = tl.exp(tlogic - new_max) + acc += delta * tv + sum_exp += delta + max_logic = new_max + acc = acc / sum_exp + ptr_out = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offs_d * stride_out_d + tl.store(ptr_out, acc.to(Out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(B_Seqlen: torch.Tensor, Mid_O: torch.Tensor, Mid_O_LogExpSum: torch.Tensor, Out: torch.Tensor, BLOCK_SEQ: int=64): + assert Mid_O.dim() == 4 + assert Mid_O_LogExpSum.dim() == 3 + assert B_Seqlen.dim() == 1 + assert Out.dim() == 3 + assert Mid_O.shape[0] == Out.shape[0] + assert Mid_O.shape[1] == Out.shape[1] + assert Mid_O.shape[3] == Out.shape[2] + assert Mid_O.shape[0] == Mid_O_LogExpSum.shape[0] + assert Mid_O.shape[1] == Mid_O_LogExpSum.shape[1] + assert Mid_O.shape[2] == Mid_O_LogExpSum.shape[2] + BLOCK_DMODEL = Out.shape[-1] + batch, head_num = (Out.shape[0], Out.shape[1]) + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py new file mode 100644 index 0000000..9eb06c3 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py @@ -0,0 +1,160 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scale_g, stride_scale_n, stride_zp_g, stride_zp_n, group_size, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + start_k = pid_k * BLOCK_SIZE_K + offs_k = start_k + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] // 8 * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_curr = k_step * BLOCK_SIZE_K * SPLIT_K + pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = k_curr[None, :] < K + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + b_raw = tl.load(b_ptrs, mask=mask_k, other=0) + group_idx = k_curr[:, None] // group_size + shift = k_curr[:, None] % 8 * 4 + bits = b_raw >> shift & 15 + scales = tl.load(scales_ptr + group_idx * stride_scale_g + offs_n[None, :] * stride_scale_n, mask=mask_k, other=0.0) + zeros = tl.load(zeros_ptr + group_idx * stride_zp_g + offs_n[None, :] // 8 * stride_zp_n, mask=mask_k, other=0.0) + zeros_bits = zeros >> offs_n[None, :] % 8 * 4 & 15 + b_deq = (bits - zeros_bits) * scales + accumulator += tl.dot(a, b_deq) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K // 8 * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + out_vals = accumulator.to(c_ptr.dtype.element_ty) + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, out_vals, mask=mask_c) + else: + tl.store(c_ptrs, out_vals, mask=mask_c) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.FloatTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, split_k: int=1) -> torch.FloatTensor: + assert x.dim() == 2 and qweight.dim() == 2 and (scales.dim() == 2) and (zeros.dim() == 2) + M, K = x.shape + assert K == qweight.shape[0] * 8 + N = qweight.shape[1] + assert scales.shape == (K // scales.shape[0], N) + assert zeros.shape == (K // zeros.shape[0], N) + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), split_k) + matmul_kernel[grid](x, qweight, output, scales, zeros, M, N, K, x.stride(0), x.stride(1), qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), scales.shape[0]) + return output + +def quantize_int4(w: torch.Tensor, group_size: int=128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + w: [OC, K] (fp16/fp32 weights) + returns: (qpacked, scales, zerospacked, group_size) + qpacked : int32, shape [OC, K//8] + scales : fp16/fp32 [OC, K//group_size] + zerospacked : int32 [K//group_size, OC//8] + """ + assert w.dim() == 2 + OC, K = w.shape + assert K % group_size == 0 + w = w.view(-1, K) + OC_total, K_ = w.shape + groups_per_row = K_ // group_size + w = w.view(OC_total, groups_per_row, group_size) + w_min = w.min(dim=2, keepdim=True).values + w_max = w.max(dim=2, keepdim=True).values + scales = (w_max - w_min) / 15.0 + zeros = torch.round(-w_min / scales).clamp(0, 15).to(torch.int32) + quant = torch.clamp(torch.round(w / scales + zeros), 0, 15).to(torch.int32) + OC_pack = OC_total // 8 + zeros_pack = zeros.permute(1, 0, 2).contiguous().view(groups_per_row, OC_pack, 8) + zeros_packed = zeros_pack[..., 0] + for p in range(1, 8): + zeros_packed |= zeros_pack[..., p] << p * 4 + zeros_packed = zeros_packed.view(groups_per_row, OC_pack).contiguous() + K_pack = K_ // 8 + quant = quant.view(OC_total, K_) + qpacked = torch.empty((OC_total, K_pack), dtype=torch.int32, device=w.device) + for p in range(8): + qpacked |= quant[:, p::8] << p * 4 + qpacked = qpacked.contiguous() + scales = scales.view(OC_total, groups_per_row).contiguous() + return (qpacked, scales, zeros_packed, group_size) + +def unpack_int4(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int=128) -> torch.Tensor: + """ + qweight : int32 [OC, K//8] + scales : [OC, K//group_size] + zeros : int32 [K//group_size, OC//8] + returns : fp16/fp32 tensor [OC, K] + """ + OC, K8 = qweight.shape + K = K8 * 8 + group_dim = K // group_size + assert scales.shape == (OC, group_dim) + assert zeros.shape == (group_dim, OC // 8) + quant = torch.empty((OC, K), dtype=torch.int32, device=qweight.device) + for p in range(8): + mask = 15 << p * 4 + quant[:, p::8] = (qweight & mask) >> p * 4 + OC8 = OC // 8 + zeros_ext = torch.empty((group_dim, OC), dtype=torch.int32, device=zeros.device) + for p in range(8): + mask = 15 << p * 4 + zeros_ext[:, p::8] = (zeros & mask) >> p * 4 + zeros_ext = zeros_ext.permute(1, 0).contiguous() + scales_mat = scales.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K) + zeros_mat = zeros_ext.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K) + return (quant.float() - zeros_mat) * scales_mat + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py new file mode 100644 index 0000000..5dbe6fd --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py @@ -0,0 +1,186 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _pack_int4x2(x0, x1): + x0 = x0 & 15 + x1 = x1 & 15 + return x1 | x0 << 4 + +@triton.jit +def _unpack_int4x2(b): + low = b >> 0 & 15 + high = b >> 4 & 15 + return (low, high) + +@triton.jit +def matmul_kernel(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K2: tl.constexpr, SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_k = tl.cdiv(K, BLOCK_K2 * 2) + grid_m = pid // num_pid_n + grid_n = pid % num_pid_n + rm = grid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = grid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk_packed = pid_z * BLOCK_K2 + tl.arange(0, BLOCK_K2) + rk = rk_packed * 2 + a_ptrs = a_ptr + rm[:, None] * stride_a_m + rk[None, :] * stride_a_k + a_mask = (rm[:, None] < M) & (rk[None, :] < K) + A = tl.load(a_ptrs, mask=a_mask, other=0.0) + q_ptrs = q_ptr + (rk_packed[:, None] * stride_q_k2 + rn[None, :] * stride_q_n) + packed = tl.load(q_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0) + lo, hi = _unpack_int4x2(packed) + q_vals = tl.interleave(lo, hi) + group_idx = rk[:, None] // group_size + s_ptrs = scales_ptr + group_idx * stride_s_g + rn[None, :] * stride_s_n + z_ptrs = zeros_ptr + group_idx * stride_z_g + rn[None, :] * stride_z_n + scales = tl.load(s_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + zeros = tl.load(z_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + b_fp = (q_vals - zeros) * scales + acc = tl.dot(A, b_fp) + c_offs = c_ptr + rm[:, None] * stride_c_m + rn[None, :] * stride_c_n + c_mask = (rm < M)[:, None] & (rn < N)[None, :] + if SPLIT_K > 1: + tl.atomic_add(c_offs, acc, mask=c_mask) + else: + tl.store(c_offs, acc, mask=c_mask) + +@triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K2': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K2': 32, 'SPLIT_K': 2}, num_stages=4, num_warps=8)], key=['M', 'N', 'K']) +@triton.jit +def matmul_dequantize_int4_s2(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K2: tl.constexpr, SPLIT_K: tl.constexpr): + matmul_kernel(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M, BLOCK_N, BLOCK_K2, SPLIT_K) + +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int=128) -> torch.Tensor: + """ + Python-launchable GEMM with INT4 quantized weights. + + Memory layout expected: + qweight – [ K//2 , N ] int32 column-major + scales – [ G , N ] float column-major, G = K // group_size + qzeros – same shape as scales + x – [ M , K ] fp16/fp32 row-major + Output: + c – [ M , N ] fp32 + """ + x = x.contiguous() + qweight = qweight.contiguous() + scales = scales.contiguous() + qzeros = qzeros.contiguous() + M, K = x.shape + N = qweight.size(1) + device = x.device + out = torch.empty((M, N), dtype=torch.float32, device=device) + + def grid(meta): + grid_m = triton.cdiv(M, meta['BLOCK_M']) + grid_n = triton.cdiv(N, meta['BLOCK_N']) + grid_z = meta['SPLIT_K'] + return (grid_m * grid_n, grid_z) + matmul_dequantize_int4_s2[grid](qweight, x, out, scales, qzeros, M, N, K, x.stride(0), x.stride(1), qweight.stride(0), qweight.stride(1), scales.stride(0), scales.stride(1), qzeros.stride(0), qzeros.stride(1), out.stride(0), out.stride(1), group_size) + return out + +def quantize_int4(x: torch.Tensor, group_size: int=128, transpose: bool=True): + """ + Quantise weight matrix (row-vector row-major) into INT4. + + Returns tensors that are column-major (as expected by Triton kernel). + + Args + ---- + x : [K_orig, N] float + Returns + ------- + packed : [ K_orig//2 , N ] int32 column-major + scales : [ G , N ] float column-major, G = K_orig//group_size + zeros : [ G , N ] float column-major + """ + K_orig, N = x.shape + assert K_orig % group_size == 0 + G = K_orig // group_size + x = x.view(G, group_size, N) + x_min = x.min(dim=1, keepdim=True)[0] + x_max = x.max(dim=1, keepdim=True)[0] + scales = (x_max - x_min) / 15.0 + zeros = torch.round(-x_min / scales).clamp(0, 15) + q = torch.round(x / scales + zeros).clamp(0, 15).to(torch.uint8) + q = q.view(K_orig, N) + if transpose: + q = q.T.contiguous() + scales = scales.squeeze(1).T.contiguous() + zeros = zeros.squeeze(1).T.contiguous() + else: + scales = scales.squeeze(1).contiguous() + zeros = zeros.squeeze(1).contiguous() + packed = torch.zeros((N, K_orig // 2), dtype=torch.int32, device=x.device) + for k in range(0, K_orig, 2): + low = q[:, k].to(torch.int32) + high = q[:, k + 1].to(torch.int32) + packed[:, k // 2] = (high & 15) << 4 | low & 15 + return (packed.view(K_orig // 2, N).contiguous(), scales, zeros) + +def unpack_int4(packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int=128): + """ + De-quantize the output of quantize_int4 back to float for validation. + Assumes column-major layout (same as the kernel). + Returns tensor of shape [K, N] – float32, column-major + """ + packed = packed.contiguous() + scales = scales.contiguous() + zeros = zeros.contiguous() + K2, N = packed.shape + K = K2 * 2 + device = packed.device + unpacked = torch.zeros((K, N), dtype=torch.float32, device=device) + for col in range(N): + pack = packed[:, col].clone() + even = pack & 15 + odd = pack >> 4 & 15 + int_vec = torch.empty(K, dtype=torch.float32, device=device) + int_vec[0::2] = even.float() + int_vec[1::2] = odd.float() + group_idx = torch.arange(K, device=device) // group_size + s = scales[group_idx, col] + z = zeros[group_idx, col] + unpacked[:, col] = (int_vec - z) * s + return unpacked + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py new file mode 100644 index 0000000..8b39a49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + dx = torch.empty_like(x) + grid = (M,) + _l2_norm_bwd_kernel[grid](x, dy, dx, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py new file mode 100644 index 0000000..a863581 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py @@ -0,0 +1,114 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + x_square = tl.where(mask, x * x, 0.0) + var = tl.sum(x_square, axis=0) + rstd = 1.0 / tl.sqrt(var + eps) + dy_x = tl.where(mask, dy * x, 0.0) + sum_dy_x = tl.sum(dy_x, axis=0) + term = 1.0 / (var + eps) * rstd * sum_dy_x + dx = dy * rstd - term * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(x.shape[0], dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + _l2_norm_bwd_kernel[M,](x, dy, dx, x.stride(0), N, eps, BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py new file mode 100644 index 0000000..efc3d1a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py @@ -0,0 +1,98 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(data * data) + rstd = tl.rsqrt(var + eps) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + y = data * rstd + tl.store(Y + base + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-05): + x = x if x.is_contiguous() else x.contiguous() + original_shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + element_size = x.element_size() + BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N)) + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.view(original_shape) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py new file mode 100644 index 0000000..83eb1a5 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py @@ -0,0 +1,99 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for off in range(0, N, BLOCK_N): + mask = off + cols < N + x = tl.load(X + off + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_N): + mask = off + cols < N + x = tl.load(X + off + cols, mask=mask, other=0.0) + y = x * rstd + tl.store(Y + off + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-12): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]).contiguous() + y = torch.empty_like(x) + M, N = x.shape + element_size = x.element_size() + max_block_n = 65536 // element_size + BLOCK_N = min(max_block_n, triton.next_power_of_2(N)) + if N > max_block_n: + raise RuntimeError('This L2 norm only supports feature dim < 64KB.') + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py new file mode 100644 index 0000000..4750b49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M: tl.constexpr, D_HEAD: tl.constexpr): + size_m_arange = tl.arange(0, SIZE_M) + d_head_arange = tl.arange(0, D_HEAD) + rows = size_m_arange[:, None] + cols = d_head_arange[None, :] + matrix_ptr = M + rows * matrix_stridex + cols * matrix_stridey + matrix_value = tl.load(matrix_ptr) + out_rows = d_head_arange[:, None] + out_cols = size_m_arange[None, :] + out_ptr = Out + out_rows * out_stridex + out_cols * out_stridey + tl.store(out_ptr, tl.trans(matrix_value)) + +def wrapper(SIZE_M: int, D_HEAD: int): + device = 'cuda' + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + grid = (1,) + kernel[grid](matrix, out, matrix.stride(1), matrix.stride(0), out.stride(1), out.stride(0), SIZE_M=SIZE_M, D_HEAD=D_HEAD) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout new file mode 100644 index 0000000..15547cc --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py new file mode 100644 index 0000000..b993c3d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M: tl.constexpr, D_HEAD: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + mask = mask_m[:, None] & mask_n[None, :] + src_ptrs = M + offs_m[:, None] * matrix_stridey + offs_n[None, :] * matrix_stridex + out_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + original_data = tl.load(src_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, tl.trans(original_data), mask=tl.trans(mask)) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + BLOCK_M = 64 + BLOCK_N = 64 + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_N)) + kernel[grid](matrix, out, matrix.stride(1), matrix.stride(0), out.stride(1), out.stride(0), SIZE_M=SIZE_M, D_HEAD=D_HEAD, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout new file mode 100644 index 0000000..09a97cf --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_602994.py \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py new file mode 100644 index 0000000..0f93fd4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr, stride_bk: tl.constexpr, stride_cm: tl.constexpr, N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + start_n = pid_n * BLOCK_N + offsets_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offsets_n < N + accum = tl.zeros((BLOCK_N,), dtype=tl.float32) + for start_m in range(0, M, BLOCK_M): + offsets_m = start_m + tl.arange(0, BLOCK_M) + mask_m = offsets_m < M + offs_a = A + (offsets_n[:, None] * stride_am + offsets_m[None, :] * stride_ak) + tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0) + offs_b = B + offsets_m * stride_bk + seg_b = tl.load(offs_b, mask=mask_m, other=0.0) + accum += tl.sum(tile_a * seg_b[None, :], axis=1) + offs_c = C + offsets_n * stride_cm + tl.store(offs_c, accum, mask=mask_n) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dtype == B.dtype + assert A.size(1) == B.size(0) + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + BLOCK_N = 128 + BLOCK_M = 64 + grid = (triton.cdiv(N, BLOCK_N),) + mv_kernel[grid](A, B, C, stride_am=A.stride(0), stride_ak=A.stride(1), stride_bk=B.stride(0), stride_cm=C.stride(0), N=N, M=M, BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py new file mode 100644 index 0000000..74e1d21 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py @@ -0,0 +1,71 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + acc = tl.zeros((BLOCK_M,), dtype=tl.float32) + for start_n in range(0, N, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_a = (offs_m[:, None] < M) & (offs_n[None, :] < N) + mask_b = offs_n < N + a_blk = tl.load(A + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an, mask=mask_a, other=0.0) + b_vec = tl.load(B + offs_n * stride_b, mask=mask_b, other=0.0) + product = a_blk * b_vec[None, :] + acc += tl.sum(product, axis=1) + offs_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_out = offs_out < M + tl.store(C + offs_out * stride_c, acc.to(C.dtype.element_ty), mask=mask_out) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1 + M, N = A.shape + assert B.size(0) == N + C = torch.empty(M, dtype=A.dtype, device=A.device) + BLOCK_M = 64 + BLOCK_N = 64 + grid = (triton.cdiv(M, BLOCK_M),) + mv_kernel[grid](A, B, C, M, N, A.stride(0), A.stride(1), B.stride(0), C.stride(0), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py new file mode 100644 index 0000000..6de6c16 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py @@ -0,0 +1,228 @@ +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + x_batch_stride = stride_x_batch + ox_batch_stride = stride_out_batch + offset_b = pid_batch + actual_seqlen = seqlen + else: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + actual_seqlen = seq_end - seq_start + x_batch_stride = stride_x_seqlen + ox_batch_stride = stride_out_seqlen + offset_b = seq_start + if pid_m * BLOCK_M >= actual_seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < actual_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + off = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + off = SEQLEN_OFFSETS + rm_cos = rm + off + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + X_ptr = X + offset_b * x_batch_stride + pid_head * stride_x_nheads + OUT_ptr = OUT + offset_b * ox_batch_stride + pid_head * stride_out_nheads + if not INTERLEAVED: + x0_ptr = X_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim + cos_ptr = COS + rm_cos[:, None] * rotary_dim_half + rk_half[None, :] + sin_ptr = SIN + rm_cos[:, None] * rotary_dim_half + rk_half[None, :] + cos = tl.load(cos_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & mask_k_half[None, :], other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, o0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, o1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + BLOCK_P = BLOCK_K + rk = tl.arange(0, BLOCK_P) + mask_k = rk < rotary_dim + cos_sin_idx = rk // 2 + cos_sin_mask = cos_sin_idx < rotary_dim_half + cos_ptr = COS + rm_cos[:, None] * rotary_dim_half + cos_sin_idx[None, :] + sin_ptr = SIN + rm_cos[:, None] * rotary_dim_half + cos_sin_idx[None, :] + cos_val = tl.load(cos_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & cos_sin_mask[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & cos_sin_mask[None, :], other=0.0).to(tl.float32) + x_even_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (2 * (rk // 2))[None, :] * stride_x_headdim + x_odd_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (2 * (rk // 2) + 1)[None, :] * stride_x_headdim + x_even = tl.load(x_even_ptr, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x_odd = tl.load(x_odd_ptr, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + is_even = rk % 2 == 0 + rot_even = x_even * cos_val - x_odd * sin_val + rot_odd = x_even * sin_val + x_odd * cos_val + final = tl.where(is_even[None, :], rot_even, rot_odd) + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, final, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'max_seqlen must be provided when cu_seqlens is given.' + assert x.dim() == 3, 'x must be 3-D for variable-length case (total_seqlen, nheads, headdim)' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype, f'Mismatched dtypes cos={cos.dtype}, sin={sin.dtype}' + assert x.dtype == cos.dtype, f'Mismatched dtypes x={x.dtype}, cos={cos.dtype}' + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in {torch.int32, torch.int64} + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + BLOCK_K = max(32, min(256, triton.next_power_of_2(rotary_dim_half))) + BLOCK_M = 4 if interleaved else 8 if rotary_dim <= 64 else 4 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + if not is_varlen: + stride_x_b = x.stride(0) + stride_x_s = x.stride(1) + stride_x_n = x.stride(2) + stride_x_h = x.stride(3) + stride_o_b = output.stride(0) + stride_o_s = output.stride(1) + stride_o_n = output.stride(2) + stride_o_h = output.stride(3) + else: + stride_x_b = 0 + stride_x_s = x.stride(0) + stride_x_n = x.stride(1) + stride_x_h = x.stride(2) + stride_o_b = 0 + stride_o_s = output.stride(0) + stride_o_n = output.stride(1) + stride_o_h = output.stride(2) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, 0, stride_o_b, stride_o_s, stride_o_n, stride_o_h, stride_x_b, stride_x_s, stride_x_n, stride_x_h, BLOCK_K=BLOCK_K, IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), IS_VARLEN=is_varlen, INTERLEAVED=interleaved, CONJUGATE=conjugate, BLOCK_M=BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py new file mode 100644 index 0000000..7212c74 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py @@ -0,0 +1,207 @@ +import torch +import triton +import triton.language as tl +import logging +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + end_idx = tl.load(CU_SEQLENS + pid_batch + 1) + cur_seqlen = end_idx - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < cur_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + mask_m_cs = rm_cs < seqlen_ro + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_offs = x0_offs + rotary_dim_half * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_offs = out0_offs + rotary_dim_half * stride_out_headdim + tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + rk_repeat = rk // 2 + cs_mask = rk_repeat[None, :] < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32) + rk_swap = rk + (rk + 1) % 2 * 2 - 1 + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'If cu_seqlens is passed in, then max_seqlen must be passed' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + assert rotary_dim <= headdim + rotary_dim = rotary_dim * 2 + assert headdim <= 256, 'Only support headdim <= 256' + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + assert cos.dtype == sin.dtype + assert x.dtype == cos.dtype + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + if interleaved: + BLOCK_M = 4 + else: + BLOCK_M = 8 if rotary_dim <= 64 else 4 + BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py new file mode 100644 index 0000000..786f954 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py @@ -0,0 +1,87 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py new file mode 100644 index 0000000..7836ee2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor, output: torch.Tensor=None, BLOCK_SIZE: int=128): + n_elements = x.numel() + assert x.is_contiguous() + if output is None: + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py new file mode 100644 index 0000000..a59f8d1 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2, 'Input a must be 2-D' + assert b.dim() == 2, 'Input b must be 2-D' + assert a.shape[1] == b.shape[0], 'Incompatible dimensions' + M, K = a.shape + K2, N = b.shape + assert K == K2, 'Inner dimensions must match' + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py new file mode 100644 index 0000000..e999a3f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py @@ -0,0 +1,124 @@ +import torch +import triton +import triton.language as tl + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == 'cuda' + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == 'hip' + +def get_cuda_autotune_config(): + return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)] + +def get_hip_autotune_config(): + return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4)] + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + +@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr=''): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % num_pid_in_group % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + if ACTIVATION == 'leaky_relu': + accumulator = leaky_relu(accumulator) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + +def matmul(a: torch.Tensor, b: torch.Tensor, activation: str=''): + assert a.dim() == 2, f'Expected 2-D matrix for a, got shape {a.shape}' + assert b.dim() == 2, f'Expected 2-D matrix for b, got shape {b.shape}' + assert a.shape[1] == b.shape[0], f'Incompatible dimensions: a={a.shape}, b={b.shape}' + if a.dtype not in (torch.float16, torch.bfloat16): + raise ValueError('Only float16/bfloat inputs supported.') + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), ACTIVATION=activation) + return c +if __name__ == '__main__': + torch.manual_seed(0) + device = 'cuda' + dtype = torch.float16 + a = torch.randn((512, 512), device=device, dtype=dtype) + b = torch.randn((512, 512), device=device, dtype=dtype) + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f'triton_output={triton_output}') + print(f'torch_output={torch_output}') + if torch.allclose(triton_output, torch_output, atol=0.01, rtol=0): + print('✅ Triton and Torch match') + else: + print('❌ Triton and Torch differ') + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/src/__pycache__/args_config.cpython-312.pyc b/src/__pycache__/args_config.cpython-312.pyc deleted file mode 100644 index ed62ea9..0000000 Binary files a/src/__pycache__/args_config.cpython-312.pyc and /dev/null differ diff --git a/src/agents/__pycache__/Base.cpython-312.pyc b/src/agents/__pycache__/Base.cpython-312.pyc deleted file mode 100644 index f0a272d..0000000 Binary files a/src/agents/__pycache__/Base.cpython-312.pyc and /dev/null differ diff --git a/src/agents/__pycache__/Reflexion.cpython-312.pyc b/src/agents/__pycache__/Reflexion.cpython-312.pyc deleted file mode 100644 index 54cab4b..0000000 Binary files a/src/agents/__pycache__/Reflexion.cpython-312.pyc and /dev/null differ diff --git a/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc deleted file mode 100644 index db3379e..0000000 Binary files a/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc and /dev/null differ diff --git a/src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc deleted file mode 100644 index 72bcb61..0000000 Binary files a/src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc and /dev/null differ diff --git a/src/agents/reflexion_oneshot.py b/src/agents/reflexion_oneshot.py index d10a345..02beb52 100644 --- a/src/agents/reflexion_oneshot.py +++ b/src/agents/reflexion_oneshot.py @@ -40,7 +40,9 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", "reflection", "function_signatures", "oneshot", - "pass_call", + "pass_call", + "pass_exe", + "code" ]): pass @@ -63,6 +65,8 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", function_signatures=fs_mem, oneshot=os_mem["code"], pass_call=False, + pass_exe=False, + code=os_mem["code"] ) else: input_mem = input_mems[ps.filename] @@ -72,6 +76,8 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", function_signatures=fs_mem, oneshot=input_mem["oneshot"], pass_call=input_mem["pass_call"], + pass_exe=input_mem["pass_exe"], + code=input_mem["code"] ) self.memories.append(tmp_mem) @@ -103,26 +109,35 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, """ Run the scripts to verify whether the generated kernels can execute without errors. To check for correctness against expected outputs, use the test_opt_correctness method from TritonBench: + """ + if output_path is not None: + root, extension = os.path.splitext(output_path) + tmp_dir = f"{root}_tmp_{iter}" + exe_dir = f"{root}_pass_exe_{iter}" + perf_result_dir = f"{root}_perf_results_{iter}" + perf_log_dir = f"{root}_perf_logs_{iter}" - if self.config.agent.output_path is not None: - root, extension = os.path.splitext(self.config.agent.output_path) - tmp_dir = f"{root}_tmp_{n}" - exe_dir = f"{root}_pass_exe_{n}" - perf_result_dir = f"{root}_perf_results_{n}" - perf_log_dir = f"{root}_perf_logs_{n}" + else: + tmp_dir = f"tmp_{iter}" + exe_dir = f"pass_exe_{iter}" + perf_result_dir = f"perf_results_{iter}" + perf_log_dir = f"perf_logs_{iter}" + for mem in tqdm(self.memories[:data_len]): + if mem.pass_exe: + continue + try: + pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.code, mem.ps.filename, tmp_dir, exe_dir=exe_dir) + except Exception as e: + logger.info(f"failed to test the code due to : {e}") + mem.err_msg = f"failed to test the code due to: {e}" + continue + if not pass_call: + mem.err_msg = call_stderr + elif not pass_exe: + mem.err_msg = exe_stderr else: - tmp_dir = f"tmp_{n}" - exe_dir = f"pass_exe_{n}" - perf_result_dir = f"perf_results_{n}" - perf_log_dir = f"perf_logs_{n}" - - for fn, mems in tqdm(current_memories.items()): - mem = mems[n] - try: - pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.code, mem.ps.filename, tmp_dir, exe_dir=exe_dir) - - """ + mem.pass_exe = True logger.info(f"\nrun scripts on gpu") for mem in tqdm(self.memories[:data_len]): if mem.pass_call: @@ -132,9 +147,9 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem.err_msg = err_msg """ To measure kernel latency, follow these steps: - - self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=script_dir) - self.dataset.run_perf_scripts(gpu_id=gpu_id, script_dir=script_dir, log_dir=perf_log_dir) + """ + self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=tmp_dir) + self.dataset.run_perf_scripts(script_dir=tmp_dir, log_dir=perf_log_dir) for mem in self.memories[:data_len]: path_gen = os.path.join(perf_result_dir, mem.ps.filename[:-3] + ".json") @@ -150,8 +165,6 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, except Exception as e: print(f"{mem.ps.filename} failed due to {e}") - """ - # generate reflections logger.info(f"\ngenerate reflections") with tqdm(total=data_len) as pbar: @@ -220,4 +233,17 @@ def generate_reflexion(self, mem, temperature): "content": reflect_txt } ] - mem.reflection = self.model.generate(reflect_msg, temperature=temperature) \ No newline at end of file + mem.reflection = self.model.generate(reflect_msg, temperature=temperature) + + def save_memory(self, save_path="memory", datalen=None): + data_len = datalen if datalen else len(self.dataset) + for i, mem in enumerate(self.memories[:data_len]): + mem_dict = {} + mem_dict["err_msg"] = mem.err_msg + mem_dict["reflection"] = mem.reflection + mem_dict["function_signatures"] = mem.function_signatures + mem_dict["oneshot"] = mem.oneshot + mem_dict["pass_call"] = mem.pass_call + mem_dict["pass_exe"] = mem.pass_exe + mem_dict["code"] = mem.code + json.dump(f"{save_path}_file_{i}".json) diff --git a/src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc b/src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc deleted file mode 100644 index ef74535..0000000 Binary files a/src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc and /dev/null differ diff --git a/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc b/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc deleted file mode 100644 index 041258d..0000000 Binary files a/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc and /dev/null differ diff --git a/src/dataloaders/__pycache__/ROCm.cpython-312.pyc b/src/dataloaders/__pycache__/ROCm.cpython-312.pyc deleted file mode 100644 index b207f84..0000000 Binary files a/src/dataloaders/__pycache__/ROCm.cpython-312.pyc and /dev/null differ diff --git a/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc b/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc deleted file mode 100644 index ad4c954..0000000 Binary files a/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc and /dev/null differ diff --git a/src/main_reflexion_oneshot.py b/src/main_reflexion_oneshot.py index b5d4f9f..ac7fc1f 100644 --- a/src/main_reflexion_oneshot.py +++ b/src/main_reflexion_oneshot.py @@ -30,6 +30,8 @@ def main(): # run the agent agent.run(output_path=args.output_path, multi_thread=args.multi_thread, iteration_num=args.max_iteration, temperature=args.temperature, datalen=None) + # save memory for further fine-tune + # agent.save_memory(save_path="memory") if __name__ == "__main__": main() diff --git a/src/memories/__pycache__/Memory.cpython-312.pyc b/src/memories/__pycache__/Memory.cpython-312.pyc deleted file mode 100644 index 09f82b9..0000000 Binary files a/src/memories/__pycache__/Memory.cpython-312.pyc and /dev/null differ diff --git a/src/models/__pycache__/Base.cpython-312.pyc b/src/models/__pycache__/Base.cpython-312.pyc deleted file mode 100644 index 0ad2c73..0000000 Binary files a/src/models/__pycache__/Base.cpython-312.pyc and /dev/null differ diff --git a/src/models/__pycache__/KimiK2.cpython-312.pyc b/src/models/__pycache__/KimiK2.cpython-312.pyc deleted file mode 100644 index 34c5a67..0000000 Binary files a/src/models/__pycache__/KimiK2.cpython-312.pyc and /dev/null differ diff --git a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc deleted file mode 100644 index 29e23cc..0000000 Binary files a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc and /dev/null differ diff --git a/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc deleted file mode 100644 index 71813ca..0000000 Binary files a/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc and /dev/null differ diff --git a/src/prompts/prompt_for_generation.py b/src/prompts/prompt_for_generation.py index 6d07461..5c5f637 100644 --- a/src/prompts/prompt_for_generation.py +++ b/src/prompts/prompt_for_generation.py @@ -4,6 +4,15 @@ Your task is to generate a Python code snippet containing a Triton kernel based on the following request: **Target Platform:** AMD GPU (ROCm) +Here are basic AMD GPU hardware information: +1. Graphic architecture: gfx942 (CDNA 3) +2. Computer Unit Number: 304 +3. SIMD number per CU: 4 +4. Wave Front size: 64 +5. Avaliable GPU Device: 1 +6. Vector General Purpose Register (VGPR) size: 256 KB +7. Scalar General Purpose Register (SGPR) size : 64 KB +8. Local Data Share (LDS) size: 64 KB **Request:** {instruction} @@ -89,15 +98,15 @@ * **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`. 8. **Triton Version:** Assume Triton version 3.2.0 or later. 9. Maximize performance by exploring the following: -i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. -ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. +i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. +ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. Primary Autotuning Fields (Mandatory) 1. BLOCK_M, BLOCK_N, BLOCK_K * Tile sizes for GEMM or other tensor contractions. * Larger blocks improve compute density, but reduce grid-level parallelism. * Explore wide range of values like: - * BLOCK: [32, ..., 128, ..., 2048, ...] + * BLOCK: [32, ..., 128, ..., 2048, ...] * Adjust based on memory reuse and L2 cache locality. 2. num_stages=n * Controls pipeline depth for kernel execution. @@ -111,7 +120,7 @@ * If it is too low then underutilization -> kernel runs slow. * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. * You must choose a sweet spot by trying out integer range of 1 to 16. - * You MUST NOT try the range beyond 16, it is NOT VALID. + * You MUST NOT try the range beyond 16, it is NOT VALID. Examples of Autotuning Setup Here's how Triton kernels should be decorated to allow autotuning: * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. diff --git a/src/prompts/prompt_for_reflection.py b/src/prompts/prompt_for_reflection.py index fe3f936..d81f815 100644 --- a/src/prompts/prompt_for_reflection.py +++ b/src/prompts/prompt_for_reflection.py @@ -20,6 +20,57 @@ - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" +**Target Platform:** AMD GPU (ROCm) +Here are basic AMD GPU hardware information: +1. Graphic architecture: gfx942 (CDNA 3) +2. Computer Unit Number: 304 +3. SIMD number per CU: 4 +4. Wave Front size: 64 +5. Avaliable GPU Device: 1 +6. Vector General Purpose Register (VGPR) size: 256 KB +7. Scalar General Purpose Register (SGPR) size : 64 KB +8. Local Data Share (LDS) size: 64 KB + +Maximize performance by exploring the following: +i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. +ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. +iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. +Primary Autotuning Fields (Mandatory) +1. BLOCK_M, BLOCK_N, BLOCK_K + * Tile sizes for GEMM or other tensor contractions. + * Larger blocks improve compute density, but reduce grid-level parallelism. + * Explore wide range of values like: + * BLOCK: [32, ..., 128, ..., 2048, ...] + * Adjust based on memory reuse and L2 cache locality. +2. num_stages=n + * Controls pipeline depth for kernel execution. + * Rules for setting this: + * 1 if no GEMM. + * 2 if a single GEMM (e.g., GEMM + ReLU). + * 1 if two GEMMs are fused (e.g., Flash Attention). + * Optimize for latency and execution overlap. +3. num_warps + * Controls number of warps (groups of 64 threads) to launch per block. + * If it is too low then underutilization -> kernel runs slow. + * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. + * You must choose a sweet spot by trying out integer range of 1 to 16. + * You MUST NOT try the range beyond 16, it is NOT VALID. +Examples of Autotuning Setup +Here's how Triton kernels should be decorated to allow autotuning: + * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. + * BLOCK_M refers to the chunk of variable M that will be used for compute by a thread at a time. + * You must ensure that variables used in the triton.Config should not be passed as arguments to the triton kernel. +For example: the following autotune config receives BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, and num_stages as input arguments. Hence the triton kernel must not receive these arguments as inputs in the wrapper function. You must comment/delete any such instances. + +NOTE: If you face kernel timeout issues, check if Grid and Program ID Mismatch exists or not for example The kernel is launched with a 1-dimensional (1D) grid, but inside the kernel, it attempts to read program IDs from a 2-dimensional (2D) grid etc. + +def grid(args: dict[str, Any]) -> tuple[int]: + # This creates a 1D grid of size (C * D, ) + return (triton.cdiv(M, args["BLOCK_SIZE_M"]) * triton.cdiv(N, args["BLOCK_SIZE_N"]), ) + +The grid is calculated as a single integer, creating a 1D grid, however the kernel might try to get two separate program IDs, pid_m and pid_n, as if it were a 2D grid: +pid_m = tl.program_id(0) # Gets the ID for the first dimension +pid_n = tl.program_id(1) # Tries to get ID for a non-existent second dimension """ prompt_exe = """ diff --git a/src/retrievers/__pycache__/retriever.cpython-312.pyc b/src/retrievers/__pycache__/retriever.cpython-312.pyc deleted file mode 100644 index dc463e3..0000000 Binary files a/src/retrievers/__pycache__/retriever.cpython-312.pyc and /dev/null differ diff --git a/src/utils/__pycache__/utils.cpython-312.pyc b/src/utils/__pycache__/utils.cpython-312.pyc deleted file mode 100644 index 5240a44..0000000 Binary files a/src/utils/__pycache__/utils.cpython-312.pyc and /dev/null differ