From 89f956362e850cbe9e17ecc8aa4ab07606029642 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sat, 23 Aug 2025 16:44:25 +0000 Subject: [PATCH 1/6] Update test results with call status and lantecy info for agent code gen --- out.json_perf_0.json | 16 + reflexion_oneshot_tritonbench_4.json | 1 + .../.out.jsonpassk.txt.swp | Bin 0 -> 4096 bytes .../exec/flash_decode2_phi.py | 146 ++++++ .../exec/gen_perf/efficiency.json | 16 + .../exec/gen_perf/flash_decode2_phi.json | 418 ++++++++++++++++ .../gen_perf/flash_decode2_phi_perf_data.json | 6 + .../exec/gen_perf/l2_norm_bwd.json | 145 ++++++ .../exec/gen_perf/l2_norm_bwd_perf_data.json | 6 + .../exec/gen_perf/l2_norm_triton1.json | 1 + .../gen_perf/l2_norm_triton1_perf_data.json | 6 + .../logs/flash_decode2_phi_perf.py.err | 0 .../logs/flash_decode2_phi_perf.py.log | 16 + .../gen_perf/logs/l2_norm_bwd_perf.py.err | 0 .../gen_perf/logs/l2_norm_bwd_perf.py.log | 11 + .../gen_perf/logs/l2_norm_triton1_perf.py.err | 0 .../gen_perf/logs/l2_norm_triton1_perf.py.log | 11 + .../logs/matrix_vector_multip_perf.py.err | 0 .../logs/matrix_vector_multip_perf.py.log | 18 + .../gen_perf/logs/performance_utils.py.err | 0 .../gen_perf/logs/performance_utils.py.log | 0 .../logs/rotary_transform_perf.py.err | 0 .../logs/rotary_transform_perf.py.log | 14 + .../exec/gen_perf/logs/sin_kernel_perf.py.err | 0 .../exec/gen_perf/logs/sin_kernel_perf.py.log | 16 + .../gen_perf/logs/triton_matmul_perf.py.err | 0 .../gen_perf/logs/triton_matmul_perf.py.log | 31 ++ .../exec/gen_perf/matrix_vector_multip.json | 254 ++++++++++ .../matrix_vector_multip_perf_data.json | 6 + .../exec/gen_perf/rotary_transform.json | 296 +++++++++++ .../gen_perf/rotary_transform_perf_data.json | 6 + .../exec/gen_perf/sin_kernel.json | 162 ++++++ .../exec/gen_perf/sin_kernel_perf_data.json | 6 + .../gen_perf/tmp/flash_decode2_phi_perf.py | 59 +++ .../exec/gen_perf/tmp/l2_norm_bwd_perf.py | 51 ++ .../exec/gen_perf/tmp/l2_norm_triton1_perf.py | 75 +++ .../gen_perf/tmp/matrix_vector_multip_perf.py | 52 ++ .../exec/gen_perf/tmp/performance_utils.py | 145 ++++++ .../gen_perf/tmp/rotary_transform_perf.py | 56 +++ .../exec/gen_perf/tmp/sin_kernel_perf.py | 75 +++ .../exec/gen_perf/tmp/triton_matmul_perf.py | 54 ++ .../exec/gen_perf/triton_matmul.json | 467 ++++++++++++++++++ .../gen_perf/triton_matmul_perf_data.json | 6 + .../exec/l2_norm_bwd.py | 110 +++++ .../exec/l2_norm_triton1.py | 98 ++++ .../exec/matrix_vector_multip.py | 72 +++ .../exec/performance_analysis.txt | 3 + .../exec/rotary_transform.py | 207 ++++++++ .../exec/sin_kernel.py | 87 ++++ .../exec/triton_matmul.py | 88 ++++ reflexion_oneshot_tritonbench_4/out.json | 42 ++ .../out.json_all_passes.json | 92 ++++ .../out.json_perf_0.json | 16 + .../out.json_results_0.json | 92 ++++ .../out.jsonpassk.txt | 4 + ...triton_kernel.py_gen_triton_code_296397.py | 158 ++++++ ..._triton_kernel.py_gen_triton_code_97456.py | 163 ++++++ ...h_decode2_phi.py_gen_triton_code_473638.py | 146 ++++++ ...e2_phi.py_gen_triton_code_473638.py.stderr | 0 ...e2_phi.py_gen_triton_code_473638.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_577640.py | 157 ++++++ .../int4_matmul.py_gen_triton_code_126776.py | 160 ++++++ .../int4_matmul.py_gen_triton_code_506673.py | 186 +++++++ .../l2_norm_bwd.py_gen_triton_code_240873.py | 110 +++++ ...rm_bwd.py_gen_triton_code_240873.py.stderr | 0 ...rm_bwd.py_gen_triton_code_240873.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_419958.py | 114 +++++ ...rm_bwd.py_gen_triton_code_419958.py.stderr | 0 ...rm_bwd.py_gen_triton_code_419958.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_414289.py | 98 ++++ ...riton1.py_gen_triton_code_414289.py.stderr | 0 ...riton1.py_gen_triton_code_414289.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_485284.py | 99 ++++ ...riton1.py_gen_triton_code_485284.py.stderr | 0 ...riton1.py_gen_triton_code_485284.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_179599.py | 72 +++ ...nspose.py_gen_triton_code_179599.py.stderr | 0 ...nspose.py_gen_triton_code_179599.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_602994.py | 74 +++ ...nspose.py_gen_triton_code_602994.py.stderr | 0 ...nspose.py_gen_triton_code_602994.py.stdout | 1 + ...vector_multip.py_gen_triton_code_432557.py | 72 +++ ...multip.py_gen_triton_code_432557.py.stderr | 0 ...multip.py_gen_triton_code_432557.py.stdout | 1 + ...vector_multip.py_gen_triton_code_450542.py | 71 +++ ...multip.py_gen_triton_code_450542.py.stderr | 0 ...multip.py_gen_triton_code_450542.py.stdout | 1 + ...ary_transform.py_gen_triton_code_571970.py | 228 +++++++++ ...tary_transform.py_gen_triton_code_82100.py | 207 ++++++++ ...ansform.py_gen_triton_code_82100.py.stderr | 0 ...ansform.py_gen_triton_code_82100.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_186414.py | 87 ++++ ...kernel.py_gen_triton_code_186414.py.stderr | 0 ...kernel.py_gen_triton_code_186414.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_412027.py | 88 ++++ ...kernel.py_gen_triton_code_412027.py.stderr | 0 ...kernel.py_gen_triton_code_412027.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_700160.py | 88 ++++ ...matmul.py_gen_triton_code_700160.py.stderr | 0 ...matmul.py_gen_triton_code_700160.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_890291.py | 124 +++++ ...matmul.py_gen_triton_code_890291.py.stderr | 0 ...matmul.py_gen_triton_code_890291.py.stdout | 1 + src/agents/reflexion_oneshot.py | 64 ++- src/agents/reflexion_oneshot_ROCm.py | 10 +- 105 files changed, 6119 insertions(+), 30 deletions(-) create mode 100644 out.json_perf_0.json create mode 100644 reflexion_oneshot_tritonbench_4.json create mode 100644 reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp create mode 100644 reflexion_oneshot_tritonbench_4/exec/flash_decode2_phi.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.err create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json create mode 100644 reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt create mode 100644 reflexion_oneshot_tritonbench_4/exec/rotary_transform.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/sin_kernel.py create mode 100644 reflexion_oneshot_tritonbench_4/exec/triton_matmul.py create mode 100644 reflexion_oneshot_tritonbench_4/out.json create mode 100644 reflexion_oneshot_tritonbench_4/out.json_all_passes.json create mode 100644 reflexion_oneshot_tritonbench_4/out.json_perf_0.json create mode 100644 reflexion_oneshot_tritonbench_4/out.json_results_0.json create mode 100644 reflexion_oneshot_tritonbench_4/out.jsonpassk.txt create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stderr create mode 100644 reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout diff --git a/out.json_perf_0.json b/out.json_perf_0.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/out.json_perf_0.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4.json b/reflexion_oneshot_tritonbench_4.json new file mode 100644 index 0000000..6ae2c2f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4.json @@ -0,0 +1 @@ +[{"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", "label": null, "file": "matrix_vector_multip.py", "test_code": "\n\n\ndef test_mv():\n # \u6d4b\u8bd5\u7528\u4f8b 2: 4x3 \u77e9\u9635\u4e0e 3x1 \u5411\u91cf\u76f8\u4e58\n A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda')\n B = torch.tensor([1.0, 2.0, 3.0], device='cuda')\n triton_result_2 = mv(A, B)\n\n # \u6d4b\u8bd5\u7528\u4f8b 3: 32x16 \u77e9\u9635\u4e0e 16x1 \u5411\u91cf\u76f8\u4e58\n A = torch.randn(32, 16, device='cuda')\n B = torch.randn(16, device='cuda')\n triton_result_3 = mv(A, B)\n\n return {\n \"test_case_2\": triton_result_2,\n \"test_case_3\": triton_result_3,\n }\n\nresult_gold = test_mv()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr,\n stride_bk: tl.constexpr, stride_cm: tl.constexpr,\n N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):\n pid_n = tl.program_id(0) # one program per N-tile\n start_n = pid_n * BLOCK_N\n offsets_n = start_n + tl.arange(0, BLOCK_N)\n mask_n = offsets_n < N\n\n # allocate register tile\n accum = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n for start_m in range(0, M, BLOCK_M):\n offsets_m = start_m + tl.arange(0, BLOCK_M)\n mask_m = offsets_m < M\n # A tile: [BLOCK_N, BLOCK_M]\n offs_a = A + (offsets_n[:, None] * stride_am +\n offsets_m[None, :] * stride_ak)\n tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0)\n # B segment: [BLOCK_M]\n offs_b = B + offsets_m * stride_bk\n seg_b = tl.load(offs_b, mask=mask_m, other=0.0)\n\n accum += tl.sum(tile_a * seg_b[None, :], axis=1)\n\n # write back\n offs_c = C + offsets_n * stride_cm\n tl.store(offs_c, accum, mask=mask_n)\n\ndef mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n assert A.dtype == B.dtype\n assert A.size(1) == B.size(0)\n N, M = A.shape\n C = torch.empty((N,), dtype=A.dtype, device=A.device)\n BLOCK_N = 128\n BLOCK_M = 64\n grid = (triton.cdiv(N, BLOCK_N),)\n mv_kernel[grid](\n A, B, C,\n stride_am=A.stride(0), stride_ak=A.stride(1),\n stride_bk=B.stride(0), stride_cm=C.stride(0),\n N=N, M=M,\n BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M\n )\n return C\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", "label": null, "file": "triton_matmul.py", "test_code": "\n\n\nimport torch\n\n# Test for matmul\ndef test_matmul():\n results = {}\n M, K, N = 256, 128, 256\n\n # Test case 1: torch.float16\n a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n b = torch.randn((K, N), dtype=torch.float16, device='cuda')\n c = matmul(a, b)\n results['test_case_1'] = c\n\n return results\n\n# Run all tests\nresult_gold = test_matmul()", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K)\n b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N)\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b = tl.load(b_ptrs, mask=b_mask, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c = accumulator.to(c_ptr.dtype.element_ty)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n assert a.dim() == 2, \"Input a must be 2-D\"\n assert b.dim() == 2, \"Input b must be 2-D\"\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K2, N = b.shape\n assert K == K2, \"Inner dimensions must match\"\n\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n\n matmul_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n )\n return c\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", "label": null, "file": "embedding_triton_kernel.py", "test_code": "\n\n\nimport torch\n\ndef test_embedding():\n # \u53c2\u6570\u5b9a\u4e49\n vocab_size = 1000 # \u8bcd\u6c47\u8868\u5927\u5c0f\n embedding_dim = 512 # \u5d4c\u5165\u7ef4\u5ea6\n sequence_length = 128 # \u8f93\u5165\u5e8f\u5217\u957f\u5ea6\n vob_start_id = 10 # \u8bcd\u6c47\u8868\u8d77\u59cb ID\n vob_end_id = 1000 # \u8bcd\u6c47\u8868\u7ed3\u675f ID\n\n # \u521b\u5efa\u6d4b\u8bd5\u8f93\u5165\u5f20\u91cf\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n\n # \u8c03\u7528\u5d4c\u5165\u51fd\u6570\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n\n # \u4fdd\u5b58\u7ed3\u679c\n results = {}\n results['test_case_1'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8f93\u5165\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_2'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u8bcd\u6c47\u8868\u8303\u56f4\n vob_start_id = 0\n vob_end_id = 500\n input_ids = torch.randint(\n vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_3'] = out.clone()\n\n # \u6d4b\u8bd5\u4e0d\u540c\u7684\u5d4c\u5165\u7ef4\u5ea6\n embedding_dim = 256\n weight = torch.randn(\n vocab_size, embedding_dim, dtype=torch.float32, device='cuda'\n )\n out = torch.zeros(\n sequence_length, embedding_dim, dtype=torch.float32, device='cuda'\n )\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n results['test_case_4'] = out.clone()\n\n return results\n\nresult_gold = test_embedding()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef embedding_kernel(\n weight,\n out,\n indices,\n vocab_size,\n d_model,\n vob_start_id,\n stride_out_0,\n stride_weight_0,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n for k in range(0, BLOCK_NN, BLOCK_N):\n seq_off = pid * BLOCK_NN + k + tl.arange(0, BLOCK_N)\n mask_seq = seq_off < out.shape[0]\n\n token_idx_raw = tl.load(indices + seq_off, mask=mask_seq, other=0).to(tl.int32)\n token_idx = token_idx_raw - vob_start_id\n token_idx = tl.where(token_idx >= 0, token_idx, 0)\n token_idx = tl.where(token_idx < vocab_size, token_idx, vocab_size - 1)\n\n w_offs = (\n token_idx[:, None] * d_model\n + offs_d[None, :]\n )\n vec = tl.load(weight + w_offs)\n\n o_offs = (\n seq_off[:, None] * stride_out_0\n + offs_d[None, :]\n )\n tl.store(out + o_offs, vec, mask=mask_seq[:, None])\n\n\ndef embedding(\n weight: torch.Tensor,\n indices: torch.Tensor,\n vob_start_id: int,\n vob_end_id: int,\n out: torch.Tensor = None\n) -> torch.Tensor:\n shape_2d = indices.shape\n B, SEQ_LEN = shape_2d\n d_model = weight.size(-1)\n weight = weight.view(-1, d_model) # Ensure [vocab_size, d_model]\n\n if out is None:\n out = torch.empty(\n (B * SEQ_LEN, d_model),\n dtype=weight.dtype,\n device=weight.device\n )\n else:\n out = out.view(B * SEQ_LEN, d_model)\n\n vocab_sz = vob_end_id - vob_start_id if vob_end_id > vob_start_id else weight.size(0)\n\n BLOCK_N = 16\n BLOCK_NN = 64\n BLOCK_DMODEL = triton.next_power_of_2(d_model)\n\n grid = lambda META: (triton.cdiv(out.shape[0], META[\"BLOCK_NN\"]), )\n\n embedding_kernel[grid](\n weight,\n out,\n indices.view(-1),\n vocab_size=vocab_sz,\n d_model=d_model,\n vob_start_id=vob_start_id,\n stride_out_0=out.stride(0),\n stride_weight_0=weight.stride(0),\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n BLOCK_DMODEL=BLOCK_DMODEL,\n )\n\n return out.view(B, SEQ_LEN, d_model)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", "label": null, "file": "int4_matmul.py", "test_code": "\n\n\ndef test_correct_int4_s2(M=32, K=4096, N=4096):\n group_size = 128\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size)\n \n # Test case\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n \n results = {\n \"test_case_1\": triton_output\n }\n \n return results\n\nresult_gold = test_correct_int4_s2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n# ===========================================================\n# Triton Kernel (Direct INT4 matmul)\n# -----------------------------------------------------------\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128,'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn, # stride_bk in ints (K//8, N)\n stride_cm, stride_cn,\n stride_scale_g, stride_scale_n,\n stride_zp_g, stride_zp_n,\n group_size,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_k = tl.program_id(1)\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)\n\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n start_k = pid_k * BLOCK_SIZE_K\n offs_k = start_k + tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n k_curr = k_step * BLOCK_SIZE_K * SPLIT_K + pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n mask_k = k_curr[None, :] < K\n\n a = tl.load(a_ptrs, mask=mask_k, other=0.0)\n\n b_raw = tl.load(b_ptrs, mask=mask_k, other=0)\n group_idx = k_curr[:, None] // group_size\n shift = (k_curr[:, None] % 8) * 4\n bits = (b_raw >> shift) & 0xF\n\n scales = tl.load(\n scales_ptr + group_idx * stride_scale_g + offs_n[None, :] * stride_scale_n,\n mask=mask_k, other=0.0\n )\n zeros = tl.load(\n zeros_ptr + group_idx * stride_zp_g + (offs_n[None, :] // 8) * stride_zp_n,\n mask=mask_k, other=0.0\n )\n zeros_bits = (zeros >> ((offs_n[None, :] % 8) * 4)) & 0xF\n\n b_deq = (bits - zeros_bits) * scales\n accumulator += tl.dot(a, b_deq)\n\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn\n\n out_vals = accumulator.to(c_ptr.dtype.element_ty)\n if SPLIT_K > 1:\n tl.atomic_add(c_ptrs, out_vals, mask=mask_c)\n else:\n tl.store(c_ptrs, out_vals, mask=mask_c)\n\n# ===========================================================\n# Python wrapper\n# -----------------------------------------------------------\ndef matmul_dequantize_int4_s2(\n x: torch.FloatTensor,\n qweight: torch.FloatTensor,\n scales: torch.FloatTensor,\n zeros: torch.FloatTensor,\n split_k: int = 1,\n) -> torch.FloatTensor:\n assert x.dim() == 2 and qweight.dim() == 2 and scales.dim() == 2 and zeros.dim() == 2\n M, K = x.shape\n assert K == qweight.shape[0] * 8\n N = qweight.shape[1]\n assert scales.shape == (K // scales.shape[0], N)\n assert zeros.shape == (K // zeros.shape[0], N)\n\n output = torch.empty((M, N), dtype=x.dtype, device=x.device)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n split_k,\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, zeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n zeros.stride(0), zeros.stride(1),\n scales.shape[0],\n )\n return output\n\n# ===========================================================\n# Quantise / Unpack helpers (correct 8\u00d7 packing)\n# -----------------------------------------------------------\ndef quantize_int4(w: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:\n \"\"\"\n w: [OC, K] (fp16/fp32 weights)\n returns: (qpacked, scales, zerospacked, group_size)\n qpacked : int32, shape [OC, K//8]\n scales : fp16/fp32 [OC, K//group_size]\n zerospacked : int32 [K//group_size, OC//8]\n \"\"\"\n assert w.dim() == 2\n OC, K = w.shape\n assert K % group_size == 0\n\n w = w.view(-1, K)\n OC_total, K_ = w.shape\n groups_per_row = K_ // group_size\n w = w.view(OC_total, groups_per_row, group_size)\n\n w_min = w.min(dim=2, keepdim=True).values\n w_max = w.max(dim=2, keepdim=True).values\n scales = (w_max - w_min) / 15.0\n zeros = torch.round(-w_min / scales).clamp(0, 15).to(torch.int32)\n\n quant = torch.clamp(torch.round((w / scales) + zeros), 0, 15).to(torch.int32)\n\n # === pack 8\u00d7OC along the OC axis (group row dimension of zero) ===\n OC_pack = OC_total // 8\n zeros_pack = zeros.permute(1, 0, 2).contiguous().view(groups_per_row, OC_pack, 8)\n zeros_packed = zeros_pack[..., 0]\n for p in range(1, 8):\n zeros_packed |= zeros_pack[..., p] << (p * 4)\n zeros_packed = zeros_packed.view(groups_per_row, OC_pack).contiguous()\n\n # === pack 8\u00d7K along the K axis (group column dimension of weight) ===\n K_pack = K_ // 8\n quant = quant.view(OC_total, K_)\n qpacked = torch.empty((OC_total, K_pack), dtype=torch.int32, device=w.device)\n\n for p in range(8):\n qpacked |= quant[:, p::8] << (p * 4)\n\n qpacked = qpacked.contiguous()\n scales = scales.view(OC_total, groups_per_row).contiguous()\n return qpacked, scales, zeros_packed, group_size\n\n\n# -----------------------------------------------------------\ndef unpack_int4(\n qweight: torch.Tensor,\n scales: torch.Tensor,\n zeros: torch.Tensor,\n group_size: int = 128\n) -> torch.Tensor:\n \"\"\"\n qweight : int32 [OC, K//8]\n scales : [OC, K//group_size]\n zeros : int32 [K//group_size, OC//8]\n returns : fp16/fp32 tensor [OC, K]\n \"\"\"\n OC, K8 = qweight.shape\n K = K8 * 8\n group_dim = K // group_size\n assert scales.shape == (OC, group_dim)\n assert zeros.shape == (group_dim, OC//8)\n\n # --- weight ---\n quant = torch.empty((OC, K), dtype=torch.int32, device=qweight.device)\n for p in range(8):\n mask = 0xF << (p * 4)\n quant[:, p::8] = (qweight & mask) >> (p * 4)\n\n # --- zero (grouped) ---\n OC8 = OC // 8\n zeros_ext = torch.empty((group_dim, OC), dtype=torch.int32, device=zeros.device)\n for p in range(8):\n mask = 0xF << (p * 4)\n zeros_ext[:, p::8] = (zeros & mask) >> (p * 4)\n zeros_ext = zeros_ext.permute(1, 0).contiguous()\n\n # --- dequant ---\n scales_mat = scales.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K)\n zeros_mat = zeros_ext.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K)\n\n return (quant.float() - zeros_mat) * scales_mat\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", "label": null, "file": "flash_decode2_phi.py", "test_code": "\n\n\nimport torch\n\n# Define the test function\ndef test_flash_decode_stage2():\n # Define the parameters for different test cases\n batch_size = 2\n head_num = 4\n seq_block_num = 3\n head_dim = 64\n block_seq = 16\n\n test_cases = {\n \"test_case_1\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq\n },\n \"test_case_2\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq + 1 # Different block size\n },\n \"test_case_3\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq // 2 # Different block size\n },\n \"test_case_4\": {\n \"B_Seqlen\": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),\n \"mid_out\": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),\n \"mid_out_logexpsum\": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),\n \"Out\": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),\n \"block_seq\": block_seq * 2 # Different block size\n }\n }\n\n # Execute the function for all test cases\n results = {}\n for key, test_case in test_cases.items():\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n results[key] = test_case[\"Out\"]\n\n return results\n\n# Run the test\nresult_gold = test_flash_decode_stage2()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum,# [batch, head, seq_block_num]\n Out, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_out_b,\n stride_out_h,\n stride_out_d,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_seq_len = tl.load(B_Seqlen + cur_batch)\n block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for block_seq_id in range(0, block_n_size):\n ptr_v = (Mid_O\n + cur_batch * stride_mid_ob\n + cur_head * stride_mid_oh\n + block_seq_id * stride_mid_os\n + offs_d * stride_mid_od)\n tv = tl.load(ptr_v).to(tl.float32)\n\n ptr_logic = (Mid_O_LogExpSum\n + cur_batch * stride_mid_o_eb\n + cur_head * stride_mid_o_eh\n + block_seq_id * stride_mid_o_es)\n tlogic = tl.load(ptr_logic).to(tl.float32)\n\n new_max = tl.maximum(max_logic, tlogic)\n scale = tl.exp(max_logic - new_max)\n acc = acc * scale\n sum_exp = sum_exp * scale\n\n delta = tl.exp(tlogic - new_max)\n acc += delta * tv\n sum_exp += delta\n max_logic = new_max\n\n acc = acc / sum_exp\n\n ptr_out = (Out\n + cur_batch * stride_out_b\n + cur_head * stride_out_h\n + offs_d * stride_out_d)\n tl.store(ptr_out, acc.to(Out.dtype.element_ty))\n\n\n@torch.no_grad()\ndef flash_decode_stage2(\n B_Seqlen: torch.Tensor,\n Mid_O: torch.Tensor,\n Mid_O_LogExpSum: torch.Tensor,\n Out: torch.Tensor,\n BLOCK_SEQ: int = 64,\n):\n assert Mid_O.dim() == 4\n assert Mid_O_LogExpSum.dim() == 3 # [batch, head, seq_block_num]\n assert B_Seqlen.dim() == 1\n assert Out.dim() == 3 # [batch, head, head_dim]\n assert Mid_O.shape[0] == Out.shape[0]\n assert Mid_O.shape[1] == Out.shape[1]\n assert Mid_O.shape[3] == Out.shape[2]\n assert Mid_O.shape[0] == Mid_O_LogExpSum.shape[0]\n assert Mid_O.shape[1] == Mid_O_LogExpSum.shape[1]\n assert Mid_O.shape[2] == Mid_O_LogExpSum.shape[2]\n\n BLOCK_DMODEL = Out.shape[-1]\n batch, head_num = Out.shape[0], Out.shape[1]\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n Out,\n Mid_O.stride(0),\n Mid_O.stride(1),\n Mid_O.stride(2),\n Mid_O.stride(3),\n Mid_O_LogExpSum.stride(0),\n Mid_O_LogExpSum.stride(1),\n Mid_O_LogExpSum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=2,\n )\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", "label": null, "file": "matrix_transpose.py", "test_code": "\n\n\nimport torch\n\ndef test_triton_vs_torch():\n results = {}\n\n # \u6d4b\u8bd5\u7528\u4f8b 1: \u57fa\u672c\u77e9\u9635\u8f6c\u7f6e (\u5c0f\u77e9\u9635)\n size_m, d_head = 16, 16\n out = wrapper(size_m, d_head)\n results[\"test_case_1\"] = out.clone()\n\n # \u6d4b\u8bd5\u7528\u4f8b 2: \u975e\u65b9\u5f62\u77e9\u9635\n size_m, d_head = 32, 64\n out = wrapper(size_m, d_head)\n results[\"test_case_2\"] = out.clone()\n\n return results\n\n\n# \u8fd0\u884c\u6d4b\u8bd5\nresult_gold = test_triton_vs_torch()\n# print(result_gold)", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n\n rows = size_m_arange[:, None]\n cols = d_head_arange[None, :]\n matrix_ptr = M + rows * matrix_stridex + cols * matrix_stridey\n matrix_value = tl.load(matrix_ptr)\n\n out_rows = d_head_arange[:, None]\n out_cols = size_m_arange[None, :]\n out_ptr = Out + out_rows * out_stridex + out_cols * out_stridey\n\n tl.store(out_ptr, tl.trans(matrix_value))\n\ndef wrapper(SIZE_M: int, D_HEAD: int):\n device = \"cuda\"\n matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device)\n out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device)\n grid = (1,)\n kernel[grid](\n matrix,\n out,\n matrix.stride(1),\n matrix.stride(0),\n out.stride(1),\n out.stride(0),\n SIZE_M=SIZE_M,\n D_HEAD=D_HEAD,\n )\n return out\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n ", "label": null, "file": "rotary_transform.py", "test_code": "\n\n\nimport torch\n\ndef test_apply_rotary():\n results = {}\n \n # Test case 1: Basic test with fixed sequence length and no interleaving\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin)\n results['test_case_1'] = output.shape\n\n # Test case 2: Variable length sequences with interleaving\n total_seqlen, nheads, headdim = 256, 4, 64\n batch = 3\n cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda')\n max_seqlen = 128\n rotary_dim = 32\n x = torch.randn(total_seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True)\n results['test_case_2'] = output.shape\n\n # Test case 3: Conjugate flag enabled\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, conjugate=True)\n results['test_case_3'] = output.shape\n\n # Test case 4: Inplace operation\n batch, seqlen, nheads, headdim = 2, 128, 4, 64\n rotary_dim = 32\n x = torch.randn(batch, seqlen, nheads, headdim, device='cuda')\n cos = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n sin = torch.randn(seqlen, rotary_dim // 2, device='cuda')\n output = apply_rotary(x, cos, sin, inplace=True)\n results['test_case_4'] = output.shape\n\n return results\n\nresult_gold = test_apply_rotary()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\nimport logging\nfrom typing import Optional, Union\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS,\n seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN,\n stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim,\n stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim,\n BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_batch = tl.program_id(1)\n pid_head = tl.program_id(2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n cur_seqlen = seqlen\n x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n end_idx = tl.load(CU_SEQLENS + pid_batch + 1)\n cur_seqlen = end_idx - start_idx\n x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n mask_m = rm < cur_seqlen\n\n if IS_SEQLEN_OFFSETS_TENSOR:\n offset = tl.load(SEQLEN_OFFSETS + pid_batch)\n else:\n offset = SEQLEN_OFFSETS\n rm_cs = rm + offset\n mask_m_cs = rm_cs < seqlen_ro\n\n if not INTERLEAVED:\n rk_half = tl.arange(0, BLOCK_K)\n mask_k_half = rk_half < rotary_dim_half\n\n cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :]\n sin_offs = cos_offs\n cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32)\n sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n\n x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim\n x1_offs = x0_offs + rotary_dim_half * stride_x_headdim\n\n x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32)\n\n if CONJUGATE:\n sin = -sin\n y0 = x0 * cos - x1 * sin\n y1 = x0 * sin + x1 * cos\n\n out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim\n out1_offs = out0_offs + rotary_dim_half * stride_out_headdim\n\n tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :])\n tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :])\n else:\n rk = tl.arange(0, BLOCK_K)\n mask_k = rk < rotary_dim\n\n rk_repeat = rk // 2\n cs_mask = rk_repeat[None, :] < rotary_dim_half\n\n cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]\n sin_offs = cos_offs\n cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32)\n sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32)\n\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim\n x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim\n\n x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32)\n x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32)\n\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n\n out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim\n tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :])\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved: bool = False,\n inplace: bool = False,\n conjugate: bool = False,\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch = cu_seqlens.shape[0] - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n assert rotary_dim <= headdim\n rotary_dim = rotary_dim * 2\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n assert cos.dtype == sin.dtype\n assert x.dtype == cos.dtype\n\n cos = cos.contiguous()\n sin = sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n if interleaved:\n BLOCK_M = 4\n else:\n BLOCK_M = 8 if rotary_dim <= 64 else 4\n BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256\n\n grid = lambda META: (\n triton.cdiv(seqlen, META[\"BLOCK_M\"]),\n batch,\n nheads,\n )\n\n rotary_kernel[grid](\n output,\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128,\n output.stride(0) if not is_varlen else 0,\n output.stride(-3),\n output.stride(-2),\n output.stride(-1),\n x.stride(0) if not is_varlen else 0,\n x.stride(-3),\n x.stride(-2),\n x.stride(-1),\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", "label": null, "file": "sin_kernel.py", "test_code": "\n\n\nimport torch\n\n# Function to test the Triton kernel\ndef test_call_kernel():\n results = {}\n \n # Test case 1: Small input tensor\n x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda()\n output1 = call_kernel(x1)\n results['test_case_1'] = output1\n \n # Test case 2: Larger input tensor\n x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda()\n output2 = call_kernel(x2)\n results['test_case_2'] = output2\n\n # Test case 3: Edge case with zero elements\n x3 = torch.tensor([], dtype=torch.float32).cuda()\n output3 = call_kernel(x3)\n results['test_case_3'] = output3\n\n # Test case 4: Input tensor with negative values\n x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda()\n output4 = call_kernel(x4)\n results['test_case_4'] = output4\n \n return results\n\n# Run the test function\nresult_gold = test_call_kernel()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.math.sin(x)\n tl.store(output_ptr + offsets, y, mask=mask)\n\ndef call_kernel(x: torch.Tensor) -> torch.Tensor:\n n_elements = x.numel()\n output = torch.empty_like(x)\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n kernel_function[grid](\n x,\n output,\n n_elements,\n BLOCK_SIZE=BLOCK_SIZE\n )\n return output\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", "label": null, "file": "l2_norm_bwd.py", "test_code": "\n\n\nimport torch\n\n# Test the backward L2 normalization\ndef test_l2_norm_bwd():\n results = {}\n \n # Test case 1: Default case\n x = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_1'] = dx\n\n # Test case 2: Different shape\n x = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dy = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_2'] = dx\n\n # Test case 3: Larger tensor\n x = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_3'] = dx\n\n # Test case 4: Edge case with small tensor\n x = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_4'] = dx\n\n return results\n\n# Run the tests\nresult_gold = test_l2_norm_bwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, DY, DX,\n stride_x_row,\n N,\n eps,\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DY += row * stride_x_row\n DX += row * stride_x_row\n\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)\n\n var = tl.sum(x * x)\n rstd = 1 / tl.sqrt(var + eps)\n\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\n\ndef _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n\n if x.stride(-1) != 1:\n x = x.contiguous()\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n\n M, N = x.shape\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n dx = torch.empty_like(x)\n\n # ROCm-compatible grid launch\n grid = (M,)\n _l2_norm_bwd_kernel[grid](\n x, dy, dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n"}, {"instruction": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", "label": null, "file": "l2_norm_triton1.py", "test_code": "\n\n\nimport torch\n\n# Test the forward L2 normalization\ndef test_l2_norm_fwd():\n results = {}\n \n # Test case 1\n x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n y1 = _l2_norm_fwd(x1)\n results['test_case_1'] = y1\n\n # Test case 2: Different batch size\n x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32)\n y2 = _l2_norm_fwd(x2)\n results['test_case_2'] = y2\n\n # Test case 3: Different feature size\n x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32)\n y3 = _l2_norm_fwd(x3)\n results['test_case_3'] = y3\n\n # Test case 4: Larger tensor\n x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n y4 = _l2_norm_fwd(x4)\n results['test_case_4'] = y4\n\n return results\n\nresult_gold = test_l2_norm_fwd()\n", "predict": "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n\n cols = tl.arange(0, BLOCK_N)\n\n var = tl.zeros([], dtype=tl.float32)\n for base in range(0, N, BLOCK_N):\n mask = cols < (N - base)\n data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32)\n var += tl.sum(data * data)\n\n rstd = tl.rsqrt(var + eps)\n\n for base in range(0, N, BLOCK_N):\n mask = cols < (N - base)\n data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32)\n y = data * rstd\n tl.store(Y + base + cols, y, mask=mask)\n\ndef _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5):\n x = x if x.is_contiguous() else x.contiguous()\n original_shape = x.shape\n if x.ndim > 2:\n x = x.view(-1, x.shape[-1])\n M, N = x.shape\n y = torch.empty_like(x)\n\n element_size = x.element_size()\n BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N))\n\n grid = (M,)\n _l2_norm_fwd_1pass_kernel[grid](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N=BLOCK_N,\n )\n return y.view(original_shape)\n"}] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp b/reflexion_oneshot_tritonbench_4/.out.jsonpassk.txt.swp new file mode 100644 index 0000000000000000000000000000000000000000..43cb1a8921cc31bf1d5c6c9e535f9f337f1bc44c GIT binary patch literal 4096 zcmeIuK@P!C5C-76fJAVBVy*3?DiRm45G%G#TO+N?o94YIdL@E8jIg{n2$BI|x9aM~k@%TZidn> None: + if BLOCK_DMODEL is None: + BLOCK_DMODEL = out.shape[-1] + batch, head_num = out.shape[:2] + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, mid_out, mid_out_logexpsum, out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), out.stride(0), out.stride(1), out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/efficiency.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json new file mode 100644 index 0000000..5ba3a59 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi.json @@ -0,0 +1,418 @@ +[ + { + "input_size": [ + [ + 4, + 8, + 16, + 64 + ], + [ + 4, + 8, + 16 + ], + [ + 4 + ], + [ + 4, + 8, + 64 + ] + ], + "ms": 0.00685499981045723, + "GB/s": 20.61677664591699, + "TFLOPS": 0.009560321197970791 + }, + { + "input_size": [ + [ + 8, + 8, + 16, + 64 + ], + [ + 8, + 8, + 16 + ], + [ + 8 + ], + [ + 8, + 8, + 64 + ] + ], + "ms": 0.010343999601900578, + "GB/s": 27.325600432937524, + "TFLOPS": 0.012671307525564598 + }, + { + "input_size": [ + [ + 16, + 8, + 16, + 64 + ], + [ + 16, + 8, + 16 + ], + [ + 16 + ], + [ + 16, + 8, + 64 + ] + ], + "ms": 0.009943000040948391, + "GB/s": 56.85527483373911, + "TFLOPS": 0.026364678559831923 + }, + { + "input_size": [ + [ + 32, + 8, + 16, + 64 + ], + [ + 32, + 8, + 16 + ], + [ + 32 + ], + [ + 32, + 8, + 64 + ] + ], + "ms": 0.010905000381171703, + "GB/s": 103.67940948925656, + "TFLOPS": 0.048077760813765975 + }, + { + "input_size": [ + [ + 64, + 8, + 16, + 64 + ], + [ + 64, + 8, + 16 + ], + [ + 64 + ], + [ + 64, + 8, + 64 + ] + ], + "ms": 0.011105000041425228, + "GB/s": 203.6243126127705, + "TFLOPS": 0.09442377272296025 + }, + { + "input_size": [ + [ + 128, + 8, + 16, + 64 + ], + [ + 128, + 8, + 16 + ], + [ + 128 + ], + [ + 128, + 8, + 64 + ] + ], + "ms": 0.011587000451982021, + "GB/s": 390.3077434700886, + "TFLOPS": 0.18099179409639796 + }, + { + "input_size": [ + [ + 256, + 8, + 16, + 64 + ], + [ + 256, + 8, + 16 + ], + [ + 256 + ], + [ + 256, + 8, + 64 + ] + ], + "ms": 0.012749000452458858, + "GB/s": 709.4667565295695, + "TFLOPS": 0.3289908111338296 + }, + { + "input_size": [ + [ + 512, + 8, + 16, + 64 + ], + [ + 512, + 8, + 16 + ], + [ + 512 + ], + [ + 512, + 8, + 64 + ] + ], + "ms": 0.01828099973499775, + "GB/s": 989.5511329923569, + "TFLOPS": 0.4588703091516692 + }, + { + "input_size": [ + [ + 1024, + 8, + 16, + 64 + ], + [ + 1024, + 8, + 16 + ], + [ + 1024 + ], + [ + 1024, + 8, + 64 + ] + ], + "ms": 0.02782299928367138, + "GB/s": 1300.361892372729, + "TFLOPS": 0.6029981106259139 + }, + { + "input_size": [ + [ + 2048, + 8, + 16, + 64 + ], + [ + 2048, + 8, + 16 + ], + [ + 2048 + ], + [ + 2048, + 8, + 64 + ] + ], + "ms": 0.046505000442266464, + "GB/s": 1555.960333552326, + "TFLOPS": 0.7215230981807231 + }, + { + "input_size": [ + [ + 4096, + 8, + 16, + 64 + ], + [ + 4096, + 8, + 16 + ], + [ + 4096 + ], + [ + 4096, + 8, + 64 + ] + ], + "ms": 0.08667699992656708, + "GB/s": 1669.6456052079209, + "TFLOPS": 0.7742407334916386 + }, + { + "input_size": [ + [ + 8192, + 8, + 16, + 64 + ], + [ + 8192, + 8, + 16 + ], + [ + 8192 + ], + [ + 8192, + 8, + 64 + ] + ], + "ms": 0.16365200281143188, + "GB/s": 1768.6294028036255, + "TFLOPS": 0.8201410657628948 + }, + { + "input_size": [ + [ + 16384, + 8, + 16, + 64 + ], + [ + 16384, + 8, + 16 + ], + [ + 16384 + ], + [ + 16384, + 8, + 64 + ] + ], + "ms": 0.31447499990463257, + "GB/s": 1840.7806285890786, + "TFLOPS": 0.8535987155780443 + }, + { + "input_size": [ + [ + 32768, + 8, + 16, + 64 + ], + [ + 32768, + 8, + 16 + ], + [ + 32768 + ], + [ + 32768, + 8, + 64 + ] + ], + "ms": 0.5913429856300354, + "GB/s": 1957.8468065643615, + "TFLOPS": 0.9078841299318041 + }, + { + "input_size": [ + [ + 65536, + 8, + 16, + 64 + ], + [ + 65536, + 8, + 16 + ], + [ + 65536 + ], + [ + 65536, + 8, + 64 + ] + ], + "ms": 1.1500940322875977, + "GB/s": 2013.3292469958415, + "TFLOPS": 0.9336122037467414 + }, + { + "input_size": [ + [ + 131072, + 8, + 16, + 64 + ], + [ + 131072, + 8, + 16 + ], + [ + 131072 + ], + [ + 131072, + 8, + 64 + ] + ], + "ms": 2.2799609899520874, + "GB/s": 2031.1908512510645, + "TFLOPS": 0.9418949084936443 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json new file mode 100644 index 0000000..ee02aa5 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/flash_decode2_phi_perf_data.json @@ -0,0 +1,6 @@ +{ + "flash_decode2_phi.json": { + "ms": 3.6883, + "efficiency": 99.617 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json new file mode 100644 index 0000000..74f2e7b --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd.json @@ -0,0 +1,145 @@ +[ + { + "input_size": [ + [ + 16 + ], + [ + 16 + ] + ], + "ms": 0.006012999918311834, + "GB/s": 0.03193081699789953, + "TFLOPS": 5.321802832983255e-06 + }, + { + "input_size": [ + [ + 32 + ], + [ + 32 + ] + ], + "ms": 0.0059739998541772366, + "GB/s": 0.06427854191049123, + "TFLOPS": 1.0713090318415206e-05 + }, + { + "input_size": [ + [ + 64 + ], + [ + 64 + ] + ], + "ms": 0.0059739998541772366, + "GB/s": 0.12855708382098247, + "TFLOPS": 2.1426180636830412e-05 + }, + { + "input_size": [ + [ + 128 + ], + [ + 128 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 0.25540405045773634, + "TFLOPS": 4.256734174295606e-05 + }, + { + "input_size": [ + [ + 256 + ], + [ + 256 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 0.5108081009154727, + "TFLOPS": 8.513468348591212e-05 + }, + { + "input_size": [ + [ + 512 + ], + [ + 512 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 1.0216162018309454, + "TFLOPS": 0.00017026936697182423 + }, + { + "input_size": [ + [ + 1024 + ], + [ + 1024 + ] + ], + "ms": 0.006012999918311834, + "GB/s": 2.0435722878655698, + "TFLOPS": 0.0003405953813109283 + }, + { + "input_size": [ + [ + 2048 + ], + [ + 2048 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 4.0864648073237815, + "TFLOPS": 0.0006810774678872969 + }, + { + "input_size": [ + [ + 4096 + ], + [ + 4096 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 8.172929614647563, + "TFLOPS": 0.0013621549357745938 + }, + { + "input_size": [ + [ + 8192 + ], + [ + 8192 + ] + ], + "ms": 0.006014000158756971, + "GB/s": 16.345859229295126, + "TFLOPS": 0.0027243098715491877 + }, + { + "input_size": [ + [ + 16384 + ], + [ + 16384 + ] + ], + "ms": 0.006696000229567289, + "GB/s": 29.362006161804633, + "TFLOPS": 0.004893667693634106 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json new file mode 100644 index 0000000..9a4f2e2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_bwd_perf_data.json @@ -0,0 +1,6 @@ +{ + "l2_norm_bwd.json": { + "ms": 2.0181, + "efficiency": 1.44 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json new file mode 100644 index 0000000..8e264de --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/l2_norm_triton1_perf_data.json @@ -0,0 +1,6 @@ +{ + "l2_norm_triton1.json": { + "ms": 1.7399, + "efficiency": 0.7672 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log new file mode 100644 index 0000000..798ca3d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/flash_decode2_phi_perf.py.log @@ -0,0 +1,16 @@ +{'input_size': [torch.Size([4, 8, 16, 64]), torch.Size([4, 8, 16]), torch.Size([4]), torch.Size([4, 8, 64])], 'ms': 0.00685499981045723, 'GB/s': 20.61677664591699, 'TFLOPS': 0.009560321197970791} +{'input_size': [torch.Size([8, 8, 16, 64]), torch.Size([8, 8, 16]), torch.Size([8]), torch.Size([8, 8, 64])], 'ms': 0.010343999601900578, 'GB/s': 27.325600432937524, 'TFLOPS': 0.012671307525564598} +{'input_size': [torch.Size([16, 8, 16, 64]), torch.Size([16, 8, 16]), torch.Size([16]), torch.Size([16, 8, 64])], 'ms': 0.009943000040948391, 'GB/s': 56.85527483373911, 'TFLOPS': 0.026364678559831923} +{'input_size': [torch.Size([32, 8, 16, 64]), torch.Size([32, 8, 16]), torch.Size([32]), torch.Size([32, 8, 64])], 'ms': 0.010905000381171703, 'GB/s': 103.67940948925656, 'TFLOPS': 0.048077760813765975} +{'input_size': [torch.Size([64, 8, 16, 64]), torch.Size([64, 8, 16]), torch.Size([64]), torch.Size([64, 8, 64])], 'ms': 0.011105000041425228, 'GB/s': 203.6243126127705, 'TFLOPS': 0.09442377272296025} +{'input_size': [torch.Size([128, 8, 16, 64]), torch.Size([128, 8, 16]), torch.Size([128]), torch.Size([128, 8, 64])], 'ms': 0.011587000451982021, 'GB/s': 390.3077434700886, 'TFLOPS': 0.18099179409639796} +{'input_size': [torch.Size([256, 8, 16, 64]), torch.Size([256, 8, 16]), torch.Size([256]), torch.Size([256, 8, 64])], 'ms': 0.012749000452458858, 'GB/s': 709.4667565295695, 'TFLOPS': 0.3289908111338296} +{'input_size': [torch.Size([512, 8, 16, 64]), torch.Size([512, 8, 16]), torch.Size([512]), torch.Size([512, 8, 64])], 'ms': 0.01828099973499775, 'GB/s': 989.5511329923569, 'TFLOPS': 0.4588703091516692} +{'input_size': [torch.Size([1024, 8, 16, 64]), torch.Size([1024, 8, 16]), torch.Size([1024]), torch.Size([1024, 8, 64])], 'ms': 0.02782299928367138, 'GB/s': 1300.361892372729, 'TFLOPS': 0.6029981106259139} +{'input_size': [torch.Size([2048, 8, 16, 64]), torch.Size([2048, 8, 16]), torch.Size([2048]), torch.Size([2048, 8, 64])], 'ms': 0.046505000442266464, 'GB/s': 1555.960333552326, 'TFLOPS': 0.7215230981807231} +{'input_size': [torch.Size([4096, 8, 16, 64]), torch.Size([4096, 8, 16]), torch.Size([4096]), torch.Size([4096, 8, 64])], 'ms': 0.08667699992656708, 'GB/s': 1669.6456052079209, 'TFLOPS': 0.7742407334916386} +{'input_size': [torch.Size([8192, 8, 16, 64]), torch.Size([8192, 8, 16]), torch.Size([8192]), torch.Size([8192, 8, 64])], 'ms': 0.16365200281143188, 'GB/s': 1768.6294028036255, 'TFLOPS': 0.8201410657628948} +{'input_size': [torch.Size([16384, 8, 16, 64]), torch.Size([16384, 8, 16]), torch.Size([16384]), torch.Size([16384, 8, 64])], 'ms': 0.31447499990463257, 'GB/s': 1840.7806285890786, 'TFLOPS': 0.8535987155780443} +{'input_size': [torch.Size([32768, 8, 16, 64]), torch.Size([32768, 8, 16]), torch.Size([32768]), torch.Size([32768, 8, 64])], 'ms': 0.5913429856300354, 'GB/s': 1957.8468065643615, 'TFLOPS': 0.9078841299318041} +{'input_size': [torch.Size([65536, 8, 16, 64]), torch.Size([65536, 8, 16]), torch.Size([65536]), torch.Size([65536, 8, 64])], 'ms': 1.1500940322875977, 'GB/s': 2013.3292469958415, 'TFLOPS': 0.9336122037467414} +{'input_size': [torch.Size([131072, 8, 16, 64]), torch.Size([131072, 8, 16]), torch.Size([131072]), torch.Size([131072, 8, 64])], 'ms': 2.2799609899520874, 'GB/s': 2031.1908512510645, 'TFLOPS': 0.9418949084936443} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log new file mode 100644 index 0000000..3d811ce --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_bwd_perf.py.log @@ -0,0 +1,11 @@ +{'input_size': [torch.Size([16]), torch.Size([16])], 'ms': 0.006012999918311834, 'GB/s': 0.03193081699789953, 'TFLOPS': 5.321802832983255e-06} +{'input_size': [torch.Size([32]), torch.Size([32])], 'ms': 0.0059739998541772366, 'GB/s': 0.06427854191049123, 'TFLOPS': 1.0713090318415206e-05} +{'input_size': [torch.Size([64]), torch.Size([64])], 'ms': 0.0059739998541772366, 'GB/s': 0.12855708382098247, 'TFLOPS': 2.1426180636830412e-05} +{'input_size': [torch.Size([128]), torch.Size([128])], 'ms': 0.006014000158756971, 'GB/s': 0.25540405045773634, 'TFLOPS': 4.256734174295606e-05} +{'input_size': [torch.Size([256]), torch.Size([256])], 'ms': 0.006014000158756971, 'GB/s': 0.5108081009154727, 'TFLOPS': 8.513468348591212e-05} +{'input_size': [torch.Size([512]), torch.Size([512])], 'ms': 0.006014000158756971, 'GB/s': 1.0216162018309454, 'TFLOPS': 0.00017026936697182423} +{'input_size': [torch.Size([1024]), torch.Size([1024])], 'ms': 0.006012999918311834, 'GB/s': 2.0435722878655698, 'TFLOPS': 0.0003405953813109283} +{'input_size': [torch.Size([2048]), torch.Size([2048])], 'ms': 0.006014000158756971, 'GB/s': 4.0864648073237815, 'TFLOPS': 0.0006810774678872969} +{'input_size': [torch.Size([4096]), torch.Size([4096])], 'ms': 0.006014000158756971, 'GB/s': 8.172929614647563, 'TFLOPS': 0.0013621549357745938} +{'input_size': [torch.Size([8192]), torch.Size([8192])], 'ms': 0.006014000158756971, 'GB/s': 16.345859229295126, 'TFLOPS': 0.0027243098715491877} +{'input_size': [torch.Size([16384]), torch.Size([16384])], 'ms': 0.006696000229567289, 'GB/s': 29.362006161804633, 'TFLOPS': 0.004893667693634106} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log new file mode 100644 index 0000000..550e319 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/l2_norm_triton1_perf.py.log @@ -0,0 +1,11 @@ +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) +Failed to run benchmark for input tensor. Error: not enough values to unpack (expected 2, got 1) diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log new file mode 100644 index 0000000..9f23be3 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/matrix_vector_multip_perf.py.log @@ -0,0 +1,18 @@ +{'input_size': [torch.Size([128, 256]), torch.Size([256])], 'ms': 0.019283000379800797, 'GB/s': 6.8769381003025165, 'TFLOPS': 0.0033986412233155293} +{'input_size': [torch.Size([256, 384]), torch.Size([384])], 'ms': 0.026660999283194542, 'GB/s': 14.844754909448309, 'TFLOPS': 0.007374367251265395} +{'input_size': [torch.Size([384, 512]), torch.Size([512])], 'ms': 0.0357620008289814, 'GB/s': 22.09093399941353, 'TFLOPS': 0.010995357946564867} +{'input_size': [torch.Size([512, 640]), torch.Size([640])], 'ms': 0.04333899915218353, 'GB/s': 30.349754856619256, 'TFLOPS': 0.015121715148490715} +{'input_size': [torch.Size([640, 768]), torch.Size([768])], 'ms': 0.051676999777555466, 'GB/s': 38.154536998805426, 'TFLOPS': 0.01902277617182717} +{'input_size': [torch.Size([768, 896]), torch.Size([896])], 'ms': 0.060015998780727386, 'GB/s': 45.97387456769338, 'TFLOPS': 0.022931485403221343} +{'input_size': [torch.Size([896, 1024]), torch.Size([1024])], 'ms': 0.07028000056743622, 'GB/s': 52.32919707322878, 'TFLOPS': 0.026109959948552405} +{'input_size': [torch.Size([1024, 1152]), torch.Size([1152])], 'ms': 0.07629299908876419, 'GB/s': 61.9623826099687, 'TFLOPS': 0.030924148063114457} +{'input_size': [torch.Size([1152, 1280]), torch.Size([1280])], 'ms': 0.08503299951553345, 'GB/s': 69.47853225994645, 'TFLOPS': 0.03468206480780756} +{'input_size': [torch.Size([1280, 1408]), torch.Size([1408])], 'ms': 0.0923290029168129, 'GB/s': 78.19549406923474, 'TFLOPS': 0.03903952047708762} +{'input_size': [torch.Size([1408, 1536]), torch.Size([1536])], 'ms': 0.10183099657297134, 'GB/s': 85.06769344825665, 'TFLOPS': 0.0424760254300415} +{'input_size': [torch.Size([1536, 1664]), torch.Size([1664])], 'ms': 0.10796499997377396, 'GB/s': 94.81235587909558, 'TFLOPS': 0.04734689946965889} +{'input_size': [torch.Size([1664, 1792]), torch.Size([1792])], 'ms': 0.1157429963350296, 'GB/s': 103.17147799970981, 'TFLOPS': 0.05152602048366933} +{'input_size': [torch.Size([1792, 1920]), torch.Size([1920])], 'ms': 0.12308000028133392, 'GB/s': 111.93864127809444, 'TFLOPS': 0.05590900214714739} +{'input_size': [torch.Size([1920, 2048]), torch.Size([2048])], 'ms': 0.13025599718093872, 'GB/s': 120.87360536750782, 'TFLOPS': 0.06037587650628989} +{'input_size': [torch.Size([2048, 2176]), torch.Size([2176])], 'ms': 0.13975700736045837, 'GB/s': 127.66936225230204, 'TFLOPS': 0.06377423335212126} +{'input_size': [torch.Size([2176, 2304]), torch.Size([2304])], 'ms': 0.14849700033664703, 'GB/s': 135.16728253430261, 'TFLOPS': 0.06752330334800354} +{'input_size': [torch.Size([2304, 2432]), torch.Size([2432])], 'ms': 0.15607500076293945, 'GB/s': 143.72741240009412, 'TFLOPS': 0.07180301742891973} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/performance_utils.py.log new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log new file mode 100644 index 0000000..45caff8 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/rotary_transform_perf.py.log @@ -0,0 +1,14 @@ +{'input_size': [torch.Size([4, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.009139999747276306, 'GB/s': 116.516414600269, 'TFLOPS': 0.057361927187824735} +{'input_size': [torch.Size([8, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.009500999934971333, 'GB/s': 222.4540589902001, 'TFLOPS': 0.11036480446025432} +{'input_size': [torch.Size([16, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.01006300002336502, 'GB/s': 418.4326731812891, 'TFLOPS': 0.20840226524204278} +{'input_size': [torch.Size([32, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.012788999825716019, 'GB/s': 657.204794318576, 'TFLOPS': 0.32796184667749595} +{'input_size': [torch.Size([64, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.019202999770641327, 'GB/s': 874.5300317961281, 'TFLOPS': 0.4368384158825538} +{'input_size': [torch.Size([128, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.030789999291300774, 'GB/s': 1090.3155820950246, 'TFLOPS': 0.5448917306321646} +{'input_size': [torch.Size([256, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.06005600094795227, 'GB/s': 1117.7109188168276, 'TFLOPS': 0.5587190533895199} +{'input_size': [torch.Size([512, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.11085200309753418, 'GB/s': 1210.930865019127, 'TFLOPS': 0.6053915321760459} +{'input_size': [torch.Size([1024, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.20907600224018097, 'GB/s': 1283.9916447781015, 'TFLOPS': 0.6419566404651943} +{'input_size': [torch.Size([2048, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.3862780034542084, 'GB/s': 1389.8987030040546, 'TFLOPS': 0.6949281439781023} +{'input_size': [torch.Size([4096, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 0.7548360228538513, 'GB/s': 1422.5052534461477, 'TFLOPS': 0.7112417740348714} +{'input_size': [torch.Size([8192, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 1.4495340585708618, 'GB/s': 1481.510571829739, 'TFLOPS': 0.7407496344436594} +{'input_size': [torch.Size([16384, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 2.838752031326294, 'GB/s': 1512.983040647386, 'TFLOPS': 0.7564886345485674} +{'input_size': [torch.Size([32768, 128, 8, 64]), torch.Size([128, 16]), torch.Size([128, 16])], 'ms': 5.886918544769287, 'GB/s': 1459.159135747248, 'TFLOPS': 0.7295781763136869} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log new file mode 100644 index 0000000..8ceeb1d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/sin_kernel_perf.py.log @@ -0,0 +1,16 @@ +{'input_size': [torch.Size([4096])], 'ms': 0.006053999997675419, 'GB/s': 5.412619757611839, 'TFLOPS': 0.0006765774697014799} +{'input_size': [torch.Size([8192])], 'ms': 0.006093999836593866, 'GB/s': 10.754184732080693, 'TFLOPS': 0.0013442730915100868} +{'input_size': [torch.Size([16384])], 'ms': 0.006093999836593866, 'GB/s': 21.508369464161387, 'TFLOPS': 0.0026885461830201737} +{'input_size': [torch.Size([32768])], 'ms': 0.006093999836593866, 'GB/s': 43.01673892832277, 'TFLOPS': 0.005377092366040347} +{'input_size': [torch.Size([65536])], 'ms': 0.006093999836593866, 'GB/s': 86.03347785664555, 'TFLOPS': 0.010754184732080695} +{'input_size': [torch.Size([131072])], 'ms': 0.006134000141173601, 'GB/s': 170.94489335948708, 'TFLOPS': 0.021368111669935887} +{'input_size': [torch.Size([262144])], 'ms': 0.006134000141173601, 'GB/s': 341.88978671897416, 'TFLOPS': 0.04273622333987177} +{'input_size': [torch.Size([524288])], 'ms': 0.0062939999625086784, 'GB/s': 666.3972076555627, 'TFLOPS': 0.08329965095694535} +{'input_size': [torch.Size([1048576])], 'ms': 0.00685499981045723, 'GB/s': 1223.7211133402611, 'TFLOPS': 0.15296513916753265} +{'input_size': [torch.Size([2097152])], 'ms': 0.008860000409185886, 'GB/s': 1893.5908832019568, 'TFLOPS': 0.2366988604002446} +{'input_size': [torch.Size([4194304])], 'ms': 0.014072000049054623, 'GB/s': 2384.482083785541, 'TFLOPS': 0.2980602604731926} +{'input_size': [torch.Size([8388608])], 'ms': 0.0230919998139143, 'GB/s': 2906.15211072204, 'TFLOPS': 0.363269013840255} +{'input_size': [torch.Size([16777216])], 'ms': 0.05051499977707863, 'GB/s': 2656.987599570411, 'TFLOPS': 0.3321234499463014} +{'input_size': [torch.Size([33554432])], 'ms': 0.08868200331926346, 'GB/s': 3026.943979080033, 'TFLOPS': 0.3783679973850041} +{'input_size': [torch.Size([67108864])], 'ms': 0.1559540033340454, 'GB/s': 3442.4952263011182, 'TFLOPS': 0.4303119032876398} +{'input_size': [torch.Size([134217728])], 'ms': 0.2989189922809601, 'GB/s': 3592.0829780891536, 'TFLOPS': 0.4490103722611442} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.err b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.err new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log new file mode 100644 index 0000000..a440b79 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/logs/triton_matmul_perf.py.log @@ -0,0 +1,31 @@ +{'input_size': [torch.Size([256, 256]), torch.Size([256, 256])], 'ms': 0.009019999764859676, 'GB/s': 43.59379271071602, 'TFLOPS': 3.720003644647768} +{'input_size': [torch.Size([384, 384]), torch.Size([384, 384])], 'ms': 0.010223000310361385, 'GB/s': 86.54367339725967, 'TFLOPS': 11.077590194849238} +{'input_size': [torch.Size([512, 512]), torch.Size([512, 512])], 'ms': 0.012949000112712383, 'GB/s': 121.46605809786634, 'TFLOPS': 20.73020724870252} +{'input_size': [torch.Size([640, 640]), torch.Size([640, 640])], 'ms': 0.014112000353634357, 'GB/s': 174.14965549990777, 'TFLOPS': 37.151926506646994} +{'input_size': [torch.Size([768, 768]), torch.Size([768, 768])], 'ms': 0.01720000058412552, 'GB/s': 205.7525511520165, 'TFLOPS': 52.67265309491623} +{'input_size': [torch.Size([896, 896]), torch.Size([896, 896])], 'ms': 0.018681999295949936, 'GB/s': 257.836215690483, 'TFLOPS': 77.00708308622427} +{'input_size': [torch.Size([1024, 1024]), torch.Size([1024, 1024])], 'ms': 0.02353299967944622, 'GB/s': 267.3461133599119, 'TFLOPS': 91.25414002684992} +{'input_size': [torch.Size([1152, 1152]), torch.Size([1152, 1152])], 'ms': 0.0257789995521307, 'GB/s': 308.8802567336973, 'TFLOPS': 118.61001858573978} +{'input_size': [torch.Size([1280, 1280]), torch.Size([1280, 1280])], 'ms': 0.029627999290823936, 'GB/s': 331.7942566255078, 'TFLOPS': 141.56554949355} +{'input_size': [torch.Size([1408, 1408]), torch.Size([1408, 1408])], 'ms': 0.030910000205039978, 'GB/s': 384.8199262729385, 'TFLOPS': 180.60881873076582} +{'input_size': [torch.Size([1536, 1536]), torch.Size([1536, 1536])], 'ms': 0.03672400116920471, 'GB/s': 385.4638805498806, 'TFLOPS': 197.35750684153888} +{'input_size': [torch.Size([1664, 1664]), torch.Size([1664, 1664])], 'ms': 0.05195799842476845, 'GB/s': 319.7462662857386, 'TFLOPS': 177.352595699823} +{'input_size': [torch.Size([1792, 1792]), torch.Size([1792, 1792])], 'ms': 0.06085899844765663, 'GB/s': 316.59383971906124, 'TFLOPS': 189.1120535921859} +{'input_size': [torch.Size([1920, 1920]), torch.Size([1920, 1920])], 'ms': 0.06326399743556976, 'GB/s': 349.6206515012925, 'TFLOPS': 223.7572169608272} +{'input_size': [torch.Size([2048, 2048]), torch.Size([2048, 2048])], 'ms': 0.07392799854278564, 'GB/s': 340.40991905705846, 'TFLOPS': 232.38650474295193} +{'input_size': [torch.Size([2176, 2176]), torch.Size([2176, 2176])], 'ms': 0.07533200085163116, 'GB/s': 377.1286528809203, 'TFLOPS': 273.54398288962756} +{'input_size': [torch.Size([2304, 2304]), torch.Size([2304, 2304])], 'ms': 0.14252400398254395, 'GB/s': 223.47460855717318, 'TFLOPS': 171.628499371909} +{'input_size': [torch.Size([2432, 2432]), torch.Size([2432, 2432])], 'ms': 0.14476899802684784, 'GB/s': 245.13358856996922, 'TFLOPS': 198.72162913405504} +{'input_size': [torch.Size([2560, 2560]), torch.Size([2560, 2560])], 'ms': 0.16501599550247192, 'GB/s': 238.28962689505434, 'TFLOPS': 203.34048161711303} +{'input_size': [torch.Size([2688, 2688]), torch.Size([2688, 2688])], 'ms': 0.1670999974012375, 'GB/s': 259.4378496362499, 'TFLOPS': 232.45631327407992} +{'input_size': [torch.Size([2816, 2816]), torch.Size([2816, 2816])], 'ms': 0.18959100544452667, 'GB/s': 250.95671542245927, 'TFLOPS': 235.5647035432151} +{'input_size': [torch.Size([2944, 2944]), torch.Size([2944, 2944])], 'ms': 0.18975099921226501, 'GB/s': 274.05819318941786, 'TFLOPS': 268.94244024988205} +{'input_size': [torch.Size([3072, 3072]), torch.Size([3072, 3072])], 'ms': 0.21536900103092194, 'GB/s': 262.9120427218319, 'TFLOPS': 269.22193174715585} +{'input_size': [torch.Size([3200, 3200]), torch.Size([3200, 3200])], 'ms': 0.297435998916626, 'GB/s': 206.56544676430437, 'TFLOPS': 220.33647654859135} +{'input_size': [torch.Size([3328, 3328]), torch.Size([3328, 3328])], 'ms': 0.3247379958629608, 'GB/s': 204.6372917447065, 'TFLOPS': 227.01096897546105} +{'input_size': [torch.Size([3456, 3456]), torch.Size([3456, 3456])], 'ms': 0.32375550270080566, 'GB/s': 221.35103620532738, 'TFLOPS': 254.9963937085371} +{'input_size': [torch.Size([3584, 3584]), torch.Size([3584, 3584])], 'ms': 0.35669100284576416, 'GB/s': 216.0703112360975, 'TFLOPS': 258.1319984900578} +{'input_size': [torch.Size([3712, 3712]), torch.Size([3712, 3712])], 'ms': 0.36759498715400696, 'GB/s': 224.90422037600632, 'TFLOPS': 278.2814886785785} +{'input_size': [torch.Size([3840, 3840]), torch.Size([3840, 3840])], 'ms': 0.4137600064277649, 'GB/s': 213.82830294267677, 'TFLOPS': 273.70022776662626} +{'input_size': [torch.Size([3968, 3968]), torch.Size([3968, 3968])], 'ms': 0.4952859878540039, 'GB/s': 190.73857592726222, 'TFLOPS': 252.2835564264588} +{'input_size': [torch.Size([4096, 4096]), torch.Size([4096, 4096])], 'ms': 0.5376620292663574, 'GB/s': 187.22411202694673, 'TFLOPS': 255.62332095412455} diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json new file mode 100644 index 0000000..3c6ddcd --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip.json @@ -0,0 +1,254 @@ +[ + { + "input_size": [ + [ + 128, + 256 + ], + [ + 256 + ] + ], + "ms": 0.019283000379800797, + "GB/s": 6.8769381003025165, + "TFLOPS": 0.0033986412233155293 + }, + { + "input_size": [ + [ + 256, + 384 + ], + [ + 384 + ] + ], + "ms": 0.026660999283194542, + "GB/s": 14.844754909448309, + "TFLOPS": 0.007374367251265395 + }, + { + "input_size": [ + [ + 384, + 512 + ], + [ + 512 + ] + ], + "ms": 0.0357620008289814, + "GB/s": 22.09093399941353, + "TFLOPS": 0.010995357946564867 + }, + { + "input_size": [ + [ + 512, + 640 + ], + [ + 640 + ] + ], + "ms": 0.04333899915218353, + "GB/s": 30.349754856619256, + "TFLOPS": 0.015121715148490715 + }, + { + "input_size": [ + [ + 640, + 768 + ], + [ + 768 + ] + ], + "ms": 0.051676999777555466, + "GB/s": 38.154536998805426, + "TFLOPS": 0.01902277617182717 + }, + { + "input_size": [ + [ + 768, + 896 + ], + [ + 896 + ] + ], + "ms": 0.060015998780727386, + "GB/s": 45.97387456769338, + "TFLOPS": 0.022931485403221343 + }, + { + "input_size": [ + [ + 896, + 1024 + ], + [ + 1024 + ] + ], + "ms": 0.07028000056743622, + "GB/s": 52.32919707322878, + "TFLOPS": 0.026109959948552405 + }, + { + "input_size": [ + [ + 1024, + 1152 + ], + [ + 1152 + ] + ], + "ms": 0.07629299908876419, + "GB/s": 61.9623826099687, + "TFLOPS": 0.030924148063114457 + }, + { + "input_size": [ + [ + 1152, + 1280 + ], + [ + 1280 + ] + ], + "ms": 0.08503299951553345, + "GB/s": 69.47853225994645, + "TFLOPS": 0.03468206480780756 + }, + { + "input_size": [ + [ + 1280, + 1408 + ], + [ + 1408 + ] + ], + "ms": 0.0923290029168129, + "GB/s": 78.19549406923474, + "TFLOPS": 0.03903952047708762 + }, + { + "input_size": [ + [ + 1408, + 1536 + ], + [ + 1536 + ] + ], + "ms": 0.10183099657297134, + "GB/s": 85.06769344825665, + "TFLOPS": 0.0424760254300415 + }, + { + "input_size": [ + [ + 1536, + 1664 + ], + [ + 1664 + ] + ], + "ms": 0.10796499997377396, + "GB/s": 94.81235587909558, + "TFLOPS": 0.04734689946965889 + }, + { + "input_size": [ + [ + 1664, + 1792 + ], + [ + 1792 + ] + ], + "ms": 0.1157429963350296, + "GB/s": 103.17147799970981, + "TFLOPS": 0.05152602048366933 + }, + { + "input_size": [ + [ + 1792, + 1920 + ], + [ + 1920 + ] + ], + "ms": 0.12308000028133392, + "GB/s": 111.93864127809444, + "TFLOPS": 0.05590900214714739 + }, + { + "input_size": [ + [ + 1920, + 2048 + ], + [ + 2048 + ] + ], + "ms": 0.13025599718093872, + "GB/s": 120.87360536750782, + "TFLOPS": 0.06037587650628989 + }, + { + "input_size": [ + [ + 2048, + 2176 + ], + [ + 2176 + ] + ], + "ms": 0.13975700736045837, + "GB/s": 127.66936225230204, + "TFLOPS": 0.06377423335212126 + }, + { + "input_size": [ + [ + 2176, + 2304 + ], + [ + 2304 + ] + ], + "ms": 0.14849700033664703, + "GB/s": 135.16728253430261, + "TFLOPS": 0.06752330334800354 + }, + { + "input_size": [ + [ + 2304, + 2432 + ], + [ + 2432 + ] + ], + "ms": 0.15607500076293945, + "GB/s": 143.72741240009412, + "TFLOPS": 0.07180301742891973 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json new file mode 100644 index 0000000..b78c381 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/matrix_vector_multip_perf_data.json @@ -0,0 +1,6 @@ +{ + "matrix_vector_multip.json": { + "ms": 0.3142, + "efficiency": 7.0489 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json new file mode 100644 index 0000000..55e39be --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform.json @@ -0,0 +1,296 @@ +[ + { + "input_size": [ + [ + 4, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.009139999747276306, + "GB/s": 116.516414600269, + "TFLOPS": 0.057361927187824735 + }, + { + "input_size": [ + [ + 8, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.009500999934971333, + "GB/s": 222.4540589902001, + "TFLOPS": 0.11036480446025432 + }, + { + "input_size": [ + [ + 16, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.01006300002336502, + "GB/s": 418.4326731812891, + "TFLOPS": 0.20840226524204278 + }, + { + "input_size": [ + [ + 32, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.012788999825716019, + "GB/s": 657.204794318576, + "TFLOPS": 0.32796184667749595 + }, + { + "input_size": [ + [ + 64, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.019202999770641327, + "GB/s": 874.5300317961281, + "TFLOPS": 0.4368384158825538 + }, + { + "input_size": [ + [ + 128, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.030789999291300774, + "GB/s": 1090.3155820950246, + "TFLOPS": 0.5448917306321646 + }, + { + "input_size": [ + [ + 256, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.06005600094795227, + "GB/s": 1117.7109188168276, + "TFLOPS": 0.5587190533895199 + }, + { + "input_size": [ + [ + 512, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.11085200309753418, + "GB/s": 1210.930865019127, + "TFLOPS": 0.6053915321760459 + }, + { + "input_size": [ + [ + 1024, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.20907600224018097, + "GB/s": 1283.9916447781015, + "TFLOPS": 0.6419566404651943 + }, + { + "input_size": [ + [ + 2048, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.3862780034542084, + "GB/s": 1389.8987030040546, + "TFLOPS": 0.6949281439781023 + }, + { + "input_size": [ + [ + 4096, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 0.7548360228538513, + "GB/s": 1422.5052534461477, + "TFLOPS": 0.7112417740348714 + }, + { + "input_size": [ + [ + 8192, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 1.4495340585708618, + "GB/s": 1481.510571829739, + "TFLOPS": 0.7407496344436594 + }, + { + "input_size": [ + [ + 16384, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 2.838752031326294, + "GB/s": 1512.983040647386, + "TFLOPS": 0.7564886345485674 + }, + { + "input_size": [ + [ + 32768, + 128, + 8, + 64 + ], + [ + 128, + 16 + ], + [ + 128, + 16 + ] + ], + "ms": 5.886918544769287, + "GB/s": 1459.159135747248, + "TFLOPS": 0.7295781763136869 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json new file mode 100644 index 0000000..f298582 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/rotary_transform_perf_data.json @@ -0,0 +1,6 @@ +{ + "rotary_transform.json": { + "ms": 5.4456, + "efficiency": 74.2022 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json new file mode 100644 index 0000000..2e0ab59 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel.json @@ -0,0 +1,162 @@ +[ + { + "input_size": [ + [ + 4096 + ] + ], + "ms": 0.006053999997675419, + "GB/s": 5.412619757611839, + "TFLOPS": 0.0006765774697014799 + }, + { + "input_size": [ + [ + 8192 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 10.754184732080693, + "TFLOPS": 0.0013442730915100868 + }, + { + "input_size": [ + [ + 16384 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 21.508369464161387, + "TFLOPS": 0.0026885461830201737 + }, + { + "input_size": [ + [ + 32768 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 43.01673892832277, + "TFLOPS": 0.005377092366040347 + }, + { + "input_size": [ + [ + 65536 + ] + ], + "ms": 0.006093999836593866, + "GB/s": 86.03347785664555, + "TFLOPS": 0.010754184732080695 + }, + { + "input_size": [ + [ + 131072 + ] + ], + "ms": 0.006134000141173601, + "GB/s": 170.94489335948708, + "TFLOPS": 0.021368111669935887 + }, + { + "input_size": [ + [ + 262144 + ] + ], + "ms": 0.006134000141173601, + "GB/s": 341.88978671897416, + "TFLOPS": 0.04273622333987177 + }, + { + "input_size": [ + [ + 524288 + ] + ], + "ms": 0.0062939999625086784, + "GB/s": 666.3972076555627, + "TFLOPS": 0.08329965095694535 + }, + { + "input_size": [ + [ + 1048576 + ] + ], + "ms": 0.00685499981045723, + "GB/s": 1223.7211133402611, + "TFLOPS": 0.15296513916753265 + }, + { + "input_size": [ + [ + 2097152 + ] + ], + "ms": 0.008860000409185886, + "GB/s": 1893.5908832019568, + "TFLOPS": 0.2366988604002446 + }, + { + "input_size": [ + [ + 4194304 + ] + ], + "ms": 0.014072000049054623, + "GB/s": 2384.482083785541, + "TFLOPS": 0.2980602604731926 + }, + { + "input_size": [ + [ + 8388608 + ] + ], + "ms": 0.0230919998139143, + "GB/s": 2906.15211072204, + "TFLOPS": 0.363269013840255 + }, + { + "input_size": [ + [ + 16777216 + ] + ], + "ms": 0.05051499977707863, + "GB/s": 2656.987599570411, + "TFLOPS": 0.3321234499463014 + }, + { + "input_size": [ + [ + 33554432 + ] + ], + "ms": 0.08868200331926346, + "GB/s": 3026.943979080033, + "TFLOPS": 0.3783679973850041 + }, + { + "input_size": [ + [ + 67108864 + ] + ], + "ms": 0.1559540033340454, + "GB/s": 3442.4952263011182, + "TFLOPS": 0.4303119032876398 + }, + { + "input_size": [ + [ + 134217728 + ] + ], + "ms": 0.2989189922809601, + "GB/s": 3592.0829780891536, + "TFLOPS": 0.4490103722611442 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json new file mode 100644 index 0000000..507bcf4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/sin_kernel_perf_data.json @@ -0,0 +1,6 @@ +{ + "sin_kernel.json": { + "ms": 1.4189, + "efficiency": 90.2883 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py new file mode 100644 index 0000000..3469c2a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/flash_decode2_phi_perf.py @@ -0,0 +1,59 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from flash_decode2_phi import flash_decode_stage2 +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('flash_decode2_phi', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 18): # Adjust the range as needed for your testing + batch_size = 2 ** i + head_num = 8 # Example head number, adjust as needed + seq_block_num = 16 # Example sequence block number, adjust as needed + head_dim = 64 # Example head dimension, adjust as needed + + mid_out = torch.rand(batch_size, head_num, seq_block_num, head_dim, dtype=torch.float32) + mid_out_logexpsum = torch.rand(batch_size, head_num, seq_block_num, dtype=torch.float32) + B_Seqlen = torch.randint(1, seq_block_num * 32, (batch_size,), dtype=torch.int32) + Out = torch.empty(batch_size, head_num, head_dim, dtype=torch.float32) + + self.input_tensors.append((mid_out, mid_out_logexpsum, B_Seqlen, Out)) + + def to_cuda(self, input_tensor): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + return (mid_out.cuda(), mid_out_logexpsum.cuda(), B_Seqlen.cuda(), Out.cuda()) + + def call_op(self, input_tensor): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + block_seq = 32 # Example block sequence size, adjust as needed + flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq) + return Out + + def get_gbps(self, input_tensor, runtime): + mid_out, mid_out_logexpsum, B_Seqlen, Out = input_tensor + total_bytes = mid_out.numel() * mid_out.element_size() + mid_out_logexpsum.numel() * mid_out_logexpsum.element_size() + B_Seqlen.numel() * B_Seqlen.element_size() + Out.numel() * Out.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + mid_out, _, _, _ = input_tensor + FLOPS = 2 * mid_out.numel() # Example calculation, adjust based on actual operations + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py new file mode 100644 index 0000000..0eb17f4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_bwd_perf.py @@ -0,0 +1,51 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from l2_norm_bwd import _l2_norm_bwd +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('l2_norm_bwd', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(4, 15): + size = 2 ** i + x = torch.rand(size, dtype=torch.float32) + dy = torch.rand(size, dtype=torch.float32) + self.input_tensors.append((x, dy)) + + def to_cuda(self, input_tensor): + x, dy = input_tensor + return x.cuda(), dy.cuda() + + def call_op(self, input_tensor): + x, dy = input_tensor + return _l2_norm_bwd(x, dy) + + def get_gbps(self, input_tensor, runtime): + x, dy = input_tensor + total_bytes = (x.numel() + dy.numel() + x.numel()) * x.element_size() # x, dy, and dx + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x, dy = input_tensor + # Assuming each element involves a few FLOPs, e.g., multiplication, addition + FLOPS = 2 * x.numel() # Simplified estimation + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py new file mode 100644 index 0000000..c568861 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/l2_norm_triton1_perf.py @@ -0,0 +1,75 @@ +import sys +import os +import json + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from l2_norm_triton1 import _l2_norm_fwd +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('l2_norm_triton1', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(4, 15): + size = 2 ** i + input_tensor = torch.rand(size, dtype=torch.float32) + self.input_tensors.append(input_tensor) + + def to_cuda(self, input_tensor): + return input_tensor.cuda() + + def call_op(self, input_tensor): + return _l2_norm_fwd(input_tensor) + + def get_gbps(self, input_tensor, runtime): + x = input_tensor + total_bytes = 2 * x.numel() * x.element_size() # Read and write + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x = input_tensor + FLOPS = 2 * x.numel() # Each element involves a multiply and an add + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [input_tensor.shape], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py new file mode 100644 index 0000000..8901aea --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/matrix_vector_multip_perf.py @@ -0,0 +1,52 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from matrix_vector_multip import mv +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('matrix_vector_multip', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 20): # Adjust the range for different sizes + M = 128 * i + N = 128 * (i - 1) # Example: N is half of M + matrix = torch.rand((N, M), dtype=torch.float32) + vector = torch.rand((M,), dtype=torch.float32) + self.input_tensors.append((matrix, vector)) + + def to_cuda(self, input_tensor): + matrix, vector = input_tensor + return (matrix.cuda(), vector.cuda()) + + def call_op(self, input_tensor): + matrix, vector = input_tensor + return mv(matrix, vector) + + def get_gbps(self, input_tensor, runtime): + matrix, vector = input_tensor + total_bytes = (matrix.numel() + vector.numel() + matrix.size(0)) * matrix.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + matrix, vector = input_tensor + N, M = matrix.shape + FLOPS = 2 * N * M # Each element in the output involves M multiplications and M-1 additions + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py new file mode 100644 index 0000000..a8ceb06 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/performance_utils.py @@ -0,0 +1,145 @@ +# Modifications Copyright(C)[2025] Advanced Micro Devices, Inc. All rights reserved. +# https://github.com/thunlp/TritonBench - Apache License 2.0 +import torch +import triton +import triton.language as tl + +from typing import Callable +import json +import os + +class do_bench_config(): + def __init__( + self, + warm_up=25, + repetition=100, + grad_to_none=None, + quantiles=[0.5, 0.8, 0.2], + return_mode="median" + ): + self.warm_up = warm_up + self.repetition = repetition + self.grad_to_none = grad_to_none + self.quantiles = quantiles + self.return_mode = return_mode + +class Performance_Metrics: + def __init__( + self, + op_name, + dtype=None, + is_backward=False, + **kwargs + ): + self.op_name = op_name + self.dtype = dtype + if is_backward: + self.op_name += 'backward' + self.kwargs = kwargs + + self.input_tensors = [] + self.do_bench_config = do_bench_config() + + def get_input_tensors(self): + raise NotImplementedError("You must implement this method to get input tensors") + + def to_cuda(self, input_tensor): + raise NotImplementedError("You must implement this method to get input tensors") + + def call_op(self, input_tensor): + raise NotImplementedError("You must implement this method to call the op") + + def get_do_bench_config(self, warmup=None, rep=None): + if warmup != None and rep != None: + self.do_bench_config = do_bench_config( + warm_up=warmup, + repetition=rep, + ) + return + + if self.input_tensors == []: + raise NotImplementedError("You must implement this method to get input_tensors") + + previous_ms = None + epsilon = 1e-4 + stable_count = 0 + max_stable_count = 3 + input_tensor = self.to_cuda(self.input_tensors[-1]) + + for t in range(1, 11): + warmup = 100 * t + rep = 1000 * t + + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: self.call_op(input_tensor), + warmup=warmup, + rep=rep, + quantiles=[0.5, 0.8, 0.2], + return_mode="median" + ) + + print("warmup time:", warmup, "rep time:", rep, "runtime:", ms) + + if previous_ms is not None: + relative_change = abs(ms - previous_ms) / abs(previous_ms) if previous_ms != 0 else float('inf') + + if relative_change < epsilon: + stable_count += 1 + else: + stable_count = 0 + + if stable_count >= max_stable_count: + print(f"MS stabilized with warmup={warmup} and rep={rep}") + self.do_bench_config = do_bench_config( + warm_up=warmup, + repetition=rep, + ) + return + + previous_ms = ms + + print("MS did not stabilize. Returning default config.") + raise NotImplementedError("You must implement this method to make the runtime stable") + + def get_runtime(self, op: Callable): + ms, min_ms, max_ms = triton.testing.do_bench( + op, + warmup=self.do_bench_config.warm_up, + rep=self.do_bench_config.repetition, + quantiles=self.do_bench_config.quantiles, + return_mode=self.do_bench_config.return_mode + ) + return ms + + def get_gbps(self, input_tensor, runtime): + raise NotImplementedError("You must implement this method to get the method to calculate GBPS") + + def get_tflops(self, input_tensor, runtime): + raise NotImplementedError("You must implement this method to get the method to calculate TFLOPS") + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [item.shape if type(item)==torch.Tensor else item for item in input_tensor], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py new file mode 100644 index 0000000..97d9cde --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/rotary_transform_perf.py @@ -0,0 +1,56 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from rotary_transform import apply_rotary +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('rotary_transform', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 16): # Choose a reasonable range for testing + batch_size = 2 ** i + seqlen = 128 # Fixed sequence length + nheads = 8 # Number of attention heads + headdim = 64 # Dimension of each head + rotary_dim = 32 # Rotary dimension + x = torch.rand(batch_size, seqlen, nheads, headdim, dtype=torch.float32) + cos = torch.rand(seqlen, rotary_dim // 2, dtype=torch.float32) + sin = torch.rand(seqlen, rotary_dim // 2, dtype=torch.float32) + self.input_tensors.append((x, cos, sin)) + + def to_cuda(self, input_tensor): + x, cos, sin = input_tensor + return (x.cuda(), cos.cuda(), sin.cuda()) + + def call_op(self, input_tensor): + x, cos, sin = input_tensor + return apply_rotary(x, cos, sin) + + def get_gbps(self, input_tensor, runtime): + x, cos, sin = input_tensor + total_bytes = x.numel() * x.element_size() + cos.numel() * cos.element_size() + sin.numel() * sin.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x, cos, sin = input_tensor + # Assuming each element in x is involved in a few operations (e.g., multiply and add) + FLOPS = 2 * x.numel() # Simplified estimation + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py new file mode 100644 index 0000000..d6eeda2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/sin_kernel_perf.py @@ -0,0 +1,75 @@ +import sys +import os +import json +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# Correctly import the kernel function +from sin_kernel import call_kernel +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('sin_kernel', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(12, 28): + size = 2 ** i + input_tensor = torch.rand(size, dtype=torch.float32) + self.input_tensors.append(input_tensor) + + def to_cuda(self, input_tensor): + return input_tensor.cuda() + + def call_op(self, input_tensor): + return call_kernel(input_tensor) + + def get_gbps(self, input_tensor, runtime): + x = input_tensor + total_bytes = 2 * x.numel() * x.element_size() # Read and write + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + x = input_tensor + FLOPS = x.numel() # One sin operation per element + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + + def run_benchmark(self): + results = [] + for input_tensor_ in self.input_tensors: + try: + input_tensor = self.to_cuda(input_tensor_) + # print(input_tensor) + op = lambda : self.call_op(input_tensor) + ms = self.get_runtime(op) + gbps = self.get_gbps(input_tensor, ms) + tflops = self.get_tflops(input_tensor, ms) + result = { + "input_size": [input_tensor.shape], + "ms": ms, + "GB/s": gbps, + "TFLOPS": tflops + } + print(result) + results.append(result) + except Exception as e: + print(f"Failed to run benchmark for input tensor. Error: {e}") + input_tensor = None + folder_path = "/workspace/reflexion_oneshot_tritonbench_4/exec/gen_perf" + file_name = self.op_name + ".json" + file_path = os.path.join(folder_path, file_name) + with open(file_path, 'w', encoding='utf8') as f: + json.dump(results, f, indent=4) + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py new file mode 100644 index 0000000..6044dc4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/tmp/triton_matmul_perf.py @@ -0,0 +1,54 @@ +import sys +import os + +sys.path.append('/workspace/reflexion_oneshot_tritonbench_4/exec') +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from triton_matmul import matmul # Correctly import the matmul function +from performance_utils import Performance_Metrics, do_bench_config + +import torch +import triton +import triton.language as tl + +class performance_metrics(Performance_Metrics): + def __init__(self, dtype=None, is_backward=False, **kwargs): + super().__init__('triton_matmul', dtype=dtype, is_backward=is_backward, **kwargs) + + def get_input_tensors(self): + self.input_tensors = [] + for i in range(2, 33): # Define a range for matrix sizes + M = N = K = 128 * i + a = torch.rand((M, K), dtype=torch.float16) # Use float16 for compatibility + b = torch.rand((K, N), dtype=torch.float16) + self.input_tensors.append((a, b)) + + def to_cuda(self, input_tensor): + a, b = input_tensor + return (a.cuda(), b.cuda()) + + def call_op(self, input_tensor): + a, b = input_tensor + return matmul(a, b) + + def get_gbps(self, input_tensor, runtime): + a, b = input_tensor + M, K = a.shape + K, N = b.shape + total_bytes = (M * K + K * N + M * N) * a.element_size() + GBPS = total_bytes / (runtime / 1000) / 1e9 + return GBPS + + def get_tflops(self, input_tensor, runtime): + a, b = input_tensor + M, K = a.shape + K, N = b.shape + FLOPS = 2 * M * N * K + TFLOPS = FLOPS / (runtime / 1000) / 1e12 + return TFLOPS + +if __name__ == '__main__': + op_perf = performance_metrics() + op_perf.get_input_tensors() + op_perf.get_do_bench_config(warmup=100, rep=1000) + op_perf.run_benchmark() diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json new file mode 100644 index 0000000..12df2d7 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul.json @@ -0,0 +1,467 @@ +[ + { + "input_size": [ + [ + 256, + 256 + ], + [ + 256, + 256 + ] + ], + "ms": 0.009019999764859676, + "GB/s": 43.59379271071602, + "TFLOPS": 3.720003644647768 + }, + { + "input_size": [ + [ + 384, + 384 + ], + [ + 384, + 384 + ] + ], + "ms": 0.010223000310361385, + "GB/s": 86.54367339725967, + "TFLOPS": 11.077590194849238 + }, + { + "input_size": [ + [ + 512, + 512 + ], + [ + 512, + 512 + ] + ], + "ms": 0.012949000112712383, + "GB/s": 121.46605809786634, + "TFLOPS": 20.73020724870252 + }, + { + "input_size": [ + [ + 640, + 640 + ], + [ + 640, + 640 + ] + ], + "ms": 0.014112000353634357, + "GB/s": 174.14965549990777, + "TFLOPS": 37.151926506646994 + }, + { + "input_size": [ + [ + 768, + 768 + ], + [ + 768, + 768 + ] + ], + "ms": 0.01720000058412552, + "GB/s": 205.7525511520165, + "TFLOPS": 52.67265309491623 + }, + { + "input_size": [ + [ + 896, + 896 + ], + [ + 896, + 896 + ] + ], + "ms": 0.018681999295949936, + "GB/s": 257.836215690483, + "TFLOPS": 77.00708308622427 + }, + { + "input_size": [ + [ + 1024, + 1024 + ], + [ + 1024, + 1024 + ] + ], + "ms": 0.02353299967944622, + "GB/s": 267.3461133599119, + "TFLOPS": 91.25414002684992 + }, + { + "input_size": [ + [ + 1152, + 1152 + ], + [ + 1152, + 1152 + ] + ], + "ms": 0.0257789995521307, + "GB/s": 308.8802567336973, + "TFLOPS": 118.61001858573978 + }, + { + "input_size": [ + [ + 1280, + 1280 + ], + [ + 1280, + 1280 + ] + ], + "ms": 0.029627999290823936, + "GB/s": 331.7942566255078, + "TFLOPS": 141.56554949355 + }, + { + "input_size": [ + [ + 1408, + 1408 + ], + [ + 1408, + 1408 + ] + ], + "ms": 0.030910000205039978, + "GB/s": 384.8199262729385, + "TFLOPS": 180.60881873076582 + }, + { + "input_size": [ + [ + 1536, + 1536 + ], + [ + 1536, + 1536 + ] + ], + "ms": 0.03672400116920471, + "GB/s": 385.4638805498806, + "TFLOPS": 197.35750684153888 + }, + { + "input_size": [ + [ + 1664, + 1664 + ], + [ + 1664, + 1664 + ] + ], + "ms": 0.05195799842476845, + "GB/s": 319.7462662857386, + "TFLOPS": 177.352595699823 + }, + { + "input_size": [ + [ + 1792, + 1792 + ], + [ + 1792, + 1792 + ] + ], + "ms": 0.06085899844765663, + "GB/s": 316.59383971906124, + "TFLOPS": 189.1120535921859 + }, + { + "input_size": [ + [ + 1920, + 1920 + ], + [ + 1920, + 1920 + ] + ], + "ms": 0.06326399743556976, + "GB/s": 349.6206515012925, + "TFLOPS": 223.7572169608272 + }, + { + "input_size": [ + [ + 2048, + 2048 + ], + [ + 2048, + 2048 + ] + ], + "ms": 0.07392799854278564, + "GB/s": 340.40991905705846, + "TFLOPS": 232.38650474295193 + }, + { + "input_size": [ + [ + 2176, + 2176 + ], + [ + 2176, + 2176 + ] + ], + "ms": 0.07533200085163116, + "GB/s": 377.1286528809203, + "TFLOPS": 273.54398288962756 + }, + { + "input_size": [ + [ + 2304, + 2304 + ], + [ + 2304, + 2304 + ] + ], + "ms": 0.14252400398254395, + "GB/s": 223.47460855717318, + "TFLOPS": 171.628499371909 + }, + { + "input_size": [ + [ + 2432, + 2432 + ], + [ + 2432, + 2432 + ] + ], + "ms": 0.14476899802684784, + "GB/s": 245.13358856996922, + "TFLOPS": 198.72162913405504 + }, + { + "input_size": [ + [ + 2560, + 2560 + ], + [ + 2560, + 2560 + ] + ], + "ms": 0.16501599550247192, + "GB/s": 238.28962689505434, + "TFLOPS": 203.34048161711303 + }, + { + "input_size": [ + [ + 2688, + 2688 + ], + [ + 2688, + 2688 + ] + ], + "ms": 0.1670999974012375, + "GB/s": 259.4378496362499, + "TFLOPS": 232.45631327407992 + }, + { + "input_size": [ + [ + 2816, + 2816 + ], + [ + 2816, + 2816 + ] + ], + "ms": 0.18959100544452667, + "GB/s": 250.95671542245927, + "TFLOPS": 235.5647035432151 + }, + { + "input_size": [ + [ + 2944, + 2944 + ], + [ + 2944, + 2944 + ] + ], + "ms": 0.18975099921226501, + "GB/s": 274.05819318941786, + "TFLOPS": 268.94244024988205 + }, + { + "input_size": [ + [ + 3072, + 3072 + ], + [ + 3072, + 3072 + ] + ], + "ms": 0.21536900103092194, + "GB/s": 262.9120427218319, + "TFLOPS": 269.22193174715585 + }, + { + "input_size": [ + [ + 3200, + 3200 + ], + [ + 3200, + 3200 + ] + ], + "ms": 0.297435998916626, + "GB/s": 206.56544676430437, + "TFLOPS": 220.33647654859135 + }, + { + "input_size": [ + [ + 3328, + 3328 + ], + [ + 3328, + 3328 + ] + ], + "ms": 0.3247379958629608, + "GB/s": 204.6372917447065, + "TFLOPS": 227.01096897546105 + }, + { + "input_size": [ + [ + 3456, + 3456 + ], + [ + 3456, + 3456 + ] + ], + "ms": 0.32375550270080566, + "GB/s": 221.35103620532738, + "TFLOPS": 254.9963937085371 + }, + { + "input_size": [ + [ + 3584, + 3584 + ], + [ + 3584, + 3584 + ] + ], + "ms": 0.35669100284576416, + "GB/s": 216.0703112360975, + "TFLOPS": 258.1319984900578 + }, + { + "input_size": [ + [ + 3712, + 3712 + ], + [ + 3712, + 3712 + ] + ], + "ms": 0.36759498715400696, + "GB/s": 224.90422037600632, + "TFLOPS": 278.2814886785785 + }, + { + "input_size": [ + [ + 3840, + 3840 + ], + [ + 3840, + 3840 + ] + ], + "ms": 0.4137600064277649, + "GB/s": 213.82830294267677, + "TFLOPS": 273.70022776662626 + }, + { + "input_size": [ + [ + 3968, + 3968 + ], + [ + 3968, + 3968 + ] + ], + "ms": 0.4952859878540039, + "GB/s": 190.73857592726222, + "TFLOPS": 252.2835564264588 + }, + { + "input_size": [ + [ + 4096, + 4096 + ], + [ + 4096, + 4096 + ] + ], + "ms": 0.5376620292663574, + "GB/s": 187.22411202694673, + "TFLOPS": 255.62332095412455 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json new file mode 100644 index 0000000..953bb99 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/gen_perf/triton_matmul_perf_data.json @@ -0,0 +1,6 @@ +{ + "triton_matmul.json": { + "ms": 3.5604, + "efficiency": 89.1928 + } +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py b/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py new file mode 100644 index 0000000..8b39a49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/l2_norm_bwd.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + dx = torch.empty_like(x) + grid = (M,) + _l2_norm_bwd_kernel[grid](x, dy, dx, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py b/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py new file mode 100644 index 0000000..efc3d1a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/l2_norm_triton1.py @@ -0,0 +1,98 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(data * data) + rstd = tl.rsqrt(var + eps) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + y = data * rstd + tl.store(Y + base + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-05): + x = x if x.is_contiguous() else x.contiguous() + original_shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + element_size = x.element_size() + BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N)) + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.view(original_shape) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py b/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py new file mode 100644 index 0000000..0f93fd4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/matrix_vector_multip.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr, stride_bk: tl.constexpr, stride_cm: tl.constexpr, N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + start_n = pid_n * BLOCK_N + offsets_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offsets_n < N + accum = tl.zeros((BLOCK_N,), dtype=tl.float32) + for start_m in range(0, M, BLOCK_M): + offsets_m = start_m + tl.arange(0, BLOCK_M) + mask_m = offsets_m < M + offs_a = A + (offsets_n[:, None] * stride_am + offsets_m[None, :] * stride_ak) + tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0) + offs_b = B + offsets_m * stride_bk + seg_b = tl.load(offs_b, mask=mask_m, other=0.0) + accum += tl.sum(tile_a * seg_b[None, :], axis=1) + offs_c = C + offsets_n * stride_cm + tl.store(offs_c, accum, mask=mask_n) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dtype == B.dtype + assert A.size(1) == B.size(0) + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + BLOCK_N = 128 + BLOCK_M = 64 + grid = (triton.cdiv(N, BLOCK_N),) + mv_kernel[grid](A, B, C, stride_am=A.stride(0), stride_ak=A.stride(1), stride_bk=B.stride(0), stride_cm=C.stride(0), N=N, M=M, BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt b/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt new file mode 100644 index 0000000..63dd0da --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/performance_analysis.txt @@ -0,0 +1,3 @@ +Performance analysis for /workspace/reflexion_oneshot_tritonbench_4/exec: +Error processing l2_norm_triton1.json, skipping... +Error processing sin_kernel.json, skipping... \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py b/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py new file mode 100644 index 0000000..7212c74 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/rotary_transform.py @@ -0,0 +1,207 @@ +import torch +import triton +import triton.language as tl +import logging +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + end_idx = tl.load(CU_SEQLENS + pid_batch + 1) + cur_seqlen = end_idx - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < cur_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + mask_m_cs = rm_cs < seqlen_ro + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_offs = x0_offs + rotary_dim_half * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_offs = out0_offs + rotary_dim_half * stride_out_headdim + tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + rk_repeat = rk // 2 + cs_mask = rk_repeat[None, :] < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32) + rk_swap = rk + (rk + 1) % 2 * 2 - 1 + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'If cu_seqlens is passed in, then max_seqlen must be passed' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + assert rotary_dim <= headdim + rotary_dim = rotary_dim * 2 + assert headdim <= 256, 'Only support headdim <= 256' + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + assert cos.dtype == sin.dtype + assert x.dtype == cos.dtype + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + if interleaved: + BLOCK_M = 4 + else: + BLOCK_M = 8 if rotary_dim <= 64 else 4 + BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py b/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py new file mode 100644 index 0000000..786f954 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/sin_kernel.py @@ -0,0 +1,87 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py b/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py new file mode 100644 index 0000000..a59f8d1 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/exec/triton_matmul.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2, 'Input a must be 2-D' + assert b.dim() == 2, 'Input b must be 2-D' + assert a.shape[1] == b.shape[0], 'Incompatible dimensions' + M, K = a.shape + K2, N = b.shape + assert K == K2, 'Inner dimensions must match' + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json b/reflexion_oneshot_tritonbench_4/out.json new file mode 100644 index 0000000..455a27c --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json @@ -0,0 +1,42 @@ +2025-08-23_16-23-45 => File: matrix_vector_multip.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-24-10 => File: triton_matmul.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-24-19 => File: embedding_triton_kernel.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 158, in + result_gold = test_embedding() + ^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 88, in test_embedding + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py", line 29, in embedding + out = out.view(B * SEQ_LEN, d_model) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: shape '[512000, 128]' is invalid for input of size 65536 +2025-08-23_16-24-28 => File: int4_matmul.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 160, in + result_gold = test_correct_int4_s2() + ^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 144, in test_correct_int4_s2 + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py", line 54, in matmul_dequantize_int4_s2 + assert K == qweight.shape[0] * 8 + ^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError +2025-08-23_16-24-38 => File: flash_decode2_phi.py, Call Status: False, Exec Status: False, difficulty: -1, stderr: Traceback (most recent call last): + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 157, in + result_gold = test_flash_decode_stage2() + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 145, in test_flash_decode_stage2 + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py", line 34, in flash_decode_stage2 + assert Mid_O.dim() == 4 + ^^^^^^^^^^^^^^^^ +AssertionError +2025-08-23_16-24-57 => File: matrix_transpose.py, Call Status: True, Exec Status: False, difficulty: -1, stderr: Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py +2025-08-23_16-25-17 => File: rotary_transform.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-25-37 => File: sin_kernel.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-25-56 => File: l2_norm_bwd.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-26-15 => File: l2_norm_triton1.py, Call Status: True, Exec Status: True, difficulty: -1, stderr: None +2025-08-23_16-26-15 => File: /workspace/reflexion_oneshot_tritonbench_4.json, Call Accuracy: 0.7, Exec Accuracy: 0.6 diff --git a/reflexion_oneshot_tritonbench_4/out.json_all_passes.json b/reflexion_oneshot_tritonbench_4/out.json_all_passes.json new file mode 100644 index 0000000..31186e0 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_all_passes.json @@ -0,0 +1,92 @@ +[ + { + "pass_num": 0, + "file_name": "matrix_vector_multip.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "triton_matmul.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "embedding_triton_kernel.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 158, in \n result_gold = test_embedding()\n ^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 88, in test_embedding\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 29, in embedding\n out = out.view(B * SEQ_LEN, d_model)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nRuntimeError: shape '[512000, 128]' is invalid for input of size 65536", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "int4_matmul.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 160, in \n result_gold = test_correct_int4_s2()\n ^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 144, in test_correct_int4_s2\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 54, in matmul_dequantize_int4_s2\n assert K == qweight.shape[0] * 8\n ^^^^^^^^^^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "flash_decode2_phi.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 157, in \n result_gold = test_flash_decode_stage2()\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 145, in test_flash_decode_stage2\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n File \"/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py\", line 116, in decorate_context\n return func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 34, in flash_decode_stage2\n assert Mid_O.dim() == 4\n ^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "matrix_transpose.py", + "call_status": 1, + "exec_status": 0, + "stdout": "None", + "stderr": "Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "rotary_transform.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "sin_kernel.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_bwd.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_triton1.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json_perf_0.json b/reflexion_oneshot_tritonbench_4/out.json_perf_0.json new file mode 100644 index 0000000..29d17a4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_perf_0.json @@ -0,0 +1,16 @@ +{ + "speed_up": [ + 3.6883, + 2.0181, + 5.4456, + 3.5604, + 0.3142 + ], + "efficiency": [ + 99.617, + 1.44, + 74.2022, + 89.1928, + 7.0489 + ] +} \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.json_results_0.json b/reflexion_oneshot_tritonbench_4/out.json_results_0.json new file mode 100644 index 0000000..31186e0 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.json_results_0.json @@ -0,0 +1,92 @@ +[ + { + "pass_num": 0, + "file_name": "matrix_vector_multip.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "triton_matmul.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "embedding_triton_kernel.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 158, in \n result_gold = test_embedding()\n ^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 88, in test_embedding\n embedding(input_ids, weight, vob_start_id, vob_end_id, out)\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py\", line 29, in embedding\n out = out.view(B * SEQ_LEN, d_model)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nRuntimeError: shape '[512000, 128]' is invalid for input of size 65536", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "int4_matmul.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 160, in \n result_gold = test_correct_int4_s2()\n ^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 144, in test_correct_int4_s2\n triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py\", line 54, in matmul_dequantize_int4_s2\n assert K == qweight.shape[0] * 8\n ^^^^^^^^^^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "flash_decode2_phi.py", + "call_status": 0, + "exec_status": 0, + "stdout": "", + "stderr": "Traceback (most recent call last):\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 157, in \n result_gold = test_flash_decode_stage2()\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 145, in test_flash_decode_stage2\n flash_decode_stage2(test_case[\"mid_out\"], test_case[\"mid_out_logexpsum\"], test_case[\"B_Seqlen\"], test_case[\"Out\"], test_case[\"block_seq\"])\n File \"/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py\", line 116, in decorate_context\n return func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^\n File \"/workspace/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py\", line 34, in flash_decode_stage2\n assert Mid_O.dim() == 4\n ^^^^^^^^^^^^^^^^\nAssertionError", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "matrix_transpose.py", + "call_status": 1, + "exec_status": 0, + "stdout": "None", + "stderr": "Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "rotary_transform.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "sin_kernel.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_bwd.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + }, + { + "pass_num": 0, + "file_name": "l2_norm_triton1.py", + "call_status": 1, + "exec_status": 1, + "stdout": "None", + "stderr": "None", + "difficulty": -1 + } +] \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt b/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt new file mode 100644 index 0000000..c0b21ee --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/out.jsonpassk.txt @@ -0,0 +1,4 @@ +Call Accuracy: 70.0 +Exec Accuracy: 60.0 +Call Accuracy: 70.0 +Exec Accuracy: 60.0 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py new file mode 100644 index 0000000..a2d3d9f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_296397.py @@ -0,0 +1,158 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weight, out, indices, vocab_size, d_model, vob_start_id, stride_out_0, stride_weight_0, BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + for k in range(0, BLOCK_NN, BLOCK_N): + seq_off = pid * BLOCK_NN + k + tl.arange(0, BLOCK_N) + mask_seq = seq_off < out.shape[0] + token_idx_raw = tl.load(indices + seq_off, mask=mask_seq, other=0).to(tl.int32) + token_idx = token_idx_raw - vob_start_id + token_idx = tl.where(token_idx >= 0, token_idx, 0) + token_idx = tl.where(token_idx < vocab_size, token_idx, vocab_size - 1) + w_offs = token_idx[:, None] * d_model + offs_d[None, :] + vec = tl.load(weight + w_offs) + o_offs = seq_off[:, None] * stride_out_0 + offs_d[None, :] + tl.store(out + o_offs, vec, mask=mask_seq[:, None]) + +def embedding(weight: torch.Tensor, indices: torch.Tensor, vob_start_id: int, vob_end_id: int, out: torch.Tensor=None) -> torch.Tensor: + shape_2d = indices.shape + B, SEQ_LEN = shape_2d + d_model = weight.size(-1) + weight = weight.view(-1, d_model) + if out is None: + out = torch.empty((B * SEQ_LEN, d_model), dtype=weight.dtype, device=weight.device) + else: + out = out.view(B * SEQ_LEN, d_model) + vocab_sz = vob_end_id - vob_start_id if vob_end_id > vob_start_id else weight.size(0) + BLOCK_N = 16 + BLOCK_NN = 64 + BLOCK_DMODEL = triton.next_power_of_2(d_model) + grid = lambda META: (triton.cdiv(out.shape[0], META['BLOCK_NN']),) + embedding_kernel[grid](weight, out, indices.view(-1), vocab_size=vocab_sz, d_model=d_model, vob_start_id=vob_start_id, stride_out_0=out.stride(0), stride_weight_0=weight.stride(0), BLOCK_N=BLOCK_N, BLOCK_NN=BLOCK_NN, BLOCK_DMODEL=BLOCK_DMODEL) + return out.view(B, SEQ_LEN, d_model) + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py new file mode 100644 index 0000000..96f48fb --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/embedding_triton_kernel.py_gen_triton_code_97456.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weights_ptr, indices_ptr, out_ptr, v_stride_0, v_stride_1, out_stride_0, out_stride_1, vocab_size: tl.constexpr, hidden_size: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, NUM_SEQ: tl.constexpr, MAX_LEN: tl.constexpr): + seq_id = tl.program_id(0) + if seq_id >= NUM_SEQ: + return + len_offsets = tl.program_id(1) + offset_start = len_offsets * BLOCK_NN + dim_offsets = tl.arange(0, BLOCK_SIZE) + indices_base = indices_ptr + seq_id * out_stride_0 + out_base = out_ptr + seq_id * out_stride_0 + for i in range(0, BLOCK_NN, BLOCK_N): + cur_offset = offset_start + i + cur_mask = cur_offset < MAX_LEN + token_id = tl.load(indices_base + cur_offset * out_stride_1, mask=cur_mask, other=0) + token_mask = token_id < vocab_size + valid_mask = cur_mask & token_mask + selector = token_id[:, None] * v_stride_0 + dim_offsets[None, :] * v_stride_1 + w_ptrs = weights_ptr + selector + vec = tl.load(w_ptrs, mask=valid_mask[:, None], other=0.0) + out_ptrs = out_base + cur_offset[:, None] * out_stride_1 + dim_offsets[None, :] + tl.store(out_ptrs, vec, mask=valid_mask[:, None]) + +def embedding(weights: torch.Tensor, indices: torch.Tensor, vob_start_id=None, vob_end_id=None, out=None) -> torch.Tensor: + assert weights.dim() == 2, 'weights must be 2-D [vocab_size, hidden_size]' + vocab_size, hidden_size = weights.shape + assert indices.dim() == 2, 'indices must be 2-D: [batch, seq]' + batch, seq_len = indices.shape + if out is None: + out = torch.empty((batch, seq_len, hidden_size), dtype=weights.dtype, device=weights.device) + else: + assert out.shape == (batch, seq_len, hidden_size), f'out tensor shape mismatch: expected {(batch, seq_len, hidden_size)}, got {out.shape}' + BLOCK_SIZE = triton.next_power_of_2(hidden_size) + BLOCK_N = 1 + BLOCK_NN = triton.next_power_of_2(seq_len) if seq_len > 0 else 1 + grid = (batch, triton.cdiv(seq_len, BLOCK_NN)) + embedding_kernel[grid](weights, indices, out, weights.stride(0), weights.stride(1), out.stride(0), out.stride(1), vocab_size=vocab_size, hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE, BLOCK_N=BLOCK_N, BLOCK_NN=BLOCK_NN, NUM_SEQ=batch, MAX_LEN=seq_len) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py new file mode 100644 index 0000000..9ccee70 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py @@ -0,0 +1,146 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_obs, stride_oh, stride_od, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d * stride_mid_od + offs_l = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_id in range(0, block_n_size): + tv = tl.load(Mid_O + offs_v + block_id * stride_mid_os) + l_cur = tl.load(Mid_O_LogExpSum + offs_l + block_id * stride_mid_o_es) + new_max = tl.maximum(max_logic, l_cur) + old_scale = tl.exp(max_logic - new_max) + acc *= old_scale + sum_exp *= old_scale + cur_exp = tl.exp(l_cur - new_max) + acc += tv * cur_exp + sum_exp += cur_exp + max_logic = new_max + tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od, (acc / sum_exp).to(Out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(mid_out: torch.Tensor, mid_out_logexpsum: torch.Tensor, B_Seqlen: torch.Tensor, out: torch.Tensor, block_seq: int, BLOCK_DMODEL: int=None) -> None: + if BLOCK_DMODEL is None: + BLOCK_DMODEL = out.shape[-1] + batch, head_num = out.shape[:2] + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, mid_out, mid_out_logexpsum, out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), out.stride(0), out.stride(1), out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_473638.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py new file mode 100644 index 0000000..721898a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/flash_decode2_phi.py_gen_triton_code_577640.py @@ -0,0 +1,157 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_out_b, stride_out_h, stride_out_d, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for block_seq_id in range(0, block_n_size): + ptr_v = Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_seq_id * stride_mid_os + offs_d * stride_mid_od + tv = tl.load(ptr_v).to(tl.float32) + ptr_logic = Mid_O_LogExpSum + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + block_seq_id * stride_mid_o_es + tlogic = tl.load(ptr_logic).to(tl.float32) + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + delta = tl.exp(tlogic - new_max) + acc += delta * tv + sum_exp += delta + max_logic = new_max + acc = acc / sum_exp + ptr_out = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offs_d * stride_out_d + tl.store(ptr_out, acc.to(Out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(B_Seqlen: torch.Tensor, Mid_O: torch.Tensor, Mid_O_LogExpSum: torch.Tensor, Out: torch.Tensor, BLOCK_SEQ: int=64): + assert Mid_O.dim() == 4 + assert Mid_O_LogExpSum.dim() == 3 + assert B_Seqlen.dim() == 1 + assert Out.dim() == 3 + assert Mid_O.shape[0] == Out.shape[0] + assert Mid_O.shape[1] == Out.shape[1] + assert Mid_O.shape[3] == Out.shape[2] + assert Mid_O.shape[0] == Mid_O_LogExpSum.shape[0] + assert Mid_O.shape[1] == Mid_O_LogExpSum.shape[1] + assert Mid_O.shape[2] == Mid_O_LogExpSum.shape[2] + BLOCK_DMODEL = Out.shape[-1] + batch, head_num = (Out.shape[0], Out.shape[1]) + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py new file mode 100644 index 0000000..9eb06c3 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_126776.py @@ -0,0 +1,160 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scale_g, stride_scale_n, stride_zp_g, stride_zp_n, group_size, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + start_k = pid_k * BLOCK_SIZE_K + offs_k = start_k + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] // 8 * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_curr = k_step * BLOCK_SIZE_K * SPLIT_K + pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = k_curr[None, :] < K + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + b_raw = tl.load(b_ptrs, mask=mask_k, other=0) + group_idx = k_curr[:, None] // group_size + shift = k_curr[:, None] % 8 * 4 + bits = b_raw >> shift & 15 + scales = tl.load(scales_ptr + group_idx * stride_scale_g + offs_n[None, :] * stride_scale_n, mask=mask_k, other=0.0) + zeros = tl.load(zeros_ptr + group_idx * stride_zp_g + offs_n[None, :] // 8 * stride_zp_n, mask=mask_k, other=0.0) + zeros_bits = zeros >> offs_n[None, :] % 8 * 4 & 15 + b_deq = (bits - zeros_bits) * scales + accumulator += tl.dot(a, b_deq) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K // 8 * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + out_vals = accumulator.to(c_ptr.dtype.element_ty) + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, out_vals, mask=mask_c) + else: + tl.store(c_ptrs, out_vals, mask=mask_c) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.FloatTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, split_k: int=1) -> torch.FloatTensor: + assert x.dim() == 2 and qweight.dim() == 2 and (scales.dim() == 2) and (zeros.dim() == 2) + M, K = x.shape + assert K == qweight.shape[0] * 8 + N = qweight.shape[1] + assert scales.shape == (K // scales.shape[0], N) + assert zeros.shape == (K // zeros.shape[0], N) + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), split_k) + matmul_kernel[grid](x, qweight, output, scales, zeros, M, N, K, x.stride(0), x.stride(1), qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), scales.shape[0]) + return output + +def quantize_int4(w: torch.Tensor, group_size: int=128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + w: [OC, K] (fp16/fp32 weights) + returns: (qpacked, scales, zerospacked, group_size) + qpacked : int32, shape [OC, K//8] + scales : fp16/fp32 [OC, K//group_size] + zerospacked : int32 [K//group_size, OC//8] + """ + assert w.dim() == 2 + OC, K = w.shape + assert K % group_size == 0 + w = w.view(-1, K) + OC_total, K_ = w.shape + groups_per_row = K_ // group_size + w = w.view(OC_total, groups_per_row, group_size) + w_min = w.min(dim=2, keepdim=True).values + w_max = w.max(dim=2, keepdim=True).values + scales = (w_max - w_min) / 15.0 + zeros = torch.round(-w_min / scales).clamp(0, 15).to(torch.int32) + quant = torch.clamp(torch.round(w / scales + zeros), 0, 15).to(torch.int32) + OC_pack = OC_total // 8 + zeros_pack = zeros.permute(1, 0, 2).contiguous().view(groups_per_row, OC_pack, 8) + zeros_packed = zeros_pack[..., 0] + for p in range(1, 8): + zeros_packed |= zeros_pack[..., p] << p * 4 + zeros_packed = zeros_packed.view(groups_per_row, OC_pack).contiguous() + K_pack = K_ // 8 + quant = quant.view(OC_total, K_) + qpacked = torch.empty((OC_total, K_pack), dtype=torch.int32, device=w.device) + for p in range(8): + qpacked |= quant[:, p::8] << p * 4 + qpacked = qpacked.contiguous() + scales = scales.view(OC_total, groups_per_row).contiguous() + return (qpacked, scales, zeros_packed, group_size) + +def unpack_int4(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int=128) -> torch.Tensor: + """ + qweight : int32 [OC, K//8] + scales : [OC, K//group_size] + zeros : int32 [K//group_size, OC//8] + returns : fp16/fp32 tensor [OC, K] + """ + OC, K8 = qweight.shape + K = K8 * 8 + group_dim = K // group_size + assert scales.shape == (OC, group_dim) + assert zeros.shape == (group_dim, OC // 8) + quant = torch.empty((OC, K), dtype=torch.int32, device=qweight.device) + for p in range(8): + mask = 15 << p * 4 + quant[:, p::8] = (qweight & mask) >> p * 4 + OC8 = OC // 8 + zeros_ext = torch.empty((group_dim, OC), dtype=torch.int32, device=zeros.device) + for p in range(8): + mask = 15 << p * 4 + zeros_ext[:, p::8] = (zeros & mask) >> p * 4 + zeros_ext = zeros_ext.permute(1, 0).contiguous() + scales_mat = scales.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K) + zeros_mat = zeros_ext.view(OC, group_dim).unsqueeze(-1).expand(-1, -1, group_size).reshape(OC, K) + return (quant.float() - zeros_mat) * scales_mat + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py new file mode 100644 index 0000000..5dbe6fd --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/int4_matmul.py_gen_triton_code_506673.py @@ -0,0 +1,186 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _pack_int4x2(x0, x1): + x0 = x0 & 15 + x1 = x1 & 15 + return x1 | x0 << 4 + +@triton.jit +def _unpack_int4x2(b): + low = b >> 0 & 15 + high = b >> 4 & 15 + return (low, high) + +@triton.jit +def matmul_kernel(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K2: tl.constexpr, SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_k = tl.cdiv(K, BLOCK_K2 * 2) + grid_m = pid // num_pid_n + grid_n = pid % num_pid_n + rm = grid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = grid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk_packed = pid_z * BLOCK_K2 + tl.arange(0, BLOCK_K2) + rk = rk_packed * 2 + a_ptrs = a_ptr + rm[:, None] * stride_a_m + rk[None, :] * stride_a_k + a_mask = (rm[:, None] < M) & (rk[None, :] < K) + A = tl.load(a_ptrs, mask=a_mask, other=0.0) + q_ptrs = q_ptr + (rk_packed[:, None] * stride_q_k2 + rn[None, :] * stride_q_n) + packed = tl.load(q_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0) + lo, hi = _unpack_int4x2(packed) + q_vals = tl.interleave(lo, hi) + group_idx = rk[:, None] // group_size + s_ptrs = scales_ptr + group_idx * stride_s_g + rn[None, :] * stride_s_n + z_ptrs = zeros_ptr + group_idx * stride_z_g + rn[None, :] * stride_z_n + scales = tl.load(s_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + zeros = tl.load(z_ptrs, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + b_fp = (q_vals - zeros) * scales + acc = tl.dot(A, b_fp) + c_offs = c_ptr + rm[:, None] * stride_c_m + rn[None, :] * stride_c_n + c_mask = (rm < M)[:, None] & (rn < N)[None, :] + if SPLIT_K > 1: + tl.atomic_add(c_offs, acc, mask=c_mask) + else: + tl.store(c_offs, acc, mask=c_mask) + +@triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K2': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K2': 32, 'SPLIT_K': 2}, num_stages=4, num_warps=8)], key=['M', 'N', 'K']) +@triton.jit +def matmul_dequantize_int4_s2(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K2: tl.constexpr, SPLIT_K: tl.constexpr): + matmul_kernel(q_ptr, a_ptr, c_ptr, scales_ptr, zeros_ptr, M, N, K, stride_a_m, stride_a_k, stride_q_k2, stride_q_n, stride_s_g, stride_s_n, stride_z_g, stride_z_n, stride_c_m, stride_c_n, group_size, BLOCK_M, BLOCK_N, BLOCK_K2, SPLIT_K) + +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int=128) -> torch.Tensor: + """ + Python-launchable GEMM with INT4 quantized weights. + + Memory layout expected: + qweight – [ K//2 , N ] int32 column-major + scales – [ G , N ] float column-major, G = K // group_size + qzeros – same shape as scales + x – [ M , K ] fp16/fp32 row-major + Output: + c – [ M , N ] fp32 + """ + x = x.contiguous() + qweight = qweight.contiguous() + scales = scales.contiguous() + qzeros = qzeros.contiguous() + M, K = x.shape + N = qweight.size(1) + device = x.device + out = torch.empty((M, N), dtype=torch.float32, device=device) + + def grid(meta): + grid_m = triton.cdiv(M, meta['BLOCK_M']) + grid_n = triton.cdiv(N, meta['BLOCK_N']) + grid_z = meta['SPLIT_K'] + return (grid_m * grid_n, grid_z) + matmul_dequantize_int4_s2[grid](qweight, x, out, scales, qzeros, M, N, K, x.stride(0), x.stride(1), qweight.stride(0), qweight.stride(1), scales.stride(0), scales.stride(1), qzeros.stride(0), qzeros.stride(1), out.stride(0), out.stride(1), group_size) + return out + +def quantize_int4(x: torch.Tensor, group_size: int=128, transpose: bool=True): + """ + Quantise weight matrix (row-vector row-major) into INT4. + + Returns tensors that are column-major (as expected by Triton kernel). + + Args + ---- + x : [K_orig, N] float + Returns + ------- + packed : [ K_orig//2 , N ] int32 column-major + scales : [ G , N ] float column-major, G = K_orig//group_size + zeros : [ G , N ] float column-major + """ + K_orig, N = x.shape + assert K_orig % group_size == 0 + G = K_orig // group_size + x = x.view(G, group_size, N) + x_min = x.min(dim=1, keepdim=True)[0] + x_max = x.max(dim=1, keepdim=True)[0] + scales = (x_max - x_min) / 15.0 + zeros = torch.round(-x_min / scales).clamp(0, 15) + q = torch.round(x / scales + zeros).clamp(0, 15).to(torch.uint8) + q = q.view(K_orig, N) + if transpose: + q = q.T.contiguous() + scales = scales.squeeze(1).T.contiguous() + zeros = zeros.squeeze(1).T.contiguous() + else: + scales = scales.squeeze(1).contiguous() + zeros = zeros.squeeze(1).contiguous() + packed = torch.zeros((N, K_orig // 2), dtype=torch.int32, device=x.device) + for k in range(0, K_orig, 2): + low = q[:, k].to(torch.int32) + high = q[:, k + 1].to(torch.int32) + packed[:, k // 2] = (high & 15) << 4 | low & 15 + return (packed.view(K_orig // 2, N).contiguous(), scales, zeros) + +def unpack_int4(packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int=128): + """ + De-quantize the output of quantize_int4 back to float for validation. + Assumes column-major layout (same as the kernel). + Returns tensor of shape [K, N] – float32, column-major + """ + packed = packed.contiguous() + scales = scales.contiguous() + zeros = zeros.contiguous() + K2, N = packed.shape + K = K2 * 2 + device = packed.device + unpacked = torch.zeros((K, N), dtype=torch.float32, device=device) + for col in range(N): + pack = packed[:, col].clone() + even = pack & 15 + odd = pack >> 4 & 15 + int_vec = torch.empty(K, dtype=torch.float32, device=device) + int_vec[0::2] = even.float() + int_vec[1::2] = odd.float() + group_idx = torch.arange(K, device=device) // group_size + s = scales[group_idx, col] + z = zeros[group_idx, col] + unpacked[:, col] = (int_vec - z) * s + return unpacked + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py new file mode 100644 index 0000000..8b39a49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + dx = torch.empty_like(x) + grid = (M,) + _l2_norm_bwd_kernel[grid](x, dy, dx, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_240873.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py new file mode 100644 index 0000000..a863581 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py @@ -0,0 +1,114 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + x_square = tl.where(mask, x * x, 0.0) + var = tl.sum(x_square, axis=0) + rstd = 1.0 / tl.sqrt(var + eps) + dy_x = tl.where(mask, dy * x, 0.0) + sum_dy_x = tl.sum(dy_x, axis=0) + term = 1.0 / (var + eps) * rstd * sum_dy_x + dx = dy * rstd - term * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(x.shape[0], dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + _l2_norm_bwd_kernel[M,](x, dy, dx, x.stride(0), N, eps, BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_bwd.py_gen_triton_code_419958.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py new file mode 100644 index 0000000..efc3d1a --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py @@ -0,0 +1,98 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(data * data) + rstd = tl.rsqrt(var + eps) + for base in range(0, N, BLOCK_N): + mask = cols < N - base + data = tl.load(X + base + cols, mask=mask, other=0.0).to(tl.float32) + y = data * rstd + tl.store(Y + base + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-05): + x = x if x.is_contiguous() else x.contiguous() + original_shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + element_size = x.element_size() + BLOCK_N = min(65536 // element_size, triton.next_power_of_2(N)) + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.view(original_shape) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_414289.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py new file mode 100644 index 0000000..83eb1a5 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py @@ -0,0 +1,99 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + var = tl.zeros([], dtype=tl.float32) + for off in range(0, N, BLOCK_N): + mask = off + cols < N + x = tl.load(X + off + cols, mask=mask, other=0.0).to(tl.float32) + var += tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + for off in range(0, N, BLOCK_N): + mask = off + cols < N + x = tl.load(X + off + cols, mask=mask, other=0.0) + y = x * rstd + tl.store(Y + off + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-12): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]).contiguous() + y = torch.empty_like(x) + M, N = x.shape + element_size = x.element_size() + max_block_n = 65536 // element_size + BLOCK_N = min(max_block_n, triton.next_power_of_2(N)) + if N > max_block_n: + raise RuntimeError('This L2 norm only supports feature dim < 64KB.') + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return y.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/l2_norm_triton1.py_gen_triton_code_485284.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py new file mode 100644 index 0000000..4750b49 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M: tl.constexpr, D_HEAD: tl.constexpr): + size_m_arange = tl.arange(0, SIZE_M) + d_head_arange = tl.arange(0, D_HEAD) + rows = size_m_arange[:, None] + cols = d_head_arange[None, :] + matrix_ptr = M + rows * matrix_stridex + cols * matrix_stridey + matrix_value = tl.load(matrix_ptr) + out_rows = d_head_arange[:, None] + out_cols = size_m_arange[None, :] + out_ptr = Out + out_rows * out_stridex + out_cols * out_stridey + tl.store(out_ptr, tl.trans(matrix_value)) + +def wrapper(SIZE_M: int, D_HEAD: int): + device = 'cuda' + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + grid = (1,) + kernel[grid](matrix, out, matrix.stride(1), matrix.stride(0), out.stride(1), out.stride(0), SIZE_M=SIZE_M, D_HEAD=D_HEAD) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout new file mode 100644 index 0000000..15547cc --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_179599.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_179599.py \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py new file mode 100644 index 0000000..b993c3d --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M: tl.constexpr, D_HEAD: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + mask = mask_m[:, None] & mask_n[None, :] + src_ptrs = M + offs_m[:, None] * matrix_stridey + offs_n[None, :] * matrix_stridex + out_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + original_data = tl.load(src_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, tl.trans(original_data), mask=tl.trans(mask)) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + BLOCK_M = 64 + BLOCK_N = 64 + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_N)) + kernel[grid](matrix, out, matrix.stride(1), matrix.stride(0), out.stride(1), out.stride(0), SIZE_M=SIZE_M, D_HEAD=D_HEAD, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout new file mode 100644 index 0000000..09a97cf --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_transpose.py_gen_triton_code_602994.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_602994.py \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py new file mode 100644 index 0000000..0f93fd4 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py @@ -0,0 +1,72 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, stride_am: tl.constexpr, stride_ak: tl.constexpr, stride_bk: tl.constexpr, stride_cm: tl.constexpr, N: int, M: int, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + start_n = pid_n * BLOCK_N + offsets_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offsets_n < N + accum = tl.zeros((BLOCK_N,), dtype=tl.float32) + for start_m in range(0, M, BLOCK_M): + offsets_m = start_m + tl.arange(0, BLOCK_M) + mask_m = offsets_m < M + offs_a = A + (offsets_n[:, None] * stride_am + offsets_m[None, :] * stride_ak) + tile_a = tl.load(offs_a, mask=mask_n[:, None] & mask_m[None, :], other=0.0) + offs_b = B + offsets_m * stride_bk + seg_b = tl.load(offs_b, mask=mask_m, other=0.0) + accum += tl.sum(tile_a * seg_b[None, :], axis=1) + offs_c = C + offsets_n * stride_cm + tl.store(offs_c, accum, mask=mask_n) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dtype == B.dtype + assert A.size(1) == B.size(0) + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + BLOCK_N = 128 + BLOCK_M = 64 + grid = (triton.cdiv(N, BLOCK_N),) + mv_kernel[grid](A, B, C, stride_am=A.stride(0), stride_ak=A.stride(1), stride_bk=B.stride(0), stride_cm=C.stride(0), N=N, M=M, BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_432557.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py new file mode 100644 index 0000000..74e1d21 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py @@ -0,0 +1,71 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + acc = tl.zeros((BLOCK_M,), dtype=tl.float32) + for start_n in range(0, N, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_a = (offs_m[:, None] < M) & (offs_n[None, :] < N) + mask_b = offs_n < N + a_blk = tl.load(A + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an, mask=mask_a, other=0.0) + b_vec = tl.load(B + offs_n * stride_b, mask=mask_b, other=0.0) + product = a_blk * b_vec[None, :] + acc += tl.sum(product, axis=1) + offs_out = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_out = offs_out < M + tl.store(C + offs_out * stride_c, acc.to(C.dtype.element_ty), mask=mask_out) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1 + M, N = A.shape + assert B.size(0) == N + C = torch.empty(M, dtype=A.dtype, device=A.device) + BLOCK_M = 64 + BLOCK_N = 64 + grid = (triton.cdiv(M, BLOCK_M),) + mv_kernel[grid](A, B, C, M, N, A.stride(0), A.stride(1), B.stride(0), C.stride(0), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/matrix_vector_multip.py_gen_triton_code_450542.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py new file mode 100644 index 0000000..6de6c16 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_571970.py @@ -0,0 +1,228 @@ +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + x_batch_stride = stride_x_batch + ox_batch_stride = stride_out_batch + offset_b = pid_batch + actual_seqlen = seqlen + else: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + actual_seqlen = seq_end - seq_start + x_batch_stride = stride_x_seqlen + ox_batch_stride = stride_out_seqlen + offset_b = seq_start + if pid_m * BLOCK_M >= actual_seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < actual_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + off = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + off = SEQLEN_OFFSETS + rm_cos = rm + off + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + X_ptr = X + offset_b * x_batch_stride + pid_head * stride_x_nheads + OUT_ptr = OUT + offset_b * ox_batch_stride + pid_head * stride_out_nheads + if not INTERLEAVED: + x0_ptr = X_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim + cos_ptr = COS + rm_cos[:, None] * rotary_dim_half + rk_half[None, :] + sin_ptr = SIN + rm_cos[:, None] * rotary_dim_half + rk_half[None, :] + cos = tl.load(cos_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & mask_k_half[None, :], other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, o0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, o1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + BLOCK_P = BLOCK_K + rk = tl.arange(0, BLOCK_P) + mask_k = rk < rotary_dim + cos_sin_idx = rk // 2 + cos_sin_mask = cos_sin_idx < rotary_dim_half + cos_ptr = COS + rm_cos[:, None] * rotary_dim_half + cos_sin_idx[None, :] + sin_ptr = SIN + rm_cos[:, None] * rotary_dim_half + cos_sin_idx[None, :] + cos_val = tl.load(cos_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & cos_sin_mask[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr, mask=mask_m[:, None] & (rm_cos[:, None] < seqlen_ro) & cos_sin_mask[None, :], other=0.0).to(tl.float32) + x_even_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (2 * (rk // 2))[None, :] * stride_x_headdim + x_odd_ptr = X_ptr + rm[:, None] * stride_x_seqlen + (2 * (rk // 2) + 1)[None, :] * stride_x_headdim + x_even = tl.load(x_even_ptr, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x_odd = tl.load(x_odd_ptr, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + is_even = rk % 2 == 0 + rot_even = x_even * cos_val - x_odd * sin_val + rot_odd = x_even * sin_val + x_odd * cos_val + final = tl.where(is_even[None, :], rot_even, rot_odd) + tl.store(OUT_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, final, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'max_seqlen must be provided when cu_seqlens is given.' + assert x.dim() == 3, 'x must be 3-D for variable-length case (total_seqlen, nheads, headdim)' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype, f'Mismatched dtypes cos={cos.dtype}, sin={sin.dtype}' + assert x.dtype == cos.dtype, f'Mismatched dtypes x={x.dtype}, cos={cos.dtype}' + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in {torch.int32, torch.int64} + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + BLOCK_K = max(32, min(256, triton.next_power_of_2(rotary_dim_half))) + BLOCK_M = 4 if interleaved else 8 if rotary_dim <= 64 else 4 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + if not is_varlen: + stride_x_b = x.stride(0) + stride_x_s = x.stride(1) + stride_x_n = x.stride(2) + stride_x_h = x.stride(3) + stride_o_b = output.stride(0) + stride_o_s = output.stride(1) + stride_o_n = output.stride(2) + stride_o_h = output.stride(3) + else: + stride_x_b = 0 + stride_x_s = x.stride(0) + stride_x_n = x.stride(1) + stride_x_h = x.stride(2) + stride_o_b = 0 + stride_o_s = output.stride(0) + stride_o_n = output.stride(1) + stride_o_h = output.stride(2) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, 0, stride_o_b, stride_o_s, stride_o_n, stride_o_h, stride_x_b, stride_x_s, stride_x_n, stride_x_h, BLOCK_K=BLOCK_K, IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), IS_VARLEN=is_varlen, INTERLEAVED=interleaved, CONJUGATE=conjugate, BLOCK_M=BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py new file mode 100644 index 0000000..7212c74 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py @@ -0,0 +1,207 @@ +import torch +import triton +import triton.language as tl +import logging +from typing import Optional, Union + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + end_idx = tl.load(CU_SEQLENS + pid_batch + 1) + cur_seqlen = end_idx - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < cur_seqlen + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + mask_m_cs = rm_cs < seqlen_ro + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k_half = rk_half < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_offs = x0_offs + rotary_dim_half * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + out0_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_offs = out0_offs + rotary_dim_half * stride_out_headdim + tl.store(out0_offs, y0, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(out1_offs, y1, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + rk_repeat = rk // 2 + cs_mask = rk_repeat[None, :] < rotary_dim_half + cos_offs = rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_offs = cos_offs + cos = tl.load(COS + cos_offs, mask=mask_m_cs[:, None] & cs_mask, other=1.0).to(tl.float32) + sin = tl.load(SIN + sin_offs, mask=mask_m_cs[:, None] & cs_mask, other=0.0).to(tl.float32) + rk_swap = rk + (rk + 1) % 2 * 2 - 1 + x0_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_offs = x_ptr + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + x0 = tl.load(x0_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_offs, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_offs = out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_offs, out, mask=mask_m[:, None] & mask_k[None, :]) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'If cu_seqlens is passed in, then max_seqlen must be passed' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + assert rotary_dim <= headdim + rotary_dim = rotary_dim * 2 + assert headdim <= 256, 'Only support headdim <= 256' + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + assert cos.dtype == sin.dtype + assert x.dtype == cos.dtype + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + if interleaved: + BLOCK_M = 4 + else: + BLOCK_M = 8 if rotary_dim <= 64 else 4 + BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/rotary_transform.py_gen_triton_code_82100.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py new file mode 100644 index 0000000..786f954 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py @@ -0,0 +1,87 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_186414.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py new file mode 100644 index 0000000..7836ee2 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor, output: torch.Tensor=None, BLOCK_SIZE: int=128): + n_elements = x.numel() + assert x.is_contiguous() + if output is None: + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/sin_kernel.py_gen_triton_code_412027.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py new file mode 100644 index 0000000..a59f8d1 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl + +@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)], key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + b_mask = (offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2, 'Input a must be 2-D' + assert b.dim() == 2, 'Input b must be 2-D' + assert a.shape[1] == b.shape[0], 'Incompatible dimensions' + M, K = a.shape + K2, N = b.shape + assert K == K2, 'Inner dimensions must match' + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_700160.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py new file mode 100644 index 0000000..e999a3f --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py @@ -0,0 +1,124 @@ +import torch +import triton +import triton.language as tl + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == 'cuda' + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == 'hip' + +def get_cuda_autotune_config(): + return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2)] + +def get_hip_autotune_config(): + return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4)] + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + +@triton.autotune(configs=get_autotune_config(), key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr=''): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % num_pid_in_group % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + if ACTIVATION == 'leaky_relu': + accumulator = leaky_relu(accumulator) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + +def matmul(a: torch.Tensor, b: torch.Tensor, activation: str=''): + assert a.dim() == 2, f'Expected 2-D matrix for a, got shape {a.shape}' + assert b.dim() == 2, f'Expected 2-D matrix for b, got shape {b.shape}' + assert a.shape[1] == b.shape[0], f'Incompatible dimensions: a={a.shape}, b={b.shape}' + if a.dtype not in (torch.float16, torch.bfloat16): + raise ValueError('Only float16/bfloat inputs supported.') + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), ACTIVATION=activation) + return c +if __name__ == '__main__': + torch.manual_seed(0) + device = 'cuda' + dtype = torch.float16 + a = torch.randn((512, 512), device=device, dtype=dtype) + b = torch.randn((512, 512), device=device, dtype=dtype) + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f'triton_output={triton_output}') + print(f'torch_output={torch_output}') + if torch.allclose(triton_output, torch_output, atol=0.01, rtol=0): + print('✅ Triton and Torch match') + else: + print('❌ Triton and Torch differ') + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stderr b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout new file mode 100644 index 0000000..10f2a81 --- /dev/null +++ b/reflexion_oneshot_tritonbench_4/tmp/tmp/gen/triton_matmul.py_gen_triton_code_890291.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None \ No newline at end of file diff --git a/src/agents/reflexion_oneshot.py b/src/agents/reflexion_oneshot.py index d10a345..6618305 100644 --- a/src/agents/reflexion_oneshot.py +++ b/src/agents/reflexion_oneshot.py @@ -40,7 +40,9 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", "reflection", "function_signatures", "oneshot", - "pass_call", + "pass_call", + "pass_exe", + "code" ]): pass @@ -63,6 +65,8 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", function_signatures=fs_mem, oneshot=os_mem["code"], pass_call=False, + pass_exe=False, + code=os_mem["code"] ) else: input_mem = input_mems[ps.filename] @@ -72,6 +76,8 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", function_signatures=fs_mem, oneshot=input_mem["oneshot"], pass_call=input_mem["pass_call"], + pass_exe=input_mem["pass_exe"], + code=os_mem["code"] ) self.memories.append(tmp_mem) @@ -103,26 +109,36 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, """ Run the scripts to verify whether the generated kernels can execute without errors. To check for correctness against expected outputs, use the test_opt_correctness method from TritonBench: + """ + if output_path is not None: + root, extension = os.path.splitext(output_path) + tmp_dir = f"{root}_tmp_{iter}" + exe_dir = f"{root}_pass_exe_{iter}" + perf_result_dir = f"{root}_perf_results_{iter}" + perf_log_dir = f"{root}_perf_logs_{iter}" - if self.config.agent.output_path is not None: - root, extension = os.path.splitext(self.config.agent.output_path) - tmp_dir = f"{root}_tmp_{n}" - exe_dir = f"{root}_pass_exe_{n}" - perf_result_dir = f"{root}_perf_results_{n}" - perf_log_dir = f"{root}_perf_logs_{n}" + else: + tmp_dir = f"tmp_{iter}" + exe_dir = f"pass_exe_{iter}" + perf_result_dir = f"perf_results_{iter}" + perf_log_dir = f"perf_logs_{iter}" + #for fn, mems in tqdm(current_memories.items()): + for mem in tqdm(self.memories[:data_len]): + if mem.pass_exe: + continue + try: + pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.code, mem.ps.filename, tmp_dir, exe_dir=exe_dir) + except Exception as e: + logger.info(f"failed to test the code due to : {e}") + mem.err_msg = f"failed to test the code due to: {e}" + continue + if not pass_call: + mem.err_msg = call_stderr + elif not pass_exe: + mem.err_msg = exe_stderr else: - tmp_dir = f"tmp_{n}" - exe_dir = f"pass_exe_{n}" - perf_result_dir = f"perf_results_{n}" - perf_log_dir = f"perf_logs_{n}" - - for fn, mems in tqdm(current_memories.items()): - mem = mems[n] - try: - pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.code, mem.ps.filename, tmp_dir, exe_dir=exe_dir) - - """ + mem.pass_exe = True logger.info(f"\nrun scripts on gpu") for mem in tqdm(self.memories[:data_len]): if mem.pass_call: @@ -132,9 +148,11 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem.err_msg = err_msg """ To measure kernel latency, follow these steps: - - self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=script_dir) - self.dataset.run_perf_scripts(gpu_id=gpu_id, script_dir=script_dir, log_dir=perf_log_dir) + """ + #self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=script_dir) + #self.dataset.run_perf_scripts(gpu_id=gpu_id, script_dir=script_dir, log_dir=perf_log_dir) + self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=tmp_dir) + self.dataset.run_perf_scripts(script_dir=tmp_dir, log_dir=perf_log_dir) for mem in self.memories[:data_len]: path_gen = os.path.join(perf_result_dir, mem.ps.filename[:-3] + ".json") @@ -150,8 +168,6 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, except Exception as e: print(f"{mem.ps.filename} failed due to {e}") - """ - # generate reflections logger.info(f"\ngenerate reflections") with tqdm(total=data_len) as pbar: @@ -220,4 +236,4 @@ def generate_reflexion(self, mem, temperature): "content": reflect_txt } ] - mem.reflection = self.model.generate(reflect_msg, temperature=temperature) \ No newline at end of file + mem.reflection = self.model.generate(reflect_msg, temperature=temperature) diff --git a/src/agents/reflexion_oneshot_ROCm.py b/src/agents/reflexion_oneshot_ROCm.py index 5e5b369..35b7cad 100644 --- a/src/agents/reflexion_oneshot_ROCm.py +++ b/src/agents/reflexion_oneshot_ROCm.py @@ -90,7 +90,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, logger.info(f"\ngenerate solution") with tqdm(total=data_len) as pbar: if multi_thread: - + with ThreadPoolExecutor(max_workers=thread_num) as executor: futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]} for future in as_completed(futures): @@ -114,7 +114,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, tmp_dir = f"{root}_tmp" exe_dir = f"{root}_pass_exe" perf_result_dir = f"{root}_perf_results" - + breakpoint() for mem in tqdm(self.memories[:data_len]): if mem.pass_exe: continue @@ -133,7 +133,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem.pass_exe = True """ To measure kernel speedup, follow these steps: - + """ root, extension = os.path.splitext(output_path) exe_dir = os.path.join(root,exe_dir) if output_path else exe_dir perf_result_dir = os.path.join(root, perf_result_dir) if output_path else perf_result_dir @@ -157,9 +157,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, assert len(passed_mems) == len(perf_results_list) for mem, perf_data in zip(passed_mems, perf_results_list): speedup = perf_data.get("ms") - efficiency = perf_data.get("efficiency") - """ - + efficiency = perf_data.get("efficiency") # generate reflections logger.info(f"\ngenerate reflections") From 9c9bb328e483e616629404dc0e9ca70455732ff8 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sun, 24 Aug 2025 05:21:31 +0000 Subject: [PATCH 2/6] Remove __pycache__ --- src/agents/__pycache__/Base.cpython-312.pyc | Bin 4986 -> 0 bytes .../__pycache__/Reflexion.cpython-312.pyc | Bin 2883 -> 0 bytes .../reflexion_oneshot.cpython-312.pyc | Bin 9006 -> 0 bytes .../reflexion_oneshot_ROCm.cpython-312.pyc | Bin 10394 -> 0 bytes 4 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/agents/__pycache__/Base.cpython-312.pyc delete mode 100644 src/agents/__pycache__/Reflexion.cpython-312.pyc delete mode 100644 src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc delete mode 100644 src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc diff --git a/src/agents/__pycache__/Base.cpython-312.pyc b/src/agents/__pycache__/Base.cpython-312.pyc deleted file mode 100644 index f0a272db34755c9ad5ac90fc0aa6f7c474ac611d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4986 zcmbtYTWlLu8a^|g@%@rG32~Z`xJi@7y(G|-wp_a1Y@8NI39M3~#nAP-_QbX08)wGk z!k88*4@C+h)q;{NRQ0~>x@sl*w)@y}k&sqe*`y_5NSBqC7v56Bs-jQ(|7UC`HGy4e zf6M3m=RarubNesf%-jHHg;mHXSvCNW)^(wEgB2dp{Hjvq+`m2Vr)IH7l9N%JN)W>Op@@garwK0*7MkVUw zHA^I|Qws1APBM)15q=iV zCt{c-0Wo5U2$B(G;;a}jMXV9iEEh0G%+P0OUI_PlS}S>D2xZZlzzQ;eN)+ z{dj_XnQ6@&WU7O~88?d9misC)7b^Bbah|L;?YJP!Sy!9ep|Y=f+EzThB~S0NXU_%U zZN~rz#p<#qSJEC!5;D|^Z>RO7IjccpiPu}6ee3XRw(+)y%oKihLmAhJ=;UYih z5@1A*8-tQ#+$5oOpvPxm$Duhhy2%7(v>BP&u#cyWP-%QdmT_jUmo`B~-B8RlZEkhxI4-{|$zWpDzU1?kK`9JshmtDgppFe-#UF+h3_nbwazvLP&numY*+%+r00%U9wRxk%2OKQa~7pKpK7^gs5}WHtf1RmS;%G`A*yKy0ZXuj z`50P)7c9Z1TH|<&H3S+UaMqG)Q;n4%rP_Ya&vgnUXV2LbG-*bvy|NBJIo$yI9jZBw zo>W5-p)!smb4>!VV6cT>B`|U%_4v>4;`DB|dJ+eFt|eHfK{Z|zD{JM*|F8<5@S`=T z+?%c2-%}j+R!f_n1oMe+y5UK5gc?%C&7Dt!v6el}Hc7h3qeSK=!Sh}s9Q@{%!?_GrD!}9 zccB^O?3rXtpuL%~!wshOp6{wllfc$PV|Y>>I_`-Y7-SZ4lBi$OBXSU3hWo^0}pR z1!bjwsMJ4nW9+lN%N;KkfAeywJEuQ2&V=ix;oME1+JC0l;*lJl;C zbRD@Tl81iD6`PK&TS=Su^3kQEg_c6^a_hkS&}yS+rLniv*js2QL~lG|G^$SAOJ{{gHp#8j7|KkgwP~ zO7@Pt`L^A+ZXkB=58t)mChi%^)8D^m0Yyc}&inFXa8iG&_jefn;`TR)H!a3wwVDW`S%8cU*+4I5&hS$mZLH zKL{EJ!L1Ejf?qYTEXITDV zHmoPi{h(>avuVB$I$S^W$|F#mCt$bpzQwUyV7b3uZEju2y_=githIFH4ez(k8&~V< z7fcsT9|)JNOV<1$z!HU1CC{U`>kix%iKiRx-r2e0*jd6$#~A{HV9N3#zt|^#y&wD z(AD6BVfa_6Vu~A>__sh+1He=+U;tPxXbof1t69Kpu~4FV)(_@x%!kZ~{{i~`&)i*& z_I@(Yo$dwWHgD$bZL>6|n%{?)kV{p==ZI1@?Cr90yqbNh=K$cjV(72H9I1QyV0 zK;5eKeR4^FtX<8^s_9zQO~5(r3}J{lhk_Q)NOfE@Ry4v0C%7*jaSc@whHDGLaH-BX zUV06I22YI$nyPCH!f=%lMqSROAc$k6)>YO~Mi}*~J&u-DLlIMD+|MG6pVpdi1FG83jRcES>Vs2wo6qT#$Cc+;kTf-B&xv=zV-YC|GaOtwPUd-&*gXJA1YY#uNOT77oJ(GZ(4l% zc71oD@ymve)t#O5&wkUg_pjF5Ekie^RwvRc6H}#$DF7Wa+z=1d>rmY#B*FRl4MMtx zD~)0#nm_XAV}Cq${ou;LiPFG{&l*0PSndiGUkR4Fg76?dv1kGYo>}WTaQ*NXJx?wk zUj>LVvNTfQE{y_u+5HrtmmPkt*fO+kC*EC`b4$6xONFz`9fubU>=kxj$+NF8c768^ zSp2?Qp0O?H2hfZ6$h{`w?pbdqb*-PfdRFTi7wi}9D|Nd|b-N3eFCb0%NR#h}@7%Zu z$X5DPh*aD?#rp31GTzzE$Loi?#E-1*;bzlEZQa9e>&ILd@E>;%^3d_ATZG1^&El}v z_Gz05C0*k1K8v2$SX$F_n?Q#^8}(x`j0MY9HHqp3c3^?VBO?TvCExSSl`Mokz6N>{ zD*6j7Fn6%8HCdX~{lZrLU;{7B_}Zw$26s4}3TG*WY}8lAwmL3nDBQ+Ak;r|RrC55^ zJg6p-4n+Xn_ig=pLoq>Ut_nEsdyakH(33Drvuq0wZc8&irKj>=&;UG*FAMBw*9{!U zeMKx^5!)TobceKjO?+RIV|U2TJLJ(jr01Twha0#{p!og;&ch8X`uYxNboRr$>kqOaU$}N48$bvr;Gk9wEuU(O3Q0TlV`F z#qUi6?D|_^OmG898(YLo=2IB3-U7*I3ywLJCnmvDHcmR}d=41;#|*m;$a$R2lM;}) zn)9&FPI{)jdBYF1ehBc}2&06MLJ2V?U};KBi3@PLl(c{nzLc7hu>v&Tf-j+#l$hrF zTUK0OmTPWmW4cB77`{FZ(`{51%E$!B8)XqNp@pGLGuBNKtPeI|kZkEKq5HyQOeR6# zxoXOCAWV)M%<@y3>0H`zHIAe)jY+Odu|36wNeAqpOpjq0|GYUene%=0yTOg75P@6v1<=3p+>mB z9zIwLA6yL&Rpc*stb+_nwoVR+0w0gl6aWw>l!3KCIB*kfz6hGU6(7fQZ zVptX)Y*_eOyrs(^F)v(%S>VttA|Q`q5|4@70ZDYl##xg(mV+sQzW{QP?1aS?3zRa; zhC#YuAv#QRy_?7168BDE(P5Y%0`WJ0vhUmV_}Jt4*s4BWi;vwCztQiCzjp4c#>Z;< zcvTz!`>{SwWsqO$ip#D9nv$(s1}ZxXkIKJbpJmVMMuiBaS4lH#ZDU?~g=TguZG-fe zrIhgc^53H$#5r~|%!byc3{cVq`o!&dFYVgD}>FVnjXZ_8dOC3_yEd|*qtBTftR zGGtijWo-=0u)A#=Wx1J0^FD_q1RrHzle&E@Hvd+z#L}j1WO~~OfNrm;Y3EbOfq7+K zEh}`ktd<3{b(2nCQ&%S3gf`bqbQ5ku+gr-CG&bd4xCqx@uqZW?mTkMibdEZt$XzFU7*oZBm{~gQ21u*-d5uAX zLXHVkGr|Jtl{ph}ogyoW}Uc=-{o?=&a@O1}D z6qwS$zd}9-L^VxK<{jC~WP=+nm<~v`5ZFQ44+<@3WAHNSZuN!fn0sThq;7GOvASV9 zy#&=9%4)D3cM=GEjYItlBeP3yoLKFsb7=k8_FR4?lP`u`==SWVL6k78>)Ir*59&%%cybM;$92 z)uU&tJ?CnnbHLR?)y~T`?K0Tvjoy=Mp~GvEs0S+2nu2=5w~u{&tloW~)_veEUF{yO zXip;h?|!lxiC3gW<%V)g2dREz`NYzRdh}>5%JeHM%4+m4r@6^P**Y9Kwn?VK>4XA@7uA1kn44Yw~&|yKurLKYaX6 zuz=QN$v6C6dL6;?XF%{x2+x&apMLlB3j~X|nQfVjh1{U`_=vM-A7L$!u-fDX;V-di zoXyo_m@a1&5SbcN2bEu&ZVIRY%C zl608XVd!9brvo&};Iy5_OqjXOubBS0aC6%U4S#g5ok%$sdg(MjntvuvCR3(Ay4E>5 zl4C26PCIkE8MgLw?X}n0d#!J;<@;4tb^f6KLv=Z{~SkaTYSeXAuC|o5x;h882 z;IARd#FzlXVA&Wo#Eb!>T4tiAfC<`6QFF`^u&8b3s5Qn0ShZ}4+G6&AJ?02FV$Oh5 zZL>zJV%33awai9cF?YbNmTl3RSZ$zIE!(4>SY4n_EjwU*eV|?~JEIM;#z14t8}Ks3 zK=u>9>H^`b1@?WT#(SV?l<~Q~#L5vLqnM=E_!#_JqKTm)fnq5n@{zCvCEN4DU{p98 zNyL3drA0U^(NI_l4yNK^3EF~UWGEhzQdAHXdpIhDXfT}M1*Kuhz;GxU6@gxTk|ttF zppc+JilkLK(BT>oVhK9BCmIsP0YM7k+Pg!dFaY$TP^sEI(EHT$f##PpGTxt|MBffLlt3p^(!IFLG@3JV;3s8m!ExdiP7CV;dk zaEU=~FcgWxN9^VVN)wce38EMp5{k?8`MHCGMQX6PI(Px2p?D}dng%{`W7>Q?B!apx zF$E_|gat#~VC1aE_C}3uU?~ji082?BggZYIk%kvq3ZK%JIHH(icy7Y_WyNoeUmE|<6Zj@qNruUgAhxS1V}Kg8YiWo; z&7jvLt6qA)Sb$!W-j*RFtX?kvjyTF?y_7NDV%{&V!H_al-f+vPw*zHMX*?Nm>gDq9 zk5dMYT;=h4DPz*tV8|%(negVs)~ zDN$ntmkJI>qJoHD5$_Q9WJ}MP1U)4t;e7QB!{-Ku6Y+I$R^w8SNW(p9N$gQI3o5N( zLc53leu_4U@36~ve!wWzIn`a$zEt&0YA=y7i(2ttnU50;FxKa zZ7Z@(<=Dmbp7IjRE%(mp&cA0ITI+tVHXW2cE zaH>~kz1O^-vTgIQ8VQ$98&uHvs3-M5p!KSz!WmHE1__6%X$VZr-!ivM?-$Qu){tS| zxRPN;aNeTCoAoa*It~%ul40H=y!9QUewzP`kxKU@!|?3;MP8mFS!4N>RyyUG(z(r| zfrEegTsYsQ8*sjK^2%*+Mk}0cozGKi;VX0kueP5$19oIzFt)U!x>Iv`rA0|Bf`4%X$fD;*D68 zbeg;a&uP|bAfUI*DEEvdW96OL%2$ClDe90c8~2?z0j=gmgrKTYuxS)BHoh9`@as~4 zEl_e7TWaukAx+rDvZeG!RxWMJusX*Kxio14z0*!M5Ro}!V90S$!weZ`F4dnk9w%oQ zU+t0bLJGPvp2u@iHq3xr3{TcAy}%0;i+@`*0eC~);U7>j#SwrmaQWk*m=Fvq_Fxb+ zTPliWXE69$DiqavZW7uKOVL)y6pJ>Mc0dLHi)aHbkhyA7Rih2-f(D;OF(gIB3T7`D z6NePLs(4hpu6QbeBx_M0DYhiwoFJkiKrFTx0G==+AH^{k5u&`xtC%+XN5n)ty-YZp z1f0Z$IIMypje?~#*p@ zY&~WhAoIB?-dKed;=G8qSTQA%LR>N7qKYY+fEA2UA+BlaYS8c}qoANg36P1Xn4_VS zLX@JEC>}P&{8~z&ql#mJflydv9feKtNCC=B!UDX{PFLVG4l|Uh!VD^Q%_6~96=^rN zSq8-*)D%lX`?X53q}pIBk3_wx+6gsNKO|iaZRzDdcwSb+5b+$Gx7W$n2CJ>+?ekv| z%k~}L5Gc>p5@%h`v9jQ4$$M7Go>dbs&3e}5oIlTP7`QilFOYjFnCtoZ4AedA9@Ny2 z4^J3pYgSD%50ccqUrzs}6;Ihx!3v)s>yW`|GZqNitu&*X;A!k|%RmjQmZ>^6Y4fEhW@ zrbo36SG%ut=WEx^)~*}dU2xT0K6~-(gQksB&wqIAgJUyIJFZwBH7y^v6kLrt*UCcE z@_bXD+|-wEdPZ(~X1eW;|F(ao>EOLqx#{q@<@084Vs~L}ccH7NuyIRat$*_L&FVsD z&({u1OWPIeoSQ6PGeO@w_wG4)S>L#=IDEsDPhPvN(6w>ugxs}Tomgz@?3?PBJO4;`MNgt?#{ z`_QhFx{3g_HnPxu!ca`Q?#Wsg*te1`+)Ks{ulxmQQTB&z9M5D} zup$Pqm&NrMz>aTpN$L|3*QsR5Dk88F`a1#KW^Gw}#s-v)NIzqLi@amXIx+?d?~SbU zM|djj0_4Z)x}dkZvMmBjVr>T7ci9y-L;#h%!U3nQ zOm#&YMJm~-j4NZ&ADS80FASI3OeE{hx}+ABN#xV5^8q#%tZ23-Q#FG3bWs8%Qj++k zdK0m!2xB3}h;RPkUAzLS8@MMUsBab}? zP^=Mr*NdXMClQ}PWvEszY3SPs!bsOI1#QrL1SX3A3ik7L^7)F@S2o@1o0NW;xtW>i z+BU{MsB5~mE?>8HZ0`eS?dASA`*Y2mlN%>Hrc62SQ?t&e3$EsztK&i2s`u8syC&b( zE4TGdZJlY`dBrtu8jloKwqF~Z*fqgVyf}Gqa%*ncrg7W9xmpU#SB%@gwvZ~1MseM& zvj?EwwF95lte&j<+_UnFrmkuGr%ikAo-M>)%g54kEDcyK!#sTNZn|bHc$ek9e%b4v z6z@86UjMB3Ff`q8Tz5?DdZ&8MOt{V8){xqcIj^zG{WV+Kpnul6{y}5QwY@h6t`AK5 zCO7AnZTzh9sc~b$({OeBUu`ecEi1IId2idh+w$!@<@TL-jUQWnZJBA`FT1&M>$OIu zaYdocH@@eu`{x|Q`}CZ%4#@jUSmC-DVhczuw%x!g0UZ3M22gn)^F#VN_pwLQSVa?lHJ(=qyE_*y2Q)Q6fRVmPT;S zqBOD)#uxj5RwajN2~L+u@XRfP4$!j3@0nM@jU;+9Mu62!CASX1Y(pA+msw-{IG9Oy zc{RN>0+4iZ`sr)yP>44n=H$(oXH6Ogg}NP;bHEKfH5BqkqRH7;Q5J; z`V~`36zEL@8s#Y1@uK6N9tJ7+Tq=a1=)_bk;Fy73fxjd{VdV_NZZ9}%!7jO1=UrX0t1Iv7m0i6BZ_5qK zbxYpcDSJB$TpQRici%iQy6s~ouz_yR<HK5Y1);lq_5tek1wb+=V++8YEG2>;+o0dziM{Kp~mi&$qN3XuB z#2Y}Ue<&qh`C3= zjPWY;W>OIr#}|VrLC^|Ar8EXD6pgWV)g0PYLkjdY%qV8)q13bz63{Lb-X$tw=7-E@ zUVM!lDwv00x&*hWRDrux4tHty=Ovh)UiAa_Mf?vT%@eE`OFVKUAu6#qU+|ZkE zct&n`rcl>-b>EeJ`MM6duA@-zz53#n7q14d1o2t{GMuxL#x3*2)My`bfY+{RWuc|x zvDsKvJ!dqlZBEj>;zslJ=6qAP+|)h!^4(LprtX=h=f{kHshvO{bHryka&+GjP3s~& z@G)u`9KHDRr4Za96yKtXH5lX*;b4%W(W7mcp<$p1HYn;1ikqjMn6+cpiWxd5)C@ge zBFg>(d2BG7*~is(bNjW(Jb|Bi?;3OG{PWBK#@si5%)psl^9Pt}^9zhd2x<75>2pX5 zkC*z#4=I{bpLlc=wm9%Pq+N=Q`l`<&|?Q3gCS^f#hsAN85xH89kD$m^$$tY?}+0eS@Do`{GPldlb0S6 h-$T;%khDG|eklK*?EIQNXJB4nuGzjJYC)Cde*=fm6HNdB diff --git a/src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot_ROCm.cpython-312.pyc deleted file mode 100644 index 72bcb61fc1f2eb6a7e57597873ef79ea6a02bfed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10394 zcmd5iZEzdMb$9qYe2^dskOWCdJbs82_yJ0!Y|1h%RU-8*(u%Cu7EKv}a7PjpzI=C( zNC=3WRGm;NJ0hF9q$c)=X)_HyRwvA#hHhtSHAyE*GIjxzHsFgiQK$Zq{G&`~oH+f_ z_x5lAO~STjnwfSn*nRu;_U-O_``*62UltXaDF~4lUmfjfrl|kIj2iS=VC|m)n58(1 zqXX0+rkVg9qz7plVQoMY)DCJ1Ob2v>Iw;cx^g+X*ft2Y3#vn7u5ZDkf1e! zv<_NHnK4ilEFLT-FcYu^O9o2_YzmYH%LdB`Y!29ijzI^3EzrJvu$;iwKt-@}urlZz zbkdZDIzn+pS1GQTXWq~%v=3HI&@S6=F*D$zWu5dA7lhOh2#<{L0zzJq^ZO)#OsDzb z06*>zhg@2@h98#%uTSy}M?*dd$~>ZfB;=K%0x!yDUx4=t9$%Q_<%&%WecnJo1bpQY zAsmbVhOpof5H-OA6}Emp7#1da171<==Or(Web6iN{eUk3nWBUJosXU7CBe^MBs8wa zG{T2?!HX2RG;-OQQGxezr^4aDp>f_9mBNB-^@>1QFcRP;p7SZIpv2iFq1JjJKTGk{ zAk9&O8jc>+awbm8>xMO)?y7!J4>|o+deFd`IRm6d&cYcXWjN+4HE5d9xU6z%Zm~R1 zhj?)`EO`Z*qRcFWU2gV>{~{k^rBR;sheSz;lBtWVH^i~Na7cnR6Wd7j$&>wTc1pa+ z`k;ZI^GZA`h67QYr06CnJKgLVK&H0j^7X7&;8~tbp66I8%)+X3Q6JC3MnwaX$cBYB zAOcW}JR2TnhrRv)Y+@VB3qn|6gS;quNBHb8U2gW+a26X3P8zQx)En{!CSpJ*PE6^K zOv00SiBaGv5eD?K!~SuF?A;34K#~u}0g@uT7iWISFO6=TzCSd=`h(E)BFrNkP=6UYC*LOh~VnEghk#Bseu;O0hPUC-#f<2j~;&hzyU^xkK4dkF) z2_qEbG-^&Vs?Zx*0BUt=S)3YUR5+iGS@N(7#kIHTH?m{UR30m*xvf>p0ka|3o*J{N za6bKV%s`PX-(H2{I&}=1nxI@dl>9QncOl8+4@IJqC*qYxWgSW|)ah^qU5sD{1hNtM z57u0gnIMX(pBIH9EH7Wj$so^J>6DGx&6SH0*}4h~={^DD`6ZWL)`@&zn2aDe5uL>s z79vqmAq0o=4EqDTh`Wd~#D2QJ<5F0-AVz?{I!0l;f#7hc9e6b)b%=tmgFs>jIa%<~ z@`M#R^qhXWH|UN`$c%@K?eWB_Hwmj-t*e6Z#4(6orZVN#le#H$rg9r3){GN#MIUbO z{Cm%&ZiT7OFxE6vpJM9MOiPMsS-Oy9T2`3j_qMezR;+B>gKcU^o0=3;lV%!HOvB>k zB-60M^gh6@&J^REbH2yaufb@PFUN9_K@CLM^*2yDtDNIBoZ}h^o02dBdgkxw+qyTh z%$U%`>2J)$=`rlLA#i$i`$fxfiZjIN?@*laHLc3h-=w9Su*7MOc_T~9V^l($XVpTi z9?vmv0<|3c)V{#vIWf4BW50YEFlqq?SLwX419qVWXf-`*515g8UEADULg5zbToQ(mhNBX&>M!qw(H&~9d^%Q{hgAr0 zqF8JQDo)OVe3~$7C^)?IxbAU7+{jt6lq-TGC3}!0n2_s=4sca(Km=KW!K7e_o48`k zLuwNhzl5al$ONY{oZMySXgRs1(!G)PfEUp9X3dCDNBPG<6ovVw#$!r^?5* zXQ@lHt8Box9)o%!uOXl0wKJG6norv|3-G$r#oZGKgV7=GbN3U&VgT$H5dM%i$a_4p z+2euZEgC@B>hZi3^#+uh*D0X^h7#%^k_}2%p$Rhh7f}&hrB;inqEe+~FsyKFCHDH*eelFUTF*ztg#2CKys1f$1?){kr%_VWRbkSfOX?lCbO zif!S?BVa4BUKX>UOaq{;&qpP8BrLHphFBC}^vJXjGiIB{%2?5hCJ*RVwWWYkEmS~s zSZGJkjnPhwc0uGS5l~=dh7WNf>SI|Kj_@H_hl9$xKo~~Q2KbP2S{K79e?9<*v?zgf zBFg%J_dFjEaFGP$HCg{slouvs%Q__?Ge|nFwSXK2mSzM7;9O>58+Kzc!bO$oVcD!` zCD>F^Xu~qYu;_uDYzQlU0NJREg*a~H_tzXD4+o6eg8QAS*>SK-3 zRQl?bk0`@a`+i3Oyjn(C9Z5@V#$J=QH>T{3^Utl=+mqI>C3p4T9lbl4e9n{X_}Vh$ z?d|tU%cn=@wJW8K3-rCMb+^vmJUjo=cZR0TpJ=JF?H}4JreojU|2^MA;}6N1zHA~$;?YblVEdFot<=!*NgI`U4b!27qLNai1dGu0pWE>hz&<8Y->^-Of zWdS{s_U;E|6|-$KZRxW1m9qB9gBe@tjq&T__o{X;p8nDEKYV_Ty^YjU8?HjwBdtlcK%>yM_Z<~BeQ#NW`}#> z;_Jm3m-}Ojp{9PuxLQJOZJrmtfBCzYQ(L;GP1)wV7G0?wJ(Hl!yk@Wp+Fo?8k4kO5~b|;}lY41+B201vmuMCQJ!)+yt1lh(B)r4)vNYVTo%5 zxN{_|UqVxE7NBNJkp>KL<86aVBdG2oz!k|WsVmmo#&xyDg;fzK5@R)}z8f~5#t%lx zITk3j#fuB7D5{W(ireA_^|BeaeOohCucHzr37b?yh(tO|R60P#f)Y)X#*4;KPB#QF zL~<)WRj#8<#K2f@Vubd5@hr~4$qm#~V|dN2soHvBHZ45Q?hA;*u(bT^s6nyE&$MnYqUghonePt<$~_t@h#pf@ymIyi??vYs!`!l{+0 z`FahQ+CNpGOY|j58S(S!otjZ;`mDh|^@>vNsvL;f%uhY_6njj{I^futj)el5_WTyr zb*#*+`loXK8*nb@CGd@*Th9{oj(hM<&kN#j@uYtp$|OD*L3H4firLLVG{lNN!5=}t zRfrvlM8U)k`+^>@IVDlnf)PlJSJ?`khQZlDTwJow0~N9fdb@;0=L? z6|d{E&L0{M%bKt#>xkcl5hDRVIspY-Hvv63Srb`|o)7vZxl~<9kLm~$N)VAD8bR-% zY#jB9UP%&!0qlx0B0F$8o-p`Z!LcI1{QA%-(?fw4RlQ_%Vj|jZZjOr3Vc!74BYEvGqOsyLWQxM=0&^( z=Pw+aA>}Hm1ul-2Z@z*+F>aR#*8{kEe6Vf%O!w`s1?dOz*W=5rJ(J8mN7Y<=+OcD@ z@1C{n#))eulGQB>yBC@kbxG%AE7r#|w(6v<>0W)~_nW`loUZRo)pssFxm>@0#x|{+ z_GfAv=7#4F%yaY4E*x8UGP$LD+Vr2cn#|U1)8>y2RFPf5*uG-zfQ!rA(f3NXFE~E1 z*M3;lx@3N@s`t)#Ciqf17)u3XaPx@MhwuWDfaD=Xh3Lb05otK0W|G7H`90s?|K0tI zyVG6AQeDUHb|p`}u-q_|e9@a~@Ip`QGG_z;I^ApPTB+gnnJw(C6E{yR(60?V z)KbkopXjOTJ~~;|yK1J`##@P-iG{NZqsz73bJ~n^OWNs9Io%859ZS;bUU8m;qFa`m zmiYs(6|d?kcF!lJR9VxiQ&Uv3Zi4O=>rS##_20XsEbh|KO7ro%(tpH%9#20#{O;4k znd;hUW2U+hqWYaPHt5{2z3`fDt(!YM-#FI}EeYsKR<{AX-8FxBp*>l@XWI5*WzAgQ zt^S++3$BGd$t}CzuY7D;o3U5S_WX5E#<68pV<_F48Twj!Xf!o6njT7|h7zzSSLlP9 zPcR|NiwVgM;QC&n2XN_6X`aEtQ<`V75EGK~C^-){@oGjf5i}E;q#S3b zkIjuQ)GS`Ov-fW7p_OvhW$N3fdtW*6K@~u42t8gruM90WoJ)Gn zry9;fPwb*_$kG1=hjgTz9gEsKHA!d3inD(chuDxqDyWj0cWq6X6347{#+oi^PL(t- znBFhxe0WGpmAn4%(Kh5#5KG|2R8{Y%pBSN6s&xAwKR!sq>U~6MOSj&0RL`E6IgzaI zy>sABSCT!l;y9XV?7Q=oJH5&Jqtm^AbK<^Z%O6*b(Dc7Q8lr3l4xEJ8b=R`X4ux^*f9|er$Wc&HR(<27rIE?*I)IKWjTs-0#-?{0ZB@9{n$v zvVk_^FRHf>G@F01zX9N1HXNX#;%%EAy1ebs_iwYjUBy7Zx9jx%t){o#rh!J|+uij8 zwZ?a9^#H%qtcPyzwCM*Pv%Is*2z}q#qkoz<`B3+$0o4@d*WQP#>bDqtG8t3r)};bH z0dJt!8BSBk?>xy2)13Ay1AP(ShR{d7ku6cp7w|TN;jM&1m#Ko!A9TU%cL9GcJi+MH z%T+=LG!^ha7xMIq_PCDIL;Et7wtR}lR}5Z561vb?FdQG%xBJLrt1H79@CM5nF$Rw` z(xcCh4S8JSN1V^PesrM=yif-COn|R||M@e03;2{tUo-SAP`9qHjf=bYEsfCx5YZ{n|_HXG~8hqMCVD!`fs0Ht}CZ=g&2L!xZg1cGc-YIM0 zg&a&&?|Bg=iDh3_%mGsm+E?QdVPEX=%|=1a_EK%MyfZu3?_U6&h(6Qb|AzYg%T%&@ zAoY0vf;e}2UON|^e`fyBtNWqg4?<|hhnNO)Ntzcq;c`o6zEckxm;f=hz?~OH@&!FwBz8$mh!k>Sh z&34)wW6fl`#5?|(2i#nXlc(T20(_PxmSxTsN83LzK^}>r28fOX;N=+&R(Lrc_4`KI zOFTZ6vv^$q;|`47z=T`YMMYkCU4wRI08G*Bd#7*`R>4)OSbe&&us=lBMZgn7Hh^0j zCJ5fq#jPE|mWsVQ47U-q^5Lrs zj2UOZ`d5 z@s#!WwY~?HHR;NhRAo!La%ZY?=c4^b6+f)_QSA?Fmn#q4sY_M%-FY@ud2;e-w)gW( z$CHjDDeIBRK2ox0$^KTwn-y=>zFC{xb4)=x*|%z-thO6ZUVoCbd2*@utrKsaSb8ey zIF_;=o9uf4Q{8ucAF1qK(!N#vW^uCnNYZgMWjzYDj5S&EVv2cjQum;|WlafYO9_eYl}2 zcXr__ln8%;5s^;|kxv0J2-h&Wju9Tp>%@`B9}*(`B}Q(DT>6btXAzW7jXV;FTrL9w zm;batkuEQEk2QXQbP+`4HViGrOAxJ>u6q@&>59%&MQ6I=iB!cC8As*p;hDo}M^nnt zlqq-4K0EX5tY^l9Vg+}WRU=ioca73jnkOyrI#N}escCws*A^A8YW1YdN>y*WReiHM zUDcMVYFl{W&V^)E+j7* zF6{Go1mp&x9wYR(3zsp%Q$@hZ6Wx`60u|8SAyFv+BCh>a>Y+xjXC4-t^$m0WH42h7 zXS2R#?KI8N`mVK8T2^mcJ4zSppP>~*hy&ap&mk&gFY=*TREQEkh|rBC7IXuU*H8Ge z1Vu?MQ9k8x=k`KCvrx9c2T%dgO=9Jx5k*75gOYF-9NZ$|7}*9M0rI6v_5oOV*|d=F zfK2>+0as&Kp VYX8T~b6WZtdd~DaiU4GQ{{ghpdSL(n From 91993f2fe509c992b662dfe843e86785b29c0195 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sun, 24 Aug 2025 05:26:33 +0000 Subject: [PATCH 3/6] clean up __pycache_ --- src/__pycache__/args_config.cpython-312.pyc | Bin 556 -> 0 bytes .../TB_eval/__pycache__/utils.cpython-312.pyc | Bin 13229 -> 0 bytes .../__pycache__/ProblemState.cpython-312.pyc | Bin 1187 -> 0 bytes .../__pycache__/ROCm.cpython-312.pyc | Bin 9469 -> 0 bytes .../__pycache__/TritonBench.cpython-312.pyc | Bin 16573 -> 0 bytes src/memories/__pycache__/Memory.cpython-312.pyc | Bin 1368 -> 0 bytes src/models/__pycache__/Base.cpython-312.pyc | Bin 631 -> 0 bytes src/models/__pycache__/KimiK2.cpython-312.pyc | Bin 2183 -> 0 bytes .../prompt_for_generation.cpython-312.pyc | Bin 9969 -> 0 bytes .../prompt_for_reflection.cpython-312.pyc | Bin 14265 -> 0 bytes .../__pycache__/retriever.cpython-312.pyc | Bin 3354 -> 0 bytes src/utils/__pycache__/utils.cpython-312.pyc | Bin 2442 -> 0 bytes 12 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/__pycache__/args_config.cpython-312.pyc delete mode 100644 src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc delete mode 100644 src/dataloaders/__pycache__/ProblemState.cpython-312.pyc delete mode 100644 src/dataloaders/__pycache__/ROCm.cpython-312.pyc delete mode 100644 src/dataloaders/__pycache__/TritonBench.cpython-312.pyc delete mode 100644 src/memories/__pycache__/Memory.cpython-312.pyc delete mode 100644 src/models/__pycache__/Base.cpython-312.pyc delete mode 100644 src/models/__pycache__/KimiK2.cpython-312.pyc delete mode 100644 src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc delete mode 100644 src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc delete mode 100644 src/retrievers/__pycache__/retriever.cpython-312.pyc delete mode 100644 src/utils/__pycache__/utils.cpython-312.pyc diff --git a/src/__pycache__/args_config.cpython-312.pyc b/src/__pycache__/args_config.cpython-312.pyc deleted file mode 100644 index ed62ea94b9178a5147963fc8bfe6e3c473ed5043..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 556 zcmYLGy=&V*6n{@A*O7%oLyd<*jKPJXAu$=&YlD+*Sx+Emi#hH>NTfsvhQ{Sm_@{ad;zxQ?br&ufiQT+aU_c!S)?JUU1 zjCqBODR|&v08Oop00&0XK=9ZYnrwuQd85G%2bI;}17F76^+heJ~f+gWkjAEOXzWLFg=U5t94@=NUaXFn-`?d>%+2aIq_#?OSWv;vSZ7aSa!_FTuP!uk;*O= zS-j=hcL-4F&ZwmpsF)^cm?l8z#^|Xi(56?kxa*~Nxum&UT1us?9UzAqxCZ@4i4C|= z4rt%o<&u;_=X}2Y(GmFO?aZ66H?uSIesA@m*=!&nj9q+n=wdxV{0BZLfzB}Zeg@1O z!4M1?BKqW*>?5UTRi6rCRfrm;`)CsLsE~S8-KUm#I;0uZ_Gu-a4C(rGQW<@p9%6OK zFly{GX7WsZW=8W8(O1N1fm#?lqkD<$vofWO9^zug0I4>{2(*MT1?&SV#{3f9SIRh; zBFJ~Bh$94Jd5K`G@2KQn_c@tjkajUPpk+)6Q1>M1b$*W7Jzi4Ovi)Hu7>)ucIe{kztYai{uE0WAJ!9t9zO!BJ2n^*6$BA5B2*; z`lCaUaAW^qARKMxSbsCqAMFoC`k4UBHTNF$1zzkAHIGMwA+Bj`QdE!ibKD4Pg}Th} z<4yy)N-S%2Q(A_7t6Sy4Yq8z;2fw2%QWUorpjIn`%;AOEW8W88ZS zMMM3v&atH1nAUAfjHk$(b@Ns8FD`7^FL)2U%L%oI7Ihu>DZ*&Z75Ugu|CzB&pv&Z0 z&R>E!C(n(*x9)RiA*T8czG2p-;8Wro%doqdU;j0ye7B4hBZp7xtLco)JODX&H^uVq6!^u25e07vdP5tkI{22ys6488=`A#*RQGF_!zZoczw3^$LXD#H)hKl83Rwjn2oPdaT8;T zo1=JANP#ynW<_o%5ijD4hH=|m7nH3!Zi(W~Dh1x6Xs?5K&P@;(Ko@aqyqGT@E>UE1 zu~GwX9)Y_KM_zN2M3|%r-g3pt6ye#*5N;xF<4wc3N3RP^wk_U zLJRFbS4xx`rGDi1T)(#G59wE4kJf8JPeeB;BN^VPaJiVV^ZG$J2jQ>ursJK=nHMkR zONTu*n2S`>Q7=%k$Qko~e*#dd(_($Xq${ zM*U4I_R(uBd!!om%!Wtqqw|q!H2xknx*uL6%zGK`F{}!-x)IhPtORBgu7oP01$LeZ z6-k^UNcfu}ul?6$>Krja7RdR>hvs~*aL%=NACmscL(;$Vko0~fJqn|cLp8xR#LHn7 z73Zybk93Yz^5w9W=sbPE+GmL9R;4UouH+5Se$_cxBk}U@l%L7O6J)ePksNMNxLgb^ z+erdiwtTJS&96~P&8X4IstH2rxx#s?dqho?4Mrnjkpgq3_Xb#c(nZl18Hw#IWY_dc z(gy=!|Bz(&H0LGwj`&_|iJ8%Q^kwZIs}Cyn$&E-B;|-FE0o@j*i5BV^?A&&>^H2*Pd$d@)qWtup!EXG)njEqHtkuX;e_g#OG19o4` z(*Q;vUU~P`4UC6Ebq$<9=o^@WZN$ggwxW2n&fu4~~yP=}M`mtGZrDz-O;7GaN}Oy}jXARA!36(TK*51_3k>ZNyxjYObiG!lZIh5Z3h`(l8-6yXA5 zaTF}O%)2)b2?0U@)*d9`E17vVX0qivPPg}T_zoXE-id}e$bfc5MM8S8=DoCXw)N?CBNtL;c*44;3bo1EN~ioGN%+q@^?qE(RGqlCw(JCBR!7@pGan z(=%BJbj6K_qS$p%tfa>*54x0ib4ErkD4H{3d5xEt+6pEzt@)$w?{*8;%2db4bW7%OrYL3qm~L9p5GBXRtKE03wmI#rcFr_w zN}NwMFIab9J$jcmO`UsHrL-c{w7qLtaCI%xNAIl7NnV`yEx0-t>BDy|70L1BpwPHa zsNTP5IdHY>j?F$-H(NK?Fx!y$cB*s1wjCOCR;R{OgYy@JhJ%aFLjrBTYbi!t>N-Eo zCwfv{3znTw(pfe)F*`BG&+xNWbR z>R7gvC!%Rf)v~i3JDNN`->~330M(l+mrV6(Q$1+e)bxdxD6M|rBn+PW9>P#06m3jW zi~8ymx$LY+oJm}oi6w){^J(Wcp|vaR?3yy&wtJv%lll6Q=_4~8i5;`uADcX1P(*3v z12ti&Suqgi;+eMT6N#f~Q%#Can;KHlw5fHT`TOAC1cj&0F0_3sZTi;Ljys@?p#-#5 zkt6~1yRMzq&d}G)4=BR0Va2*lZN6K&Y3}0eMO;Bb^^@;c3un&>{=vo4p((?1>87R9 z+H`5{8WuEUa>!6wvi+8^<`awKj%exNXM`}j=;Ef#%dVOwS3}y>u!hB^%9e~(X=7FLg5vV1l_x>nm0t?1x?p=DOA2n41N2^XdcWUEcFo_pOAf^h<(L zZ~4S%yJISzIh@#;s1qu8%s)MUNU%MzXnJzlVxI{ldf#eIzVK$7VDT>Myemb}>{s`V z(kxb|jn!B~Lw*g(a>3TPXlh!vZ2e4US}{WQSFqsp=2h-CY(M|HYv;*Y>SvAZ7Kq+2 zu0PpI{i4Zsil9Eww>wYmBR;Hbw?Xu8J1syzqHU+l)JG1?_^8o#s*L*R0P??d*iO|@ zzpR(|1GZBw)US%okp3%=8fmTiQfG_VADdbP^Lb$Su1Rm z_!{-8CathV;#<_Gc58(;iQlI`?WEG0%BLHtbRCI&J$>3uq#Nsizhx(ZzvaO6TTTjj zx9W5SwFcycO;(2g7+}!#5+@mwCtf636%QbP4q)T6W^x?d35DDYVEzMW5T|&Orx+EY z4h)4xQw&qdYZib%5MZTyV$}1`KQC(ikuYAHSiHM?RV>1NaIr!s2onT%NC*H{FD>fe z5E_H)W`ONR0iz@!dr3sKbnJ0B20V1+^QB^qKLxMBzEEg1b5mo(AUg{`?g$WoT#c5i z9m`tNOjlZ4v21ZpS?<_sl0yr&=BwSy#T84%TNjJBPN|o5rWy52Fm3fDcciU#XJu za53r+ftLrY7`WP84lurnU~~v3cVHmshr7pQ(=Znavs<9Eq6Y3{(nVUj6&Z3e;122W zQtTcSqEQ}hJ>VPis^KEke<=`R(E-Gsld`dwA`Mr?Q4a5%xFj-bViV%-(iJbSd7r#C zq+0M^{|fvvNc^%b0_S($5qk&DQE($AihsaQ(YNg`JQ7Q0x~9LAEM2hF!m*e*498+` zsv%voOK|Q=TX#Ls%Uj;B?JhjbJg~Qz91n~Nx5@!jMqICIFQNX{2;?U^5^0ILy+SLk z^fhm3c0Gl_w7~TQZz#d;;ECbvs)E`wA?Hn!-@-!PC;$qwFnYm!buK>+@L)9SGAIbN z=k^8?ph^||sJs$rWUUK1w-2o7S8+w>LCSO#-)Jcen-ne=D{Gl^Kcq}iPFCTRGHea( zU0G09$tCi~XUnB39G?XoauC3^3lC}Cu9V8fO6vdwPvq(48(WI2d8$xPd39bN{(>eL zO(76Cw`(w33AidOm4!4=1|cH=KVZ)|2Ylx3*Y(8ouv^Ra@8|ymA!gXi2SA%s_9-YUYRAKJ8ROycupiCPj=%u>B-XBx4+1*M2Ea6vfeKEv4`k$g40z$fK?$M| z=`h-Y=4>7w9vr?*(%z8S984LbQaVi7X~=yGe%yT^=vlVSFlntjbCFJLh1WgFrY&2R zEZftT?elwoY+JB&ELo1HEyv$KCp_D?U^zcUuh2HN{f@;Z?HTqTT$!SmZT2bc@}^2b zR{{60QrBG9Y*(W9&w8dh?pVqa)o)R68{RN1TDIPdy<7br_4A^i6fHKN{LJc{E1E5O zm3#f$FMm6=V`0;_kF70tipvunZ&khB{6_P~#oO(;b& zYwEqO@woMp*4qKb#Ae%H(0_daCz_=Wd)6c6Bc_Vz&j|@&8_-Nwpou_MdaC^})s+)O! zUS9QB9xVMD4_40ShW;@cJXqNmEB-hhEZrIplqPP<+a3&zR@nl;gJR~*!&&bN=rI%C z$;9!ZXx1mi7b!10Y?JPc*U1vM^43CIgHi);TIY2_pN&r12GHvSKC8HG{j~tRE_#&U zO_H1;pBD7G*p!~3WSqzwt&z$4;xeNx;33OhC#COl)?88GxXF6OhVhtO7Z?+7M32HC zT%_RJ77eTp^E;UX3jAax!>)A&;E^dpoy6_Gqh<@%4E{9(&|;njYU9<=mabr1=;v{N zy)DVllf{{H`EjQnk5Z$O!W$W_A1k6GK%|L1ay-2<*%^6}{y>8yI)^Ef+hO=HUrFEFuEaHCwt_|_BxI11R zui)*x`}gWYIbSgVeOP}D2CvB(nx{3&OZoM_fJ2p)u{xC=&7oTs4@$#-GrUF&=o z($zd#rHqAlDen}%odcST|HAbhm*l^I82lI6>qQ|CT*iN)1^-2Ee&p6m*ZD6Tir1ow z$Z=mAVUse#ujSY>7leu)76;JFRs&^gk>Dh;;Hf60yS;?5OBVzAiqHrdF z0O!gaiyXTX+iODtOByZ}fV;C#00&^Zwn>H}Wd_W%plL~34Eq;oFa{5l(fm7#9~h)kqo7?$n_$5775x-Y#);ING>3G4#`CzqA`R2 zB328AHzKxhZ;Os|xg3}R(U`%7aXzF6M1}Ez9l)y47!{2&)GL=Z;6%xVM1_Yk9D@Gx z#q1$eHt;t+1VD!V`9lEHvvW|-Kf#YX1eWzx;v3N73_SdqtsKm~0uXKiptk^o3)i%V z0WX}_afPj@ZLCi z5Tk>tV;FU-PGWRSbqb?Xs%J4er+N;fXH^$5s`DYOeIDrxs&R~X)l(Ei$Ec?0#xeq^bwE8W)h{#N0c zOM>N+pu6;I36d1-&GW`oT(Iq1H0@t**}K$oFx_(S$88HO$I`l{DROF0!jL?Mb#RMN<1pv?|@@vTfpt~{v=o4Sb0O)V@wP}4VBC&=#WH=69 zrhjqnJX%Zr)NBU&UXkr+EA^hYy#aQmpVivBtErz`Y~3x?FKX3Cw`qUTsy@0;`vIZv zt^l~S>ezPbg9Z}$M!I_w@xk^w;6E%Qf&Z`^(?6`Bkl&>0uB1M!R(CgPKWtHV@6r09 z*=$gNj(gi6X^sdG_+2=B6HfL4*w_bfSRX(ZeHupH2jA<#_upFh_FLCyVDx>)NrTrY zYA^XCp$N;x)aB=&+&L;mJLITMj<(BDs~okA#_00%+axOUJ4SoFCQ;{;KI`=PM1v1q z+~e?_rw_dEz8A*(L+C_=Crwvpdq?MKQFF5U;PJDa7#}*<-Xm&Hb{{$2*$K=u?WcQ= z_8bxE!|g}Gk=NPNaY&>Oo;=d*E-(c?#ZI$6B>=DzJ^U&coyfa?)lGcYd^%Vy`* z!+E60b@j-yM2w{XY#Qwfv6g-Js9YD;um)U0&B8`W+`|J1zw>3SQQ+iI2| zL6&ET8zh2uQiqcu%Kjw~_!t5lws?GS*rH;54}uvV_C;79xMSJBf($c`t_+x~#3x$w z9gE}omFX*jZsT3xz;Pz%%4NJ2B(fnpgw2rjyUfuu1p79K0pu)o@J-;?san)aYXvb5 z_9kSPV3!byg2TjUyS{gNuRuG$j{4FdAt~}lz(|U0%u^H~ck4qA3(A>8F%ZftrCDDF zZX%GhW(x8R^Njc(QnRcbb>8$z7Z6ELIFW1$Nd6p}FGV6DIPu~23dRELUqiM7HHJ5x z{7Hm8>RWG?r!5UPCw|D^;M0~pf^JV{T3_`iYhM{jTf8@q{P6eC}P2KYtm+KOA6jlQ`@xXL*R_OW8r- zL=Ar60T>g`5{uIXOL7Rmt&u-cKs%Mi%}y+m&?m`pfX&Lz2{nAIFgD2|av=Q!g7l5O ze7zgQH1sX7myt*cn??@rBC;lNZA+_3QqcMpr`s#}<$szK=F_oTle?b8G(m|6IUz$kLbx%W+Pf9BPU-JD_iU0rr diff --git a/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc b/src/dataloaders/__pycache__/ProblemState.cpython-312.pyc deleted file mode 100644 index 041258da8695d655df08fcd5067365004d082324..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1187 zcmZux%Wl&^6dgZfCvEzmtx6+^Xw(WXAX|jWOBHVj2*kz2ob2D?#9e)`g&m-|PUw^9qQV{wM z!SFFMauNu#j{+1Z6orbbgsQ8CnyV?GsZR|n0R>sX` zSwm#a|FTv)=U6}BT}R{klSYzq>sd2x#1W=kyA!n?mD>SMu}^W5aAUneXy9kH<&TyH zCv`#g5kamJAXj~bh(`24SrKZWx&|?cc3yCyX_=NA(ezBSMl>VS>=Dfbt*tsaKK7i& zYm|g9Q=AfRtv4u%a7b$0pm>c?Zl@$kJwFZzFEpYgWi20$#jPZ!EnrTbTb>7NJdfu+ zFN}j01wQ6^Z(5jU8BHwZ6P|~oC>DPu7ABs@3St{f2N(kw2Pg{oj7|iiKi8{oW44(z zu}`XX>~G?<9!FR328q&Y!u%?IszwIotWt5i=nofaFjvoBQv(<3W zM%h}pAACB+Y=flUCSehw(kj{Rnr!twZa2kDdOoIuM4HPh8+dWGWlqn`hh{?}_gSN^T42Yl@;Aqvfw?`50a9n@BDERt=@nM*>bX5%y(RlnF~I^^pLX_^Ape G_xcxpJQd>r diff --git a/src/dataloaders/__pycache__/ROCm.cpython-312.pyc b/src/dataloaders/__pycache__/ROCm.cpython-312.pyc deleted file mode 100644 index b207f844bd052131ae5c7f8578b2b2572c84c949..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9469 zcmdTqTWlLwc6Z3(`zev8Dan#3OO{MqrW`-xN94#4Ikr^UvApqGNWjvZkwluW+!(M|MO4UPZNL3L!su|Kyq>nnHjp~MUQT>oUY8Wy|b#%lSH4T}hye49fT81oA zUK_DSZ9_IGuZ!5DHA6KNQlYbm(ceLg;l8Q}$x!V$4I>PdP_6GC_@NvvR%{KNXj{`y8 zKODjw?@x^LEF{B$P$Y@TEDkav!Q$bHVZkCe#kd3;TJMJPG-46K%#g}}$dA!8syh_m z#=xi{r5XAT8qzREMgt`+V`8+B>KGl=>S1qYktTo+1oagwXdthK|9U+|77!=_`ITlt z1tq9220;xa86Y?=D3Os;5K!14=gX;}yI&d&eaT8v5NIfA~2P6^V{G>%RB*y*2@d(3W(H@F%JWdArP&`&_G?MxJ z(mbLr5)Ux`1SeV(MZ)rvX;@Cw10DKtJkDc+4$+{{7baVW?i;8FlM@l?P;59ZYKP-E z8sNQpk>=RQFeZq`b`sT)Xbb6S1^7ZSp2Z2wLM)mv%OxW~lLUAxAHak_ud_JDMmUZD z;3Co>zJVPdUv62`vm=#wcy!HijAMwH$l->Xxcs(nT;%T=bDa^ z`i6qr1ND|CT2#MtL7T1X%GEtb#;rCzx7eO-I+ANTCiPp@Tko0fnljFPE6yWX=aHqh zfAsx>FXue>NPm7EmPWx!QZPYViUp(@!h6r*~ezF)c$2T@whD?f^Bfcy=8)uQCdqan9@cH#V8us(T_B@)dF>c!qi(A1tYEzG+{!$n}T4JcLKPB*^43U zlyl`Ytgwl49k78aEaSFX-dWoGj5({Z7PAZyXFQJt_1lOsy{nS({$r}*49)-j^9nbX zp_X%uvJzvF@MmGHf(3A@JC6YSrEahyoNX}(Izca3fSqf@EplBsWo+d$dtkp`qT#ev zP=#B|z2ux=mG?;j{k*ju&P+9kTG4ZGrei7!T}2f9pe1}$eHD#S-kJfC!YNIXAKvpk zNI*Rs3&xpHY(z9w@JCS-30z|%DdTY=6!o()M(7q%N_Sj1bK$bk-7C_=p$MBYd17(i z!`=*Wd`s%jTmR36-~A_ql!0M~C1K$gYd0`de=HDXy&BvGxWqdk5)G1i0Qg`TcpEWJSNhFy8@yh<|N(@@`;5S9cB{|;L4&Fv(fksR#e3~k(L-b z9gc@$k_ZsB652RX!zCgi9_z^RYGC%FHjsdY7_SZ26SxS$VRDaf2kBz2I3!LHb%fUg zJfRp{u^DPaRP)gU2qT`2;sa!gUD6aK0aui7gsI3frIg|((wB(*lvxrJCDBMw#LWQj zdlMDv8YZl(mYT`7mD-MMZO4LRaqN-h#PWgFhUU2=vq$DHzkOn&|4TKpzDPZG)lY|~ z!tXohnU(fK+4e(A>gADq%fL#@V76s2*Ybm$>y-)9W2a~0bfLa!CY`O{GtpOYx@Wdd z{mI0cf_=-hWy&%WSvZ-qcdghDXYGfV49oj+_TGt}LhaUxQ=ht9C+UL2IYZA}nYH|@ zV@IK(d0snve3DEzdFnULx-|o;Yg;oSdsD{dU0{}cS>Hn-}!5E*V;j}rDg39*^s{R%M%34S<8v9I3hLQcJxTuP;vpaVy! z=uNZ=m_D%LXwN#@bGD9*z5{!pK1f3(2O?%A@Rk4=q6dd7fiLx9C>boy@su7arHAS~ z0;WsmltQ?2wNf3gY`mI+P*tN=AYjxFXjwo}AYgu^64WsxPmVwef`Y;{4eE=ASJVK3 zVl?tdP*b!Z{;cw-atgY?N}y9vQI9}m|5Q>3bYY@DZVG}ftccul6Vz4_zfkKc_mOi@ zTSWk4bM^WdMbwUWG->@)&OpHt-U%YEig4x?F&ZYkW>XOK@)(Gih ziIa;u$V$@b6f|OTrHewMjpPWU#=A*GWS6Mt;*lhI^@y4=48j2KA+;JN8BK8b2+7kV z?=?slTQm^KS9(2xjKI7<#N31h$$Ns0iF%SD_gyr_lYAn{`y;SBks20N@Osh~>p1cl zRm4h>OBkia;Ux~5gH4hv<%&M&&eg(=2hn14uh?3$wwC#wdE4%V@s+NgY*){6N51QP zMt?qIJHM(ol83}kPfnkmIy-$~>cafFoYS}De5CI#H0{VVJvVju=iJhPpN&uG3oiGB ziHIrh%wXQqGXHYHRlnl$W?kL|@8Z>EclO}9oU3nA^?B=#54zv$E;KgJ{c!e&g>6l9 zeY1UqU7ahtPGonTDERh%bp3Cx&)~T?{_2e-{CktOzIIXzJlWZ};@FXO?3f?TJGu(a zEd-!8cIve%XN*^#&HTD3c; z4O51h-Se;K?cOgNk#+lur6p@=`HCZa}~ zC>BGtmvy+LOv-&o3f!l1zETrBj!L`{o>2-X4p(^$32KIVpptuo`map>Jt0tdB=F8} zhCs6k0)+>M%?dC|hEWSDcnr|wN$@UR<}HS2L0;wg^&R%AF7tkQ?X>n8`_pdPpR$*# z82+DvQ{jLLoVsVgsoMxAP*%04A|+_y0i9f)*_Tt zvDP3WPg-8l71|JfkhYd}t+3~h;bn}{JvRx~jrW`}!w>PT@`&;Sx_a%fHmJG^@;`08 zW&QT}z5Tz4-`n2{zr`JnQDY@t2JkJTsQOK*a>D;Vb~}dTJ$fwzmvJv_Gv&Sv)?~1P zc!F_^*&rWdInD$2TThDigkl~{%>Q8fy~ze8Pr=9=&c8kT+ti&C#61Dd25?Vc64wMJ z+jP;UEa^~SVzGoqVsXs+S&YFek!nA4Gsq^0KIUPPc=MXX0#RI8b=1@At(COD;_Lk& zz#v(TaX-lssgH@;F6si;zzz(WB2n#&7Un~wUdxSw-UqM+<6!)yuMszj7Dsu^vZ6tO zLexlriE6Npm#hwC1Q?Eq3MuNqhTsQnQ#5md8?2uT;!uL;L>=IQRF}-y5@v{%Ttbg% z3I-w($%38)D}~fZ%neY`ay$b%qLjkApq%2Mq!WT-NnA6a9f}moJ66u({uK7WO@J-~uh3fObjMW3bmvs(3_q8iO)re(Y7b55N($xV*?H!Xe)rRQ zJua%9a<7s#+vQuhxe#8Wm#!?G%Q$=U`m<}Q2E$2eb!*-9*wonc?Wx=E;}6E(8~Z?b zPgp#=^n+Z-sod7n6ECfrt@rxw_RZ)Q4m>jN1?zc}XYR`El{x>c|G}k?u6%gqqgOwC zb*W?dSZ>dST;s)w0eOIR(fP=H=<~zJ{{GhA-CAZIzLa@|$qdIb$?-LWdQ@jrkX%*0 zruvd3BdXsQlf_%>2t7Zl%9;;M^vyKP4=)V=)>c0&JZ@;5J3M=M?)dES1?!SN*U&Sm zU#+OqExL0JN1?952C_5n*uCnkpYEILo7XSw%scl3_+YOw>6Sl+frZjCLS{1U~9VFLX4n#?jl>S~&j4N$9)l?dA z6jC4pmB%k7`7S~As;25FPxlRKe>ERRh3YJ*CGo60BS1X2rRi^078x*ZY=kNV9v~33 z>NI_e9xSCB#L>5#U-Rwe_i6)EfJ~$OXuXt-aUOn@^$`Ca@#DpUtf%DrBdnXq2Euea zNsh2^(v!a?F+Wc<&LktORgGacdN{Or>am1ou(zBEB-J_n%Faz*YkUeoc4q_0Z%MR)c$%{?~C;*I+{KX_rlQ}jG?p0Q9Q;4Jw31m$-x@fxQ9w4me%R?M8VQcHYi)#!G-IIvOEq2 zx$>&1gM-8dLHbGPuAFJpag^MoAg_;lQq32ia&F-JXPKj(fz`0+4_Mj#S(rTM@HMC|(|;aGA}+cuq_XeQhB zmoSEjLdKC>D-mhYs^Y-O(vg?3k#Z9#eN|fSkQv^A=uPy*g$y>yBkSsW*!S>-jPq(< zKeVQ@8eD5e+EQO=Ze3~Kmu=p+cr4d^V$uw}&*Hc@aCacna4~Owx!`Edn4ACb!~j+h zV9b_!!nVy*=1g<<((6mXOykMC{ZwIV-Cr2KUeiGH*H6IyF0pWz!op81Q%hW??o{4< zy5QLIXXi*ygX=4fXk>5K^r+DTSI>U*vD(nHSN*ZMu4lLU;~f;qx6`D4w}#~Rs!4vo z;iO0X@$tHo4eC#}P>}zmK?57THm!@bCOzriK^GQ zigS?}ybj4~SJVctL4|1KlGnhKLuga_Rwjv)#g7kcQTTu$RBD2IgK6lF3E5!6K}EM@ z$t!*t;JaAy6;06z)ilBeu1k!}r+n(b`J`_YhQX^<{0hQd?lj6qfJixCF~!6lCDxUj z2?6`q;o(p)1fIrm>7xk=C-}z%A`gkkyH3;}FXFo!{{)J};KLn&XiZH~)ThYwDKdPD oET5v9&rsdJp*>l&=RdTFI`kPj@)_!Z_)GetigJB{NG#3zU*qWY>i_@% diff --git a/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc b/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc deleted file mode 100644 index ad4c954d6ce071bdc2f2335b6c38ed1a4e2e0521..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16573 zcmd6OYfu|kx?s27PlP~7JY;CV0)#ObY{zdL^R%%|{J?fhVryizz{nCJ-NG1+lw@vY zwn&w0BWAM&-gwsXPIisgt_@Rnevn&RwIq|t&DP$zD@8ac9j4ZIZ+>k4k=R?_c=p!r z_jRjVLY8lGXXnQrmCpN|K98@zcmEfoQAa?SxOnsO2ipna@39~TMJDs;Ey&CfG(nSo zVpvGYVUjP)hGl#zAC^NZ^D6?%VI_%u_<=;sv2f}Sg!@~qoG>4iqa2_|^S?+u&|xkFg4@dRnF%j5R@UEXV6kBgo3 zc)Tp@f(nm6$a+1((}Y-_DLncKWabDjftwzdX$kz%dRq1t3EMW%a!3`l;w@rWNgHV; zlvG|dT|}$iBJX9U7}nr+hPAYb)G@^-U#p{9ITCruCxJwQ?b)24sLmy*Z6YkJuPCGH}8VUMoFT)j& zjk6(U(i0jBjth-tqu4U;3V1`zn1>ZR7Qq;}V=ptocp>IY6JBP-#dt@=4tl<#UzkMC zc-cvRNGPcNK{xH1V7a0R0RUXM7%vOUQm|hC2vY>BWlR_qL&OzrTme&pEf$PQF)HKx zgxt)iH{`nFWyZaJmcdXreU686R!#fqjpT~RlqxL%Z|wqbEo z+OZwex^#65rd8>hR!r^bss>EU#q?oy-7WtO|H8hdqEvPNjN*Z&>JiLLcqAZ!f(HIP zxGh3FM~o6td4!CSZ*B`^!Hn=pK!)RmK$EoWzWiQhuAM|w5t0DRCjpn(8^QBlX2z%z z@B{orfFg;Uk+`0`uwf;+uS%xh#+CStdBdS7zM9esRBtGIhzoy35K(nRO)Dy(PgDar zC7)B!L{u9>u;GIb!6q;G5LWU+gzzC4N#go{R1gvQY2x`SM_7s|d^OqeVoqEI+}$g9cmFutM5G->4WaC+ z4Z?j#4736F38@j&`$ZxwM~$>8Vz^2$I%qTTeT*&%2Iyj`tvJ&rA%zL%v0rLXk_c|yCkOpauNhu zIm+pC5HF_~8S{I?dTKlvqP*9}*ih}$U)2A97ryyF5W-s8J3;~W4!JnOsbK2iW;mta zeaY+Rw0t#Ebxwmj8g~c0OcQiqnjvz^`Oeu2#))IqKv3v0PsmxqDS{K;ae)RXd~9%> zQy`9p*&$|mD#0ljFVx6_ET`Z}nZg$w8|P^Ur{YhJ<&^A%e=Nixp5j#Q37C@Rv^dN^ zHtuDGhicfMe-de)vxq6j1$JOmi4nfz%w~uCxJayyL zo*!n^&9nCY$#z zb$`~l;yAVHIG1#sOF7P`Du-u|{(br8m?B+W7eA1!Zjb37+N$RJW+PBlZjULywpGkc z%}&ilW+U;QJA=0e7kg5+Ju$^vrvCNW*W>hxZA+r~(4u1D>f$Sjt%p*@hZ43eYj!Gb ziL;C6R_q5>?OjQGSIXYAY9C122R_@CvY$>2olV-$#k6T#U3@xe+m_~6=9<-FN3z(Fwow05E;E;`*Ao?v^)8Yqu1;t+ zJ+2^1tTB4Nd!g}xxnq$|+o@ZFHwM43x5RqZfY$nM_!caw>eg8Q+U9Lb=aw%d51dZ# z>Rg&i?&?qPJ+$1M+GGfzKJpJwK<-)R_93}BClIHHGPdakypRx#g z|JvHuDgT|~P+PxV{%NPOU!(qOy|TYd?UBekbV4GI|3@tl&E=xzf}~$WXQv4UiLFz{ z$>D9}LEB040MFv=6c2fv+U0^)V0bn_U#9XU=*J>0eUn%-mBe`vQ`vOB#h+&DSeaPtM1*k2+ve4?Aa%6{mh&1v7V=7!9(a@Ux z&=*<}?D7IrT(}-aX>ruJlGcg&=iHfet4vcw<&(&H_CXwUGD1EhJzigTv0e$`hFfE0lO@bxHnzMqs>^x!)Br@qnK?GLfPrPB0m{#VK`DtZ_ z$Q#0Loruz>Ky36KVnu3&@LVD(>f=nh$A$@G2rY?_pE>j-&ICaTOd9Q+? z41y6&U7Dtby0&6jnm4R!<);WM`E-b9CXiKd1+M^)F|~oZZiX%KgpS#QxkiXMQU??k z_z#+$8Vg-!+ObX(a9;u9%dkBC(f4eK!{(Q#qo-B8CV`Q0*G`8+)t zm|&R~Ft5P8Q_Eu-PK#U32%bP1nUHIYz6Q@i5zITzX)uEbh|>orLlcuBmmhY;kt3W8 zL`IcR$Kp#Y9&rY3>G+4cc-G09OEdZFU`IbKa#SvE?FLEI@8sTM0NY@uHUdr zJN|NdMw7Od&*;~T=9qK-^op@Iell&XT(vrrR_B6q@%(ama_5nh^=M4?u-SXC;L~a%e8h0wYWl^A(Jl$}6olH|;YTttH)`>GKUhft*9gRfG|Wm8%c zQH;WQ(_}<_Un@y{nux|{&-M~?wC)PL)lqFk%V=N(>5|T4T2C8b3?=r=rX#XvC86t6 zq8t>G(CTKu9Z^uNKPZIxyoV!B~h zOfN=#7$KiM$dw*Ga>8}?$mt_p$9i2oy=RYf_nzid6Vs?b0>x&L$2}36WV|i`FM-0- zdo9Gtz2jFI3Me@! zOjbEa8Sya=!!wR1_E zqz*0jB&oxnb^O=;|9O9sI=52%GP2?gk5vSq)#lmexwhH1R7w3z_nNb1;q1MtneGG# z6o!Vc%WCE>&R$HLE9Un=O}eH&T~?8{*dA+@)}ood^&-NyY1P`8v^FlZC#@|r{eNe& zLZ_JQT|?Sd`_N{eAGsB{5df(tQM~0zBRn!;cKq>iGa%gW2vS>{E-k-VKkt68A$H(_ zf%@aR9O|C32jDP1G8uZ5#79+4-FxI8%MUd|;kR3i81GT`$ke~>R73s~rMAZ*|D?#! zL&-m>Br#7ZA>M$Mkw)@ZxdU)$u8>d=E6e}iv9kC+GvzrHpa7nx+%%8_0GuGJL6{h+*;&gQy9 zLC}UmTGA}T3p65k{YG&$ne|KbgI^*aWp!=9+UD&b=Qaw623cV#s{4VSL?N0a?hJH~ zx*j4Xy8r@B)Bu{hB3fNZ&?cloEP34sHDIWzB%;QMkuK&>t0+S4ZuE8YOdGtY(NVhJY9`LOtg zxsYC&u98NS(DrC)s8)pR5WIClh?IzZ#1w7>ZaHe9H$_S!7G#{@&6Gtf(K5O^TLZ9X ziCE|w9^$Mt87;kT2Vt_5`DykkZ5w9 z=*M6svv^m}Vx^eux=le$tMkQfsQ)XG-^|m<%dU?+r-i(}^*JrF>jOfI)8VcK@Yad^ z%?ZXk;Rcoyw8Rt}@=j1AAiev61s(THQvt7=1vW4M3f>%+^70r9Ex;<$lqcx-dp#kR z%4b}WodtdsSvZvRgCUgv{eB8;=otZExOF{A8Wsf1e}8%G&jFc0&L;$8|ENZ3k{ z0cQHLm+|rs%-P0<^$RPf+|Z3*4ZUGm|Tmqg!{;`RN(n<~Coe8_8#=Tw|_7&unF@cHDG-O~FFEz<}b8eXz z9KvcQKy8Kfh81{fVN1>Hzf(BGIsEV`W`fR;S}NHM~gj z%|C#x<*6Lz2y0w8HRh^#0}m)clQZ1@k2O!ZnZV@4f%a|NS^y0A%rc8`yS zA!-Zpt{QE^qjK({Edy?K99hT9czxk!0leWQW&u17Q0RZq8Vb5RleD`jx6S1YTQkGO z0~LzSZ;(^_1{v+JBE(GZ4O_ju>_tf;80BUu@7`}vRUPh}000a%jnwtIL9(I|x z%7~V`Fo%jb28B#cjS4mFh71#U>v$*p;1|K{#F7$xcSgO;ivQb*XM%g#7F$S}-9{}XMNjxsklYRBZcXqwK>z&TGJLg+g^o?tV zqIZtGePmu8H_X4Du(qxk+E6{Id0P{+uV|_t7M0D&*J|tU*lyeI*l*hx=+)NFWNT;Q z#j~l}a|ul~=q3{-O$+ra#@5B!G*x>?e_Q`6W5Tcr#Nnd0brn(XNN8%lvN&dX(x#gD zj-;tEZL6F=IXe|UxVV3@by=RM?Md2tzzArlUo|u*4b2NX9vHT*tBJA((4~Sxv;Rha zV)MR~eSbnzvGH*R<2(It_a`il6+;96yss;?}(j`$HOF95vEt}D5CQU#>=D3H zmn^4Y0vaaC9vLLPvXhwf$WCE$N_H9(mK@>-oxwrnj@9y(WO>WN(1Y@hbrVtBuwG0w zG~GFO``p5Xlw)6P@JSV6Jx{Jz6HPCyHttI{?px|fH6BV7Z%*6nV6UyHzR4!)ccp4} ze^Id~b|`J@PV~P5Cz9;Hn6zC?*VWIf)3&z6_WsYvJ zetVQ8$~MD^6}2UetxunnWHWGv$+DN9K6#PM^nJ>ujy@2yk73<~=2jMf2UAdpUGW!K9UFeaMg2^IlA ztantIx$0mpRgP3pK<8dix*)TLHF(Bm1sGCc9!c9AQHtw`szKkQ$x0xg#g^JYx1kM5 znwjhvkjL&2#H+wmMRglihXgS%_zGx7B+ZUKt6R^}yX$Ec%5`8#kW89|bm^dRDrD4D zWwi17h~~cfUgjRb+ey-0_!{z-2t^t+CA0zMHjwm<5jDg`KXB@(NX`pz=9$|C-J6sL zX;bUllGhcAv_b3*#?-sK50s9EUsg(;)_K79387;OVm_yu0;Ofe z@2eMUz>7GuD)ghe{__ypD)|6C%cI7ep4y0Uqt*~?2D+&6b>sE8kY;TL`7+|PE+f;K z!Y}GX;(ZEj32th1?5cO1I>Br7Gjae0l7^dNC%m4qkuhM>`cIt6=l2-~7CH7ZKpA*0 zD?>prUNYyP3S6MhL&6MW#Ouw~p|OA$%md+yA+Q2x_4pLl(-gF{hBxIkjR?#kJqc1> zkQ$ssmF9oH`Ew9|#(~d;Q4HQcN4AA!dnu7?nW86%7fC@4YqwLLc~E@$4%@V+Wp@)Le42yki%q z5*VzE2NhT>1akr~n!s@eN73xSjN&++3sh&sP*^bEk2CUT=n|{-nNxFt-#7SnU2-#= z9PGsWkuq0sngB%XbG)axVl*@GTv%qpusugWY0Ym<<6H~?B5-1wjhxumCFSp*x-)!x zcu}|F*q5p}nRM)nlS{Atr|^g2&m0d9p7>X_3CF%<#mQ93sTuW`T}M(S10OXmsow3F z+cmpuZvX85xNf2ELFxA8#u;^@WMHklYQFzH!;CR)t>A6O^Tr3phK0>*HBAe)FKXIn z^)Y$u?3%TF&Nu6u3(f{p*2b88ZR@s0>%CWE@&tGm6*sTT^yY(Wn;LIjym2wTsdjbK zw&bR5iwz0;p7fTsbgd&z)u%VNJ=UwLD`S0+D+zPioNiV(e_>(!!fSVTE;KHh7Y{79 zEth=K_1FC$^{@5}C3}VvXI@G5yqf5~kf^%2QtV2DF<3vVpFi|3i@tRbR0lF{jR|AJ z_upxW%2VX`T*<)ydR$GE*R9&NfwLA^HO#x$EENg+&Mz#x)@&439ONtZEY&WRCMplC zS~`=KPPBEH4?caeNto^_I|OI{8*8s!^O0VGab;s~t>R;QXP;C4+p31%2E`}jAu|+y z*HGQpp#D@LgZ!r|4d#oA`fAmm+Lc&StHhcHIp&?pz7F*UuW6)AKxv6bt?~ zbsJ|J0nc5had3n09TIJ1J02lo1)nzX<9dh8FhpQzi-O^tCQ-Zu)<(;-PvmKlK!Q=* zDI4Tu0roC&R|bADyz?k{2nAS#H_?Ks+3$A;F46AJaAPjk0==MT5;cyk`~5*Eu$_<* zIzENbEE3)~iC@#H>V3ba7wGv=yyv~}BK>*wzVG4x2r>W94Hm_pc{gE2e;0g2@o!;v z=HKBma(2Spg^C}&i9h0KVG|!(QZ=0)S{5yFMf~iq3`>^JYdXJ&KlAUQ&!2e%oNB1= z*h?o)|2ekszjL-gr|P;tq!;P$<;(|Fc^Fs1YFQ1Y~lBfvlfn1W*>j zl=$*&Xp!Z%@Bl0sW&nH@gS;yO&oMApVGeNrR5HOyaD$e!lL25~Of5!@5Pd^K zH1-5efnVO(ZlJ?X4ev$9v|%US^K2XDUW5pi>%PW(fW<|KoJK)Cr9q{p3$*K;kb_T= zz{>%YjsccaV>)_?<#d<{jrfD$N+x$-9nE-~@uPUZGk#Qn#SfV{6{Oe#4&?OY1gG8Qqsf2WI47 zm)hp`&hEWQ$9JO__eyp5YW3k{_2E?YK&td8xS8oo6Z)DpTh&~6HvInP__f8-rKy#= z{?)o;$+}~yy1|s~rJ2KPrPkQBe{tabZUDREpO>~L3_BM)Qih$0()KUQs^$-;${aI& ze`mJMYf|Rg8P!8)d%}5O7MPCtvkA-Qm14*8ju|C*NLVXwmfd7;?^@WMs^5{S*|{W7 z*6drNlQmt-hf)=Nf1~^xPx92Oi3{G;sgcy+Xu=mr4h9kvY;urIOkPbLo=TNno6$Wq zSz^BT9rxu6efRXg>0VO(<YqpA3`BIgzze~MpaRDR1?*Lz^Dg-6FxYBPfCo<-U#RvxKVOy}Y>ybK+wJU$GX- z+d4}9#1jN3C2&y+2EfAQ1%t{@Hq0U!lQPyI#4or%-i*aBS`aO=OT{&??yKI65VIG} zW3V}f;`P4#kj&OaITLI*a6Bj)wNgh$sOjJ&HRT=;;bSCJBH9>RU;?mBp-~Fk!9key z@;!%`$&LHZObOfPKk5YAV(!X?Rr7WRaW%rRNV^wY;hYL|Km+$OVf z+L>)|dN>L{|BdJ*e0GE2Bo#6(GzigBH{SnEhl9b&fem6F{Nk8P_#7Y*;4cR~GQP+- zB`=k-#n7A!NthQH)eAf=$4zAXr*93YKeULW%I}>jcp+ zdkM9>ugET8GA?6ePcVs~eiM_&1aX#(%95H>Gy2#o@huBKT^jgIm$p=ZD_m(gQU&w7 zr{}8xtta&ka6FzlvazPDV&?Fdrm8jYHU=I2tz$QiE!Y;DQuci@Ehxz6o$=EvWzL0@ zqF%J;-Q%bhZ3d%2am9R7{ODrUlJB$TglQYY%sjH`|nFNP*qK2RY3GFN1Fga2?>z ziAommrvwvyFu)+C6SOiac5HOq?dMDZ_*Qx>6+rkdRv56{?^C(c20bZM+l5bk{aB8?{D` z)Pp6A1X6lH&4EM2i7Fg8_J82g$E6HG1@+Whkponm_`Y2`q?{PZzx^IF-~8r#_{U16 z2)KG5f3xSzLF@mAS!=*&|B_T{6i0Et8Z%e zy>3V*surITkuf{V8coGw-4-ra|X39D;A(^MvS!7m~QyO9l%!}zyJnpU<}M@jy4ME zv??$t?67=4k~5ZdmGiLE>EbIfGm3NHlEhEjYP~0 z5-OpzM~do&O`%JxZ4pO2Gw9NRu-eW%+QHM^l~1(Cd6sX^71D+JWL^FHbmMYYUdp01 z*23ncFkS0*PK7PeNgGNw8wsVdEmR}V=~AyR4`G1MX8)7pr3r0%wmOgS_$vHXs@@K-@OJdwxgTD> zJ9BpA>{jQtYIm#kTW)=f!N6zBzO7xn1mDiNE^?l9qxG2NGUuPJhjH%7*y+mTxlIwr zm|-Soi58XNH93bUX__*iEUOB_<01TIxc0*DfFO@d{ZBOFq>eByqx;`B>@)1?;u~0; zJSodJOgT>ho+UVna0Bv)+1ouloANpQ=z@@vC#t23l0r7AW0(DkcG7;Y z-DydhBT0V?c?XyF$6BbYeI<2+o+ESI4I?4d(!n8@F1alXCh~%0naPMfNmjdO@@l^# zPvH?NgQ7uvWHQF?!{TjNybt~_aC*=2*urMC2l&~mG3IT(dh_X<&L4o{o{O+uy>s@% JKS|9}@ee0UNC^M{ diff --git a/src/models/__pycache__/Base.cpython-312.pyc b/src/models/__pycache__/Base.cpython-312.pyc deleted file mode 100644 index 0ad2c7349939c7e2b124bb94e94bd8b3f3172b23..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 631 zcmY*Vy>8S%5T5<9F(DTa0;Ne|iONOJ9Th?%Dbhe5zy;REyLT6z?K8W33YUr^4?uSf z1udfCC8&6Vl@Lf&bSGVCsF*qD5UjK_-|WtO-^_d&40Zrd_wwE3Gx{GCY)^Cp){hY^ zL4#&BG^}6^FF2!^*P@YyWWeDBw73E-4f`nix*}RcS^SfNj0f@T=s4`DGv}?Ud}I7% zrZX1g^K<8`zM!}ACRiUJSb~9qX(+g6Mrdy2NTQBJZka{l_K|Yt#Y~$zB#m)Sjg1Sj zHGXc}ivP=s_zvDRu`gl4GJbh*Vp?OBH>o$TeLAwUCRJ&NuJgmw50HrfoEjWQM(pzb-=2rt zw-`AcZTk&;Pj=S`;53UvQkJc1Oj(AVvP8alO?bB~U(HqB|5&OD0+(vw` z%&o1VH1`nA6|6Xq2JauN0NtnKYjS-BC6laBXO*QfEiEg=e$kz`T0_?!Rzc7v+IDZ zAfYO%BDIyuDHjeLTZIFEKo7n5VmFbpT3QvU>LItJp_iWeX4gw7>O_9?_Pu#;-q&yP zb8l}F!LmPpRQp;)=nwi4EuqWo-2`S0sYqpXRHZwsGX`5_nNY@6IgoKWZwOVv5UZjQ ztHuHyug8raSF*16DcHQ8hev2FS>S73bpL(OpuoG%dUlBVD9xRx!+ z3oa&-?7G;nT|akM)?A6mrfL}yzGqt|HeF5D!8-At?BI_q73)gqJ*1FQu=eV(ts#u6 zjEbsk0?|J;rLrqbm17a+;U}n^%0Cg-+0Xr zAoy{EHd4}5zc1vd)e`6+B44y@Y|5H1UW0vR))y6BgHXVXcd))d(vUsLP|^d$k9FNT zv|M?*{wbI5T4d3&Wd)aOva%>cYGzTs4PMF)QOd!cPC0B~$%bGhNi$*cq}ISIT`fgY ziXi72kOycdjxxD+dbAy%ZjVoI08Nj2gZWkdn_lloWtHEI55tk+qu$}tZj8?+S7R@F zP;O|=cxXHyIOOHt^@fkWNC>&)D*tV24_u;c+WPGaK%fnrp_xGBd)qV{5bzWckT_rk zs)ij0vT%G6Ak-kH^6Ff~z#y2qW8e^`4pBtBzUtlCe^p+8IXC;#I`#J20*~?6MDPi- z%&N??@Pu9O&g*?~StR3bDmuF?s$7IEGU)d1%j<7M{Qe%{QRNp;K#NmuU!+roNCyZK z|Ib$C55nPke`K-0*9Z3bh?}P<`2faB2U2OcU1=ersZ(B@z@zd$S_@78dB=K5@@kRn-n1FzI%*G~a# zY#mbwLOT_>2HrtB*)Y$t1Mea6IUSqiI29#t%6bF8Pl!eQSa%2@kPxDU7@^4p;gfNC zJYz#h*i^7Yfuy@lzi)xyI}H#uBM~`9PsXFz9op{zJo!915U2nxv<~>ckcgxsaK`xx z2teWuG%&b!{^9wp>{ur|wv{b+vgNJpR3|&-^&eO(J}i1~4taxPUUtyS9ob2tfirtZ z97wLlcQYt6_~`EAdruE;j-3C|`RVM|rTNaK`HhP=+MnL~<>IYh`wzE=W;gn;w0o{R zPv^GMBc1ffW2TcH-AYe%(i2q425s;YNa@9Lhl-i+e%hWR!$dB zP8Cm{nk(=@*Y=6F3N0xlDgI5_z>?%AB?+dXp;Nv`lI}EQJ=BmK+(0N)2t^&C!%isP zd=73RlfeCef-r*mkp$IKLVwO@KvvN6x8C-~%HF}Dy;xRA?oJ>v)yW+He&%UqGc(nh zzSh3}@y7Hgo5K8#h=xbJ(J^mq+&i4#P4x+r-l35l4)&cSI#&3LO7=J;-e7iFV8eUq zO@)-tS};*=)7H$}L4brJ$dAKwE-RYb408y=r-}Hf@HH&a|F49OK^Xd3%F(KJ$KgU) wmUqZB$mo;hi~-r<7>0R<62GC`Gt~P#y4*pR|3pV#h$ooI?+Sk-*a9d20#Df%z5oCK diff --git a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc deleted file mode 100644 index 29e23cc138d0482512bb0b41a3380383a96f3f73..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9969 zcmeHN-Et#G6}FQQD22ck#cfwol~_VMyY_}9I3%F;&#r1M+vT5y%~p`wGcBp((e$Ld z$F`E93Z8-+UVvxe4RQ%DkgD8q#dl8k%*eYGS6q6l5|3tjx=)|}&fjm4UCY)!U(41%Uthk_eN%q&leM+Z3sVVcm5_x{H>I*( z954Oc6r!|dX60g`tZ=1DwaoR`x|j)Fh{NZ@-Qhr-S?w|KQrSY~t{;X?bt?0GDZI2Z z<>`f9Z!p>wPma%>s2tt!b0toWb{0YvSK69lp$gyYB_xW)d+t(bC%h3eRVXVxWUFOL zO{Rn^bXh7dQd4-T3z|{hmPK4hr!oQAXsM~mb8}6@too`_&Tn@*z1|tiCysOJr^YU} zdp!}w6dNR5mj<207ka(UH@a}%R;kyf_!B+bIT@Y}cLoPy|NL<0Ys6_kiNq6?anv# z-En)DB{`~mS$RRiVkVl}(fXT-h`Qe~ixR6{=v@0H?es~lSypHoALmM3Yd^2`;DCvY z4=TYtyGP>i=nN;Q#LoHd;2vWFhb^Wm2q=tgf?aJm_7^QmQWeA65y2@YgO8SmG^OxoAfg2Jmdu;IG-^% zX@N879&ag;n;FDjRL+ZwmE!On&Ft0+KJapCp}+c-`Tm zG?ouQ3=V!jcyS7YN%IQ0EmW$UlXgiEpfN&vM#P1h>4I>)-4XaVnM_JHLH-rS!W?Wm zzjcF`e*I38%VJi^8LS4F%6Y*1aYeC^INtZgQ2_(hQ{g;LSlTSyueS-gpe0}(g9@cG zvZTs8jV%yj;+h%t?6PJ+}&g7L>JYUoS%>9?Ful*Y8r!~~YMr~SM zf-$M2hyQlfR2MLi!OqhLX92HlGY2nNy2dnwo11H#<%PLYNjphEaXN>Rlqe+}96)GA zfxocN)KoB@K0(>oOxgCuy~4o?1zcPEmLp$SYfTU?9ceP1lO(#e2H*ywYchA&&IlK@ zER}ewgSm=xSl03KjLunjRXE7tX#fxjsMPj!vmKz2obixJ7w!!M!+5x~ip~EOgC#30 z4dzjH0!PbBHD@2)J+0L)gRv7>?=&~kV+|}pbHL*|k0@a$yec2}tsPC;5oB_-rB~kT z^^Oh?Uf>yt-N*}t!4`UClRYxZ%Yx62eEg!+@Jn7V;1D2VG=DaAIEt@H=3HN5oyP}9 zJI}_a!_W6f^6+aEWS{18Mh+Hh#FKA<{Go#xkYA_cjZ+Oz$9vEB4#&^nvBW7mY*9e7 zmR1*6GDj2;U z4TZ4_&gf}462zU`rq^HrC&CA``G|I0|9ZF;f8h*qjU7bn(sZoQCKM!M;2ib^98&-Z zC<2D{Y|g%M!d|CYB*1_G8g5~BQ589yS{?eyRge$3#zqoxESGR`n3V$^0(xoq2Xx4? zAj4P~dUhYMxl-0qJS30P;4tJ1lf;{ic6;4`KAEP~qC)s4v+5?DU#+_RBV;n+$*Kz> zJjS@tc*)v+9aO_11zu`{lnLl(QIigOQ1Lo5Y#A5;=K&`YZ^vwDtm}8Kp@#aQ^k$*c zaVi}~E1xD1w{~m2-f^%@Cq57(BtJDC)QzWL4slZW<7t%(Oj$t6IuAiY6OFn~tju)k z6vY+9ZGalEan0#3Xlim;&ED?QRoEIsw%tEj9sG6169tbJR}vWv5r8C5CI(kVXJYc4 z>|xJZW4Cz&jaftR0`Mn;4)#i?o+L>ynAO+L;ay9LiCDAr!ST_+RW*^cqZ%0nN{j>8&Q>)eO0uG z29j=2yJ0bbm$e=q0S7WsMK(bqRVN;_pVnPc{@X%ggg8wz5nx+@PGK*6d=$nf1P6*p zS5oKl0-LNu1m>jy2^@6zHq8a};9)hC#W&Qp?0plrEaIJi0_ z3s{vx0Wzpf_&I)+@(7sxQe>w1(ANNC36svpt2X&OSU77axSeQ=wdp@hHj@Xm0F0Vq zFOuy2;UO~Y=X)o^edO9rW~~rC!T}?PgsgpmIrMb7O-b+I;9x~|j=g{tkWWsZ6Getc z$h4b$A%=VC?7|5%ZBq%z zMVSW!g>w;3V-HQUd_Xzn7?6oV1Y!*gi^PpPkGCA{UHy1hKL)>ZS3lm>k9YOs zUHw=Wizrg>>c?2Zglhe+e!Qz6dqi&bs1&@bAO8>aKF1xT%jl|a$>nx?o2D{WyKp^uohwNg<2T2 z62O8H|AtBp3l_0}Y^X}e7n(jS7X)j!%}lmIZAb;4VY&>JHx?rj4&)I^@R+p(g~eXVk0_eRur0I9a@Qo$a% z+J3w^tS_DsBZ9^aEf=U>B8{KyjYbqnLbD-omb5TL=AKy~qO>ulI=_PSGcacNa&@Hw$B$Aw*KWbQ3@9fC4AHabgD~?)1^fyZ9GW+f zm$$PmUOwFF3y@fSvAOkeyxMy3QMmjkfv_cByqql_) ztY0S7Kx?u0X%QycAq5yTAWB_AY(OYeC7L~?v3PFkMx_<8rEy0m6{uSt0Md{Bwy{afPGMNzft$ijaAZ^T$3llL97Uv^_*i~{Cgj0+WH;+M{B@Wl)CB>~@umG- zI`kWBjySy76GtWZell8;p2(mg1=7KxHCqrkOPkA*X_ChJSo9tUil_@bxBvujPHdcF zV`fcNGV}82BhWln$;@@B)tpPTq2}Zy0${?1dtx{h8l)DrqcO-_6E{?V^hRzERiY!# zB(u2pS*>JS!3Z5j{?{4iVAjJF7S;vcje6@k&&6D#r5(-J&B6DgCl?GHYtvrY3H*;` zsEyCeVmdd5up`{HLX+Hqb0>f2?UIa@SR+`cBFkt%yh$3^d{j$JQ;p6~&$!=<_W~cr z4YLciG&qLMNByW4Y6(9d91M5iVD@e#^=nzZD>i?sd{uUyf~!QU1wFZ%O$w-lvOs0! z8IimCj(5oZ8Bov;l0Yv!i+z?X%)L zJzdj5W1D=GgzVZE;ds#OVKh=p7}p@#S5Jcv7Vw}J6;IXflqZYaCU%MMF%rOks%N!c zBLQE6)&;1*Ue_CJ1P-r704JS&NCpm*juO2_Y^Vv2t@_|@X$er`3h~QD$4$Mr@u>FR z{HyWh{Oeg?JUKZ!KjtnSjchD8wD?=M*3j&&xQTd5G&PHKOXT*1qC7N&OrW9RYtEkg zAGD=wOZL`wXcN)ekHIG|g}I$7fT1n zCI)_+`^n6;!(~1}Xkia_9KNa%!(A~#aAA&^d32Z!+R%U*AtUM^*leZ-LW7fv{4Zd# zd9Uflpry94x!dIs7Gm0KZ9_I$N)r|(h-NkbiEc|H4?+R;7o^y_J?&PUmK#@{4%lX1 zWdrJU+vo?S{w30dz8DnCFS_IjVeYwl2`}?Su-W&6iKXziWT}vQmuDawp|FmAKSITB zUv$}jTlfoGh>eKmG}BibfC?@5^1U_z8;{<@ay>F_0zD4{+Fl!Wh+#B!hjd91o@^bV zG^N&Iw1_&MMbC)u07zEymyNJ}a$QIb(3@({Bs5CdMbg(w=%6ChIYkL0zHC`hZ69Gd z8$-G9DRfmpP=Jwx%Qbg#1R7W*-F7F8{8XSOh%=z6xbdJX@ZrW2>h6&MfDhq>E~nb! zOO^%+{C*)`eSeeRKeLq5;p|D+IEMu4ihHpQgn)r>Tp(pT-(mNSZ{LO_fX%n>gui2J z(#7A`zNc@?{{7qg*T!DD5@v9JF4If2o-iNsVi=LL>HYAEi+{%udmNf}sWO;v;b3$c4|qN8lUyio}V7rM+?D_o`}o#>pzMD1z0}ImA6Z_24EG;ZVztlD3Mli0JI4Omxm{e?RWF^&z9pJB4OyfSUw6Scc4HLtmOr%g1 zd-U`jYug4t3e64Ed}&~@e3kk;1NYp83L#!=Mm z?io3fkc+k+~esY5j=p07aI6m8latiW|Gi4l=2ZAU^ab8tsynj2Py^jdbc!niVtzrpUSCP8rf-)Ng8O zARr%B!nkwJCdSb!BP-g>=}75{*?Kf`de6AR1*^21qj({h9=hvct=zUWdn6PY*xPNk zFa+ggiCy6l92S7t2Dwg-*jQGXTcqM6_G&iaRgZ~y6i3%zy?Rw}bCl|_x--$WnLDPv zi!ZN75&z?bZ@u^J*Gh5toDTxnv9=TuFa{Y%dP6w9CI~+`x2=;zSUarqd;%X3Ne!=z zqx;$zIElhn?_3HoP$CS1h26r~H1V^0OvWNUw`;258we;w1%E6_Z_Iwj3?lq zfn=P!65>m+0tg3CS#XtzP7Y*J$%Ig-@FX&2Ab<1V;m+REXYft-=-v&x(sMvMlE>gQ z1XKYan#}QoE-P8c&pfmOrm^^l*oS-ncTylsZ5%y)ld3>W#72nU1~55OV(g*Of#N`b zOU%6ZF>@440gh!srEhO95rv5SjB(3$1@Is8Ivi-IHktI;r`)7tVrozhAT<#TT$zWH z^w{=3&JFgp()QSJIR zf!FwEUFlWab!c|vmrAr6%25ps63;oU3w11wRt`;$xsfy+n*7%+%oxLV47LB0X# zpqea(1r;H z!H8t@JQumNMOf~Ea~RZVRHz!NpeUt2<~bFXxjYgZv@OZ&juffA^fQ7sIP>{Ac#u`-xAT!dpn?Yw=r<=&2 zFd9a2rU;(wu7ASz#m@7UX1H6U8v8|%4IokR?i?mP%(=}NmD-B(1wXON_Iy5s!H|BM z>5(>?BQV&~Lwmz7VZ=yU zVryGUL}ErYbqHUZK~L$_6|gs*)q%h}wxL?TAXiEbTFvHqy)QwIpG~^eXJKI|TmPwNblxkWEZ?C(p_>YnYe>CaHH`3sx_JY|-Y>h*GZI)c z0>q&>l_~|!p#XF;HpwMEA#+b$pTk8_g%XvH9F{}vO`x?3W5v{k0X)Dtz5bfD$=}eX z@OPO#~p2UQny&ZXPgp^xft$`6nv^zk{p|s9N9mXry%soYXtVw zf&<6VLr5mPDL34KhEsB!sSj{VTcH0Sz+dyc*l%E*ShC+>x%>1OsrUZg^QVqMX=KBj zzCNRmo>ECrsi<$!h#Rv{QSte9)H#u777Fr^)D!~EL8pt+n(j!xhh+dpq2J|tWj=m% zL9TTp6*;nLc+bR_VdzbwjfX}IDi?VDgAKw4@{b{hPyu4ea4y7LzY#fjCeRrJWIh75 zI8Q|VNS#r(y~TDxF1!Njh!a2o!Um~3j76Ri^bg-`mWGFVkb)ImRmR!cYKr~9W12dd zowcnMM`51RsFw<+B9HXKxsata0bNaJd*^Rxc?Nkv1x6R}zV-9Vu7_ors1lq#6+^d6 z2kNyO4DnaJ0OzB;RFju0goFt9T)Y6vy!0{qkq@y1`%$Dq?k%?h#R$)Z#kZ1vcR ziz@@XFp718Hz%n)?!Z-OyBF7I30QykI=1Tqpq$gYVeq!MhCRVBS~?wf$y^ha~IUD_7F+sKQ2LaR#5T)*O^mo544>2p`mL|u+MaCTIw1enLy!6XYTJ|j z<~%%UOWn5AO^!e-{$Xg7Ubdw!YfD{9_S#aHwWV%b>Z-QX4XvWK)NM=Mw$z;~b$@s5 zeV4jl@2v*Bx`sy)_OY{nccQJu+giM>#iy@rw6%C! zi?_A-tQLQ=a`E&c9)RO*h11CY?IV>qy%erT{Pc3TG&(6xug%8v-HofwK==7i3t!T& z;QsNm8)I#bY>BrvZe%<;;#G#n-*}A^?|fL3-0-t(_}j-*cKn=c`g5$foSc65PN7r0 zI(OGx#(MPAI(rNMon2U5Ts&L)LA3n(;@Qm)uP(oi&hGs1%JS<=XB&5~Ex*2e_VcB! Y#qTe_zIL{=bbs-C%fG#9E@9q30ffpZga7~l diff --git a/src/retrievers/__pycache__/retriever.cpython-312.pyc b/src/retrievers/__pycache__/retriever.cpython-312.pyc deleted file mode 100644 index dc463e336802936d3187aecc27bdae9c5fdf2a5e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3354 zcmcImO>7&-6`t8YF8@SPvSdn@#g%L&VVqF09l3U(s#Yt-ZPltk8x$~&lv(et(V&OES^mgU zl|v65g0pYly!YnKoA2>pW%noVdB+usG^Hqw!fGbmC2;sgg{ zF2iSqgpd^zVpd8>n9cDSPgYLI7;)$v(uD=2ix0TErG&SLW76krFdgG5Kb|X6VZziL zD*V8-ofxLxZ~ySb$dA%$-qh+U_1^Un$c85dN?;v<2VNa7-~_L8hA_kF{DP1Wy-4@z zf+6UlA?cFgF=Rt9yfa)6sI&i!q`gs&*a{4~gx`?28=)3{;~t>LD-CZl=yZ@&fI zZB)WKzJhG-`{)B$x`dLh5bJafMz}dFVHmaS6j6+yj&am)=4^))G{>}ZRM0HlpfE#) zi`f$+v_rFUj*)YcW-ecFl6lp+M1?HKfTCaozm0C|{i#zG{)QomVwu_>!un zLE6f_uFisiVVh{fq`4b=hm7ZoR8E3@(@7@hx_50l-sF4OA}>RC4OK&(m4os1P`o0? zE1`HbdZa8YdaEzT%ED)Iw0a;?7C!WDgD4_x(Cvzw1PnKBOW-hz+l#Ss*o@m)Jaz74 z{y~E?g626#Y|f^bl|8M*J!nLEo`*DpUo)*Rnq?=_OxOCTW$Wks1bTN2p?RT%Qt}?7 zFiZvin^zKafh9|GMGW>+%_Mw5{A+Z@-MWis1*XqJwWsM)`V|P`__=!U3q zvk>=pJ7Hg9Nlb1$lX?lN)TkI0hyfiO5AYX4KsW}UgYwfJ~b$$vhKlh2C>*qfF z|EHgB4Zh&_5p*2c_$r6dd4%D26EFJT=g*_7SOcEM&5Ku|Z&ERNnqo317Mz*aPXp}8 zMh=S7%*_%%gog%o!`6tIXXTnaPS`ROD>1*D?NqSx22^p%X0@ELRGmuNr9v)k!`CAD z0({jxi0E$JQqj(5Ooy)Kt)4K`cF4neU%B@e#o*6e05TiX1^K>vE6f9di{_j2*GKx{8J zx*k3LIB=pSqK+Og13G7}pIHh%@(-;Z{p8s1j#W>O{?Y#j|3-A;_x^{2E8IQ#H}Y!d zZ-Z<8N}#V2o!E>FR0dCML{63`KKx!a7`}1-`uU}o9tD+^ckW&I=)(V`z7^?XR-PY$ zEv*Mf|NX2RDU!jc^-`h$pyYqs>65Va7K$F_}mV^@LZ{aJWb<_(; zGYcqE5&=(70q%G}%EOQjJlF@@T66pCwmzqecldd^B<+V@CAqmnj22G=eoCnpcJ277 zOYpq#d@S$W1SRPwDCV6eEc9~&?qMp~g6$|54MkP*mTfxb6+|Im>}LkWn+N0^`be)J#TsuTfa&_tl=cx&DoU8@7hNp}U4^0Tc+|9RKB8n;m_X!)G@-#w&sGsz1a~s^afknYcIg z(G<*8+uz>ixz4s4;(~28A3CJmJ+^#oMce3owH(?Ccirx})$=4gv>qN>u{Oe|%F;h0 zy>~;)p;hT&*F$L|@{P5wHEGRLiJUD5xBB`ka?jIne|dcCmH6uA``+qnZ+tR!e`*ci zpRV>Pcc+%8R`K%m7arkITX_mr3-&(=99<6_UAgc$aJ&{o9no3`^7cNF``6|EzqSvr z9jNih8~qY-ULSyL`#?oj+~b^%VfTWH;m=vJz@rvkdO0;~B$L#aOhRTCGK>x+;Q-W3 zeTH;^DCOZzN`_nt>VYtPBdk~nLpe9L*SXyeQ+c!=fH5)&Dl9-Y i%hno?F@B2tPf^F;(cn57{2Pis7sl|wGX$OM&3^!$Guw6m diff --git a/src/utils/__pycache__/utils.cpython-312.pyc b/src/utils/__pycache__/utils.cpython-312.pyc deleted file mode 100644 index 5240a44343db32ccd23713863f8a830c080e8631..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2442 zcmcImO-vg{6rT0|*~a$9*iaCpazfhJ5Fkkr5(z;i0;N?FR}GO0#EIE>7qBt5GrJC8 z){(3BV3bOLa==naW;dcm?H+zq#!xFP7F%B6j?wjq$Ds#Nt!`*wcb zytgyo``+wtcDogUB<`$E4zmcoB%P|5DviA^XsjR!NmLY#mfzGUW%Ny>Cdu?Q8l@!~ zG$S#fS&0SB%~3vXo4D|YsEQ}@GQ&f|v-dNo6*Pg;ltiUX>b^~+8W-lPqM0=Gv@)Ki zHF^Nu?nNlgXiNmY`lv{zrxkB$_8k`Vz6GaI5&ozLgd|2|zeN(OaiF=!Y}KAZ-}Es^ zqlS==_orGUc|4%Bd2fmBss8rSv+X^>+ikvp(i-r3Z{6|TZuJFx+vHoe$;YP{F!$(` zSEqbW7~O$SNnusSIy-^m$%M{{iG&=Je7w$t;*zY}$CI&;8ji<;uvXFegs6gb zOy{RXH8d$JI%_m^3$Y8v#A#XQ5+W9-6@~1|Q~{Nm+Zf@N^LH*7|gm;(N8h&u(gcDH_Z$$rztWXO%=Wtm-_L6H!3V zM`1fmC}f)+>;|n{Ls418L6R=G83vqW{tHlmjoG!N7R*g0!L^uLNafm!!iRZW5L!2d z^9AAjdhai_MPXo5xKZs zVA0u-7>z1eAmhh~B4FTJP&y;xi9@6;qgVe1)yghBu%-bXWegM`X>()-OOB3z#WKOY zIKMETJNNJbKxVct+7@g%%hG)QM1COOz2@AgYkO`!v(GSR=GvQ(l06!bjzOlapo8)Z zg}6d7b-z$25rzUi^mZ$9WQHUuT%!`dh8*}wBMT7ZtworEQeG|lNVQbgMvP0V+AF^$ z`rwL)5Q-cTpl{{7)31AS95I8g#1eWxtrQ8cCKwtLAwV9H*Y-h?YoP_Hgq4PfpdYt z>A-p2I(XwoaA?GTvwzrU*BNy#A?u734yih=s2Ddv7)*{z=M<<5p-EgrI<%6!tJ9*Q z>Ncn+G8Usj`Mwy{Ylh}xsyI7{aUAOwKU^T?rlVg~arkahE%%`knW6Y6;xPqo7-%J< zszRXg9IiqUB_6})VEhz3%6FhLXp67Oe)83WRd<0uUJ~BR3R@O?cJkr-t4hIga?4e} zH2dgM*1Til*yfU}Y17qGaJA$I)>?|Lj;wjhRkzZ!+>;y5&lX*0H(lKYSNHn8qN}Go z)xF%E>s-6A4wE0lh||64JW+6-$a{;5@G@K?ZTkL^k(M x$E9Rc?!)ANiR7Ac5)_1?s29lc0@b`g)?LO*HSRXB)UjQMm2&Pi*r-~=$KScP4%7eu From 6a4f019c6ec9e37ac7fe9cc47a4b2125bda7d040 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sun, 24 Aug 2025 08:08:24 +0000 Subject: [PATCH 4/6] Add platform information and optimization hint for code gen --- src/prompts/prompt_for_generation.py | 17 +++++++--- src/prompts/prompt_for_reflection.py | 51 ++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/prompts/prompt_for_generation.py b/src/prompts/prompt_for_generation.py index 6d07461..5c5f637 100644 --- a/src/prompts/prompt_for_generation.py +++ b/src/prompts/prompt_for_generation.py @@ -4,6 +4,15 @@ Your task is to generate a Python code snippet containing a Triton kernel based on the following request: **Target Platform:** AMD GPU (ROCm) +Here are basic AMD GPU hardware information: +1. Graphic architecture: gfx942 (CDNA 3) +2. Computer Unit Number: 304 +3. SIMD number per CU: 4 +4. Wave Front size: 64 +5. Avaliable GPU Device: 1 +6. Vector General Purpose Register (VGPR) size: 256 KB +7. Scalar General Purpose Register (SGPR) size : 64 KB +8. Local Data Share (LDS) size: 64 KB **Request:** {instruction} @@ -89,15 +98,15 @@ * **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`. 8. **Triton Version:** Assume Triton version 3.2.0 or later. 9. Maximize performance by exploring the following: -i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. -ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. +i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. +ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. Primary Autotuning Fields (Mandatory) 1. BLOCK_M, BLOCK_N, BLOCK_K * Tile sizes for GEMM or other tensor contractions. * Larger blocks improve compute density, but reduce grid-level parallelism. * Explore wide range of values like: - * BLOCK: [32, ..., 128, ..., 2048, ...] + * BLOCK: [32, ..., 128, ..., 2048, ...] * Adjust based on memory reuse and L2 cache locality. 2. num_stages=n * Controls pipeline depth for kernel execution. @@ -111,7 +120,7 @@ * If it is too low then underutilization -> kernel runs slow. * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. * You must choose a sweet spot by trying out integer range of 1 to 16. - * You MUST NOT try the range beyond 16, it is NOT VALID. + * You MUST NOT try the range beyond 16, it is NOT VALID. Examples of Autotuning Setup Here's how Triton kernels should be decorated to allow autotuning: * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. diff --git a/src/prompts/prompt_for_reflection.py b/src/prompts/prompt_for_reflection.py index fe3f936..d81f815 100644 --- a/src/prompts/prompt_for_reflection.py +++ b/src/prompts/prompt_for_reflection.py @@ -20,6 +20,57 @@ - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" +**Target Platform:** AMD GPU (ROCm) +Here are basic AMD GPU hardware information: +1. Graphic architecture: gfx942 (CDNA 3) +2. Computer Unit Number: 304 +3. SIMD number per CU: 4 +4. Wave Front size: 64 +5. Avaliable GPU Device: 1 +6. Vector General Purpose Register (VGPR) size: 256 KB +7. Scalar General Purpose Register (SGPR) size : 64 KB +8. Local Data Share (LDS) size: 64 KB + +Maximize performance by exploring the following: +i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. +ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. +iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. +Primary Autotuning Fields (Mandatory) +1. BLOCK_M, BLOCK_N, BLOCK_K + * Tile sizes for GEMM or other tensor contractions. + * Larger blocks improve compute density, but reduce grid-level parallelism. + * Explore wide range of values like: + * BLOCK: [32, ..., 128, ..., 2048, ...] + * Adjust based on memory reuse and L2 cache locality. +2. num_stages=n + * Controls pipeline depth for kernel execution. + * Rules for setting this: + * 1 if no GEMM. + * 2 if a single GEMM (e.g., GEMM + ReLU). + * 1 if two GEMMs are fused (e.g., Flash Attention). + * Optimize for latency and execution overlap. +3. num_warps + * Controls number of warps (groups of 64 threads) to launch per block. + * If it is too low then underutilization -> kernel runs slow. + * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. + * You must choose a sweet spot by trying out integer range of 1 to 16. + * You MUST NOT try the range beyond 16, it is NOT VALID. +Examples of Autotuning Setup +Here's how Triton kernels should be decorated to allow autotuning: + * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. + * BLOCK_M refers to the chunk of variable M that will be used for compute by a thread at a time. + * You must ensure that variables used in the triton.Config should not be passed as arguments to the triton kernel. +For example: the following autotune config receives BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, and num_stages as input arguments. Hence the triton kernel must not receive these arguments as inputs in the wrapper function. You must comment/delete any such instances. + +NOTE: If you face kernel timeout issues, check if Grid and Program ID Mismatch exists or not for example The kernel is launched with a 1-dimensional (1D) grid, but inside the kernel, it attempts to read program IDs from a 2-dimensional (2D) grid etc. + +def grid(args: dict[str, Any]) -> tuple[int]: + # This creates a 1D grid of size (C * D, ) + return (triton.cdiv(M, args["BLOCK_SIZE_M"]) * triton.cdiv(N, args["BLOCK_SIZE_N"]), ) + +The grid is calculated as a single integer, creating a 1D grid, however the kernel might try to get two separate program IDs, pid_m and pid_n, as if it were a 2D grid: +pid_m = tl.program_id(0) # Gets the ID for the first dimension +pid_n = tl.program_id(1) # Tries to get ID for a non-existent second dimension """ prompt_exe = """ From c15c24d5c747e9839a78a5242ec1828959073911 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sun, 24 Aug 2025 08:08:59 +0000 Subject: [PATCH 5/6] Try to use memory, fix format --- src/agents/reflexion_oneshot.py | 36 +++++++++++++++++++++------------ src/main_reflexion_oneshot.py | 2 ++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/agents/reflexion_oneshot.py b/src/agents/reflexion_oneshot.py index 6618305..02beb52 100644 --- a/src/agents/reflexion_oneshot.py +++ b/src/agents/reflexion_oneshot.py @@ -77,7 +77,7 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps", oneshot=input_mem["oneshot"], pass_call=input_mem["pass_call"], pass_exe=input_mem["pass_exe"], - code=os_mem["code"] + code=input_mem["code"] ) self.memories.append(tmp_mem) @@ -111,19 +111,18 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, To check for correctness against expected outputs, use the test_opt_correctness method from TritonBench: """ if output_path is not None: - root, extension = os.path.splitext(output_path) - tmp_dir = f"{root}_tmp_{iter}" - exe_dir = f"{root}_pass_exe_{iter}" - perf_result_dir = f"{root}_perf_results_{iter}" - perf_log_dir = f"{root}_perf_logs_{iter}" + root, extension = os.path.splitext(output_path) + tmp_dir = f"{root}_tmp_{iter}" + exe_dir = f"{root}_pass_exe_{iter}" + perf_result_dir = f"{root}_perf_results_{iter}" + perf_log_dir = f"{root}_perf_logs_{iter}" else: - tmp_dir = f"tmp_{iter}" - exe_dir = f"pass_exe_{iter}" - perf_result_dir = f"perf_results_{iter}" - perf_log_dir = f"perf_logs_{iter}" + tmp_dir = f"tmp_{iter}" + exe_dir = f"pass_exe_{iter}" + perf_result_dir = f"perf_results_{iter}" + perf_log_dir = f"perf_logs_{iter}" - #for fn, mems in tqdm(current_memories.items()): for mem in tqdm(self.memories[:data_len]): if mem.pass_exe: continue @@ -149,8 +148,6 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, """ To measure kernel latency, follow these steps: """ - #self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=script_dir) - #self.dataset.run_perf_scripts(gpu_id=gpu_id, script_dir=script_dir, log_dir=perf_log_dir) self.dataset.write_perf_file(input_folder_path=exe_dir, results_path=perf_result_dir, tmp_dir=tmp_dir) self.dataset.run_perf_scripts(script_dir=tmp_dir, log_dir=perf_log_dir) @@ -237,3 +234,16 @@ def generate_reflexion(self, mem, temperature): } ] mem.reflection = self.model.generate(reflect_msg, temperature=temperature) + + def save_memory(self, save_path="memory", datalen=None): + data_len = datalen if datalen else len(self.dataset) + for i, mem in enumerate(self.memories[:data_len]): + mem_dict = {} + mem_dict["err_msg"] = mem.err_msg + mem_dict["reflection"] = mem.reflection + mem_dict["function_signatures"] = mem.function_signatures + mem_dict["oneshot"] = mem.oneshot + mem_dict["pass_call"] = mem.pass_call + mem_dict["pass_exe"] = mem.pass_exe + mem_dict["code"] = mem.code + json.dump(f"{save_path}_file_{i}".json) diff --git a/src/main_reflexion_oneshot.py b/src/main_reflexion_oneshot.py index b5d4f9f..ac7fc1f 100644 --- a/src/main_reflexion_oneshot.py +++ b/src/main_reflexion_oneshot.py @@ -30,6 +30,8 @@ def main(): # run the agent agent.run(output_path=args.output_path, multi_thread=args.multi_thread, iteration_num=args.max_iteration, temperature=args.temperature, datalen=None) + # save memory for further fine-tune + # agent.save_memory(save_path="memory") if __name__ == "__main__": main() From 62c97b39390e1a6578be25e4b72d3f470a25a7aa Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Sun, 24 Aug 2025 08:20:21 +0000 Subject: [PATCH 6/6] revert change for reflexion_oneshot_ROCm.py --- src/agents/reflexion_oneshot_ROCm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/agents/reflexion_oneshot_ROCm.py b/src/agents/reflexion_oneshot_ROCm.py index 35b7cad..5e5b369 100644 --- a/src/agents/reflexion_oneshot_ROCm.py +++ b/src/agents/reflexion_oneshot_ROCm.py @@ -90,7 +90,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, logger.info(f"\ngenerate solution") with tqdm(total=data_len) as pbar: if multi_thread: - + with ThreadPoolExecutor(max_workers=thread_num) as executor: futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]} for future in as_completed(futures): @@ -114,7 +114,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, tmp_dir = f"{root}_tmp" exe_dir = f"{root}_pass_exe" perf_result_dir = f"{root}_perf_results" - breakpoint() + for mem in tqdm(self.memories[:data_len]): if mem.pass_exe: continue @@ -133,7 +133,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem.pass_exe = True """ To measure kernel speedup, follow these steps: - """ + root, extension = os.path.splitext(output_path) exe_dir = os.path.join(root,exe_dir) if output_path else exe_dir perf_result_dir = os.path.join(root, perf_result_dir) if output_path else perf_result_dir @@ -157,7 +157,9 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, assert len(passed_mems) == len(perf_results_list) for mem, perf_data in zip(passed_mems, perf_results_list): speedup = perf_data.get("ms") - efficiency = perf_data.get("efficiency") + efficiency = perf_data.get("efficiency") + """ + # generate reflections logger.info(f"\ngenerate reflections")