From 3354d2f21e6079cafbef92c4a1a243c1d5a7152f Mon Sep 17 00:00:00 2001 From: adonishong Date: Sun, 24 Aug 2025 07:40:18 +0000 Subject: [PATCH] adonishong --- src/__pycache__/args_config.cpython-312.pyc | Bin 556 -> 556 bytes src/agents/__pycache__/Base.cpython-312.pyc | Bin 4986 -> 4986 bytes .../__pycache__/Reflexion.cpython-312.pyc | Bin 2883 -> 2883 bytes .../reflexion_oneshot.cpython-312.pyc | Bin 9006 -> 10324 bytes src/agents/reflexion_oneshot.py | 80 ++-- src/agents/reflexion_oneshot_ROCm.py | 40 +- src/configs/tritonbench_oneshot_config.yaml | 5 +- .../TB_eval/__pycache__/utils.cpython-312.pyc | Bin 13229 -> 13229 bytes .../__pycache__/ProblemState.cpython-312.pyc | Bin 1187 -> 1187 bytes .../__pycache__/TritonBench.cpython-312.pyc | Bin 16573 -> 16573 bytes src/good/flash_decode2_phi.py | 143 +++++++ src/good/l2_norm_bwd.py | 110 ++++++ src/good/l2_norm_triton1.py | 93 +++++ src/good/matrix_transpose.py | 74 ++++ src/good/matrix_vector_multip.py | 74 ++++ src/good/rotary_transform.py | 196 ++++++++++ src/good/sin_kernel.py | 86 +++++ src/good/triton_matmul.py | 87 +++++ .../__pycache__/Memory.cpython-312.pyc | Bin 1368 -> 1368 bytes src/models/KimiK2.py | 1 + src/models/__pycache__/Base.cpython-312.pyc | Bin 631 -> 631 bytes src/models/__pycache__/KimiK2.cpython-312.pyc | Bin 2183 -> 2183 bytes src/pass_exe/embedding_triton_kernel.py | 102 +++++ src/pass_exe/flash_decode2_phi.py | 91 +++++ src/pass_exe/l2_norm_bwd.py | 55 +++ src/pass_exe/l2_norm_triton1.py | 49 +++ src/pass_exe/matrix_transpose.py | 47 +++ src/pass_exe/matrix_vector_multip.py | 47 +++ src/pass_exe/rotary_transform.py | 171 +++++++++ src/pass_exe/sin_kernel.py | 25 ++ src/pass_exe/triton_matmul.py | 83 +++++ .../prompt_for_generation.cpython-312.pyc | Bin 9969 -> 11051 bytes .../prompt_for_reflection.cpython-312.pyc | Bin 14265 -> 19711 bytes src/prompts/prompt_for_generation.py | 25 ++ src/prompts/prompt_for_reflection.py | 147 +++++++- .../__pycache__/retriever.cpython-312.pyc | Bin 3354 -> 3354 bytes src/soso/flash_decode2_phi.py | 145 ++++++++ src/soso/l2_norm_bwd.py | 112 ++++++ src/soso/l2_norm_triton1.py | 100 +++++ src/soso/matrix_transpose.py | 74 ++++ src/soso/matrix_vector_multip.py | 74 ++++ src/soso/rotary_transform.py | 194 ++++++++++ src/soso/sin_kernel.py | 86 +++++ src/soso/triton_matmul.py | 99 +++++ src/temp/embedding_triton_kernel.py | 144 +++++++ src/temp/flash_decode2_phi.py | 141 +++++++ ....py_gen_triton_code_155036.cpython-312.pyc | Bin 0 -> 5348 bytes ....py_gen_triton_code_176773.cpython-312.pyc | Bin 0 -> 5294 bytes ....py_gen_triton_code_180807.cpython-312.pyc | Bin 0 -> 5704 bytes ...l.py_gen_triton_code_18528.cpython-312.pyc | Bin 0 -> 5177 bytes ....py_gen_triton_code_200147.cpython-312.pyc | Bin 0 -> 5287 bytes ....py_gen_triton_code_211539.cpython-312.pyc | Bin 0 -> 5789 bytes ....py_gen_triton_code_322972.cpython-312.pyc | Bin 0 -> 4566 bytes ....py_gen_triton_code_347928.cpython-312.pyc | Bin 0 -> 4995 bytes ....py_gen_triton_code_355413.cpython-312.pyc | Bin 0 -> 4516 bytes ....py_gen_triton_code_429595.cpython-312.pyc | Bin 0 -> 5651 bytes ...l.py_gen_triton_code_43398.cpython-312.pyc | Bin 0 -> 5639 bytes ....py_gen_triton_code_459432.cpython-312.pyc | Bin 0 -> 5651 bytes ....py_gen_triton_code_474863.cpython-312.pyc | Bin 0 -> 4248 bytes ....py_gen_triton_code_477598.cpython-312.pyc | Bin 0 -> 5485 bytes ....py_gen_triton_code_480728.cpython-312.pyc | Bin 0 -> 4203 bytes ....py_gen_triton_code_490985.cpython-312.pyc | Bin 0 -> 5651 bytes ....py_gen_triton_code_507685.cpython-312.pyc | Bin 0 -> 5114 bytes ....py_gen_triton_code_524778.cpython-312.pyc | Bin 0 -> 5943 bytes ....py_gen_triton_code_533885.cpython-312.pyc | Bin 0 -> 5358 bytes ....py_gen_triton_code_552958.cpython-312.pyc | Bin 0 -> 5822 bytes ....py_gen_triton_code_574109.cpython-312.pyc | Bin 0 -> 4914 bytes ...l.py_gen_triton_code_58716.cpython-312.pyc | Bin 0 -> 5860 bytes ....py_gen_triton_code_600998.cpython-312.pyc | Bin 0 -> 5424 bytes ....py_gen_triton_code_605163.cpython-312.pyc | Bin 0 -> 4520 bytes ....py_gen_triton_code_620455.cpython-312.pyc | Bin 0 -> 4637 bytes ....py_gen_triton_code_635331.cpython-312.pyc | Bin 0 -> 5395 bytes ...l.py_gen_triton_code_64602.cpython-312.pyc | Bin 0 -> 4755 bytes ...l.py_gen_triton_code_68534.cpython-312.pyc | Bin 0 -> 5310 bytes ....py_gen_triton_code_713720.cpython-312.pyc | Bin 0 -> 5831 bytes ....py_gen_triton_code_721645.cpython-312.pyc | Bin 0 -> 5394 bytes ....py_gen_triton_code_759146.cpython-312.pyc | Bin 0 -> 4755 bytes ....py_gen_triton_code_764635.cpython-312.pyc | Bin 0 -> 5651 bytes ...l.py_gen_triton_code_76684.cpython-312.pyc | Bin 0 -> 4828 bytes ....py_gen_triton_code_804525.cpython-312.pyc | Bin 0 -> 5651 bytes ....py_gen_triton_code_823958.cpython-312.pyc | Bin 0 -> 5166 bytes ....py_gen_triton_code_830218.cpython-312.pyc | Bin 0 -> 5264 bytes ....py_gen_triton_code_837397.cpython-312.pyc | Bin 0 -> 5532 bytes ...l.py_gen_triton_code_92676.cpython-312.pyc | Bin 0 -> 6221 bytes ....py_gen_triton_code_940390.cpython-312.pyc | Bin 0 -> 5049 bytes ....py_gen_triton_code_965031.cpython-312.pyc | Bin 0 -> 5371 bytes ....py_gen_triton_code_984659.cpython-312.pyc | Bin 0 -> 4473 bytes ....py_gen_triton_code_992208.cpython-312.pyc | Bin 0 -> 5515 bytes ....py_gen_triton_code_126106.cpython-312.pyc | Bin 0 -> 6447 bytes ...i.py_gen_triton_code_14965.cpython-312.pyc | Bin 0 -> 6506 bytes ....py_gen_triton_code_198114.cpython-312.pyc | Bin 0 -> 6322 bytes ...i.py_gen_triton_code_23614.cpython-312.pyc | Bin 0 -> 6201 bytes ....py_gen_triton_code_269764.cpython-312.pyc | Bin 0 -> 6232 bytes ....py_gen_triton_code_335674.cpython-312.pyc | Bin 0 -> 6397 bytes ....py_gen_triton_code_349606.cpython-312.pyc | Bin 0 -> 6854 bytes ....py_gen_triton_code_369704.cpython-312.pyc | Bin 0 -> 6123 bytes ...i.py_gen_triton_code_38100.cpython-312.pyc | Bin 0 -> 6201 bytes ....py_gen_triton_code_405645.cpython-312.pyc | Bin 0 -> 6623 bytes ...i.py_gen_triton_code_42419.cpython-312.pyc | Bin 0 -> 6189 bytes ....py_gen_triton_code_450387.cpython-312.pyc | Bin 0 -> 6867 bytes ....py_gen_triton_code_506478.cpython-312.pyc | Bin 0 -> 6619 bytes ....py_gen_triton_code_543766.cpython-312.pyc | Bin 0 -> 6224 bytes ....py_gen_triton_code_560861.cpython-312.pyc | Bin 0 -> 6274 bytes ....py_gen_triton_code_576804.cpython-312.pyc | Bin 0 -> 6365 bytes ....py_gen_triton_code_653084.cpython-312.pyc | Bin 0 -> 6202 bytes ....py_gen_triton_code_661704.cpython-312.pyc | Bin 0 -> 6809 bytes ....py_gen_triton_code_684759.cpython-312.pyc | Bin 0 -> 6062 bytes ....py_gen_triton_code_690508.cpython-312.pyc | Bin 0 -> 6625 bytes ....py_gen_triton_code_720655.cpython-312.pyc | Bin 0 -> 6403 bytes ....py_gen_triton_code_721584.cpython-312.pyc | Bin 0 -> 6888 bytes ....py_gen_triton_code_735113.cpython-312.pyc | Bin 0 -> 6510 bytes ....py_gen_triton_code_739112.cpython-312.pyc | Bin 0 -> 6288 bytes ....py_gen_triton_code_754689.cpython-312.pyc | Bin 0 -> 6216 bytes ....py_gen_triton_code_802348.cpython-312.pyc | Bin 0 -> 6976 bytes ....py_gen_triton_code_812012.cpython-312.pyc | Bin 0 -> 6904 bytes ...i.py_gen_triton_code_83138.cpython-312.pyc | Bin 0 -> 6185 bytes ....py_gen_triton_code_870175.cpython-312.pyc | Bin 0 -> 6228 bytes ....py_gen_triton_code_882682.cpython-312.pyc | Bin 0 -> 6202 bytes ....py_gen_triton_code_900175.cpython-312.pyc | Bin 0 -> 6510 bytes ....py_gen_triton_code_925215.cpython-312.pyc | Bin 0 -> 6364 bytes ....py_gen_triton_code_959027.cpython-312.pyc | Bin 0 -> 6227 bytes ....py_gen_triton_code_124574.cpython-312.pyc | Bin 0 -> 10690 bytes ....py_gen_triton_code_178552.cpython-312.pyc | Bin 0 -> 11271 bytes ....py_gen_triton_code_216434.cpython-312.pyc | Bin 0 -> 12078 bytes ....py_gen_triton_code_219875.cpython-312.pyc | Bin 0 -> 11838 bytes ....py_gen_triton_code_243114.cpython-312.pyc | Bin 0 -> 13580 bytes ....py_gen_triton_code_291697.cpython-312.pyc | Bin 0 -> 12042 bytes ....py_gen_triton_code_298484.cpython-312.pyc | Bin 0 -> 15028 bytes ....py_gen_triton_code_308542.cpython-312.pyc | Bin 0 -> 11700 bytes ....py_gen_triton_code_312025.cpython-312.pyc | Bin 0 -> 13580 bytes ....py_gen_triton_code_357204.cpython-312.pyc | Bin 0 -> 10726 bytes ....py_gen_triton_code_365790.cpython-312.pyc | Bin 0 -> 10087 bytes ...l.py_gen_triton_code_41463.cpython-312.pyc | Bin 0 -> 17037 bytes ....py_gen_triton_code_430740.cpython-312.pyc | Bin 0 -> 14662 bytes ....py_gen_triton_code_434177.cpython-312.pyc | Bin 0 -> 13580 bytes ....py_gen_triton_code_461728.cpython-312.pyc | Bin 0 -> 11735 bytes ...l.py_gen_triton_code_48845.cpython-312.pyc | Bin 0 -> 10510 bytes ....py_gen_triton_code_490790.cpython-312.pyc | Bin 0 -> 10687 bytes ....py_gen_triton_code_511041.cpython-312.pyc | Bin 0 -> 11303 bytes ....py_gen_triton_code_512013.cpython-312.pyc | Bin 0 -> 10729 bytes ...l.py_gen_triton_code_52090.cpython-312.pyc | Bin 0 -> 11208 bytes ....py_gen_triton_code_530716.cpython-312.pyc | Bin 0 -> 12788 bytes ....py_gen_triton_code_635842.cpython-312.pyc | Bin 0 -> 10504 bytes ....py_gen_triton_code_718301.cpython-312.pyc | Bin 0 -> 11366 bytes ....py_gen_triton_code_731602.cpython-312.pyc | Bin 0 -> 9375 bytes ...l.py_gen_triton_code_76683.cpython-312.pyc | Bin 0 -> 11272 bytes ....py_gen_triton_code_769812.cpython-312.pyc | Bin 0 -> 11689 bytes ....py_gen_triton_code_790411.cpython-312.pyc | Bin 0 -> 12519 bytes ....py_gen_triton_code_811684.cpython-312.pyc | Bin 0 -> 13580 bytes ....py_gen_triton_code_815235.cpython-312.pyc | Bin 0 -> 14927 bytes ....py_gen_triton_code_816192.cpython-312.pyc | Bin 0 -> 16119 bytes ....py_gen_triton_code_838410.cpython-312.pyc | Bin 0 -> 10256 bytes ....py_gen_triton_code_886215.cpython-312.pyc | Bin 0 -> 11420 bytes ....py_gen_triton_code_891149.cpython-312.pyc | Bin 0 -> 10306 bytes ....py_gen_triton_code_912380.cpython-312.pyc | Bin 0 -> 11091 bytes ....py_gen_triton_code_925632.cpython-312.pyc | Bin 0 -> 16181 bytes ....py_gen_triton_code_927195.cpython-312.pyc | Bin 0 -> 14170 bytes ...l.py_gen_triton_code_93329.cpython-312.pyc | Bin 0 -> 13579 bytes ....py_gen_triton_code_942564.cpython-312.pyc | Bin 0 -> 10905 bytes ....py_gen_triton_code_977481.cpython-312.pyc | Bin 0 -> 10397 bytes ....py_gen_triton_code_991002.cpython-312.pyc | Bin 0 -> 10974 bytes ....py_gen_triton_code_995030.cpython-312.pyc | Bin 0 -> 11463 bytes ....py_gen_triton_code_143388.cpython-312.pyc | Bin 0 -> 4463 bytes ....py_gen_triton_code_167554.cpython-312.pyc | Bin 0 -> 4386 bytes ....py_gen_triton_code_215639.cpython-312.pyc | Bin 0 -> 4522 bytes ....py_gen_triton_code_220059.cpython-312.pyc | Bin 0 -> 3963 bytes ...d.py_gen_triton_code_28664.cpython-312.pyc | Bin 0 -> 4168 bytes ....py_gen_triton_code_338946.cpython-312.pyc | Bin 0 -> 5090 bytes ....py_gen_triton_code_347725.cpython-312.pyc | Bin 0 -> 3936 bytes ....py_gen_triton_code_387667.cpython-312.pyc | Bin 0 -> 4330 bytes ....py_gen_triton_code_404776.cpython-312.pyc | Bin 0 -> 4709 bytes ....py_gen_triton_code_414029.cpython-312.pyc | Bin 0 -> 4266 bytes ....py_gen_triton_code_419949.cpython-312.pyc | Bin 0 -> 4266 bytes ....py_gen_triton_code_433589.cpython-312.pyc | Bin 0 -> 4625 bytes ....py_gen_triton_code_459560.cpython-312.pyc | Bin 0 -> 5092 bytes ....py_gen_triton_code_486455.cpython-312.pyc | Bin 0 -> 4554 bytes ....py_gen_triton_code_493519.cpython-312.pyc | Bin 0 -> 4522 bytes ....py_gen_triton_code_570539.cpython-312.pyc | Bin 0 -> 4424 bytes ....py_gen_triton_code_597752.cpython-312.pyc | Bin 0 -> 4266 bytes ....py_gen_triton_code_637799.cpython-312.pyc | Bin 0 -> 4394 bytes ....py_gen_triton_code_640557.cpython-312.pyc | Bin 0 -> 4409 bytes ....py_gen_triton_code_712104.cpython-312.pyc | Bin 0 -> 4253 bytes ....py_gen_triton_code_786715.cpython-312.pyc | Bin 0 -> 4266 bytes ....py_gen_triton_code_827439.cpython-312.pyc | Bin 0 -> 4442 bytes ....py_gen_triton_code_843690.cpython-312.pyc | Bin 0 -> 4395 bytes ....py_gen_triton_code_864396.cpython-312.pyc | Bin 0 -> 5150 bytes ....py_gen_triton_code_885795.cpython-312.pyc | Bin 0 -> 4384 bytes ....py_gen_triton_code_960121.cpython-312.pyc | Bin 0 -> 4428 bytes ....py_gen_triton_code_972847.cpython-312.pyc | Bin 0 -> 4431 bytes ....py_gen_triton_code_212491.cpython-312.pyc | Bin 0 -> 3287 bytes ....py_gen_triton_code_254823.cpython-312.pyc | Bin 0 -> 3485 bytes ....py_gen_triton_code_318959.cpython-312.pyc | Bin 0 -> 3477 bytes ....py_gen_triton_code_336206.cpython-312.pyc | Bin 0 -> 3591 bytes ....py_gen_triton_code_357644.cpython-312.pyc | Bin 0 -> 3540 bytes ....py_gen_triton_code_392963.cpython-312.pyc | Bin 0 -> 3577 bytes ....py_gen_triton_code_403404.cpython-312.pyc | Bin 0 -> 3782 bytes ....py_gen_triton_code_466457.cpython-312.pyc | Bin 0 -> 3577 bytes ....py_gen_triton_code_598128.cpython-312.pyc | Bin 0 -> 3465 bytes ....py_gen_triton_code_599125.cpython-312.pyc | Bin 0 -> 3443 bytes ....py_gen_triton_code_637798.cpython-312.pyc | Bin 0 -> 3354 bytes ....py_gen_triton_code_650964.cpython-312.pyc | Bin 0 -> 3507 bytes ....py_gen_triton_code_674736.cpython-312.pyc | Bin 0 -> 3425 bytes ....py_gen_triton_code_786517.cpython-312.pyc | Bin 0 -> 3577 bytes ....py_gen_triton_code_800477.cpython-312.pyc | Bin 0 -> 3577 bytes ....py_gen_triton_code_839169.cpython-312.pyc | Bin 0 -> 3741 bytes ....py_gen_triton_code_846578.cpython-312.pyc | Bin 0 -> 3434 bytes ....py_gen_triton_code_964700.cpython-312.pyc | Bin 0 -> 3577 bytes ....py_gen_triton_code_965300.cpython-312.pyc | Bin 0 -> 3663 bytes ....py_gen_triton_code_973282.cpython-312.pyc | Bin 0 -> 3741 bytes ....py_gen_triton_code_114093.cpython-312.pyc | Bin 0 -> 3354 bytes ...e.py_gen_triton_code_11496.cpython-312.pyc | Bin 0 -> 2632 bytes ...e.py_gen_triton_code_14792.cpython-312.pyc | Bin 0 -> 3754 bytes ....py_gen_triton_code_160821.cpython-312.pyc | Bin 0 -> 3042 bytes ....py_gen_triton_code_205496.cpython-312.pyc | Bin 0 -> 2747 bytes ....py_gen_triton_code_216901.cpython-312.pyc | Bin 0 -> 2331 bytes ....py_gen_triton_code_274099.cpython-312.pyc | Bin 0 -> 2950 bytes ....py_gen_triton_code_369711.cpython-312.pyc | Bin 0 -> 3264 bytes ....py_gen_triton_code_412290.cpython-312.pyc | Bin 0 -> 2521 bytes ....py_gen_triton_code_429164.cpython-312.pyc | Bin 0 -> 2909 bytes ....py_gen_triton_code_469771.cpython-312.pyc | Bin 0 -> 3181 bytes ....py_gen_triton_code_493615.cpython-312.pyc | Bin 0 -> 2470 bytes ....py_gen_triton_code_529486.cpython-312.pyc | Bin 0 -> 2323 bytes ....py_gen_triton_code_571713.cpython-312.pyc | Bin 0 -> 2713 bytes ....py_gen_triton_code_580037.cpython-312.pyc | Bin 0 -> 3061 bytes ....py_gen_triton_code_608628.cpython-312.pyc | Bin 0 -> 2747 bytes ....py_gen_triton_code_619005.cpython-312.pyc | Bin 0 -> 2686 bytes ....py_gen_triton_code_620806.cpython-312.pyc | Bin 0 -> 3307 bytes ....py_gen_triton_code_671609.cpython-312.pyc | Bin 0 -> 3053 bytes ....py_gen_triton_code_724790.cpython-312.pyc | Bin 0 -> 2938 bytes ....py_gen_triton_code_738982.cpython-312.pyc | Bin 0 -> 2747 bytes ...e.py_gen_triton_code_74175.cpython-312.pyc | Bin 0 -> 2746 bytes ....py_gen_triton_code_757083.cpython-312.pyc | Bin 0 -> 3146 bytes ....py_gen_triton_code_759138.cpython-312.pyc | Bin 0 -> 2872 bytes ....py_gen_triton_code_780911.cpython-312.pyc | Bin 0 -> 3002 bytes ....py_gen_triton_code_783719.cpython-312.pyc | Bin 0 -> 2925 bytes ...e.py_gen_triton_code_81159.cpython-312.pyc | Bin 0 -> 3121 bytes ....py_gen_triton_code_853096.cpython-312.pyc | Bin 0 -> 2793 bytes ....py_gen_triton_code_869907.cpython-312.pyc | Bin 0 -> 3087 bytes ....py_gen_triton_code_879575.cpython-312.pyc | Bin 0 -> 2996 bytes ....py_gen_triton_code_892743.cpython-312.pyc | Bin 0 -> 2610 bytes ....py_gen_triton_code_917011.cpython-312.pyc | Bin 0 -> 2737 bytes ....py_gen_triton_code_930305.cpython-312.pyc | Bin 0 -> 3115 bytes ....py_gen_triton_code_953212.cpython-312.pyc | Bin 0 -> 2569 bytes ....py_gen_triton_code_984648.cpython-312.pyc | Bin 0 -> 2593 bytes ....py_gen_triton_code_997014.cpython-312.pyc | Bin 0 -> 2822 bytes ....py_gen_triton_code_164112.cpython-312.pyc | Bin 0 -> 3896 bytes ....py_gen_triton_code_205689.cpython-312.pyc | Bin 0 -> 4519 bytes ....py_gen_triton_code_334537.cpython-312.pyc | Bin 0 -> 3885 bytes ....py_gen_triton_code_370413.cpython-312.pyc | Bin 0 -> 4204 bytes ....py_gen_triton_code_424820.cpython-312.pyc | Bin 0 -> 4081 bytes ....py_gen_triton_code_554113.cpython-312.pyc | Bin 0 -> 3711 bytes ....py_gen_triton_code_554981.cpython-312.pyc | Bin 0 -> 4097 bytes ....py_gen_triton_code_561330.cpython-312.pyc | Bin 0 -> 3936 bytes ....py_gen_triton_code_686366.cpython-312.pyc | Bin 0 -> 4136 bytes ...p.py_gen_triton_code_80693.cpython-312.pyc | Bin 0 -> 4032 bytes ....py_gen_triton_code_105954.cpython-312.pyc | Bin 0 -> 12470 bytes ....py_gen_triton_code_260701.cpython-312.pyc | Bin 0 -> 8238 bytes ....py_gen_triton_code_329295.cpython-312.pyc | Bin 0 -> 10476 bytes ....py_gen_triton_code_338032.cpython-312.pyc | Bin 0 -> 11811 bytes ....py_gen_triton_code_339628.cpython-312.pyc | Bin 0 -> 10978 bytes ....py_gen_triton_code_344391.cpython-312.pyc | Bin 0 -> 8238 bytes ....py_gen_triton_code_373163.cpython-312.pyc | Bin 0 -> 10654 bytes ....py_gen_triton_code_385268.cpython-312.pyc | Bin 0 -> 11620 bytes ....py_gen_triton_code_405620.cpython-312.pyc | Bin 0 -> 9027 bytes ....py_gen_triton_code_431864.cpython-312.pyc | Bin 0 -> 8238 bytes ...m.py_gen_triton_code_44150.cpython-312.pyc | Bin 0 -> 8781 bytes ....py_gen_triton_code_450091.cpython-312.pyc | Bin 0 -> 10876 bytes ....py_gen_triton_code_460195.cpython-312.pyc | Bin 0 -> 11369 bytes ....py_gen_triton_code_527413.cpython-312.pyc | Bin 0 -> 11712 bytes ....py_gen_triton_code_540784.cpython-312.pyc | Bin 0 -> 10442 bytes ....py_gen_triton_code_555768.cpython-312.pyc | Bin 0 -> 13325 bytes ....py_gen_triton_code_634902.cpython-312.pyc | Bin 0 -> 11124 bytes ....py_gen_triton_code_669031.cpython-312.pyc | Bin 0 -> 10630 bytes ....py_gen_triton_code_711258.cpython-312.pyc | Bin 0 -> 7884 bytes ....py_gen_triton_code_816058.cpython-312.pyc | Bin 0 -> 7990 bytes ....py_gen_triton_code_824557.cpython-312.pyc | Bin 0 -> 11368 bytes ....py_gen_triton_code_840463.cpython-312.pyc | Bin 0 -> 7337 bytes ....py_gen_triton_code_843724.cpython-312.pyc | Bin 0 -> 8238 bytes ....py_gen_triton_code_893238.cpython-312.pyc | Bin 0 -> 9022 bytes ....py_gen_triton_code_915460.cpython-312.pyc | Bin 0 -> 10519 bytes ....py_gen_triton_code_925133.cpython-312.pyc | Bin 0 -> 11312 bytes ....py_gen_triton_code_939610.cpython-312.pyc | Bin 0 -> 11631 bytes ....py_gen_triton_code_946209.cpython-312.pyc | Bin 0 -> 8724 bytes ...m.py_gen_triton_code_99563.cpython-312.pyc | Bin 0 -> 8967 bytes ....py_gen_triton_code_123151.cpython-312.pyc | Bin 0 -> 2855 bytes ....py_gen_triton_code_179581.cpython-312.pyc | Bin 0 -> 2836 bytes ....py_gen_triton_code_370053.cpython-312.pyc | Bin 0 -> 2846 bytes ....py_gen_triton_code_473025.cpython-312.pyc | Bin 0 -> 2857 bytes ....py_gen_triton_code_502063.cpython-312.pyc | Bin 0 -> 2910 bytes ...l.py_gen_triton_code_50482.cpython-312.pyc | Bin 0 -> 2920 bytes ....py_gen_triton_code_557502.cpython-312.pyc | Bin 0 -> 2847 bytes ....py_gen_triton_code_560359.cpython-312.pyc | Bin 0 -> 2851 bytes ....py_gen_triton_code_794865.cpython-312.pyc | Bin 0 -> 2940 bytes ....py_gen_triton_code_834634.cpython-312.pyc | Bin 0 -> 2862 bytes ....py_gen_triton_code_931009.cpython-312.pyc | Bin 0 -> 2862 bytes ....py_gen_triton_code_108037.cpython-312.pyc | Bin 0 -> 5497 bytes ...l.py_gen_triton_code_12912.cpython-312.pyc | Bin 0 -> 5241 bytes ....py_gen_triton_code_186313.cpython-312.pyc | Bin 0 -> 5366 bytes ....py_gen_triton_code_284744.cpython-312.pyc | Bin 0 -> 5000 bytes ....py_gen_triton_code_366643.cpython-312.pyc | Bin 0 -> 5531 bytes ....py_gen_triton_code_391924.cpython-312.pyc | Bin 0 -> 6371 bytes ....py_gen_triton_code_395140.cpython-312.pyc | Bin 0 -> 5126 bytes ....py_gen_triton_code_417385.cpython-312.pyc | Bin 0 -> 5271 bytes ....py_gen_triton_code_654780.cpython-312.pyc | Bin 0 -> 5407 bytes ....py_gen_triton_code_769893.cpython-312.pyc | Bin 0 -> 5407 bytes ....py_gen_triton_code_993568.cpython-312.pyc | Bin 0 -> 5481 bytes ...triton_kernel.py_gen_triton_code_155036.py | 214 +++++++++++ ...kernel.py_gen_triton_code_155036.py.stderr | 0 ...kernel.py_gen_triton_code_155036.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_176773.py | 214 +++++++++++ ...kernel.py_gen_triton_code_176773.py.stderr | 0 ...kernel.py_gen_triton_code_176773.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_180807.py | 244 ++++++++++++ ...kernel.py_gen_triton_code_180807.py.stderr | 0 ...kernel.py_gen_triton_code_180807.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_18528.py | 195 ++++++++++ ..._kernel.py_gen_triton_code_18528.py.stderr | 0 ..._kernel.py_gen_triton_code_18528.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_200147.py | 235 ++++++++++++ ...kernel.py_gen_triton_code_200147.py.stderr | 0 ...kernel.py_gen_triton_code_200147.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_211539.py | 205 ++++++++++ ...kernel.py_gen_triton_code_211539.py.stderr | 0 ...kernel.py_gen_triton_code_211539.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_322972.py | 181 +++++++++ ...kernel.py_gen_triton_code_322972.py.stderr | 0 ...kernel.py_gen_triton_code_322972.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_347928.py | 198 ++++++++++ ...kernel.py_gen_triton_code_347928.py.stderr | 0 ...kernel.py_gen_triton_code_347928.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_355413.py | 195 ++++++++++ ...kernel.py_gen_triton_code_355413.py.stderr | 0 ...kernel.py_gen_triton_code_355413.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_429595.py | 190 ++++++++++ ...kernel.py_gen_triton_code_429595.py.stderr | 0 ...kernel.py_gen_triton_code_429595.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_43398.py | 222 +++++++++++ ..._kernel.py_gen_triton_code_43398.py.stderr | 0 ..._kernel.py_gen_triton_code_43398.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_459432.py | 190 ++++++++++ ...kernel.py_gen_triton_code_459432.py.stderr | 0 ...kernel.py_gen_triton_code_459432.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_474863.py | 189 ++++++++++ ...kernel.py_gen_triton_code_474863.py.stderr | 0 ...kernel.py_gen_triton_code_474863.py.stdout | 15 + ...triton_kernel.py_gen_triton_code_477598.py | 218 +++++++++++ ...kernel.py_gen_triton_code_477598.py.stderr | 0 ...kernel.py_gen_triton_code_477598.py.stdout | 14 + ...triton_kernel.py_gen_triton_code_480728.py | 187 ++++++++++ ...kernel.py_gen_triton_code_480728.py.stderr | 0 ...kernel.py_gen_triton_code_480728.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_490985.py | 190 ++++++++++ ...kernel.py_gen_triton_code_490985.py.stderr | 0 ...kernel.py_gen_triton_code_490985.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_507685.py | 228 ++++++++++++ ...kernel.py_gen_triton_code_507685.py.stderr | 0 ...kernel.py_gen_triton_code_507685.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_524778.py | 226 +++++++++++ ...kernel.py_gen_triton_code_524778.py.stderr | 0 ...kernel.py_gen_triton_code_524778.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_533885.py | 188 ++++++++++ ...kernel.py_gen_triton_code_533885.py.stderr | 0 ...kernel.py_gen_triton_code_533885.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_552958.py | 231 ++++++++++++ ...kernel.py_gen_triton_code_552958.py.stderr | 0 ...kernel.py_gen_triton_code_552958.py.stdout | 15 + ...triton_kernel.py_gen_triton_code_574109.py | 210 +++++++++++ ...kernel.py_gen_triton_code_574109.py.stderr | 0 ...kernel.py_gen_triton_code_574109.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_58716.py | 225 +++++++++++ ..._kernel.py_gen_triton_code_58716.py.stderr | 0 ..._kernel.py_gen_triton_code_58716.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_600998.py | 182 +++++++++ ...kernel.py_gen_triton_code_600998.py.stderr | 0 ...kernel.py_gen_triton_code_600998.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_605163.py | 190 ++++++++++ ...kernel.py_gen_triton_code_605163.py.stderr | 0 ...kernel.py_gen_triton_code_605163.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_620455.py | 178 +++++++++ ...kernel.py_gen_triton_code_620455.py.stderr | 0 ...kernel.py_gen_triton_code_620455.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_635331.py | 182 +++++++++ ...kernel.py_gen_triton_code_635331.py.stderr | 0 ...kernel.py_gen_triton_code_635331.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_64602.py | 189 ++++++++++ ..._kernel.py_gen_triton_code_64602.py.stderr | 0 ..._kernel.py_gen_triton_code_64602.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_68534.py | 205 ++++++++++ ..._kernel.py_gen_triton_code_68534.py.stderr | 0 ..._kernel.py_gen_triton_code_68534.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_713720.py | 191 ++++++++++ ...kernel.py_gen_triton_code_713720.py.stderr | 0 ...kernel.py_gen_triton_code_713720.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_721645.py | 232 ++++++++++++ ...kernel.py_gen_triton_code_721645.py.stderr | 0 ...kernel.py_gen_triton_code_721645.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_759146.py | 181 +++++++++ ...kernel.py_gen_triton_code_759146.py.stderr | 0 ...kernel.py_gen_triton_code_759146.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_764635.py | 197 ++++++++++ ...kernel.py_gen_triton_code_764635.py.stderr | 0 ...kernel.py_gen_triton_code_764635.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_76684.py | 188 ++++++++++ ..._kernel.py_gen_triton_code_76684.py.stderr | 0 ..._kernel.py_gen_triton_code_76684.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_804525.py | 190 ++++++++++ ...kernel.py_gen_triton_code_804525.py.stderr | 0 ...kernel.py_gen_triton_code_804525.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_823958.py | 195 ++++++++++ ...kernel.py_gen_triton_code_823958.py.stderr | 0 ...kernel.py_gen_triton_code_823958.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_830218.py | 195 ++++++++++ ...kernel.py_gen_triton_code_830218.py.stderr | 0 ...kernel.py_gen_triton_code_830218.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_837397.py | 250 +++++++++++++ ...kernel.py_gen_triton_code_837397.py.stderr | 0 ...kernel.py_gen_triton_code_837397.py.stdout | 1 + ..._triton_kernel.py_gen_triton_code_92676.py | 220 +++++++++++ ..._kernel.py_gen_triton_code_92676.py.stderr | 0 ..._kernel.py_gen_triton_code_92676.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_940390.py | 190 ++++++++++ ...kernel.py_gen_triton_code_940390.py.stderr | 0 ...kernel.py_gen_triton_code_940390.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_965031.py | 196 ++++++++++ ...kernel.py_gen_triton_code_965031.py.stderr | 0 ...kernel.py_gen_triton_code_965031.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_984659.py | 174 +++++++++ ...kernel.py_gen_triton_code_984659.py.stderr | 0 ...kernel.py_gen_triton_code_984659.py.stdout | 1 + ...triton_kernel.py_gen_triton_code_992208.py | 227 +++++++++++ ...kernel.py_gen_triton_code_992208.py.stderr | 0 ...kernel.py_gen_triton_code_992208.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_126106.py | 201 ++++++++++ ...e2_phi.py_gen_triton_code_126106.py.stderr | 0 ...e2_phi.py_gen_triton_code_126106.py.stdout | 1 + ...sh_decode2_phi.py_gen_triton_code_14965.py | 211 +++++++++++ ...de2_phi.py_gen_triton_code_14965.py.stderr | 0 ...de2_phi.py_gen_triton_code_14965.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_198114.py | 210 +++++++++++ ...e2_phi.py_gen_triton_code_198114.py.stderr | 0 ...e2_phi.py_gen_triton_code_198114.py.stdout | 1 + ...sh_decode2_phi.py_gen_triton_code_23614.py | 203 ++++++++++ ...de2_phi.py_gen_triton_code_23614.py.stderr | 0 ...de2_phi.py_gen_triton_code_23614.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_269764.py | 212 +++++++++++ ...e2_phi.py_gen_triton_code_269764.py.stderr | 0 ...e2_phi.py_gen_triton_code_269764.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_335674.py | 198 ++++++++++ ...e2_phi.py_gen_triton_code_335674.py.stderr | 0 ...e2_phi.py_gen_triton_code_335674.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_349606.py | 225 +++++++++++ ...e2_phi.py_gen_triton_code_349606.py.stderr | 0 ...e2_phi.py_gen_triton_code_349606.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_369704.py | 208 +++++++++++ ...e2_phi.py_gen_triton_code_369704.py.stderr | 0 ...e2_phi.py_gen_triton_code_369704.py.stdout | 1 + ...sh_decode2_phi.py_gen_triton_code_38100.py | 203 ++++++++++ ...de2_phi.py_gen_triton_code_38100.py.stderr | 0 ...de2_phi.py_gen_triton_code_38100.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_405645.py | 214 +++++++++++ ...e2_phi.py_gen_triton_code_405645.py.stderr | 0 ...e2_phi.py_gen_triton_code_405645.py.stdout | 1 + ...sh_decode2_phi.py_gen_triton_code_42419.py | 200 ++++++++++ ...de2_phi.py_gen_triton_code_42419.py.stderr | 0 ...de2_phi.py_gen_triton_code_42419.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_450387.py | 224 +++++++++++ ...e2_phi.py_gen_triton_code_450387.py.stderr | 0 ...e2_phi.py_gen_triton_code_450387.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_506478.py | 212 +++++++++++ ...e2_phi.py_gen_triton_code_506478.py.stderr | 0 ...e2_phi.py_gen_triton_code_506478.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_543766.py | 209 +++++++++++ ...e2_phi.py_gen_triton_code_543766.py.stderr | 0 ...e2_phi.py_gen_triton_code_543766.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_560861.py | 212 +++++++++++ ...e2_phi.py_gen_triton_code_560861.py.stderr | 0 ...e2_phi.py_gen_triton_code_560861.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_576804.py | 197 ++++++++++ ...e2_phi.py_gen_triton_code_576804.py.stderr | 0 ...e2_phi.py_gen_triton_code_576804.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_653084.py | 203 ++++++++++ ...e2_phi.py_gen_triton_code_653084.py.stderr | 0 ...e2_phi.py_gen_triton_code_653084.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_661704.py | 223 +++++++++++ ...e2_phi.py_gen_triton_code_661704.py.stderr | 0 ...e2_phi.py_gen_triton_code_661704.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_684759.py | 213 +++++++++++ ...e2_phi.py_gen_triton_code_684759.py.stderr | 0 ...e2_phi.py_gen_triton_code_684759.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_690508.py | 214 +++++++++++ ...e2_phi.py_gen_triton_code_690508.py.stderr | 0 ...e2_phi.py_gen_triton_code_690508.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_720655.py | 214 +++++++++++ ...e2_phi.py_gen_triton_code_720655.py.stderr | 0 ...e2_phi.py_gen_triton_code_720655.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_721584.py | 219 +++++++++++ ...e2_phi.py_gen_triton_code_721584.py.stderr | 0 ...e2_phi.py_gen_triton_code_721584.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_735113.py | 214 +++++++++++ ...e2_phi.py_gen_triton_code_735113.py.stderr | 0 ...e2_phi.py_gen_triton_code_735113.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_739112.py | 212 +++++++++++ ...e2_phi.py_gen_triton_code_739112.py.stderr | 0 ...e2_phi.py_gen_triton_code_739112.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_754689.py | 218 +++++++++++ ...e2_phi.py_gen_triton_code_754689.py.stderr | 0 ...e2_phi.py_gen_triton_code_754689.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_802348.py | 241 ++++++++++++ ...e2_phi.py_gen_triton_code_802348.py.stderr | 0 ...e2_phi.py_gen_triton_code_802348.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_812012.py | 226 +++++++++++ ...e2_phi.py_gen_triton_code_812012.py.stderr | 0 ...e2_phi.py_gen_triton_code_812012.py.stdout | 1 + ...sh_decode2_phi.py_gen_triton_code_83138.py | 202 ++++++++++ ...de2_phi.py_gen_triton_code_83138.py.stderr | 0 ...de2_phi.py_gen_triton_code_83138.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_870175.py | 206 ++++++++++ ...e2_phi.py_gen_triton_code_870175.py.stderr | 0 ...e2_phi.py_gen_triton_code_870175.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_882682.py | 203 ++++++++++ ...e2_phi.py_gen_triton_code_882682.py.stderr | 0 ...e2_phi.py_gen_triton_code_882682.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_900175.py | 214 +++++++++++ ...e2_phi.py_gen_triton_code_900175.py.stderr | 0 ...e2_phi.py_gen_triton_code_900175.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_925215.py | 218 +++++++++++ ...e2_phi.py_gen_triton_code_925215.py.stderr | 0 ...e2_phi.py_gen_triton_code_925215.py.stdout | 1 + ...h_decode2_phi.py_gen_triton_code_959027.py | 193 ++++++++++ ...e2_phi.py_gen_triton_code_959027.py.stderr | 0 ...e2_phi.py_gen_triton_code_959027.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_124574.py | 197 ++++++++++ ...matmul.py_gen_triton_code_124574.py.stderr | 2 + ...matmul.py_gen_triton_code_124574.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_178552.py | 173 +++++++++ ...matmul.py_gen_triton_code_178552.py.stderr | 0 ...matmul.py_gen_triton_code_178552.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_216434.py | 217 +++++++++++ ...matmul.py_gen_triton_code_216434.py.stderr | 0 ...matmul.py_gen_triton_code_216434.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_219875.py | 216 +++++++++++ ...matmul.py_gen_triton_code_219875.py.stderr | 0 ...matmul.py_gen_triton_code_219875.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_243114.py | 250 +++++++++++++ ...matmul.py_gen_triton_code_243114.py.stderr | 0 ...matmul.py_gen_triton_code_243114.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_291697.py | 257 +++++++++++++ ...matmul.py_gen_triton_code_291697.py.stderr | 0 ...matmul.py_gen_triton_code_291697.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_298484.py | 290 +++++++++++++++ ...matmul.py_gen_triton_code_298484.py.stderr | 0 ...matmul.py_gen_triton_code_298484.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_308542.py | 215 +++++++++++ ...matmul.py_gen_triton_code_308542.py.stderr | 0 ...matmul.py_gen_triton_code_308542.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_312025.py | 250 +++++++++++++ ...matmul.py_gen_triton_code_312025.py.stderr | 0 ...matmul.py_gen_triton_code_312025.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_357204.py | 226 +++++++++++ ...matmul.py_gen_triton_code_357204.py.stderr | 0 ...matmul.py_gen_triton_code_357204.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_365790.py | 177 +++++++++ ...matmul.py_gen_triton_code_365790.py.stderr | 0 ...matmul.py_gen_triton_code_365790.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_41463.py | 313 ++++++++++++++++ ..._matmul.py_gen_triton_code_41463.py.stderr | 0 ..._matmul.py_gen_triton_code_41463.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_430740.py | 252 +++++++++++++ ...matmul.py_gen_triton_code_430740.py.stderr | 2 + ...matmul.py_gen_triton_code_430740.py.stdout | 15 + .../int4_matmul.py_gen_triton_code_434177.py | 250 +++++++++++++ ...matmul.py_gen_triton_code_434177.py.stderr | 0 ...matmul.py_gen_triton_code_434177.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_461728.py | 241 ++++++++++++ ...matmul.py_gen_triton_code_461728.py.stderr | 0 ...matmul.py_gen_triton_code_461728.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_48845.py | 190 ++++++++++ ..._matmul.py_gen_triton_code_48845.py.stderr | 0 ..._matmul.py_gen_triton_code_48845.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_490790.py | 198 ++++++++++ ...matmul.py_gen_triton_code_490790.py.stderr | 0 ...matmul.py_gen_triton_code_490790.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_511041.py | 203 ++++++++++ ...matmul.py_gen_triton_code_511041.py.stderr | 0 ...matmul.py_gen_triton_code_511041.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_512013.py | 199 ++++++++++ ...matmul.py_gen_triton_code_512013.py.stderr | 0 ...matmul.py_gen_triton_code_512013.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_52090.py | 180 +++++++++ ..._matmul.py_gen_triton_code_52090.py.stderr | 0 ..._matmul.py_gen_triton_code_52090.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_530716.py | 255 +++++++++++++ ...matmul.py_gen_triton_code_530716.py.stderr | 0 ...matmul.py_gen_triton_code_530716.py.stdout | 15 + .../int4_matmul.py_gen_triton_code_635842.py | 205 ++++++++++ ...matmul.py_gen_triton_code_635842.py.stderr | 0 ...matmul.py_gen_triton_code_635842.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_718301.py | 185 +++++++++ ...matmul.py_gen_triton_code_718301.py.stderr | 0 ...matmul.py_gen_triton_code_718301.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_731602.py | 193 ++++++++++ ...matmul.py_gen_triton_code_731602.py.stderr | 0 ...matmul.py_gen_triton_code_731602.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_732866.py | 250 +++++++++++++ .../int4_matmul.py_gen_triton_code_76683.py | 202 ++++++++++ ..._matmul.py_gen_triton_code_76683.py.stderr | 0 ..._matmul.py_gen_triton_code_76683.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_769812.py | 242 ++++++++++++ ...matmul.py_gen_triton_code_769812.py.stderr | 0 ...matmul.py_gen_triton_code_769812.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_790411.py | 249 +++++++++++++ ...matmul.py_gen_triton_code_790411.py.stderr | 0 ...matmul.py_gen_triton_code_790411.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_811684.py | 250 +++++++++++++ ...matmul.py_gen_triton_code_811684.py.stderr | 0 ...matmul.py_gen_triton_code_811684.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_815235.py | 312 ++++++++++++++++ ...matmul.py_gen_triton_code_815235.py.stderr | 0 ...matmul.py_gen_triton_code_815235.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_816192.py | 299 +++++++++++++++ ...matmul.py_gen_triton_code_816192.py.stderr | 0 ...matmul.py_gen_triton_code_816192.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_838410.py | 215 +++++++++++ ...matmul.py_gen_triton_code_838410.py.stderr | 0 ...matmul.py_gen_triton_code_838410.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_865534.py | 313 ++++++++++++++++ .../int4_matmul.py_gen_triton_code_886215.py | 231 ++++++++++++ ...matmul.py_gen_triton_code_886215.py.stderr | 0 ...matmul.py_gen_triton_code_886215.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_891149.py | 243 ++++++++++++ ...matmul.py_gen_triton_code_891149.py.stderr | 0 ...matmul.py_gen_triton_code_891149.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_912380.py | 232 ++++++++++++ ...matmul.py_gen_triton_code_912380.py.stderr | 0 ...matmul.py_gen_triton_code_912380.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_925632.py | 299 +++++++++++++++ ...matmul.py_gen_triton_code_925632.py.stderr | 0 ...matmul.py_gen_triton_code_925632.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_927195.py | 277 ++++++++++++++ ...matmul.py_gen_triton_code_927195.py.stderr | 2 + ...matmul.py_gen_triton_code_927195.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_93329.py | 250 +++++++++++++ ..._matmul.py_gen_triton_code_93329.py.stderr | 0 ..._matmul.py_gen_triton_code_93329.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_942564.py | 229 ++++++++++++ ...matmul.py_gen_triton_code_942564.py.stderr | 0 ...matmul.py_gen_triton_code_942564.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_977481.py | 205 ++++++++++ ...matmul.py_gen_triton_code_977481.py.stderr | 0 ...matmul.py_gen_triton_code_977481.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_991002.py | 214 +++++++++++ ...matmul.py_gen_triton_code_991002.py.stderr | 0 ...matmul.py_gen_triton_code_991002.py.stdout | 1 + .../int4_matmul.py_gen_triton_code_995030.py | 234 ++++++++++++ ...matmul.py_gen_triton_code_995030.py.stderr | 0 ...matmul.py_gen_triton_code_995030.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_143388.py | 137 +++++++ ...rm_bwd.py_gen_triton_code_143388.py.stderr | 0 ...rm_bwd.py_gen_triton_code_143388.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_167554.py | 139 +++++++ ...rm_bwd.py_gen_triton_code_167554.py.stderr | 0 ...rm_bwd.py_gen_triton_code_167554.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_215639.py | 139 +++++++ ...rm_bwd.py_gen_triton_code_215639.py.stderr | 0 ...rm_bwd.py_gen_triton_code_215639.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_220059.py | 148 ++++++++ ...rm_bwd.py_gen_triton_code_220059.py.stderr | 0 ...rm_bwd.py_gen_triton_code_220059.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_28664.py | 140 +++++++ ...orm_bwd.py_gen_triton_code_28664.py.stderr | 0 ...orm_bwd.py_gen_triton_code_28664.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_338946.py | 147 ++++++++ ...rm_bwd.py_gen_triton_code_338946.py.stderr | 0 ...rm_bwd.py_gen_triton_code_338946.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_347725.py | 146 ++++++++ ...rm_bwd.py_gen_triton_code_347725.py.stderr | 0 ...rm_bwd.py_gen_triton_code_347725.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_387667.py | 170 +++++++++ ...rm_bwd.py_gen_triton_code_387667.py.stderr | 0 ...rm_bwd.py_gen_triton_code_387667.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_404776.py | 134 +++++++ ...rm_bwd.py_gen_triton_code_404776.py.stderr | 0 ...rm_bwd.py_gen_triton_code_404776.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_414029.py | 132 +++++++ ...rm_bwd.py_gen_triton_code_414029.py.stderr | 0 ...rm_bwd.py_gen_triton_code_414029.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_419949.py | 132 +++++++ ...rm_bwd.py_gen_triton_code_419949.py.stderr | 0 ...rm_bwd.py_gen_triton_code_419949.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_433589.py | 128 +++++++ ...rm_bwd.py_gen_triton_code_433589.py.stderr | 0 ...rm_bwd.py_gen_triton_code_433589.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_459560.py | 149 ++++++++ ...rm_bwd.py_gen_triton_code_459560.py.stderr | 0 ...rm_bwd.py_gen_triton_code_459560.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_486455.py | 147 ++++++++ ...rm_bwd.py_gen_triton_code_486455.py.stderr | 0 ...rm_bwd.py_gen_triton_code_486455.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_493519.py | 136 +++++++ ...rm_bwd.py_gen_triton_code_493519.py.stderr | 0 ...rm_bwd.py_gen_triton_code_493519.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_570539.py | 136 +++++++ ...rm_bwd.py_gen_triton_code_570539.py.stderr | 0 ...rm_bwd.py_gen_triton_code_570539.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_597752.py | 132 +++++++ ...rm_bwd.py_gen_triton_code_597752.py.stderr | 0 ...rm_bwd.py_gen_triton_code_597752.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_637799.py | 147 ++++++++ ...rm_bwd.py_gen_triton_code_637799.py.stderr | 0 ...rm_bwd.py_gen_triton_code_637799.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_640557.py | 130 +++++++ ...rm_bwd.py_gen_triton_code_640557.py.stderr | 0 ...rm_bwd.py_gen_triton_code_640557.py.stdout | 15 + .../l2_norm_bwd.py_gen_triton_code_712104.py | 132 +++++++ ...rm_bwd.py_gen_triton_code_712104.py.stderr | 0 ...rm_bwd.py_gen_triton_code_712104.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_786715.py | 132 +++++++ ...rm_bwd.py_gen_triton_code_786715.py.stderr | 0 ...rm_bwd.py_gen_triton_code_786715.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_827439.py | 134 +++++++ ...rm_bwd.py_gen_triton_code_827439.py.stderr | 0 ...rm_bwd.py_gen_triton_code_827439.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_843690.py | 136 +++++++ ...rm_bwd.py_gen_triton_code_843690.py.stderr | 0 ...rm_bwd.py_gen_triton_code_843690.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_864396.py | 147 ++++++++ ...rm_bwd.py_gen_triton_code_864396.py.stderr | 0 ...rm_bwd.py_gen_triton_code_864396.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_885795.py | 135 +++++++ ...rm_bwd.py_gen_triton_code_885795.py.stderr | 0 ...rm_bwd.py_gen_triton_code_885795.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_960121.py | 130 +++++++ ...rm_bwd.py_gen_triton_code_960121.py.stderr | 0 ...rm_bwd.py_gen_triton_code_960121.py.stdout | 1 + .../l2_norm_bwd.py_gen_triton_code_972847.py | 139 +++++++ ...rm_bwd.py_gen_triton_code_972847.py.stderr | 0 ...rm_bwd.py_gen_triton_code_972847.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_212491.py | 104 ++++++ ...riton1.py_gen_triton_code_212491.py.stderr | 0 ...riton1.py_gen_triton_code_212491.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_254823.py | 116 ++++++ ...riton1.py_gen_triton_code_254823.py.stderr | 0 ...riton1.py_gen_triton_code_254823.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_318959.py | 114 ++++++ ...riton1.py_gen_triton_code_318959.py.stderr | 0 ...riton1.py_gen_triton_code_318959.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_336206.py | 112 ++++++ ...riton1.py_gen_triton_code_336206.py.stderr | 0 ...riton1.py_gen_triton_code_336206.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_357644.py | 114 ++++++ ...riton1.py_gen_triton_code_357644.py.stderr | 0 ...riton1.py_gen_triton_code_357644.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_392963.py | 117 ++++++ ...riton1.py_gen_triton_code_392963.py.stderr | 0 ...riton1.py_gen_triton_code_392963.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_403404.py | 124 +++++++ ...riton1.py_gen_triton_code_403404.py.stderr | 0 ...riton1.py_gen_triton_code_403404.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_466457.py | 117 ++++++ ...riton1.py_gen_triton_code_466457.py.stderr | 0 ...riton1.py_gen_triton_code_466457.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_598128.py | 111 ++++++ ...riton1.py_gen_triton_code_598128.py.stderr | 0 ...riton1.py_gen_triton_code_598128.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_599125.py | 112 ++++++ ...riton1.py_gen_triton_code_599125.py.stderr | 0 ...riton1.py_gen_triton_code_599125.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_637798.py | 104 ++++++ ...riton1.py_gen_triton_code_637798.py.stderr | 0 ...riton1.py_gen_triton_code_637798.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_650964.py | 113 ++++++ ...riton1.py_gen_triton_code_650964.py.stderr | 0 ...riton1.py_gen_triton_code_650964.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_674736.py | 108 ++++++ ...riton1.py_gen_triton_code_674736.py.stderr | 0 ...riton1.py_gen_triton_code_674736.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_786517.py | 117 ++++++ ...riton1.py_gen_triton_code_786517.py.stderr | 0 ...riton1.py_gen_triton_code_786517.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_800477.py | 117 ++++++ ...riton1.py_gen_triton_code_800477.py.stderr | 0 ...riton1.py_gen_triton_code_800477.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_839169.py | 132 +++++++ ...riton1.py_gen_triton_code_839169.py.stderr | 0 ...riton1.py_gen_triton_code_839169.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_846578.py | 109 ++++++ ...riton1.py_gen_triton_code_846578.py.stderr | 0 ...riton1.py_gen_triton_code_846578.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_964700.py | 117 ++++++ ...riton1.py_gen_triton_code_964700.py.stderr | 0 ...riton1.py_gen_triton_code_964700.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_965300.py | 113 ++++++ ...riton1.py_gen_triton_code_965300.py.stderr | 0 ...riton1.py_gen_triton_code_965300.py.stdout | 1 + ..._norm_triton1.py_gen_triton_code_973282.py | 132 +++++++ ...riton1.py_gen_triton_code_973282.py.stderr | 0 ...riton1.py_gen_triton_code_973282.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_114093.py | 124 +++++++ ...nspose.py_gen_triton_code_114093.py.stderr | 0 ...nspose.py_gen_triton_code_114093.py.stdout | 15 + ...trix_transpose.py_gen_triton_code_11496.py | 96 +++++ ...anspose.py_gen_triton_code_11496.py.stderr | 0 ...anspose.py_gen_triton_code_11496.py.stdout | 1 + ...trix_transpose.py_gen_triton_code_14792.py | 115 ++++++ ...anspose.py_gen_triton_code_14792.py.stderr | 0 ...anspose.py_gen_triton_code_14792.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_160821.py | 105 ++++++ ...nspose.py_gen_triton_code_160821.py.stderr | 0 ...nspose.py_gen_triton_code_160821.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_205496.py | 89 +++++ ...nspose.py_gen_triton_code_205496.py.stderr | 0 ...nspose.py_gen_triton_code_205496.py.stdout | 8 + ...rix_transpose.py_gen_triton_code_216901.py | 90 +++++ ...nspose.py_gen_triton_code_216901.py.stderr | 0 ...nspose.py_gen_triton_code_216901.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_274099.py | 102 +++++ ...nspose.py_gen_triton_code_274099.py.stderr | 0 ...nspose.py_gen_triton_code_274099.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_369711.py | 105 ++++++ ...nspose.py_gen_triton_code_369711.py.stderr | 0 ...nspose.py_gen_triton_code_369711.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_412290.py | 85 +++++ ...nspose.py_gen_triton_code_412290.py.stderr | 0 ...nspose.py_gen_triton_code_412290.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_429164.py | 100 +++++ ...nspose.py_gen_triton_code_429164.py.stderr | 0 ...nspose.py_gen_triton_code_429164.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_469771.py | 105 ++++++ ...nspose.py_gen_triton_code_469771.py.stderr | 0 ...nspose.py_gen_triton_code_469771.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_493615.py | 84 +++++ ...nspose.py_gen_triton_code_493615.py.stderr | 0 ...nspose.py_gen_triton_code_493615.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_529486.py | 87 +++++ ...nspose.py_gen_triton_code_529486.py.stderr | 0 ...nspose.py_gen_triton_code_529486.py.stdout | 14 + ...rix_transpose.py_gen_triton_code_571713.py | 97 +++++ ...nspose.py_gen_triton_code_571713.py.stderr | 0 ...nspose.py_gen_triton_code_571713.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_580037.py | 100 +++++ ...nspose.py_gen_triton_code_580037.py.stderr | 0 ...nspose.py_gen_triton_code_580037.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_608628.py | 89 +++++ ...nspose.py_gen_triton_code_608628.py.stderr | 0 ...nspose.py_gen_triton_code_608628.py.stdout | 8 + ...rix_transpose.py_gen_triton_code_619005.py | 100 +++++ ...nspose.py_gen_triton_code_619005.py.stderr | 0 ...nspose.py_gen_triton_code_619005.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_620806.py | 102 +++++ ...nspose.py_gen_triton_code_620806.py.stderr | 0 ...nspose.py_gen_triton_code_620806.py.stdout | 15 + ...rix_transpose.py_gen_triton_code_671609.py | 103 +++++ ...nspose.py_gen_triton_code_671609.py.stderr | 0 ...nspose.py_gen_triton_code_671609.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_724790.py | 100 +++++ ...nspose.py_gen_triton_code_724790.py.stderr | 0 ...nspose.py_gen_triton_code_724790.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_738982.py | 89 +++++ ...nspose.py_gen_triton_code_738982.py.stderr | 0 ...nspose.py_gen_triton_code_738982.py.stdout | 8 + ...trix_transpose.py_gen_triton_code_74175.py | 89 +++++ ...anspose.py_gen_triton_code_74175.py.stderr | 0 ...anspose.py_gen_triton_code_74175.py.stdout | 8 + ...rix_transpose.py_gen_triton_code_757083.py | 108 ++++++ ...nspose.py_gen_triton_code_757083.py.stderr | 0 ...nspose.py_gen_triton_code_757083.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_759138.py | 104 ++++++ ...nspose.py_gen_triton_code_759138.py.stderr | 0 ...nspose.py_gen_triton_code_759138.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_780911.py | 105 ++++++ ...nspose.py_gen_triton_code_780911.py.stderr | 0 ...nspose.py_gen_triton_code_780911.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_783719.py | 107 ++++++ ...nspose.py_gen_triton_code_783719.py.stderr | 0 ...nspose.py_gen_triton_code_783719.py.stdout | 1 + ...trix_transpose.py_gen_triton_code_81159.py | 111 ++++++ ...anspose.py_gen_triton_code_81159.py.stderr | 0 ...anspose.py_gen_triton_code_81159.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_853096.py | 89 +++++ ...nspose.py_gen_triton_code_853096.py.stderr | 0 ...nspose.py_gen_triton_code_853096.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_869907.py | 111 ++++++ ...nspose.py_gen_triton_code_869907.py.stderr | 0 ...nspose.py_gen_triton_code_869907.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_879575.py | 105 ++++++ ...nspose.py_gen_triton_code_879575.py.stderr | 0 ...nspose.py_gen_triton_code_879575.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_892743.py | 87 +++++ ...nspose.py_gen_triton_code_892743.py.stderr | 0 ...nspose.py_gen_triton_code_892743.py.stdout | 10 + ...rix_transpose.py_gen_triton_code_917011.py | 88 +++++ ...nspose.py_gen_triton_code_917011.py.stderr | 0 ...nspose.py_gen_triton_code_917011.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_930305.py | 107 ++++++ ...nspose.py_gen_triton_code_930305.py.stderr | 0 ...nspose.py_gen_triton_code_930305.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_953212.py | 87 +++++ ...nspose.py_gen_triton_code_953212.py.stderr | 0 ...nspose.py_gen_triton_code_953212.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_984648.py | 90 +++++ ...nspose.py_gen_triton_code_984648.py.stderr | 0 ...nspose.py_gen_triton_code_984648.py.stdout | 1 + ...rix_transpose.py_gen_triton_code_997014.py | 102 +++++ ...nspose.py_gen_triton_code_997014.py.stderr | 0 ...nspose.py_gen_triton_code_997014.py.stdout | 14 + ...vector_multip.py_gen_triton_code_164112.py | 87 +++++ ...multip.py_gen_triton_code_164112.py.stderr | 2 + ...multip.py_gen_triton_code_164112.py.stdout | 1 + ...vector_multip.py_gen_triton_code_205689.py | 88 +++++ ...multip.py_gen_triton_code_205689.py.stderr | 2 + ...multip.py_gen_triton_code_205689.py.stdout | 1 + ...vector_multip.py_gen_triton_code_334537.py | 88 +++++ ...multip.py_gen_triton_code_334537.py.stderr | 2 + ...multip.py_gen_triton_code_334537.py.stdout | 1 + ...vector_multip.py_gen_triton_code_370413.py | 111 ++++++ ...multip.py_gen_triton_code_370413.py.stderr | 2 + ...multip.py_gen_triton_code_370413.py.stdout | 1 + ...vector_multip.py_gen_triton_code_424820.py | 86 +++++ ...multip.py_gen_triton_code_424820.py.stderr | 2 + ...multip.py_gen_triton_code_424820.py.stdout | 1 + ...vector_multip.py_gen_triton_code_554113.py | 88 +++++ ...multip.py_gen_triton_code_554113.py.stderr | 2 + ...multip.py_gen_triton_code_554113.py.stdout | 1 + ...vector_multip.py_gen_triton_code_554981.py | 101 +++++ ...multip.py_gen_triton_code_554981.py.stderr | 2 + ...multip.py_gen_triton_code_554981.py.stdout | 1 + ...vector_multip.py_gen_triton_code_561330.py | 98 +++++ ...multip.py_gen_triton_code_561330.py.stderr | 2 + ...multip.py_gen_triton_code_561330.py.stdout | 1 + ...vector_multip.py_gen_triton_code_686366.py | 95 +++++ ...multip.py_gen_triton_code_686366.py.stderr | 2 + ...multip.py_gen_triton_code_686366.py.stdout | 1 + ..._vector_multip.py_gen_triton_code_80693.py | 99 +++++ ..._multip.py_gen_triton_code_80693.py.stderr | 2 + ..._multip.py_gen_triton_code_80693.py.stdout | 1 + ...ary_transform.py_gen_triton_code_105954.py | 312 ++++++++++++++++ ...nsform.py_gen_triton_code_105954.py.stderr | 0 ...nsform.py_gen_triton_code_105954.py.stdout | 1 + ...ary_transform.py_gen_triton_code_260701.py | 237 ++++++++++++ ...nsform.py_gen_triton_code_260701.py.stderr | 0 ...nsform.py_gen_triton_code_260701.py.stdout | 1 + ...ary_transform.py_gen_triton_code_329295.py | 287 ++++++++++++++ ...nsform.py_gen_triton_code_329295.py.stderr | 0 ...nsform.py_gen_triton_code_329295.py.stdout | 1 + ...ary_transform.py_gen_triton_code_338032.py | 321 ++++++++++++++++ ...nsform.py_gen_triton_code_338032.py.stderr | 0 ...nsform.py_gen_triton_code_338032.py.stdout | 1 + ...ary_transform.py_gen_triton_code_339628.py | 289 ++++++++++++++ ...nsform.py_gen_triton_code_339628.py.stderr | 0 ...nsform.py_gen_triton_code_339628.py.stdout | 15 + ...ary_transform.py_gen_triton_code_344391.py | 237 ++++++++++++ ...nsform.py_gen_triton_code_344391.py.stderr | 0 ...nsform.py_gen_triton_code_344391.py.stdout | 1 + ...ary_transform.py_gen_triton_code_373163.py | 343 +++++++++++++++++ ...nsform.py_gen_triton_code_373163.py.stderr | 0 ...nsform.py_gen_triton_code_373163.py.stdout | 14 + ...ary_transform.py_gen_triton_code_385268.py | 272 ++++++++++++++ ...nsform.py_gen_triton_code_385268.py.stderr | 0 ...nsform.py_gen_triton_code_385268.py.stdout | 1 + ...ary_transform.py_gen_triton_code_405620.py | 275 ++++++++++++++ ...nsform.py_gen_triton_code_405620.py.stderr | 0 ...nsform.py_gen_triton_code_405620.py.stdout | 15 + ...ary_transform.py_gen_triton_code_431864.py | 237 ++++++++++++ ...nsform.py_gen_triton_code_431864.py.stderr | 0 ...nsform.py_gen_triton_code_431864.py.stdout | 1 + ...tary_transform.py_gen_triton_code_44150.py | 265 +++++++++++++ ...ansform.py_gen_triton_code_44150.py.stderr | 0 ...ansform.py_gen_triton_code_44150.py.stdout | 15 + ...ary_transform.py_gen_triton_code_450091.py | 307 +++++++++++++++ ...nsform.py_gen_triton_code_450091.py.stderr | 0 ...nsform.py_gen_triton_code_450091.py.stdout | 14 + ...ary_transform.py_gen_triton_code_460195.py | 294 +++++++++++++++ ...nsform.py_gen_triton_code_460195.py.stderr | 0 ...nsform.py_gen_triton_code_460195.py.stdout | 14 + ...ary_transform.py_gen_triton_code_527413.py | 247 ++++++++++++ ...nsform.py_gen_triton_code_527413.py.stderr | 0 ...nsform.py_gen_triton_code_527413.py.stdout | 1 + ...ary_transform.py_gen_triton_code_540784.py | 284 ++++++++++++++ ...nsform.py_gen_triton_code_540784.py.stderr | 0 ...nsform.py_gen_triton_code_540784.py.stdout | 1 + ...ary_transform.py_gen_triton_code_549779.py | 317 ++++++++++++++++ ...nsform.py_gen_triton_code_549779.py.stderr | 0 ...nsform.py_gen_triton_code_549779.py.stdout | 1 + ...ary_transform.py_gen_triton_code_555768.py | 351 ++++++++++++++++++ ...nsform.py_gen_triton_code_555768.py.stderr | 0 ...nsform.py_gen_triton_code_555768.py.stdout | 15 + ...ary_transform.py_gen_triton_code_634902.py | 256 +++++++++++++ ...nsform.py_gen_triton_code_634902.py.stderr | 2 + ...nsform.py_gen_triton_code_634902.py.stdout | 0 ...ary_transform.py_gen_triton_code_669031.py | 265 +++++++++++++ ...nsform.py_gen_triton_code_669031.py.stderr | 0 ...nsform.py_gen_triton_code_669031.py.stdout | 1 + ...ary_transform.py_gen_triton_code_711258.py | 252 +++++++++++++ ...nsform.py_gen_triton_code_711258.py.stderr | 0 ...nsform.py_gen_triton_code_711258.py.stdout | 1 + ...ary_transform.py_gen_triton_code_816058.py | 289 ++++++++++++++ ...nsform.py_gen_triton_code_816058.py.stderr | 0 ...nsform.py_gen_triton_code_816058.py.stdout | 1 + ...ary_transform.py_gen_triton_code_824557.py | 268 +++++++++++++ ...nsform.py_gen_triton_code_824557.py.stderr | 0 ...nsform.py_gen_triton_code_824557.py.stdout | 1 + ...ary_transform.py_gen_triton_code_840463.py | 228 ++++++++++++ ...nsform.py_gen_triton_code_840463.py.stderr | 0 ...nsform.py_gen_triton_code_840463.py.stdout | 1 + ...ary_transform.py_gen_triton_code_843724.py | 237 ++++++++++++ ...nsform.py_gen_triton_code_843724.py.stderr | 0 ...nsform.py_gen_triton_code_843724.py.stdout | 1 + ...ary_transform.py_gen_triton_code_893238.py | 295 +++++++++++++++ ...nsform.py_gen_triton_code_893238.py.stderr | 0 ...nsform.py_gen_triton_code_893238.py.stdout | 1 + ...ary_transform.py_gen_triton_code_915460.py | 292 +++++++++++++++ ...nsform.py_gen_triton_code_915460.py.stderr | 0 ...nsform.py_gen_triton_code_915460.py.stdout | 1 + ...ary_transform.py_gen_triton_code_925133.py | 303 +++++++++++++++ ...nsform.py_gen_triton_code_925133.py.stderr | 0 ...nsform.py_gen_triton_code_925133.py.stdout | 15 + ...ary_transform.py_gen_triton_code_939610.py | 279 ++++++++++++++ ...nsform.py_gen_triton_code_939610.py.stderr | 0 ...nsform.py_gen_triton_code_939610.py.stdout | 1 + ...ary_transform.py_gen_triton_code_946209.py | 284 ++++++++++++++ ...nsform.py_gen_triton_code_946209.py.stderr | 0 ...nsform.py_gen_triton_code_946209.py.stdout | 1 + ...tary_transform.py_gen_triton_code_99563.py | 279 ++++++++++++++ ...ansform.py_gen_triton_code_99563.py.stderr | 0 ...ansform.py_gen_triton_code_99563.py.stdout | 14 + .../sin_kernel.py_gen_triton_code_123151.py | 100 +++++ ...kernel.py_gen_triton_code_123151.py.stderr | 0 ...kernel.py_gen_triton_code_123151.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_179581.py | 100 +++++ ...kernel.py_gen_triton_code_179581.py.stderr | 0 ...kernel.py_gen_triton_code_179581.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_370053.py | 101 +++++ ...kernel.py_gen_triton_code_370053.py.stderr | 0 ...kernel.py_gen_triton_code_370053.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_473025.py | 92 +++++ ...kernel.py_gen_triton_code_473025.py.stderr | 0 ...kernel.py_gen_triton_code_473025.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_502063.py | 100 +++++ ...kernel.py_gen_triton_code_502063.py.stderr | 0 ...kernel.py_gen_triton_code_502063.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_50482.py | 102 +++++ ..._kernel.py_gen_triton_code_50482.py.stderr | 0 ..._kernel.py_gen_triton_code_50482.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_557502.py | 100 +++++ ...kernel.py_gen_triton_code_557502.py.stderr | 0 ...kernel.py_gen_triton_code_557502.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_560359.py | 100 +++++ ...kernel.py_gen_triton_code_560359.py.stderr | 0 ...kernel.py_gen_triton_code_560359.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_794865.py | 95 +++++ ...kernel.py_gen_triton_code_794865.py.stderr | 0 ...kernel.py_gen_triton_code_794865.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_834634.py | 100 +++++ ...kernel.py_gen_triton_code_834634.py.stderr | 0 ...kernel.py_gen_triton_code_834634.py.stdout | 1 + .../sin_kernel.py_gen_triton_code_931009.py | 100 +++++ ...kernel.py_gen_triton_code_931009.py.stderr | 0 ...kernel.py_gen_triton_code_931009.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_108037.py | 129 +++++++ ...matmul.py_gen_triton_code_108037.py.stderr | 0 ...matmul.py_gen_triton_code_108037.py.stdout | 1 + .../triton_matmul.py_gen_triton_code_12912.py | 120 ++++++ ..._matmul.py_gen_triton_code_12912.py.stderr | 0 ..._matmul.py_gen_triton_code_12912.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_186313.py | 125 +++++++ ...matmul.py_gen_triton_code_186313.py.stderr | 0 ...matmul.py_gen_triton_code_186313.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_284744.py | 124 +++++++ ...matmul.py_gen_triton_code_284744.py.stderr | 0 ...matmul.py_gen_triton_code_284744.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_366643.py | 127 +++++++ ...matmul.py_gen_triton_code_366643.py.stderr | 0 ...matmul.py_gen_triton_code_366643.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_391924.py | 151 ++++++++ ...matmul.py_gen_triton_code_391924.py.stderr | 0 ...matmul.py_gen_triton_code_391924.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_395140.py | 120 ++++++ ...matmul.py_gen_triton_code_395140.py.stderr | 0 ...matmul.py_gen_triton_code_395140.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_417385.py | 124 +++++++ ...matmul.py_gen_triton_code_417385.py.stderr | 0 ...matmul.py_gen_triton_code_417385.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_654780.py | 122 ++++++ ...matmul.py_gen_triton_code_654780.py.stderr | 0 ...matmul.py_gen_triton_code_654780.py.stdout | 1 + ...triton_matmul.py_gen_triton_code_769893.py | 121 ++++++ ...matmul.py_gen_triton_code_769893.py.stderr | 0 ...matmul.py_gen_triton_code_769893.py.stdout | 14 + ...triton_matmul.py_gen_triton_code_993568.py | 133 +++++++ ...matmul.py_gen_triton_code_993568.py.stderr | 0 ...matmul.py_gen_triton_code_993568.py.stdout | 1 + src/temp/int4_matmul.py | 286 ++++++++++++++ src/temp/l2_norm_bwd.py | 117 ++++++ src/temp/l2_norm_triton1.py | 97 +++++ src/temp/matrix_transpose.py | 76 ++++ src/temp/matrix_vector_multip.py | 86 +++++ src/temp/rotary_transform.py | 254 +++++++++++++ src/temp/sin_kernel.py | 58 +++ src/temp/triton_matmul.py | 130 +++++++ src/utils/__pycache__/utils.cpython-312.pyc | Bin 2442 -> 2442 bytes 1100 files changed, 51644 insertions(+), 77 deletions(-) create mode 100644 src/good/flash_decode2_phi.py create mode 100644 src/good/l2_norm_bwd.py create mode 100644 src/good/l2_norm_triton1.py create mode 100644 src/good/matrix_transpose.py create mode 100644 src/good/matrix_vector_multip.py create mode 100644 src/good/rotary_transform.py create mode 100644 src/good/sin_kernel.py create mode 100644 src/good/triton_matmul.py create mode 100644 src/pass_exe/embedding_triton_kernel.py create mode 100644 src/pass_exe/flash_decode2_phi.py create mode 100644 src/pass_exe/l2_norm_bwd.py create mode 100644 src/pass_exe/l2_norm_triton1.py create mode 100644 src/pass_exe/matrix_transpose.py create mode 100644 src/pass_exe/matrix_vector_multip.py create mode 100644 src/pass_exe/rotary_transform.py create mode 100644 src/pass_exe/sin_kernel.py create mode 100644 src/pass_exe/triton_matmul.py create mode 100644 src/soso/flash_decode2_phi.py create mode 100644 src/soso/l2_norm_bwd.py create mode 100644 src/soso/l2_norm_triton1.py create mode 100644 src/soso/matrix_transpose.py create mode 100644 src/soso/matrix_vector_multip.py create mode 100644 src/soso/rotary_transform.py create mode 100644 src/soso/sin_kernel.py create mode 100644 src/soso/triton_matmul.py create mode 100644 src/temp/embedding_triton_kernel.py create mode 100644 src/temp/flash_decode2_phi.py create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_155036.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_176773.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_180807.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_18528.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_200147.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_211539.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_322972.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_347928.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_355413.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_429595.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_43398.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_459432.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_474863.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_477598.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_480728.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_490985.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_507685.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_524778.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_533885.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_552958.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_574109.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_58716.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_600998.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_605163.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_620455.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_635331.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_64602.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_68534.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_713720.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_721645.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_759146.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_764635.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_76684.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_804525.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_823958.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_830218.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_837397.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_92676.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_940390.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_965031.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_984659.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_992208.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_126106.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_14965.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_198114.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_23614.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_269764.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_335674.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_349606.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_369704.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_38100.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_405645.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_42419.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_450387.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_506478.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_543766.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_560861.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_576804.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_653084.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_661704.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_684759.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_690508.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_720655.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_721584.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_735113.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_739112.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_754689.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_802348.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_812012.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_83138.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_870175.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_882682.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_900175.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_925215.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_959027.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_124574.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_178552.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_216434.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_219875.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_243114.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_291697.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_298484.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_308542.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_312025.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_357204.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_365790.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_41463.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_430740.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_434177.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_461728.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_48845.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_490790.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_511041.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_512013.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_52090.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_530716.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_635842.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_718301.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_731602.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_76683.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_769812.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_790411.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_811684.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_815235.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_816192.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_838410.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_886215.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_891149.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_912380.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_925632.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_927195.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_93329.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_942564.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_977481.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_991002.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_995030.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_143388.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_167554.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_215639.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_220059.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_28664.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_338946.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_347725.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_387667.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_404776.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_414029.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_419949.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_433589.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_459560.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_486455.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_493519.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_570539.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_597752.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_637799.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_640557.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_712104.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_786715.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_827439.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_843690.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_864396.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_885795.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_960121.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_972847.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_212491.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_254823.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_318959.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_336206.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_357644.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_392963.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_403404.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_466457.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_598128.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_599125.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_637798.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_650964.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_674736.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_786517.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_800477.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_839169.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_846578.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_964700.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_965300.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_973282.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_114093.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_11496.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_14792.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_160821.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_205496.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_216901.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_274099.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_369711.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_412290.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_429164.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_469771.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_493615.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_529486.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_571713.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_580037.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_608628.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_619005.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_620806.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_671609.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_724790.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_738982.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_74175.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_757083.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_759138.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_780911.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_783719.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_81159.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_853096.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_869907.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_879575.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_892743.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_917011.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_930305.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_953212.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_984648.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_997014.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_164112.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_205689.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_334537.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_370413.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_424820.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_554113.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_554981.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_561330.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_686366.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_80693.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_105954.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_260701.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_329295.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_338032.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_339628.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_344391.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_373163.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_385268.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_405620.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_431864.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_44150.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_450091.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_460195.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_527413.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_540784.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_555768.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_634902.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_669031.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_711258.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_816058.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_824557.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_840463.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_843724.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_893238.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_915460.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_925133.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_939610.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_946209.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_99563.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_123151.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_179581.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_370053.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_473025.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_502063.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_50482.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_557502.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_560359.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_794865.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_834634.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_931009.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_108037.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_12912.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_186313.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_284744.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_366643.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_391924.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_395140.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_417385.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_654780.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_769893.cpython-312.pyc create mode 100644 src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_993568.cpython-312.pyc create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stdout create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stderr create mode 100644 src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stdout create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stderr create mode 100644 src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_124574.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_178552.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_216434.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_219875.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_243114.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_291697.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_298484.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_308542.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_312025.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_357204.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_365790.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_41463.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_430740.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_434177.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_461728.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_48845.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_490790.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_511041.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_512013.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_52090.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_530716.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_635842.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_718301.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_731602.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_732866.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_76683.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_769812.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_790411.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_811684.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_815235.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_816192.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_838410.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_865534.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_886215.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_891149.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_912380.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_925632.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_927195.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_93329.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_942564.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_977481.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_991002.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stdout create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_995030.py create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stderr create mode 100644 src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stdout create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stderr create mode 100644 src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stdout create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stderr create mode 100644 src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stdout create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stderr create mode 100644 src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stdout create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stderr create mode 100644 src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_105954.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_260701.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_329295.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_338032.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_339628.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_344391.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_373163.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_385268.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_405620.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_431864.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_44150.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_450091.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_460195.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_527413.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_540784.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_549779.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_555768.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_634902.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_669031.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_711258.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_816058.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_824557.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_840463.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_843724.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_893238.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_915460.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_925133.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_939610.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_946209.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stdout create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_99563.py create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stderr create mode 100644 src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_123151.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_179581.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_370053.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_473025.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_502063.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_50482.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_557502.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_560359.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_794865.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_834634.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stdout create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_931009.py create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stderr create mode 100644 src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_108037.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_12912.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_186313.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_284744.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_366643.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_391924.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_395140.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_417385.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_654780.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_769893.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stdout create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_993568.py create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stderr create mode 100644 src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stdout create mode 100644 src/temp/int4_matmul.py create mode 100644 src/temp/l2_norm_bwd.py create mode 100644 src/temp/l2_norm_triton1.py create mode 100644 src/temp/matrix_transpose.py create mode 100644 src/temp/matrix_vector_multip.py create mode 100644 src/temp/rotary_transform.py create mode 100644 src/temp/sin_kernel.py create mode 100644 src/temp/triton_matmul.py diff --git a/src/__pycache__/args_config.cpython-312.pyc b/src/__pycache__/args_config.cpython-312.pyc index ed62ea94b9178a5147963fc8bfe6e3c473ed5043..4168ffc6fe2f2436464f758d8f9bd317cf7de9a3 100644 GIT binary patch delta 20 acmZ3(vWA8GG%qg~0}w2WU$v22kqH1b*97SR delta 20 acmZ3(vWA8GG%qg~0}vF%&DqGU$OHg15d@?F diff --git a/src/agents/__pycache__/Base.cpython-312.pyc b/src/agents/__pycache__/Base.cpython-312.pyc index f0a272db34755c9ad5ac90fc0aa6f7c474ac611d..4212f025bc4b02b3b97c7b383a1b71a7e86edbbe 100644 GIT binary patch delta 20 acmeyR_DhZXG%qg~0}w2WU$v1tUl;&J;07oF delta 20 acmeyR_DhZXG%qg~0}vF%&DqGEFAM-h8U^D3 diff --git a/src/agents/__pycache__/Reflexion.cpython-312.pyc b/src/agents/__pycache__/Reflexion.cpython-312.pyc index 54cab4bd97f26a4a5622127942cb5f6a38e1bf70..665562496c745c66b8b24028a9d2c7b2752984e0 100644 GIT binary patch delta 20 acmX>sc36!2G%qg~0}w2WU$v3jf*Sxl00mnB delta 20 acmX>sc36!2G%qg~0}vF%&DqFp!3_X9It3O0 diff --git a/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc index db3379e04a5f6150b194bba0a6cbd55a6310e256..ba6c9882db0ddcf2cef5f1e5623abe9946c84655 100644 GIT binary patch delta 3707 zcmb_fYfK#16}~gGv+pN6yX=FZmN*lR-CG8a7|;ZsEU;z6_H|<*fkjfuJMYtY1BVe{}4NE6sM|s?k)?M zI#Q!{q&<7@x#xc8%y-V-d-k1{h5gnW7K;(V^Na8=55HP=#rk6lE$Q%COPh)5vLlqx zD`-HAFWY?f(}o_p5B1Pp4p18u1AfuVxg~1Ut{BeHqfQM<7(s1}`DGL4eR2^+xl z$hlejq;=Ai#er43Vr)i)o_<0PpQgu-=WH$t;W<6GPcbOkOH5`n@J5dYfGkJCAm?WR z6-O4X8Fe`UAV;wc6lPb~z6CQrWh=66CJwvG(vHx8AOAq_d<1+!6w)z9hWp2 zeMk#9&IFfmCESVJghz44^om;Jz7z~gESRDRnU&@0qQ zdW;_Z2T5&v^~Urel(b}NpwJjYe&5HTx+(Z13`f_Eict-UQAc5=6MoTPOO5^-=R3Xr z_m451=qOc;)%;+d}6lI|Q zn{Ys~_I zLN@8MFE2etriLev^@ziS419_!qx^V^JLqB}VkBkDyh<3xQK?^yVXbRtQr9~ok4R!H zrOgb|ex-JTWa{Z291O=~0Wz>Q5hW>QBeQJLTTomccroOitPHn-`O=1;4(UXEOhT-|&1`MIi=`GOG0 z3$dA=k0UoCKaHtXEw8kz1))DX150-I>!vd%wIDbn&-AL#KCk9=F4zyOw>)3O=(n5{epSh5@-XU9^d@)o{8%DvIezum%8g2wc`c#nUETKN%izBn1xSq z%GJF8@dfjDmmIEhO&4tQd5_+5)J#2l$62`K-F($}%iA=4>`p|P4 zxRJ$~33SJTxnU{SKi>ZP+%03wGOcsuFNR*c9XdQ0I($2nmDLDa2ER)RjVO2EtsFYQ`a zbJTqOMF8ILE@B+{OPufZ)-$bYplSMv=~|V4YJuCcSkgTGz3C=3um|3@^$%R$*ULIE z{pa1MDO7rh`s%KpM3{k9S73Ljj_vnj=b(s_?>qK2Xg@G(Kz^vVwK>=ibJ;eH?!&y& zeY=g7Ak&oGemF#s38!XaY&nXS&8Lw|IhVc!5 z9q50zn*iO^ux(b;O@kgFZaUaDzu{(qA;{@&J`xDpbhC9f7ACWK1Y)*`4OWrSJ8V&j+f z%5A~8q2ow@=4!j|;4As=K}qZuj;~yN@1FI$s9iakGnDjSDOM0)8tF?P)am0{B0YFY zJ1j*b!!n$?Q7O_bOO2qE4uQab%B!bD?8@Is9mMVVn`$nBGWHxqcoHp|tfw1))JRUO zeW&(~@4V3Q_JOw!yd8cktd`WP^)0IRsRi?%cTeT-q3B@>TfP4OG3ADuPsyT>d z!!yZb8}tMtq%)f4UbDUiuv(}|YO~xcR2Z_Id4N$}Fz#7pR1$j!;eIBq0=X!qqhMCf zq;$a6OM_x0%}~mez6s&R8X1(O<6t4hiP%SMt!vCn38EzzB;kEKlE~jR=>!o;B33Dv zD#;Z7JRs>L5mg|{S*d`GMI7R{QXzbrFW;(yBKA6njhRsWwNp*+wOna=ul-89THCBP zKc{*RESL}C>->AB5rT7&h!7qvdc3fcVr%ck=_1y4?`g`$c2E*2MJWeFxliiAe=E96 Xy@vl7Xmkb{s)HIg{2iqSX|sO=LLI6Z delta 2339 zcmb_eTWl0n7(Qn%JA0qK&0gE>Zns-l?nPRPupptdg%yNAQb@qHbeGaiyDc-@0AEjH+Q6rC3!zj8x+=`jpd5DkC!~PBRyle6KUE!zgE`;Pg_3QQak6oYx}F z0>M2wXHiF}-#jWj=e4Sc?5uKmQ?SgB8qgWBjp;yb%>EE;NmQ+*TQ9LF<0Xgn&efu- zyTlfORFNGoF_(jfAS zC^-Rs)NqL@f?@RB(wnv*V?dOxH)SlPc!6_)g}(?}3V=OqEWnGIAhGM69z3uILOt<)#kiPxd!J~G$K#*&sg`nV7n z@;Lb$eH8l~#`biU`qCk#yR(;ukn7UY^-iPQn++*#az{^23BkJdp58ubK3Gm4>`b@O za3v?7mu#R0j7+C_j%DjnJO&UCV^wX;)R0#>rZQ;l7u z#nVE*@zgJ5yV6~G;b3pBS57N=U2&2ZI+SEjtD?wDVTIzN%d+0rnV%VQ(HrtSG&{~D zK1huF7mchMss4f+_ts9D>!!#SUqj>4N%QiDc4?^jzCAkP`q5eOt8ejF#{0faH;ztq z9e&W&H__EM4KF<2%)p?*@CczSZe*Sy6lp3cM0Q)dlUG zHnMl*(7lS)g!fnV*U#9Iqk7uQT5Nymk;Pfu+_FjYiie)SP{W07=eCW+M(W4QR{h|q z9b~7R?vLXi#iv|lQ;`LqulsD>gUE)7$c7v29o%)7iwd z33=fpFCz0Q6 zPQBt(SvXfPc_g?5cR&sC}F?|-$S;Q*A-l}p_6N@u$+K3I*3T+0oKZE5yHf1P+q>NxUooCvy#;f$uso|VB z03FNKq%x!CITE^DIc%bE!HG0`)B&8nWSt+Ay3+g@idnW_`=dg_tSHxzvarudUtu{8 zBCnt89qYrj_#nxJ*Wlgca(JQR6L2eakmCfJGFi{WPsgtbmCYEt$&N_b%Zff*ep+Rw zSYQ7rE*+g3n__?`Tj0sfRb1-lfFxHUsp4eebAfIjw$5*lv3a3K?mp0&?vfR{XHzb% z7dOLG+7g%KaNXl=5h?xpaOQs0&VJc)Q=J*x>3jIi?9|m%c3d=H(FYk$aN^08cUo)&Kwi delta 22 ccmdnn$hfzWk^3|+FBbz46vWNh$i2w{08MNLkpKVy diff --git a/src/good/flash_decode2_phi.py b/src/good/flash_decode2_phi.py new file mode 100644 index 0000000..9fce41a --- /dev/null +++ b/src/good/flash_decode2_phi.py @@ -0,0 +1,143 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_ob, stride_oh, stride_od, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + sum_exp = 0.0 + max_logic = -float('inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + offs_d = tl.arange(0, BLOCK_DMODEL) + for block_id in range(0, block_n_size): + ptr_tv = Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_id * stride_mid_os + offs_d * stride_mid_od + tv = tl.load(ptr_tv) + ptr_tlogic = Mid_O_LogExpSum + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + block_id * stride_mid_o_es + tlogic = tl.load(ptr_tlogic) + max_prev = max_logic + max_logic = tl.maximum(max_prev, tlogic) + sum_exp = sum_exp * tl.exp(max_prev - max_logic) + tl.exp(tlogic - max_logic) + acc = acc * tl.exp(max_prev - max_logic) + tv * tl.exp(tlogic - max_logic) + result = acc / (sum_exp + 1e-06) + ptr_out = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(ptr_out, result.to(ptr_out.dtype.element_ty)) + +@torch.no_grad() +def flash_decode_stage2(Mid_O: torch.Tensor, Mid_O_LogExpSum: torch.Tensor, B_Seqlen: torch.Tensor, Out: torch.Tensor, block_seq: int): + batch, head_num, seq_blocks, BLOCK_DMODEL = Mid_O.shape + triton_grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[triton_grid](B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + return + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/good/l2_norm_bwd.py b/src/good/l2_norm_bwd.py new file mode 100644 index 0000000..a62b863 --- /dev/null +++ b/src/good/l2_norm_bwd.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + scale_k = 1.0 / (var + eps) + rstd = tl.math.sqrt(scale_k) + dx = dy * rstd - tl.sum(dy * x, axis=0) * scale_k * rstd * x + tl.store(DX + cols, dx.to(DX.dtype.element_ty), mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05): + x_shape_og = x.shape + x = x.reshape(-1, x_shape_og[-1]) + dy = dy.reshape(-1, x_shape_og[-1]) + if x.stride(1) != 1: + x = x.contiguous() + if dy.stride(1) != 1: + dy = dy.contiguous() + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This L2-norm backward doesn't support feature dim >= 64KB.") + dx = torch.empty_like(x) + _l2_norm_bwd_kernel[M,](x, dy, dx, x.stride(0), N, eps, BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/good/l2_norm_triton1.py b/src/good/l2_norm_triton1.py new file mode 100644 index 0000000..e6a3f4d --- /dev/null +++ b/src/good/l2_norm_triton1.py @@ -0,0 +1,93 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_idx = tl.program_id(0) + row_off = row_idx * stride_x_row + col_idx = tl.arange(0, BLOCK_N) + mask = col_idx < N + x = tl.load(X + row_off + col_idx, mask=mask, other=0.0).to(tl.float32) + ssq = tl.sum(x * x) + rstd = tl.math.rsqrt(ssq + eps) + y = x * rstd + tl.store(Y + row_off + col_idx, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-06) -> torch.Tensor: + x_shape_og = x.shape + x = x.view(-1, x_shape_og[-1]).contiguous() + M, N = x.shape + y = torch.empty_like(x) + MAX_FUSED = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise NotImplementedError('N > BLOCK_N not handled in 1-pass kernel') + grid = (M,) + with torch.cuda.device(x.device): + _l2_norm_fwd_1pass_kernel[grid](x, y, x.stride(0), N, eps, BLOCK_N) + return y.view(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/good/matrix_transpose.py b/src/good/matrix_transpose.py new file mode 100644 index 0000000..cec61dd --- /dev/null +++ b/src/good/matrix_transpose.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_idx = offs_m[:, None] * D_HEAD + offs_n[None, :] + y_idx = offs_n[:, None] * SIZE_M + offs_m[None, :] + mask_i = (offs_m[:, None] < SIZE_M) & (offs_n[None, :] < D_HEAD) + mask_o = (offs_n[:, None] < D_HEAD) & (offs_m[None, :] < SIZE_M) + val = tl.load(M + x_idx, mask=mask_i, other=0.0) + tl.store(Out + y_idx, val.trans(), mask=mask_o) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + device = torch.device('cuda') + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = (triton.cdiv(SIZE_M, BLOCK_SIZE_M), triton.cdiv(D_HEAD, BLOCK_SIZE_N)) + kernel[grid](matrix, out, 1, 1, 1, 1, SIZE_M, D_HEAD, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/good/matrix_vector_multip.py b/src/good/matrix_vector_multip.py new file mode 100644 index 0000000..7ff22a1 --- /dev/null +++ b/src/good/matrix_vector_multip.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A_ptr, B_ptr, C_ptr, N, M, stride_am, stride_ak, stride_bk, stride_cn, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid = tl.program_id(0) + row_start = pid * BLOCK_N + offs_n = row_start + tl.arange(0, BLOCK_N) + col_start = 0 + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + for col_start in range(0, M, BLOCK_M): + offs_m = col_start + tl.arange(0, BLOCK_M) + a_idx = A_ptr + offs_n[:, None] * stride_am + offs_m[None, :] * stride_ak + a_mask = (offs_n[:, None] < N) & (offs_m[None, :] < M) + a_vals = tl.load(a_idx, mask=a_mask, other=0.0) + b_idx = B_ptr + offs_m * stride_bk + b_mask = offs_m < M + b_vals = tl.load(b_idx, mask=b_mask, other=0.0) + acc += tl.sum(a_vals * b_vals[None, :], axis=1) + c_idx = C_ptr + offs_n * stride_cn + c_mask = offs_n < N + tl.store(c_idx, acc, mask=c_mask) + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1, 'A must be 2D and B must be 1D' + assert A.size(1) == B.size(0), 'Inner matrix dimensions must agree' + N, M = A.shape + C = torch.empty((N,), dtype=torch.float32, device=A.device) + + def grid(meta): + return (triton.cdiv(N, meta['BLOCK_N']),) + mv_kernel[grid](A, B, C, N, M, A.stride(0), A.stride(1), B.stride(0), C.stride(0), BLOCK_N=32, BLOCK_M=32) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/good/rotary_transform.py b/src/good/rotary_transform.py new file mode 100644 index 0000000..f1e0ffc --- /dev/null +++ b/src/good/rotary_transform.py @@ -0,0 +1,196 @@ +import torch +import triton +import triton.language as tl +from typing import Union, Optional + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + if not INTERLEAVED: + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = x0_ptr + rotary_dim_half * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_ptr = SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + out0_ptr = OUT + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_ptr = out0_ptr + rotary_dim_half * stride_out_headdim + tl.store(out0_ptr, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store(out1_ptr, o1, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + else: + rk_swap = rk + (rk + 1) % 2 * 2 - 1 + rk_repeat = tl.arange(0, BLOCK_K) // 2 + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_ptr = SIN + rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + out = tl.where(rk[None, :] % 2 == 0, x0 * cos - x1 * sin, x0 * cos + x1 * sin) + out_ptr = OUT + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_ptr, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + assert x.ndim == 4, 'Expected 4-D tensor [batch, seqlen, heads, dim] for non-varlen inputs' + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, 'If cu_seqlens is provided, max_seqlen must be specified' + assert x.ndim == 3, 'Expected 3-D tensor [total_seqlen, heads, dim] for varlen inputs' + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim, 'rotary_dim must be <= headdim' + assert cos.dtype == sin.dtype and x.dtype == cos.dtype + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert int(seqlen_offsets) + seqlen <= seqlen_ro + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) if not is_varlen else output[:, rotary_dim:].copy_(x[:, rotary_dim:]) + BLOCK_K = 32 if rotary_dim <= 32 else 64 if rotary_dim <= 64 else 128 if rotary_dim <= 128 else 256 + BLOCK_M = 4 if interleaved else 8 if rotary_dim <= 64 else 4 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + + def stride_or_zero(tensor, idx, fixed=None): + return tensor.stride(idx) if fixed is None else fixed + with torch.cuda.device(x.device.index): + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, stride_or_zero(output, -4, 0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), stride_or_zero(x, -4, 0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/good/sin_kernel.py b/src/good/sin_kernel.py new file mode 100644 index 0000000..134a422 --- /dev/null +++ b/src/good/sin_kernel.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + out = tl.math.sin(x) + tl.store(output_ptr + offsets, out, mask=mask) + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int=128) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/good/triton_matmul.py b/src/good/triton_matmul.py new file mode 100644 index 0000000..df3bfad --- /dev/null +++ b/src/good/triton_matmul.py @@ -0,0 +1,87 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k * BLOCK_SIZE_K + a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < k_remaining) + b_mask = (offs_k[:, None] < k_remaining) & (offs_n[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator.to(c_ptr.type.element_ty), mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.device == b.device and a.dtype == b.dtype, 'Input tensors must be on the same device and dtype' + assert a.shape[1] == b.shape[0], 'Incompatible dimensions for matrix multiplication' + M, K = a.shape + _, N = b.shape + c = torch.empty((M, N), dtype=a.dtype, device=a.device) + block_size_m = 64 + block_size_n = 64 + block_size_k = 32 + if a.dtype == torch.float16: + num_warps = 4 + num_stages = 3 + else: + num_warps = 8 + num_stages = 2 + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=block_size_m, BLOCK_SIZE_N=block_size_n, BLOCK_SIZE_K=block_size_k) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/memories/__pycache__/Memory.cpython-312.pyc b/src/memories/__pycache__/Memory.cpython-312.pyc index 09f82b91d9ce14030c7823c89dc54caace80ecb4..a738e8ad7e106bcfeb8e41374a2ee6eb30064896 100644 GIT binary patch delta 20 acmcb?b%TrhG%qg~0}w2WU$v3jmlXg#k_CJK delta 20 acmcb?b%TrhG%qg~0}vF%&DqH9%L)KH%mp_9 diff --git a/src/models/KimiK2.py b/src/models/KimiK2.py index 9728de8..d65c1fa 100644 --- a/src/models/KimiK2.py +++ b/src/models/KimiK2.py @@ -25,6 +25,7 @@ def __init__(self, #api_key = "wisemodel-xxvqzbsnecjtoxufxodx", api_key=api_key, base_url = "https://laiyeapi.aifoundrys.com:7443/v1", + # base_url = "https://api.siliconflow.cn/v1", default_headers = headers ) diff --git a/src/models/__pycache__/Base.cpython-312.pyc b/src/models/__pycache__/Base.cpython-312.pyc index 0ad2c7349939c7e2b124bb94e94bd8b3f3172b23..838ccf56433a55800304b2734ade7026c2b0fecc 100644 GIT binary patch delta 20 acmey)@|}hIG%qg~0}w2WU$v1thY0{dj|IvA delta 20 acmey)@|}hIG%qg~0}vF%&DqGE!vp|7$pwV~ diff --git a/src/models/__pycache__/KimiK2.cpython-312.pyc b/src/models/__pycache__/KimiK2.cpython-312.pyc index 34c5a6765825d2537bf37a5e5d6d5131d020ec19..f42b878d0c91fd20d074bfa340679c41695429bf 100644 GIT binary patch delta 67 zcmZn{Y!~D`&CAQh00hh8S7jXC$lJoq_CQYc3pd;5>C6Y27!@{0ve`4Tim-Bh)R{b& U{WiZQqwxgIuM9w{NCKz}0FW*cR{#J2 delta 67 zcmZn{Y!~D`&CAQh00aebb27R&^0qLu-H=oL!pXXMI`cs$M)}Q=Z1#+-!mL~$wI= 0) & (idxs < V) + mask = mask_l[:, None] & mask_d[None, :] & mask_v + + embs = tl.load(weight_ptrs, mask=mask, other=0.0) + + out_base = ptr_out + pid_b * stride_out_b + out_ptrs = out_base + \ + offs_l[:, None] * stride_out_l + offs_d[None, :] * stride_out_d + tl.store(out_ptrs, embs, mask=mask) + +def embedding( + ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, +) -> torch.Tensor: + assert ids.dtype in (torch.int32, torch.int64) + assert weight.ndim == 2 + inferred_D = weight.shape[1] + if out.numel() == 0: + out = torch.empty((*ids.shape, inferred_D), dtype=weight.dtype, device=weight.device) + else: + assert out.shape[:-1] == ids.shape + assert out.shape[-1] == inferred_D + + B = ids.shape[0] + L = ids.shape[1] if ids.ndim == 2 else 1 + ids = ids.view(B, L) + out = out.view(B, L, inferred_D) + + D = inferred_D + V = vob_end_id - vob_start_id + assert V <= weight.shape[0] + + BLOCK_L = 64 + BLOCK_D = triton.next_power_of_2(D) + + grid = (B, triton.cdiv(L, BLOCK_L), triton.cdiv(D, BLOCK_D)) + + embedding_kernel[grid]( + ids, weight, out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + vob_start_id, + B, L, D, V, + BLOCK_L=BLOCK_L, + BLOCK_D=BLOCK_D, + ) + return out diff --git a/src/pass_exe/flash_decode2_phi.py b/src/pass_exe/flash_decode2_phi.py new file mode 100644 index 0000000..1e7431d --- /dev/null +++ b/src/pass_exe/flash_decode2_phi.py @@ -0,0 +1,91 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_olesb, + stride_mid_olesh, + stride_mid_oles, + stride_oub, + stride_ouh, + stride_oud, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_head = tl.program_id(1) + cur_batch = tl.program_id(0) + + offs_d = tl.arange(0, BLOCK_DMODEL) + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + for block_n in range(block_n_size): + tv = tl.load(Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + block_n * stride_mid_os + offs_d * stride_mid_od) + tlogic = tl.load(Mid_O_LogExpSum + cur_batch * stride_mid_olesb + + cur_head * stride_mid_olesh + block_n * stride_mid_oles) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + new_scale = tl.exp(tlogic - new_max_logic) + + acc = acc * old_scale + tv * new_scale + sum_exp = sum_exp * old_scale + new_scale + max_logic = new_max_logic + + acc = acc / sum_exp + tl.store(Out + cur_batch * stride_oub + cur_head * stride_ouh + offs_d * stride_oud, acc) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + block_seq: int +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + assert Mid_O_LogExpSum.shape[1] == head_num + + BLOCK_SEQ = block_seq + BLOCK_DMODEL = Mid_O.shape[3] + + _fwd_kernel_flash_decode_stage2[(batch, head_num)]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + return Out diff --git a/src/pass_exe/l2_norm_bwd.py b/src/pass_exe/l2_norm_bwd.py new file mode 100644 index 0000000..d88ac96 --- /dev/null +++ b/src/pass_exe/l2_norm_bwd.py @@ -0,0 +1,55 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + x_masked = tl.where(mask, x, 0.0) + var = tl.sum(x_masked * x_masked) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x_masked) * rstd / (var + eps) * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) diff --git a/src/pass_exe/l2_norm_triton1.py b/src/pass_exe/l2_norm_triton1.py new file mode 100644 index 0000000..1f14639 --- /dev/null +++ b/src/pass_exe/l2_norm_triton1.py @@ -0,0 +1,49 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + row_id = tl.program_id(0) + row_start = X + row_id * stride_x_row + row_out_start = Y + row_id * stride_x_row + + col_offsets = tl.arange(0, BLOCK_N) + mask = col_offsets < N + + x_vals = tl.load(row_start + col_offsets, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x_vals * x_vals) + rstd = 1.0 / tl.sqrt(var + eps) + + out_vals = x_vals * rstd + tl.store(row_out_start + col_offsets, out_vals, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-6): + x = x.contiguous() + shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + + element_size = x.element_size() + MAX_FUSED = 65536 // element_size + BLOCK_N = min(MAX_FUSED, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError("N too large") + + grid = lambda META: (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + + return y.view(shape) diff --git a/src/pass_exe/matrix_transpose.py b/src/pass_exe/matrix_transpose.py new file mode 100644 index 0000000..3380832 --- /dev/null +++ b/src/pass_exe/matrix_transpose.py @@ -0,0 +1,47 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = tl.arange(0, SIZE_M) + offs_n = tl.arange(0, D_HEAD) + + matrix_ptr = M + offs_n[None, :] * matrix_stridex + offs_m[:, None] * matrix_stridey + out_ptr = Out + offs_n[None, :] * out_stridey + offs_m[:, None] * out_stridex + + x = tl.load(matrix_ptr) + tl.store(out_ptr, x) + + +def wrapper(SIZE_M: int, D_HEAD: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + + grid = (D_HEAD, SIZE_M) + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M=SIZE_M, + D_HEAD=D_HEAD, + ) + + return out diff --git a/src/pass_exe/matrix_vector_multip.py b/src/pass_exe/matrix_vector_multip.py new file mode 100644 index 0000000..9f2680a --- /dev/null +++ b/src/pass_exe/matrix_vector_multip.py @@ -0,0 +1,47 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, N, M, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for m_start in range(0, M, BLOCK_M): + offs_m_cur = m_start + offs_m + mask_m = offs_m_cur < M + offs_a = A + offs_n[:, None] * M + offs_m_cur[None, :] + mask_a = (offs_n[:, None] < N) & mask_m[None, :] + a_block = tl.load(offs_a, mask=mask_a, other=0.0) + offs_b = B + offs_m_cur + b_vals = tl.load(offs_b, mask=mask_m, other=0.0) + acc += tl.sum(a_block * b_vals[None, :], axis=1) + + offs_c = C + offs_n + mask_c = offs_n < N + tl.store(offs_c, acc.to(C.type.element_ty), mask=mask_c) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1, "A must be 2-D and B must be 1-D" + N, M = A.shape + assert B.shape[0] == M, "Dimension mismatch: B must have size M where A is NxM" + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 64 + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, + N, M, + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M + ) + + return C diff --git a/src/pass_exe/rotary_transform.py b/src/pass_exe/rotary_transform.py new file mode 100644 index 0000000..8baa298 --- /dev/null +++ b/src/pass_exe/rotary_transform.py @@ -0,0 +1,171 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + stride_cos_seqlen, + stride_cos_dim, + stride_sin_seqlen, + stride_sin_dim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, + ROTARY_DIM_HALF: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + seq_start = tl.load(CU_SEQLENS + pid_batch) + cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - seq_start + x_ptr = X + seq_start * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + seq_start * stride_out_seqlen + pid_head * stride_out_nheads + if pid_m * BLOCK_M >= cur_seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk_half = tl.arange(0, BLOCK_K // 2) + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + rm_cs = tl.where(rm_cs < seqlen_ro, rm_cs, seqlen_ro - 1) + if not INTERLEAVED: + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + (rk_half + ROTARY_DIM_HALF)[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + rk_half[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_dim + mask_m = rm[:, None] < cur_seqlen + mask_k_half = rk_half[None, :] < ROTARY_DIM_HALF + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_k_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rk_half + ROTARY_DIM_HALF)[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_k_half) + else: + rk_even = 2 * tl.arange(0, ROTARY_DIM_HALF) + rk_odd = 2 * tl.arange(0, ROTARY_DIM_HALF) + 1 + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_even[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_odd[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + tl.arange(0, ROTARY_DIM_HALF)[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + tl.arange(0, ROTARY_DIM_HALF)[None, :] * stride_sin_dim + mask_m = rm[:, None] < cur_seqlen + mask_half = tl.arange(0, ROTARY_DIM_HALF)[None, :] < ROTARY_DIM_HALF + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_even[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_odd[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_half) + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Apply rotary embedding to the input tensor x using Triton kernels optimized for AMD GPU ROCm.""" + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + if max_seqlen is None: + raise ValueError("max_seqlen must be provided if cu_seqlens is used") + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype == x.dtype + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 if rotary_dim <= 32 else + 64 if rotary_dim <= 64 else + 128 if rotary_dim <= 128 else 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, + 0, + output.stride(0) if not is_varlen else 0, + output.stride(-3), output.stride(-2), output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), x.stride(-2), x.stride(-1), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ROTARY_DIM_HALF=rotary_dim_half + ) + return output diff --git a/src/pass_exe/sin_kernel.py b/src/pass_exe/sin_kernel.py new file mode 100644 index 0000000..9787742 --- /dev/null +++ b/src/pass_exe/sin_kernel.py @@ -0,0 +1,25 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + kernel_function[grid]( + x, output, n_elements, + BLOCK_SIZE=1024, + ) + return output diff --git a/src/pass_exe/triton_matmul.py b/src/pass_exe/triton_matmul.py new file mode 100644 index 0000000..ade7921 --- /dev/null +++ b/src/pass_exe/triton_matmul.py @@ -0,0 +1,83 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m - first_pid_m, num_pid_m) + pid_m = first_pid_m + (pid % num_pid_m) + pid_n = (pid % num_pid_in_group) - pid_m * num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_offs = k * BLOCK_SIZE_K + mask_a = offs_k[None, :] < (K - k_offs) + mask_b = offs_k[:, None] < (K - k_offs) + + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_warps = 8 + num_stages = 2 + + grid = lambda META: (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + num_warps=num_warps, + num_stages=num_stages + ) + + return c diff --git a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc index 29e23cc138d0482512bb0b41a3380383a96f3f73..72703647a317c69b7644d4e2eb73438f7f63cd3f 100644 GIT binary patch delta 1152 zcmZWp&rcIU6t;*SR=~uAi6`HY))qp6hSYEXtb|ByBSMVElwtR^J7IQaGjEmxF{%Cu z{s~_7WIXvdc=Y0*foCtiowl~by=->o{rJA`eQ$p^zZ@<6n4b@B@bmNg=lEUe=!Abuy8c!bziHH>-?C=!}(@Ic^3yZsXJvW!#0| z++pe`t5QW}=o5h~1-aLCE5icqk<-`aWBIoNyMfRnKrv9M8JSOmqwIVREE zdewP)c+hwPn^0R{U#-FNGN8f)wd)XrDNZOEt>i!>2-RbpE=~xX0dZzz1WBfxOjRx( z5mFFTbLE=CS$WQLn$US^g(g-2%fPqIOkJ1lM49lk(wD5@ln4{bgE2TCGsHHs0Zub! zk3cBGn3akW~8+FBV_sv9F#4NVz~5gMV8LKZ&3#MlcEVTe4pu@NQ`91oUU zPGGTk%$~8}KXNwHXs-`BSoKN_3vUezx5o|}JEm>1i+0)q%7}=Ti5-FLUA71i^AH0& z)oFrDq^UKjmHGS=L8s)=otR8Vg876-M>2Qt$Sh5%^TpxVAuLZrc#rc;V{5Wr43Thx zt(y@B!CKB?D|2SxGR@A0wy+{WUK77<-t`l|)#1&Q)^9zQ#E|^A$?@uNdesN*HkwVI?{tEXhtZ;c`!)CH!NdFf fU%9$jC={+{_X@%6``cGXGp)ki*$;Q@Ex!5-ucUva delta 65 zcmZ1-_R*K`G%qg~0}vF%&B<_9o5-ilWX-qHdLhqdQNiz$a-u+)56p~=j2~DGxL6w4 RKBO@EGxD=Eau=}zWdKFQ532wG diff --git a/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc index 71813cab567d2f4251c15a54b2db2ae822f7f37f..ee66664c4402389b9e6e3dededab4f8c908cfb58 100644 GIT binary patch literal 19711 zcmeHPPj4H?6_;$qZ8k^(v`sF}Vc-BQ$&ys0I6@%g4`fSrRGYRS$w4c_phvqya;n`O zc4t=<69jNSMQ=Uz(td@0gx-4ewO>F$ZUqYT_ukAdSGMDzh0zwt3Wl{LXWqPd|KD$h z`s?ej2TS;Q>o0$vtle2!`a8WSf0q5j*YDur_e;^z(NeU0v>aVJx)fbLx*V+>twgUJ zy)wVj{(+odT3QNzZ3-bXC1fhp*;Hk@(5aYZI@jq~9OAc0g~9tWH<=Y9lL<8%=}@aQ z7q|BAi)m)YnM@K4?~AQe#`7fc2a%FQ7N~4~E=eg`-*6Pe;c2Vg;gU1+kivCE6APMYrR>>Sm*PF z-N6FrQ}dKBy-Ww&cfQ3TCaVk1d0Qysc=)u-)#vzaRzh;qAjD^8HIDjkIqJ_=UbdqK zi#!4Q^ht)zl_I`IRzSi*s7z!Us#c81l!kMPLWxTvWd?F@9^SljZ|?xMDQ;a~we2nk z(y=~8PJ=@w2t@N*elXKqC;AH)S`nr}d`#I#I{RlX;HI`8>^+^TkeDbN!GEiODMhMg zE;MpbJXXjhS)Kj9;1o*^4zft4&n~awh4B1|kX8&7;y?K1I8de9aNZT4$t;>trbfy^ zNDY-hu8jRox?*=9`$qP)+IGd3wJJGA3|YK&TV<(=g$cuA3bBGk{vYST8R1hBX z?{$$EDW(6D)EqmBZxy-e;a!eqCyrET(vcn)kRatbmbyw$b!JkICZ{sfG#rxr_AD%f zz@VBXc&A6sd4BfwMzJZr^iN+Qr!4HE<&2{1Oah}8AN}XK`hxDrQvdtenx#CID7X~a z86|Y68^uwY<2STuvPfnBvz-37pSZg2?pkLYXJybX+t`PltVw;Pxe32!vjx48G(K2bQQDcueTy>##d2z(T*t<7*&>q1syCG=9$WI%=m?COsr6NS+((XF&Er0Z zmb7+JW=<3cq6|}1kWvbbWG>O1oVr@ErCib2n`y`m%(%{Au`Xx@>$Gj3(VnwoVmY~x z?<$#xMj~^BkbCPCl`rJ>1)gIgBfDzxQrQ}>c7n7Rm84u&#A}3n>a0!`@?gucXF*0Bn)jNENo0>PEB6DP9=!+OumO~>?OzPw=642UW2V)4R78**x42z zi;WLHcz;8zt_rWMAqk+qXA~6EUaTOay2(@`Y{4!flXpcQg;Hu;ZQAEG+UG&pdHcFz z&wEQ{&YjLaBT;*PB6TJJ-*YiDR(@nLUmgWqxXR?>kzjx&Uk3q@d6ujXZ> zA$yJ0nYguo@BZG9I3EtkhTtX%!qmlpoohrXA9PVIh5{YJ+J0H|pEKin-JZd1hQh(lV>B_cvoZ|w{QT)0#E0lJe~yor^uFk3p7=4p4F`LWS5XhSGy znGXc25evaUp1^(AG^1Z`M;VjEW~xQ^1WRE3YyK zj6t)0FZ%^*1s-5wsm=>D$}Y%_1inU?7h1Ja(vGzr5aTo|$E?#Bzy+8L?eIu%)BOAk zv9IpjZD!#N)}K+ND6(gJO|hCr}6ep}Z$qSG37Cf!@1!sgg=? zxB|-B#UvP#4|Sk4QaBv~P0Byvq`+YZR8X=b29BQ@83neWA^R9fbizJ1lj7JN`4|q+)I{OH2*LFhRf}X%fK}!Uo(1h!&JZmf__O+pI{#LR~!ADB*wk1JnsXUDNr) zHbqUid#oNHB0Y3I`>u1b5bKerg70ltDZ)pn5)iE$aHj39a2X~;yvS0~^41(i`m_b3 zqU9c3T@-D@!Bs5R@utw{VX(GmhTX+5A{~${xj=;^QFn#~B_Zb#W#;oy6nZ0GDYjP3 zb}4kBQ*_6VUfl^D;<^|~s5l;x6|JIr1Ey0GZ5R1G8bY|R0D4NnD2$O4(Tt!5NZko# zGjJn67HB>_g`tjG@3n<^SKLyhI&1{yLv})sBb^~^RhGc>Q*^%<=dbbn!;H#61S(?I z(*THX(jMC&g;a|m3@#V+s^5Qv14;o`TujLph18a07(t= zsKX=;0^A_L4FXKZ5)A^}Ai#W^Z4oIO1enhPYF%lA05=FQos%{QaDxCh2ry1YfFw2u zaDxEjSgJvQ8w9vPfN>`DeL#T!@b=3_fNNI~0mlQDrY~B3;~jW>(Fs>yrVnuIjL+Km z9LQnt58ZhQE@0pQq`z_2W{e#dfA~vU%@z6Y?G<@k(IcB_uE_sCU6F4f>IS05jgsny zZFNDjxujfP)NL*)ml_ytlT^g5Hp9Jpr5ML~M@~VFx z%O|h-$IOJulQ$P*hWy4G)j)pzr=_pyPl^5G(^dNJyq)5X|LR1BCo-SlvNXOjif;(m zEL?T7tp7d4ooxI~tnzPK^{4YEKl~^$5x$}IN%k7((cdSt*YMx@<>lq&^Gk0BE03?7 zzxU1?E02TokAC|4%HvnhH^b#OS02B1{>#gH`Nu1d-#q{P@?`l(E5CaqdleJ^1BFBj8p-pzBls+l$~aHwXSEaP;|qdYSuRiP*` zFFjQuKTV-5F{d=OSRp4fJJpIy0R*%ZoP7M9y{#0Yjg52^^z`&}6by|lpe!Q;69_X_ zVe&+swTya`KRW9HNqH9wM!n79E~Sj@h6*M53WjEr|G7GBws*VE%AEL&Z}MEOhnrV1 zA01F38OV3XNtq8)IPfd|+l|WcLXW_Q61gkEMa@gN-yVO9STz gD?5Ic2JR0IOfHPVEDii0e3+aWg;*N-ia3G#0J&R?<^TWy diff --git a/src/prompts/prompt_for_generation.py b/src/prompts/prompt_for_generation.py index 6d07461..51f3e0c 100644 --- a/src/prompts/prompt_for_generation.py +++ b/src/prompts/prompt_for_generation.py @@ -39,6 +39,31 @@ * **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`. 8. **Triton Version:** Assume Triton version 3.1.0 or later. +**Performance Optimization Guidelines:** +Based on analysis of high-performance kernels, follow these patterns: + +1. **Reduction Operations (L2 Norm, Softmax):** + - Use single-pass reduction when possible + - Leverage vectorized operations (tl.sum, tl.max) instead of loops + - Calculate optimal BLOCK_SIZE based on hardware limits (MAX_FUSED = 65536 // element_size) + - Use online algorithms to avoid multiple passes + +2. **Matrix Operations (Transpose, MatMul):** + - Use block-wise operations instead of element-wise + - Optimize BLOCK_SIZE for cache locality (32, 64, 128) + - Ensure coalesced memory access patterns + - Use `.trans()` for transpose operations when possible + +3. **Memory Access Patterns:** + - Prefer vectorized loads/stores over scalar operations + - Use appropriate masks for boundary conditions + - Minimize memory transactions through data reuse + +4. **Autotuning Parameters:** + - BLOCK_SIZE: [32, 64, 128] for most operations + - num_warps: 4-8 for compute-bound, 8-16 for memory-bound + - num_stages: 2 for GEMM, 1 for memory-bound kernels + **FINAL VERIFICATION:** Before completing, verify: 1. ALL functions defined in the code have EXACT signatures matching the required function signatures above. diff --git a/src/prompts/prompt_for_reflection.py b/src/prompts/prompt_for_reflection.py index fe3f936..f185670 100644 --- a/src/prompts/prompt_for_reflection.py +++ b/src/prompts/prompt_for_reflection.py @@ -103,30 +103,67 @@ - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" +**Performance Analysis Framework:** +Analyze the code against these performance patterns based on empirical evidence from high vs low performance kernels: + +1. **Algorithmic Efficiency:** + - Does it use optimal algorithms (online softmax vs naive)? + - Are reduction operations single-pass (avoid loops with tl.sum/tl.max)? + - Is there unnecessary data movement (element-wise vs block operations)? + +2. **Memory Efficiency:** + - Are memory access patterns coalesced (vectorized loads/stores)? + - Is shared memory utilized effectively (block sizes 32, 64, 128)? + - Are there redundant loads/stores (multiple passes vs single-pass)? + +3. **Compute Efficiency:** + - Are vectorized operations used instead of scalar loops? + - Is the compute-to-memory ratio optimized (MAX_FUSED = 65536 // element_size)? + - Are warp-level operations utilized (tl.dot, .trans())? + +4. **Specific Kernel Patterns:** + - **Reduction kernels:** Use single-pass with tl.sum/tl.max, avoid loops + - **Transpose kernels:** Use .trans() for block operations, not element-wise + - **MatMul kernels:** Optimize BLOCK_M/N/K for cache locality (64, 128) + - **Flash kernels:** Use online algorithms, vectorized exp operations + +5. **Parameter Optimization:** + - Are BLOCK_SIZE parameters optimal (32, 64, 128 proven effective)? + - Are num_warps/num_stages tuned (4-8 for compute, 8-16 for memory)? + - Is grid configuration efficient (avoid complex GROUP_SIZE_M)? + +Common Performance Anti-patterns to Avoid: +1. Two-pass reduction when single-pass is possible +2. Element-wise operations instead of block operations +3. Small BLOCK_SIZE (16) when larger (32, 64) is better +4. Complex grid configurations when simple ones suffice +5. Scalar loops instead of vectorized operations + Maximize performance by exploring the following: i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. + Primary Autotuning Fields (Mandatory) 1. BLOCK_M, BLOCK_N, BLOCK_K * Tile sizes for GEMM or other tensor contractions. * Larger blocks improve compute density, but reduce grid-level parallelism. - * Explore wide range of values like: - * BLOCK: [32, ..., 128, ..., 2048, ...] + * Explore proven effective ranges: [32, 64, 128] based on empirical evidence * Adjust based on memory reuse and L2 cache locality. 2. num_stages=n * Controls pipeline depth for kernel execution. * Rules for setting this: - * 1 if no GEMM. - * 2 if a single GEMM (e.g., GEMM + ReLU). - * 1 if two GEMMs are fused (e.g., Flash Attention). + * 1 if no GEMM or memory-bound + * 2 if a single GEMM (e.g., GEMM + ReLU) + * 1 if two GEMMs are fused (e.g., Flash Attention) * Optimize for latency and execution overlap. 3. num_warps * Controls number of warps (groups of 64 threads) to launch per block. * If it is too low then underutilization -> kernel runs slow. * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. - * You must choose a sweet spot by trying out integer range of 1 to 16. + * You must choose a sweet spot by trying out integer range of 4 to 8 for compute-bound, 8 to 16 for memory-bound. * You MUST NOT try the range beyond 16, it is NOT VALID. + Examples of Autotuning Setup Here's how Triton kernels should be decorated to allow autotuning: * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. @@ -173,30 +210,67 @@ def grid(args: dict[str, Any]) -> tuple[int]: - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" +**Performance Analysis Framework:** +Analyze the code against these performance patterns based on empirical evidence from high vs low performance kernels: + +1. **Algorithmic Efficiency:** + - Does it use optimal algorithms (online softmax vs naive)? + - Are reduction operations single-pass (avoid loops with tl.sum/tl.max)? + - Is there unnecessary data movement (element-wise vs block operations)? + +2. **Memory Efficiency:** + - Are memory access patterns coalesced (vectorized loads/stores)? + - Is shared memory utilized effectively (block sizes 32, 64, 128)? + - Are there redundant loads/stores (multiple passes vs single-pass)? + +3. **Compute Efficiency:** + - Are vectorized operations used instead of scalar loops? + - Is the compute-to-memory ratio optimized (MAX_FUSED = 65536 // element_size)? + - Are warp-level operations utilized (tl.dot, .trans())? + +4. **Specific Kernel Patterns:** + - **Reduction kernels:** Use single-pass with tl.sum/tl.max, avoid loops + - **Transpose kernels:** Use .trans() for block operations, not element-wise + - **MatMul kernels:** Optimize BLOCK_M/N/K for cache locality (64, 128) + - **Flash kernels:** Use online algorithms, vectorized exp operations + +5. **Parameter Optimization:** + - Are BLOCK_SIZE parameters optimal (32, 64, 128 proven effective)? + - Are num_warps/num_stages tuned (4-8 for compute, 8-16 for memory)? + - Is grid configuration efficient (avoid complex GROUP_SIZE_M)? + +Common Performance Anti-patterns to Avoid: +1. Two-pass reduction when single-pass is possible +2. Element-wise operations instead of block operations +3. Small BLOCK_SIZE (16) when larger (32, 64) is better +4. Complex grid configurations when simple ones suffice +5. Scalar loops instead of vectorized operations + Maximize performance by exploring the following: i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. + Primary Autotuning Fields (Mandatory) 1. BLOCK_M, BLOCK_N, BLOCK_K * Tile sizes for GEMM or other tensor contractions. * Larger blocks improve compute density, but reduce grid-level parallelism. - * Explore wide range of values like: - * BLOCK: [32, ..., 128, ..., 2048, ...] + * Explore proven effective ranges: [32, 64, 128] based on empirical evidence * Adjust based on memory reuse and L2 cache locality. 2. num_stages=n * Controls pipeline depth for kernel execution. * Rules for setting this: - * 1 if no GEMM. - * 2 if a single GEMM (e.g., GEMM + ReLU). - * 1 if two GEMMs are fused (e.g., Flash Attention). + * 1 if no GEMM or memory-bound + * 2 if a single GEMM (e.g., GEMM + ReLU) + * 1 if two GEMMs are fused (e.g., Flash Attention) * Optimize for latency and execution overlap. 3. num_warps * Controls number of warps (groups of 64 threads) to launch per block. * If it is too low then underutilization -> kernel runs slow. * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. - * You must choose a sweet spot by trying out integer range of 1 to 16. + * You must choose a sweet spot by trying out integer range of 4 to 8 for compute-bound, 8 to 16 for memory-bound. * You MUST NOT try the range beyond 16, it is NOT VALID. + Examples of Autotuning Setup Here's how Triton kernels should be decorated to allow autotuning: * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. @@ -242,30 +316,67 @@ def grid(args: dict[str, Any]) -> tuple[int]: - generate the reflection wrapped in a code block with the tag `reflection`, e.g. "```markdown```" +**Performance Analysis Framework:** +Analyze the code against these performance patterns based on empirical evidence from high vs low performance kernels: + +1. **Algorithmic Efficiency:** + - Does it use optimal algorithms (online softmax vs naive)? + - Are reduction operations single-pass (avoid loops with tl.sum/tl.max)? + - Is there unnecessary data movement (element-wise vs block operations)? + +2. **Memory Efficiency:** + - Are memory access patterns coalesced (vectorized loads/stores)? + - Is shared memory utilized effectively (block sizes 32, 64, 128)? + - Are there redundant loads/stores (multiple passes vs single-pass)? + +3. **Compute Efficiency:** + - Are vectorized operations used instead of scalar loops? + - Is the compute-to-memory ratio optimized (MAX_FUSED = 65536 // element_size)? + - Are warp-level operations utilized (tl.dot, .trans())? + +4. **Specific Kernel Patterns:** + - **Reduction kernels:** Use single-pass with tl.sum/tl.max, avoid loops + - **Transpose kernels:** Use .trans() for block operations, not element-wise + - **MatMul kernels:** Optimize BLOCK_M/N/K for cache locality (64, 128) + - **Flash kernels:** Use online algorithms, vectorized exp operations + +5. **Parameter Optimization:** + - Are BLOCK_SIZE parameters optimal (32, 64, 128 proven effective)? + - Are num_warps/num_stages tuned (4-8 for compute, 8-16 for memory)? + - Is grid configuration efficient (avoid complex GROUP_SIZE_M)? + +Common Performance Anti-patterns to Avoid: +1. Two-pass reduction when single-pass is possible +2. Element-wise operations instead of block operations +3. Small BLOCK_SIZE (16) when larger (32, 64) is better +4. Complex grid configurations when simple ones suffice +5. Scalar loops instead of vectorized operations + Maximize performance by exploring the following: i. Autotuning key parameters BLOCK_SIZE, num_stages, num_warps. ii. Better algorithmic implementation (e.g., naive softmax vs online softmax vs fused softmax), better memory access patterns and numerical stability. iii. exploring all possible operator fusion strategies within the kernel while adhering to resource constraints. + Primary Autotuning Fields (Mandatory) 1. BLOCK_M, BLOCK_N, BLOCK_K * Tile sizes for GEMM or other tensor contractions. * Larger blocks improve compute density, but reduce grid-level parallelism. - * Explore wide range of values like: - * BLOCK: [32, ..., 128, ..., 2048, ...] + * Explore proven effective ranges: [32, 64, 128] based on empirical evidence * Adjust based on memory reuse and L2 cache locality. 2. num_stages=n * Controls pipeline depth for kernel execution. * Rules for setting this: - * 1 if no GEMM. - * 2 if a single GEMM (e.g., GEMM + ReLU). - * 1 if two GEMMs are fused (e.g., Flash Attention). + * 1 if no GEMM or memory-bound + * 2 if a single GEMM (e.g., GEMM + ReLU) + * 1 if two GEMMs are fused (e.g., Flash Attention) * Optimize for latency and execution overlap. 3. num_warps * Controls number of warps (groups of 64 threads) to launch per block. * If it is too low then underutilization -> kernel runs slow. * If it is too high then register spill happens and shared memory is overused -> kernel runs slow. - * You must choose a sweet spot by trying out integer range of 1 to 16. + * You must choose a sweet spot by trying out integer range of 4 to 8 for compute-bound, 8 to 16 for memory-bound. * You MUST NOT try the range beyond 16, it is NOT VALID. + Examples of Autotuning Setup Here's how Triton kernels should be decorated to allow autotuning: * key argument indicates the variables that change and trigger autotune to re-run. This is a must argument and you must not miss this. diff --git a/src/retrievers/__pycache__/retriever.cpython-312.pyc b/src/retrievers/__pycache__/retriever.cpython-312.pyc index dc463e336802936d3187aecc27bdae9c5fdf2a5e..c04971bccd2cc1a87e4005926688051e28408f5f 100644 GIT binary patch delta 20 acmbOwHA{;7G%qg~0}w2WU$v2&pBDf&Dg^2P delta 20 acmbOwHA{;7G%qg~0}vF%&DqG!&kF!FWCW!E diff --git a/src/soso/flash_decode2_phi.py b/src/soso/flash_decode2_phi.py new file mode 100644 index 0000000..e0157ed --- /dev/null +++ b/src/soso/flash_decode2_phi.py @@ -0,0 +1,145 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_lse_b, stride_mid_lse_h, stride_mid_lse_s, stride_out_b, stride_out_h, stride_out_d, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = tl.cdiv(cur_seq_len, BLOCK_SEQ) + offsets_d = tl.arange(0, BLOCK_DMODEL) + sum_exp = 0.0 + max_logic = float('-inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for block_id in range(0, block_n_size): + offs_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_id * stride_mid_os + offsets_d * stride_mid_od + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + offs_lse = cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_s + tlogic = tl.load(Mid_O_LogExpSum + offs_lse).to(tl.float32) + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + exp_logic = tl.exp(tlogic - new_max) + acc += tv * exp_logic + sum_exp += exp_logic + max_logic = new_max + offs_out = cur_batch * stride_out_b + cur_head * stride_out_h + offsets_d * stride_out_d + tl.store(Out + offs_out, (acc / sum_exp).to(Out.dtype.element_ty)) + +def flash_decode_stage2(Mid_O: torch.Tensor, Mid_O_LogExpSum: torch.Tensor, B_Seqlen: torch.Tensor, Out: torch.Tensor, BLOCK_SEQ: int): + batch, num_heads = (Out.shape[0], Out.shape[1]) + BLOCK_DMODEL = Out.shape[-1] + grid = (batch, num_heads) + _fwd_kernel_flash_decode_stage2[grid](B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2) + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/soso/l2_norm_bwd.py b/src/soso/l2_norm_bwd.py new file mode 100644 index 0000000..9b96b6e --- /dev/null +++ b/src/soso/l2_norm_bwd.py @@ -0,0 +1,112 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + mask = cols < N + dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) + dy = tl.where(cols < N, dy, 0.0) + gy = tl.sum(dy * x) + dx = dy * rstd - gy * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float=1e-05) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + M, N = x.shape + dx = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This l2 norm doesn't support feature dim >= 64KB.") + _l2_norm_bwd_kernel[M,](x, dy, dx, x.stride(0), N, eps, BLOCK_N=BLOCK_N) + return dx.reshape(x_shape_og) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/soso/l2_norm_triton1.py b/src/soso/l2_norm_triton1.py new file mode 100644 index 0000000..93b6007 --- /dev/null +++ b/src/soso/l2_norm_triton1.py @@ -0,0 +1,100 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + row_start = pid_m * stride_x_row + _sum = tl.zeros([BLOCK_N], dtype=tl.float32) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N) + mask = cols < N + x_ptrs = X + row_start + cols + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + _sum += x_vals * x_vals + var = tl.sum(_sum, axis=0) + rstd = tl.math.rsqrt(var + eps) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N) + mask = cols < N + x_ptrs = X + row_start + cols + y_ptrs = Y + row_start + cols + x_vals = tl.load(x_ptrs, mask=mask, other=0.0) + y_vals = x_vals * rstd + tl.store(y_ptrs, y_vals, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float=1e-06): + x = x.contiguous() + shape = x.shape + x = x.view(-1, shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + BLOCK_N = min(triton.next_power_of_2(N), 1 << 16) + assert N <= BLOCK_N, 'Feature dimension N must not exceed BLOCK_N (64KB limit)' + _l2_norm_fwd_1pass_kernel[M,](x, y, stride_x_row=x.stride(0), N=N, eps=eps, BLOCK_N=BLOCK_N) + return y.view(*shape) + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/soso/matrix_transpose.py b/src/soso/matrix_transpose.py new file mode 100644 index 0000000..974f7f2 --- /dev/null +++ b/src/soso/matrix_transpose.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + mask = (offs_m < SIZE_M) & (offs_n < D_HEAD) + m_ptrs = M + offs_m * matrix_stridex + offs_n * matrix_stridey + out_ptrs = Out + offs_n * out_stridex + offs_m * out_stridey + m_val = tl.load(m_ptrs, mask=mask) + tl.store(out_ptrs, m_val, mask=mask) + +def wrapper(size_m: int, d_head: int): + SIZE_M = size_m + D_HEAD = d_head + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + BLOCK_M = 16 + BLOCK_N = 16 + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_N)) + kernel[grid](matrix, out, matrix.stride(0), matrix.stride(1), out.stride(0), out.stride(1), SIZE_M, D_HEAD, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/soso/matrix_vector_multip.py b/src/soso/matrix_vector_multip.py new file mode 100644 index 0000000..bf6fdcb --- /dev/null +++ b/src/soso/matrix_vector_multip.py @@ -0,0 +1,74 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = tl.arange(0, BLOCK_N) + acc = tl.zeros([BLOCK_M], dtype=tl.float32) + for k in range(0, N, BLOCK_N): + rn_k = k + rn + mask_a = (rm[:, None] < M) & (rn_k[None, :] < N) + mask_b = rn_k < N + a_ptrs = A + (rm[:, None] * stride_am + rn_k[None, :] * stride_an) + b_ptrs = B + rn_k * stride_b + a_block = tl.load(a_ptrs, mask=mask_a, other=0.0).to(tl.float32) + b_block = tl.load(b_ptrs, mask=mask_b, other=0.0).to(tl.float32) + acc += tl.sum(a_block * b_block[None, :], axis=1) + mask_c = rm < M + c_ptrs = C + rm * stride_c + tl.store(c_ptrs, acc, mask=mask_c) + +def mv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2 + assert b.dim() == 1 + assert a.size(1) == b.size(0) + M, N = a.shape + C = torch.empty(M, dtype=a.dtype, device=a.device) + BLOCK_M = 64 + BLOCK_N = 64 + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']),) + mv_kernel[grid](a, b, C, M, N, a.stride(0), a.stride(1), b.stride(0), C.stride(0), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + return C + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/soso/rotary_transform.py b/src/soso/rotary_transform.py new file mode 100644 index 0000000..4e56347 --- /dev/null +++ b/src/soso/rotary_transform.py @@ -0,0 +1,194 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def rotary_kernel(OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + if not IS_VARLEN: + current_batch_offset = pid_batch * stride_x_batch + pid_head * stride_x_nheads + X_ptr = X + current_batch_offset + OUT_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + seq_len = seqlen + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + if pid_m * BLOCK_M >= seq_len: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = tl.arange(0, BLOCK_K) + rk_half = rk % (rotary_dim // 2) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rotary_half = rotary_dim // 2 + if not INTERLEAVED: + k0 = rk_half + k1 = k0 + rotary_half + mask_m = rm < seq_len + mask_m_cs = rm_cs < seqlen_ro + offset0 = rm[:, None] * stride_x_seqlen + k0[None, :] * stride_x_headdim + x0 = tl.load(X_ptr + offset0, mask=mask_m[:, None] & (k0[None, :] < rotary_half)).to(tl.float32) + cos0 = tl.load(COS + rm_cs[:, None] * rotary_half + k0[None, :], mask=mask_m_cs[:, None] & (k0[None, :] < rotary_half), other=1.0).to(tl.float32) + sin0 = tl.load(SIN + rm_cs[:, None] * rotary_half + k0[None, :], mask=mask_m_cs[:, None] & (k0[None, :] < rotary_half), other=0.0).to(tl.float32) + offset1 = rm[:, None] * stride_x_seqlen + k1[None, :] * stride_x_headdim + x1 = tl.load(X_ptr + offset1, mask=mask_m[:, None] & (k1[None, :] < rotary_dim)).to(tl.float32) + if CONJUGATE: + sin0 = -sin0 + o0 = x0 * cos0 - x1 * sin0 + o1 = x0 * sin0 + x1 * cos0 + tl.store(OUT_ptr + offset0, o0, mask=mask_m[:, None] & (k0[None, :] < rotary_half)) + tl.store(OUT_ptr + offset1, o1, mask=mask_m[:, None] & (k1[None, :] < rotary_dim)) + else: + rk_half = rk // 2 + mask_m = rm < seq_len + mask_m_cs = rm_cs < seqlen_ro + x_offsets = rm[:, None] * stride_x_seqlen + rk[None, :] * stride_out_headdim + cos_sin_offsets = rm_cs[:, None] * rotary_half + rk_half[None, :] + x = tl.load(X_ptr + x_offsets, mask=mask_m[:, None] & (rk[None, :] < rotary_dim)).to(tl.float32) + cos = tl.load(COS + cos_sin_offsets, mask=mask_m_cs[:, None] & (rk_half[None, :] < rotary_half), other=1.0).to(tl.float32) + sin = tl.load(SIN + cos_sin_offsets, mask=mask_m_cs[:, None] & (rk_half[None, :] < rotary_half), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0 = tl.where(rk[None, :] % 2 == 0, x, 0) + x1 = tl.where(rk[None, :] % 2 == 1, x, 0) + out = x0 * cos + x1 * sin + tl.store(OUT_ptr + x_offsets, out, mask=mask_m[:, None] & (rk[None, :] < rotary_dim)) +from typing import Union, Optional + +def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor]=0, cu_seqlens: Optional[torch.Tensor]=None, max_seqlen: Optional[int]=None, interleaved: bool=False, inplace: bool=False, conjugate: bool=False) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + if not isinstance(seqlen_offsets, torch.Tensor): + assert isinstance(seqlen_offsets, int) and seqlen_offsets + seqlen <= seqlen_ro + else: + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.to(torch.int32) + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + seqlen_offsets = seqlen_offsets.contiguous() + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and (not inplace): + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + BLOCK_K = max(32, triton.next_power_of_2(rotary_dim)) + BLOCK_M = 4 if interleaved else 8 if rotary_dim <= 64 else 4 + grid = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + with torch.cuda.device(x.device.type): + rotary_kernel[grid](output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3), output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0, x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K=BLOCK_K, IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), IS_VARLEN=is_varlen, INTERLEAVED=interleaved, CONJUGATE=conjugate, BLOCK_M=BLOCK_M) + return output + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/soso/sin_kernel.py b/src/soso/sin_kernel.py new file mode 100644 index 0000000..020ee44 --- /dev/null +++ b/src/soso/sin_kernel.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int=64): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/soso/triton_matmul.py b/src/soso/triton_matmul.py new file mode 100644 index 0000000..a3aa0a9 --- /dev/null +++ b/src/soso/triton_matmul.py @@ -0,0 +1,99 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mask_a = (offs_am[:, None] < M) & ((BLOCK_SIZE_K * k + offs_k)[None, :] < K) + mask_b = ((BLOCK_SIZE_K * k + offs_k)[:, None] < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=mask) + +def matmul(a: torch.Tensor, b: torch.Tensor, activation=None): + assert a.dtype == b.dtype + assert a.dim() == 2 and b.dim() == 2, 'only 2-D tensors supported' + M, K = a.shape + K2, N = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + if a.dtype == torch.float16: + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + num_stages = 2 + num_warps = 4 + elif a.dtype == torch.float32: + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + num_stages = 4 + num_warps = 4 + else: + raise RuntimeError('Unsupported dtype for AMD Triton matmul') + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=8, num_stages=num_stages, num_warps=num_warps) + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/embedding_triton_kernel.py b/src/temp/embedding_triton_kernel.py new file mode 100644 index 0000000..555debb --- /dev/null +++ b/src/temp/embedding_triton_kernel.py @@ -0,0 +1,144 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weight, out, seq_idx, stride_wm, stride_wd, stride_om, stride_od, stride_s, + total_tokens, d_model, seq_len, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr): + pid_m = tl.program_id(0) # sequence index within batch + pid_n = tl.program_id(1) # token index within sequence (BLOCK_N stride) + pid_d = tl.program_id(2) # feature dimension (BLOCK_DMODEL stride) + + # global sequence offset + offs_seq_m = pid_m * seq_len + offs_seq_n0 = pid_n * BLOCK_N + + # collect BLOCK_N embeddings per step + for nstart in range(0, BLOCK_N, BLOCK_NN): + offs_n = nstart + tl.arange(0, BLOCK_NN) # [BLOCK_NN] + mask_n = offs_n < BLOCK_N # [BLOCK_NN] + global_n = offs_seq_n0 + offs_n # [BLOCK_NN] + mask_seq = global_n < seq_len # [BLOCK_NN] + + # read token ids (int32) + offs_ids = seq_idx + offs_seq_m + global_n # [BLOCK_NN] + token_ids = tl.load(offs_ids, mask=mask_n & mask_seq, other=-1) # [BLOCK_NN] + + # compute offsets in weight tensor + # flatten token ids to compute global offsets + offs_weight = token_ids[:, None] * stride_wm + (pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL))[None, :] * stride_wd + # load BLOCK_NN * BLOCK_DMODEL elements + local_weight = tl.load(weight + offs_weight, mask=(token_ids[:, None] >= 0) & (pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL))[None, :] < d_model, other=0.0) # [BLOCK_NN, BLOCK_DMODEL] + + # store to output + offs_out = (offs_seq_m + global_n)[:, None] * stride_om + (pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL))[None, :] * stride_od + tl.store(out + offs_out, local_weight, mask=(global_n[:, None] < seq_len) & ((pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL))[None, :] < d_model)) + +def embedding(weight: torch.Tensor, out: torch.Tensor, seq_idx: torch.Tensor): + assert weight.ndim == 2, "weight must be 2D: [num_embeddings, embedding_dim]" + assert seq_idx.ndim == 2, "seq_idx must be 2D: [batch_size, seq_len]" + assert out.ndim == 3, "out must be 3D: [batch_size, seq_len, embedding_dim]" + assert weight.dtype == out.dtype, "dtype mismatch between weight and out" + num_embeddings, d_model = weight.shape + batch_size, seq_len = seq_idx.shape + assert out.shape == (batch_size, seq_len, d_model), "out shape mismatch" + assert seq_idx.dtype == torch.int64 or seq_idx.dtype == torch.int32, "seq_idx must be long/int32" + + stride_wm = weight.stride(0) + stride_wd = weight.stride(1) + stride_om = out.stride(0) + stride_od = out.stride(2) + total_tokens = batch_size * seq_len + + BLOCK_DMODEL = triton.next_power_of_2(d_model) + BLOCK_N = 64 + BLOCK_NN = 8 + + grid = ( + batch_size, + triton.cdiv(seq_len, BLOCK_N), + triton.cdiv(BLOCK_DMODEL, BLOCK_DMODEL), + ) + + embedding_kernel[grid]( + weight, + out, + seq_idx, + stride_wm, + stride_wd, + stride_om, + stride_od, + seq_idx.stride(0), + total_tokens, + d_model, + seq_len, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + ) + return out + +################################################################################################################################################## + + + +import torch + +def test_embedding(): + # 参数定义 + vocab_size = 1000 # 词汇表大小 + embedding_dim = 512 # 嵌入维度 + sequence_length = 128 # 输入序列长度 + vob_start_id = 10 # 词汇表起始 ID + vob_end_id = 1000 # 词汇表结束 ID + + # 创建测试输入张量 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + + # 调用嵌入函数 + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + # 保存结果 + results = {} + results['test_case_1'] = out.clone() + + # 测试不同的输入 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_2'] = out.clone() + + # 测试不同的词汇表范围 + vob_start_id = 0 + vob_end_id = 500 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_3'] = out.clone() + + # 测试不同的嵌入维度 + embedding_dim = 256 + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_4'] = out.clone() + + return results + +result_gold = test_embedding() diff --git a/src/temp/flash_decode2_phi.py b/src/temp/flash_decode2_phi.py new file mode 100644 index 0000000..1e798b3 --- /dev/null +++ b/src/temp/flash_decode2_phi.py @@ -0,0 +1,141 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_leb, + stride_mid_leh, + stride_mid_les, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + for block_id in range(0, block_n_size): + tv = tl.load(Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od) + tlogic = tl.load(Mid_O_LogExpSum + cur_batch * stride_mid_leb + + cur_head * stride_mid_leh + + block_id * stride_mid_les) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp *= scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + out_ptrs = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offs_d * stride_out_d + tl.store(out_ptrs, acc.to(Out.dtype.element_ty)) + +def flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, BLOCK_SEQ: int, BLOCK_DMODEL: int): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ, + BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + return Out + +################################################################################################################################################## + + + +import torch + +# Define the test function +def test_flash_decode_stage2(): + # Define the parameters for different test cases + batch_size = 2 + head_num = 4 + seq_block_num = 3 + head_dim = 64 + block_seq = 16 + + test_cases = { + "test_case_1": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + }, + "test_case_2": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + 1 # Different block size + }, + "test_case_3": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq // 2 # Different block size + }, + "test_case_4": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq * 2 # Different block size + } + } + + # Execute the function for all test cases + results = {} + for key, test_case in test_cases.items(): + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + results[key] = test_case["Out"] + + return results + +# Run the test +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_155036.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_155036.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5464ade1edc9bd33256605098978ba9e7fa71eb GIT binary patch literal 5348 zcmdTHTWs6b^-`oLN+k8L9+qP}jkBSRnxt-QCu`EQiPJPm6UW)Gwm@BXf-KUOWLa`Z z+If^3*rLNMtbmmLMCl5|Ee6zC0>s_V(xE`JVZes%$BYX^;5tA+^HKjPwXp&F+PTy- znmq6Mu?z6tbIp)4;z_M>RbZVa5W0#4BoIk7RvC#g z!kk-X$T5mQB9Sol)I=+zU=gWAO-mC=AUDBN?eUYeV10+WS=BIR6KD}7z}nU(*fzm4 zqFu1V=h%eK3dAHQaPLrKj=GopJ7|oxAd#)F%UvQu&NaN~s9`%ca_)sQ;ZcLB0;2|3 zfsUU&cl^bXsI$p9yfq|z>m2k~kpkOUX$eaWYK#JLvr5l6x8StYIbEtt)P zvjz!-%HOQ6c^Qxyx9V2uX~1MWO>5CL{>BKBl*--2UO+pJ^28h4L8?RTXp(EcBKOFO zUe&8o=88Vmsgl@c)oa?~Z_>n6)@kllr1~+R>bphOG$z<)jtZ>QQSCynzGH!=)tb+$ zDy;@o`}8=h7Em47+RJbz7+2eYOrRY{rc(_-0HGgP1T<;s0$r-r^w5p57+=Af{sx;Y zXy-4X1}pxyX!Nw`!j^SD%7d+K>{6x+cG_0i=|}lg24l@t+F7`*$-*|jn`%+FnX+A~ zuSIVg8^NAx^bim|@AgsX5(-hH5)Q{T9PuH-pk^ZSbjWU4l%(O9mC{pEWF{OJ3_2o3 zQd6Q~tqfE$9T9*jZCK+e<-m|(y$};6Aj?Wx5<}gFT}e-iDUg&6`hpmriYW$}o>L4) zR;0Kfh9~M!vJTDHA;GX$`O*^&Fllh}>4~tcL?i{<*g8$0Qv65^N_j?6Ywp zJP}c%F~f#IAwFZ!(R5M{gDcoY7;P0r$tnRxlK_6FCnseQFCiNw?i}-xBq+gTgRc^+ zPEkMthAI*~l)GX59h5;C-Qi&@>lpmWYxVxd+7_OJ2#V^DGaYVw`(U$yf3fn{5O6&@lpJP_%gA4c=hq4 zI)79Jnnja$5^2m6;FzvKJm`8%B-_v^vo1^x?{CwC!xp~!XQ-_W_8VBdR0 zVW?oc^4Q{mrEmV6SbE}vBf4i`dFFk=?uDKz26AWSJu8%0H=dJw?vF*`j~3 zQFh0}-N-k%4o8Kg4%~Thhr>zg!JQ8}gO)^PPD({_l-C+g3^@MRe-D51RivUsjnXu# zL_2KMa?vuq3zWN6svRBEWb0}4HbEfVgKnHHTAETXJBc=h%hd7MmXd7}uqkOV*KL2h zAxzs{U^_1GEln^zlPmJ5T`3v|UozGV4SUn(j$~4883wesp{iDuR+&k%VWZ-z8PM)# zq%!rwXBWtfL!dIOYHgCiN8Zv@2FhJ^zb7pjPVv_mRL(4INX3@|l<2JC8>3*Yt+xrC zuG5Jv0|u4C#lR}ZB0#4rsnu|$#EVLJHhn>q z!s*E{a7BgqyqP-AZl{N{fpam*k_4B}OTIBwRThFu0V1@%tm!GK31Mye9xAfF+>z{& z_r>ee*QW2Wp{3Vuhi`;`WBo1rEB4;rGb_}Y74}S#_2iz-K3nL!$96B0A2A;=OT!=X zE7aZ9xU<18{P0QZ@o z&IAjdRi?Xj=FIU~J|9_Sg23GbvpGlBk>B?t9yBySUyXfmA= z;hh6-h-d`f6!FWV!Pm2~`Cf3z;`uo-6&1rtF*T*c8dovH9#75AnGb}@CC$8=*(6b( zODeK#ht=!JzLMyjrjK>~(ks9cgHOI*$@UED2xzBEzE16IDbS^zD|!9eNXg^ZPH%9? z6TtWN664D!7LTm#9QyRd&(3^uX64(jZdfS~t-S;+PH*n@?CT|_BY*am|90<<-aGR^ zbkf=hAhPb5LsG#VZ(>4&V_g(u-LP7`Oe_y)DgJox;)y+ zQtyuIuU&iXC$DQK7tTWN{^zV{(bcL1bo)%VeA?*;ZvU1}ZLQ89)XqNOJN`Z2xmMr* z&Rzy_UTxGYC{X=zh)K9a8?MT&;i)7%iRa+uW-v8K!e?)giMV3eqv@2ah!NVnJ zF(s!Zcu_)t9>cnD5EyJFYQj_Lq;Lb7YgIyqIXsgV=91#`()WN5!UV~;p(~RFL41k0 zKO+85sHa5Oez@(eZDmV8;V1_Yyqdf9H>6AuJ2=LJ2=vSOg2pZy+S0t8PIgQC;JiB>qolCJ9N$ z>QJ!)s}FIcm15MDNTo_mHjzmCQYuv{mHM!4AL`je#hOyd((+K=vI~9bQ_mgSli92x z?Y4dBm1gFgbI(2J+&kxf=iWaui~~XYcH-mFtPi2jFjH$xm3VF?5Sl|QVu>UgsCQC~jnWa-AiCHa=2f0R)}+b zSI(X9jk*niQb3#GdHy>SUOo{PvC`teZ*DBHRPE2Yx6m}{njXn5nG-_@@x&C8Jd$hL zf{cF1&0b*X3CXoo^*;^WS+C@kED4ZgePE5o7C_>}U#;H{oC9;&fFUmHS1n7f5kl0( zaXkA?ZP{G&oDCXm)sK?VbsMAxiPR!&ga|1u!Uiq2CWDna=_etn3F}FXEU`>3;eP=w z+blIpR=jHPwPCx{u=3d!ES&}`TEAWM641PC+rkSWZPUE?Aa&3OCKe5wn!%6Mf_W^d z@YoA$F~reouv=%7NUf}C*{ZKXPzuJ4*`>_|yqM~3jyP=P%}vi5(bMG^*OuRSg)Ht_-7wlMx)S zpxX5nP%RUYBwo}>O&xzhAP8*`$BtLEe++d_X843K7Kw75qmgJLB93O#--wLBYIh2J zv{U3#W1YZta;X~}%f`|pVOSb5lMW|1KFuZfk4=T4yC#gn!iEnXJ#zHuVdzmk8>1@f z{6UBarf@#ed8-;pnryiJAUig))yV1-sfKpmHRN=>bWDW$2SNCuxe zoHHG{j(l(aXCV%^y(s`wL zZD5TV1Z4<#cW!sSp%7Vf?2ylv9gM~|&b7_9J@$qaZ)nx=rhN7ZX}xpt*2TNm^3w{_ zK7V-O>YvH3=di|z-v=*925o2X+pwe+OU#gYs+H%&aXuXdbk~~~D6MLJ0$Y6!Nhq#& zZ4g7aURbhLiFj$REla^z3>TZ4sN;RSP8CUXKw_AB&<_c){p!(mDvZH)oq8~An%J;i zJKPwes}=gHe^TBr>MIj>85FCeWRa+0Q~fwWTIbNROw{l+ON#!Q#dtjwrgXfvi;XqS zEyh@DV_1`3Hyo11Fsi9$Y2zzVlD(!qY|7FiP7$qauLW6Gw+ znrpsc=3l3u)W?%sjIZCnG2S8Bmn}=xT1g;@{1ysX`xi}>LSkjPiCC0-_AX37t7bgT z!jNeLwL@;GFr$$%t}?&^#U?tcT4HHjn<_=>;1M3%SDBaqK%Ew2BjcHILA3xNY$O@ZbBu>OmaK`bNG;IPQ}tt2H4nyhJ>oMQb$5gZL9K!aa{U7<-0J8{~C|~ zg-V3^7R=y1q(!WfBT%LUe10&3JbwUd+jZ>ayMf2Hgf0zq*Ij7auKS($Iu)`}?!G%v z@;1)3&bHp_kx$E$B|0#3DtGE5?tbE4;!`?QBtu0yRHD5zCvqn~I)A_KUf=Rt58hdP zr?~ByLLV!V$BOhZ=zdIZSLp3)bbFC(FVgL0n?wHo6PJG`kxLYg&UY7E_N`p_W$({> zKWX?=+wa@fT-WZJpZc30`$LL9w6JH@e@O0q;_`mD`?0S>@pUY;topiEZmhbF$=zkA z|HG?ugR_J4jSKr$gGY<>+w!@mWFY^(LbjABFI0yTMVBadiSlXhfJS&9Q%wrhRPe4+ z+dyT8$uarJD%JF5nx1jw9QlJkV!(t>dKyc?_P=lUSgh;2kkc*q@P}aaqVVVqpG!2P z9ECSrkBBT{O<969XH6`bH3Lr?PIz^&4E7Q!nTCz0W*r-KFZj%8U8>?x)?%<&2R-4$ zYqVAeAvJ4NFIlU>S{=NE-Tv=bvxz#UYG{FHXRAd{pcsORV5@hAt>Lajv9=R1aYI!9 ztyOzvWLri>;iyfS_bb3oAzy>sQw}uB zedUH``AXRzlzYqGpxpDqiM$PS%q&x;0{Qs-sp8JAU!VWYH@KhsLz8M-y}oB{6ncL_u=Ur=#^@P9sJM(#Ga72cw?x%?BZnOTwm literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_180807.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_180807.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e0eba802607047841382281212d6132d3c86a5d GIT binary patch literal 5704 zcmdTIOKcm*b(Y*8xg;f0;ztstP=+0aux!P$Ydf}))bU4-EXhr3G>|QWptviE5~oBmCSZUTR2XULUx0=3-z-Wy_Nx)l!S zsm~M6Dy?jdLJd%yN;6NalDbt1m*V1ysS(iPQp{NUJnU4$UE_sKvFZG0(Y4bP@Mz4?g8iU$FO(Kc!L%}k)+vd$UHoeU|7x|i*6PCk zDehVxjN_@QUGXXg*kMLrRcWf_SJ(<=r==#Hl4!58w9dz?v@70kV(Tic?>&J&{FH=S z^x-}d%^^QCB;xC_iEs8Mp7{+b84XHPDixO}1aTZtTJsa42_}A8HOLXwJT1n@#b7iL z=2c1z#>NGe4wgnT5)T4WlH;P_?@*0lUQ+2BLU?>aR*mr)S+z*A80G~Kl>(zRkfb)0 zX(j%^>`nw<;Rj7s(ioCzXuN8kjfaAx0VzB$s01(KFGwVDhrcbAnhx`zKpUjez^imD zFdCEu)f68alLT1;oiyGI7#VEtv@EJjXhy`gt7eU&Euzw(t40uXX=GmI%D#ZCQ~_p# z5lLmS)fg>xfJMHqZC{P_+=z=)(sVE+^h^XpQ$cwm9y<^m7h-acB!+rqAv)ayP>&ED z6?i@z8xO!X%kfxXN)Tg0VYavVs}h31y* z`Ng4+-}+?u&hX0hwU;*boy;|#Oqufb*6f>$?J2ql60R(@Y{_}LbB^wnxd*3mJ$*Sx zUy6C;XkO}C?8v&zcIA6+L> z=kuQSW%`zF)p5V&Zp+%CwaAA5OwMyA^+w)_Wv`{q7g=O?=N&B{b!IPTyK;^_%YmHZ z$P*)B|5m}#mOcBk&Ma84clq6%ZP2z*v(O@M;Pl~&6 zQi@6E95W`&vQt;3n6-q7Ac@kjD{dXXJ@T}^gU_Ke6*{qte0}=1JiUv2efnDV>7Xk9FE6)@^~}i4!nf7GhFd^u@<~a(qgN`NABGP?sgOb=+<>0zeZ<*PQf4ka_ra^# zH$Ue(*2#`_t^;FdGfG<7;JnMk=UnGH*}2Yj7Mz}?;l<&bL&&)}A+tZde_3w_zQH`skTbb3MAVD}ao(n2TCLjItf3zu_bYo6uutWE2=j5%#yIP`%P zpzWzGV@+EZf*Y&{_^RerEcK_3JZs7~y7TsyytDIfFEmngaWCv%vR^z18ybRNrQxfh zC7lWA;UP;T5N}8jyfI;bfM^7mtg3}8$)_TZAS9BCVT_O~sCtPBDs=i@`E>;W9oM~D zo2-+j;DuDJtI zDOF*4<`m$ntom>5+SgjPt5qzviZ{WTn*S$f`kz_HY1Ua1MAg3ipWTgJPr!NX4us%!6F5{2^lA}4yaiuSHTv)-+JQOZ zXT_sf69W%Kc$moIaSY*SBOg4)?XD&3 zqP4&_FHA1Kw!Y`cZ_a&s@z)pEzw^B~l)0L|T3{UugR9Q_?RVQB%mUFy zC(i?sYs8Vhz`C>TtHS;0-RQ&fK;-D;nXP6-*;3)GYuURp|KRW!`HzyGUE6bX(X{(KtU75ndGopbn=J>BrC&jRZk_i7bCc*1sSxq+ZX8^Mb{m8AuNlDVT$PntOlb=EyS!pl7oLBCJo@NrGs4GGasx?Hjg^o1A9! FKLET`-xL4< literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_18528.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_18528.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5099ff956c0f6feff729364968a34003d714e426 GIT binary patch literal 5177 zcmdT{O>7&-6`mz`m&^Zu>c^xUC5mM#m1W6E8nn`gjNxUI3kTEYCSPQ z=(K6ZJV6qOC#G~ejj6SsGx6k9!-_$~nV-NXpTMVD^Aik5ze8THTb{6R79LG`3=7qV z9PB*)YXtI6StzI(FjFu^6b-28YK_93B~q!$mY3K2|gQ9^N+LOOW6yYHbSW?TsUD zn6B3~a^4+u%?4ec?3F!>6f)W;w_8UL0J7(LeceS6=lrrycHnurfMjXtgO)>fU&mfU zJB*6N524Qm<)Fc%*4EvRSr5q}*)a|4<=W(c9MV^8H`rkG1s?mzNN&e^avPil+3+FP zA$Q2;=`+CZkV$ND5l%^N!{1ydmMy|L=XM!vw+LssQ{Kf9O`E%bg2(U^6NeI}4!ouX zpM>nkT+FX?8O+sv4;$>a@|}>wy6>Ia>~6Q{YV6`UA|V246|xI8 z$k2$XIya-y_?TiDjY*_gyv6gH8K!B>Je!8u7*7PAP2&Nz5CcmZ#m}7Ofy!b^+zsnJ z_2um!rw10Y!n8Qc#`%FsHa^Wtli5r^o8&XnfGETVBz|Ug0H^`oGMCCEV{n2}HWQoX zg$$qGH+vxl;s!IGg~J_u>4l+}K%jZHCe^fs0dNW9hWJy6ggkoGj-2l0@uhL)+$WCi z{K!3r%Rp6}zTz3x+5OvBuO7W}bd6mPZ}z;bI$zEosd$2lWyPI8dJj|vi%DhSuCup5 zl)b)Ret-4U$|>bssdqEjr+WJebj91I^s3(O0$uipm3h_wY=ODw@+>EplEsfV*T%1fmm ztg)NE{RKSvV zdhcARM_Euqy=yP7_uL}Zd#}B6^9}X6Z>yp2{GBpc4-ubrBiG)Ca3x5yHlLD4*PfAP zhA`J?ftTilOdJPsqxoT=r>W!czE_cqrWz7M&k@smA@02(twV!Bh`6oPh3B_O5NB{I zpmMZKJW*=NF*4aoDpN8oGYK>HKs1e_9}{2&zrbw`X1CyO&62E*39Tv~nQ5k9s4Q|e zgLW%eWn^muz-207%Go7%!@6wOQwov^nFRPX!-*Kng1Xw)nLR;HH2i_xe);^?@~!#g z4)a^fcbZQ&{i9`z%*Zy`3Q4W*DS>424Lxb#L)SCMgFK8L=sCo+g`!cbMrHt4HS1hP zJU7SlfU}f1$<9K@Nf)4_Ybd3;GyHi-0@(##08ESxX=I#B&FeU%(KW!2I)p*I+z|dW zQ-yEk9PxR`rSG6I= z`|jHp-&VYLES=>*?I+ zyUMgH-c4qga_VlRUybww=tcII8P~FN$ysDKLAvL1|D}F#y|b^(*q3cfw&LK=oJDbU zab>a8{fk^Fes$`~)VlXKGiAnI>{Q+;d)msu?!S4h6kQD`$JiqO$z z629U5qi)dG>!_Ba3>I4;0u-7? zYc0w--C7;Ra0Wd?B9bv*S~cyB2L zp&`5o+y^*v0Ws$|Hfj>Si3ytH=TmWBvrD`v#p0|8<(`ok(osA02yo8GKJWl-<2)Z_ zGzx0?_#{-ZEP#1N(kS>G7}D!~E~8lzQ1I*YA}?e`9Yi%Mp3Y|Ys8vT}&Dn(57|sT6 zsK4jcv6})p!pWU>$nbuvpAQQL=h*NCUnHMKb z0Yte9DJ4`2uPxp@@Co@!)q&iA=#|J*mrq?f^{dnQ*9zkR;{Tc#+iYz&0=5JCwgUOl z`<~#Ao7&l2usuJ1-|74J_U_oe_b>R%AcsGHTu&uYr!WThCE&o;+_gerUmE^O&cXj0 zjcJSooVGRdR7%pU@oYwv`17+uy)t-{&xlz;Gebb?ai<4?X0Js}ESXJn*O0zeEo7MC znJhP#=3f?KU=zXwiGP4k)l3k?r^xYViRH%5c@9 zlMH&g=W6&$xN6l&AKKl0IejTz_3LDyuKalrlMl_vevF`Et{=ak1cfRwTKbp-31nt+qyEXOO7KFYcK&6?t2U~k7gq9JHcq)v>%Ky|D zrNWw7?HElVftptJR7A`Fyhfm>O*?S};H46l2KzP+tzjMXFZ<1IWRpy{Rt z`q{_eIYH0s;WIpjZfuO7F!82$bYq6<3YgzPV@3@Uj8&gZ3zW~gNn9CG*b`@88#v=1 zQMm2Fh+-;_2G6}Vc=Bw}R^u`UV=a7J5$KnZ3`<>Z)0)cY+kc+Ay-knVWL*W%Q<%n< z+e8brL6oL0rVxrXV;g1{feSY;0zPK1(aor=J38D9KMUc!y=^(>zz$5$02Fi9$P$n1 z8b;VoBx|@DSO7mQ72SjzJ9%nFQ}H|Is*!FKOcZuiq<8Q(2r#VUwTqx*heYjqS+KQY zcMx6t24>iTTWah#_TE~tudZh+cH&lSA?vdUyApF_7X~fZN~j(b^T7J+%j=I@zSrvP z(2os7R<*Qi$6hnIsmAwTg5@^c_C(*CcKGhaUd#~BW4o~%!W#Wd!=nj}o9xPN0uxLb zR_C?{Q`l3-$5s5p-Z}};5^JyVs8PhU<2^iGv-=%bgLlOPp+I*DA5*|*` z$YI4WCq^g5z^p&SD>^X{nG_TzQ2wXG(EzWovw>@&+4)%*gd*xdC&^J!@O3KM5HBgr zRUtGvB`ezKysS_s6zaS}y`fNpib;~i5HEmY$$wBWY(rOJT%mwO9n`af)~lGVL@)Uz zIUvf!KLde;2oGbCBq#=AGCDCK39_P#5C;^7EQdd$=;i2) zP)5NtK(C??@z?y5Vsw5^F{pojb>*}&DCEFH3FNOSEMIm*VTfKo*isb|p90|!UzRbH zm%=|MiwY~uUh-cNf{Foj%)vwfDRqLY?N4t181KFs6=$TmKv3wO3It~Y@>Dc(ATTLJ zAX|Nnb*w#L=UtkgyBI`^Rj+PpV^n)mphW*+q=o{%W%()`D|#zaZu%Q z?INyZuYCd5PZnCeA8J0*f1uBGo2xH=(e&~|H)?zSYddoFJoF-a*B0zA5}Qe9C*5w+?IE25I)xDh zd0vbJt9J?@2_HFUTX!HVmJvqN0Mxf}l5=%222+ovvM~;*Br(;sm>%%t5ZASs0c+zL z%*M2sj_;C5x>DP8qr9ub4CW@Z*nssbR3$gYXkHUD!e^3=Rq00eDp##P%vd);#VoS5 z!jCQL?SbH?pfRz+Vh+}zw$|BZ<+T+pFrV+H(?5t#SM@iR*H!5nti?ub!sZp`_O>S! z!t{CMV@Gak6df-w%t4Chg)5<;u=xs1K)o?Ua{M>dyF*)-TYY*(Crt(B1hJoRzW|-0 zgBu_?rRX3J$qNcoPCkk?B3y$+7riQokbeEWYF<&8a_0A09~*q4lPGqP4k~sewvw)D ztj-;Uo`h>fBD++VNyT2kDf()569)i31)ua+=;EkkK_+XeKiPkMFn&CKxnQ)VjwX+8 z17>^bV)A1A)q<%xbvAi6ezM5g?y~NJ)0Hq3SwsA%_bv9+OmYS?Patc{_;Sy$x$aqB zNob3frub07YKxz`Z+Cxb_vP)rT*roefBbCGWKQ(O2OrRm^k|;;6gXpnv#7m3(US18XQsxC>m=fD7mV#(UIkK5OCYcBc`j- z*r4!);<{E_gF~%q+fEEIrbb(XG8)RPe@7cDS6+cgJ=A1&Y{hd5g%BhwcDkAbs*^^Q zml5W8_9&ohy!vmg+IM>PM6WpF6|aXqHT+NZ^gr{?sCs9LQ8nxKe-;~0+yOJ0BjzC4 zT1G;DgYy@)!$PK^s4zd^a!rpTJ)h!REO#F zapFrv6AR@!)PPXyLU*WFE#vu!!cIW_uEL9g7?sp)sOW;>XheXN4z+YJ0JU{!QBcg) zlf%tg+Tp3WY- zySMMtlb@aWe=M8 zMXn{|T@gN>y*Yd5RUjG}sAwHXXGw%Iu54RwaqZAO`dGXl$?BY~@pBK%%|C4Jruw=6i@zK)+2bSXUE$k~%VGd*LcAd94s<0gMNwZN)1Q&~FQ~0# zKpO5{&vj2p(@z~M?M9ZS)Od2dw1+_Lu&1eTGF<8)&|YMBF7IF3|8*xoTUw+sm1Rz7 zklRymc?zv<4~12v@kB3Y@{40>0nG=5v~$TMjWy^N_sLfAcIxqCX%qx ljPc#^>*HHyfVM7Ch%$f8O;eQp;kXuAy&IO@*Ny57{|P&2d)NQ~ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_211539.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_211539.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd49deb17c2f59f6d568972ad28783d8f1a569ea GIT binary patch literal 5789 zcmdT|T}&HS7QSPT$K!t+n|~9?FQsW9{DhK(ENw$V1A%VZY_>_sy2dkrjj_qhKnfXa zC%au0)RiDorI>Cch+3&;6DjmZ+eBKa*|+ZQi>Iy{YnGLenupCBEqU6fJ$LLe4sk-e zKM#8)``&xbx#xcO-sAh7>%ZD;W(4I=7k@we3lBoy;*M5K*~a>Q0->vjK@1T_=Q1^M zj!^r$QT;iRKr9hb{b)loHKSw6ko8d+?lJl&$jK+j4Ta@{G(){j-pEdH&S*j=#Ml@k zi$cxtQk?0i35wME6^XA8E*i=NY z#3MvEL)_Hgz$=6j^9Wk7;rT40P1AIu^PhpfSS-%-TNtJ75tJojiD*#~m5LrXD)bv2 zg9dfwFq=qY++qoqVtH0-%B`jh&^*zH((9-=}k)rn>5O3N8CZ>1HmZw7UK z=cYL*ZT`oQ88uC;@&_UF<{>NavlI!jQbkrFdT$!EohHz4|3IQiR7dr3__!b7h^o^m ztB~QSpHU2%j)51V z99vhV*o5c^#D9?CePaTr7^CBY+Av>agG0lawu2YAAjA4pQ{FeAY!dK_HEV%K4?NO) zRLA6oX5CM07BC!+Kn@2K2Qh&-e2|J^EXaUkwTdWIba0UO@rsdWFZjZ2L~(5_f{R{M z@l$w{unnm2>cCj=bAk`{LmdsZYQZR$EEFux3NI>pc630Y@p4o#t=dADuL3wVPb>fI zZ@3ucM))y*fNdD|2S)tDa5S>dKg33a2A&Hv2<+%s1N0hT3YKAlks%*!uMmy+Mp!Pw zhU>>BeW2Eafhg>JQ)AT4cRIfG z$%p&plKyE640xZ2x5SNCsux=myFMWjwbwgjSJO0gKZ9Z}w0Kx9sb4;Gw;4c8>%p(5 zJl-z50o*69<<8~OrH0!l=+5?F{5oN+S$esczBwg5Xq5kAT-=@bJaOpdV%#jq2>0Q=wt82mtjuR4#SXM&@OX*|MT z7-w0?dd+A2;c@mD$3;2Czz_S!Sgr*u6$2y%fnmis2&sb5*qUuLHz|f-M95aHEs6mc zOqk>j;34KfG$I6t#-rnK|BVKLoHJymI3ny@kUm5&vK(-zuSp>TOmISFD@70P6e^Q5 z)LFTG;E1co-mDW*NX%8DhRg*L^#LCnVSvB625gf{4HUzGUjU?f$N=)v1ab)S6du2C z@)??Y7P~g(h?DDsej|MN4^_roH9FskU5@?wm+|di87tGCs(YSV*;AV!Ql4jLOlecm zY}HKFyVda*zcN*=*vpdTFQn`*Cg~Sfx9wgSogbCSN~wDe_^RS+^_A+&-O{n?=Cq|~ z_Q=eUYd>2!KY#w_v0J@My-DvO*>WgJ9!gpcr7fP>j+u^Y;|rJOFM+yesVBLuO}4Zp z$+o1WZPjR&F0I(zvm-Mj@dJxH7cV6%nr}b<#gRKllBZt3cgiQ9@_lW8Q>8pt#XaXP z*|{s_+#{Vx+l%knYh-&(%DzMDN}Fx7^)vM;bA@zt)ehvoVt38$d~f1j={~u1Ut%Ox zdMIT-EOo6~UGKGCJ$&VGoL}rpxt~i~_DP-hNzdGAnXE|DuDRy8afP-_pSWlB%2sdu zRLWYNqH8|fxfuB18LT)u+db1gcj4XM6kT!c=%W53^L2Bg@dI1RTmPAU`RM1}A9vrr z@M&*q|4Z1qYqoEuFW#84R;B1|AL29deXFGuJobJXrR`d z$m)=~x#!@t0`SQdobcdiSRaOOcnK3R#OPuKqmStrGNy+fsjC4H805 z9xC*RG2EnXWZ5%j$n#i$8{q+#@2vnoGT2*omKc@iEytTVFi6z>C~ro=rj1o^7Q8y6 zhMYh_1e|RQIW(J%T4N@WW{mA1&O`Oz+O?-5+Y%LxQ86YsQ}h4iO#d@<78RW}M&#|= z|5*@PMv(?~viAl50l&8(g3#a~wL}ka;b?@dGpW>| z*m8Hg4>O(v*kPQF1X%n?91@1P6EK|X#y(8J$QV5C@WYVGr0*)jz^p7k9u{~WKc;ge zl3@_Ly0+XD+?!zHgAc!;(#RT(ib|yJRZpqZyIN8%y|U^qmQJp^ily#{R^%$VYP(`v zr9JR$=}7Kq`Qx!KdhYZje|qMjfpk&Q2{3^R|Bab9R^bNhy;*##YN_h>1Ssv4bR3ix z=j`5@y{mL-yy_-H)NgYl9@Z({0p^Vej@niV-*aIaU@ES#M` z`@1)!uIXO5%6}M&O};ie0pA{VSstnTzO#7CMQ!b_xJ>H3Z!7xu)^6Rr_Q&wj$m*8* zR0gTj_x1HbzJcS5>x3`Gkx5qS!|)HuI6To6T5ECmm7wTDK|wJ=io*-+TVq@%d!bIN zDVG|6Y7i)vOw{;>qG4tUsdHsQhHf8?GUH+PC)_Z&gD^q-r|`9=CkWzOWc?em{T+GX z;i{wGskmIRrrSdtgfQB%>m%=V@0uyoJJRvR$vY=AJu|A_z7U=duUW9pj9S}1@Bg^} zd)!;sBb_zl;Z*{iC;UWu$DZ%>+X+X!c)>gGT|>}ZHhgaV*s3$tssnm~l@dpblA?(h2b$WfKDZ@X+!7xgbOncE@& literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_322972.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_322972.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f663a05f25726ebd7df7aa74fb623579880196 GIT binary patch literal 4566 zcmdT{TWk~A89v9J8ILdHOY9^hkP28S;6g|fSOMCIElV~jmzYJnU6qi^GM;f_$M%Go z34!dmnw6;5E|u)4s$#a4$V!#!kVs_vP^GI%rB8k6%UVHOYf2?c@X)-44V9|CZ2vR9 z#JEQ3^`Xb|x%~e*|Ly$eJ7<2+aW(|y-Ukn+jyVx}N(QxJt#-D%F+z)oM?8+B@k)!w zu`xC!%;OYB0*)DSYD+6EZxX1O<7FBd@#a_1sn+xe!_#-Dm8zg|ixrU{@8B5$#oR6A zTV6qDUqQFFrbk%bc84Cf);+i1LE|+ax9gbjw?pE}y#?4CrjN6VM%% z-tgGv;kPe^?G2`~&^N<>d&>x}0(GUtOts!l%j1 zR;sd6oMLI9TUB<2=!Ph zs95!oRjQlEtEQR-sW!C@0*t(rRw0m8xrPZ_wHz*;S;r8(Mwf-%EN|shY1C{cxQ^w~?O?NMvjLCY?%z1hGcCc?MLnvJ#RMLwQ(? z%RyepQJs~L9 z<(W`e=${IOr$fq=nCuHp3Q473mcsptkeKO*QNNJ5F7SLbIT-}M6fqf`7Nn#QKQ?na z2<#d$EP~6&2M14|90ZQ;-dRw2^0ATW%mXwPg8%JDkQP@7{c9Vd!TEfA0UpUpQNs+cPB2obcah&3*!&90vPq|T<@SuAGrVv!8WuLm&Tj69D%`kPB#$IXVw;2>=UZi?4~bvU$%A;l%R9qu$^3{kHG1^PRN0>~w3RC6@MqChN>#dTS%5PSoPnES{)->S+=WHEU}8W z!k*gxCwuyzS!Xw_bEI%X-2TsUW7i$vh>t1@uZ6qJxL`{?vepWg`Bw64k3J)?xr=dDNXuM4Mq;848 z)6^Kx3z8@s$fncbxR@0DRvn--90DjM*v!@8FUUs|x(gnavqCZ~1mi+-Qkjy-1tpy( zs;trE%q)3{$_+0SizLXiaYdHN=V~2WE6D0JysXQYuE3O0_{;af8$eeUwRLD0%Dzr* ztlZJ1T`qguwM%7ByLRD)19>_YxdpDw`f{X%_Wd91Ws$?H1&q7IZw>?keut51)yul_&Nvo_&!AcOK+xG*Pa+AR zn`2S<>IsWUSrKl{NY%UiJs~NJl97Z)66p3y)&wWTIKPZ4&k1tz!UH8C^0RT_jC2)- zkS0hjK(%GY7(YdhKO*i=XkQsyKI;Ckd&|^`UxzE&>CR4MCbsNkL^d)T&&0PlGIAns z`{L1sqhGpUv~5NvM`g_^v~Pc@tEY6Z_qnwTyK=8D4lN9AA?TJm_AbU3;#+3uLw@?Y p3=WwDqoV^4hn9!7$!L2RBg}ot#xV9ipD-g=&xZ4BAKDF1{tltP$}9i? literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_347928.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_347928.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2775ea3456204bd86f64bf8bb4845d02e636377 GIT binary patch literal 4995 zcmdT{U2GHC6~1GS$K(H>#Q7mO8&Ve55D1VEcF}IxEde$_LRrvmA)%J>OdLCo9cCuU zCc0J!iK>oNDd7o6TM4sTY0ZxW`xc~1OL^M156viOv2LjpX;ZblWj9o+cRq7M7${F2SE8UYPa`*VYUh7HYByRqae zc2WtI6(stpG*sxXWn%;ORq3nJf_dsIwY9>FM(oDks&)FASKENOl6AT@O%>KuMp^OL za1(C)UiK94l0m=x3yY>vfE|!%h&-lI7+?rH7LrE;W@1p{#5yh|BT{HA7!?Q?l0u1y zNKAzn8&8G=VwxHeCE&?QQW68r#5N^HBO{6&99Ja5MFstBPEIKLz#=PBR1kwx0x1s@ z$cvub!#7fr`H?`4m zX=BIvcrrX1Oax^mBq_!D3BplVgRn(ZkR^K2gax~0!cKWtoOd+soS1q0XtK zG%AmW!eZx0C_EZcMv{pgp@^7JI%O%`sfc6aoj`SpW9LOdh$bRII0hw|2#$(ULX7Vm zpALe!#0)3lgu1%-y|w2}kP!E?RfXx&7O;kcM~=abspyFtIo)&TX3u4IFFX9|;YSYF zGwRS9hn)5E5zXQM%<#G8nkBa>_tA3OUd^#rJq)tC=FFk&k%beQce{Et?=4cfsJ3yR z=G~_rTWxO5T*z+Cz4>)>ciNHnw0`FOy!l#lZg(!Sys1a?^rX4Gr!lit^K43Ut6)cD zHZmW}_Gk@(JKO{3zO*s#uA9Gbna>>7T#9+-C5zr?wpX@cWvy>A^Hi zZK;w_68x)?q;EZgW(%yVdvReY`Rcw{Q#Nd=HA^U`!&VMEHTG?r#zcOi8^3bkM>wm}MpR0=Dcn8&VL+|435GYIUdaj&>+?T{uWFS5tQo>{p&_%Wb*a!Qnf$>AW#0p^rxX(wIJ6@Z?wVIexH16IHzy#W#_K*iRi6paK6=2=S_ zV|bjH^fQo}=p2*oFfoF|#964adV{Aw=sSI4fEWRpB^oypy+X-Hy9V0nX+qBsYy>im z3ke17Df{;QX!gf%PNxV7}$`hb9= zx7|5;b4Pkw?E|=WdDJ6$msfrNk=;9YZ1z~jnEC0ly(1?s*}B!*6NYnHPf7*dR2Oc~%~G6Kz3R3?@1S5$$Gn8k)+Mk%X|m7%T7LV^iCog&DjOcfey zfFD$HmD*~+pRTRAONy(|R;x}ld@pTg#Zj^vK6*~~&RnvbK~a?Au{q`{&3X1N!5QWS z^V`6#uKlDJkpNNKD~)G9d;m#QH2r5l0z7gOIL>3xIk;OoT;R zv`3T?VxLSxI-~VA{d<6w=YwKGpkqrkF%EwOfX{?0puhB9!5mSZh%2&e0jdn!1-N$V zPM5co&cKwPz$0JPi(sfCy$yiQ^^I!(N<*`HVx`Wf9$)eJ)T2-A$kTAeaoMrL*F$M~ zV5zf*^RlG#jXeJ0mvQ{rC0oISI;b*`D9Q%lI{mg z{@1Li(p9Ylbl2-{)vHHW-M&{k^=fs#CbfUn;r;h~w^#fAg1tPl*Qo<~YJt*%#+=jx z9dQ+kh@Elxqnv<0CBm0R61{t3j71e<2`3Y>B7QtB=@kUwPKyaSDd{1p-#@WQdcXuD z$+&PG>Dv^dg@1c2DNMx0A4(xm2N8ng%h0VE8HV`=+5dl`c#S#1 zKmfhA>1uo-zGl%$Yl*b!q#d=jU&RafI|n7#jmTgxuIpvmnb!GALom#Vx8v92 m`fzdLOMDlvQ||hQ3}PJL@iB&}`*z5PT>fQe>qVPx@;?Bstx$3R literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_355413.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_355413.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1f6ddfdf96951c6f334dcf556e505ae72bf6a1d GIT binary patch literal 4516 zcmdT{UrZax8J}J6uGikR!Pv$S{seMqQY#`{Ai3l&B&W;&Kp-dGHPVq(TgSVA!PsPX zfdq}!O%Ma%R_l{4*F8{q2H`+a9l_9 z??VT6=9`&szCSbb{pR!UJRUcKLT-OCvRs4EceGO~o+7d3We{3KDpHv^8Z5NTAY=F0 zamOIXAdMNd<&>5dT9wtfQTg>T+EX2G80V^w4+$#&fO}FDH0X38`lm{&prO&)D)OCg z7#H6#&T1~O$)$3dYe-gI5BNd%j<@ayXi#F2Chh2#?=fL{ld2sc&U4qUpYKx!h*%^B zh+Lp9^k2Vl@mkbXW}X1F1HP@iy`d)5g)+*PsNGg&o)qU(u01%}9ZolPVdo5wP)fpX zEK;g$Ew?MZM5zoGpA@yvz+g(o(w-}vudVQ459Y@}M#_t2?69etG8xqCih>#Jq1v$b zsiPz-Rcr3-6?T?9x`2NC6Oin~J}gklI^#lE!ajR-L2DhZGnmf;xEA{wL1xO2{n$Z0 z$Fl8FU=Ldm4%pIa%4`Yl=-|pE)t5=Ca@{eD>upIvOl8rNhQgrFZgc1!3UdQG-RMnv zs$qt3;}K&lED+X=6ZeFk9M&V_N=zj}M2{qfHNwS{5tZbZFs;RgM@-^~ zsRnT*Cru(6rXEu@WtzB)6diRZCdU;sIi@8Hof@Lkv)5r$5OGX^79J`^C6ZxAbhE%n zgyhhW0fSCh2bSC=0yR-Fi!DrmQe{eu5++7mQ;|4`09GQ}dZ_LEv;s1SgYxi6N*f_i z*;c*$^;YL}QXexWB2le#BoZBqm?OzVM`Tz_n4N|m?KHLViB9NsYU8&xRgEQv74X$e zCX_KvPiXO@6ZaIDE)7PLunpZOP8>hh4Kt*6cU6I1KM1xs;WPdQm4&tgNb<~g%ynd# z4fj6l{0rGT-#^!%Y2T2Wt&6$3U`9+IwfeH2x`kWop7vj#dvx{T)z!$e=8c1A*F9&g zOIdH@ug)!BUAnpwS#8*8?ppVDS(kI3z`_i$!?ef0FtYAxwJyE%*3J*l4KGA9m($~` zhn{ z*8_)F53L8@p1t^@F0gnweK>Py<;2p_)v4!o@64VrU;OOGdf+s$fkRj3t}J|%Ig!4( z^4995-^<;v{K(s~1^y!)@E++DK6M{d(egNzV8BMIZ5ydD8ij$%qB3JYgw14o(a#vn zROMjHIM}KjKGpb0x&vgVTxm66X*LZ$RleM%Y#G9Z-~o9iYcpQB+3> zKQX(<}rTdPyQZQKeAdL5t8e&$-gPAs@{<82N4M!nT0;;KIl2+`g>n&59EJ2OTZxmdu?EUwhVHzxZDIy-eiu zGg*I2wlP`$wL^OZYRKxC6jIp~qdHQo%B398 z=gOYBD&b{c7`faOxe1{TbpdhTv*!S}?jvZ!Un~q#6mcgjdab zfSH1-O~s-bkxb1nm1xA!lw;+cE*pts+Yn9V!`+*ZWxh6Y9MG7DlrcI&oRC36>SPig zs_rg3qf`^bIRyEo-Jj9)q+z2Z;iK_nLJPZydny@?+*XX(j7B^=P^wUPuQ9YclUgFG zDRC_^Y>p6lDhZr+9tkNRth}vg36-XUu2^DX(geyV06R@Hz#h$*jGKl*-k>>#wev^j zj^xBprsb*j<@n?A-(G@ANwCgCqR4w^W10Gu=GB?6x}S6J!y{7Tvo7aa+Lk|A`s9~4 zt;@6hfYSe*6&127g@EjU?N-3*eOVja)2Y411skk>NH_mI-{nf*|G{1n$$o3VerduE zJ7L&tgF_n#d-Su=>PMg=HHAFjXdHgrOhSrJ#1f^`i$ENsF_XBW$%J8QcPI2>V(>Bi zeo5-^YbRmaH`KlnL@Hb%WjGmEA0u0F;r@s{kQ_|LwX^ySkV)Sh<7=qkHD#FZko-sF z`4ei+3y2kezVH6NJlo0~huz-az8qhQ=RLINMQsO`@e7|-|OD8u+)-EyFsmJM&~eaUw1??W}bBme*a literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_429595.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_429595.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9ea70a1968e0ddbfb58544553c9a3ba8fcb310d GIT binary patch literal 5651 zcmdT|TWlN06`kc{xqR%B6h%suB}WPpCu}4u@-t52yevzyV@YxB*tMdXpt(zm6d&^L zO0twCum(~jLIa3G1Bh-5h(SLh)dsBOr%F*YN`D%pKiV>_g{=b+h(D@7C3aDuK!M&_ zaw*!OqNb0J4!HBUckZ1#cV^C=;qPhMh9LcB<`*N^-3Wb&8+xM4mAexJLi314ERjG% zrIHvTw6-B>95NA@W=NQm8STHgs|ux@8+WRhu4h zu+}$C*UM%MSy>lr<50X_w`;3p?c2~ToSmiMb8uws9)3{5(zSUoKJ%t|$WigC^G!5F z8xTiVY+|N}fO`WQG7unCYH&=9rBk7VYCAW0?8TRYXHEt$P~ftc}1^*c(? z_euC9ujGM|L3$K4Wf~-(WY=nq5|Uh!5nF<}(9fjTGybYkNZp#PQu0d;EOEnh9eWp2 zrm3nwRiahrQj^5!R%LeW*w@J#rs3>XB$}%vv@>0;8!QLyct+ALsi~$#wPnhqjCO-$ zbG0PahsVz}Nlmba(GLwQiWu;Ym(18hk`Gfcy-X!5eru_+w3gopsYUa3!vlI7@7FjD zJCQY4`FA(0ltmlcD^a+-UtZJcldL*t%Qno(JXB?s<9^@amSm6~(lqRbRlR?{HSD@U z>OLmW6be`dc)US3@Y&hGwSYl2B}2lfYEFwI93KUemI4HiQ8!>w4I&)$F+LsTLrHvm z$q*k(MLE@!NQYR}8W~R{f{EBDrDF!E0tKMKM^t1>fdM9XXNkD_;@sZ@g+la@D$y;!Sw>OWG zMKzDbSP<6=p;UnR0(ZHDOsa(~SIBfEB7p8v$*L6gPY^n)q*h~LW?=l7$g5T{J<6qE zE`n+eCql`wU?M%K+Dg?Cm^92ih}EeUE_o?9!G%GqMi5%Trq&(O--kOV)BLC~77BBn zBcbqUNE}J0Izmw{C3XsYxKrekW1Z0I1nnHl#!}HBcvnoPf}rMiQzOj z{=mK?2ag;Cj_TTyDs|_(z$ol7;bYh!GP=hg+Bw@l)1PZAIGW|-w;jy3t)k1Dv&=c= z6YGp8JEAbV<=!Ig`&r$B|C&GFm5&zudlh=G+_O&Czst{0%}wRq`MyF!heCJAJw>NC z=g*&4{OyXfL+)FrJ=v&2w*Z6j=FEAA(%7yr?TZsD4yChOVY=ns+pfBK%bX==$&W0D zK6gDOpSS}Q$DAW~VbM`=9h6Vpc6;WV=bCe|#ir%Z((Z!$DfuLpqUIM9#;B>*DhZA1Iw*lL)$Y&QR(n4RDuGjfGWFkubws=m zXAL-OgiIxQP8{b`VSGyTR0kC;?!E(nFpne@*EvJhyXovhzMn)-&c zO)^)FSbcjbn*_77>SOKDlag$D&xkQ&7d`r{B)b;9kOZ@Z)$r={t&*+WRJP%pbj!id z-!5O9za?+lLB2MBCwX(lK8r+2R>=lY%?Poo>`?+qrZYHNEg4iZD^87ZD#>yau`svs z0(e1-VLaXiJPwcn^4p<5V27YI5&|V=2ubgH;JJr6WeXq!MRp8Kjz`5cvdbaW(GqPhtzK z6)fiAhJ)JS6Py|duiNWpkIo#u+AY5*UtVu&U1-17u9zC+lUYmg?D@5`!FL{=ePZT` z?B!fr?tC`Ac<$!#((vk@-*tS_@kiS$Yut!@><;73KKyRvmakp$wJ)A7_&N*BlYpVl zy6nDpCvE`&tndm}O>|?BxRG&mCTP?%H#UZG~O!>l8Cf&(PUWfoja{S?IXd zu}Bn}zPC=r@5y;_fNMzszm?2nW#=x2~MrfO=?o)tb4<^Y7r_2x` z#0sMpZo&PIyk17ljJZlaWwxd-Pwu-gFn%e#+Yxahxqy`8Y=}RIt9FqS#9%liaKWzX#y$%B|&;yhM0g8VbYqG$51Fp8BYKCn8KGlrdDb*5z2Z+|5=J>Rr!KP{sC(HI^D1$A3}?Rd*%}&k5rR zQ4sJmqJrEd*k-g@R`l~1pvx=p2@4u_Z&IkvC!Z|VH^}`(pI;s7peMeJpcUK!v}tS{NtBDdU@@YbN9?9Hz}WhE{tdP^_ka;R9&|JhWBR6Qp?H& zFd0(r0jAxBfxk#Kl2G1-WhaA-ykBRW` zOQ!^pyE4X?uLx(kl#u3CBkWdfQ)-(+wU>5LFq%%VOGq2Lw1uc=lWBH5!9B-^K__e* zB>V%iO(Q`NUn0k!5&Z>f{sQgVB$0u7tLbXfrlFM>AYf}f{K)&w?=^2)%grOteC+th zQEuAMLv0I*Yl%&}#P6+%i?_MtB^LEUIdB}>ag^JX^mrK;4PJNAqx z*@0~T9(pCuoWFbSIrrW<-@X0=$1w=nXO|xgh5`tEff=KD%fyOQt?Q6rbY6(c$r+FKN)J zpk-wGtNxUrQo|GTrsM)Q0f%Qh()Q! zdTdeBD=tHms)xSBa3928r!Y9DNm$L)28eByN2PYP2W+Km&}}!alV`oMQT7{&RCY@x z^HhW4fLU}KF)4vHW7DQ)Z5|#&lWMd?Ve`EyjY`AU$!VT4)`LJ3DC9aV;_bAIcla`< zLJo}_iAcj5m6QhsF+_-XX+s{(A;&dlR7?(tk&$qW*IW@Xk{A#)N*aue3K|(tMtCqv zYE&#Cx3_6jKS*Ry%SllPHEGTmFKMnTLTq4A)||;PS@TM=7~=(SCWSAp;^LaWtSi}v z$JYvY(;hrodSfoWO2=!&IgRKN@h%kc2cl7EDXGWxNzzDRYq2!q-XMw2j`9qsM= z-UEeJRq0f+7TYsdc z=fbTExy#c#7K3l+xwqBsB|Bk(tI0<1#BRl=iCghJw+)-sWFzTbwY%W2&VDb~mk);W z{?JVKqrQC0;k^H_di*)#&Fo0;$nKo3o4Ndq+4DKmk!L!7_0cEiAD>@j4ys3%81Cap zwsvYD-J%{Tx{${YlB_?^Y<}q^JX;D(Rkrr_!7=N}U&uqd&ZiXRm2eMuIkUe>3Cmu2ck+A_Y5?$VZW6uNIL{D(3^MYmYWp&(+J|74LE+AlZs^or{b7|GY-UJ z)?*_#z)lmLha1e8Y(4hDjHIZP3;fuOr)%9Q8X~npwAJ~Dt#TRt0RQZ&Hl?@}y59*S zRRT3BkXD=m_p*n-AycOC+QAVLGwi_xo1_iNrjI+WzmbAhcmi^8Y zD>s@xSbU0rQv;!puwKoEyz;KB!)x>|BVx$^&He55mHy;9{q6PZ^*7Ih5Z@d3*=zI_ z{~m>zch8kirUX*R@1hWWT6Du?m0tcCpI zyCBjiz}x7ch&P)K#ho1*nc!n18kHCu5#l0;6{HlK1n35Mktf7XEOC{9Me`?ws{mig zD}oqK_J`XvGRnusb?DGsB~T7|D^*Esh0fwP@k6gd#0MJDqY+_o6Bbo4mmC17m+;qQ z#Y3?L1e(VLDzO{*_uwZ@=?GQyA+JBPFTL+)!F!E&8-FhRa`^u6!iKgy+qOWqEwF7+ zz%%cs-_LD*$~G>LjSFn!66?bqZkA*@>E6WMiJ6+2ZMh#js`+H&sLE;AU=2T$}jO_p_0w z?)vAxn#^!|IJYMknGWPaGo23)%pUmE^%?g&?wRlOl(XQgQjb0N)qd{Vp7(8^{&3N^ zLp_Eui+N7_GHja79?o7_q&KhTif+>DQoArp1hZ%JWWy2w4(-#QY=%iQ**!No;3}jq z!=<@wWRczgb{5)VVo#QXxCzcHWL5oARsCO^{gkV?1*Ts;A?}71iNeopc!T1#T^WQY z+&YOARM!;2J5vsxOgVukEvHozU(EhUaYA))^b?lzzYX^A)Vyo1jAtpz;?V|)aK*8B zv%@0ntvpJ~W$|WnD~9v>zv9g;&n)mxRk>LLw@J<3eIU z9;~2-=7}Xn;l2dmpr3L}DB>0cX)G>F68@vc1dI||`1Bl2|6(tgjKWX4SAq;0RoAH} z3bpm>sY2aG^-Q5AsP+^BLG=Vw)Ii-G?iN>|YqLYshZeSV{O0IyPyYJk!gtQSqR4GUo7DR5f*aIRFSzP|FL!{d)wzEqUK)98)YCeWggAYRi3jmM^OuetyW{X*!Wdk9 zHQJy=eRDg9VzTCmCKHk@Tpbn5=a92PLQ0Bo6Ne3oj|}=I(Ad(h2@fRW{4CPPDs371 zz(|rGiwg%ujAF1&kTeUAqLUzqFOc_7i2E~YDiH1)4c8iqj@`sjqS%0ZRhhnYU$IeV zo678$&6r(rB8RtZc!}6W_?JAsOjEk4=mfgp4P<)KJw*y=!1|iHq8n2*3e*-om|_s) q*R5I1d6A2~(RZzHg#&J7kU#|YB|StCH81;|$ltu^+jNc9Q~4(!c#Cxa literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_459432.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_459432.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..366efba720d32855f496db9de41d037f5e293e10 GIT binary patch literal 5651 zcmdT|TWlN06`kc{xqR%B6h%suB}WPpCu}4u@-t52yevzyV@YxB*tMdXpt(zm6d&^L zO0twCum(~jLIa3G1Bh-5NI^d$)dsBOr%F*EN`D%pKiV>_g{=b+h(D@7C3aCDK!M&_ za!J~uqNb0J4!HBUckZ1#cV^C=;U8$)h9G_O)~`mcyAk>lH}pi8Dt9Lcgys>8SR#Rj ziX|~bXl+B%IAkI)&5$rB$sv+J91&-96_t`Ta;CVe($;INX&btE8#=iS-Legxs!fkL zSnHdn>m@UWtgMT*aVTD|+f^x9`!;k7XJ;w+92{A@haZ%%bZs7t&%9|Ka+JO5d=m}P z2E@^2o0usg;NHZB3)Cg*2NSz=(S4864y)o zjBAIUhH_8$4n2+Ko^_I2a!i{M%6OzYiPmcMlBY`ERIx>`v4rHfUb1x>h#9X`zoP_w zpM+2HN*)*)q(?zhra|&acCFSZA;~2fu_c%b{Y-j2^mO*Jj5EmIn0v>POw zt0l2MJbtE0YJxqCer#Y-#DI6aWX2wne3**qB`R6=TT7Ltwfsg%Et;_5S(Rux8MwKZyqCy zY95QRAg&ccu>kW0?s5^CR0~_Gkm*Q70NursRVnPBAaqnot;WL4!1ytdSFK`tluN-} z1l1Z&gpy;yM0!%S6{{mKX_$Kut5Ypp@=|bu3xigTAhd!_tvjS|hC3(I{HQP%3Ui$! zq3~!(97(4-LQyUyb_#sBQ{<9kozUt8?HtR-QqdrIS4^jZqa2^&5_`v{g22{^;WRk@ zz`-L2_U{9Z>Z(W;yYpRO6!w_#F>DYS-D42#ob8|K&$Zx`2%dDOt;*7+f_GjnX}|9iz6$c z&s|T+C++~nG3Ur#SaRfD2jvsD-Jbd8x#nDKsc9v&ygToHN8FEtBIbUGJq7|G|~%wIg{~Kg`DE&z(|SkIN?tOvBHJ+<}}$VOkfTQJB3ehgKi? z9kKe@2hXovQ1%Zf^q|~xn{mxXXQJ6iZelLI)c4`(Rny9OrTyp{xh8&gdTsI(r_w&S zetzh4=5_E7cxqH(+7=^AZzy}7{>_W4;%cw*XtzRle@d>O9#;B>*DqfBBg?H{JhyhovhzOn);Ts zO)^)FSbcjbnFO=5>SOKDlag$D&xkQ&7d`r{B)b;9kOZ@Z)$r={O39XQD%K z@0PF4ugIHrkgv_(N#0zx&mvKhRkA@;GeWE?d6Yns=?som%LdiVic@2pO0wKUEX-}b z0AA2y7>{=Wj{{_Y{BGzE*dZv5gg}WILQ{A|H3MvlQ~V*U;4oyWF&Rs#WDz=5XNtQb z2FKEq91l<(+^3qtY-~cqMwKiErGTUAV3k8m@rV&Gsl=ID2C1efL_PsqT#ftulh^`l z8H;(i;h=W-1g8eV8}_={qccaZcFQlymp7VP7uv72E2akdWY$tRdw%_F@SR6zpO|?f zdpXyZJD*K2ox3@_JiNB&_Z^>f{K@vpIyWL8yTiD%55F6^jKg>Vgvu^K| zyYHClvu71k(+1_t`f?s9W|)+w~JVE6h*NO)s|nG_x3fKmJ~P#r@0VhNor2)qsC)|Aw~>nt$`! z&E!@iayaFa{626@7`}4CQC>r@5y;_fNMzszm?2nW#=x2~MrfO=?o)tb4<^Y7r_2x` z#4@86Zo&PIyk0`hjJZlG zh5)54ac6n5<%<5=lPK2mEKsYg`fquyjMcR)+tDftTg6&oO>O^^HT}=5vuoBlGDMZ% z{?GQtjw?XpIZ6~DGcyqTE1b6&f~~PB19cPcTtLckHpCyqRlCRuVlW&MxL{XxW1j}z zv3r1R)#Uw~=mE-s0L4FzHCf=j0asg5HAA+4pK8YKlxm5<14L_2b9`FRU{f`R6X_He zu;M=%!Xfz6AvVpabQxQNY%Hld;VM7QrNUe=!KI?&NYxv`8cU6h<3A~csymZ~=Y;Wu zC2p~{|;UE@6Fxu;M~7u zFNGW)c|eOL0b(HF)b<&F6mK?X@e;5%0dL%Kcxb7VUgI%l!X;s944zp$4mqkZ9uwi= zmre;HcV&z(T@lW5DIv|PM%b;|rqnitYA^1hU^JazmytGhaSKt;Ce!SAf_sh+gHG5s zNcab2TSkH)zC@0{Ao>f``~}*zMIruRU-)r8omYPSN`PlK1 zqtvvahuRhr*AiQHjcY@PkNi6EL8WP6+%i{&0kB2C;%Tu~BNzSe~98j-gAzmn~vkB}z?n0q`FNFw)R!mb6Y#wU*qO8-GKJ}lYixt04EShYT%p_yW zjvPDbWRS*WY@F8AMyWb9XGVInjY=x}mUidf?NhwUJ#ap00!<1oM1QKR@*2wc_rMq4 z(k^N))dipXE&UGYW8rgZt|`CjdB9D&+g?i#(4@zKG*81c)!2x?lLthv6<3JA}6oWX3S*jOHwils& zG$S0caeN(`!)V=e^$SG-+@CiMUPie%NEhJRJH3ZHu3%VBXCj5eyPEDJHO{)fB3v(tB z4O366nzBIL%^GcW=jU>YS(wrChE5I9>3-|5D2TALpoN1>QHjSe6S~Zb1PI!qPEZf!2*~P}#S4`s?k% zg@Qg~%qEiB;B+E6lQ5?X`GG`A%bSCSo*Xo_-0UDogIewbO;ywRlmfn*g}gGO>3J=C zV)mW_-L1xC0hZzPd#BHyIR!n$zcZ^bTt5o72=E!(kXh(;2zlg{f#rcRv*|u)U3?*V zSH_pe%l(^D)Vf^j3YW#w32Uq>cdg#8$o;>#@bT4;u5Khgi*6o0Uy;vSBUNwD&o4Z> z`ta&{V&lMObg1GTvaZzR(CQ)#hi*BzdUw6Qge&q`#>?rW)>BKUq>X+?CMKlB*g#SUy=gwQ{}l* ze&ft@@6Zxk4MfX#N{81+EB+IoMJoRHV3_D$5tqf);Obndx18MIe+c^GlUhe{6JxH?<&E7noae;8g zbYfN`92{XSOZY~}6JK7tYbvva1x<&8r3?{gQccg>F-`c!$&X7ro2wtA*k7j^-OOED z2%``|RdR4*vbGh(nF1Cu=tbRmSoL?Hfw)X~s zH4s|sDfN8VYmL@~;4`83MR)A?-2;{GfsNE=_n9T>rPQ@DwmepLuYYG#8hFu^g!SmA zbfV5O(!18B*F56Ft2ZjngH_R66+QGH=w6GJV&yxVf&OYRvUaX?uAKPrTh(B!>hG=w zdaB|6zXaVJUymTkXN~G#1KX1DwG%z`beXrOA%|BCQ$(tx$f#`5p*o8!$j%OGr!2iI zVHP68F~yke_6#M|T}kDh@=uy^UF13t_lj_cw+^&DBHKo5YANy^XnS8h{|cI5wqx9m zR_dM5vYbI_IK^1VP>$G@T4xxGsxS#@L1=7Q(7Ht<7kFzN;1@fq5kIW3>ml&qFdnbSMu zN?hsu?3@XAreWWyHraxH&6vxYhC#1@HlH@g)NA|M#@9z-$aVOPN8k@Is))KGK%b#* zYrGcev98yGVe48g5Vl5NODGUolS^_<46SC?-+y-G^e->}YV3=#XW#w7Yt9+qt*bD^ z8(g`ye5)pQt&Tqpe;Run`+Occy}UI7ogV+n@#W(+vAZ05s(qS!ocr|%bb5H}B0PpT z?^3!PS&wcketzn?^ZmMnd;#l9E!Oww=EIvmyJcNj8V6SX=d9R;wHpMmL$+HXYxJc* zysuOH!-Wr6M literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_477598.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_477598.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bc8fbcc7c20e16c05ca33164abb6c4c39d9c73a GIT binary patch literal 5485 zcmdT{TWs6b89tOqiV{W1y7?kocHL#E)p&7}I7{88$y(>?+Rlm<4PqM;RFSr2>*A5L zU3(PDY%LOD50-%fR?q_RfB~}`1M&8<3>dHs*uyRllVW2LxC{``JS=Z&lfCR||Di5s zRV5yFdDsCY{{K1u?HvC6zyJJ^<7^1VVC>WB=iLZ>hMDT%D#T-&Kxhr|h$mubto$U# z2%R^@&0{2i1Y$MqP9P)_OGs%_ z9Lp3k`XTGRz|j+m<5tE0GAyTDic6tpVJ77UYb>?^Gj{yc*FC`5FqiTg;-bH5S#eAe za&3AB@4jwqnX8^tK7*~sQF1leW~Esn^$1%aLPn3US&yyNV5LF&DM)F>a!LzNY>~I{ zyMU2uQ`!_7pBg+Z*zPi%e9Dh!m%&P^U1?XWx)-nRrDGp2gwmmVaYO2$4^2FpG_`>r z#gBPBQ|7T3o-)J{FxYMIkthM)v_;oW!Ke6UjM=5{1-zIVeH=6JDc6r$))>#;TQ7wz zDjt15KE=Idsh%2vmQaux6>;i3!kHQ*G>Vs(=7gY0BjaIdR-=;gv>;AZErI|UelIEo z2@&T)kkw3bOta02$tf`$4@G&64vXQ$l%P?|f|!&v%VaDWmX96PNOa3d zwq_GzLR?75A$duoB$yR~of;=6XN3fql0tK`sM!{U=+v~V)0X6Xg^+~rg~dA2m&Y*kz`CG;ZM?N@Q;tohNDbQPD+BTxhm0@*$@^sha(Zqny5&CTB5AR zOvb`EU`eyqPC%m;!ZC>3a3)|$qZf5;{DgH$>V&YKd*=AZMBidkoR#Ln5utB79GMNv z)5*j)!c%bieUcdIlZE(PA8>s_{9S?Pqlu{y9F?3*gk}XXA;b>PErnpYIvIh34)vcn zaeUxqSkapHw94Y*Q3wgw6KNdsN<~i&Ag603o{nc<-*I%R=N~wl3|!IT$h>gR(pF#` zMb@2}$g|x)dE=)Se|&Kxyxp;L;8dPHrJgBpuFU!Lchz$RS8MjwyzAhGcbm-zfqGRP zDztWFm-4MWInwvgW?$(^_hg1Mr?%X;{hR(>TfaK^u*tnVPr-oUy&@9kAbikv@t1y*yxSMuB|>bVD;bLDdSa%OD(VBXiWvGDWo9q;zk=KFVj zcjZ64b$9VMo%y~?xyxg_+&BnBzz5R@GtJrXuI)K>uxR6SzGW@28o2KY=3T)Z+l%Vp z1CqXW;p&C!mov+Gu6zC1#yfu|`yay@Be56$ARF{AMQq129eBnJnW)jCAkT}52vqyp zZ~>*G*eUp5LkgO~Uw_@GLL^Vj_JBnXNL8w>QE*QxMW?8Q6Qml{{t1jlREk!Jr> znwgKrH40!0fQpi)!*fDej{^9MOlwp$fzh%In@7GWVqY2;l>j;ta&&4wIWK7x)M|N2 z9KcI-xk_u!gzz3r=ozA{(SH_>V6w7>`X}@mC_=~QVwT{(}t3&GR>SBTQuAEMv{zSMrdt>$< z8_bcx92+dKu9cJNlb@WuIdWrU>-g>0HebuN59HZ_9669<2VnU=+m&a#cG>P6*_~s% zix!*u{Ra-u%4~Wzdt&{3&fmLz?ia&98~$DMp8~%R>^d%AH$U{W-S-6Zp5Vp{JD#KJ z@B@eIql5R|J$ZM}hJVN1zy0oxV?aG$w0k~!XKj3Se7$9(cgJ@k$G)tdc}RLQZ|6yW zfpGzP6&SX_I17wh|My~|tLJ9wM(X4Dw_W$Tj$XIkXIk@2Yu2^PwBI2>V}(m|neYzN z`emA3v88R9BR}H6VvTHSDfqho-ql3Wr31+BRENZ4aMTg_)dp^>^uj)^7j~XVA>NcC zcyr3clPNRsq@h8@eP7*x&y`e6lg6gMK{(z1WkHfdk+39$;F7V&9MvHr`UI+_ar)RJR5#NfWW-b0k_E0;(4NcU zY0OcWOXw&e;s!=9z#0{aB@;r>itj#=Fx-Wr%Yw$$s&oiLkP~h&^Fksb;Cs}RJiP}7 znl+l3gDVWw`!Y;mZ~(u8G#`^C33oU(3@8JDL*G!%zBmCQG5AS0%P7Dgce6TF^tM2! z+}x&KDtdhCaM9&cho0Dxt9gxE<%*0qGqZj=_k91a&i?w%U%r|9_B&4~(nYHmK?KV8 zczV3ZxHBVLzT1J#z?}tHbkOQKSY(@44y6wjnU-u|OSm21jNd&6i!7}UKJg;TaXp%C zUhmjgzH@AsJYBNGH3TYo;JKSuZe00zTs?n%1d8{+7Db({dLv-lt8dGz4n1u0?YpV{ z&G}l@k%yf7-`l&re(zt2mqB)qI;vwxkQfbedVYz6_+>lGUB%%TT;=EClBh9NN@Vf6 zc_u1rR%l5iS$JG2xUb zfG}hVl0JpMl9?cg&yf8ui2Ez*C=!-y{;U3y=|$prsRO~aaUwlY3Sh1id33&pTDSdNoT>QH0Xwi4>dF$RMTu9V19FwtneGwg$Xz{~(K znpMY=tC%mAL<(#5!&r*s#UO?LL?nt-E5H5m3aee+xDrzHQT{f8N>#sl?(Boj8kh9- z(W~7%_ndR@dE9$mzZXRZf)Q38Mc!;c=+6YG3{QEn^EX&5AsNY39F3N~)F@@abb=XW zDWp&_i%;cf=_}I;8*}{-M>YNsd|=dOM|LE;WKKab(5-&R{0s0l#V!l* zx4(c+1C3Ga<8Ik;hZ}WN?Q-5hqfQzr&Z-8kTU0>YCW1$>@b=Yf7q3VoBiL11^$lO^ zyL2@q)aYf0H3R>h7Mk3cPE}|zfa`}$t(NI&;eefVb*J4B41SJ7wY=GS&Tzw&!LOEO z&H*{?FdPOq0hqM2Mi&44NrN&Rf|V=5fdF2iovew0*O<%}g&G(>~3q`D0k+%x($ zEu(R|jqqm(KSLx~f-hmMQ=S7fM24rNVI4hMCl~js;Cuqp~8+Vn=yJu#V((Lef(cN>Wpa zbd@*><|m`FbTf%L@&y^G@o`N`VxH&#U%Pbxq-$joY_Ul2ot4OD%*-euAXcCE)KIssBqqB7>Q)jr6yx)6AXku~6j*oq+|SPToC6AO*gI9)t{w(8h$m}4xKk7DHX?V!!u#{@XHRa5 zezX6X=-GoV6?x=t&P4L!;ZGYMwmoReoy@6WybPi*Qakk9DFdC3$Ax;9y^~G&zDXTo@RehK}b9@sFm~kL6}pJR8R!H-A0#xb=(f z{40Y`9Yf|tOVy7qWc19yyp)~Gi-BDK`spvd>t|Q4<;6FF6U%>i`S#_z?_|!*k7Z9S zUw?EVr{@M9Nb55j^4AwO!oU4h{-ypuumd}AJ4khtNrVZzW>PmLO%JL%aWw~ll%x!6 zmE`=84W`Ni0zkt?A27q z*X-Oc2yF&eNj3-U1i6)+<8<1gJ1ZI(4l5N=1~ORCoZL;_*h|s|djK8uG59rlwX7=z zJl{b9rWEc0P9-Hw^^%Fi&0bUnY?y;+RpJokk>I3)`BF4sPg40nmnKuQiVE>2^eieokmDMKUIP!qkXA7E^+hL`$N zfM0^Y_6bZTD!NglZ>epu?c;WHsNii~>R9Z!J!D=mXSZxkPi*baTK#`)J&|uckqd9O zp1JGV<&oX9;GB16hL(j*A@EFaThKsuY*To($T2S7?Avu9&X;*N&mJlW?tc`Y>ea+QE$eI*%_$UxxHN3BN~>438KFu{8-i@>1$>QG z|E*j5rgN`T^!8T}+=ZR5~4S@4{qrZg*;VJ;L;C6$1kRJ>3SDqwU@ z!D2PxNpduSJ(@B#tt3N=6jzdAJ+hZou|1lcoF)}hv#O|-WkJ58O~-XjBfWn$o0f8^ z+p@FDU+se}!|>PcgFTQ?1=QGL4sA8Jn!{TyZRWMDCZBnA%j+|Tc3sHZvLr5wTS9Xt zw*1zUqo=>R^t&rxTzT@#>$@E5<;_9Z;%-{FF@IxAXv_?+`9AZn_%~*N=;qCSAUYcs zUYmbyOK8pd*Obo^D~ZSbKy>ov#ocDax$j1^Ez2Fbxs6j#*|&-=gl+0UL66R;4Vn}9Ht^eQPelc^$imL zggT%>q=gR;-9A*L+o{VG+;r#B`^G(^C|aN!9XtN8_d)OX9s=$#h<24X_EYFcSK;XK z@5Bzulll2l?_zHe!LrcOz7$`K7a3S@bsQ;j@C705hl@6{5>RuSwbxD{2WsmmI?2j~ g{GB_5)y`!KQR4SPjG~&p8)J~izv(`F+iBVL4>50i9{>OV literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_490985.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_490985.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e38d114275513a5594b7b8988325fc25ef58816 GIT binary patch literal 5651 zcmdT|TWlN06`kc{xqR%B6h%suC0jBQCu}4u@-t52yevzyV@YxB*tMdXpt(zm6d&^L zO0twCum)1ZLIa3G1Bh-5NJc**)dsBOr%F*EN`D%pKiV>_g{@N%h(D@7C3aDuK!M&_ zaw*!OqNb0J4!HBUckZ1#cV^C=;qPeLh9Ld<+AoG~x)J&kH}pi8D)+|-gk}+oSR#Q2 ziX|~XXl+B%IA9_$&5$rB$pMl;91&-96`Lh%Sz5SiXPpOtyPW2yq zHQ1+8rAnXb7)xIY3Sx*CgE3aM;TD%-p&qVW2P^0|!gv21lOv2eMOEgwUXlJ@oH&_nZ@r_r{8SVth z#!5-7509T|kQ!hQqaPYr6fxi(FPX82Bp;??dWlMw{nk`rX*ItQQj_NEx<~ZZKdf;Q zb|P!8@b6AoDT_9?SE6wFu)L zbsrOG0tGC6Jl>$|`0T9XTEL*1k|AMOHK)ZPj*kLKO96t%s2i}T1`!VW2%nDfp(H-O zWQY%?qMT|fs=dS;7lUJ}Rd29mdRhb#y%RJxC7{9b_(<~LZN%fn#BdrM zf8fZ|N1izd9M!ccRqW1pfKk|E!pE>dWb}YRv~#+5syEk?cQne!?>d+*TLqUlXPI%z zC)OBGc1U4%%H0Lp_p_Qg|8@U-$9y#J->cAj<*qfl_C0=fVrF9AJ>QeBYg6bpxvSvx z=KS;L6@RPZY?FJ|Xiqk(&`rQ#ygBo{L#c07nAU}{Wrxz#_gz=btYyZMv&;`I zg+6yZBcHej6vvDscVWSicO8^Z+;w|q8)q7Gv4w`E(BjU#`x*HpmZD}T7;8}RcPt%Q zj^$lF(6vmCFLm5VDgJ{?lPgE^u3ngp%bz=?xSo_x6qvf75V-?6i^4R|Kc_HzmkupI z_8VgPi4R^}xuERtQ|NxV>n`J(j!s3hk=)o!dZFjT)61r%^GfT{6>>%V?DWd`Cr+ic zfA##p=ggbnAMn($!nDjs7T!|!Jo~Gcm&N67xs}e(T2{J03Mzq9e>C;phjm1} z4`+2atA|V_c}^VVQ(=5c^i%^CE$+SxfG~?B6xTTwT@N<6wn=opqFBvZyD+I7QNgQ(vPQMJ3CMCuW}9G$G+L7Mu8 zv`sQsj97VlDVYSbv+84Q)02{Hde4Y4V;4R8tR%Y@y^sX6h1Kxt^v#kj-&D5Xn{>;; z&)+Uzoxdq>+D5)Qe>-_|**=RzNmj`QQOyXks^n1uNv1P6S}ht>Gb>Jva4N}iW3e!| z{t|dWi(x$81w0Or0rK0SKVXNTG!z0QW(ZB;A=M19DNgW*u!6&osm5e1rIJPHRGlgA ziWnS8k8?agb#R|*3bU~>4I5Rm7?c8zii1@SF~uWByrL3ksu`r3q7eB6Y;h&-^G{<7 ztYs|b;f8}+;S-!12(Q^|rjJe?z1k_iEMH!0Xr61m-l~}DRaK;NsxQp5L~8()I`2YpdLleC!_M&hCCMa>v)I_*xfE=Y8#Y=4rrC zXH9nBdt-M1gzMTCqWQW*dFNrdr{JL99h-h@>aAQyE}D1lxkcW#FWQ&7m%pEXvTN1e zEqC5C)n?BsriL}jnf2v7P)svZO!jh~^5+iEy>R`7g_it|)-{Tmrl;s^C{NYr_RO_i zZ(AVp4c}X%?9;X>TejmzbXJ(1oSB?&`Eh1G{C@mKe98Uut zA;dDH8g9Y;j=Why&5XH1V>J|mlUt#!3ZiM+O6Zs&E3{R+`x?HJHp?dJXjqG9XVEPu zPz(V|TjI|0WXl!(wI@-m<$0i1SoPoXS{bXWS+=cJ6t;@B!kXItCu{njS!dU*b7Y7L zzx|)>jcr$e#&eV?KxU>d_7^yBF$7y+bPu&fd`1zp5*wnpuwhU4kywn zE?~ufGK53$r$cO#Q|U6c2H99rb;4DCluL!VV1i3U#i5Eff;E;J8O4862o-lG3(pCo z2~iO6Gop;#Mc8JvS(f$l7of{)@CkDoc5hIq#wVXF)Yi$p1)pE;FL=E2se;=ppM2mz zZr?0DLl>ypY<&L3)!hevef;BBKYDfbwQ~>5CO0XcfG&(@`pu~~3sg6dqbCrPp|jnQ%#%7=dRNk3)`XjK@TH z_@z^V$XyxXOIL)mTuMmusu6apwkfquq1ubPC>TvA*hQp`UED&{^T{+jn&4jG!=Mwk z4HEtV*@lrIh%b@jPl)~kHGYA1Y>>!6z0+{DVZ+c&^bxSNc0cxhp@;h&PBkuU}MNzgX^(B+|s4#1$oRz%dlp)-eg#8VN}U;I-2 zluGMH_5B743Dk(HM-wf6d7WSwv1}Y86t92DxS{&^Aj8u)4ENR*^&5FcKqK}ldW|m` zXRD8cUegUjzo`@%cLVivIwWuCBh%VAPhkVL+#_1x=|XAhQVOBC9oJyv zG|a;GX&8??%5*d8)C~=`!=HuByti&SUW;q7;R;Z3XPGSVxZ=SGJBVc2tST13Pm89T z%*IZhn%8N5$6aO8)slq5E=_vF+$xEKX>3B4#Yb3kkB{Uf+(rRozOyP!#d93Ci-c%t0 zTH>3_JgSy4o3W2Kl54+f6xh8~%hGj(`8|h~5C$QWf$FHxZ14`JxP03^*_*i;7Vgy&RYjLJ9{u#vmLRl$yZRt*?8&A85H2 z6R${P!I02091L9v%EPhfuHcXmm0Kh+)FKO`V=X|nfC+);!_lDttfw4{2CfKVREWGi zHWdK2HW-S*Qt#cpuVepS;3)PNQpM?FD|lssS9%KBkU-D3BeONNYi3u5TH-v3&gTqr zHe{$wYi8(a?T#cwOgH4w~4rtoqPh;j~5zzpXfd{eQe4#SIUl8v6;`7U2gg*-} zQj7bS4)^7E^*y@4=Xdh?dOrCsc;#)%I+w4}5$e_uhhKY1^TE z^0r;q?u)3^Ia;}o-^isyB9hhZ2u$mp!p9+^1fqf4xHG$JsY~)luTvK-c-J>}<*PY+i zzJT-FI`h8Hhl7iKe>vB;`2Lr%yzl2A3ERNVvKdeI{e^d)*bjfxh`f8hb)ve~Z+ytv zyaxM=#9Gogp-~u7kjKSnh+u;DIH03m&kK+Ta|ojmjZ^j=AV8{%>oN6Gt|3kXw_NkO ze4K%Ca%L-@xKY;XDl}3JKy_G;=|MdQG7#5e|7m%6xf70Y@TQCtw%v2 z%$`GjWA7cELi6&}7~nK7OoT(i%1Kwkl!ev?8=!@)5@+(%}2K%>wA zftdr>yAnqeR|{rG>fp@5 zb;9aQU7ER+I8m_FrcTeCP8@&6)ui5rCUF;0yAZu zv8IDdYy-@CFl>t`NMXhnx(uSZl-3!Y>Xgjv(hL zULV)-hPWPRLs@Z9<ssL>fU&nXl}aTXg#N+(>K3#Gm$#q!33&@c1qzqM+w^z5}> zvBWFh1bfQ;Pxkab^Ukb#XNgl~>-K+^8?W5~D~VA>Q(JHN?{MD22vTDu0i!D7Q~pOLZ3p1BC)6dn9WVZLcz-cDLgGG))GDj`0%JwBMDc>g=k0!M1<&& zJgitIVlb0byQG!}$;$yD%9Fk+93{0WK(AUBsL%}l2-0{&mLzg%l;F4s#WpqSl0ESP zOqqaJx((5(cwAeT=vj8vC(bO_H73q3JKcby4tJvGxdl1u=B%^UW!9A*$sT$1M*HW- zzc}^TsYgFM|C}~Bn8Zn#VsoZ0&Rkq(Ytv`u-S>TWd=DmoX=4&6fN8c9_+Dn~GroD@ z{^*_2hbMq(W}rrOA=;J)v-1LJzj}1pwEeHU;d)c@3_6N5<_}#@sck&D%_z&}< zLRX~{(Ct#gawU46+1;-V>hLx;YVHZ=8naHY-P@9p{J-2gq=eJvcTGJz)r6@B=ZKG-mCR@ro<6YH5f7#+7 sSNhvL!HQq8T6(tOqSdv4dt_p$0 zCTrS)PH|?|41dcOd@D-~*;v~x!+@o{O8YG|VAUbcTGng3PWT)Tu&zGUbgH-ibWfm9 zrHYL{)mCVAUFh#R(;ISBz{94|!~fy$b=Z(1n6pqup})DX#8QzTckD1yS03q|pHcLjZfs%cxsB;?FahK$h$T6VA-I5{hsnADOo$ZlqOL!$vy3p%i!aQ1{D`r*K zX?sj3BQcUk)9IBQk{+X$7!6HLg(Rw=rLvZgRD;J!UY5A4FX4#SO0|-CY!K+Rk_F41 z2K&Y9Dr8l0O{q?*1xpv#60esCsUEf-{ZPlEAzeMzV$_hiFcs4aR91(n7=?zop+ZYF z#0`M;ajEXdp<*?=Bpq1ARItjXxl~EFvi0&FomLbrO z9~sbf^KjwZ~AnqVLT{|sm7Q%!tujE(u$AZu@QVGRVPMN%LE@A=7ZybFsqV5 zJ{TS5RAb>`h{S@dYP<%MfhLGCp7S-SW-&I#MPaz0lGnKK@QA4DW0Rsv2_he6xxkfj zD^hMvm0PT8E{=;`*=R;o+f?jIKoEnxh=EvekBhR<=JDq9*kgIvQly$E!ffD5Pz;Tz zCfsDh<0=`7MT7w0$IFGMm5*H$0@1+GWF(@R3R0sgS)dk2MO8~|Xh`6&K0(#v%A!r5 z3Pu0{B?N~!jb-A-uLLGUUZrpYxPoee^%*SN=KZzZT=igo$8?^N$2WV?l8w z7JV)_%tb}Nz=!-IH$LHqo*#sBEE|pv2f&PCEE*W&_$U`SG;uuu>=H2)0}DUj+J59n z2XIu^#;n3{z8MhVLqQP01!VMSFLKlK_Bnehm~riqU;5PHN{%PSQ=J*dZuwl!&df@R zec!LH{5t%L@G`M{JhT5r#r~pvHfyJ6M-=;R`E1V9AoqMqd*?gmI#Qjf8w=f^(9h49 zvhKQfTkjmbeKdV>fy(S^QQR#vWY%4qI;gmxo*{GI-RU~T>svUhcw1*ISnY6Pcs87V zR;g)TKD_Gw*l~JBpR1`)4JtLhh0yYO<#f{uwIgsY`YL^5TLa z{i@RV+%ly!cB~v;edKs35AsIH5(e&GvFll6J0t3+bBWMeGt<8O46FsYqlwM=-Pc zxP>*ut=n4RB5=K%gh02n8*v0Hul*`t0_&8b{I>ILEKz~Du8Jg5GM6L`>Ed?Q7JF5*oAp) zCcXwL)m}_OSO~0>oPftw3*-DjJlI^0vV1r6{qPrzaGuF1??5(h@?_%VyOVcr+`e)5 z%)Oq)p0(W_inU|S(6MIi$XeaW6WSySm^~Cib-ADz~ zPu}L%o%KQ?zPbu|qX|C;kVGlne907t7BFU)1U33HBVn(v$IOShzh z8Ro!!rcGhmGR*UjNlpWvr|#2!h4yFYmMw)h-lz8|^uBawhTf0$G|r!xJCWX-p`V5x z(>UKh*AF~qKh|8?%Q?Ntcp{#_<1!V(3vdd z$+~K@?)t3v$$vCCjU@jZxI+m3<$_-=EuN1+4Ht4Ej#yosfOB03=ei#HhKl^Z9M{Y9 z5xlhqNjF4@Wm2`6&l=&aF6I$&V}-;j835eLDy30n!R3Y~t(aNF$qH#zG6FbPzKJxG zSk4-w!t;EaIesA6DFP$CL^LCm_d6veiee2LtA56K0gRU@PQ)S(JqF03VCYe^Eswg zz5v%Sxj#pH<=&jzE1!R4L+%=UXUtK|>}dMr+TQj*o%wUm`#ozfUwULTxJmgX7~-Il zuO(i~QJ&cgcfI$T7MoV4faxIRv%s{vk_Qt9b5w1r=`MF~d~tmBEHJI4eEJcCjEb1LFC??;I$k|C12s9LX3xx8{nWgbJ3gtzK&>K6Bv#~*hQr2D!2@Fd_2ZZ zMz|9^2ZO;)knlUG@_K?GK0~&@BKzM^V~#Mr({QsPuX~#4Bl2|!ZuWzT!F;{eZGf#z zMiP$;6mO!*ZgG@?mpBbovk?UWuk*1vrZ-{C8WF_+Bus$*_-Ii8=;*i5nYov z;TDD5Ox}!J7G!ZI8xxIrEAH74X?bXe*27aeMA*NiMhSxcVo;BqP4H#ord3 B_XGd{ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_533885.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_533885.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78cb61d5f93dde9e25e264b048d5a2717a36202c GIT binary patch literal 5358 zcmdT|T}&L;6~41OJ2T7d%(Ae6AwZ0iIN63D**LcI=fp8Km>3+(u75DecD*y~5Bq~> z25iWzG_DdQw2@fUNUYmRY}AL~+LB3LN}@`^eMyzR>~!54OhpOk! z{;(!XNdF#s#m>Fw+;h%7=iEErxx?=`t_nf>$J@UQ{-_$EFELSU_QK$PJAu#~5|BVd z&``c7h6t0kM6E+4fkYx~+9|5$d%+@-VaLNfObOOa_~a&hTV?(TBT#RV*9!{_(E=yX zA__ws8=Jsv!nbe2ud2+471_6JL-tZG9B-kaDhm>;N?x#&goj_n4h$Mqr;i^x_=@-B z(V@cz-QyiN-gEevLBDeB_@P(5g9cL=4Ep&pe{t5W@Voz43(mR%xzD#S&>NdTTrbQg z_$?$YB}wNNlGc)BwdzzIQ#OPWF11?a%u$WXt5!TxT`GzDDR3%LTV|8o=q1P^wHC{j zEW69(sS5IL)p5NLZ^Es*%A}cZN(;ogwCPQqTC=4FSF~?x)T?z6j~kCiVckT7TCcL^ zXj{>DkXVnmvCMv@cpKFQfw)0l$2$ov(WExPnL%$`1TOQ`4DVno;1V$(dOE9Duern^-Y`gs=YX~?wgsXmNKs@yz#0< zZ80r0e;2=QP{sTZXcBqoK?!g9Dn8Mxc;vAdWYi~*8n(C+6r}*L%;q5^ytaojEK036ar#h0|-IF`!C@GB0;8+gpaym z8w&}bWOj1C6Tq~1(JL!HNik@lFknoFe6nb86TS#2VVQ^+GpP8;hzv{Rd%<9Emhh}W znIlL6Rvl9$!>+_f#TbN!WAw zKk8G0@mPm1AjXs~S@L%&Vsxwvs4lQA3PLCr@Ip?Nc+5L0N-;6gIX3Bq=_1n~hjc&F z-MxSRGcaRRuWRMSr4De3hMznOr$Ivx9Efwy3``HC+Om#j?eJX(zbTcgu1nFgPVMLl zUz-f-e5=-*bGnnN?tCmAT5Q*M9?+c!w7$FU`nlfO-c)Zos5f_J-Cf!MX7$hZrxFW+ ztb4CEkmKrqUNhf#wK2Ub9mqCz>RhLGWQB9TFU?KPPNrSyzHCE>&UI)41`Vg!;fK@ zOIQ=6N^Cl7OHiss>rzU#!<9hSs*)--Vg);uN(vFKP&#f9MF@%!swH6&ED2g+i~b1~ zg)5H3b-LxX(KMbEa7 zq*SITIbu!NmAaxS)ozw6q(XimwuT~qy%$OshE3+lVwNE8Z`ZFpzpiiDLcj9-R{FM* ze>Rm;X_bM(RM>3-sg{$tG+nb8HbI#j6AelbCqjO4^<~I_S+FF0t|WYK^N^JQ-{XJ+ z6!d|r4a$%|XxIQOl}TwgR_lg`VTFQaP+2-49Yldviu&p@<_ujK#`_z#P+pG25Bi+eUI@z?sIFt3MTIgo@X+C)|%QU9;%^$dWV4*Gh zNc#%I&v4UR(wAknrFPDDT8j+eA9G20ZfbTa-S(42+W%qrgYcs3 z7txj4mX+!T{O4-?o6Bya*4)V9)cU1PNShyirG~4R34uX)!dDcLKmam?UQ~E5QzX zTJ=BK)Bnsn*7VMiAj^h=o?dwzZRvsXWPHS^l( z2R710X-7eXubnwReLlz3BnNKP-E6tmvNQpcJf$6hNwyl7i5%0AYPliaj9!Z_AAw1h z(hfavBOCv2C{>?sUYJ_yzD>Tg=0HxD)|+eDK7a1&xu3z6_T2%%^6%!w23s4AfNi%K zmRsw;S6#PdP+Oa;+olcN<7)oB*c}^V|AxN|a@1;rW^M5hgC2)@?xbh&;dAD14V@AA z&>e^GF@q_NB>bN@!x|1Lh8?aCvLap{lM45SQ({bxOXiU@kAT7E&x|(^j|kV0xl;bX zFfT;o!gxeHApHQ$!68BNB0Sct1VMa>9DhRGpHcG~g)GcFO;?)MEG@(!0pyOyKWzS> zd5tcR&p!W|;}b`LWYD8+^O38OHM=>>pncE&D)Ld}Ym>ZhMHWZi%vUcHjf68fKBvyA zYY2u5-^UjVbGU5#+;Pic4l`$8zjId5&kFak%>5n$5!}~Im>_Du8nz;5OO|iF!kQ8N E6JJh%WdHyG literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_552958.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_552958.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fc74db3b27f8a85441939a51435b76a12ca6b34 GIT binary patch literal 5822 zcmdT|TWs6b89tOqiQ+}FZ21yj;wCGaD0Sn;ahe6~+NE|bP3>gux}tGC2#QGCl4Z#u zX(#r`f!R7FS{_6f1FT>Jq7DPz#{+K`Op8K^YeV?_zyhKAZU5%t%hy)}MX*5!O z6C;Ghd#1c2B!NUCW%;qYR^Nh0BvZB*hf46e%g0E8e2u(WvmEgYl!#Idj*c}_@I658 z7XtzfzrX|ZjJte{6WG_t5vHCY_Zk{uJxFBhv2b%llwZT+pEmu^4xJfzJ~m_%gNcwM z#sJV>_^p58!Lz2o+*Ml&_^nMKZr1c#Xd6q9ZK?9}6hd0VHum1SebBNGtIi(e)<_Vn zuFBu6#W)XYtx*j-#&GWu+|;8{U5ltqDrZs6DpEtL7e|N3gTACq zYb}nE=`Z_JThxd^+#+w{nTD2Pt8eF1sWWahXwQz;sUt)3;vS?LVoZUREtF{B8+~iSh-s}NesXHv&_H8n5Q`-e< zBV(;Y^}y`;?FiJyEhEu&&DN*`wnpo8j2Wy?SDjEh)sC;9!7ZP?8U&g{QU7TP7tFNuz%8rf6 z;Hdfr9#0D9t+!y(IOb_Zl5tuWA+0J?RnJ;;Ws>65IoYIfMA%a9g3%uOs`kZb?}e;1 zDNn}}V(&yeF&S4TvYE%?<6=hXm8C?lB2G>B0@n+kML|er#$zypN;VUl6s3%q-akDT z1GOzoWMOg-9qK!B=$oK18#a1X&85d6BD@jEx1a!YbpK(*hw{U@;li#Ju2VmDj|HJ;hz{gB~iRu6Ki@$m0 zy;oM4r}cqVhBvqp(`ivZjT@0Gyf%ihi*%q!R5{yF{rIIPW7KYQfRPPIQ^Oh?}C zeJPx`N;&LYV;8*E{P-ZVub3aeJy=D}uTmPVdR32wo&hxgxw*w{`J8O8uMA4z!rIbQ zpX#^I5*mUTcz}&c+bE-vf=^?VX4_G!#ilmiV2$wyn#QU=hizL`>U^qxm9Zn(#*$W9 z+j7jSaY~2nN9C+}MsRK*VGi-Oyrl^n=Q4Nw#yOAJroX#>qd&P#e|P8m1^8On$k z0T{CvL@AaXi}jggLP*Y9kYxI*;2-6;Y;jUA_^Iw~OfoYjN|GqV22J8wlNgeAVjW)0 zjZjT;91a*V4rmkZNmTtmdUTm%yz7XTNgVQIt%fIZvE7%pSe7DY3|2AFyeRo zt@naW`N`bma(iE4;QDjdo&%ydUuy5W{m2JB@ArH(`$z5f+TGyyb6x-)V^@yTkFHjktE_RpM7LqS(Rvu8 zC3+W7ENA!@xdpE9!b0c{+r3P7FS7tse130kuW`VLU)^tEsLd9_D@=zzP!2ZeC+-C! z9|s>U1s^WHuoB#(pTOwK+@pheHpd!AjSDMu_hzo*Cpn)!h%v3%c&S9TLvX8f(0a0Y zCdU{DuJSnGIvLFKIo^n`&~0GnpdBXeWMq{N;J=1e`1@OT1LdpihUL>wN&8_%6Y#4y zT<;hz)+a3R69^3n9*q#ZnnxftFYu(}uw4hidVd5PDXDtK2*1t8%H;JaA^U5#>h zbb(8_UOBzF0VeiV8%6Uuyt!=S@VWmh-uy}(KvW7gcbuKyj+{VA1kmA6IWNw&sjxv< zr3L>}Aag|Z- z{|et*5=YkdLt7;MQCd2T zJ^A4>3il$V21}UBnEo-ilv;Puc~Q#B)^2Z7iF7t2Mg#cRoQT8yFgY)pd>s;EfMjsz z6wl9ynS>Zii9fh@9zKY zu@9bq|M}(bynLS`!#@2vFaeNxHTPz4S=)QzbRkAsr+=>zv8hzedw7Fvp(OY^r6-X)*Gx1q3pqTKb+^|RN`{`^(_#Fb&# zz`tG;TYPPG0=^?wS`q!!rwz^9Cbhl0<~Du!Q@-)v%N@{ZbM9Y?mquKZe%eA3IE~;E zlfDHHGgLjr>`%l0K4##8Zqha-S;g&5B^5J}$Yx|kyf`h@jx{fd896J#zYL&2@sYzS z0+X#)O>8`y7H%MGtZK>7PfcZonY8${bPkN6Opts79u+S^5T79K&xrpE>MRrf%k7ui zE1t)R(?q2c1sn3CxzS3O#qO-JpLb(+-HSY2%}^i?5~0;VFyEQ$tayPgbK(3@Zm2>5 z4e;L7Qt@MoM&U>$fGGwsApXDl7YRi0pVKLVX!>l_i$Yy1!JU^_ HE0uo&INrC` literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_574109.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_574109.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fb76e01de297358989e75863ef9a0d8a826b680 GIT binary patch literal 4914 zcmdT{U2GHC6~1GS?Xkx{acn2#4<=of;3Xs^!17aeX%|8uB%$rD7LlYi#xsfS_$Qqi zCyd|TJ@nZLB*QV%F^;sUMlYhrIBW>YuJ}*Gzjh<^39e|xQ4u!&3c_8)B9@4uvGSW3 zBhC#1T;yN9$UCvj&cg+CSfhYm)}nc$0ikQH)uvh_$mMhoTK%v}H{%X*Y9L z)&hU(C3Cw7 zW(4a@@HfJL;|Bzu9TA@85;bU2_019wD*Ciz8%u+3>69Fjae4wFp14+vPgZd?NkU+2 znFpWdp-D$@NiNBR5#Vi}2EXRPi*y5Sorm?3DBL6MmYk9uOEsREHRS4>B_#WUN^E$P zbfeTLkr*}Yk=!6#wei#58l-yO5|S4qNFJ#Hd&C-7$Om{Xp`?8fi47wGBCR7f)gV%3 z%i5=8keXCPpX6CFX;=ihhWzGH9y|R4&;OLx$8S)`cu<&DOi3}!@qU8G0sBpgL5wN3 z89q712jhVVtB|o|kX6hMmjDAHFeCDcZH|jfg+*1iBvTbZ z;JzJ*aS6r32I5JUiz)Wmc{I%|pdB8;s%MpaJj$$;nBD1PehigI{*WBE5u=IvjR2MUd?3lsT9|MI20`%QVMNH^vt@^t5q-}%Y8ADvqcuC}l5emhUU zEuStp+_}Nbx8yTLciRHHbSb~HC-3f&ho0H&*{)1i?u~_cfLS#^@_pp|IP}CeB%dle zeG7En`Pve@Jf7ckH19kLZj)Z2pd5$^PHIC6CpS(dchQ`($ddh*WX1Fit4JWFo9{nBoTFW_&}S}7U+r3&6XnS z$2nSuBBuAk`WS}LcL~6Lb?-WdxCXmUGMO~YR@J zWSKP7_Y+#{q@ooF=w`9>Nn37C2woO$VO_D4UrP1o7A#D{~bd42@ zK<$J?VsC-ujIb(gn!fthVrLC_1$Z}DbyT&Q#2PiP5(JAQf2GY!^m=-d$Zh)9*H`;f zK>C$+B#=bD3-lW_z{a#;!WpmJftgJ53>OTA8TEMVkFZQ6!CcnQ z%4H@c0M{`6`%g3RATZ8V=5qP;T<)={(7Gve95DD$xF?bj5B4)cI5@-YeRtsNf#HGC zZa*^@5yK47VTu#*5aYGv-OW_V+pnUQ$H1;1XUZ~k2nDZnGZps#Cd2F%Vv!KX1O*`# z=U5n=QIi1^6afxwnFD34`v%3tiq~d1plEIuy6|`>t7!}}6+Yd57vv)BFW(I{#iYie znAAWN%OoT%vG2`_aIjx-M1(*nnGhpWsboq}OhC=zHGV(tYA%zq;!1E=A#Ww;I37|& zpdWGs8<|yES241YxMD76Ais^@gPrNZ51o&A*mj7=X9lM=2qhpy<~bg|C{Z~SZHMyyYF@9$!2*dXDQI$ z?8(f@54pwZd()3;{~GCEqx}Wiojs8`@!`PY$i0!3!w=uO|JGXDu{?cjjXbtS9|QGM zdRLy_^@Q$RBRkjV&Z5O8f9ILglby~?FC1AKTx;!KJ@fP7kA{EK^!twAc06%jyls5$ zX?g1L=RN-Au60kpJY2MUzW@H+@jK&7&CA{Ez9Vb&QTg<9(wloPPqr2)H>8{bMHeWS z`tug3#=@?>i|Kpm53a9vKHhckw)H8sBTwyEa6h5i9%1vWBjd;g*Qp(!m+7o6W6Sk@ z&jE0{=nlwF1RVAxI~D<*OMtRT;*}a8cEE+{ROzX%XH5MJ89|aQ@{=#CJ!zt9*1mDr#W_hI8)FNLfdVIj4Zl64~X-Dp+yN)}K zBIV6RmrkyI>Ci6+es%sA=hwdW{-%j^o8@z00^~fN881?exses$!;bqMk7hyXG|Oi| zNjGHoW%d=R=7o+G?qU3X{NppAq|NfFO)oMzZ$}oImfDx+A02!`o-Enn!XOV8JD9~w z_bz=fE)U)w0mlC4yx4-Z6$G%oYFJ)*=y`*0+n~0G^X-sFo;w=BWvR`{9_MDFSgU;aWczsDtn zB(E6p#>B@$-3JO?-Z+7&WQ@I!)ZxmzhB^T^;#7=#o4*Lguv3umLwJ;o1VMa)?0-a# zKcV)L85yWstv6arhEAft)QUX5yWMxXKWl^9h7lR;<@Rq7sHNjmdLQ8`xV%{@BbAIG z7wzusaAvq<0vUijO(hGKDCG8*tXQ%kn=9L%X)ocPuz+Uz*2Im84F}XV!UQ55pHWeQ S@O(O9M6M3FV&0%t|NjAFr7k@H literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_58716.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_58716.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e518920c2fbbc1fe56ac1b934267c781a4e0fc5d GIT binary patch literal 5860 zcmdT|OKcm*8J;C~mrL$0R}%Gp%XVWkaUw`|V!L+hBvx$MmYo_g+(uTx(%h9qiXxTe zN{WTa)1KqF>MbOt)L~gE=nLiR5vAdFTJ$?EV&d# z(^2|5bRf?DuleWy=k?9#ccSP*P%gwi7-#Ab`jjwM5sJhE8->sfBqNy$qv67r8m3I% zHfbNGDWp&#(~s4(@Re-}9kRTXsbqU~`6w;ZZ_>*}%VCGiC@54@*04y*>|^8(#VK>} zb3R7zsxBWDWd2Qh*j36xhh9E$);DAk1A*GI7#X1L z@O$vFjf|QGV^=64;46zvEf@9i+8r!yCCj=UEbS#r53a+~JVaF1g074K!MtK=xftC% zsN?n6gKNH2yP9q78gM-pCcsm?5jSAbq?$0owb)L)nZBFLVqqVR1vi5a+Oo!tGPOc4 zlktQSau&YjwN_%Pt-#VqwBQ!(nuj$S2MXpH{lF%pQCkbN1_>2!#Z0^nv+;J^3YKl9 z{u_ubu`bx+c1siQ06)~k2tr`*CYjE|OvXDQ;%(k>({>NU3U`{eT|10K2kw%Yt?c97 z*ajo%C`HhNdvJY4Z=9vYm61CwO?+2bzg2c8-i5ny*9vQmI)&zt*KuAY3uJ?=)(t{= zZ3aE*k4+d%R2x^+F%X%>OQ{?XUWZ}R!iH-~jgG1QNncPlSikCzj41{kj{0TLMhzwy z(GK(*OiYWaig&l+n9@|Rj6pLXW|iRBxKCZp zm+Zk)S!=?Dx9PxgCVIse)BLJNqNZLns3C(IFsMr^S%xavKL#@ul);UYe8Rpf;fVqX zG9cOL=x7Yw7QT=;87IP@%wT7IR|qG}`qbzwbWPGETp${@IKG$} z3=A|d9nF{nZP3c(m6*Yqc@ZxrMa*KmA=2)h%ik~e&PLUV*pxq@^p5)j6Mk(x8hP43 zrbM*fm>THSl*y@H;Ci7aMV5n+F&~7eMI*imMU5!oXQt+Sptgj8C`{6`hY$5V2MVKR zt5LyP?T4_)#*h677L$%1?n9!KJeN3^?pYVw^_T9|)ZGxTi)sIQO}Bn}JC_$5(@j~i z`@R16kGy;2)`eBa`tHM7@vwd}_gGC{^rZZWVPaAw@={~EA=931c{(dStqm&kYN>AJ#oL2RgUjukcEmXIyp-CXrf+dDD}Du>@U`#sq%NlRCXS~sWd=Uz$(&mBW%>QLX0rUV`oR5SxJNS= zS3RE)Nku+6sVT>~-+C>zXJIxWrT1sKC!n{Qy*Y31D*a3Us$;3^&atfb6cG+m`rv(< zy*7AtaN%1iEwkqj^xg+BH^{J)eaIG8o!UkuF0wC_h)i~cooq-}RkUd}5+I0Y6%P<5 z`nAG7xq&bWS(>u1Mh+0OgsPxXnYIc(jM`>U%ytHSkHS=loHP{$&g6Ztwk(UPw-WQz zXD}TSD;fi9A--%Ue$-=H5qkiA+>ct1kqoQ!@iPiD*QsBy5aXkb5J;=sYRXP}I?@l( zRq+d_J=6M_KFjj1-j*SGW2VfmYV5Mz93K`kd+hkCQuVxL6Hz0SY>PX3+at zw#>G9%f_Ox1G}uacd+CzZ&{Aow zeRcg-dwPfV)%82IFWqP1t^zZSIhZXM=2slc#XB+ubN&kQa_3bCq^*)2Zr&`GYz8B1 zb5n}J%F0YIpltjD%Et4c_u96Q*DI+dG8t+osk%tDbyc9p0M29Nq}c|=Y9B}pX9*qE zA>h9Xzu0e0Slo0YK}sG;9QkR(&9=q1_muZ1-kn%$?$7f5Yjpn_-w!a6JeoL~?zzji ztB(GC zSC@!qTqS9qg~*7&ZxyGB2C&t3mV_nxeK zPv*k9d#`>fFNi;!NxqSIBi)xCTbG_#VQ=$G{OSv zU;|Z14thEDIPv6whb}ymv?4qaE@C)F;oWOKcjpx~8Z*y(g9(JA5yk5ym+OEZo}IyY z#SlxN;3J48#gx~kl}JGGg_X#dHogT#hBFwMf)^~{O97-zXmY4ZY&xvPV&uOiCCn+{ zjoZwzaIedE$YhOTAx(_jKHb2)Jt&hpf`m4@3LOC5J+K45#wG6rW-Tux90C0_=G;bDIlp9&i0P(o3Adu`jgB0 zsfBZJf`7RvDtuKs0pImzTJ`#w`!x+aCbhG>hGzZTeX;J}%U!6<{gHS%BzW}mCX&FN zge)=jBvgh}xXU~fhX1Qf!+YJ}EJ`&Ow>=cp3}+x3iD}B~Q)=;Yb5V)JqAEP_V1bfL zhq(w0zOZV1W6`j@gv`DQONKi(8I`BQ%GcE(7{f9_v7bY=X{RXaQzZNeiGN1zdCGCE z^=j*;t&iHb*^S&a$&tj!W)Io`8on(0QRYIg}XM zWPpZuv8Qp>R5Mo JUgga+{t0L}+IIi| literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_600998.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_600998.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a624365fe67673ff967d99ccc7c926a134c85af2 GIT binary patch literal 5424 zcmdT|TWlN06`duQK!K+1N4b>|6I&N05PwvEa%7-Dzj|lM zrD%nY(zN~P0NlBcnLGE+-aU7Qf3(>w2->yrPe%S$h0xcS(VDqPY&=IGG=~_(5D_$3 z=!ro><#kd0AW0yW2&;CqsfC`=v1Hi#s0?$A{xNd$F>*s?`7p&8uaeh`3k;eV8)ITo z7~pjam20soMX4^6QO zJrZIVHdbIvFZI3pa<6Ydp^C&nz*gq731~h1Hb{8RAp!Qk(84;6+UyzPdQqRSZ8K6= z8fh18A~|J1DB%$8qE)4;%4Uo{g;C($f%0q8`HSX1UxrA#7VR&h_TR|%)+6st!;lBg}4SveX9o&n=aF^RPd zv830$NVp#H&mvki%}I3OCah2=)|IVM**oha-qnlsqDA%2ExI6w(EBgRs67l>7HNa%5S?4B8?=?b<&#m< zBpSb$PfcY$*&);|O(eud)u#rrY004NF@Zk#1qpfZfssU$$ZH(n@M~VfneNpoWYo`( zDu%c)!g4{)o|oV-WO*q?Cqxv>7#9z6{-`g+C}bq=XFwWP4566N)~Xmde=NuE&A4yQhc-xvpJfH)&O6vJ4E@eLKa7&tRL%=?BkA_fG{Uh+j)ObC8X@Wm8L z9Rr>z6y721qhf*wjz}BDC`NS@^y(N!p;f-%nL?|=B15s!Ar@eZDO!dgi2o zd}sUT`R2=UZj>MM2iWEje_+%vjKpL6{6RJ*H1k}bSzx1M&A>Gm77E3JKGAws-zk z_dDtfYSsP<40F0u)v2!Z*}KktNmI`0x%S4wJM-_PFJ<ugcG>f6W_bC`XD!QT7vr+$ShDwC{f@cIA6`yPWqRdZ?JJ|% z*U#Lke=FIYckfC+^GQpF$+XLR+LlM+G+nS;wuenu?sUF?+WJCfaam%FFCXV0eIm0f!??iK1Q>xpli$ld;c zMz#Al;Qb)QtbRX*#4KbRujp}p897!M=VAf;9caw~1pY!cehDZrheQNM+nji2csoI1Kb43=n`hZsfEty1(!yL>ojV!7b&ffYb7pW5e;Q? z+?!)H3j8x`qla|~T4*R~Xtf9;k%V%9!cCfdvu8+cJ(2uDdc!vJjU_#?js42{P5UO1 z63wFJdhtj{KvZucuW4XWrx+MvatsO&!%l<(?AnVE-$hnYDIzC31`?L9~)!fOXXUznix{_ zP}Ps5M*s3@JY)y?>(eea!?PDp%i=Ti&O`KU}b zNIkQ3j;@(GHht{cTMHNFFD#w9(ZATAZS0Wgjx5=cr8{zT?M&x%=e6;LEAv-C-MiSE z-FZl+4`spU zj;fw*OPQd%OSxqVD#9w&ka}}raDFgDd~_l6@{OLwp5;r|`&ZgN@3_@*+yAR$Im$6( zpSI6lTBUZR+82(`AJ6PwZFnw6)y(#SZr1&fHX4lUc4T!(JzO*FQ2>6WR=OCB02;!T zj7T8(Y6(W4&@p5}4?J0R*s6dfM>DnY>^uy7aG|veXLV)+CF#bxmH0`=3QjF;+=*qnM zZ|&N5dbX`s6!wZS!#lP7Pu}T&<{hniXH5`g`}Tj98{57C8;((=0J-f0p})iT7DA93 z++mQe;midX9cTR93z(z@mKS^hKhOGF%9&Ob*Fz5h=E}uwYiNt)fS2O%dWs2dBn(_q ziqMX^7v2rIS{vig`%RB-@0kF&7=i|+vl-5=AAC7FJJAFdf@b3T{ma@(3YoaXT#~v?4AR^KJ|I;t={ZwZ$C7U)kf(R z7y|HoZu(pv5U_vAb)#vqX=MVGcB6C}lyue1v(wM!srppY5_==M7`=TOl(bQL`C%|D1ZN43FBXA2Ihs&5k zX%vV5l2P>Gkf4|Y@fa_#?~ZZBqrw?B#>Y7|0o5c>=t9=`g7FBmh}5+TDMKBN#+mU5 zdz@pz9Ha^2e*=$oJwXs(BkNy~?XReD-H3G5Pd)E>)^!cUO9Zfcp8R>^&l=ZFMfOiA{;5#LgRenIs$Up@Tv8t zRVA~*$lV|>2YHoSN9aY}DcwJC*Qa|>=r-__jiZRLeM5x_!ujC59yyv;?K|J2)$slW DzQLO+ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_605163.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_605163.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e85c49951f1790a6976b74fb24e837aaf92eb851 GIT binary patch literal 4520 zcmdT{TWk~A89sBf$DZ*$_r#C@UDl-`%Pp|GfrSmQOG3g*U9ENlm1R5=$Btu%nMnd@ z+zu5|<5krp+J_kRfgr6!D3Rcil~!t>+n0E~YTZ*RSy~>Nm&$BtANth)j4uqqDBIgZ zPvSHGf6jmY%bD|k=bzuXTuub-w~=2>-0&du4Pk1{RV6mv6hg~LLK2lkW0fa0MwvW4 z#f-5OlBtBLr?#~6lxUevh%b&2PGVj$&Ndz&w@KVh_I_2-nB9TMpCm{&86~_8^xI!C z&c9-umL1@eLt+#!R$2o1txcltSLZX1Z6xWsq*Hh3_Ujx%8NsrK z-KrZQozjKtFl>>sy(}eBy8V9DHzHlPBkE3_Cw8`7&s*1LT)Ne-?_g2qCc%Ejt-Ex_ zq&!-EN5Qx14xL6n`Us@Fx?A_QgP$r*p=(GND_q8>`*enQtGh@P8NbErMtg&FKatZt zYfR1iOh6ZOuQ|$EG&NZ?HAn!R)|<>V0aF)w_s}Tx^N(5ZFT{>2WOFx26d?+;QxSDC zWW%(U#Lj6Y9aADx;kbm^WI7^YTSSSZVlw8c4^>Mma;O7LJLo83bw~3wL+w4 zDkHqADRD^-&tYekf*I^g%}j;0^rW0p72=dac0qv+LYSEbiyU-e2@9$gQM8I6W|9Oyp;GuZolRYhFc1)h*(RR0d0fu1!X!L`tv?afgIXNNKPSadIpWJhv61+mLG zQ}j3I_@({E@Dr!Fus6GRac^!o@87$&>p|#V=;5h{R}0@infIS8I8PaaMVEi^Qr^|` z`MY<|-8r`!dDvChb0Y6LVVo`bgP+rP9d{h7J*yuVy8H9~{zn~o|DZ8kbOjbC@~$1m z*(b32*e}BKRMFFsJMig`|0w#3ftJN9Oa1eF(c6|gmiO*n8P9uvFwZ@9dlzEa*kWSk zNWSIu)!DDz{qs!G*Oj}nv~y)B@7@2fCGR~7a=dea&+?0{xg#*Kd-d$^`9n`V$Q|5- z$Puu z=XjLiERv0If!Em@vTP%1*SVVHI1M*J?NUa%qY^tHISo4im-p5tF*SVD>22nDozYpH zTVwB6(Un3v{X-PuMmNaTDkN?bt3o!#Pv`*Ynn?E|euQ|;rlj~3=G2MEG@unAjhw`` zO1ffCO1=ufl%A6nfQ0Zt%tod7teKD4RspDxQ`tlGk|U{n?_~>^3S+wzo`M`tnjT;_ z1`wx`ZGH|eN)NEuQO|tkAn-(+`d8=-R2GrwUwAwFcCP0up>v(>To*cvf_vdu_SoW; zPmiy&ZS%()@o#3|%mr?@-)dhuaI0%&_O5ExnbCe0IDT2o}Aq zMPGZdx#w>ICub|SAkkwCDZ3%$DE#WJUWiP<+k?uE${>l(P!f}&B{stV&sr|F2FYa} z7kZr1>M9fD3fa;mZq0VTnpBVy)yo>8gcW3}Yjh|{EZQ2Rm#M94C1bN_Ym{L0zogA> zNh?mbnbFy6o>M4}ApPx)Qm6(GIxpD=AQFpL|E*nnp70Gk+tw?dcqKXDojU&~@AN1V>}TfYA+dasdKoq)3RyoTOa?#I#BB zY&oRuLfiz6O-J{A^X5>^<4kzVQ zOq;;sY#JoZY7uHbSa~@trzBDbI^wBms6_#wE3i&582pjdnWUzw##xvYybJrX`$~Lk zF1RK?n7TLho3k(}*o;A_5jppKJlC?)wR-*0p|9C@;0p2hjQ2{x?%SW-`s9}vjrZn9 z0Hwd17h7~~RRX%3%&?k_p(ozvZG+lgU2~f;0`=v;7rVGM_LuzSkr*&W&6_5~nB~Go zH@vjbkV`oMPh}r;*i)$j_9x*(W(KNz%-1M|90bfH;u>~D(M^$Y0WE~Th%koYI$`ZMY( z+YrtFtm8&UneLzt!*1{Fxt+X~EV~HjM%}w^>$mi>hj3oB=e4^7cLu)o0k_E@TCA)& zNueD(i#xiDyIy-L?4Uf0@^W%1Sw=t>18vI#O9N#F=u(yDfCj5=on;~$j@l|lyAlyk&L$GPHAt13?n_bof)&(8cd2A)d1&6khDuc*+MYZ9!DLOd z{d?$@=FYw6+;iqT_n!ORJHHb}2ZHj<_eDgX07D^}idcS1riHge%S^=evW)a3TXp#$BQZCb`J%x9Z%2|-Vr3NZR`J2TLe9TJ zE>xEf3X=UMJ76zI=eUUm95j+0WgN~4%I{hw5c)9x;i)r+PX_w1tI+5QJ-)cPQyOQqDl*9<9L`n%COHDJxzi-;WAtXKLV13 zSn)0Ho0bM;xQK-zO4I^9PcTABxGSuLf~>MO+!8f%8LZp}OTID*c!{j4eRaH!65fQz z@EC%H)oak;1(NB+|Hdq#z(o9ll0{PiKl`O8G#!M`zjbK*Mhh=rB zv?6{=B~kl1OzTnX7*&-aH5dzoCCmlYV0=i%OjHRmZz8MB(MsItEf^U9bW7}ZtG zhb7HwIL77h&~QO$*K{>3$pJ;f-fbj5WJ0WFE^t={1#Z+6eEki9kwwQ-kJhkbI$;4;s zor&8M*_!Opn2!@3hN2{?uDL!%y=We)fC0rk(ejm$pCl_7udzw2<-7 zUcKF(?O$xpwSAatKD^Zb=pP3ROI^RMS$g}ynTO$AH;6Ch#2=cyPXtGbPx5Il(|`AB_RZPx zTuaA&|6R-GEf9s8%v3b6GUHWUmRG{|AaTFNL3CZ{F1O4wArqcCACVV3fm=l z!!}5Rmufqjs&<%FFbOT$5{&LC`>%3r8MZRtvs3&a1qYsx;B;Tf#^5X_q=AztgO1df zvO#@bks7r6 zB(Jst%uwa}0dpe~+xkJMeok$68YU?N&Nngxr4wV2g749N7SjMHa)k|znUEB|W^tsf z2URjl8ez>=Q*DJ87P_Os*cB;wXmV?LSGRXZl~6FM9fC$d*+$G7q56x7K5MwM_4d|T z*jd^BInQ$Q!Cce3>g%Ah*m#x1dPR9S*{5u)w*Y zVax&@=@aTs;$#=8+M&XpxO^Q*R~eVpfHD~9v{(jn1p@X9>KkO_o22?KsTxQ{GQL2w z%nU$R`vU!BCCxbGmqwZhxwx?ZeMC7Ig;6%8OrR=4&IB-MVAiz=DlCP_t<@I%9eL~&7)6z&42Lu zb3T7|$Fi@}JoTmXU9;yYQBpvC!pNMH~W^#IRcKc5gD}p;E zCdG7cS!l|1BCUwqO^yd6KzM5^ROXYWEG;rxo6+ zFp5OH6*&oZdUXSuQI3Wy{|Ljd-5_l5ri`z?$h zHC9cKu9CA6D0oZ?s(#XR>awN>LP1Rq?5*r{TI3mi4mn>b-nWW2NV!b&clfd_%K&Chw~^PvvXs%@faEsHR~?oEGy! zT{<%R{-f8setG0qCm)=AboS%tEK|do$6*K%@ulRYyil7yy-Z~2I`hO+Z~dl0ZSJnV$vpj3to`?5cWsRQ8+Zle@|k_s z+2W`A{4Oi+ibr$da?u`z|EMwet-wO5r4nApwn$jVb~yGmUA{i57EkUAa$HkXD1csETG=~b8h*n}RbjK8FEGi#T2cSz10_|r|t=TAw`U<(eMB*ROwlxmX!mTYg zwye>OR4=vWKyGj9a`N(;)9SjAuYP9t^zN@k(p|S9+EpBQ0KKu}nXR7kr1#HsPj{~& zXlCp0ZJXO>HCKF1mY9rK*|q0h_gwcnX|H!9O8i=gP?Yc4WgGIeEW2O3;j~cy1FSsh Ay#N3J literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_635331.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_635331.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84071bd69b7ade512dfae4ec9750b50a6f13ab6c GIT binary patch literal 5395 zcmdT|T}&HS7QSPT?Sb(RU<{6dq>0+5-Zo)@B&39dO`8S^f&L_IwvfbYJcF_E51AQA zA!BJ;Wwl1C1bxD&k?{6qO@lT>D{V=;QWI&Vc6T4vvqX(Go0Tjz56c?~dD^EvckD5a z>x8Pi`>z@)SKM$uTdg`K6-394I<4`5XV~5gaRL=3S@!Mdd8{+S`#va zSS`}maL&+9pgH0^(6Jh^rYNRzl|(?V{G8yZVs_CkR*!)sR$J7wd^R@R1MCYqiM0%| ztkO9o8T~HxhX{_gp|s0F6!sJVNZZj$WYYFfHbNY0uiCDRWWehh@VxA-rU>3 z3sH`B?^G;8bPU#LfZ;vk0;iaw6M{-qUSb3O(Hw2z1unp_9@Ugr=sXLi!^O;F8b-0^ z&2o+=OL&IErO)9n#z7e%&Ik|1Fdkq$!?{-kf*Bd%J;NFi0fJ{QdO|EF1TQCeA_}G2 zKx7JqX9&lrnBajk(ySQ8s9M2WwP6%m<#UlKv?|Oq6dN990k)XDXGk*d87O`y1o_NU z+rQ2=Uy5>L{J7W0HjjFJV_so28rkRdvk{@0=X}iq8y;^4t~obQAmaDH>FvE&H88jbKxYL3d_$Qk?W4ESv&7qyJTzEA9Y_- zx2ZMj>tJSgCaM$N$+LIu`{JgI-F5Zt#rGE8OI}RvU2AHU?X7VtV{b_KW&3k+>RwIl z{N9iECibQF%C44`p4FIq;FRn-l|DU`_Az&A*tqF_ePiOSB%S*CxAlkO-5E#YRc>)= zVJcOf>RH=$P<9-QpUgNK5@WK%4U}v5r@ND9lHlzpsSC2}$m+F*!`!cRQpSC2KWV`HoF*UMs=8KkU-+yg#iZKhJ#8yv@}A?e^Z2ck-yHbcTW^2Ue*2B|yI%Q(S9W^iJ(>Dl zpAyLfNt0avLh8`UvtJM^dzZT8`i^)Hp5>mIp1HG$_hrZ4lyjB()_nY-9XZ<{(5QC* zCVU^HRH*AGSD3kK;}t!wFC)ha6I{fX-wd!ZvcV7Yo4*GXm`5TCVrfY`hhbkIBSqqg zT0_hTT8tzmO(=k65HG!wR-`Snp^fM!5wGh-9}*DP(|wmYBo+BCvxOA-)=A9*#s;+o z%K@ogBt_$hzI=utEpsqhCN#i}=@?ziB-phP7`@=o=wO{jZM{XKwvlUFT#OQn`#863 zH45T0Y1SjUScT9~&{U|KAQE-39iU*7Cf|B9q&A+Y`%!wsHvJ3vz)U@T{POxO{|3=0 znndba{zykaR;Q8MG_a&o42&=}4jT`{P6mAJ`p+Q0bqOBo?#AKbTE`XSwkZaF)H}{{ zhk#QIus!)k71IdpV?xWJJki$5;dP3GdrJK3@?ObvO(&dqce2P?9rK{SKnDYzi@u} zwd?&${prRIneIrF9cj8FL)XrB&2(LzSiHP&8Pt7Ced(QtW%_WMJe;NvXHAvT<$G59 z?AXj$q9x%=9!a~JS1x@XyAk{5>|ZbZ>B1fBxwt-St&+O$*=pw6t{U$;_sh=xsmp85 zBWt!!sV8f;d^9Wlyw<3R%e^wjYegO;4QKSH!tpkC7!g>oJC=Kpqkgns(1*n{0yxc)d(gMp1 z9-o(IJuSsdtBT!$M*wrh;ykRf4?xQm?xv~M}v9|fC~@&@i6RRBMiR$R0Ja9 zaE-w?0Auc;l|y(F$MO>)f#>n1sQ~0T9Ji{u7KY@`gGmqk_(c`WHz-t7C-r7)>!to| z-45xEtlc4<%2qq1UN~`A*Uei#wq&W=xnQy@z30FmUi+%=W?%ZvcOMzZYNPZzm;hW4 z%?xD$5&M@N*PE7_RwqGeHA*KzNmtFjIP+qbs!ud6v)98*;aev`NgJhCAJrm*H6BRR zB^y)It8I74t_?G?RZHF3rf2Z+`;#H5JKhhd{llUt@m1;seAlXJ)k?khs~p=VwY@us zQ|iBOsrhlan@e;5Uc3}C+ob_@M{yGaZnIi#+$*@SY`L?*{t(=IC*b;}P#VSIKVcMo zFd!%uzG#FO*!Rb|{NdmX8{wmzT8?T7D0Hr7JpO2iSwiYaxq_iO!%=1;#2({Va0f+# z_}{`~Lr)OIcgXyAWcdeb+%O^?^()s0t_@uS@d^R#o@am4`0K_EQ=UEA_0_4Hr}8X? zcI{paErd2IR4s+tU;aFFBlJ*ZH}yzo&bfKe=^$(g$716`;|2n9#qhcLhFK+V9sIiU qcBe|FhkkK)$S)81H?hrTCn78lsUShvADq)8Thp3#=LfW!&%XgXK9*Pj literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_64602.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_64602.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368ee229839e47e3f8285c74edbefa2ff7fa3251 GIT binary patch literal 4755 zcmdT{U2Gf25#BrAFUjMNsDE3w<+!pLTaqQqjvJ^>9LxWOQqnQW$1=#3NA4Le9AIVv7GX9&M*9lJ7UvhQ5Gaqenh&~z3D1J`d1 zoWB-wR5{AP*ap8%hQu4F!8R6K6zH8zBvwlDq(gV;{49;CHK?8QVH*jZUnyCd1#!}; zJ9Ty(I7y*uE$qdPrW3joYv_VVuG&_xkI<5?Dy`b<*~{xC-MU-nthGE~1(M?Xq=-hz z7OitCKtRbQg`I$ilAKCg}q4`OZ ziU-wkla{qHNeP$L{e;58*iW0J7BlS=iX2vg@jyg0nV=F(ge8-T$wASiRZUhTe}l=& zlUiUxQ%p7@s@A|hEk(j(MJBImN<@?bvT9a_@yc-eC3iY*5|G#f0V|aioE8es1JgDE zYjcV`tp-%-eUryd2QX>!nmjHgz%|umLUK$Eh$Sf|OuIF4Dj37APh-balelX#vNcp2 z!1jS!?Ju{xr)6baod|}c?y+EKJgAMyiOyhHN@(4x66)5Z_(V5Q-BSFHB#MzlH~=wf zaw0G;DG4doH8B$aahVyCA^cN)r;hi6!1TOSD$Z68f+Kjx)ow^T18vnIhcMrn>dX*p z_EzIu-d&geb}IRi@^I$COxBydxVFD9=k7BuN=Y4G{@|gN%@G+a~ zIGyvJ&U4-$wx@5U59henOfbjoe_O;o~f&<(<4c(84-708`(4i*cZ75f%Q34hv2Q4H6uw{xEVNxM6 zGG&3U1WAgwO2wFbz);%ZV{-kmVEm34JTtqef)ri-F*y{Bsb^qN6k?y$Cg}cbppP0K z?s>3h>1LM74lTRaT2AMh-&9@$**Xr#Prigwzr(bLr(cZYEY z`Gu55+`YtdDh$P-7&K1Y3jDrZ44$uqtb7B+*r56abPy##jl(tHmFg;jBIlUrQ~ZYx z<80nlKOalQGQBz1KI3A6^*&`=@(p_yM;1nw+7|+I_5y1+zVo@uH$R>l&-7%%nPb_} zN3rEtzWwko+J4ryWh30j=WJk~t$EHd_inlaKJ)gJJ>B!6kX9FG7iO2*f0WE~u6ZFPq=Rc*bEe~A=Y!5H zvDW%}9u{~b?`g_6wEsz9X=cNNoNnWabygvWveNL)$9uLp20grei6j!qBq7?8B-~XS z&{S3N$Jw)^>3|?QIZ9|1eoZ{0SDBR(4kqa;i?yCDaI>qd)j%W>TPvX+;7R2w)kG*j zu7ANAtCjx^AnR`9=-9I71d5;}r?b%tf>nY>Wkl9Fk?jX|l~@0*UHgZg?dlbWy%Kpi zQ~UqqO#d_Q9F}*^BvG|*|7W?e>j?-rMx6twb`3`U0QW6|AT`zv;I3nc0=%9SgUV4n zbZC;Q1wuhp3iMPHy;e<$Y(dUfvrnv}9h^Qtr(ozdS$Ijq!%;J7=uY&SG^P_KI|@&0 zi=LGfS+zjWq(d<|A^CaymI(#nZ4;T5Ore650WlIc-S9G=loBB+5R(#NZR{lsn|vfO z0S|O|A{7A|!!oQdsgp5HRe7K)C|m?#mlb8jzVa=Yasz(qLx=`|m_v1q#+5>SlX1P! z*kar$_!^9B1#g3KWy^`Yjf=v9P~hs*(WSSZ9_;(|`QKdq)zzopd~b`Uyo_-frT`M( zPTekWb?NJ?4UgKE+n!8;(8U-RK*j?2{AEsJAwp$n6ElZ#(dC=)urWZyOiqt^+dvb6)JQwbKaLuD8OfH?BPQH0&DG z?&cbrjqA^ax_>Wr=g!#w&R-5Wea4_wOZ+w~9|s5hf|YHh5?*|Nbj9F5+$8*&F}dY_zh$bQIGk5SC-&GO`vezh2_9I2n`9 zD7V2Tt^w-Lq1&(#1o0(ueujiE&|WCDB=y?qXPUyI4 zl)O$6NFd_6oZ8aL$eRQzUXx-V@#Hr1RP*^knm6C09#th=XLwpbahE}0737(1=2^kQ zTi~;9L+5y6(8k;DQP-_?FYWixb6GqhtgJ zsMsG>bxr{}?NVF{Jpwanw?P-(bS0(O9$_z_okp|7kHE@wi_&6P*Ic9RIgK90qfoj= zui{WhY_sCgZSff-u@ybWUO|cv^C{j%vL-R@H>_zAfQnzSj10gUzhcEYr@&@BSMww7 zZ^n^sRs0YJ=w~J#4Vr>jr&+htg=g`6WmdO3V6dQ>)r1nLSluAe-5?8F*7`gTw$|Wl z5cZ3y?AP;r3X5lTEzK-!Gg#Qm9j>I85$ zH8>~sy^$dyDesrW=zduk9orAseqr>c!1J->P#CsMP9?)5f|wK%`^N5sfn1x6reHUZ zynguLAs}e3twj}XF$Ct|L?F$;v1ZU_2XcEJ*zenOkyTe$=G+rk%NBOFC3$LpFJxfQv zbR3%|H)-T_&kSdWb7W3Y1G|^VrMPzSQW-Q7jE zujsEDcR!V<^YW)x@{K6sj2=(*EYL&j3cSCmS}qYsCeT9Z<^#oo;2tt;>J2&^=JUROPkX) z__H0ANwXlm2~}YfR-p$CI%Ir;dW+-3YiJ5xukjY1N?Unzno}6V8gIh_KouG)T=fJ7 zO=+9lTH{e{dNo4|p-^BwMj9Gl*S8Iwso4Xu^3Cg;)3?@BFIeB4zO~+@kcw5|6w4y} zsJhbxu11!Ajf#%*k#)QU>p14uyKsjfkLh(vWB?734XHn5*Jxgth(!gBiAiBBL8Ih5 zV^Ev;$^aD#%7QtTl#d+Js3a~8W@$JA9CRhoYR;r^TMmzA$)8?Po)}N zDNG^z_Wvz$O(RZg#7T{~s8^#c3pHv8s-I*B!=I}VBK&$#wCc%QJO%gx_@v8v)-GF- z&9}z2t6Y0-Vm>{W{*pVkME=J5*t)_UD^SN&y)SUbikx?i>saAB7A$LB2h^?uYh8!c zuEWdzm9D-5)mPy9fMcdF+xLMmKQcF>a-jkhDsZ79=bm{Z`$n#Ng==3Rf6ac#=1+WN zFHn06++IvOo;`l|bmm0nR&{v|am*ae9?iL-7~INPR=Dm0)m`AapL&98o{;JZ<-1os zdomY`9$)5c(c{m2uVk~oKQZ%O_PtzBF0ty|UEua)PCuo*vsYEBt;o8ItfRg;u~cjz^fxRC200~hSpycPD!{mE+nHuI3Xm5zRdYe8*=;c zCBDRZXX6Wfg`J0f_tx(({PsfO``0$jl$*|+0~Uv8=0^5LiEWv^wCMY!^I_-m1P~o` z<_r)y*Uawh?h@OY>s%B*8GShV*%=^mbmr8i7nvQ?v7CRQBR{o#@Jp(%Y=f(M=4`36 zYyRE2cYk#wb9VX?B=3LDiUwJYLO`}xcgveO|J3Dsp;Ipw=L=*mJ+-&|d%oKmeg6l0 zS!DBM`t^zeHv$}F;xXtnXXPfbF99#*ad_KlYz-3e*=uAxCTo^xDk;gr?J==>wYVxI zrIZM7&ZMjdp&kSpSBaYNP%6PcM7pj@$gsypQ~Y>BcvHlI0%3xrMd->TK@eXd+aHnr zPpG3rFh6g*+g3KcOdKt@Ab5NYWCzL~jCt{!E0IlloxODKabd47vUpazk-B%xb)n{wm)Lt zMpKp3v94`4;~Sf=TH)hAw;xGgJr_d2ya{OBIbqR6DRA zI=r$%C6;1&R!S-B9oYfiZ5`p*syhQ&5OL;tR4i(hutSd8Nt1bzRtZ+YR$yJD%>0y3 zTBRpa{eygR<@scRPz#zs2-UJrRf6lLMp?4(Bsl zdT7a#Rm&KfD&T1rr#_3n7z=5DSS>sx&4iD34`*H>aAstLa}O&-2nde3=ngWN;5{tw z4oQ@(1D;6~ULowGq=N^JNYSDtt*ixaS%;R4GN18GGRnd%LowsS43sT8YZ-#fS_X=% z0VnIexBc_o=1XBVz)g6(O!Jt>8}RUB;ZT!jlnL?89P4f7nemBc;F>cN`9h;^SY19G zat9bT#02+FOuM055qiV0@a-*a?XCNuM=IGIl~HH+f=jp}aSM( zU4K&jnfiiSv%CgsHb=ZP-jg_c-_{hY6CDs8KZkBYWk zG3r5?eWBr6L%b>3AUa!CdRHUj{*$8fWa^FcDKCAmjEU(UR#wM{62|1)UsoQ8^`y(I zud_?ji_^){Wbaz-%VPPeJ@c2AyJv5Q zQ?I`#9s$vNvHrAk=f^DxI?*mV_asMF&V1goa&|c^IuFPCA5`vGxODAOd?wj1*0!$( zQbT9%RlXVPU3b(bUi_pbNhjOIooy=tac8&a=)Qa6%l%&so&K`@i&Lq$JmN8r=XVtGT-u}Xby>i*7#9PQs4 zk$vxD_&!L|%ljvjn3-(jBsI=2Ez9teY{-lEfzqr%;2rWf0nY^_Ab%F}6b1ZNAebfs zFTxfo11&tJpwXKkh(a~nSW*Yt!IF{B`9hNleEpmd2tVGY1?^o7KAMOYEI+GP7olKm z5oqLd2dQF5r@Y=GWpqR}g*a6o0>98Uwu>B)@@&IVu(0LXCZo-pu;{lJYGNQ0C_y`- zE*>FBE1n*+S@t7jo^{95nu2bG5U4zFi^Kx)7oQ^wh!H)ljTrbU#X7CyT?*~THLait zTIHK7U>-3F`D5w|1^^3dU{JJ2R1p*ZoT4L`WVqK0Bmke@ze|yCzBe;eAb(VU?m+mA z@oFU=_!{3ofARiJ`x-$j7~lw&J-iY~Am0E09%ONHXAm31_ROx_X0PMM!jnO|ZkVZS zX*yOHX6stv#WNu;%r2`W4b4wa05H?cl+VjNdIbhQio&DOCaG!PxTK#9aTg~U1`gDQ zw>-f~<~YlSSxLi#xIIv&R0PgV9>_IFd zM<#fhN!kqfNY)VZKA?U062rQ~Bkop-^wPd58C@i82FV-E>^>OHwqO&csW2S# zWe|qr1|0i`-Assv;>NaO9bATFaF;a09v+-i19G!sPh$ z?yvN&2Np+a$17`=?iAI%UbAayd~sYPtD=4L#&oH3q4rwsmA>flSXfBCCDQKkJY0iaQc* zH`$tcUM^E3gyby~)dKjzeqK&S>wtsr1(e z77krI6z38>YqpnC#-`}8hopV}j7U1uly%-VKMJoo%dBPo;u^Ih-nP`S*pcvjbU69k zt=-GJSB7pht=jH5Z#(aH|G~9Ivwi8v;*n(iTGdNws%*X=8*@A{YBkyo3o==weQYx< zzZZVFmb9j7YT(JdLonVF1x-*xJy4Ahc~_?*$9GzDQ+p$e zMG;M&#Ue-p*R_0WMbJjJrqo1jp0ykbZ8|Ba{w-@dK3fr>65<+xqthv#6Uc`E0d;=7 zAq$#<;Q~eLI-xJmtN+%jeXnQRdPQNcXgz#W!~f)){%779W$#Q8B5&RP&vIkiFJQ(o z3KZasWzhEz*l#`rslj%D^byWnsMV9ShdqEvBhPTW+w0*NcS}CgDr1xH2~^O0aoZzQ zAlctYv0d1b4sMDxTpqJ1g1Hdeje9gR3>)U;vMg!5!ElIa)XNoLGUtl68)KIhivA=M z@-q0&Hp-8&eK4Hu$2Ro7&;(pH@hwf8ISk9d0kbmPWRT}Le6h{JLIx5na%{ONSP!TS z!H-*l7@}x{LS+@vzIA(LbYQ(=NA%RXtvq^iy|g^q2e+cqiUspE^Ezdp_b0kiJNN(b z_?`aS{i)aAdZHmqwb9o=1uFRY+4JjA0S9iD-*PRxR;QrTqK%$_PGia3p4mO?RAt-+ zb$fhy{O$?pG-{)-KCvT>CFYA)B&w4$t8Mqlt_>5imO?Ri)!{4BN9UtGu>q*jKg^2) zTZKlzwp|X(9_@QrQoe0a+nXzQLtFb8RZxK(&;s0o7j2ATeb%zsDK8(O5IK5$-fZm6n=qXe)! zpZ{(3Z>l$RS@v+(os+juW?2f=)-MGYgByCemqP77`7HQp@H?4(tVSwR)=Ve5KzNAs z&KJH_*AUit`BL>_^#%fRMe~{IQZdBfV$etDj{3TcqXwO#}j5I z2}EOQsHhfiKWtDXf?5e9wc0?Cc2}xYR;sjB)xZA44eCZyDn(sBntqv0`qQtTJN7t} ztar10Nxh2iJ@0ewW6rrZzvp=tK`Tr=j@Jbd`V=E}qbtDbUjVp=L?jX^G*ljmA;Q94 zs^VF?a!X!*BI(F1#ulVe^uJ zZ*C$nU*TtjEhJsGr0$!c5X$(skat(*c}u=dcWRKEBY{;Xbl-d>#u;E|0(zapM{Ny$ zgGRj`&{5Jp7Y-MqYdc_LU_H~I*Xx`G1$Cqgx*JD_{edoN!v(xkNN#MF|u_Z_+6{7QG(xu|%1#(>5p547$YF zFvQ-Zvv^K3u$!4}5ZeZedT^r!ETwO=WVdfI&t`qQNUhH^vqN{mntAL9^uQ)hOqH`a zYOl@7v^ab^Pi|$s^cH=`SDC>R+FpGE-9TZ_IRzh{6?~poFcfx~=Kj*1{8n^YpL^>hPWlNBl4 zYPu7mYSPnEVl=LqZh2BOeX6D;L=P350)I-7)x#7i6(K$ zB>EM6>JK4Z4+ZMEM{N#frGto0LQFb*D8xEoe+=*e(En^ z!ws~y8}a_!V0LhJ*P^f0IPu&kRI%CVqVKhuq!NvUhM#=*-sL-&XRj677aLzM@~;~w zm-zb66PNjheC%H0PNG2ENfr6sSgau*%?=qS!HmB?e`0o^7~E47_RJmpV4&D>v?v@k zPCaLRx%O;(zP&)tMW3!&r zyYd(F?b#!(6%< z*;lZ`o#60pAm48}94NATu){D+kqcDgu*e=akRc9!O+b|Aozpi@-~N7HE4=nNFxNlV z{IKhTlTUkp8+h9H;rZh3<9{UkS7Eo1Qjfa^=uBFXv`Hl$!{4gia51%t*YALO4?!}w zIgU}QY9wnE1qZiAA&9%Ckm@-D#l|kJKO}S_StG@ZuCYCEys)g>A&Y8VB!(fPj4R^? zOEs<;7MILWTBTS(xu|jJfW+$7sS#|$F;^P|YK>IsNX)cIWjtEImQUh!-qdueT67rQ z<&cGo)#e7VcpMsbU}ZxH*B%XMFa#cb%y%iwn*jgNUbPU$urH^ z)cB;Tg|12=ham;$&6(D+rH~qzCsQI8>E7FK+kNCRDN(xtMUWPysYFa#ISM*ffmQHV z6?U!HpfIl-0Fp^Tag4&O_yF_Rcja4AySaEH(BdK@M>3ME*@AK-eSQaS)N1J$LD zjL6D#R1uXUKtBdQwb!a2r6BV8b8lwf{8{6JmisL~m);wDIJU5@yU2Afkb?`ogNx8} z-H={$hq8xecm08DSs)KBv>jT6o@-g=1gw0s-+0TohJ_Ai50{bWtUuS0?Z{u7-8Fk5 zpD0Ydr$5x6p8R<5qrqqF8A$k0I|5KAgy1K_t3~0}!i7bl-8cngFZ-Mka$J_nzm=a} zWOi=EN*>Zj8~tlOL^b9w7Ren;47bDxR?p?wESo>@V;-<|P{{FFJ|A6Vwt<`jc8Jsl zy-SRDsjeCR{@&)N=u!xl$~dF!h1H9}uiEgj!YOVw4n4fbhzt^48A5btTq2oq15P?_ ze$^yi)klzYNZmC;IBw2053O)b;|56;-}kIm>)wx(kHcE6G>@%SykuyHwHk>W zQuM!K&7-Bik}621g}& zVNcor$)5ga(Q#IEz6{}5w;ET0|L3!@}V0O|@pxsXaH z#i+6$gPbO*S|k=#rN{v%(rxAQ#2S1-PIAu*+C=FMGfFSEhNYH`yy zC4;0ERW$r21I?=uDJ^2yn@CTKg1-dlPOJA@jpJQiNBop1#9M2f665g z!B2hlK{@F&sJ_WKvm9(T2A7+*8|RlB8jXSFK%;R6@@k;z9)E{lW`g--;mE@7-d~^i z`0Ph#7rt|8jUofI@of+h8gf^%SC*Oj{NR(u_uC$|J)HukKpQ84$<^V+zsxkxwmp&F zAAdCdo0GugXyf==5K+SI#B5Wcb#CTq&olA}zRE)tSZ)hFxP1TePp=rKZVy6E|7uZe zvbEU=*bZ9L3L0l#)HQB7sjba5ZZig7@b&-R+@WA?ocrhEWstAIIA_(Du+KVp%GdGf z^Ovs;`%>^opM;mC$=Hx$9eH;$p_$&8oK`jI`h-%sM_iQBs;pS`-r9eYD<7E1sGJfX zA>S z>T<)`;Zlo*wpP&3+AzB6MlN4PP$N1B|B_e8wPssOZh)74f!siLphN)-R%z~?;hV#& d9>7-P1S0s)m?S|oyc~8Tf7_z)%1zD+^`HJcbou}Q literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_759146.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_759146.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63877afd3e3a32b19d0711b807fb20f26f86f06d GIT binary patch literal 4755 zcmdT{T}&I<6~1H7ukEo7<~M|0NViFy1PFv|k`>8z7XtYu$tKzqf@C?K0c>m!nHfT$ zV`(C#3f_IdsH$+=3y2z`Yog~sOxn=FCQJQ9#VgwbH`O$-v| z+!C=4k^~Zokg2C2ntKZtkqkLrr{S4kEl(dN1@bPrlouFeIK&}=7bp>h+-0=TZ=f?G zC$R8vZ{XX?(}x{`{Vp|VEBfiUiw5l$B-)D>Ii?7&a~+$~Z!jO8zu0}w*Kas-gPsc) zdrqAX*h>uMU~GkF^N(=@PS z4$Yx4V?c@VCEudIEvXTj150SUKrCC9u#M1S&Jrm$M=R|$r$CGifs|9D@K>h6UQ8=1 zU*29@!QPllb7`z;t6Q^x76_}KSOhd|sfSN$6!uon=`mp_iwZ#8Q4QY;TGS;$|kgr+dDY2x?HPAd-<#(}VnJnxifu@j`>X-51dMySS z@hf8nC8?vLJW@CjFCk;Ey|iIb!-j2KmPTZM#1|9{+AsT~BcefuCBI-$lOQD;l%h(q z=&dpsX+rgltFpla1;rfLCdJ^$XpYD!svHzVpQISYVWc>m+@elK3<54ekIxKy8HbRJ z&)l$%g9)rGO)5S`ylHUQb|0P^oGOipQLs=k=ztVfd_ta!Ck8!+Y~v_rizjAHkh@EksG2Q1fwH9a7&e< zzA;gbis9DrDIc&4#DD}IA3busz4Kk*815~p+;X`Y3}N6=_Cv_%=v5`M^Rq4SmL#!e z+ogABT$PDy@z`hbgQ@#dsftwJ+S{FJSEqhD@|yc2kmFsm1Mz{x&E&x~exKf( zan~ek)9xnyOqQ=r9)EZ_MJ`-RH?%zF(hbK~dRIHYs8~Jv5iUsW6#s>=j%6Hm z>AE8;^osiV<(0`NuBYDgu`B7iD{1~Js3hror|+De`8aX;jY|JOp4x;hgMcs(JQrhz zK5e#gS^mGchXQu@&Odpk+~=kfQ;(=f#H8kcpySz~NrkGVD^;Q*!gV{o5kTT9`)UXA#ltTleTuNNF z7Vdrwl3242PK}EOEr_>*0dc(vYz>qi zGY$#t;N~&ten#w+PUD0I2@-EZA>|Ac@KH;@2~_~9bPf>=G9UzRnIJ1dSmLcxFzhzz zaud7^wkzz9+z|XHryGi}(ApK20{*aa5(YWuAuLu4-CuR|SALsiJ-}PX?szR*Sn0JJwty5>gCi5tgkn(9iHH~mq#pIdMo@s)fl-43C{w58gCHum z;0}Mva7D%20NB!`DEp*gUx$fM2Au! zFUwRsW9l>24f8{DLkrDwz8PDVvFSg4;q=Uo#mAEE$&uv2)bOM5VtDo4-=6r*iRaGC zGuEurt@pj~)c?`rO?$kleQTZ${rs2q_w=5ZWL4s7n%tFP?K4*rd*L-}i`x?IpYjQ1 zetK?tq4}q=4C|cb<9x!u#?~eGK4`h$k|Nf2?a#2>%+ZXyHdEdF7oMT$4L5SQ^t0yC z1i-0c!&e(`(dH=hQ1=osBv@jEV2xSeYFc4VmLz-}EZdp}5TRLy3AIQs3pms=y_5&K z7*(RNTrmR|xkOtT$Puu%JgfnH6mL#h2m)yO7ql^I;ZFbxGq}pwJEmYcfr2Q;YD}mI zSVgB%0dh4~V7h=@V%2|Z*Z!$xJ6gqJs{{_t)b>9))BnsmyJ?*xMwIN^|5?W(A#zJOm5eeI=_4l{)WUqQ^5k`J$= zZIn(gE92~8Fi;>vRj3*ibcZ_(3eTekGYl28IiD6~Nii|cpaNkjDtb9wtOWc}yacC3 zgD(Q3PY6Z~7gV|vVl*K7!eVqp9o<4*g9}E-q2PuZD2K-wh2i_6G7(l41^*!yK{toF zPSeYxe)%I<(g%<70K5U*Wl?2~el}ZGt6#|0)aw_so@)JkwxU`;`^teTYUcSlKFd}m zLJRLd+tc~0Q@=j<PL->jVrf+>7?~uVA|cY z2jT~^Y;Ce}S$q^(jI8zo(@yK%uc{E`oCzjt7IvqmR}McX-`{W`R|Oz^W7C6y`vbqY zuJ_Gc0A&7dR&3L?T?y!}GTo}u&%Sh5@95Od>ZPNd5cbmh^?zcCrK+3dtn(wTq&Rm8ivIh=5>RRmD#N*8#0ulT-Y=|H{ SUk_Q4t8vZw)*ZX)$v*(r5&;+h literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_764635.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_764635.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c8a52a3b94c39bfea4d68cdd24c9e7e88455621 GIT binary patch literal 5651 zcmdTITWlLucE*pH@!L+EI4`$J2?SF@nn#^NOUk27mxKhC?KUDUvK-IYu@l?r%s8Q8 zET#)960uSoEm9n|QqtuEhem434?#i#?H8;4u`#07Y*(^GKem5?c7OI`wdamKwrkvG zwU3XzlINb+J&!wQ&bilrGMV%kj8EVEWO}m>!@fs|GIT}Y?mqxn!)T1gXR*n`j8Eb+ zt_rIsH8{rLAz6>Iv@p{uMiVmLPeT~3u1)uAXwAEtYsCeVqzw{ONDO)z-sv2 z{h12w7>_Hw5a88=#;+CmamzlEs*__x7#B>}iZ&O49k&UXU=h@4 zCBg-ppi$s>9bozykdD_^_<>@n8bX~=PvhI_YbZt-aXYW%2(oAt%nIie_QqA2vMX&M>$`+SO(uBR-jHHe7(+nE`k~~~EO9J2`Bo#j^>F3yJfc1twL0Tfb ztTz&1B+YEpOH0}UqUHD~%edMk6Ca&nA|S_k=6F^jqcL6v^b1TdFkOHx9M1-6#v_Yz zo++t{z)8lU1X?|HA9g>FE-fvppRXp+63H1d?e8(<=Q56Aj?cqTm816U7e zWoSAW33y;n`Dny5!>|!%_Nlo`9$+g(Ulg``&^_SpKMEYlT3%HUXP*Fv(EH||holg( zduGgJSs7a%OLk<9ZQ|)$Mso$0H`OI3mOWS4jZ5p7Qnu8Y?7@MIX+Rv#Tb-NQlqJ(X zn6VD-KK9kg%rmcLtgneL93SGzfh4(laMPXY{1i`h zT|Jqx^(_&51XkDlp)GkT>B!hRHebou`gi=hXaCTiWwI$=8$k-;92*rMA)17H_ z?GU@ZZ_O_(snKdjvJGsch)*u=bfy-zEW1x;9v{xwhQDr1zk2qYa}()TC;oUaeLj$R zDUh)RKozB5p_ZwHBk4vK`%~w>qXvFx#_WA}VJ9&qZR9K|q$gwnWVA%E3?E}7zGBrw z!k7(ya-~;6BCcTq7D8WV)lfPJjn6y|DO(3pl{$*jbh)}?kpuf^QqK&6CIb&Ot}tkg z(i?t-3wWqTMZ171yE?7`%@5*g;{>dw6rJUnFIB5reY}EeF0WCb3T}=5?N`s^5J$Zt zO=~OS!1nRQuBL8-Z%`ydO$t<=1+72`8o#>s9s#wbEUQ4< zq*BM@D%g3aLNCuscEmBz@Gw3us2<8!(Nvt|4=S2z9p9nAp+gE(o`q7zD4jP>3EGM^ zon=`Cg7`^A+OLZ1_^uL1C--JdP(c^NLN0|~p3C|D0DkEl1aUrmeQo}J>($D;hJAjx z{@VPqenL&3QGXKfG1IKF*~Y47bm)gEQr z`?^mpb`~p4JI!C3V@eYKxn6gVf3Cl;XUgB}UVNmi<9@kdga;XTx(BvYl2&F)1kKC` zeaxNb!2BH%9xhe6SGpP`H608~x>$s}7-JY{s`_)@*%))0WuvU5<)*!J4BHQ@BrTi` zeA5!?hZ6(e>n;NQeG)aLkaTCk>*7k;7fCBSBWa=0@t4>Eq(|?IM)+VL7L7q64F{n? z7kZv#i7;=%(IUFQu+U#UeUipU2j}IcCJ}|BgiFsp2|Czr^mqzA&^N}m!b8$dd3oP7 zi}J1&)T0@*5!0{E6!!-`v;THf1`8WGB>502eam5Y3f-1z@r=C^>9YhDvl%!xrPI4 zo>kA~aq;xhV4gH1l~;bb@y7Za+o!LOZH;|#^vj{ohth2$8FJ*7*^zD<%9>B6sgpUd z{`Td!uf#UqT7L^fN4G}P2aaXPW3te*S@Teu8p@M;@vU3t`jwgGndH&UuGH~#%V4g( z>toYZ(@x;C$c@N7HEtQk<(_D3xoPdpSUa=U$DoN@9XG9qGS)*`Yo~Z7XRTi|ubQ*g zcJal$*(#p7Wws?A`C$I0qdVj1PR(Q;&t%QV#WQ)M?F0AP@zvwW1!(B?N79B~@ziaN zJ#jXpX~|KxM1PWm=_b{bq2M5prCO7(-*k3ooZTsJ)_EjH8CLYm`b6*hrUbXPxVpI6 z@!R;O@8i(b(2nh+aE`L9n3v6oi&?5UIk+*jK9uUnwmu2uL`(8Kc)aQOn~BsCdsfV7 z5l7g=@UDFDD-Cq$k^Am6JmF4?$1z$J$7yw3MQh?}z%>BW7z{Gp(BPlsysJJefW;@WtsbM?W7;zk2STR%0W?mp}xH{rTndd8#fkw%u^u zx#irQ2d0@2hk S{SOz^n8lejAGmCgUH&J<`#p(zPA`tR0Mv9mjT&0UP5415O$zk}>IO*%@G$We3j; z*haJJxN>DSZ3TX)yjD|zrAmf^WYU)ssZ#7GKjekYS^-U62`PRkZ#Yn?$y3jr{b57b zrcKp{UTN;!d+xd4J@?!@-#PPz&1ObWzWCW6`fM(QzQ&zGG3Fa<-$COp;t)r~QD^Q= zbP{@>9H2UB0`Wvl$0aF3&2BB#sCdl=4ei=N2?bs9N_M=@s^T;ogR zY%wu@&kggP`f-{{b+xmI)g#pZKu zky006=TY#iu_Sh$OkfOLV$k1QIASJ0o-}WvNftFNibXMv7*HuTy%`ytO(>?B{F);$ zoU|%dh3N-L(pIt-Q+6h;urqiUP@2(4zk@k;#a=S2{CaF*J#Z}g1<9cvavLD%t2-34 zLSd^Fhu(K84u#gOah2e*`jX$P6r?!u48=7|74VZ4N`+$3(X2X}dkYLgapy1;+X#5P z1j(~umqKGiUC+vqg?7?gf?ayjc`FrC@#=FtimQm8_Z@+LH9?~rC`fmR_$XHJX$J|7 z9tcbQK}I9xxMm&{gG?Og!^9j(CH0Bx~ z?d_8_N*IzgR+7ai$A`L$t$49DTx@Y7HcG@NC&D(7W*Urgq3*C8>C=q3$wdb=CL+Y8 z5Ez0Dgtt+W!=fCD=Lh1NQRwN(S)gmgH42+!#?rh=^l&&X<@CTf$p=dX%9ii0)Ls+B zerYfq;cNTCk^ZpUCnRdZy?jEhmBdJ`%nuCKLa&w|=;k>tn&=I|Mr9!p>gUA-A3rpB zBLu?*VMG8QymkEeiDNLJIX5Qdw8hK)}6@ZzH7TiAsCbi-Hdk6#3l=ja+Z2#F3{VWQO-k(Y@6TFET&u_O+ukz*OdkN z#+zdc7%v$&6~~D!=9eGen9nFi#V~80$v~*~6WQYroY>Orafg%D@pD+~X z0`cfD9}p6O^Y4DJf_*woY6eceG01BS#}7v%{7L|%D+oT{2qw}(!5_67p0!g@{f4}Y zQbY-5mpPpjsX%zdekuf%9un*#J9J_Uk{Xla88DNJ4hJcXZa;UaNh8i^M6-l%en3^l zUGUb}`gnMtn+u;F*-_-1L-lbX0vvc68o9v}pek*L>K`gfyFO`|xNz^n6ft#d&Yn5g zko7las+z>tVOWn5BEIlJY7<)rt(q|^g|OSigP@?%x??Cz+6JjF3{EmYm;f#}0Bg!O z!~+h7q#Mf1K#D)nP-3t*$tivNm zp_1?iQzZ4q#+)-uMHh&TFkXWdzlG{s6)jmjix&Tv7XPGaF>p8&IJ_8mGZT1o-uGpo zVS(;iXzI#BZE0AtxW-PUPW`ezeFT&jX#ax6k7f0#`kT#alX?|^#u`!$xz01QW9(q+ z;P}<_&h&f0o5K&3`^tRF=k1@hKQ*_iXP>*bExLmlcW`QF*8PUswq$pz=bt&ei;g`R z$DYYZ)^R|6`?;k_ZGKLB#xG^)?Mtj{iM21W7HBN8RT;J_?aHz{(w7&jYcka}Q{im& z;kmtks`<3$5s|Gs^Qh%X`{VY7pS+)KzO>}5oZLU}|GR}TFspWCb*QcSp+>+1#Zt@; z``{WuYE2}8?veyYB}tA>QqZSgeRUNIK5Mdvkc*II$<#2*%-|poC21%@ER*eFswhD# zLz5gv%TqQGRPoa*m-<6C^q-&^Wn0clpe%#J=({r(EGJMDC0WH7!{KOCDTDxp<&5<( zR$^6|5B>*v?Uj~oX%&mD;vl$i=Ksl={&&_{bnC3iQlk7nyBk}cfDP|aVS(7~9nrtR z_ZCHPHC9xRuHb+q?!#3OD$OGEk{pVJB|dbdv{R=ua&#T8e5v@TShH!?u^=m+#4;mf z?vSad zUmfBT5k3^>6TNbuW*rtlQ_srbhXH1Ghxi1Ce;!PFG{woep1SEg4)*ZUP+XQI6IfSd z>>OX)^<5X&6F&fz4tS&q*sqEpqpa|$t;?QDwSC#QO})76_Nr~mF0b0UZbdF1&YhQ8 z&vGEMdm$Mv-!yn`j{6uMNp(Wz8*Qi)j)=Uv0A<6vpK{1~NT;dawAnNx~KYz`V zyD*_%AwMr2L@insk1=;?J{L6&#_V%0GSION( z{i+$kA5B-PYt@qL?cM*N>wd9!hG-_9_#gYf!J}&wBCWYuZG?w#q`edVd;V1f&8eY> m%Cw?49~qulA6wV36mmVn_6-{&2=|LF3OTCbukEHqH}{`Ix%({u literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_804525.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_804525.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d7696c4e832a62d0ff64b1e3e00f5af55b7e06 GIT binary patch literal 5651 zcmdT|TWlN06`kc{xqR%B6h%^$B}WPpCu}4u@-t52yevzyV@YxB*tMdXpt(zm6d&^L zO0twCum%z&LIa3G1Bh-5NI^d$)dsBOr%F*EN`D%pKiV>_g{=b+h(D@7C3aDuK!M&_ zaw*!OqNb0J4!HBUckZ1#cV^C=;qPhMh9EI-|8n@c8=XGVou!4Rge0R@5K8HkD>rzS>^x9U*64%T8 ztZRp!hDuNO4n2*Po^_I2a!i>K%6gkwE)UDYnrGVsTi5sTt*t?Lj zjaB`r60JIy8YM=zD!XgPzE0LK1!uP+(Nra&o#|@bU^!^VGm>^mjWsQ*EmIz4qzfdQ zswJ^LJbt!OYJ@$E-ZHQ#YQQ^QGGh-(K1{{*GL@|Ot+~q5T7Dy>X3f|B2lO`FuW<@? zB5SVl?`~Kri#E1bqHuY?yr$D9S#{2qZJ3jNsLCqG{l3F3$sj$XY1j>`djEWD*mZ-{ zeN3Q96twj7c!O@>v$KI~L4#^ag@qB-oDqjPJ_aN$1qmLbZqT9{L^$Z9d?v<+Q~3Cj zVLqIWajGeq3A3s-I+jd^lJOBvHBW#7CK<-6;4pe2`(NR7FAx8MwKZyqCy zY95WVAg&ccsQ~i@?s5s4R0~_KkeO&y0NtgMRVnPB5Oh>Yt;WL4!1z&-SFK`ZgiFI* z1l1ZzhEt=VWM)FOm8!!qX_$Kmt5Ypp>QZQ&i-1;*Ahdu@Ejy&Y4|Pmr_z__=9N{{K z!;z7&IGjnhhhtn?>=5`!hsdQyJD}A8+Bueur(+@Tu9!)OMmRprCHIa_hJdXTBN=e~ zQ%@c^xbGlvRM(bNsXN~ZMq!T$pTGu@(LDyy&Y8aHzILT|y}uF&mrchTw1 z2NupNfi}h2F88j}o?J|!n}NZ2^X3JI($Jn2CdR#tHWc)uR@(1!3g=txMMq&0YA6j|j zx5Ub0A3nc&LD}E0&;xS!ZN@bdn~vq8`SIDzV(&+%S4_+2mA0d+D7r(ol4ul z+WEmRnAgES;HeRXX$<%ij))DbO zocVFq0GUeioH)j(Blwi)sSYYy+wLrA>d`rSyCdlxQugQNJ6aT6aB()FXN&I$6JmH1#cM zn`EvUvHJE>HVI~D)yLYQCneeRo>61gE_(D?Np>xIAqi#+tKrq@TP0h$scgeH>6U|^ zzgxaGe@ou9gM4lNPV(l8eHMw5tdb3)nh|1E*`oxKOlNSkS~93+R-7E=RFdV!;}LG- z1@M9v!+5+4cpM;uq$I@8=0 zF*KT);CO)Q&_2}^VdLW(HmYPPCf9lQ~QA?D@5`p?4phd1CsB z+~s_0{(LU8c<$!V($MOj-?e|*{zuy@YuvDW><;73J^Wtumak3mwJn}5_&N&AlYpVl zy4=3^#%}=#``Z^|1^=Of^RV1obkOgP&%81HM!qv2D>(PuAaB~2?8`kXKP)`ny=L!` zyY86kb7vJ(<2vQc`SKnpW|(Ovcey|X@`vZ2yY}2-YhhR0I>pS;({wIepc?Xf=G(8e zFA{~uAFNaM8QZii*ZE^QC(KRFPA# zW^%IuIh^uIejhj{0$(NJD6gT{Fy!zyB(m@V%o40IYhX=TBeYFb_bEWJ2a{xkQ)Y+~ zVueu)x8Qz9UN56&)?B5r7K*{it4c|+fWeasQti`jl=#~>G zjsT@C31?-pm5Tn_lPK2mEKsYg`fquyg4MMw+tDftTg6&oO>O^^HT}=5vuoBlvP6~N z{?GQtjw?XpIZ6~DGt(dc3!Jw&f~~P319b!MTtLb(Hq0NyRlCRuVki<8xKL+xW1j}z z@q2)6)#Uvf=mE-sAjLn8HCf=j0asg5HAA+4pK8YKv}%dM14L_2aePM5U{f_ml9@CY zwBkP*B4PN`AwI>abOl>OY&@kp;VM7Ir6XJ@$)#iBaMc^Z8c&ao;Xf&asymZ~=Y+AO zCN?7mx>v9Ug$ zH%HIXMXElRSa^Qz;RC-p{>jT9zr6O!xqD`lo0LyL7sfO5`t<8XsxH@e!+W!Nsd;4_ zm<%a*1Jmxpz+a^N`Q{tk&D2tAwHugrQa*OC9+{bU;(6af)8f?1{x3|=Z#t0EE%y|g zTjnoZyYMr(eZJEN2>!2G@qn%em4NPgZCds6$vZCZj+5G1owq^myF=IgdvkX@IQMVa zOCg6x?$=^TkmwIOwSC4P#hcApx&-V^!W(xC9$G4;*LaMXa7maPg=ZFzLyl@p#6@`c zWzvGkT^Z%eSA?@%TFCIK5q7J#DYZ?Z+Dp4A6w4&pC8UjA+CtQ`sSG=orxRK5=~P zC^v2Bq1O52wdAH<SBfm<1xYayEyhhx5^`i3X#Y&G9ks)p+uPDhY-)KU2jmY3A z4fb_6;U}EAu{mj0+C)%Yd=1|+wCZObtymlUQK@8;YbpL#YEpq8JEhS!5$Z9VGW-n} kXAEDLio00P-4_T%(BDuAg7ADjWJJ#90<-(7U31Sr0p>Eg7ytkO literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_823958.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_823958.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97c0cd309ddd31bb9ea80bc0004e2e1158bdbfc7 GIT binary patch literal 5166 zcmdT{TWlN06`dvb#rKShvF_DV`{py|NL$pFi zO&=c}U}xsuJ9FpGojYgl{>f%DBPb!)2NUOA2z`c$TCr9;Tm1w=D~LlJkwoK_mKY~= z+AwV#Ckez8le(VT)Jn@4cyh9)r9s3QAHydf!>5|_V=PC%ORiNlk24&@qe*AOLe(zE zK8A1LO<;?OGd)IbZqAQcIm^4`xVi2T>$_;&VnDp5Zj1FI5wbnN>yK&Hx%Bz4C`AM* zoZvJwruYWxu{+Yl99}U*M5SMsh12C%16K4P(uQOq<*$ z8>h|zzfC5w#s%<^+=A;&JC-egk28B3bTxva85?$9j@%3sH}s;3BufS(&lbPSEnyv4UPBTUoSxkwUbV+;}aND_Cb`7mhGD1Q1p4^$dc;$B$qiLWmIIM+9y z7N*3RNR;oJh(xC%(nLDd8;SENsZSK5eG)%C(+5?DHVeoZGW1!-a!j69x?Nrw|HR^sp7#9805%qe|zdHIN;=XSFv_ zC7U~cPPKLZ_T_7*uby6wtOqu`UQ}%_W>1!!K80C!WKZ9NA&2vEW&Wpki>22{_tZ_}#hWR5MmeTr3ecjcISZtu#0Kw^Y_nht(W|>jUh1b>Ao?>v{ zuOqA8LR<|VT;G4wd+Wl@mXG>wA6NTNslits8VT1Bk)z5qa`|CdXIEkFgUGsfHGcKS zHv?+-pz0j_g#JXj^ZFZ~%>U7`F&e$_pZ<>Of9da(!E~JXyc5|EY{8Kr%~Zcj8eO?Zni1k$qXk}? z6;e?g#kJyr4!&1gV{oxokc=j4l7^NerVc>hyFgmSga#qvw%;K-JsU0mJ&!NUqvB$Ou)}%<;loG?*V>1LWXS^4Vg4D1;DDAW>eyY zS)K==rNoKI3^bf{5gNJzQ<@{izXh*AdY%^m6vO=*8RZglIu2=c1@S{x;SgT#Fs|xv zh0qVAMxBpH(Fq}dML}G3;0gzfZj468;XM)YQ8xmu&=2vYo}8)Z>sSddhcAs~2Xo&n`U5M;<>aL^*`d6#Xt6IH zT|BDz3e2i!qqX;rrFVnu-LUi)Ev}`3#R28Wm7|xB-m!FTkX;*=uKS+0yPlBh39UZ0 z={cAkDcW7xQ}?XyrQyY4#i*RywDznHZdm%WC+?Hp{A()NUSwVQ=aeHw)|;P%{#~|J zWm^^3Cc8&DdpFpt273YYf`^K%eaW_H%SSdry8lY=xExuDn`wwiJDxe{-2Ay6i<(M|Mc)1J^|1s~7x4=;$&5 zHT-@M89<8+0XNLRkr^Yc8c+ue5Aet5|FmXM!^||4+{JKeNuFTW8G>4SxGSi;Z1Rz=roIvw*wyvBY2Dz9kTZ zhVU$KAK=IZ#GK_KfW9wag68OXhh^84Q?R%b<|Eg1e|MRA9;XwaGnXV z8U;ywbOO@Y2!MG?(kQ5o^y>*fm(rLRWc)h4zzb zB;d1rD$3(5JT6T%p!qh!>$nb?vpAcSL=nI4b?C0(wq4&^U9<2y48fmxr2^h8a{IGG zkSk_KOa8X(nUcqs9Vxkd*`bG4$k$CBnU0erKcC4F0n02@CN_k^!4f6CqZbTv&SEL z5oOOM6n`PGx^VNzr{uA+6*&OWOTj0uoV|SZS8rxd?of0Mftr?-38N9})#I&FpAt3d*(}O^>RH7yvPbazSNMEZGGVH)~ znww4XFA6^boe(BS`~y^FBS8?KA?u%!?Jp<*`MiOBul-Vc+3*x`oB*IY+taG1mXGooDn&K MoA!=N7JZ-p1W4msNdN!< literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_830218.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_830218.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3be567f6f2ef84a1c692b64b74cf4915e4897c2a GIT binary patch literal 5264 zcmdT{Yit|G5#A&3j>q?#dYF`>Scy!fk}Nq%<5W$W*m7c7vfJiW%WespcYH_`Dc_x> zSU9G1kRnuDKonX)=M<>vwt(f>K$Kr~g8=of!a;wOX9X~EbpZnLkNO`IxhT+IojpE8 zE40=0@uLge?#|5a?CkFBH*q5lbY|WVIzG z360jx=qE`6am19Sr#`jXvO12Os%vQyvHAz_$p`SM*8C{L((jO0YMLhvtbs#Qj;4ib zU6y$OU&k5279(qXfZWuYAGNUNcgRUo<3ud)ph>e1aps0Cma|0AdKa&MOts9XPlp6C z%!{D}tC}#yrC8iIyf}XR;L#8asdf%c96vNV9)XK!6nw0D_%=V*;Y$$VDyn4)=Mp-2L6>Kq- zLnu#tAKHvp@;2GiTD$!|?LNsTS*BsVOq=A9eArIE z9a4v+pFRcr4vEAX7vYSgHe6>qv1}2}IJ2uscMDgRI;CB3y0v{}b^!$s;YT_aMRgr` zO#?m&$&I;~U*j^EtJ&^v(%;He4I=NMFUa+DQ@M0YBA~=qw6S$9U{Mcd>^Cbt*X%7N%7yEyg)M zNbp9?HK;lp!}r}s4XLJ@Y#}t?)Mp!`u(qu^m>RlI zHmn*1?o24jrSLnCgyG#M7CBYVCT4i>1jN(W0%|tFhF}}GIjz#^XjA~w$fp;oFrgAt zHEv3!@iB!`DickH@fP1y^)O9k=EF&tjWLAh!b#kr7DAv&rMQ{X98hUY3A1O zacXcO%})!n;RrVv4@aiMVmzJd568HaI4JOuL6Mu89RzAn)66DPu@Ib~m`;VJIX=ZD z2WHQPKwM`=(r~!X?B6#$^bAN;$JVUsaDD)6!hj+C7+fKX?zJPUeR*POLhjzM__L$8 zEw(1AY<1;NDc0`aym;xz#UpFsb^k`sbBgu3?BTM*D;rkq*(0}M$WT5eFWj>B=7^Hh z_4Dtpo>(~{pDFZic>5G*Uyd$2+vHxw*`1?HZofRQxF5?gw{4E)*itM%Q-~|Rf%Vs~ z4}EGI&FRaYHhDzx^sXIT?^C`uqPV|#omSk3b4^Cug|r@~rM%*(*O&cw;TR z;o6rolpSsF5%QC=L2*1%cvf-ruhZ+|N8edr_|Sg6PuVxFIL33-ZHH^suwsx+h1Zn! zzEa?^Uxn8^g_sf;SbzMQ=c6;%+CCh-eoz^HNeR4sPfs||dcEf(V!ikB^VeQcp7@sHd*Sbt&Uldcyc^l}Zo-uy)!2ATDqVd?svi7Y zrFl-A<5Lmr#r5Kb4t`ggd3e~XNJ3L}NmI)b(|f`3ogl4YLX!}&TX6`_ZxI6`0+EtF5ce5;4AM zrB$LNF^y;p2}#D(G`bZcGm^26{Sp<`Wz3?zu3a)~(FjR|L;`Z_Vb9IkL~YpCsXsv8 z$Rl9451-#!zBQlRVSa1*PV>q7dBN=%Bt|kx#;eqonxzDi$X8I1KE~rfTMar{=-tKP zC#c&3;-FC_Qvl7XaV{mCnd3NsYD$QQXQ5%mv(V60pi}KB?oEgz=>?7lSPc!UWQ0x3 zYe=Tj)tC{q@cZy`L%3>08-5r_l{y_3BXQo3MFCuO;R-v9W{gV4AR-C)s+*yj?+1oz zY~UyV0?-&dg&_?%D<;G^-+k-cTe9<}p|j)(-178d;376WeK}*v+qD{637tEZ9nF2c zpuRHdcHiese5dBojq*qGWb1k1UPI`!5_gf8eIMr%3h`%{_PA9k<*;#T{JRv*8}d zj+boC>`S*TuH~akM`gWyYQxgEHd-_fXAj*WJ^5D^va`fE^H0nBON=K!1^ru0yTY`~ z&JAXleBxH1UkUUBtOoX#7~8US$(j#ufb{VT{pb5(Yn^>1#=LA=GUbPUV$BPyiz|zT z?w@4}kxNq-r`Dano+&Z*e5d?M$0$-!xAf(|}9BTN% zAu@nR83JCQjwLgCpvk7(2z$@A0`Ci6S(FeP{8lK$Qdj9KHQdZlO&VJu2IR2K+OPv} zE1I?%PG;yPZLP9a-B;3P5UusaG~~p)Gt@07PyzwY8m1apQy+^^{kKPASi=ago2>e8 zo!Ub!+tDfpTg4jTPEG%lJN?hBGi%maGDOq4{h#^9jyGV%dz2W!W!tgDU*Wwa5V!{a z6maii&jm!CW5a;X&tigRx%otdQ_UhLh@nVW;2_a%B8D|oPuv3>Y-aDji?(r|3^FPO zzX_2zB+6j``;@3sQ0*Vqeks_LYKX!=g+?!Od|J>zR;41zbczcaHMCZ(4G0e5XkmvG ze~wEHX^RK}W{0SGT;LaeIFFOX= zXm+CP>&PB2yS>@*veTO#yJvx%aMikEEn@&JJYRhD$&W{WfAn`pi{F0z9z{Co>=77Z zb1$DK3Q>Vq^CCc3oj-&^+;!_sTn6a3A^mQ;6LxLAn2BSyKqKP z(L=xN-(8`gg|&8dRSdYtAWz5Y$jV5?Sfw88yEJ}rd=pcfhY3VjzhI^a!u|OvJ+cQj LY+dKf+CKjYiAZtv literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_837397.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_837397.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdec3fa87e0bdeb8d4a51f193fb6fd384a6298ad GIT binary patch literal 5532 zcmdT{U2Gf25#A$@yyO3`_-9F`lplwP995DF+liaRmK-^@tkiI9Ben#B=A9%;{E_cY zlBF&SI!zH0H6RIn5Y?nGQxqs23y9GtrzlW8_UEByWhBJa1qj3_sy8}PU;5I{9z~jB zDL75qhc0=0Gdr{2&hF05?Eb-I(jzE8T>p6Dj~;}+AWXHGc8IM`3ZWImA&!co%au2E znNoOFOnsTA5Km1iIMuF|H>cw1N#pJ|!g1<7+Ue%)V_Hu09(`vA=rW^2EO}tAvZK{A2V{s_cWLFwG}YEhb|jbh&R-1GR{*0%O){pSY}QPz14pv zI4rX}#IS6vPy=V*8aO!=t~Z&5E^2tTI-y=c5-fG4OscBXlQ&1**^wu$*ox@|4MItq z1a+lw6sEE14#9&mfR?CH4m9l8UWco#Nb76pXS%a9-U0}c4(!166tpCr*oGZSTT2}V zvAUsPgj93s%k^hRoAPTi%{5A9eI~-6Su({PQ3sayL!~m-Gt56aR<=oBe33$ z-F4$PpSig@zA{S>_Fxm4Id(3r<_9(tl%&7TqGr45$NM?DZa+J)3hz_q?P(aNwhp*ekvAKyvu-pq4Bt?2C+@)e*EH32 zqtMNt(&*-2Y4GHm<20H_0qw9r95tYl=~zgdk~Ilwf)_T)XE9zy3s3?>bXZv>MP>c8 zkQf(2v0#LgH9{yp&db_Rsuq3e;@Obm!fTp=`~lo0s9GqO6uiL!Q%kBm=9 zvN|y%$yN;rKWQ zv_QUFfhHC4YFRxEa`tL4F*YXhlFW$wHA2gp2zNagm$hX6@Mge9=zU$*Nr@>w4x}hE z6_O?r@q?jpJ}z~OLbzMvW7FNhbpt-nagq3V5SCI( z#Di135a**WP0t5GTx|>|V3m*b9qv1N7$mawNw12!&MZo}7Iuy&_4MVl*Y zTlS{4CA%+sp+>e;*6d!Yx;*VDRfEoyPEoL)Yiy_O4XIJ?q1 zqF3QW( z+-$*h5LE2@vV*xx`Rj$}4&S#QN%xoBzU)ZB9mrq0J6AaP>Uvk9Yq;PZPM><@ZHK`- zawDI4dzUmtZ~M>fzwmzK&G+P^8y&9{yss>2i{AaYj#VEJ?)INt%RQT&E4UBjpTFz; z<+Z!++ua5C(IvX9MWE&Q*||KOyYx}(U45bB<$~wsdl%OSe&4Y^_-U}<8Cudj^tIkP zuzDc(T%No4%md%6OAP7Pnw?lR=j|UK&rA8i+xza~LZH9k>0jpyo-@B4yMN)GKV2EQ ze_`ZziGpVoFs$|!cA3q#zS^FNlk+1Jd3-a_m_9EvCTd;*lu#?I`sy0%!lgbH| ztQB}^Mu>-NF&`R$zn7~0!c%u zCVh^YdI3yRW0cgz{h+%Sl0InwDe=#S!l`QcT!+;u6=srv_7%o7NQNpDr>t^1=C+kUOPqY(Z6OTF(^D+9-tS z^C+Mn7RZKf{u9c=g-a1>!ryanzz@OKKcATK$7V#yA5O%kXC%H`;Khm1H1CJVD?+~T zO9}s#6J6&={RzR(MFc)9Md$sHrDln+7f?XzKRgKsBTCb0i$q~XdHn`)+htZYAzC!>DdTBC+jbUqBHzSK}ZO)2C`guLe@YqmgZ$` zB?8NqIDcIVPABGg0rEqzR|&4Nwvyih);&MZC&LP4(*(FqAekMDhDe?h1+vkCi&PD1 zjvj|3Dw5T0Oe{hVfzZ{Wy>JTn6Yz*flrUemBC~UGG&7odqiD1*4rPW?CyNH#;<3!J z8w081scR)h|Nh*KxgX3g9e?-xC5J1+mMr$g`OJKVPE+ZgqQkXvWcf(;#E*}q*%D*B z&v=V1uC#tvw{1PLc#FpHmRe=emyt@RFSE$%*b_yc`rbpcV{s}omF>xfa<*(B|LuFO z2jXO-(8Z26DsZ$Rf-p?I@f+LXc*l={EhDt_L`cP`%5$#-hyFmMjthLCRi>#r@ z>gxXHMN`JK6xv`tAX?OC^h-TIGVMh=i>w_IT+s$|>iBzyO`|P8i;R}kX`v5RFbq$v zkQk@L&1<$W7QZX)gYV2LQb=$y|WX$lbSlj zCg*avP1K_`IhiVGI}tajtwU>a%Yg6rKS5)j#6YZpNJnPJR4u1a1VLtCCM~te)`)6m zz%0icgVs8${#&b7i_Xm~d#Y6|v5M2dp6dT6d-~0+Gbq*>lT_Wh{h#^9Q+L2b=7?Fy zAm-u7-{5e`FhiV10AnCTeIo z6kvtp1c`zB4xF+jxI4t9zFy@X#KFyA3@$(lzrYI#QHigzCLB$~`GAgGO~N6#szere z*;EU{L6S=>a2J{3<6%A+<>TYh#FLYntc%1U=_w&!Im0O@HwIGROjHs@k~3>zz7pTf z$}DU6!U(9`gh#wp3GytmyHcl1PIu~T$>mMGRdTdIz_hibPH!8L&9!1$HkDZC(q!(n z`<;h=eeySFK0R~)yBD`Lv`w3O6I37=U&&l4vG%32Yb~F&-EO-#3qrFt^#%wH*2Nby zFP2z$wr!37Bz8Nt{sssQ+Eo9x6KTxpNY<5W%`e>RdqBTdHX?}KgQYhAt#?-6`QS=w zFntz+^;ff^0jm)Nu${`ZoT<}~tSwKS)YIWwJgKvfO!lugcVpw+|IJ<&86Bx%C6@%$ zia!ny2TY2$3H`+7ES2NLOHsIS&%n)3W~-E-__=yABFVaNA}&h&^=V<}l<_tn7vW^0 z1~*l_QSk=ZP;sB&cp}Q(M#`WS-(io%65LFbe^uaNOz;;Zu0vH;Qxx?DGX5Ev{(@S| zdZc3C_ucT7RejW&p({w$w;vtn`_s zkk?mq`HG&_$A<4vmZHwQ*qUiAtAQ>VZRF@t)&LDRB8RKY5Q;@MXIV!mJd$>T7_7;tk-yzaGRI0u6|kfV&6;0&wBc0 zstEMU#sbuH?xmg+eLi!Ac{ns`cpks6!hsYZoP{<9{q>C@7K-$+c?(WenbRU#MAJ15 zs+3KxmyBQnMbko2=QSXQt)f-bjsqratI(ybI+7HP3)l;2J?OiCfiX2=O~t6{`iyJz z*+siZD*7CvSyW?-MZ022ZG|MNpr!IHBC!_piHIt()Y30dn41$YnGkNtwLt{^B- zB!F~<5b%JVEGN^$0k7Z*4G;4y*p3}BOnU<$3v+>`-eFd8&phMhCX0N$Cm^rwGcwIi zTq==e5=8Ja4H`-)ejC`kebe^$i|sQZZk(U=`q=hSuW#HdjD~`Hy(4T;Xy-X!yTDFN zwnMKSRIv=>4~}>s%0ejU8E3g58`w8_)dS=b(iegWIM{jM*#kh3t*e6y+*})2gwF!r z4=W*pp6o<6`z_N=Q^K3IZjYS#quCmth)pEAv*vA)Q+bnPR+LOn{p8Y5{Xh0E5{pN& z-*`bXy%0H>GudZHCDXRZ$w&6uThHBmF43L1lB)m2-Wk>9Z1wMT&L5mRm^`1NvP~V5 zts|<<)wz`?LbwP4=MQN;YBsMbZPwtVNdl&m2n!BUwyt5%OBstwF z-(s({@2FIJENUnz>P@LrLs|QwD3xz$d7nt`PwM7cqUKyfOOkxwTj=DS+upsL{90m0 za_&ljvd+cVBd$H%0u z!H1TiD2Xj~OU^y1;JvHTzSENP^n;_4a{$CphBy_YX6>`DOSYZKo<-v?&)+BSy(&F@ z;(_*okooq72QwdAq^Do~gc|tl>&Vgd7$S~1oZdJae&IujYs4=|&`8#DtT4p|eRz47 z8fI1!%oV`iLLwR~Q7XP7Q4Q-;T~Up*zw$TDAC(ouC9<` zzJhdRZUlQNG-HkuwE7jbr8tH43>h{snlQa-6wWZJ3W?5DNhL_sLIN+Y!eLd|2-5H- zS=UOl7e)}NrgcQ3VaU1y{*q(TIZs z1r#BY8^#^)0vk|UN~2c4OnLrbjpKb(jiVWb}EgD2x)Zd{Nie)~5%O-vQGG3IN4qnUugtojLLkN;*p`Q8& z;N0qUC5qcBYZ-Rh?}MU=5w1?c2TCWXS2X^haA3dk>FSipATC`r{HPam1g9q0fK2lK zYpkp-q}y~oav&y8njcWU6S=eFm!n%~hM z(V4wx)8yGSeKtqi;)i306FV|=Q<`i_(@hw2BzELRPoz6?IoEaM!=oP@z2BYbIuX^! zEip@$YDf&szc%+;>cy<9JwvrW=!kUZC|ledGtY8asxi`?*ID1Xe&c$=meDolEH&|~ zv8!*tp717j%&{5EQz`Xb>Nd4lvtYbOWLgiVse@_D!Tegbq<7ArvFu76zkB-j>BW~9 z`tEgS+MZ8S&!;WVKe9J0+1--eo!XJLw?+DLW^3dWsEM3jqt|Nt;elI+ZyruuPPS$3 z?P<%?Q8MRnEjd~xM{CxxJF3r_ZSm3A=}bJrJ% zOiM>pzeLqbRDHtskZM}I3_QhtO_pi~_P8l#n)PO>`ZaJv^L+hWeX{%A=BPSuh#6)( z-Z7;*a*eL}*16W?`FBC9g6@3B48iNY-Me^s;arY##ne-LSmJg2cNeU(x_D_B}kE0)DqTKNUL)G!I|@Ar0E2_HI*i6AX{SZbS29P zBt&iAQWWX zdVD$H^TK5TPL-x|W>PK!EIfO8iVgZ$Pk;@M2&1dNTC(0BoP>XKNOFarFy+@v$FclW zK;U^jXfG$eLeeuUK9=QkL%OAF`9ZMcd@7({Z6Td$Di?iu(UU;G*ZQ95gU@_a{7h@OmRL$(UJGFORw_W$9 zfoRr7P6Cm(#`nbbkfN&;`YS-lR%`kk>gJsNMnxr6V7CF>e{`ohvcypBeK{c zr}D1t^8<4OKe`w>6+H(B{TH)hovigjK(<41%Ms~)WUbxOsjbD;)<@1gGSz%F-;L{i z|2KOnWVA>6m4pJ75ME;3QFzIgLQS zMaDlM)1Odtp3uG3c%yMe)e2YcEARs*cY%o8gQMu>llb10_^yOF-dJ6j%v$p!0BeaGr+>Ki)RLG z&}=)d9C_EO3Mj2u)JR|{l5uUx{3KGP^k=00VMEkzrc^>oew2R@h*U-W(Q{`X-fY$; zP1PU0(%iZCo^#JR_uO;8bMK#RHV#3VdFKy9haCugjyt(x%Qkj60--r1Ab|*@zRaKK zBaA*dO7+nM5{Y3$PHt-EFOVWV%ztwX_XO%W<8a(7Xml|-U7ubZDFeEcR>+pC+;bzkb}^7ra& zw$ZEenb!H9OXoYggO&pG%+RLb+hIxUOa)w;DZ|h&E&}m5OOILh;3V^$oN7_cQw%~e zYr!14ST;f`p<1V4xPZ&_Z7zXO&5yITVeY&hUgcC4>)CTYTR1;vQw#0-4jNTv1oX%3 zO5TrPQPwZ#(c3=co7i zO{K8jCsfY*m`k;)Wd@BdkmoLtS0eaTQgs`1Tp*FoO2iQWg+RZ4hXz}HbgzV6yop!k zBXoK+AdmP=I;n(pZcK^}N`X;-NYLqUG$81vfE0)fiaL}1$x2ibeIA{=C58rv6#tka z>E@6iXByG*%m^zhQb-W}w{$Ms0*cN>#z*~1bVQ8E5;jW09+ALVNTx)zYe>ab+p~6MXXNdi5QDlMj2Ht8vozTQ zn(IrIudg@Xib^B$SRg1i4+VlF0c9u}X$%aC5v5s{g3XFJI@S!mW^wd}C6x-rq1 zB-XfUtz(_HPxmBxk_~Hot=5@#dXnsHlh(D)@zVzq2WAc=yHd^rPxd|aJ@PHLFW+7} zcsAucyT-L^9UJAIB%dnZ|AF*y^1UM7pdxc{uTlIQ?0f>oeYwc2&&WoNbA-X>Y^gPg34POOq*Y zTim|xshqnudo9_w)Rd}gTbca0Ep_Jdr=BZuE?xcd5;5Pr)R*$MuJ}^k4q#@vX*R*m zR3=ZS9KNL!D;GZIQYU}<7xwbD1KI0#VEsrpWj0Z&!W{~BEm(MRtuSh!S~-OAJ@%gE z9yfoSf6@eJ}F?f}{h<-(=L`e>^fP*Z_Q7D5fg`6N6AE=@tP#7b?u2M7-(*_E4 zm*l$+b7Q9aENH;Tq^RT)#C=X>G>?MO-YCOtuJ!DnU~?!IW8!Sw66XdefxgSe%qpeQ zDkCuOQ}-yx0b|uF&IKg+r7WlG*0p6j z_Y>3OEK(ClB`+hNdFovbpXqB-U~Ej3PSj^&sa_rmjEVLA6VYJc#-Vz_KMH3u{Imzkj2C)D^h0uhL*6E(kEA6^D>lsK3BEYa4l*$;U z>orZ28la)GC&Pi!8$#gpR81bonofqnuVMK#G%~~^cw!S&|J2a?6$>@D-#PW1AwporLs5nBt>~(u2<=IbmGFJ~K=zrn5R75kjLn zW9*&Iz}|^r$p`FuSwy@INhx|ulmPhsM+~f$4q}oilbd{2X+IW>(MEb1S4GK*F@i>? z2O+0PhcH7Ut_pS^ykSL=+orXB7%@2vN|8tvkS*|ln}S+A`7Yh`1qvTxtw-X&ehG^Jq1PY-TtC~yU zV=g18tYAI~dJC*7ffL`$YcI5HPpepLm0*EA<^Cso`rlb+HLT-fM1kM_&thZG9kAgw zsx0IKdvE9;aNa@)LSs{d(M=q=kPpU%0OSWoP$tJj-6V(;p`fT+6;W3F!GJ9K4;OZh z80jsv3-DGrd~_2Y9N!o@=3{jRUJt<`-3*TjxHKy|1J%(Z#^XYW=;i@zIk<*p&cge z0#4^^+vCY^(S_G*y10X2o; zg*y(9JDtt7Bpe7jH5^iO3p_AoMZ7&G88-*rbOoNTQ3)QV5ygm5BO-KbCWQQh(Xj9c z8H#h)CAevhhsD!UAIOB5LGn|mAc+#h=ZOCsvi%*^ZV={s)px45$SUFr0U`BL!$NpI zyk#-E9ID^9pw6pXR-?nR<~yV!5yoy Gpnn5+&og@f literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_965031.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_965031.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6433f7687a7d5ef72c5a60d8287479a9f664292 GIT binary patch literal 5371 zcmdTITTC3+_0I0>?6B`QEG#dFhO85eF^*Fw6gv(MH8wc5i5)wn+x5->3(GE<85V+;i^CpR86Bf@XKTKl&pVLSJH})SN|Na~FZoJYo<-#L&6In>a_P zxNbs!jwBFEjHz-;QwwiK$5M1KTh-?lY<*uvPalIKiiZ#b@^b7mc4&1GFImx-YBE>>hfv0O}? z5Ai~n6GBl&F=332GZ;4V?5m*|8^^>*7zRbBSw+u8CluqMW5=I68iL6Mpr`+M&*5VV zT_g`gv}R(r>*3k_Scf-BfUPRDF~BPuL#!0}N!K)$6(Iq;j4M;1StdtBJNQjP= z;+oUIPI^SQXuotMRcxU3!wF&SH6Ss)nc_qyRFxLBmjpa^J^W0hIM|tvJv}0UEPB@m{;V`B~G$lv&vRh)Qfc$QY&TI zN)~7=Y!G##uV8(}Y;|9@)X7}dsZpboMeEZ7EH;V_KS-;hj|lX}-$`^C1<3&phtFLc zp+Q0+C&K)A(5UEym|~jb5+huA0-qZy%!T74tYRp<$XFuGD2Ay~mIE3uBsezMq*$Zz z$qV?%@u5k9Q%qBAbYxUeVPgUqK;-cH91bqU!V6rKVb$TgAHX^kQ;|0nx~PGf@+d6r z{v3>TJiyz-!EPgt!X1u!4sRz{g*zNRieWO!gunp&I2MBgkvBPGZb5ytO``9DQrKl7}<&Qp^I!pp^Glz z`6f`QX`A-9q4uc+H_lInBW(L6WeC((^fAU3z3LAsx?F2hvk> zzg!N;)lX(f&mEZ2F*}mJ^ttu16p?kh-aI*fcJ6HY)uq-A?@rmdGeu>c)#+B**_@*8 zczpA{bG_;Q<>PYTnYCd#&~tO&XP%R(s;sYZsY~{?uO3-9$$R@`@6nsXvUf0L%KCRK z4axqFl%A5(yhr_ZnoT6^KJoBw{t?5UXrax?69b9ddTc2JVULX9lWBtrWiOdV9)>au7gNTZyEaqbCcE2KsSjUWYgwK8gMIx8dG~|#|js?cm&6Psd)gv=Wa6(Tlt#nHTOIc(O8MAc@IG< zc-@sWh{VIVN$PdCNN&ZQafyZ*_ppBaF^CymANmymQgCXiXfS&Vs%91%)9PNFT`Kcf z!=y1ugY4}_RV7WZ);}8sK1rIBmZbH5WgOcg@6Wi)RoTW6NjsxUI#RBbdsxTluhL1U zU@n;`I@MZ(p#G48wM;Enbq!5lk|`QRTC5UHq8VzkRWrg+8lj<4->2QzPPB_Q(Q$)Z zDc)2FB$B64&@gcKAgFRzf-Zv7e5qmAxZ&A+8oJ-_0sCtYSDH79yi+9_DcvHc1lK25*J#iiIc6GvV&(x^nv2 z)de{?9OJtIC@{NVC|?WRw-S2Ww^+AOw{&W`W%(EK&h8EWfpxE3)62C3JE*V7@=(SM z{AhR*DikAJh8I<+D%_9p;`TUu2})98isc|dhxVumPtix?LeRtQh7nu`?#lIw+k;W; zvqD*^!4UQahjruXa2LQ`J?b4#vy=NHW#a&yNVL~X29NX&JYq&wJ6w^r3ppl5HV?~3Du2$L>$%MhuimN0k zjWir06%8(qKSY{QlPC-#rCMh!X-=Rhg0yKIvsWwvXeqghG-Jf4O{3L+%WL0 zs}+s4Vya+IP5+ZU{m-;BtJ+zTgvM|GXR)#E4p^~`A`RKzHW2*^&RY~gXmI<0dbg0= zA{Ur&P{(xxMWGn>Vl=`kW`X5}P$bN=a6i`|d(>}^=sl=LT5{)IR6*Grq!j~vT|`C| zBYb(ltyWMB&~@%nzeJe0VjPA~5fz_ixdgA~O2rU~CE{$bO3l=YwVX*q_(R_gzW^?< z@d%6m1B?ixniPLO?c-z*OS1fhn85Qmk(V=dAwAnv)5?;$Gce>dJp3Y<04bVAE}ztw z^HfXyIiFuTo^yL4Q98X+-#rU*`sS^3)*Ph5v84l<9eY1M{Kumo9nHLO`ksMwQqmC^ zVsl@6dFJIDB+C99-Vf@p*RNd!rj3%G2d3F^?eUq%b98mO{s#NO#Px~w=YeUar03w8 zZm^}IY2Q*{d3vq$Gx7iw3%gV5&DA$8o?STm*2_|FsvnZ`zotcntV$su+oPK0k^1g9 zyxSVJy*O`;)DNEq|DNrZO54AuFO4j2X`pajuyXydU&YHn;o`C@1|PW>;G0UJOAv?u znko9RsK7mop{htC&I|0NNv?2Vpa$7EpWxKEQ-eS;7osLKl87m9eFdR z0l2%f3_ksnB4|O4&5N;xSl+4*vZ0no-|Kp(D{og(2Wr}}C@zTKI5E1ZM>hP^1uLn36Xv`bnou^P;L*9rX8hQNlU2|ReDvX(m*rpk<&65No Ytl!XM1mXUANRRCG8@7fkX4Mn_0N?7X{r~^~ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_984659.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_984659.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9be1bd43060cbc744d462f4a3644ded9c3c5357 GIT binary patch literal 4473 zcmd5-B#-b`l8C5>5@cB&6X_fxC!14vrM^hoZYyO}NT3-c4f1_J-XJ zfoNUrNmOSSq&TWhG1p4uq)K&2BywM>aH<5~`{k^lt2L#PCHT;M%Ux2bx-Zu^Yp-om zqx95P>PVh>^WM(;GxO#*&wp~c>pWytbk(P270Zp;B$@n zpx1Vr8L?F(v)@J|b`uipRa=f*IOtfXriO_1?SaeZ`a?s6s}MuPS*FepUOxZMK*Uk! z5(l&yew!PZZk_^txy4KdRU3J{RFS6~szbHSSP)7%>((%}Zw-pas(q=VE9Fv*!-7Ln zZ7R1!Z9_Yce*Op0f{Wa^y`#I^{Cbfu#)nso|La{HH~juQ?cr!t5h$K zmraHTsV21vHW>ZV#G_HuZYsCX_*56=QhtTYQm$^*zk?l<>aW;&NiT1%vl>u6+pL-! zT4kzMo9nDLNM;qO&AQb9#P^cjWyXjVqr0mNx`l$4A&DMM&_tL-Se_sjQ5hGcF=NAm zSfb|@WQj?M6Z@njj!EG}D9RHyEQOO}0`Ai zj{oq)X^;?aZB_Ym=>WJwcUb-gIt^{Pk;66HmFddiHG7+OuIOpX{w(tg?MKgCp4s8d zaQ1qxbIo;7yYRx*l#MDwc0ajTI67GH4W=#6obK7aOkegw?k96Y`M`tlN+3VBxNkM`)cJNAm%O_^_T|p! zS_|Ixg^LAm*9yC$e0F7J`cu!7VBrV-1#ds7;Ow(phRgc1Q*+?%0{@^V&*x8mlK6r< z{$(?A2RC60kVN59@PwWTb)}REhxl%V>2MaOgvB9;ql+MJ)Xby)$6(SDcDt z6@;lt0oZFeW>RQ|=`F?@RBg;~D%^X}(hx@)JtlQ`+svC&49}!23iVrWF}Fe0M%7eB zRW{9yVt7QjlvSadpf}a3qmdbIHdf6szrnAKJZorph1^<|;~myFmRHxqelV&vVVAM3Eg9N}ATfT;biSW@II-CiB7WAgeG=KJMQacQU>y3vK*ZA<6nLX9J zc%(NjM#6FV3=n1EA-ZfAbpO!MZ(Hu~xwmKGN}kD&6b_wQ+kLvQ>ka8O5bDUePJ;@X zw*gkrPAti(gb*i;9GwA_1YC@a6APe+a!aD>r6BFPpd+57a8n6QiqnDw@Dw_#g92g8 zfEl#a--z=2EqUnLZ4k+f0lt<4wt?7yx3|i1+p5t@uY>fv@RR8!fdGq~W7d|j-EnAV zi!T4{)y&mg)18oZ;hEhtdpL6#Ag3cck#BjpcX99P=;Oqr#8dl#cCO_1YJEkI51`4v z`_KMh!5_?btoeJif#;4lwDT{RK=!=?(+0}Z?`Auw4SU9(?YZO1%JVaGGYjp%P8B)# ztSjTnhS#`VxsLl?_qy_Ut?m0o&Q|oc7Ja*m&Fz2pa29r>2{}F5C4J8$;B&Q6;H0~- zISxJ4eK>`9QwqbSGr^@Z!&lV@8q;8NvyYcI&^ zy!vnLS``2rdA6fh9QBH~!I|3sCujPfdFRl*bEa_JzBPDC{-1MW#}jbT7*!51*F6;d z8{D@jg49^ofV)mJ7l3z)4@(Cr=}-h&2}Q!P5bCL)9Mussx&=3=o_%Z`ZQ~pda?)|C z$qLma6j=(fKzHn@UTyM8VjYEYQ=iWWk|^ttMl6xIm=uCGV!r`!el4T}ud9lCAwHTQ z9w-#2gk(er#f9XUGEST~MEKIED`DlekdWkQjcJP}C#Pt&DeKjUj=A8EAWy{=S+0lJ z{bh)C>rq$jNmn3%x8Nt=hv+o4!J(!W?NTYwstuM}c59bQ{$_2UanNHn;T99xZPLUxBvd|z2V2-Plqm%e@M&h4@Pi@zLl`n4hb zZ3#N{q^ms2<$B^s9RB4@LH$oSgOcdk6LTyIe`+FPQdWeUlTzi)eosisqNL}ee*VN! zPMFY`80QyJxol{s>$%>9$WO(EGtw|jAw!US7rG5I#`rmM{sp=IiuP`>$i#iv_Cec* zDS%JF%;%qX&N(-%bjHCc&c-wG4O@A3xa(o(V&_+MwrNHtXL-$$0QO{0%=gaqZXh7@ q&Ga4D$<^;Y?tIj_Nx7S^BkcN$i(%~l@_jS%>|1kx=L3iC?LPsU$GIy2 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_992208.cpython-312.pyc b/src/temp/gen/__pycache__/embedding_triton_kernel.py_gen_triton_code_992208.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..123a3912b8659fff659834ebf0788a62cbd95fc1 GIT binary patch literal 5515 zcmdT{O>7&-6`m!R+$EPkqWB|;5@XAD95Zn&OLpuiaqPyHEn9yn+QhD9H3ZFFNfaqk z-CfC+IxN;ffRJbaQ9uFFXaK7~fJ(IhtEZ$u(HcFa=%L)mh>5KW5Qq=eO^I9-=&5g( zKcW>nPSf_#0e0uTnR!1m^X7Xqf3n-H2+BXAzZ$oB5&E1km0~FqTRsY*MZ_bXilgCD zO$}2zZWE-M;lEQq)9W&uq! z)X~j6p`Cp~yQO~nD9c;lH4R&;v*zAK!&U-~+uv4l&xYy2`- z^ljJJK1Rt^c7m;vRbtgxv7GYAT%{R2qA_=ytQ3?U^uynQEid+_eAtUQ-G`R&ycX~- zLE+X^J8s8}J~OwjV^59EdNZT&9(`smyqjp}O*Qk|i#>SnoE>a>-jCAg{kt^6o(U(} zi^mH#)H{e1jgOBaTs1yDHGX-%Iu09fqwbdnyH?GWb*0cO3Nk|?3DOO6CW4e^;^o;X zA!yL($*>gDOi6iM5MlL8B}j<`#Dbh=kmH(lN=%N4;mJ^x*UVutoEQ@{)0`kCC5;)4 zC&TiQF3nVGrh(-_C@GR?)vQ8Xm=qFnNS@V960{1zeVScP#)JeYlR{Ins9CQG(Xnw^ zrVqY*A8!!of=mj0O}6cGMyqh z6e5Wt5e^{In5ifq0$U{_l8kFK{3Oi`rpUTji9^ZJQAv}3I1Fj!K}d`5zrVsSe;egvqM{&qotuA=m{unFz%MF(JecP0fa&yV4keoeFh#cXb_p z8G5vaZK=|5@i;g}Ku`JzLQ+AG_8_Ngeljzedwt#UjMDqS(NMz`8AtZHPnnhi>)13S zn>#zAa_t|U{^f;VTv!gTwXN?vsd6Wk9_Vv4W_vSlE9VNX=G?2QYyYxmjZ=d_y{hyT zn%i=-YV(0S?R{vq%^%1d$PQ#rt~TEBulVm-k1A&#xSMmXT;DSDnfq{>Df$9wuHXwS zF{GsMCxs_WUb8O(ZGbC*^3{^ceWB-sZZ-(v4VZ*Fk; zyxMm3<71z=?s@vtmPcmfYt41$#+I&s<~^J~Tk!8n_rq`Lfa-rfeWBogcGYmlvSL|l zU;DxOfs?BLBQnuP)?U5aq#il1wp>W}6?*zV3ICx#cVQ{E*01hA^=aF8@^8JJ z9t1ypdzUV&-b4A$)2jD$`rHF=)8eItOS$3Y{c7Ow+RR<6db}_Hy|?apFQ(5GY|iYl z%wXos7hdG++-yOf&dn~wIp+6g_GgF5oxgEDeJMMq+S`|o zEMNSLKDq__hr}1*2l1vKW6cOBMl*|oJS`?7Bx>@}s#AWN`pN0G^~*2_elWEwjFp=Rhq?@*xo>!s@?+8W(;?rWolnQh$w{5wnK z=2NzBP!7oiee^O>d%-EUVIww=8tVHAt#xj1t5gN-QwG^xR)Z$$r3iC1nJ~N|#q!3K z1siK*d==^K8k;dwNwK4blvO5qQ!m&`NJ#V)q#h$&JILO&mOE=Io~~F0|G!PYsVZ-H zihTY4ZFv@3Fayc4yzUgjhVPT~xMC>fEHf|6L?glmS>+893VPlY$=x$(+>t+m4oGX` z;VGdcYJveuX12~tKvj*d+yrzOnAzX*_ zm%Jv3fZ?Go$VGf~Mh8mGT*_!cPCsp8H}UHg@@pnPvI!a$68(fBp|$O&Mvp1e!ZMXmd8b=>Sw=_aK&JzQ`#Ew(MR-RM(ZSFRN}&-|&( zsSku(v750^xnQ0S=DA>jbIqT~ocQ4Et-+gvt1sPoZRNFmYq!dE=jrY|*A3nGxm_x^ z>mJvhr`z*fdy%m!KYrkF&&M*c-0`K}y#L_ZxsL{ZJ@AR|&w)P%?l~@{jSt-|_uWC& z9bA5H-QA@O6r3)lx8Q70&KGU&pIuzMvT$XoY5Cx~_jsOrS?PI5d$Mn;w7FT-j@;H#GE>tq~Y4xW8f3c&%R;D8x;I%Nc&uDQL{fgt?=8AfaG{8dP+ zS|N8$^X)Qxq)as$>wq3Pm|y>b(6yDVq|7zi>cAWv@qb4fBUhmin*bOHJTMi{DHKI1 z7Bltk#TB^3EYF;PiPw1b-&(b=_3Wu$vBWFh0()xxpX}*><{hVdXG>8v>-K*Z8&BN< zJDDS90k|DQ(Z9iYiy{b(4Fy^^NaO-^PV-^$H9~T-AjzRfSQ0|;GT3IibYPD@0>GXkx;%TCZf!jA+d}Wa!rT*mM6FRz!B$?Q^;D;bh z$7M+(kF+YPmoVL-Pq1oF48xFV_@rAUU}sUIPw6Xqnv}t!uSI#I==LfDMVD9Udt^f{ z-=clNUSvJliKSEdJx70g_V@k2>Cb=n;v*C7GAkEg2;lXV%#|YBm>pd8-U+M(KAwS2 zhgmrXom|8G3z-*+Y*Q|9lJmIXghm+ zh*sz*4FWVDS>SeOc6MiHXLfgH?$0)x8A17H?6=|Hc@X*nBjutm0rxWiEFu=ML==q{ zTVj;Za9xZXr3l0k6B?iL)MCr(IBLTBa3984vUdL%#ZqrmH%b#mjjaA{G)h<2H2`MP zAVFB7r+` zna~y6>Y62Pl<2Z!myx>4NN3eZSYu~l3CVt=w3h6u8jB^{Q6wZeW}&ald&0D10H9)} ztCj)EGi7(xO0^6^NN%YPuf^(S!H{JMp$}x0>#JthRTiuOl6{=0lA~JRr7<_jGDq1X zy2}$;sw|%c$qnD*8jX@eqGsXvN=`|STf8&b3%zuy*Hr0k65-evi&S5hrysrXDa`Oy z$yCcegyhAXC7)y{^OIbXu?n_Ng8o?jO~Va*>Y&K}U3luGrV>wcnNL|Nxw&d?4J%5` zQbR@0EmdP1ifdQbQiIf@S-O0}Tct*+b=CzEHtsUBdB@B)nAs-zcR8IkPDOcN+CFV~ zWyu3FV7s+x7V}6ga>ILnq>w~{?JQ+|fK{y(*rknZuaenYz0V>nwS$bFSqt3r8h1z{ zCwyO@GtvC@zOuw+EfZ)41&rr-e4)3nCIf`3XT_N*P9-DpvGFp8fKH`iL19wWC&Vzv z2kfd&jH>1-J~7S*V@!lq4M9E_ALmpmnh3H$O{n@vTztM;HO7M1BC+Y1s-NZfgrFM7 zU<4-UH9*xD=aJmCDOE3s37+GPu&iq4qFjuNi;OrE=uqk7%mwapl#8qS!3fKoSM6HE zoK1|Myf$?K*pug{Mb##Ve1zqg7>r3=dKe8qj0z8!x2D2Bo+jlnPnYBD6o#sr2{ofRS@aF-dpP3xs-A~eaw86h&u zsYYQs#$XMZK$}c7F&+u2WH1y`O;aMz2x5?jy;ndcDsZZ$$N=+;gFUXO2C+EL7U!-q z+iXmUD9Z?;V3bo$utJHsF)=MNz)bc1P1ITCgKe4XZCj=+W*`kpr~!k z*j1L9r;?}cdVL=lmyFBwO0Uv>Am=@}?(I>$J->H-;{Dj0^B&Jq zo=vxB@$kaoOfcKgeS1XdII`+gI*#YuCst<^_bbV>n{|%)@LV`!P5z41jZ_Ruld7@xICRch;p*!+sSNaXb{MbE` zFzqRjNMBblW7>Iz-h*k3VtyRcp1?HM12girWM2Qg@#&Q1F6`#w+{L`rnI2HAEz6db zn9|;-So`jbt;))gR~754d1qs$OL0D#cX=KdsXB8?2Mk>Gi;WA7nYNV|m8PCMeV?@{ zeXl7^ujQO0DZ?gc#Jpr)KD5#c>PXq{);BDkS~!*IU6vL9fn5E;lqpYJQZoweO_62d zq2)ovzdu*sT_R?>KRCE_@HeiHd>{I9{w~GdrO;g|GVf~23@ENAR!oX(f6ADz@0c?c zT9EgdLOZg%(gVvTr9Pn41y)Yqo>97vDRswI4=Hs6piRo~&gr*Kr!J<2%*6$Hy`fWS z=v;B%YP#OE9(Yj+yqIe^@)zppeK-avRF3&DQV*XNEPVxsV-ZQj1n%8YSfUAfSltyQ z=+2;D5|0=~;!%&FQQtt*xXVV_AnPT)WEj)Qv{-Vbz>n5u0$)i7E`pe-mcAq>OO_w) zN6;%l$Rp_bUFI7qblooN*QRf;Uz@(YKDi73+VpMywdvdZDapt}%zcx-0bu|x1UTLc zC}29b^&col@I$eMz}Ia71#MxEzSWMaW4KB}r8z=g8>hKMY^g2$eg~-Pw>%E7@zo$d zC8%Z$YHyJcAi;lygHxRLD5-{`ZwQ!pY-1ifXx@)2?3OiVs=nk8urFYMkxKNd)HpaW z0XzB1qs#XIrn+lAf9HVm6ZnLGK$S#!i(|fTt}o5c9nH4(Dwf_X)tj~S=9^kS@Gtq7 zhn8AbJh%MU{kMm%x8Cu5;{VvcI{a~4dMMSCJc-|to~)%g*}qBcS?*V;K%REwXccDGnaGlz|r4?S?0}i=Jft|VVRns%b)i&4SW?!pluWXEnW_r0uyJ=3)&O^~ z7<8kt>U02`e@9`BaK=h^wTf%45W94@$Ed0VgWkaoP>o~hUHCQ}$H?!*uSSjwNg5%M z{|i5fVomSstEBuZqfWnPox+=kb;_FGr(}Hu z$p%S30g3O!QX|_l) zNX8d{^4?xZstw@V9q0HHfb%DD)sHLe5Ba0G!tb$4g{IkHKvzsE4J>yh65;}S%_&1} zit|Ls8=WNpXFrexLAIs2Y1LFoq5_m^5jjC*LO}sClxn2A8hO4NIlvENIWFJ|@2BYd zUj-;Yt9rwtD%;RuG%ZAxFkAE@oIEaBP8geIr$BsT=Qq5Y#pyh=8%b0A# z`CTbhgaCrpLJ2k!3Cwh{P*&%mfj zRG^W|ojkK)cO*}5)O(VH8+Io^4p(yEfdfI5jA7Pad zHX9Sxyy9kK^WvKeZ!Vv?P2bu3nK!#^bF&*wkKOZFz2*n)|9?Rs^lJIx z2oqdt{ox2x*37*T;;_q*JX1KPLvV;*+VDNT?t5DCJ^i__YpyOurbaete`a7gqBK8u zJCJ?hywd$jcKB!6*TjIFAR!1VX&og9_q}5zvOkrp>wL?iiSS>NKxl*j literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_14965.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_14965.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8b16354b674d92d66327779d1f82fe5a1b59361 GIT binary patch literal 6506 zcmeGgTWlN0aqmI$_K7e5DA1pF_Q<1Xg^JT4 zK=Y9W?sjHpc4lsFcXnp~YPVYuw9SRzO_B8oeSs0ZX%s-7cnE}65QjJ-g2qcdF-~YW zIZKVx1mcNlom-y`JxB8Nw5@8O402TM@es|?@6mV4T;nFrz@uqbjrqnM#!WknGdql% zYmbL4oaH@w++6u$>w9S2LL%N$vCXbD?9KEj$DF;05E8(XshF>xKFa`^Eg_UYen(;D#9c2vL$aMw`L(;&I^VX&x9A3urSeu7ZEQFh7nk^!u9%Ld$Id*w#iSb;sVNr&T2 zRd&_#5Fs~JduZCLYu@U=*>2qj@n*1#Dcfb%^@Cj&9gep=#}-m<`JQZPtrEg@V4Q#Q2EYwxj}Bn-xP0$O$;83OL@W=HR@*UJgjX zDV4$C6c0W!#zG;HgA|Mfe<|VNU?CgAA z`}Z4t3o&6voC^f`zNtWPCLm44qWc4r@VELzA=oGJvvYlb^@SnmV2tAruya%4 zm*y5(Ah7V4q*#>2FL*hJZP7b#@+KD8l-OSA2b?6~6VF4%ik?Rypb-n+*?9lkwW zupdg2Fwb6}Z|qTQJs>>a7;Q3IC?X^*|A&3f~PNOdQ?}xGME|69?mT&EiV=7`jX}~hdX^H z6-!cUo~G;>#j`tU!n5{d_GIgG-Aa?c;OR}8));G2NS{v03ey6tj5B>mVVYsQt|etI zzKnc(iwBX_kv66FCr+$88?%QM=dRqfdw!*RKyeP-4?S2^j$Kfk7ZO8{Xyf}s*N2j$ z>HgF^+0l>Ryzyr4#3w^HhgNqTQg$6GwDtduJ`5)%>?a)AgeLrIz%LxsRHML4^FlOO zom=3zuVpn~!irfzGMa{g?TsU_q#bA^ro|0$dIvhC(TzJ$X^m=uW_J@U)MbM)7yZ|Sd1-_lP#M}BSkmi*fE zE%_$d$Qe2MuIY}a)F|>ZrBh76QDgHz&^EP04gZEsyiWXdb$pIOuw4iUbE0a& zpcZvRKLzK~RA7$R!Vko0?S%20Ypy{HfU6HF0u-eld4aq5W%&YmKq5F%7V50a&)BZ0U<1IKKthu*$S6 zOndf9f$4(z-+)<{Eh$U-NWhstaKIHOFKGins7&eHLp-_|(OUFJ@C z4^~O7m09B*OUhJ;uF9WlecN{zgU91ioFO; zz@LblAfIvSd@{pXwTy=0Y~|iw?j7aciF-Z9!uiZ_C5llt9R~DwnxjRM0*EI}FC5LJ|dh#{12xMavnqdQZSdKGdcB6}DUkLSf5j$pjak)zqk# zip@|-3%ocVkwlf6;TKi&7Ntg-(Wp^tK{8q>{807ILhS#00#L#-3FL8F= z=?2KDhJHRfNu^l+SA9iLRjC-ZYxA2uv(bnW@1H9qo&gS<&OYU0OC?zyG%dMe% zQ~8e9{yhJe_@CnWw}TJ8+-;@mwug}2q~8) z^I;(`UQ@!0MJUdQQzUeQkXYbpVH8sG^MHE*otPMZpyMc|TZ9bYkRQUpPKfoy+A_J5$(b;9(1 z^Y!K;IY1mKwj-BkSx(8t4vcjoxRKbG+4p4^U{5GS+DcqYL?_`~v$~d#rj8aVfY*J^ zD-)TCq5)uhrBJsVNkxh#j4{aKTy9OZ7R?y5AfxsDiR%+jtbjeaL?FWcrEQ!bypLa} Okh8ns=(%pyKH$IVG;2Np literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_198114.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_198114.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ebca7cf6aefe0bffc26d4c13b819a9e30a2887 GIT binary patch literal 6322 zcmeGgTWlN0aqoB}j}(tDkrJtgB{_*D+Lm1_P9nud5;(RayOLeoO%o}OuQcx{Ql?1d z?i5?fUBLhbLIwduAs?cF0AW6g$hCp^ulY=Y0u4|U;DN3!TwOp#^y2~rsMzR7zdC#5 zk!OdFlb}HJkp=E{W@mPGc6N7mX8+7EP6Xxmzy4j~UqOVv!iZ5^b>JRJAhe1Sbp zU8{*{LdVHD%QQtGo|x6;7+q^MNAlFHA;m=EEUnw)6h~dAuGKlG?VOcIvtCojI>^}` z(Qbc4JN<}ul6PJ4{kqYy{ae zLr6w9)M`syw>C!UlspD2jNWdvA4YH90)0cWSEiP%pubPH;u_CL4#~C#9FXk>tb|Q^ zwXzW*hs`#c^%)~K>u=Y)VSz5ti>~YCFvbVHoCd6PKSmc)?l$Y9^qBblO??bV?vQ&7 z+bF#z-hjc;tbw-|uX#h;ZE~l)4Qr-!%E3Ame2k2qhh+4Dj6S(bpC40~)VGA8pxmo; ze6UWz4wTz49oEWa*g+29St|X|JGj&a(12;&j;78ABwOTueT@fNt&VJ&4H=xxdbXJOY$pd;jZI}{k}=H^jm0#& z#t!y(>a6lTn!v%q}ileCnn8bWScS^nH0y0v1p3d93Z|fnby0)v|t&N>1aya z4yyM*w*2Q;doN~$3*vk<#_vr;V;7=QB9k7D&cIUc6@}PdiJzO_3)tRxDk>&ej*n$H zeuSM*B=^iOvOr*A38hS$#qI3AvC)0|_5(%hW8)V&_5v@Y`4nsJ%Zd_69}%7d{jj4J z{{fpei|+40zV=mSg(*fW-u~>dJ09P%l2?kSDxSXV_$JUgS2~NK(lIqWT=DJ69$yb^ zD}~j-&b3o&V9$*)wf$grVw3SKPvxfyZxwe}n8EDP4Yx1*${la85dV+?kxtk0P=2T| zQF!6H|Ff>ET{oTkvoCK3!o`8&*;|2~IeNntEHbL=sdaZpu~&8P*s>7J;D&)%NoKHW zL!MwkQQc2&*)7adRSR;o-*+LmclovaYlT;f(`sxVu&e)V{Tc`ipn zKc=G`8d6UU8Pj z*5Z}WGdbo?pkwvu%F*IjNm0A^R|3!E9P4gh;aEPCv#k5W#bc`fshoYY&h>0*QtjSX z35)<2-IiM{j^ux%(!Ib+dkSMJ-IKF#1P1bs>eDE=t2&At{=$Sx_pdwsg|n)205o-M z-*9@jtfa#SI$D{k6EWVxV16=tA2WVydUO2N^yc`M$H;F@Z^~~?Z_20Oc5Agt4GxH! z>M^JJHWRRWxU1i5cx57F(vAvbAItr9=_{6`!FN>;f)Y-q9UP+fuRXPUOIzQ{V+H)g% ztuHs4J+kid6-LW#J=w#X)Dxw{Div9$eQ?Uw>F_%324LBpcNgBOVD!7d;oEdSoWzo^ zLO%gAz6EnHGkK;EtJK8>K)?;*_6QP^-nhaAKMgs9=0jx{D@MlNu=0h>nx<44;SOz1|AJ|A#@g2La{4! z^I<`zxi&pip*dH*X6m)OUVCtD_*eo#j~c}$+wt}I!7!w^wlgT=Jt4dZ_+9L!g_kg- z|M0!6U&I;}Ti~K_3EQDXvvK^}$rvB83OLm@UAcHPi1YM>2x7emrL{@CD6z390e8ffpB2lBiiO@Qa$GNx=_DM)1&I ziylHcwN(gl4wwb_#4lkHA%vijzdd_=!{g14Zv=wb$qkPWAg@1r?7kO;w|&f9VM@`p z-p|8oUrgSdI(K`DQ>VD{oOm-NZFcsozO(X9>G%!$^WiT;f?tAf%3x87npp@T? z-E8MJ!rQlku8{M7|Nmc+(y1%GKOSYGSGN6llrb!`HAOt{w`Gr4Um_7~qE|M$p1j?) zQ|;P$t7~`OowMYoHtFu-u~Jg)`Pq#~`QS-) z5~`!*DRK*Ar^)*WInIz9o&BYt+BulB7cB3)st(jWT=gPXd-jALJw_PeG(J(b+CzZm zsnvT@@Tjx^slG-VkbvVg%`%&m;7uiy7A45{gn9`7OFk`T1kD1cSwFJ^J}_Ea?d-8L z@aAz9>BH3y8hvms!!4xv7lZ`x;p9U6E&Qq$f*`&^uD>DX@2F>ku)o*!Zda9jkvLf0 zhP?h|IWJfHFxHRSgR8q&c7Htp*gXp(T{W&lWH(`3_XKoK3&0zUfB9_wY}E=dB=Mo% zRe43O+A&5WpiSf_st$}fk+;3th9MWSIqxw5-5Vzm!F=tSCW!W}mn_ILSaA=%+on(B EzbkfCPXGV_ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_23614.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_23614.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9adefc573a01290502a0b34b7e41188fc0d03849 GIT binary patch literal 6201 zcmeGgTWlN0aqmUmkw=jtMZGA?Qj^N2BU!PW1dS~uj$_%f<40|`HKaN{(A=X$i6Z5@ zQ*7BsfB^!8OoB>8KSUP+!T<&0)G6vuf4V@51n7?ra%+baKGrt`RP|@k35oc zsJICVG#^>uc4u~GXJ=<`XJ_^=9*+w_`DgC;)3ZT@z95KQJXK(Wrx02}7-1@bCMzv9 zNm)2O$4s&mlBpR>jy<%}Vp?Wr#LayK!%Xx37>n7r*c;UWlTHB%2)i*SqZwZlDUL1r zom=$tTlCYi02&D~5esi|ldjr2U2maDH;rU>O(SuU3X2bjDkn|h;PF$hz9yX=J8g;; z@X(1Vt`dfN8{PP)d~+Q7-yW!Vs>bnlJo zTxov;ufyi3w>r{&R{*Wa4m9xdbv_%?19}VL$MhB0e1KsdOxI>>YZ%?+=g_@VlukAK zJ!V?$8S3pPOigZ3Yqk3^YwPMp>2}PeJG6jJ)7xy?A@us65lVM9=+kWVl( zO?s=&)!8>Y8ra>}qj{?a;4er!kHe8d|-5{-*)w}ifnk_;NylwW% z>#g3dhsfU4WcTPny@$-GH|pGMWcTKgwwCFj-cgb8gqa^74@#vw^iaC}z0X(#T6A1- zLr<_lN_$;z8`2rDoZFroEGO1rzwW8QLE9b_T0~*TNrn9I56D>xQznPC#d+Ce;>p;Q zEjKKhw3aYk^Ga$;iOflHY&s%JBsnFUY$6rGrgJWGIX<^AXL46$C8a9l9GfmVA$yEk1^5uDxQ?-;LhqsyhLFue~F(D^S?nE3*r%bQaNXJuCW0&X8f>h?z zf@XSDO^IV!nggEHh0W;nW>npbV$;_!HX*Chg|GEYf2~I~#p)N*5BT=&Wiok!aKuzJ8HYz0hMilkz1AxWH=B2&`Zm%+>cz<=;OazLZjC)%i$N9-fXwXCvBl zDmfIHf(;o~mFTc0&&>}5HXKVt)M*LJ(G-?Pr1|Oi?)gOt2ojiEOC=@JFOBSdVb60w zFt_XF)UKF4RxqhsvNw1E@8y;>mjoku_gm=IA}2zmAO) zOvnp`{=)e+->%Hz`+z(1&H`V2(dZmp^9^N=5EfV7RX9?N8$H7}`;DHxw*({f%I!UC zzOl?W8NuiI!eF8I`obTteth-r@R%_?cE@)(^Xk3U_T`1&^elTaM@zmy=E!}IfB7fB zma^16k9X-@?p#T1T|Qd$7;Qs_I8wa@6?XNdbsy#p1P$l1T%yS<8mp#2*R>vv@xk6b7aT-lZ1 zRrp15V$DCCb>0I>1NnhMSMi+DHMr&<%5o)tu)yV?DN?^X_|f4H4_|-b<`LtW{cBH; z8hxXNe>BUK{9T2y{P5SvKy#MbLOG$_k6N~spGNMMWoK>(tid|o9)EMZ#JjWc!l2PI zVDJN_mUasSWrZE2(d8)9V^CPmtf7UYFmTUYF16 z9A+`|iQ|U9vV015?JNpAD|dxL?DN37HHg!DKsGAOHhRG%c(F?nWnML11hjmE8fL+L zO-JTs%So7yihDKLB#zI8g=W84gGlP2NsXKA6x;!-0Q1$HpE3-X8EEES1h<+T{2^uxuUP7~3GR zNw{W7tRDRTMq-TlPV$-s*hE9-+sM;dEWFP(X!(ORYyXF}-{01n4l6zS!KPrvX72;v z$)=Q*X8vH)|E*2I=aI9IUGKALE{-5*;b!djpV}FKM6wD)O8pE%=fOWffWvEEn?>ib z+X`)XELK}jwQZ@kUeek=7Or6)S11r9jRNJpAxIA$a5ei*DtiG}o+H)sP(3(6D)PEw zUDqZXUBHnrUCGTHSbifOmBXBM>q7oUk}62DTB`ubMj#7;)WZt(Ouas#1Q?wbwON!g0m{t6r+yCm3sDA-{DI7|b+0co4tYf8#JaZ? zAfG>T^pOvB^?u;F<|#(5Z@blHgb&`HynW)%#INp7U}FNW&Z&1gwfh}CD_8SZi^p#A zw}w9NTz%@yok{8LeHgkvb8F}A6RSg;$^Dn^2Rc{w=l2&~ z*T-*8uWo<&(}mB{e@d^u9=#L5>s>#37!*5QkNW=qiWE;@+xGoYCUUL!`=g9)nTHe9 zKEES#tbBk*@Dm+b?|%Aj_b#J**WbDab1hjWJ8_>66^@e3_}QD`)zMSN$WK?#{Bret z%s4Z(8dp};%SL>$3@H~iMniL$K21L)*cti}Lc&>ky`!%fG&*)F$_^}^kpbC4G*Pt;Sm1hCdTWNPE;2CHEGJlh|A%&bClbMNY@EVdzsv0;F zr5f`8OirpPg}g3l)*)8N;W6Enb0|%v68K|eOuudU-p+`befdm>>MIAFvGu(PD+`GKJbZ zOP!&Io?gmZY6&bQa)~kn@VeW#G@2VNa{xmM+#XsvpFdxA5{ySK?@~{$rz{Z6g?xds kn?NE8v~LhO8z(75dA<}UDJt;r0E4_c)>?MH>9&^hAMaZvaR2}S literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_269764.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_269764.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7ab680764bb941df8648644815c6634ad8b3b7 GIT binary patch literal 6232 zcmeGgTWs6b_3|N!qC~yr_p@yh%Z-!8c{FLgE=ZF!&hl8|cCFo{gd$S599eQnx`|zK zP=NxK)&iE-4=X@{su+f*OMvFHpB*q@1@Qm_@@UH4tZFt01svWAm4hL0-1#6v6cD3>mp>{U= zXx4A7+!(*x?z#uMz%Hg{m(yAw>|!lg>wb(aq}ttVi`LV`A86QPL28@YWBEqwZQ>1C z9E}lpd+|Qjt?g6W)jn*Q)~<$YQ1~%AdLGiz4?6nQF7y4knx_6G429KRt?m6X4PT(@ z!gSaxx8(;lsB#vp4K%F@;pgF|QtE)YzhLE(3+%6bp9ES&5&Iblqj&MinYJXC=#fM! z9w7{ype)Xd5w}4qNrRo2($i9Oj!y^%olHjs!yc8QscF%$#Q{|e+ma}yWy4`&dq)jh zl}ajU!!Z}VnwVRdGbr)uykV1-v?N9b4dw_xA-<6mQ-)&OC_952AwMYm!aMB7f;yGc(_skqviED!rh9L^C$?047K{ zJW2Re7|z& xXku^EHG;EV_^VlgCF-_f3q$7NnH+?SH+*kwM&%ZVk?aL5aDJobtc ze8eZy(}|ctMPo68sj`8$8B-A69`@nQ8V2uT*p(_fmlCh=4QXdODe!VEniLHuD6Z+O zDuEZ02f^Bb$F=|bcK4OEbXlH{#>Cw-(b(mvGLue?M5o1+vRjs7yA^S6em7vd(j1vc3hfSRP4wb&TG2osVxV^ZL3ho>A&wr9^djSxmWTp7hcuF+t#@4x4DSUMQ%Rv+3=^s zYux@U336OpsbffY4}tPhXGHfzu=ZilUc*5Ba8)0R0wMEt`>!9=JNK^n_hs#y9bF%> z*Vy8r>+!XY{aNl#sBQK5%JIUXqNaE6TMO;aI?En^ems}XQsqErVO$Sv%Q`k|TnCD$ z_3k}up;6#sT-n9KXzq79(+jMOH-AWHda{mkXfWrj>_OqDD+iD>kUym}17$Xlzo4^& zV5#$o4c3<*)7kzl8|fSXLv7rB7TFx{o_PC2b|QZ;_j+NX_{_D7>-5uQn!N`*izFOg zr7-Srz>?abGwhP6EJ&#sgt~ged}ab*z_+l9RALqfs)t(R;(r9aaoyWBn`%?-aZ+pIj;j&6ImV_SSYC-r z^RmHW&*sHNS!nFFzOkIYp7N&8C=0T4^Ps0TxGq9W$7p1*`7{OKn z>IhI?ghwtxmq8T|YVTQ{U70QVKbR}N{K?eEQ@4meUd$fM94ov1`GX}_Pv+<*{bcc| zPDjd&ugq{|rX5Zj0L$*2JAdvy@Am-PZ!>*5(^q(7jTr>_@4#HkY>v(Ec@LIp2?n~$ z;f~c;S6(edKX|Pi2$qAPe{T!g?3E#8b4mN*rDE{Z8>9@?Gx#s~?xBJqG)QPQ0@7Gg zpfyUcX>?NpQcEuC(f|l{I9Jtl6UCcTSiHRtHNkO{xlub`O{EgLkk)Joq}f&5ER=x{H|t*t6!K3abOClkb7)MX*jE|BWo9#s z;I55aZS>SeFCMKJOCXS_QtYZ@3@Go9K*?#JKoQ>==^)_Jv$#71-QB~u!~bW@d)lC5 z3qmwPN)DJY?1K1aA|^&`=B*0_AFkD)cJtQ&T!la}0>w@>+jCZ{8z_w;wBb@jS>a<* zsL%E^Bcsj8-ezQj%ge$c*3j zq0YXKxNBT7dcF7aPCatui>WV8uTQ>yds5IRh0>h7-l1$lR`S-$Tg8*Nn9oPP>?l2X zetn9+JvFUQO_%21TyMX!Ik^3k?vJ~#&wf7i#p%+>gXF<0o57CN=T@F8ve!@Cnkn`F z>aPob)Bd8BF2>e_!bay4TVZzxdw<~nZ%FanwcekOG0|&%KOJK%&umQ+&j#$7la<3H zf{*CMjjpF|cMa=Z!(Vmn%z3g@c5;*HE{qowde1LzMM`65_0ivy&c9Z=5ZBL7ml9G* zzN#k{D^Mg6$4D5Cljq1SjGZU%BjlVQH`)h^VZD7v){&=v@2)se_ejNu+`-HlGoOrb z(ph|>?A2cahPOKIO2Y4>1t|Ir#)2fA!5P$SLV^D&>6ENM%#dn%{Yzp>PD=&_r`bHS z5NVD~6Qx~p7=$v(nW_5_!cxnzX`c*Ese z9?OkYYyd+!-`2f)VdX-_fiVWL-sPTLPsNEb7Wsk|7lz!(&fenyx_5>^1ow@5iXehp QhbiRUvE~_iyPAdn7q-kkOaK4? literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_335674.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_335674.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3c121a7cf7b4026026492640ba22e768fb49930 GIT binary patch literal 6397 zcmeGgTWlN0aqsv#zQnhv_akXzQGUpgohXi-haJh5Y{@URO{FM4(7dBai6Z5@lWi%- zf(8PFYywC`KS%=s!hA)p4MhE_3lyje6lj3{@Jc5lt|}lP`fuQ*?4U9%xkCU@u73Fr3?;^of#Y& z4b-Zs2UZPy559mpg9HMP#bt&px3x7x+^o>!_9t+XRZd3@r?ysF85WUfUjneo@2ug+ z+C5f8r`T|_qFdZm!%IEJ2q;5{wkofi(ZoH1S>3%u=1aM`Qt>=3tuN zDx25{JI6)^T3GaT%hWM?oV^rbqq1%+%+P0Ko6^vuvFV|!^W#{+nFT?%@&XrTSb7e4 zVizApLyw~TqbMWWYsN-+mcICvp3ql%c-c}J6XlyDY7>>&E0Xm>d4M6xUZG(;*gN2lfP4^GkyU1f#qiyrY1fm= zzkb$tCC1J2^Zo$a7xD*Y{X!@f-Rqx*jqBsNK%cK3eLX7sSTGWGmsw0 z9$qyTT>XjBExR*gOdGS^*{OoPH!-r=;7FXBzN9o_%Xgx5uPsMCXWhD6vhKch?BkO+P8O^KNpjQGmhbG7TzzXj z`Ip9|{xQilmOuOR{2Q~<*;&aoo1`{i;oh`2+qwFJ)O?`eIJoW@kQ@WQC;wpjovGk> zEvd)L?@jk+Csy^a_M~A0)~ZkIR~B;vlDD^D-<{MI8yqV`soA8a=xojorGaV7-IzI^ zKAwF(cS362U2yM78jF-U$*qi}ej!oqz)9Iw4oOs7(ol4_ri`Ud z)tXcIE4VejAm-}s_<;bMLW=|(a;X`YAmt? zU8nkRk$A%V_37L5*QamKuX%#}`t)u2_37L4wW5L1Gum5*o6fTFl-GzVzeXIs>-YWx z<(~35;Z3-QovBY->sX`U9k2Mgd0sYQQ1KPKPXmrFp34|MSN#ot_(j zMFrri792h+7x4?Y29ReFambwF1ccTkqm64AU0hoe+SE%S^deaafl;tFtj62?SYZse zsGF7RRwK1uQh+ZfDQbe+nt-=nLVzV!uJ$*P7}R_#d3Ey3vita7 zX%4?~9PZ=}Lv`;}T;V@S%A*Y_;R56Lk>#vT&#;%n0oJEe!Z}F)aH0l@o1+5YGz78{ z$a0iWPd1j*8pw-$TG=eHygkeoeRi}cqb2WpXn+)1p*DO^?KM{X2QA0_J`*@P6# z0NE%LjzXFWRV6!Rq8zprhyUn6KV&gHhu?FbQ8p>DK{+v_aBx6FlmNLTWNwwD4bU=y z3C}5W$)>8}kOgxrzYr04Su@Kn%EoO_C&QIsVT%Nxwg%=^+6^S7dn&ad-ob`&tdGaAb#&L4f+m-W3-d#zgl|CzX4&OR``Nx24E4TY1@JDF@Qh%9JAr6OevSubMz#o-Zlo!Bp zaFrPUCv23Dak2)kv~qDdd|_mB`R>uvu?TYmDbtk$c+0*_PW7&-6`uVgmt6jdB1QdKw(P(bWyy|h#j0aFabi2M<-gi)6IgY;pt(zl5=F|p ztJsp4EnNf%nFJt(9z>%7!dx1qHc%BkG)EUG&;Y$~Eg>ehE+8N})ICwL(g3-%Zsx4$Baz5eZLv)fe%m8###!C+^67I&Pw^LzpVw_A z^4OVk$Bv&4Hq@AD0-A#F@!#M-gA{OKDUFa7YBd*#J7syi;VCuAYE5@dO?p*Vg;J5) zpwd$eLQ_ywZPlnYs!m)YkW(;aK!ZJTk3hw}in-EOpsV$pYIZ;alR-6^J0-<^;<;Tz)UxK|YHJa*;oq(}3m9eXC)l{WhRCASX zR@oZbH41hL)LqjZJa3RjYUi_5gq}Rss@GbdVjl_iS*m@zRcn=QQMn3T(e_{1(N<%J zyF7%NOl?xzDxQtE*VOgYR4S8Dw;ivTidFGVs#o1K1zO`?)mNr!XSkk@wskr>pk=Gt zVT{-gntY(MV+vDVwLRYS-d8k&8r~@gMQc_MgsSyPrNF1Qitk{~?Nz#2byR6zJ#VVi zNb`DqF#oOpURD{s+N>Ju<*vE1S_m|W{Onl?FZM@xHxE@>Ptm>MXvk0KjG#=8i+)n4 z#{%-G&cu{qQSuWK-dBEyPAU=IIxfYAq`(*-7IZoi3kW(Bf=tnwDN%~ay4m3N?bS^I zDG(hJb*7Y&N=!G81#X1LCdPD1yfLmbvJ#U-oyJ}2Rxu)uiBW}DCjFar_GSK}cr7AE zb>>W1;Lqt!BjZoUhK}DDzlb%Qn^1J8tVm%&6(f9Y zTV7T``(9~37>5JAOu-6Gpp`Dv&^TkCwr2zLt}Tg^51kF^cq*PfKkwX}II-k#rH4{O znc-Z^EU$GQ(wYvB+wo~)$oU>?+2NI_j9Sw(3+w->DL%w^5m2|luHU9@^Ks%G7`%#fC6z(*Ea z+LSUS<@AlzjR(%Hnse(syRBgNL1X*&e{d$o^cdCxl5p6QTJR4CJE;wRL4XX^)Cg1w z#=xqXK`I)lidG4Mz#8I=O4OB-hSaoylrp4dl@^H6W3W8Q0rYbM0`jW*Dy1^2Ni~P4 z(XMK}`aQxi-inOcXrz3YSCs%!L3IF+fGYzb8Uvv-r{R-_0dFKw(ShvV^x4$e>>kawDRI2u^rUa5Zf5tSewsK@Xl$N2 zIDIfHO}_{soAYF@y>mEEHzyAlI7j+G>Oj_;V`qEv&HeYe{yg2E=lTnrD}69^Fe7Kl zJlB?EayRC>e&6%kp5N`dcR<^F`qOLL-XG?Fd_~)HCBK=!&+&Ph&vSf%Ykk0N(YP(S z;~$;+@YEbN=gvRVcc1Ia(|viaZxN#W5As}l;@A?sEq6?#{RP%lVC@Cg2f?k-vHgR{ z?Z_{qbHx3QJxMn0OgXc}0^2%w4eF$ADO=|JyUuSw`W~|U;O+c!o0~BeyO7P97?AqG<3aeU3BYM7M*bakJKU3rIK-4V0aq$gpyQOl z#OWGnjaSmTB$yy37*{nFg=();7}%^!g843cryNykwBnVzE>&Vph_1P`V5qq+85+w` z@|$Q3QQu0RbzQf>8WvLji9R?bEblWlR(|izs$0L6H%-O~_`OfTi}fBazLif2j3mBa zr+8uh4}A(T5BpTGzE8)QFp8U0W&|Dz)~5=Ff`_7=2wjGmh?@b2RUea7R^SXI#tOD_ zZZGGKa_+==CB_0GV1<$du;(F=yw?kl5+*TL z0Du8QJ>4XTZ-#@SpD{22T-Na=8gAJQWddJA@g)+V07l`urF8v-3#^~kIYpEeJ{XWi zeorm4x0cyg%k)VD*p5^9XMlXHRKjS%&*}_34ZuCw3~vWQI0}yo@Pju6{FWd@CG0E7 z0Y4~*>!>VYoZ+|VRs*saH)!c4EVkp@b@^fqj|QxPWdQ#e-N4PbVlzBjNTNIuQDmJO z6(@Des-zCoX*Q(w@Xj01=`Hx=PhnaC@4(%!F)^^{bR|wKdc27Ypr2IkmLdOvN>Z@aKC z$Uhhy(gug}<8Lnbt}k_N|ETT5w%L(Sy6&CH_pTLpUSBFfs+@K9#N2Rx^NW9+_*483 z@%$UXg+^hqW$Uun)@)tb^8Ys^cmDRK?~XBn+a2E>V=A6mzC=9lHYEm%M@R&V=+(v6 zXCAb6Ypva%wf3YONh*11iEYcC%!Re~XXpI+L+3E|z3`L#<&bt^C?A&c@(nFKSp-N! z948@rg*;C#W9|aEf{^7Rx#-)H^J=~wNppsJ$5ynUw%(!(*%}jP4Lk{tfUn?9W|z+6 zjb$pm@#xM{zAFMxeG~ACqq7xC!f;BbM#2g_e8r-&0--}H-9St~5~FfV(ka-~#@3ba zzR|hT7UYLw5#eKG^jf-Fu!qKC!bC)TLBcPV_~lam75s`6K@guK+h36Vuc#fK6iD{% z)?2Mb@+E>W6}KRlJFTYFVyD4w#m^!=(>-5o!|Y=Uk+yQh3nXf4UgowF&Vr*c9Z5xs z6wr%sw>XqKRAhh#)lF?Pm!~fm&6s16)tPQjwHGazvm#gHW8C`j6#^0VFKmMZ(YSns PLe3rYj;>prF^+!$p)#oz literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_369704.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_369704.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75428ef8e9fa60ae8505784557c0e45e52eea3b5 GIT binary patch literal 6123 zcmeGgTWlN0agWEx@%@x2k&@>KXl;`s=j;lO{7VBgV!USTyZw zA;q{wyJ?GdYKwL=YX(itQ5$1<(=chNuhaS_nzZT=YprW!TOdOAyI7Ueiuv%#GcUeG zpF4V1u~pF{r_LNXdNSf@Qq=@F3E#ae)Hx)8A**Ftr(q4q5Z7z{lH(Dax;m$`iBsQ@ zrC}lx9g9HJ`CXz*wBZ&;w+LGXIA|id8B+2HSX;F)dX3+TByZCUl*ZA3qCG|kb+Xkl z_{ea*wpz*8B&*fIAVhnew@q+rVi`hoYwbtSYcgouE)t?s)Grzk!mAP6MME8L6O9@! z`I}~JWgow`QmI4Kfkcxw5+qU@E_FUaN@rb4K=ew%ChisoR}6{)Z8eft^fYmMv=uiX zaC`AiX)|D*9x(!RcwhtwlRRRx?Z4(cV#ha}skR%E@13F^i9XT881e2bT0w&@kp%mi zHJgfUqD8}{?xuNsjkRl-*e!NV`!%Wo#qQ~j2Bl%s0S#+F&3c;M3ABJhrqdign0N85 zi)p+H8J~-W2*tn%3n@1C^S}RQ^`C!OKNxZ-Iw7H0Qd}~|g=gtFqv#XKFryfv00hOb z$Z|$Q|GXlGxLIChtY9{rDuUBIrSiU@j;SN z+|A5%f@i1xv+1H@u5p35I?5;((C7RQ$IsKpk8s$Xarh=F=EyupUknS8DTTu56bt4s z@$4lUYi_%kNJeJpIhv0zvd}y~OXH=R!4PyJ8H-00G8~C0#$+_g(~Lq@*}<&Fl)%x# zWia<;FgM1l?8X|Tu!I!HURI18%g-kS#e}(%^8()kYHjTpQ0Hyl4RHsIyDu4Dz!jE2MsnCl5;d}r+4pt z_L*IKphtO(j$UEt8J3%46Lj-fnioL&2)7$_!w#1J6pmXO-QSMf-kg2Ko(~sY+tSCs zaJrYJjFdlHbUv0I+im9>Kqi*30-)1N(=J@}F56vXbn0eA7hQpQwpAQt>k1rX^-mWF;3txNg z#L9{M)wTAbZ`YE!B1u%4mMyYFrKu!+6=TYklNX-Dk#6j5jXdGRrTVE$$nacaO`C@eEn% zLv4eV0pxIJkL8(_)3SX>dK65T9m_AszJA%!zqWV%xIFm0?06o`sW-kg_Qu%KtJ!F7 zW@Toz?}MTDhwf~DPTu}pv3>ua^#|_3CL#`kCB$Ku#9gZQs)!LPtzb8+KL0<4 zJ2zv0x{~D?^s<_a=@W;UqgR=g~^w`pL zerL(und@2UStVrm_VkIew>y8e;OS4lw4vWoqP&}=&R~R#Yj&3?dx`RwC@Ub#wu~)% z_HE}@sQ(VtBU3&3OGT;|20cJomn|7f_UX4lS`+Fid4i?3-jc^#^7{Vqgcm05M+Pgm z2egjBS1)+w>K*tSoKDCh2nj+@iGYAp$LJ-JF-ZER#~*LsLxnNI=~e?h5k;FrO}xtw z6~^=tb-i{Anxr=0$=V3*)AomAi#ULs2Y;rM4qT;%(J-z3zp38= zNFr-EB+-KiU4WgyaYrN3hWFK~#uOA|t3i7WI%?2~q2^-=1S6_FM$vQtdfpj=#LsvR zg8(7T6lg&%J9rH1Tho(YF^)%?fEu$SWzHe^uWUd5!QG?3!BYY0vm zAbo%=LJhSPb2VLn_cx?htOCmmG$d0j{d6-i(oF1bCiZa0upVQ$QX@IMDDF5^Axbeo zat1k#2{Jb(J_o5CRFKi(Opcl3@b4WCn-U-tc$Aqth}*FRAo0?o0yRjeh6_x5 zR+X<<>dHd4SWRGfg`8m*6mz4;407suOf^{OPK2>(Q0dslM$M3sPXU}`~YxdRfTKA1kIdu5u$;DKKViL{ova4jsBab3PTTe_g>lX zcI5W2>|eF4jjc}=9(&>U^M90nD-~Xg-1ai%&h49RwhqhvZU28oR?lAR{_ZFfzSi^I zQAV@O<^=Ja$Cy4*Iiy2yh>n#5kKYLl$bo^+13NR0C30zEg9_%4Ll)MxYdusra7G^a zQDOWig$q%6JXVNv1^%iWU#LKiLLAir7}cHCZDML%cON11IbGSmZM9AI_br*Scn3JrhZPL&%qE%n1bdK+03ZHO z#D5G`g(L{#OJw^Cvi}uzLA0%--U_@CsOWw`JYNYQmuFebh?O9wx{$Xmw{vCZSKUC} zBazNl{~A+&@r!c+o{m0N0z zvT%Bu8Dl9VQ-z|Xna2Gw7PIfMw<-h1oB|RMc4JOPlfDL092@jI zH|Xa#=%-}?G!kMW7T)8=T-9~D-a}(<8p-aeM&cY579SH;PMX5O8r5$0K+_(uFlp{H@d;kp?k+Eooe)Z zOg7mw)Y{ujRc=sgw)-(_>*_}7R?MZ_w17?1TWs1P^v+)pO1IbP(`faS?$g<;971GG zdb7?|;db3oh5fqIhHd%jj^{|GK{9WT0?9%ZZr5ElT$R>YC#{jyJN4G8Ekbp?E%wT5 zt=_7K$lg?Ccj-aBi_ECk>)dE$=lYShmg%70R+jLbnV%gGN~PQMP`dT{=PUv(+OE2x zCs-$?wWhZX=?qxTZO;vs6RWUa_f+AaZ4U~~p|InmLVozi%qQ13t~<%7YmB%8bdWT1t__n@s+obXLBYkdr2NB95g~rq^nuG+7DsyU9 zGd-%N#IY<*15fJ0dURqvs;)<|>8l%?kX7lzH+m+%(W9DTWmIZbgAw)E#5xQs#AhkQ zotQ#&R*^15wCIG%6L3NX^Egtmm@0vpRQaMr7Ro233yD;8N;R2CG-?W9J}HqJk4J%J zc3L9lq2eZW$#g(3^q9`1d|3i6uv#<%E7$~c^?Y0T_s<6}r<5smCK8ngCnC|Qh&GW* z4n)RbLk3kPI;hFhGlPH)#u5>ALc(%1h2`DS%tU<0%$x)S3Cyjfk`n2c_Pn-h=T0D) zo2A%gEKSKuQcg(qLnT!M;k%Wcpc!!&>SypnWzf?;)ErpyTjua1z@0f~o-gb*+WS|01DPX)#g%jAj}+oY*WjIAqifGy!3e#1Z`X=% zI5R><@Hsx;pYOgo`=@K4UVAV&Yzz+H_Z`l>^{}~harU=ei=NEUqA!p+^2p;~eC;<< zmU`&%E}WY`R}`BUj}|;e%YY#c6kED1@CieCcC5Ap%Lqq%c9a|_7%HPq#ZyqQElX{=wxSry$CtP5GrIQ~;=W>cUl}HH_OEttu3#W&{~E*kU0KId7YYQIw&k|v ze^D4+@egL54?$92t}owFIA?V9ulNVDT+tuQbGhvW>h}jfIsEbAo3GwEVr<{P^3sse zGi3OOvP{w6ksr2355FRv>ii<}Sjv zFCm?ogpOwmraGW&0~$?eX=mD@J9KA^PV*X(XBC~d#vz^3Y4CB>BzWX!OD!7MfL@IO zHlQ=M*9Ut2$|@6Kc4{juym^DsoIN^8VOpo3qhVuuO~c0Ynug4CfhG zR&*%u3k3^rWkeMR0=m)V^ z*DQ&Zga6-1j4|I!UZVh;XvlmQc{+=Q54k!mf3#-x|FHJ^yIRv>rAI&76s*|jec*f9 zl(N#yA8q=-wJG>Ka`v(7LpIID5d4C0zuLcP_7R^dgy?w*>_Ue1Gw@ssa}EV@c~ki*A?r! zHrePbj)duQZsx%9rFc{hbJnd3`5Q^9AjxX303;iMECf;yE7UUuD_xNwe}PNft;wn; zMI)*#?W#w1*CTuCk(ZT2M2}%owUCO$kIG@F!o0~r9ti1@6LLcwPePIi)t=o}u85PS z)5;w8>{iGt2hbpIQ5ABjhXvDRC0ddizbczN$+RR0HX!$dEUT<9LCXXPF|7GaS5_K z^Jk7e^`VaLk32U#g~-iKcRP&m!Fyx(PTU{;)q_!NjN;{K^?tkdsI6=1TJBom*d6}v zz!&YyFP^zSCOsG%H^#=7XD;1uyZorP?~~BSp_`L;x86IkJg}bJd-+kIeQAGgf5CNg z-U%-coicX+botCLm(RzH zGvmu~Wm&yq#OF$oa#6!HG>7Ta^b>-ep`RiooTXRWdI~|KZA;d<$oyI?2`Dr`g5*Hv zq!n0#w>SVl_TvmxVUO}QG^V5cEMR)e?T!RI1IU! zOCl4h#*SmwGyq1N_d zd+3R$oAMT$0t<=xM2P`-)$Lmtnjb2007DAg8d^G^J701Vj7Ki-Lf3p(Ng$XD`2r<3 jfkYH&T_bYVPEv^Sd?k)iRN%=026?xvG;Mv?Z7t_N3$Y`p literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_405645.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_405645.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a25e6ce5c3391da66a2ef8f0e9fff426745f43d GIT binary patch literal 6623 zcmeGgTWlN0agTQ-kB{Sq#r8W?MxPIZ6?TCnj}neKfS3fu|4a8HFHKDO}grgH*L~y-lU)2 zq~B7%KVs#qZ&Ra|>WXb|qfx5?@z$ztwmBkbzmHWrrCJW2IDP1s?C_DZs=W*zK6(1^ zkrQEejp`s>)wHr=Yp z_8MN_29FK_Z#(`2s_Wb$H_2P@nvzX&Lj`Jhij1}=WORUx7P&*S3{#QRF^{1JxjotV z?pG87j)r-gMsv${%nAF==(d!-vO|ZHoi($AV|%j5q}++^DmT^JU74J0dU6CplYy`I z=@!r}N1!Z%zV<3;wu*Lo(t)+4Hx$|#tv3QlZ*oPa&3fSsFa3C{@Z zg^(1UP-zTK@L)4jA`%f_qQpfK(-3!S~GIzG%+>#nxXPgTonUUTV(HF>#D8B70ZpGKLGN*X=rB19l+^J*tT>fn2JqF}j?F-$R z?(B)|ft#LB0v`wNTKA+5t^1mDow>2keLZP<&F;@Jiv8K5qaoL(IJQ3^38rgJM~nuh zt7Jk>e>SN&wmmSD%(EpDS-g+z$l=m-dNnty_`3>B_bRhPVRqa+_{q@6Lj`7U+5qF2 zhI~`EV($jw`R1VF2x93yAiaWt_#Q0J|Hz6wzNN;+#@v>b0i}6w!M$(Q-LJU&|8Ve+ zL%$y?xR0ex>%h^o*pnMsF%>+0Y4bf_!_v{kqq&!tXO))c3ckLyrRZ>Hk7p8Tvgm2f z9alWhrpuos)0;Z9=4{HntT?wVU%3@jy80Do|Lw?~Ipx3^#d#(*c#kr@GkA3{J)GT_c{4Zs z!TIawmk)h3_~GE{w!O->y@j@YpHVNvr3ojg5m(sL3D`vn%}}W(ftO~4c(^u%@B={$ zUe1GGTS79LtcEW@sg#;~oCn@BY3 zd~t#^B+W7<8zY7!EtxA~kWAxs%LWK~2)Gsulhsg0=YW8x8j!&(--ljZ`6lD5T5du& z>8p^5r_5iUzA=A&`o{d^Q{>mDZ^*At-;i&XO&rA;Z<=p-%l0LZOb(+UUA{C0?2+#O z7s`FS=l8XX6#of9s@D6*!IjU3glSQ=Vo>vFV$i5knx`kVyH1=4P4k+Ug9|~s+Jct4 zzr3+>kKj9m1##L9qs^vL<8UR4c*m=bTzCmE)mzV_4g%!}d?KlN)RG-Bo`s>zQ1;Tb zlc^&`r*C0CGoS0cmP`#6y-iE|7x(8bymugV;$G*ruv* z(VE-!TjusJQpeZq?(Dw2tu=LcoqA^ZutEikw5v!nMY<9E2EYr>j59|R=m2ys*faL* z*=x>k0XDAE?F!wVyHub%LDn~5wgqd(n%#8`rmPEkT8jS8qQ_hG`o8M&8BL`gWOt^9 zgneL)FnrYl!CTpr{0o9N$SH{=Boj#jQV0V_B}vYhq-p{;y!lTQ&IBh_OHE`HsZAg1 zcnHpXlfD7-;8G*CUdWHnzCk71h>2El<+ga1>7kdfa?A$xWbQn+KsPL;TbLzGzd5>Q%xLyIU44JMlCLdL>FhE zkfFLO08ZW@--3Lu9J5-=2^?g7K}xkryeP5ZkjS&UYLVTw$P2Z|9^n+$<20_Sfv|7} z&>*cEA+3fa(+qhx7mY(g4pk)&=i-6`kHMyeen`|r0fzuVi)z(E3N7yxj$$$c%CVJn z6cTP)i#53Dls1iOt?CS^zQBt!F-cU(i~O8w*`U;kLYg#cy-*|slU{~T{1mn|g-SH? zcvC}bPFHGh&F4>@Tywesa(Pn6AGuI-`v=T*W;wLdcDq>#9=tPp=j7dyH&;hEWrWL5 ziFcc%^~Tnvw-(=89=b)}?)|hW|IE3&qwMPFxH3APpT2yzady3P`$sJwwyaFv?!I#} z-}^YZb9NoVo&LrCW$Vh|t%>}WpZ|I0FUdb8^KXRjdbzdctq=V6ChMb)|Gy&3XRo*Y zc$5iUZ~yTqqg&>|2(jN|N)44>F(CMfj;#f@tp<9OK+oraof$`(OpmP7ExF^%QKj|y zTfzLmX=V4B{JGcjV-e-tcs?rR#Vbm5t^_F;al`=4QNvlo1B{(BJVMAaY*=gTSoSN8 zU1@Wc{EfY2K`p%{7qWX(r?hx7Xcu0^C(2ZQ_E4SW_PH25p3T6+hDz&@fFn4SoQz8F zT$G5568H?Ek`VlgkBbRGJFWsgE2^!0?AUR5)A<-_GnFHKdSEKS&BXWv!X)tF#8>

M#23i^cf|YywXPB7cLG-fCBp!*ztn+To&`B0mpU=F6?y$jI~R9;`3ztWNo25> zxnyE1;Vjx*3j>*f5()5{Kd>~mI94(Oj2}3B3$aYBWX2ectj>kjOl!%4F)K3J-Wj_( g_Rt2{!}A0pm@n<41mS(~3W=Ou1xNQ)o3?=e0bccm=l}o! literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_42419.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_42419.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df82689978aab4b786c874a3375eb3b87601d4fc GIT binary patch literal 6189 zcmeGgOKcm*b@oGY`KS1iL{i^4jm^ZeB1Lwp+I5rI`ni@I*>Mf2wzo8Q`H?A7-lbw& zUKVsx^X9$%i`8mE(ANI&o9JN=LSJG;YfJ^${2G8|#37D|qsj6}OcE+i zPE(UKfp}s{)uT-Of|QHs!bQ=|rMjT~{S!cR9mN0O{cI%_29EnOPKk&@$9Wv#S}qtfn_p~|DH zT0Avcw2jupV3ZsYB2~rfZSHMdb(O{%U3IowLXuOW=X40ofySx@O;VF&#}tm3gTn(b zSevD) zD(sca8eG-puhCZL_(~qhU-ex}jjU7q=UXGlwqWm6b+*BWf{YC@JeiYc3PqMD>I5tOeFjF_W>VmsK<@$B~X>F&!(;i5Pb4DsF3VCZ5n6-_3(f)PHE>K27icZ#2$=?1Jj z91n_7mg7T7j_+k>qOpB5S6N`NursM-g2i;UueYzK9~eqI8@|l37kMGU$JyGcte67j zy+RMzhQpuuD>!)>^kffmI2Y2{bZ&gr-kuqL=x{Dum#n$qs-q(_vSGC^Ok^h(FXi^G zT6bp#i?)W$&_jo3G5kwwX5^vCvalz+XYu&r;oHvN`abuqnffwEHf)ZCNH($<&3jkQ z%bf@1#)GT2Lz$tX+n!0Y6*Z-;MrYYaQ z5?*cIKW}~LZd@K*8qD?Q)3U#B)xCe-sF4ojM`eG{s=IgIxNd7$9LXm0WPai^n{3-R zPZgcs+=%S_-ilFn_RJfK?#`^SeL(D>g;KTXci_@0*O0=ts!gIL zU07WioT}j=HVCg-;bU0!h><6flgbFOYy!wMY0ZdW3iB5k6)Cx+c6D6A(c! z2Zb3?F=0>*DPn+vXC)e(;nmop=*tm1z|=dsimn`!862cp(6Oam#7UPHG=-42pk;R@lCiH?oLh!3LNLb2I;h%6~p;95zh$GVkB)=p_rzuXC zrfbqByye>(r-vJ6%o!PO`0PtVD`$ zqA^UpkUmHww>Q;7>c7z^(VX!UU5%CBd9&)&7xJcFErGuCDR{A7nfXFKC2%SE-8sbv z^MB}5hlx0Bvm=S&ES1q( z8Euu(jz=xV!mG`2We&XPhk)~O7Zj>`c*z}OLLcD50o?6}?$KMg!!@`1x+`>OmJ0?* z0ha)Zp5xz-h4_F@eX13MTGc_-*H8g)kpu-3)Jw z9V#PHz*#`RsF>7rLao;X>hbqf|Dn>!nC;7oZ4t=S-!S(Er09|bGPe$b79xX zwMq8DWJI2f6lUIEYr4GA`O2^TpZQm&?(VraTIhP3-FX=bDxAvXO;A=u3++Gp)9jzq ze@qw7h1Oi$y7#5W9!s<7Nyq=+ko@?~)^Cq7!JBR09%D4mJf0wa;M8Z1mEIx|9HOCh z-^&ktugbnxzwqtN+UBYGi4De|8_CDyme=nD3WrX}z3&!I{6ZN_>`b98YJKhPNAk^DfsM?Oo%Cn8A7EPc#ltrNkP4?0=_GXxqR)| z2z)vD9H}dn?-+AvI?2t(`NIMWa`>~DxC~v1A_(G3WcfR?{sXm?^oV3W^j-Ir$S&eM zQR={7FQlwg>QvE}kjt~YcWLieyD+**A<|M2{29?n*o!vTLOdHUQ2?)-9SeuDhe|qt xL2ILb`Rvl!k^y53GT9eevMrlNz&76_5MlkwGD#4w$8S-{zI)ZS=ek)n{NIA{FKGY( literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_450387.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_450387.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6fe3287f6167dd69c00d6b59dbb770cb2100ba7 GIT binary patch literal 6867 zcmeGgTWlN0aqmSQ-xMFFM9H#c2exQSRxB%39NUQzTk+ej?ZSptB?p>!lqgZ8d`HES zd~BH_K*%HjDfB}$8X!!6qSOYeq94tt3lwO8{_rd#Cax|ZAo_7XS}Jl7pg-;Gkw;M~ zWv6L@<|7N-?##~a?Ck8`?9BYhW-}uwKXCqP|sf^YRc)N?2SDy)=IQis-Bp14z`C!3xzlB|t%)s3XrcuuG%cO?3Nw4hbL#?$N&a;k{+H)dJ{%{E>G9SZ>d{_)#+=k>>MT%HMG$sxXawB z?l|eIMrsmS=`xJn6A6DlJm;-a9^n7L{eqQbZk zOK^$lz&4eAmAfq5hzc>4ITzu%3#vnFxU=!$fvJhhn8St1gzAtIVuTmC>zq^>ZL4;S z!z7Sp9Xhfey}lmh*P+q%bEE6$Mz5>Z>O%3!1ShS-M%H1xh|Q*m|7z6~niRR~!9-|8 zWidD+fbxtWIfplDy&jE+#<&x(o*d_}l1!j7E*c+>gj6aR3aRFJI4p76Se30z zfrB6uZ>okwWtcH0+~i;u=rj_7L(g?>3p*Sz}auXwweG}6h5I8uLL_Egfey;C@=MTKt z4-|D97rx1JV}ckHqFnu4PD%jxUa=2^!$w>B8(hN_deVuS+;g^BTRynx+@3o1$kCKd zW|H}fi;iunlgoBzb~rPf8!5ERb4tfyWy|43`;pX%lBYEvQ9QfnI~32}1(V`Fm^!;; zZ%UnbCBzYy?od8YyT(y2j;%iv1M4RWws^1KObFe+M7DP>}kn& z`AkaY>w>J%+->$DR-)1Q#{>@wYy~BlHaP>cdbx_ZTFInFr;mF*?=4^`6G(G zbHzy60;?2bH-BwL7F%Y=`lvAl$??RB+}^o8vwQN_<_(Lkp0x3i(=}(9HRPrX{fckTqH}MWDY1^+3kutu zHkMdRdNN~8Q_Ei8+?ClY`PT|pmA0-$Z+F^^X=3hl=BEnVo~H1emlSqu+PLIt&6vtP z$h)`Pi!An>G1Hwo_J}sTbMmc|cupkW1-ek!T_v+CH>8+5KrN;nOJ+|_R?ItB7}E3{ zsKsQ>8Zw5ol%2{uo)Avbrf5m3c-tCD_(#!FND>+ z5SbjUk=6j7!1I!fj2co&jcVA0N@-N1O!LIpakz=(Y4kG!Z0j2P8YDBaK{ke|vCi7K z#w&t@=ttUUauiFyq3BN8X#GyBV{{sLKG0M3acJ!0tXr38DBn}0sgoIY5hj!Dnztmg zSO+o*9vAE!bZpunFrd0zJUKQ52h{|l;CTwt zb3VUcv2INblpLPyt<0_b{>)EOCrj?;xkIyu^5X0(VBamC+>LjS6zS&lk&?xpJ(xL| z_ZHasu3~fleM^6l?k`&UOBQGLQ07oh%9BM)Yk?_DEp+~_>o;A$-S@C}K^<>W zKX+l>I_+Tll^uT%^NAOSok5KeTLDEZYkMAD{W?%mTIGDn8eD-_lp4 z`-+ynC3qZOELz%9$Cv4yh2shxD6!5GYb&unFp#D8T^~koM}HApAnv#CPqSG^#*rr; zu>OS`FeYowSaTQOb9@8P_mFJ|+nm3#$aVnhr(l+>Ib+W4e-9*R2)bHJ-j0&XU2=QA z+U;Qs*L`ec-ERv^_6wf5-I={Ob9o1$7 zz#G7tXtA4&!u9bLzMUHiZ@kOisanE1uJ~?mD4KmDWhzY9MKld!I*nu1?tT-FVd`7) zGj9+N(2529Px!&RWBP!plk$6Q*3|l~v}w@Z&F^&zT5Ocf_*Ob4u>bM>J;fLEf2dQi zd03~s`2#x1L{QQoGoz42*jRMW3o=K05gLM%NE*R=t7VsDmbYmB8_Qd(t*zSHtE~gK zx{U=Ng5@g|@MaGK<^68RSQx;saGn=m1YCRxR|lbbd=ytWd#U*vDjk~Sg8@>+F`H`O zg*PK1A;4(9J4E_80*4U2sS4m27Du4q8sPO+QzdqV=r=&CmV_WBxKL0Mxc&7=Z#{CL z9_bU$U^&j>sunDW=Ku|`Dg${12(pckc;F*3$V@;52??Bt;A0}zs)!SI&<}o&Bx1iK zU{cMR3!{bK;xSCN;n2Jql|#mX)qET9hqPrNtkz|QoQEh#lhK5vQe(ojYFeW-aDAFJ zY9s&W4Olb|N>fDl}2_UB{b?S)|wX}a|yKO53=eK^^ zq6A*OH+b*dgKKX*yv8fn_~N+qpgFPZYn!_@d#iAIf&H}mv*zN?OAiLQhl9h);Bax` z%?G}l%N@HuZvCire)QAMd*_PX>&YEAmn$w-!90I*VWhb2l|M}WG5Pys@%7LHH^0=f zW5sK2Hb2?^{}-fi@%Gm5jxfR7?cW_?bj_?>BVKYDQm4yDNd%YZ#FGEHhyFc^f6wRs zu8ciRrLQfst@%@hh|>1_LZEo~g3|lL;-w!Khr-IG;bKHAN>fT?x(qH2F+f7|8hMdi z!Pq782|}jJM7gY*=YVFq*@qeRQD!(8% z9FOv!AZ^u3sKFi{kMomJ;bn0I#^D5~^h>D96hRPQAnTuz?JuYek`pBRj{hxxnS7Zr zl(!?ND=TN@atFqC;Pgt@Y}c1N0b8XIX|0SpOQJ2!E0#{efn$rqnZsoY;3cnrZfJI> z%m9p2L7r?h6D=Dt#v-#L+m>l7n=ocZPWLLFy?UKMgzZb~AVIiSj#9|6d(qzcmPOmc FzW{HttOfu8 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_506478.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_506478.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f66b97682940b04ae03f008b9e00be6dc51992d GIT binary patch literal 6619 zcmeGgTWlN0agWEx@%_@1q+W^jpd?#~<0i7>I*u**q4*KWK~bqr4>a#cQKCrs?qpfc zv0#b-p%Me4&<{~*04qR&$h86M)A_0k6lj3{@FF!Lt}Y-T`cc2A*!}58+u0+Jq8ut} zg96P*7P#G+o!Qyh*}d7B{gc&dLQn=Kemgo|gV0wPQHw$WnMT@S-C%?Wwh^U?-jX542+INV~z^r^?URi_UNbf z=-05ukclzAryVhsR%dz-jhHovHJ9|T%n*L_7FOeuY&>=0^659|EB%AAr2zI_yxiA+ zAy}@a9#~2Eo_qoI0ul&37MC$ip)Ic&;$D%SaO}gWDRDZ>IJM={3b2So$1MO={xQQh zd4y8e7*zsGfI2ea5?!JV%VIRQV2c5UH6+|+^HPv8yXi(uqy{LgPl#0U`V+pPkO)8x5l&!GRZq%zT|2E0`IOcmI{ z+WH&xq`-kyfq)^IxFM6ExNL~A(=?0#Qw9T3z(E9!2~>PaD2y?NBC$Y}z7dG>&w+}~ z`<8$Fq-Q$Ljq{U%Alow<2#yDY(Rl1&APl?H!*jtNft{G_0jwt!4e+Bh!v^CFdzhXa zjU1Ysp@Bfdt_blMjr-{zz3}ppS9*aWH_@SKh8}0R7#pR_$I`q2(hqaJpdY?&_@{6f zljvz1a=I6+^VUpY)zO?h|JdQ$!Om^k9CP8>aC&s9UaIL_wRNrA4oS8{_qhi*@7`Rs z^(N2es=S$qRMo!RB2{&*7$x`7 zXe&riP>|4uhd7>E5U722_H6phhnG{tMDYbJ)d}1ecj8` zlJ8~7_Hv5Mx#~0h^F2#vC0AF<0EV13=6g_eSN;exT5=|LIwF~xKzU>HhRKy4mQ0PH zys-(CH#^fuC6jO4KpGoBd7U+HLC&glcpl97SgU{k+`H#eKTXdpi1Xso=}*spa(?|l zuXLbywf5+rwa1>o<|D2aSJ(k^^|->$EeQU)RPYU=L?jG2t_36#VO07(K{MA;HIp!M>`hM!j9`0{e9}5tFRFS+~QY|-&nXKIlMz#|AKNtfe@U(Ed;)5 z3kS?w&A2*=t0b=Qg|MYut@z0JwdL1D47~kxfScrH69yIk$NRP5#zzB_tl|h|ec{@X zwJ|0#;Wt+LEfkTmu6W_#%Z3JVGSMe%!*DV3Mwqzd@3~h1liigZ?537IY{mlDl=(DEuV5&FSpEFysjxNdEmDP4-&0RT* zGu@jt`;vW|+UJ-0B&|P3IdYUWN7d#iGXQgzSxb8GLwh;sZN77C@z}ESqt{Zp0{ymq zg-|($R;fA|zD_ktRAc7)D%J7~)_P>U^{CW(^x@zattV1;m2r(~dGtC=I%k?SrC>y`#TYDkd<`fcm-=?4RM2UfUymvSy|uDT}Ya_8Jte{Xf`^!W~Cu_p(( zqhOvOe5C>rTfqzc4bBuKltcouiUh%s2@S;cB%@1c%i=kF4xTBD9{h$9!;2_Xo;6g6 zx)}qU$>P~9lUl)JK(?g#qmZ^N52=t7DH4mm>_3qhBEOTo3i*_xA^G3P6Sa)-V_lh+ zKUlMLmcElU^*B@c!KT=Vm7aROlT8Wymhgj3|Hn23pNDPAm_F7fbP<%$i@Gr@o?~}W zpSR$Bqzj>|uoDo$QG#9N5-G;4L~9gdDYn*PYb&;P+^Rkn?tO|WQ1qhVI8bgMgp6Jf zv4G@paajtAEwN@C8$I2^!m*j$j$Lw?6|I3(~JWs?$HCJ)0F(1d|W@y)}34x5@pkWsnZ$$<^K zBYAG4sycab!|nvg;Yyx=>Ol34cdU!nrNDB-!+Obo>e0xfi)+I_Ums?qVJ17luhj{g zwZ4US=HFQwSfL&s{Jbvv{LtD6y*?6_M#9<28*8=Gn=NghdOz_lk3H;obTNByH@Rhc z(_Odl>inxqrsZ=hquHj{{xJ2&#P1W?w}WeLW~2VO?P^P%>1p%-UlEAh8oocu1Qr{= zKgy_<*&Zg2x%A0_{7DUhLv(h-b70-mDS0}-^mNbKQe?Cjs6tJLZnh9?$ClhG&5q@50sjt9b*)C}gtF`DT52BgRa~;eKKU=*gP| ZBCKCqMhL>aeUe1>_ElTQyJkiCe*@FWdN2S0 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_543766.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_543766.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2b4b8b68107256363c0ec73b2e39b5b426632c1 GIT binary patch literal 6224 zcmeGgTWlN0aqo_g@*{R4q=sWVmMhz}l{ArJ=Rot05@m{%?@p1e z+!YKEAY>9i6!IY&2oMHE5vMi~f95j<3N%1~f(N-VadiO!(U1E@#YR8+)!E}6Njr3$ z1O=LpEO5IsJF~O1v%9l1_Z7=B2+AV=rQ*IIo?z%3XD@4?D7teo6cO4!-{lag!iKAz9PYpbB z^7N6TiQF_(3_bWCy zfr=Z`VXr)v9h9)bTCh6MG-q2wUkg$wWx&{9&~nKQ`kN<_Kvz)IaZ1GKU3_zRrv=@f094lPOM;k@bfBHQFa499Eb^FOA&z;tNM#GI~rjK5p zpTPX57i684WHHGL+#K*^E>??!tF^X1>!S=^k@bWg+;D0+LcN1oD`3x1lkX{`OoimEoQ`v(tIo~?3#(iFUI7VOnNvrEu`gLk{I753v=_k0Na&F z#iSXI7vdRS7~$q;k~`6B7E)aESWc2b z`iQs})Wc3x`Wqa|EV{oH1wyOr3R{e=`3JINcYJ|mHLn)Wtoi!0$2Wl1zS3UoD2-{I z!)t+^*%Os;U#U|I4_!N>g?HXOsD(ze<9Ar!@??IpaH+Usjop?#w(bpNU%caw6cQh> zAd>McZ_jTpj2E825&W#{)2>_0-s}q-;m+b<@!Z$pp&Y&Li40?P&MyUn@)<|Ql*e9aNmQx{^eKluM}P`zN$sGtg%~fvr&zW-gx4( zp-+d_*nK$?#<6YXj_sOfI|wg#Mm28~OCJL1CI;e%YVsHq3>(Yszji?D+_M(gn{#Y* zbbZ8JWl9IHCDuCj<=8vnw$)=R$BG9_s@A=CExa%1s(1s1v3w>+Rf3(xF)g?y=iD&4 z_Loj--MiPqBfv$wb61KZ`QK}FFR;?S!aQ%a2&+gTX0exgs5D;vN6;H@y+gGtcEyn( zRi}*Q8bxt-LJRLap^)GyiP>fiOr8hkXhpBj@rd!1^=3SRUXKA3@-g$brZ?tqO>fLk zJw|?OdP9C|dPBZVp<3=zorAaY)D6dV^DQ>sAZP6j67i+KYrIByckW`R8nrb$*fe;` z#h5rR=?n%9e<4LF-G&{i;WuYfj@4IJUMUym0vVe3 zm@R9zUbFS#*7C6gf)O=}LvfA*<^5quH60Tu>OUnO09$dm; zAB&Qr6MA$9FT9bA3sJlA;6kE@b2P}=0w#d74@g2F$*Bc;u3AyJbO=xp)ln zvfa(dNHemh8QCiy$9jz6%7`Oz^eCQyDoX2iNGl=xazft8C)1FSLS^Qmd|LEkPHbA} zhvY>P593iToiXA8BPkOPVKNKKab~Ut;Q?&~3Vd?Tm_}#nIzy@}3erMKmUQZ(a7A}D zDERk|6-*4&qKOWJ+A5@Y4VW*%Cw&Rq3LyoJf}!k*b)P?bd_5e=o?Q0@0P+X3WB2{2 zv+pDJDqD(O>;1e_iyr=B>Wh=NCSSij$!n8*c}}|3A#XrH^5)8$r4u*l&xgP4C_j1j z))aSpYFe9`F3-Plt9@}}aO)@CA9r7y{e1fuC(FYRk_Q(zLLIBmtvpv^t{uNQQ||xu zUl#tV{#h-bkKYRM>zz+*Mm!zN{el0#BBe7|dw)L4#IE-Je3Y>)vpGpT8+2q(R1c8| z4$+J2T~FQa8q&IkzV6zQ_vWbFhp`D#BC&jXK zSxa82LViRXC80S+o*_3ec9y)4kZXcmZyzW{wDxT|XMuXxQ+1*4;i?~bLfKPBJQ-!h z)A&X?YQF+>U#;Dlg5O39knHQU1&KI<)2Z2{4F6FwX-Nj3A(|omOF~-8h&lzg*|@VJ zzA(DGcK5hx_}lR*GN!9tGZ+2X#Exno@&}idyi)DQ*Z>MeR(GuI_;wJmdlVu)H7jXz&@AkLc#xnj3R0us8 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_560861.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_560861.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e54582351d636430eec3d78d126599c0aace5614 GIT binary patch literal 6274 zcmeGgTWlN0aqoB}k4GL!@gY%f$%!q|w(M9*961*9aBRnRB)hho#!{RdXx>qz%!hJ! ziY?_>FhGEiO#o5IhiD*Zn2#cIZ6N;4M?X@aKm+uL2f8tFbpZhp3b(xA8>8!87ksvrJ=!_V>`NokW7Yp^IaQ2s-ZI`A=%LLsO?-dWAcCJ@ZdKI+tS0#G{ub&rB!dZSWu7#kX3G=T>h!ex5u9X$v!jKRYRWPo1u5=j zKpOlW#iKZJ4^J+`W&sSEq`fUtu+9$@o={xZ>KdngyeaLMZ3YWZHDq_TNNIBr6xSRf z8^ch$BXQl_n59d08LTk7*XTct-uN2y4JdAfS~i3J9f}#ZctuJ;u{2-MW0zIG?UDwNI%ny1o2AuAFgf67g+p0^tuZ7>!)W?97PNmPV zO}f8@*Jp4vTj1@-d*0A?KtpMZ2A479SNhYP zA66(hfr=f|VXqv99TcC!8gP0DWcZiuFlM-A_Rgl}2BesjA$_NZ+wGBJnhzM9&33-0 z2^)}M0mE_xT0%j~Nf96MJNORH8Jx@1SS%R{5}KKpmllMeQzPZL#w>`bIWd&rV!TGh zQz2fngv3yCPSDH|KxNImEQl#dv+CHB4d4Kgw(Fma%b3;4XZpC;FWpNpsXkr50rv(TYXy!*~(9J`vwI%@$r1xeFmV z9MxzHMg=e$gCVK5M#qJCDtwVka#Czr&@8D)MB;eODlH~BY$qF7jf280=7324@psu7s4rC z80QwEu{{e*91u9zLOGS>a6k9-{wF5(?*oc9$VD#k+(ki53URJ=EGNkzeO!D7^uz92 z`Uf1^47#@ib$HjJ#7{9A?HWp*TUbi>(^dHJ^6pO1XRf=I@(GMXFB zPvxJx?)kjuYR^sPsmx28zV5@Jnvb^XxiC$63-v(IEn z7{_*&0;8&96oi+$gQ_!#rSAmkbqvJs#Pa<2807Kk%N<;wRJ)%kcT8k0n}MEBm@7>2 zzRcLf_F} zBJ4SXRI3+Ff;+qv5*H+m!JzI2q@YQou-DYxfo7?B=OEqgkm#~BbKMbRC(MByjhcfe zM519j!`X@Z0nxneoX%mO9D`5#15_DQb)v4pwZv+o_+q(p$GZ7*<|=bz-!(QnkvURv zbmS*W_P)&FP3p1YVU-G2Xm^EXD|8oJAplk!IY<8V2d;;KaMQYO(E}ZJiQfvVC~%MxkBip*DD@x#q0amcCXnYPQU`f@HGmAAhqlAPdEUO z>kw%OlhXtQyd+PhO}sfxwFD^jfVeRl0Nw&upca^-NNcc*xAK9)Td&jC>IbbwYMb!e zxLOgaB@D&8^+3@j*01G%BQaw7QS#b__W1IfeuzAU;%y(BTeSSynvLDm_xp!h(}Dxz zpKXeb*zPaHkFqI&ljEOl`aiZQ_&jV=p81$cn`0<#QOxs@Fg{p}zb=sS??UK2>_pm{ zrklS76q>i|fd|by>MdJuo%Pm*Tf@f^2u9Q>7R5RVln=)sIklWcLH9}VAmAeIyn`L+ zofmM0|K#Y;u||a#`B0D)tSO}W+7ZwsaoOwag&`}3)HUUWmq%O6<%vQ@) zAkhg@nq3wonG1&?+uGNPjJG0Bwjxi7uV6iU z1*&?!$tOi8=ESCjen>(j5x?C*o5tt?fSziJhcKB1<+yqV2Jrx`M+ST>p--bR4V@vS z6$NQAE=!u}qOhddniTv)$cP?>YT-jbr?!hB-UMd+_DEmBwn7X+Bab(8V#DRm9NX~u zGt(Qc4uIUA%<+3})IIPCdxb5A*89KcR)dGWoc(h8=FFS7W_WdmFD0a#0eQ2lZ|$Ac zcZw%&&|i#w6(~J+=H@JSYj#eZohvQ8eY5M*=J1ZsdOz)5pZ{X?%jweC{p8_Go8G|M z!PSFBX8qWWXld~IKP~<_{l|3ajqpt`ztO#I%kKy<_lExeiWE;@>Hp~{6S^|+(^1B- z%+?I?tjCf$QGJ0#aEM;s=z08B&rY>x=hr>EbIzi928 zXI?LzkEmznN-?n{T~=dDRmf+EBP29O$|DwIrw{V73tH}E*d?VNb!qt;W;r1d^quteg{?6L=eO` z$niJC{vGvg5Z3p5-tDQ9&l8i?0p#|qC^@A%h_NB$^{?$--Tmz_V0TT3bkw*Gk^O|F z;`FY>bMdMP;0?QbWimHeH3JOkcxUh0`PK7PE5>NVxK{ddeN`LA807X=?HF<(3v-tR d=c9B%7k^{v%zRZVwcd{+z zSTjI?kWB!Q=!a+^Kxv>r|(T{w(Xn_Xk4-a%B;;I4yq95mrj@*9qtFuQQNh?&G zrUjahEO5IsJF~O1v$wM|`xmR#grIzN{dZxT8=)^Sq83XX*d_^tRuO|3B8FybEip?d zI2qT>Y6-*=QALhAwAL~ttBtBsT1bqheSc8PXy4P`u5->B7#)kE_LhNlkkRkbZ`h@u z+NGalji8G$Xkm=+X=hE16`I~dvt|;p=7uhoCBkQZfagCa8(*6`f9!So;^ee!se#AO zo)NC2zcAcKGKY^2Ma5`EzwM|(nBqGti3`m3D z*}{*t+fhWP*m1kATguhKOYL9;lp#c0gSV5>q}+m8rvMcb*VEq+i%ue9hvfqBmdfP0Ay>>yY~iGWvGN7z7z!aZs^|wJvFJ8AEQd zKh^c&KeR~HfHlpk)xesT2HYuHRX8=&GIMuRYZVfQU?%r6{ITth}E*awdaXP}t+E~&L zb%2Rnk@bGgp9rzCE(o21tXpQeBrh8jZ2vx4SEG_bQZ~f>*CO%7xU6BXEyy}vNOG*K z#Z$;8Hpa%;gg^^RzG0bqjlRfUj5LALS&7hzaB z4m`>E$IZ6j20*q+I!jEALhn*>hz-EOPX` zUkHR{3WH%5OlQUr&t9hSA}#Z=WZ)8=p!vu$D;tx+AkPXs&Bz9RF-~J!8NqUNEE$Rf zWQ{)%kg3`T*a?Ghgzu*tB6JdYgl3xfX6XX66?Arv-1ttVGzAypUL@?&(!!*MNk_@|# zUI<5?U09-lK*N3s$pnr2>7N`t@ZxhX07V|8gI5{)63Zpn7~MLS<^_z3DAm_h!h7`|09R zY5y^4b}w9-_Dd)HlG|TmVx{DDiMcMhuV+jfuzXM6Q}C7@Pp>;hB*(}d@)Of7 zQ`xaUqsQwX%a0Xi*7UG=m=G4L&+Btb#epX{)nq z^+f(e;Xv`E)VrtLxi@31Q05GmJDvTdMD+tFWy>9ssJ@J$(%G9eR{N2APjv{{UAdWJ zpX45wY~y#SPmf5?os?`RKl7E|n3vAXOSbv+sSU0E{ZsFr%KW^bDTHoV*F8ST<6C>` zlhIqF>tioVV=tFIul!Z}D%^{3Vm!FQPLS)t6?Scb*4rQ71WKIpd-xMp!NW$OV`pIq zJa`uxiD@ZAN-yd~Ly$~S0+y!~ky6GXxSdJx-9!{_h@GVw`|Lt*L_E9DHL4%ph2Drq zMDhvqx2HGfZ%=Q|uX%#}_VlLw_VlKFt*B#kjOLF1wzFnM<;F48ZX6Ep>I3EO!S;WE zU8PUk>NXSL`>y)A1zt8`P;mvkPXm4|>|bCNpCIdN{@7=1_v5q()7AX15evfpmzT96 zxH@>ip-AT{=DYBCI0-U!&lF7;nV(m+;d=_C%x*L#7A_ zmPtmN(lEM|wk1qymlWvX&?q6Th=Q$wGTzx8g)!WrZr2Z2i`3SGULOHzfpQpw+LjQu zU5Jk*V*R0f8;L>9caqm8Yg9DU{3r56Eo1yh*P`VQ)~szD)ZfXPdYnZ4U{h?wcJCVB z$)?0I{FOi0^nYwq@Ojv#jOin7N*6&8gy^D>GVUzuI}SV->_KP_b^-zxO0cPXBE^`M zP=#VF_10Q%JL;_sx2lgN5R9l%Am}*)ln=)tJ=I@CKKnWDCBQk{$sLC3!BJe{w;|<< zh9qy1@%u>506nsvVXs62tWT%hw~)}`%nLFxM;*Y41>_fywkV;SY^>!ekm&fdvRPnx zfe!c~+j_ng+1HBfZ$)0<&R{)W$CVNp;&_mog33q9I!GTOr!qi($wU&6ctTarH<<*7 zZN=d?8t8|-gy--T?la0JB@(D*S`-ciXov|QIfPWKo`nHgBQTM;G8d#(s^XB>axA|X z6L?v3iCvP7O-h?kqFbT1iznU&Wmx$5&tPvMo}iG^l|Hj+v!_pOcDmDNH*F4p?9TM* zM|RXR@Ehw5Ytg^f|7nlpd+pxry|edc-d>+!q#33Z=kIq5Tb{nvck=HP&p-%1{#keF zsSEdK>Gj!=G#e@{T)FSLx;6aF$Gtau*P@?}-aA_wf1Esgb<5SgdN_Z$Xj(gUH(VNg z^-qg`PW>TOdMj|>#ccNc=%L%vZF)5H|5v0qeWU;Tqm2K?!1qTP)iMufh?kuD^qJ~W z62VV&a?|_tx_3sc#oxYap%eGqkYz*Fr79p9JDHY;y7$<9@5doe^k_lb_$HCRZ`k%20KFP@%IMT|& zfhmtYo7E-3G#q89Pc#>I+v9EAp;KOfO{4J=e8iF9cK$gED>)%n|CSiEr z`>wZ2P7p_l>HxAkSH!GX9aPXE1WyEe^LxJ<#^|;Nk(N3mL-Z5+N{4GDmW@?40B@S@ zD@U?NsycwHdeqgsI+vfT8Zbs7lWnCh+gCMW%mmqS)r=tva&>K60otA-5MlkwGD{Gy RhetKYHd5{web=lk<==jdSjYeX literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_653084.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_653084.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05922d65818b9028bc5aa7d69b5f65c356f13940 GIT binary patch literal 6202 zcmeGgTWlN0aqmUmkw=jtMZGA?Qj^N2BU$k)NNgc-9Lts+KWe*3Al2!C<{l+V6e-`G zV#_`P3=kk>5>z7kA-V_<1}G4xc2R%&(*;r_K!0?QTN76o5D@>kUv%XB^sBQ+9!WV= z+yn)hk1TMzGdr`hv$MCeGy7MM$AzG9xj#(J1QGg*Aa?OofpwljXc=LIsRWv+wA2J; z;q)vs!BR-3rY$-4&`OJGnVl9l_7My-&HH04X5VFRR0m8r1tcKs#+;0%eNCh|HtBb6 z($8z`` z{SCYho1@<9NcUXyz|YtDY)B92ErcJ_S77r2hIueuo2{*3bd#S$_fAqe)$I3} zZn0;mw>O!Z+@RKK_hZ)9)s52am`itP0h^|`*|fvxt-m0Y?rhMf+3G3Xr?Xc$gvgro zR-LQCow}n2`*o)c+w#+0PmxT6WZoVHl7$-Fsk>~rCat?cS~IJ6>+LmLgc^9;?3LGB zyjU#O>(?PwXBH<}BKRF(hN_Xg?bo+;2un4s1xZ;MM zV1tzQy52UVGhjKlJvUfRtigWWQ-g!HJt(w*!j4l4`QaauvlOOG4r>c@vdP4gu}NEQ zSTt!ZVY=p&)T9!bmEzcRM3hK!Qa0H{DuPYtY~*r$c7E36uEY4ghk7|n5QK@+iMl@nm8!)U8pQR9Y zVhYiDMY<5tqEjYMz$qEb<4DC~ssv_I<%<$oD4&!rBvR2C)np>ks40N?q(o{m9tD>9 zS&5j3iksBirUQDR$8;v;OA>H_)uItt!4{aS@4L#s-x$7>QfAb-NK_u4ibQ83+Egkz z6q$q#8CI3(uqMyW4FfhDOGMNu3CqzGmUl~YQ}La13lb0{Ft?UUN~B+UVc(wTU)&26 zbE_1)gryl-Ny-VSajc|jAbq#;Jg7$eh59-CQ5p1j0JR2|J$X+dvg+&09J}xH*Re5z z33;K=UpT+&+mSi)0B~pCS>TH=8J&ZxzM;%f!s5!i3P+1^qi6V5ztOYjwqS%_y|Zi8 zH<}qEBltXD7%cQ&pa0X<&#vAZ9yNwX@A{5pUc2AgzBK>ao+VG_SjiX29DU&NFTMC1 zDNEh=co*Nyy;%}lmyQ)ZM%$1f4wc$^EbtLSd3LU~1uFW}sPR6ryo#9$s zS?aNf{B6q}`HqqpD#TZI>^FM%8{+;_??446at^HZZmnV<=)gL|`dwMaV;2epmv`iM z6y7L~ulk3x&if!~AU{y(D!yrS4X*l!vRug@EO7bfiq!89eR|}RBiCQJb<}w7!0NLj zM&F3xAIUN$e^+5NKm08+(43{VP);cKqn0h@XOX*Q$(b7hYp{;@#@-n#@$PKAFle+4 z82mt~rQHHSSz+5+OIrouXki;HS`fbxIL@&Sn~4*_AS zwl5={nTC!h3#PiDYZDqxXlZBKp*wVEj85|!k!Ka1x5goz(rIvV)HJx{CrdpX*o0mS z0yd#Dw%Z4K}sWFqCgiAmbV7{97Q-%RE1I>KV%RnKY`ZuUD zsO&?cck!j%OG{^RBZa`nUDvvbyRP+QE@oaWiJt7m>+fF4auwDOy|*YBeUI-o_};?BRlfhbG54Y?=UUqJ z0ZiEx^oL5p{*pgX3bcLm^8n{4Z-=#Hjw`!hWl{KQ1>%ItLHrjS5=cX+G(?$c3NvXM zvuOsRKDHr>sK&ar(Ewlv95O4i(@~7lYNTnh)lU@cyvg6F9>xZ#O^y#l%hutIu?-TN z#A}ws>cRhSB*vH@B(GV7O*CY_k35~l!be<#mOoju_J3IW{e7+Ju=1myYzkIv_CWB1 zY)V;q=1(^L-`W&>9y$Bi^%0xq;s}BkZrXkTs+|GIB&#sw)XyPw9{d9YIK1YyS#%z| ztD+m>qUC9Uma;Tq;~g#tm+2v9y8g8a|{SF`VwvIlTwFRAuH_3$96$n%PI zU7Kul9!J7-B{_3o`R#aA4s+J63ke*_svyg1tpX$+fiwhi4=dC&1uI{XAc28P+^xx~ zCPgEvEbVGUb~hq>8j-!qVWP(>X)#;5M}VlAIKbE^ZGJlkVj-ru6bJl z^7%8z9{W&N@5i2No?_(smfKxM_|Tn+J16gs|N7oIHpcPFta`Uod(hFdd^LZyc>ETB zd+5u~m1oZ0osjNLOd1oDD|2t(?YQ)yf8f*5C!y=px3}LpxiYko+<)mopmX^^{y@=n zeeBlM%GOstpZ_BL=k&^3(Ypb>*7dVTL9x^IxbOe3Nb$_IEk7P*BG-C(?q~0H?=ZS|{JncH*OFzj;}7^y;TXw`pT8Ae898n2{>94KU#*;v z8D}R~;>wD8*@!QcA?2b*X=sknXXr-+J4-)CNH|BYb@Ua3M#r|SbBTFhEDI$5NGShJlo%1Je)kmn`MI>ZV&Jf^#H4yDOd0)K|AsVh-EKQf!b^9lK71p^;>qEi0|RhgkE z>MJDv19|?5de$iCd)@DJm+5_!quh&p{zW~fm$wqEk7U||`N6OI0b6GfEmpWLQ>eYO z)ERo@>7~4-mcU{nmnbs;uep7TBe{_>2QZ|-?V;uK`SWEb!Fc5IF81Vl$^yY$$QLNP k2_&LG`#O=ceu_er=WB6-q5_W&GRV7awPpJ|ZfiOJ0q@Ks%K!iX literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_661704.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_661704.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c413dc8bb8a0892b60250d92490e1aee25e8a29 GIT binary patch literal 6809 zcmeGgTWlN0aqmSQd3=cvQKDqovIE<+ttfJwD30ws{D@@5kJxULT6J=zc}Iy7Map+n zY{|!lE&_y10+2#KM56(s_NPj1pw35;e7Zn^2IvpZ5@O=&0s^8R_oJmE2Lbxi&K`Lb zrBZg97HB@Qz}?R5?9R^4?#|B4pX_!kg7Qk!Z-#%@jL;VtF$z}&R!9P&86+Tqh@rts zO$-t`PL5H7G=W57MCW4+t<(Z3(j&IDZ5R`%#_bVWpx>r%R0j;2ElA)5Mnog720Tm~ zw3|0*XE$gkMN7mgSl(s^t+f?e-$sKRiA1g@i*1q!+U{d9PHC3ICr%%EnLl^zjApBV zM^BzUdhA5FsZLB2z!ZEdci=yRlAyv$86gd5wHAmQReGxF5k1LTPj_8UdX<*}DM)Ej z=qU!FDQK#V@+cm~iCYA63bqVjkSFC8sFW{hHpU8ct$%afEUZCatrgc0k*uMO4#7?4 zMs>w0e;uzz$AVB?wZ1J$r@>Yr6tB^K6utg;khistQ=_#Luy)2+yW&<%HMmtV*Wea~ zHDF`CBlUY&72x6PGe)KYkC2H3nJmV9Ad|HQw|xioXig?`k6+ZR>b+z{r5op|98uoczGEV+un) zr9IXB{xXfAhi?kZ&zchh=4pLVDA1|R&>iU1UV~c{M-BEj(xyU*GNGBrLRX?=6Jr`BUK!UIIhl||jmFG1s~8i< z#CVcVP6jt?>|y?#cqt~vHRfbg;7@B#z2Z+KhK^ksKZiM-o=9p=IVnX2k-x~xmEQJh z!x#(#S=OM#Ytf5qQDF@lTRS$kc5LjTW~)w=m`L*S8f(9O88c5mhgncUkoL~ z!y1dhVG)#P2FZE6P}{{=B0S2+c{w^ILi5BJkCkKrmGQB}P&BMjp>SBUCL$4;*ZXR0 zWe98pnS4_-B`e*`aq%(_qd>pmPz*F%X~*2;%^=5?hnIi6+I=}8jmqPpu-H8u3Xg`8 z!-;rTXb3i^Tb9DzNpWnv8?f$3EF=%}f*4K+;$D7yIQsbbBo720HYJ&e^SGUV`stqD zC-wnF+ssEU3;d`k#l;w3KbDu1zo74SE4p(j{JCq-uZJp!Q&Vy?6!Fk7_bYIEaR*0(J9dn(kclRxe>e-(@ zvFK<@_uX^(@{yn0fsfUe+nL>&Kaqd#ru(;nj{*zUed!}hFjwnzYoVtQn{V2k9$50W z7CH+TKKJg*u=niF+-up_O17r_@mag-?NV)BB}a2%lj_*MOcC}Siw44w_8nyta<&!@ zsg9k?X38F1p%{ntYb)aH*{$!N%?)M;-+R2wAq)4wf!y91|Fpkk1Ci&t)Q(=&*1Iv^ z6^eFSzh>7mvLK7|fem?lGrOjD70%C@=H1;H^F5b)#x!lpPtNwL{$2B~-5I9DI`dDd zY)i&mV!6yj)|R1`eEykNr(Z3+H2a#`zGvRom9b))lpn~xrLr9v3XgeKWjAHai{7@Z zrQD5tyUTkK=g6D0UFjqDXwy6W*ZT38XkicNLS?s?tnT~;)!GSav20zmdh?2E-L}k- zmd8LX7F*7gHD%=7mF$(f&aJ9*>pZ*dD+|MzR$$waQvJPR7$6s6XElJVfd>+=vZl{iYm1!U&mSlUQeZQCPe_g#>R3Rt5U2*KW0s z;Uo&==mpq>Cg@f3q^ez@TJ`bb1b8&);e9rs*SxAPo<}aA;iA?W;x!(J{?bdA)*`MpsVfs4>-t3O`ajaH~9F z^bowavIS9L9=?1;*Kz3j5RQ6P%N40HAA1gV%xvSZ$g`VtB z)BPn+%gll40|jaNc`(?VH-G8fLq)nJbEw2Qa{IIU3%*%)ZcnkL_YT)vqAoV@SK`|2tmRMJbwU<~w zn7LBN_77v%W50;sBJOnbWZ0ZD>nsopY~a=<=##T$ZTU0rIllquzsq)jg)Ll~XFGxQ zLohC9&06z4?|}pjL3dlp*I9CVN*?c5JG_jkybIZ!=>h2}P;b zZiz^RiqLg&NrT8t=UBC4--Kg?`d0j`>%;@JVk`d>e(-u&K49vk{9cQboUacL~+KFq!#)7xN3Ka_Yqz8fWei!5&OyEJdPD%R#m!8GH zXW)1L1^mM~NZqT@=JAM!LHaZh z92+u2=0XzXiC9wBs8MlJv#e4Y_&Y5+wUN^UPBVH7KKV1))-)=!$n8lFEIM83{zb1Z zeR9#+1dz*}KK{UkT01_pU$@VO<~DuWss<0=9=v^W;rttS&kO2#p*SWlv?Q1O?K4-W zug(tKVn6Nrtfjc^?7|>_cW_7@94d~#x!}LN)VclFZNF@r8~JqS?UTi>wdBsrOBKIr z){1?e;^yc7F!9IK?^DIs!wVi^v32XR&(>mnu;u?R$n2Txo4z~3gsyjdcZ4xCvwWU- z)@@1;lwTkb9HPF(z+-m@4{JA((Z6NZr}}qf%z5fvTiJrzy2>tO z^Q2Gd{$!Alj^RsYmtMga%T!78XwFLgcnq?A6OhQ!SOb!Pqddug+shV=S&_@Lf=93X bgg}J-OWPnpc$QzFkaNepW9KzaU&FruSm~jw literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_684759.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_684759.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d28209f3f4fd6056a77060fc22134e5fdc5e9d6e GIT binary patch literal 6062 zcmeHLU2NOd6}}`TiljvS+p;A8#BG*D`E%{2-Lkqr%U_c;KT9&Su$PWdMaq^VOAbjl zu}cjqP_)uoz%uq=1t?JMY3Wj+!GK|T%&-kBuon>y8iC6I1NBSxB-iQ7o^~!JN=hI% zT`^#NC;{Gk&OPUTJooS1bNF|j=MX$Q-u>P5u^xoJz>M|qO=6QF5L!VZ5{WdLY}CXg zVe;fGHEAP|M9i2tR@X)?l9Fx4^>7>JM5=px(k9y8vE68Pm~^m+MWR!*OK8U11%=+C z-LXYGvqd{8v0xFKbcyUc_DQZaPVOBv=_HZlY+2-5AY!ij*pxE{d+5~J!>4}3Xc22QVsi}Bxe=FVRBE88FXBZXQm{>-l!b2@vEuX`B{UKuFe^D zMa{}mY|vm12^XZxX(?maPp3rTtl=>$;Z%0&=+(IkSpV$2YPc0uPKlB*3p&|16oc>N zrwvaht4Nn0mWzh3L!Oot0T8^8G))`+M_M{@vhynLp_9l4T~FDY~Jgfuc8 zPh5(t)7i|v_>`1UM-(|RqDr%KBfyO$({W{55T!&`l*WX)>D2zY1px#Ctc#k>2)JE% z_D2VP_{>3&7(+txiYQ!?2^0`&dQ2zM3Ke)1Ud1rZ7WkQeaU-ggVU#SI$Dlt8{=XzWZj@&$<2cFNL zT66pIFW>X^mgaxmU*Z9Tb1m&E?kb%sz3_?u)5yn>JKTZ%;dS1#G*O%=T`uok<#*F(_ttd0M{CPlls0{Os=z~<0$sOE|SR5h|kckjVidjYE&j_K|g2HS0d zRl}wf(DrUpc0X63p?f&`5qFKN9K4=f?R%!c-}Cxc=w-UJPmD753bTv3uU~7!CAeT-Qf2lJApvibj2X&^u;8+U| z7uotK>e*92fSm4BS&kr=17t`M#XsMqcLVbq!AhkZjhJ(r~K&CzTUTxw=VT9!OhHn%+0B3_SZWvhDt2 z6nOXP_?%>(af5ES%oy8kYqWbA#36$?ZrG;4%_}TGw)~d-6mUkMn-h5sBq!lj{tQJP z)xF5&SvpueSUOi6F9$vfUkg`8ul46I=YLXj@rBEUXY)ty*`BN%(QUCBxT!#+D*-wu}N zPs(F3js(1|2k!gEarqZ)WeEO>97K#c0^&qcwB;z#p0jmC1I_rZ)f)(rhNECcU>Ztx zMu=UaHPP`2bEA1=I#4^0Zxb5gS@Vb`Z5=UZm#EZ4Y#z^Vf|#Vf6&^%mk3Ob_)PIAg z*+lk3dxw?Z&t_{g5{OL814=Arcf8v1`g$F<4ukn@m^?_SqJ$C5!a237PhAi7ue`PVR^{YP=GMN?`l?T!zcVS^ot)Aqr>b*r-U(ex0Pa zU($+NeF$^gP@rdJ@ABR+hk@Iq5b0`YogqtS?M(WiC1Vlf|7acqO=ubO)e2A8* zItcrge&x48XFWn+Vx-*Y67a}FAhe2D#1dgNS?q~P zLd7+6;BrElY%#-ooaMDaPYx&YgIj89#Ym zu@u3fv*(6Rp7B?kuLo8Vo=1aP>{Jn4TWm9$GPK>Z#H|uNW`72!ro!o{;shN#8%s#G zTP0mFXB98DY)6xjWM77{vg{e-4hBHQ%22Iepgc3?s?t>JmmnmUWW~Bz%`*H-0E3Mc zJL_07<`!*bRjjrw*NkHIRg$YcU1FAO(?pex&a#Z{{&XJ-T6R{SqizOzkgklOG!#Tq1c335L} zM#~Nv?I6P=wX2?i0H;c7U&fGIYK_%@^o|4q2g;r!&=T_M$9TMrw{V1*F3(ZOP$b|b6dfxr z&2wIxq7lQ2X`YWx^S(JI#436p?~6=xiZ&eeu|SO~x==)X`JiH$^IZweEzBvp0FXpQ zx6JWTK{2Qp&~?RmBzI+A(FtOd=XeUVDkd(>&2bTt5tqE33U!Pp3+AH0fB*DD1M&hVmRubWg?6aTILjk zurSA9<grOgK6n@++jz?^mcI8#vyi3Q#{7+@!a$?nvd)Un&n&pbCh z_f3c5CpPPvGM$;JFY0;{)P}{Kp=Hb7b!&a5RkrSXND_3{Mj6p*=&phu+1#m^Y~AzF zK+<~)Br>`lTaeYRn)=Jkr0nj>)7@+Iewp5X`}k+0H%GzLga+o(^|{7w+0qTdb4^~^ z>c!G~Kza!S@jY0c`>_c*>sA}m4Vm^EgL2d1yrX~3F(5kz{&4({qrV@`J5D9^o50bN z?#WEt(C3|f3B!ZB`qkm|aOOyMQEq-AU)Pr~u3H_ckz_PMt~;ADBeHXE!hmP(OYh6n zXS?JkZ{FFHFsxJN1fM#Ulw_(ISSefTuuOSiyskNEEWCu=`wNGV*_tvWd*dfIY>k;C zvTaZH${nxVH6Yst?gs8H$;aN1ZEwU+KhWwwIDPeWVm#HKd^a=x$(z^S%%1pk-(ZmUSr#%&Us zsgSn;mcY_t+Fhs`Ox=Y>sx$R5-7a*kN;gP4NgvR}C=uJEHi=TDY9KozAi|@WYRgMJ zY$YGtg-(`Z`Y!ZJ=C=!7SD8;dLw{}hw*K1mZT;jkR$?yy3Sf$>R`eSN$`CxlSMZ9X zf)!DYoS|{Ou4(o4^y^t7J(e(T)Hkorrf0LA=}4CQboTn}ozCl#M1TC`x~p+@AU%-b z)2~7#vp7=yIde;VXjA)qc1YHG*D3otMXys0uzvwqu_P_2^B>y21=z4gwaHXl=2D*O zg!%siGq0GErqoLxf|{D3vw7X!x$bnWyXwB#>(c29-N<5#kMjNSxqf&m4RWaBF8&I3 z2o#V+46^$e0lB<})y7Cx7t>bdPk6C+G*&NZN*Ol_RA=#c{dY9haGSbS+KE+CYh~Ve z-)bZ>pslJ4a4xf>sS-=O{@X|lkl#xlWVkz9F|vFYd6Jejeyppq@+WUrocg`IsmFQy zPd>#?tX0>4FP{>)rvJ&OKd?_B=3$?*rjNBTT?obWl5VDab%#}zTKUQf?lSfxGzBXW zGeA~Txl@rS)~x0=6l*EP`F&0i~iY-#~l3@@iAN4}*tcR1^KE@9K&L6|C zpTleGHT=StGwMOEX#ESU&#NgGkb0K89P)EsothIu9g1s0s1Y3{0GCrx0YQaT%omNt zk_zfDuU0XOoFFoOpTIFMRU-$hk(aBHL;M+R$65TUB%%Bmpk7MRL0t=Fp8;xJ7RqAY z3d3sU%SL#do$wBIAOw|@z~gh?Yg9~X%AnSPJU;fJUga(^rScF8UP{d+*wCD+L@`w? zhT@v%goUsuDC8`+q!_m;HBya6m0By`co&pjhDZ1u78UXh3OQZz(G8nDetM(M9Y4Ea za{y#_#z!98QB&I|^ffx`yU}{LN%kJUH+k>u{fT$iCRllb&CLn-8^z6rmeu#t?`22t zPs#cl`HV>_*db58akV)8mf+zaiQ4*IIu(#`vza{dkNid*fzxYe<9YJrB*#6Sj4; zePu8?SReu3aC=s#(o+Q;!1$h_ZY7)y7YrDqkjb{vl58m$F=j$`R{_s1SdiZQ!PM2M fM>JrM-X;)1e{Gp02-m~cNM!5ETf48C)$jQ)Ah~>G literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_720655.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_720655.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce64a82dbb9749cb662235e35416d00e31c08f5f GIT binary patch literal 6403 zcmeGgTWlN0@s4*SkH_~zltjJj*cRo7632Piansb6Y{{1F$}ZYUb$XzACz3KB^4-a@ zlskY10z_;Ah(bR^qXDV_1tM1lqR)J$K+y&$5b!`FBCaYRAo_9pLq*O{KZ?#Cc_%ud z;xs7Gd}M*!o!Ob4ot?d%o!P%Soiqj|nfycilO_!N3?X{aD1hw#8>(fD#aKLvO;lTa z0@rX;ikKjAjKimNZhbVgtcfG19CaNvkR=-TPm(P8Cb?GQny|8F4x92c7;o99-?~pf zwNJl^vrW>h?M-rmHdaT!iA~r|7-u*1aLnOh`y-^rWz}|ewvGU0j9= z2JnC|WHqs*92C7ei8bqb+A+CV^c(%SQXsTU@_F7$Y-Oe~RhlMSd zdGtd+LT5FMt7cZ5o91HrAeBg@C&O;lBqmjQn$N`eNQz0YDw)hgSm@2D=0sY2{)B3Y z@R4+kQ_b@npAl5+Bru?Uvj(WQi;OrI z?oz3<%r)*-l1r=R(FDs}QQca@jAUYicc!l)0as>3)hUR4g5{VL@MLa0iN>Eqg(p!~ z_0)|`3LJCeD?RbA^a!f0HVS4l!1x}>^2j~%=p3rH=nT)?h=|d+N+B@LfyJx>cZ)&m za@mKJYnq#7 zfI+ook}M-cBT49kQ4AtdnHiC}9Z3p87pT(x_2pkb>z~c=H-+g)l*n?FN-}sgn9kT&Lw9dQu1_Xhd15cML8!I z#!Bw4?9jH$vlz?8^6`~cCD>PT9ocXlS6s&*@E_c{f2ZVnIeVcTXe%U?z~R*%C2(Zj zrua`~N4A~r#qr#D{#M~=$$2oeE%X$we;Vjppf(+@#S6I$Wrw%yYAzg8T!(iE+}XS33KlvP*P$J=$$7Y9 z!Q9PN)JgzKVeY29thk=pv60S$yM)C_FW`?|7%T{uP*f>3reT%B<3MycFnPux?+fYb^{b zzJse)#n-=J1smpUl_OZwk;-w* z8S<^DLQX6SKG`Vfme8V&3j7_*n1oLmq&?s|aL8uaB3oHprho^H_)?L~l0~8<>ssx8 zlyS+lKi^tSf^3qkvJH)$G|9ARXeQA}U5G?L3ixfjkxyogh-M$U5gYA8Cv;!E51lmV zrl-u`n7%iEWBT6w#8c!qrtisbOy84FN;cNa5)UkEzUsE&m_%H|!q&??at+nXNxPxY zzC1!(7$)o8Y#O}fY=oZ{R2qSrYZ1aE+%ECRG^dic6WpxkX;e$qd57(demza3YCX${ z88MPn@gWr-fI9=+b^+9Qm=g-QNIg-k>@as7C5T5 z>MafTHC+96krPBF8iBO#xq9S8J@R}#a+1G<^cX>v7MG%ElplpEOsQr_Qz479LJrG9 zw#y?sY8fq?<~8+^xgZgSJW1fubsx5=w3hi*vpWiYEvAHrE#h`_`|=x0Z>(Hg zr#?LPaZB;Q)rS+z#zah+h!v-AKWv`e?m6^++k0)RQy=zzG+I3NB)MmH+uySM%F-(< z^y<)hyx8@lKh69_{-a!cJ^IkkZngekr^(SmKkokjE3z_nuj9L;OypkYcSjlBGCSk= zX`dx~vGSS;a@p*KtA!`hBj zJeb3uy{rY2VJCkT9U)8g!9jIb+v7=idzyhK1(nhv9-SDKm`aH75R*v@A~*}amInNk zOA8quJ#C5Faplp0QSH^U$Hd@0#y}^oWyoQAe3N|Tp)afuLA`LB#=1yQ8pc?4W%?k=k{^)wS33;o^a%g4CRW3tVQU3B z+E>tZub@+}psP4N=%i=$oPLmI4IgL*45ir`KR^RU72=E~y-X8?&-@gt( zW2`R%66i@FmiVoLRWM-+tD1y`0~j86t++Uro=q!^}X?qZYxx`x)FK$EFeuAw?KZcUTrLCZ|OP~dEQUi0a2nJl^S&7vG zAyi{Xa0!|sRj3lQMOdtX78Ax^tBa{PrXtTef-O~B)~mwC6l#YZC1aF;GW$47#1c@n zb+H!RWn)xY)0z^eM^Ke!OO?loQY5PHp)~dbDp~r<|4ADH^ee$WQ$TUr=vs1uaBL_V$O`m zWg2>iSdJM5p6KPEY<>ZkP35R;FUOd0jDrrgmuRwnl@z_a1_swamdCD{$0tnI2gZ5k zvOgXO$`l5J9GFKJJNgQu|kaKrf~7XtQgQzI~M1e_*I!4^~XkJ{Z)S$=8K0=0nwl|76fYJg_^=+jPcx6 zS<7>=@o-$$VXo+SJhl-OZ+Kn#``c~TqWnl~%pc&|g8smWKOT%mTKz+?Ky5KT&=%)L z$JzjE8w~qnL5Af5QI>0G#)6?Y#wHkOU|?b5(FlX-%-;4LyE}G4i@bpuyv8yk93SDr zOnF}>76<9={60_*-zYIXY}5pLz6seJbLLrd%Ad71B)Y${*k;6OF?BX;*^oH7Oj~D$ zriYTjbgkstmZh5$CzhO^R7i3*-Q6fTo9FeCV_)L*5^YPI_{!!=j{nS?G=mU>X=dy6 z*5v8ro1fYrxbL~27}7#b6$<~mOPED zH7TEDY5e5!ozU&jJTbp7yXBB%IkcdYEGMSOubi&AuGy~Cu3LL=?ER$uPRH$ztfwXG zY@O0$z22FAC;9eInJLwh$x}9Hrgob|Z^NTD<6#`n4an}CtDdb+xwE#-Q`#kmcZ$NI zduDr5*YCQr&h1nBoW1sD=dG?AUFnXyA*pe1*4rW3J07)vzUQ+&3tJxVlRDqZ9=IUc zFHGrjwC(4eie4S*kmPBD{t~@?zWtHvbM0r^hgTo#rNd{l9p|K-=Op^vN9ueVs%p-+ zBifcckz!~2By(e;3#^~)PU|J7PojNyPtH$BEk`8!$ikk--O|wu5`6*IhSYv|^8J%j z?<5E3MrKFSO?O&vw=QnlFKybNbshKvdGHzR2*lUo4|bNk2YRrkR0Kq(S| z27I3?G1RZX*Uce;7zV=nauH6ks1qqsFB*UYyT)}TUb;z#wdt7yiEz)C&k-cC5+w1ru1gjEhm{|pB#qdhs8_E!)jSZ5q;gjQcdT;lBY21+^ zcVvt^ayECOC+BQP^yVDiR4`+2N}O71-udA0y~7W8WSb8y*uHG;O#Wz&nPt*vpHwwH zCK7$Q?!Mo9f9L(8ULv~^eaXJe<~wxmtVWMlf2L~hSf)*NNdQPnxhn4?@d%B_4oIjTNKIamMf-dxqDoYkGP zJ8}-^-x}>2ZGHcS zuLO#Q>itbs?akSb7z=;JWBz zA$TdaXa(ItX!*Dmo(i-O&{+HUHv#7l!RIM$hPY;n0&W)A) zm+!*O=2E0o7DBP1QDq3(P|_Ek6$*K4OjeI@6S97-1&6YwP^1uh)F`c$0+0wWzX!kA zV^~rMKqzE)Bzl%C*2KvrXH}wi$zlV@YEN`Ox1!qmTjm?)wEu41!&=GLxiGNM`=tNf z#eP=mXEUR*CpGcqYVX{8v+t#Q=BbCRk83iG=bj8OivvT_z))uF>XYhg%Nw`c@!a;@ z9e%iVp*Pd|V)Mpp%Z{441G5Lx21tW~nGJ`3GyYrg*J9>U;E97>s@=3wWvVeeZ}|T! zl0JK*?z^Lm|3>|HN138!R{Dwkc5R|3e?)~~6P;LcZ(ek7lib_>=x&*&r_@vZ%akY8 z4QaD?`@Ao6@Ql>{!_2uKXD$v(=Y}#NJ`=kxg(mWF0TahnP<5%!s#Y*|PW2oi{dv_= zbwj#Js&1OnCDk99@_OWH&0CSlk?2!mNuQeUhmHL-0UrhOnF|M0TSzfvOQC)v43Bc- z@U$XRMTo~oNLCMr;_!eKjl|;MKzJnwlC^Jfkyw<+FM@I95cBv1$;QHAWQL+)_8wBk zFPuE;;AoT`4|8ww6TpXG6=RR!lUEZ2@ij9237P+lyzn5RqCRxL@6N0C5r^__WVO!- z(?Z^ZF)wmd&9%(7d{YP5Gc{6~3S6BkQjTkZ^hbwQ!5RGikXISF1%LF3K-H-BCpg%lNjfAUP2#9{%ew2=#pMJERJ@QD} zq2eS6(0pWp+nw2&ot>S%ot@b~GmI5Mxgq^#a?Ov>XBg3oMgioJhd^iBaOs>sqIj*wlqN^>Q> zRlHcUr;3bZw@l3&5t;`|c?O^ClbyJSGt9$w0}S$#evVAmNfuoMN0s^Ot7fV3d&sT{ zLeht!T0hEdV_D+{$*Hr#O8k2NarDL`(4i6ZHDY~bk8INQmK$Yr8TQJw4krUudevGz zAqT40JypGrB_+OE@3IA&Krc&4FRMO3=w;L4Wb-q0G04r|kuEJ&GkJ7dR$94UZqcop zY^~z;Rd{p=cw6xwP+sQ_xk286*OY9K>q=1FGi0n?%gOpIs)1j-;>Dg(O{R=^ig(%zZB!+s&P2Vv1e7M*04i~i4#|6MzDmlvy$o%B_Yc3>;mH{{33dm4t9noC$^bp2<-7Vn~`y#JfWiuzfwE5blxq>6sqDdLprqILUH+IKlA; z*qO=b^D}cS5LnnYDG_IJKihlo$NTpm1d6(Yja=c_i@Xr$V{G+UR+K>c0ihT4Gs7qT z3x3oTdfbWJ-eqQq$%gW-w$$nSPWNImoy?xgJ9nf8wixH)XnHhrDZ4MvbfivV{Ppzf znP03}R?jGHhZX*p{y`+nyc&)Yr8@Z`%D?MzKG);i${DD2_c3NrLIv)DfeB z=_r_x)1OHyj@=KL>BL3J94-*oqnDjQ~VuyrgNRyt1x?SAN_dnqrp7WyI_EE zOkJ*_Q?Yk~@LXe1aRjmSE|6ZrKztXL=YMQP9^Z2PQhj#EYQNIhn|B{tclRmozTX}F z{orp0^X^j%rY+#;TI$MC{?+L`H9=m7Nln$niSe&pX< zcmdfQ8FRWjb!^kwkbO~c?q0ciC#ZDvDbBvTk$ZE>kyjPxtEqwelHDFtEP6SJ~a0Z#ncA>P5IT;Urr0!6YrfoPX;e}bIWMjmTq$RAIR%9Bl z0m2tTHh_yqQ*cE*T`GFQ4Bv)cUd1-!jrtn4p_|Hd!!zcuO|Q&fn_iipe1`nm^oso2 z^oo45Y`Sf}RSJ&?B$Fd3NEa_Z0S8tO{tM;7e{iK;haFGWjsxWseBwW$N}+-sF`mW2 z^kC-F_2JZs4X1B$K0Tjpzn)ADYY>U>kHM9RZOj#53G;R3Xzu4t5nhJZ6-I*E`4#9fE@RbWhQpIEa4>$&p;1NlP z?2`n<^9GJelAJL~RYg&F2cIgO367bTC&(yL9m`ickJ1> zqHOx9sR7j_mX7|{kr*MrmAqP+3)YbQCh}y8v%F`l(((sumN!&$*Wc8dCLF#0U{h?w zT2C+E%BBR)=YO#2cWhJedDy0$^*t(Sj3NkZj8l-+K3(*uC3w-;i_kdiMA8hQP5ISB zra7Ay+|Zo8)H0>kQEHvI)qN~n$23=@m}GN5P~Po^4B7-&wCjxUQ@|f!FD)Fw(95{O zFKybjty1AxE(8}f&JI))$6t-^dnyFNCdj7{WbF0a*J*=pbOVbcdqt5bJ7?+xd=pCq?m*@Cd9Z>ev^x;k)YGPmQU zKhFLs`G;igjqsY6+icwV&~I;#cZB7R(v)*Y<)1HFXzU$nH&@(c;OV zT{w?Vl&Sazp*oB0^D%hInuS*emDV8vM{p`R6_w!aCJ`4U@EJlW@c$Vf7ZZYZTm^hq zR9o@bu@msX@)6QzDn|5l|8#%?1 zr~3;=fI(ti)AIPzc)^S@8d;r-&FSWX1!GoZvb{5YZTyi9ut(KSp*P;e24~uhWRKW*9PLxd?rAF2IvnDbYtS`0sKQ356@Sm%cdN64P*4ULr`XHi0u!(id4rtD6Wlr|ed_RJ8H z(G9iR5;v^PQMx3L!3v`XjP|4G^{+wSu2d2YXxePnVL76dNWdLM^mOG%&plS5ZmgWW|TjT+Kr3c%sk!+a@8=S3ruDb;r zkYNGCas*mJQQK(&Z}GeM49*yw%hX6Rorn^em6Mhh_^4YWrIhAe5Hd4DY@SVW8kNe# zIL#IlV(A%PvnBwQH0v@iWJJxbV+TexYn@6;8O=T)yONw=oYyS;l?BZzN*RIIC@fra z@+p3vPfM({6y2`TN7(cHTPZ%RStpYmdq(r<6+4lcId)~?JeF`~QPLPu5|SLx&I3>8 z;)CeygQ)l*%4uFx-;~I+7r)ao`<)h1b2WNp7A5Fm#%3SDI03sp0iQU{5nmM8i!mua ztI-&o<-ur942ktQx-OdF8P_bacwD3F z>|lSV&Z^HI*5+-^)x=>&noX*AV$%F&wk5-nNpY+gkEM9c0pc5yX}v2<3zjjNj-|xy zp!&dL%fElK?{Y@CBre3_{Jz;({8CJs&7?_7X3o`4nsJ%Zd_69~BOQe%M`$ z|A0-KMfZ0iUto<{Ws0$icOZN6j>os67*JNsrdG0PgR54 zN)a`2$3O=Z6au zgRK&FLd9D{6T%X{p;gu>MV_| zCo18mbIhG!=i2erg?aLl8ymT)XLPHi18MN@{`$PcPQKYC*D1gJ6|}If1~*F$FE*}wG{t!_LJG$ zJIB^j08ZQ>?O8wk1f#E@8QPk?j$vAqXL45}Y?N zXV##>^S~JG=*=-6(cfaYbB>hB$BfXP-Ws7jy*UD0G3_*L25yg_H|2wd6!5gObaM=j z;w%sdH7D6@UY^8H(z3a2yYrQvP*ZXsavd5|}U*S->qc?kWlX{|bRHdR-+6O18N=K@+8-Nve z-d%XBg3(98a1JB4=>a%oC0~Vp;`=f83X^9Fu?igq`6k#=_4iamJ!`M6zE<*m_(s(q zs0M=n+7YnYghMckIDE|lAxizs{1bKoBL-8@tUUwM}?#oUa6B3RCftJydjwjkEdRNK9CMl)QGKKlT8Y zA0khtILAj;la@bQv$>l3djC*s+Hm0fvrVxP+r3i!D4P;EdH&g^|6`ki&%-w5oF7q& zHHj3PY@LII^5J5<%D|JrZiFttN+@=PZao{wG}oa=A2jD`)J&sxH);>A4IfJ&=uxNG zWIMhRKNx}J)pj05y{Coe02iKz-(BoL@9GeK>Cx|GjfyXFu_!6np+&QC{M*SmAGHcN zg*9EXcr*yz^h5{3w^b zUA=4XtiDq^b(8*lhw%`;q6;pmp2D@e%kX%&-&cw!(U96M;;^( zUfv9Z*AA~9E;-jv+?*}kc^@Iid2*v`pcGQOhI00T z<#(={1NDs5yvP;Ep4OwsC?mWC8ylV>^aNk?)a$(|cr03k6kekZNWk%$W|>P$@K%yZ zixQ-8LL+Q{iBF3eL9@VV*3Yc)Cg$y^pFMU4-ZHKseYpBTqsQkn++vFVr7#D4IM)#W z0Kb}rAc${}>u-qpJL=sa?C*EK+g&4{BgSjnkk`K==jB>I#s*Lzw6^7U;H3X+j^gSrsJ|My_X^*m3N1W?{!hBJ-mFVvN@X1OmEz(LWd()-X(3K*QpYE4KxqGk4 zhxC)84Ro$*6Wh>-m*p8u3?JAC5wkyrT1u`{}(3LZUq z`smn+NVBRIV5Q*OItTSCQeX?KB~F@9qZNqjHM-_`1SeVNY-{4A8?sDDMXKitAa#DP z>Q!C1M%0NQZHg5L)o1n}MQ{8LH12GY)N0*?>QU({3_^IVYMaW`;ZD_3hXbnBgv}Le zU5}7Vf@Ibl1(I!bxKrg!xGpW&B(3>o)Rv~Ut3h*RT6Yt#*Iav}1>SCaW_8&;kKhkJ zz~8HOs6ALiwM*?aHPt#)f33F{@0I3PyS4TY{y-z(_Fu6>PY0~nZj5VF9p*SSpgK)h z>ub`)-_Y2ER7&kL&Y>e7F#8&T_CI26)L$xSlCWOSPchhJGUst%VjM0T&S!2Bkn3!*ncg$ z*7;1vdB@{jwX9fjONAFofv46ydos)&f5+;vm1Fs53nxq6d)EAWGd9?P+*o!lf4IzpUIzr^)`p0?f!g&5~O4l&mEj=ys}GnpIBzL}r==-kzFg(IJge>{GB$7pHC z=vwdozthie!LcFd zKjsgi*Mo_N(3!e?@)7g5rZ?tqO>fLkJwkqKdP9C|dP6>~GAgIqZ&u~?)FcX7 zPu<0S?=F5vjK@U6r+XJ4Lx^tni%IZvm%`ElxKj)o_AiGhcoSyB3!>o&bW7E_hMcXg zO_MOD=3udN`ssNzAp_k4EX(*q@XC_R=L%}o{9vfY!0NL7;pv4B4s$9&*nqX%ukjIb7|B>h2+2;l$f`(seqr zD1<{~HFL8F;#;wZ7-Ec7$j$HH2`NJkc>cbV}yLVt(wKaGaaIJyCTX89|=PO zwx=1{+l)NjjO>%}V=PVJ%7`fOqNG=#3b8r^=_I6ER>(MoSQ1iFs2~%?8K{tyT$mGY zE%ZYgBTM*c57~6ih#`z5OnM2Eo%I+3hZNE;@Hjk!&ehe0JXaFs#e^d3)SS4a+Zq%b zaMpv1YGBc6P+LV8uLCo7Rq~f`q7YrM$m@f^!R<*OU-!4CPp-S$0P=X#MVq>^?u$}3LXAp`iqmdrrx|gC6uOw;=FvTQ`rdgtiHYSc46Ws`}y#foyF~EZ%y;J zr)NsjGsT6sZUruF3_S5k_s8AW;-3$Fak4mkKY8HNhOcw=z{-IFckTGi*<$}oe_H&r z_D8MwM&y=HSnqmlv)$3j-P`v6SEO*}YVS`+nef%VpN=x-&TLK*&wDNDiOL}o!6kZm zJ-FjGB%Q&XUj=t(T^TAfwZV4h#|yDi&r>%;#e=6ydtWV{eZ6=-S~@#Zj7dfLaw)b{ zfy9XzBcVA)o*_3ec9y({kZqD&4{R&6mjZ(sYmWN0qhdqd!w?)heCbn0C<&qatJr2N z)&BvyyV^dPfS*E(kn!uR2}#&ubSfTG;Gan@B>W>vG0auM2BX`n zw#Uz;62hm*n67GQ?7{hzu$T~EltkdeKeOadpsG*=L41Q8UnA!~P!AXx$-WnSCs-k0 zAP!XekjJ~MX4OhR#cG^AF(GoIsgCw literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_802348.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_802348.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0906fc5ca67e1fade118342bb35d2668b6dcf9 GIT binary patch literal 6976 zcmeGgTWk|o_KwG6kH^p0PMkOooGmRe5K>ZFB|zC00wpv+fvQykbYwi^N8%SVV?yXy za#xjVx0T{(SBjA;;qFiAg0xh-()P0}t+bW)$8OkSO<5`0^22_ahW**EJ$LM}4K<}( zRoZ^^N;CJId(S!d+%tF1cWPX%o-;gR z-1vxb`Vr$go(VCW>23X-xw%W`Z8T@mA>Pu|!@5NH&3Ezg#}w24nd1juW=|bFp;+qR z)Y0Qp2WNtgmU_qlFbU5_0p6=f0yEazyiS8!Jx5$?%x4^rq0}`gyV@xAEm;~QBiV5Y z5RHD`sLix`sR_>0O|S`QoEfL=lxJ5tw`A7HIlV@2Mw#w5-8vi! z*&Zg^*6wgTu2a_4q!pH?q-%U*3~ znf^A~o)(P;0c}72W=+`xkD>QIL_a9^$OBkIxlbO{G|lwL?#AdKwpGS0_h-7_|AQU@ zwfhpB37~MuW|gl?wrG5^N49EkW~fb*UQORtFO%|+`W<>Yom`om_qJ#?XnsJ0TD`V& z&0+*vLVm+B0Uxis_^%6VvAs_@2VP2u?oiQG`7>%dn3dvtgDU>KB1zy1ao5Juh zKF%j3R$B6JQt18cDgJznPbk#UD90XG>}to(B*Oo zBM+kDgD9ss+IV9k&z}3%NaR~1qGD=r!D>w~(gHaFJ30aXQHm*;7T9wEDHu^`3`Tgc zm>EMNf1bsAvz&`1g9~he6{DAU#VDrZEH;Y?jKs#0;b>4H1HqtTNQOcp%PDl79IUUC zE+_`6PGU{)i)@R75`%#lub4nkLk6v~C)1L+3Doj^cl*z4V~a^)K}-dL{8%IqTnI>! zWMVWBhEFpl3c)dnkEg}}8w-ik zm{B&dp+$~e;DrPqW7~OIQ3C1X!n2?p_RHeea1yiV{s`)Fty)*C#lV`wmz}=taJI0+ z_YKI^y}D&(OL3$$Tk-B#bB^6|j#r%He;ohQ#3vJL&b`@}tL{xDf5pA+YM|mCyD?pH zJ)fPa8IawXed%^rPa*y5fr1s}GM43${77M@uRC>PRd3w#UGe_Na z_pBaTIaGYEbhOgHbIm=TGp*aY3J3EG#r>te%g12F99ebt6%VeAm8L7s(VVgB9?qL; zqo{jZZ5)|x1!I0RdjNJ|au%j5^k9{77G^8V5GZQew9YsRlNDwF^fV3DNJLrhGl(+2 zbNH>pxl@IS{2RqnS6;jPTIs+?(;rUX+B{L&Jh3)7`FH*94fxhb@ZcAAc><(eNWhLP z2p+cCu+LX>hXHU>t4Jp1n^P^|4E%2yDr3kPGqkMhL?KhMK{kfSg%LPlx+CaUgp4}V zg@`s9vNH|Hq92e&XH0m7kS@b;M8?eNG8PGItu~p#TZ0^gkagh6iTU>ZHE1|e^UaT# zZ)Oyam~ZAnkC>;L@^z0{zw>-+{m%2P^~uM`?>yg<-+8_zUoV?vvuwF;xYkI82_%!J zkl&al;D|8Ty}hb>Am5@8NTqne7soz0nh5*&ixT*sKrFW814qRR0g2~)=|of-m3UF| ziP3N(Af*L<2cwF41!7m62NrR+ieu4QT*9!{cKc56KkHM&pzojA-}u12r};#X_koWO zOA*nBl~4mgQ;ZtTH2y}9#ccN--j4a%Mr26GxBI5fKFHPzf+}NYwzDud0>L94Psi{! z$5H}6DkLR!%Q$dIKG;}X^fzTgJbg$ZpRXx<3d?J6BM?tgz#%1lsborBqa~_Y6tr>d z#s*T&qizj<4#*`EzzH8~N!8*c55`Zl85^>LU}VRud_qhLnyB4AtmuU(zvv5uU<+C*HGv572Xq9;(no#q(?QFz|l| zW?p9UOkwAHuuMnL*8HggG|li@b*yC?hN75X&XXk;7i~XU9xLrGR>LQ z7p~5gY$_05$xOGj?dpN@XsTYlo? z%{lhgT(~k9E~hTs^eo;Pe(IzC5Bsmqe>(En(emhn)5D8*T)nG%R`!&btA}qy%A0=q zm-JsVf6kQ825-8!^}b)+>$dbV_kI6=MM@_w5B_wN30xlf=_sRF=H4u^+iA!islA{> zaEM-7_inxgS0di6UwXIaZ8YzKMJE6OWv6H&{2$@dl);+#bcg3?MXDpEKT52ZLKU#Ajiz|CfO(gv` z;RODL4D~B4#a{1s#NY}u4VP33tw93*V-!79mJx-Vk4kXamrRHfxD26DJJhQde2*%r zKdpd&8^v7z1KDsg#(j*`MeDyCy*HlZ(lLIYfbU@OU99+9c-2UPAihSHuaNa0Xy7Yk zUMGz2c;E8YbO#7xPi+V}oXc`vu5D6LAHEdZzOwzBVT^8&NM~sf&gga#_NvXb9LvXQ zB*5!t$MW9%-WmlkQ~*8wt7lfu)QlLT5o2E-$Pd&^7-NvbwP6Nm;}U@g>o=A;f^gk? QfkgH#YqpWM%<3Ng3&1G#`2YX_ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_812012.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_812012.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8dd2a3f07ddbcfb6b22d5f53331ca64eec677e8 GIT binary patch literal 6904 zcmeGgTWlN0aqr0^k1z2d>SbAy<4Ck6D^?QNj^#YH-2Hpm0B0Dsm8@Kke+1M^XxH zr*47fBMaQ^%+BoG?Cjp`%>2pWupwysJindzYa>EmV#I8G8Q3HVgqDzqL?VVpOFc14 z7&tjajnV{?h)F|^IknV_q(o2Jw~k> z&p3CVaZ<8IY@+pDX4F<$q3vBX%9BXqE4tX{h@ky0R^yazeeuNUmtGal9Xq4jOW@Iy zr;i>x5pJkalLeT9XY(e!mrw#sSZb4`2~~TMxK^eo8y><*RyaL1oOD%|38_eJQ0aLF zp?Mgp%<`&U)rChyavrt}V9+P&6RD&>VKL{5bcMgM<|~{@ulA~YoJdsAdWYaTbFIAM zWS~Y?y<H!expT3DswgA4PBe1N3dKkyLN(1gxDg*RFb0wgNY+mI~aYawcr9 zcd~X5%L+Vredf$$%R_V`K_{#EJo879R7K5KU+==xToz+rNE|k({^A}M+I(DofSAx&zdSV*|c3ALX$1u9y3>OT8%(+ zD9E0Y@v**(|J!kM&OF^0osI+voe>jrGg8p0lZlvan~~$=a%f73iaH&OheVx;K$p;& zc}b2dy2Zer9MIX29GV`Nbf(miiMVc=3SEv)%}(i*ba_T+ltf&XbQ%lSZBk5{lBN?v zVlKE-=Ux=fNf%?%w9cH2io$8#Wpu)c`1rBQGv}~`)3XWPr6lC2C<*5UrNr$h4@|+N z_<4N`IDtL8GB9{|Jb@UJSz+5Ly7Q& z&S7vu0{dA&c>%B0em)itUl67RB|0y`@a&X;?PLY3391?9JVwsjM2`~f98xF<5u%&Sz$hf|nB_j*u+cgeBn$c0wj9jW7YTn(9I zI+;7O>e`taT6elLw>cpC} zA@#}~w?7;C&;fF6_DoN@Cwn6M>~+uYT0UvHWqT_1(jA8@Gm;+3Ud;8bI=WLYult&F zow>0ue0vr+OnW2!M#0{YJ-+PFe0`d|ui$LV?a-W$Z%~Azd(A``($QUHk;|V=YR<keO^Wu|zTvhCvLdVNo*jApOM4dgQ~?xn8h^#ywuJd9q`gtrOI@cCFdm z*+I>=bAutR9bgu#J;SEi1toJiefhSlOLKLta=X5^G7P&3+m7S_eqk>rW9KM0!z&CI zex>;VGJ=qHBM4E+$qK0ohyj_N*Yx5K2*w}T2xvjE*yoENDiZS2ym@c@fApA zR93Y_s0%$6zWQG@^ii3I@)??`>=r(}+HK8@;Y7NIlbC!EXYIF%dTS}!Y>STd`M#|q zSRq0txkP23R1VuvB_Ut~%fTL1!m#Q>ssSSG!=unEVdH(~v*smtA39T^6A$@*eR}o# z_372`QxB0}pI((;pI((us}_+K+3S{T-jaO@q*CWlkSks2G7jAD{tMb&eE;0V$K)=) z>4S9btuzfc&83h$qv$pa8V*?r(r|@Lgy3$b#ORdao^`f#=LN0xepW-G&Xlh=e69V@tenJeil zx&7&%r-lmNrlrG+hja4cbKvTEU-shrNAh&j!jS^+%p6J|%K4YMmEL^Qft&n+JbfU~ zA1Ltd%;EIm?8Wr6xgX!;yYh5bp6{w1{mD&!cb?vz=Xc{#5OVc!uJb0}nx|Xyd~1Pk zxy`q0eETZjvHU^}D-3sNFxH^)yOw8f^1XSwH_!JP{8hr`n|yzs?$7i6YjC3+%=2xj zqwDmp<)a!MEO71u=O}Oi@Tvf3TxnO1Smjz^Fk?^KvuCclz5^Jz&9#FE&s|*QIziS0 zFg{~T+p_zw!YAv3p4NiDv*7U-yuPoyeGFUNgY2%3S^P8gm_gYdTGoBsx$0J3W$ z2`O2UfJBTG=_DmGNxCLDFQ-tI*#L+v_zWY#Q&FTg3B%`hTN5qUxoc&QQzI3h^=-}Y zLZnJX=$f>rPU2@sEc?*!A~8Zel04fs^#CpT$^S$i+)mbynHnvBux7=qk7P~OC?Nb` zQ?O#alE)+2l)#0UAI>R0nEykYg3rS?6>T5WNhXSttjbJ6sbPCFV?-z~?L}w|b|Ptk z7_Cx*QaO<~qBKslmwQLKcb0n>?oA&HF$E`z_4_~K`ypPjrl^0V6`qBbJtr<7YwiSIf7QA&`DxE*C-Z$<$(@(hOVQV| zZDr`jM1JRUf0+Ga^7qO7o8entajm&)!*6f0-Rt=OE3$m%qaEKLWkMgde}9xQEweE~ zJmXv;I7UMEGI@sFz}Q*x9zxc0w|OMLK6fG7jW)YBHLDx>|f%Nr3B+ z%X!L=q-iBC>lFOd#;+^m|3>Fazo0N4kBOfkiLowRgy~)E{c0Crn-n7LC9Yu- zH8ySVdkGg#B?i-jMGD|Gf6LO?;#iRZ7#ECunOHhjv|x-wHdm%C-Bz?>%!b_FO+0&Z afkgI#u+4DAUtl{74D6L@t literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_83138.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_83138.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3258c79ca30c94a9b929f25c7d8d61c01c8e2303 GIT binary patch literal 6185 zcmeGgTWlN0aqq$7@lEj|iPSqzV>7X$NPfk&>%1&WaxFQso4Qh+Txj0$AycG$N5!^$ zESMrd#3q0w^g}d|BFtBn+CW|O$6tjD6lj3{@G3VZt}Y-T`fnyN!fj^HdofmQ(0fe zY|v;!ioZb$tJCYP?2Zu`18ugz*G9 z2w>1U+n#ah)5`5S?@{#TZ-5qf%sd3lW76l5TV%$71G3qGy)vuAy8LYLF_KA;Y|%%7 zWUB!OWSb5f(mEQXHQBO0&gd_jv>42G5j5m@OR6tau?Q64$9rS zrrB26SL^M@E0Fccomu~fzoijy`xb1_)3~O-x;{E2Q(zIhE)Oi?Fkr9jG~ibKKM8ah zg_wzZc)N>ucgA!{3v_=X6%7%piO*b~6GAqX%p_IooS2S@;aM)ht4vr7r(%L?iYC+H zjA~jC#I&TEHEhq2N<%NNnyNhJ+3=;r?EI`s376(nlaxt|La0Y&U*%2<7m`9sHH{~D z?xgC{8tz0oHhgLBH0D1!pHZDsMojPmcTTD@J8JkWj7gtUA41~~qte4Dueuw@CI#u- zfAqvv*CRZtwbqfIKi2@q8(?0<-bKU*Uu7fnA{Q6n?=2WQ7tTcDDwB>zC5~6k()=uk zjbH(jaLIHm5mBjdBmxBC0*6guKbK5L&T}bFN-PMfxrTxH-%*)Nb$}x!T;%HGE$JlB zNs(|;P%W6Efl?3X()-orpKlCYOpE8Gxo|`nh=(KR!RMg}s%?A!og z1JPtyigUaWN%O)GHy2MlKX;h}0tYLTNvAm6&%L;3aLn^h$$h-H4f#Z z3b^Dj2KX>HD?pp0jyYjo1Pb?p~?)i1_)8K9E-rQ@OPWMtQA6t$W z18Z+7eTS6RL+j4NxskH3qmWR1JJzB zskJVpYro>(f3yA8xN>k(@lTfhorSF8-+iAVTKmfb`))dvfnkgsBeoFX93d9T%@*%! z+e%xZXYF9=z=-nPh!Pkn2l`4+?^ObOfi<|VeBj8f9_7I67@5T5Jg107<^h8O9Uoh- zT8jtQqU-JZ79Dqet*ggYjuj3Tvr5O_b>F^4i{7=rIIeUIuKR`-EgR03<*|IaNEW9) zaVpN|7pbx*P#9A@-&wOLp20Qnt6Zp-OJSFKQ~mg>%g zZ5shUW)#e&-L#$HPgapk&KRURz!L#LvL+cdrIH$zk%{f8GIM?x>>u(t`Y9pPS@SlM zC`}TZ!5V2rHpxs>lP{BS0}|jd$(cs|Ykc(>03<&N-G~gf8E@SA+t4ZfZUHNN_5980 z_4%9A>+@5Ok>8wNm*1RTmv5F?-o(?_nQM6SKp~J!okk(^gm&eKI5xY77w8_of$m`+ z8loHgsU-z2^I}+>lT<4PH6J2{D7Z)B;W3i||P}{?4I_6FK}#6Zwh4pknXJ4VPWMr7QU>g~9w!a-(Hi z@Q$rlvGo>P6x)vCr9aySO7uX zJe{R@Q5GPw~e5AKDar9=0iO{fN$*5-7{arWuE^`3*pSDuEpRIfTx_N@UGhwjqmWc^i(6 z;BL0pT1TyQ)>;>Cbsr0NHp^EjjBGv(ln?tMFJ<5!cTb4>0T&P8Y9CbhUd9zpv9-Hh zr6covI7Es#;#V17cqb7NLMAPgRLxov2T5K_4Zt}Ozb%$uxmJy^37YR!O zH`s^_H6nW&k-g$ktidp@jHtCLj#W*NfkIYhhV+zAq#%2R3X)cwuJS1n+f&3(CeRNl zkR;-JKV(s@S_GjbX5#lS*@2U`T80MMYBjXr6SLKT!)oXZd9f%+^T~{)Qs;%ss-;e8 z5_+_2)MoL=4`EUq6HA}LfoO6LnVLZ!nem7&*5r+!>I8&yulN(r$fT~ZR4E8tLwVG^2S z0B-jq%yTvuS=lDI5{s20r`*$hEuzwk z-zwz1TCJ68Smnr1hx$fNcu1))ee*;zq BGoJte literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_870175.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_870175.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..473ba90a57318efeab7938214412c437c99897f0 GIT binary patch literal 6228 zcmeHLTWk~A89sBdJ-$1(6DQ6My4wYZgm7KhF4{%hLLfkasx&EdWIW>=zJ!?xknFfJ z>PpzVDxl_M zqzHt^8a@Mgk3O$2ODusf$!^P#abhflKSkDgU+ zHS)-*Ge?e&g`7?98iA(ZT9x5mL=uF#*5YJ?TI)IDPF5k_-?tFnw z5^S3Exh|Q}X`gIv&<#H`T}{?nEN|H@cj<5HZ~6)?mU_Hh{~$U6^Zhuk`t!P9!baCR z8$Gf|?#6GD@c`8+_h@fr*Z8)Y24vtvrcL%|+#mg!Mx0S@n}H;f>4h1+^|=&8%F>7k zMBl2@TEsJbO|j|FUr8tAKG_rR(pjC5`{I6`(TRAEPU=J>DvbyeXc`4frv<#2pWq!9 zX$+mAhhnL4fKUybG(E}J<<)lsPL-4rs%270M}%OKjd7|WoJa>H)iA>gX;C$4-0oc} z4N^`u28CcM!mEZ_XG#Wd#FA4G><9Id+W76#9b$nEHKd(=u zLswK+lPn?fFl)Ues(ROt(Po;|z;JXjtNtlf%FbhZ8|D%5r=t&GEa~$!Ki*3C!XMgtilA|p<^4Qfr7yh*!bQas}*8JAFvAK8dx_{^Y-2c$BCwq9= z;hK%)BXiMGm(nq?=-8GWsd&5RV~Tgn!e+(0?VefjypGkiSr|watx|9Ex|lV%vVN^ZsdN=LyAjqT=W%cWqM~+g2#TK3Lhl z2lsndDav8Vk@gM7QH)jW(d#RXg-~zDvUKJ0QpOW=>^W+H9xAjf4E>)+&gn7h(_LA4WRa|)j{Ox zm^T%MvL`DR`cmCmu_Er)Vo*d-MiTF=HKXjtMDz6;KR30e8b62z?n=n*H$zv`;dN8gS-tc4ITUz z7C;fnR2&5B8%JOh8|X-^%NR3;4fK?zH^~Os7$!4}grBB0nbEA0aB>qeC8Kz=U-dC- zcw_^;jzTujQ;ke46ECpeTEAw$wSLV$^#c2?^=tN9>(}hlvgs~;r;bRp^Tbd)PXvst zp5T)uK(6J(00sFz8l2>{J+2ySi8NqtwK#+~VX|tdCr_L>+0<0BmR9KqqF{TxpN=jBKOM{M&mOIKI*M-<-kLvK zIFKD%g5AEqZ0*S&S*Bkr9Z~2&g>h9Fdxi1AW&&c?mbcBF{m{7)>3hWVDNNt|)kS7A zO!^kiI%~;W=5~Gvwp)_!?n?XSfAzZzMqw|!LkKRt0nDgvmVd%hg<6Hk0JzK$0Dno2 z&QP2oLpK2n{6*F^04rn-7Ro5x41MtlwytreyUd;XmT0ot3P$4%M9Nf{Zh~Mfz_ezu zzWx3ii(%@e?6rc!+7|yV_GFqfe{5*-@>MkT^-)Wh-!Ym-44YrY6bG@@Uxb&6DS=DM zS26t`#}slNjwxsPn9dkt04-!g97@3Tt?@?$3i53TU4$=@F#)_-E7=*&ssRj!v(ChAx zgwt0mtBf3fEf(Sf1`U0wCaqF}%BihR;4%bC5-1ilWTqXy0a~?6yeP4uAk;@Yo0(nB z%#-J9KK= zRsiB>@X!=o;+OD20YEUw?a7{4a=M`W@U~}9EjinOa=Ejk&s?ah?-ToNdnvfkd%sHw z9C|SE;MBwMUpyM;lyR<{6d!g<%f6oCdxiH(C+;!#hraAAzjp571p8l zt}kzX~&k9(rEA`TDY_v-o!5?UDtcc(mOA?iW*k$^0=>z7%@s;g-67 zu+nbpv^?AN|97Nx_IB_0$C=>mzVDASjmWHw6F+hrvnQ$tNd${%WXb>fBd8Gl17G5HKW!I0(=YC$k7*@_j$}yoV-cVxGRj7lAqa<|4$g|`M z=FXAN5Hg=9mwcN_?TT+p&NN4TV5^!@_YmM;nDN%xKA=F^JDV1j7cM$yU`3avA(}GIDw$?VUfOn2+t!+X! z0zW7|N7{31Xq`EbOmkBS{v81$TKoq}{5{;N6hRPQBirAR{U4|Y3S*MF<-h5#lKY7L z)gI(>&&qkZ+Kah9TpbS<2A}o=w@M+>RvQHvjdW()Z}r~n{Z)UJ>>-R5t84Z^{y>!i zda2!CyjZwcH2{t4ZtrX&pQxHJ#~_PywkO|HHDk_#T%M;^ps?vS&Ccox0ulD7wh4mp StQ@3}bIYP*>rJcnR{sH$u`bvE literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_882682.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_882682.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb78fe3386548768d6f67ea903f6fb5936d2937 GIT binary patch literal 6202 zcmeGgZEO=qcJ|BOwb!v@$N6#y$(0Kj5>i5oC`2u#1OhE^z^N!XdMo2uCr%tY>~0Ey zwd6{b>a^7**8PY%RjMnURC@Gu{NyJm_0&rDC#!IWwZ}=3${+5RY3QH(b#K;Q+X1Is ztCOxjI?~R(H}B1xH*a>{yf=UKcw7j|zW4qxIUPjkYl7ItQvud_3ZZ3$5vCGoyxdab zl!enX%s5LSnVPcX*h9-Lre$_Y+}KAj%rx$gv6y|Ay;&JB?i7%Kup4tSn({S};@G6$ zxk*32Nk1(Mppg(0vG6W8?y9cS^)4EB(@1t#H4^8ku=tRua>^7Azj6At*QIkKXH2mS z9yxjX$jBSfdR3jk%D}fCgL)ZhV90WrqHU##X$t&VixRY0q<19kj-jn9VkfZjy-F?|&_A7GdV)79Bp>P9#CIdty?rBjW5 zkEtemhFbeMQAyh1+$P4OgXg)=6t*^-jIDYKu@EZ;QS1 zTC2C}A+k4B*x0NM4W9BEvgHq`>J(O70pNiOYm8;zo3GBdTshvFWQDn~+uM;!{17PxYv#SQ(X?(_lnBHn{=A3h`MA zaVMq_ol~TX5iL4t@&ufe!90#sET&3eCRM&9k%jU}>0%-komNdI5{;Sym`_TiCgM?G znVXS_d8oKay=^+67kW%*Qobw!7g#MCffa0lxq7~<{QJ$p%PD1AosC50!O2K;I-*Uc zk^_+m*pNY0i4JP=%QbZgdUa>NUFd-WH6|t9SOS z`bILNWCWk%^Zohm8*_iU_Sv<2gCoY^$X(yj%xm|XTbJg3+qLA$954C;nPU$;{-yoD zk+Rf%k9YC>!ug`uymY+aFkQZA z%2JO-0U_eYw7TN8!BD(ZA{+$Z|!0Fwf<76{z1I{`BZ4M{n%Ab9Qzp~0in4R7M3vb?JG-r=aQkd51XK2`%UemBKy`~}a4Ec@eHTjL{HTkU0 zVHPu=J8t^R%co%1&Y`fgd|4>OJ`b&HgLu7%WTV1ty&Ft|8@n7)W>wQgK+89%VHW(? zWMo#hoP_Bp`&W}q;`mHhX!MKKiKGsj)Tqf$z$KsxFkjXCDT9ESfkwXQWuTBx{Toyn zRPrIwyZF+=OG{@KhVy}sJFa&W_FV7ET*|y!6g}BX*?pNK_u1zPM+`Pxy= zya2E$E{IEKKJab=hI&3eaQ(oI=HI=N<;tuddT&uM`X1kH@ZI@Kt9t!WdyD=+G0^hu&jXyJv;)?XIic)z_|JJ79^T^r9u8-I>7e^4Za8vdJQ1uKzCRu?Yr{0Co1@I3L;P9H)X3=@< zwn7^oiX`3pom$bHzg=?6{WeNmI!$A3P0P;fzT+O~y%3i>g7fAIYR1Xi4iaf7a z*R{z;=WrxUmy# zYEm?!%F>>CWT+n5TaUb;93gs)kgA1LBz{zmLKWss4iZ7gmz9WM5y)-S;-a|sR7uxWYT6l#oTW)t4;lp>v@0`3l_Un6N*cihrGwR)T?Lk}D^0nNx!iih_ z?SZe_SDrh2cU-zRK4FYctjxZBx9#$S-o8&mpM-8q-QIEM^2^{6xprI=9{H+-Y>~{Cj8rLQ|H>jy>Q*`Qs!re*RW?W%#r)^oy0VzgoEv zGtN$|#FZ8GiV>eLLCQsq(9j&E&(Mzuc9wpOkZ_J(YwIZljkfJs=MwY2SQ1cZfCR~b z%qc6d1aEN&e(Z-CsKOrQb!bdS`CY*DmfPJ4cn6w;%-`f~NFk@kWTxU8Jcp!`ss@fk zsf7H$l#^;oAtku98477xD#4 lZUTuY(7I0Kte>I~<@rV&r>MZALk#k6Uv1j)j@w$!e*o)7Bmw{c literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_900175.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_900175.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1477a8b1ddc086308393e76fb1e4322077f362cf GIT binary patch literal 6510 zcmeGgTWlN0agTQ-kB{SY@c2pg%lNjfAUP2#9{%ew2=#pMJERJ@QD} zq2eS6(0pWp+nw2&ot>S%ot@b~GmI5M`NQn5C)fN4eTEUeXcRymdI*G;5r;S;hQ^95 zF-B;(VVWGH2*eXpI=4O=TF$^zQ}(Kk63CIY`y&)by+z$DagCWdBafzBHO8B^={IlF zPjAz2sofv3a@Mz~F-v*Hwztrj)qr?wSvT7p5wzdOs-00SM~BWHdzl?MaZa@t!Q;bc zkDnL{yQ@?;0ZhX8@b6GBBMD4eEK>#@sF3B~on+BPa8#MUzG{{lzlZFa zAS8Vls`aDXGL|)NkeoUzti-SPA4ji01RWYdUnAC6_Q)n(Z@E!6mtn6=>u@qqrB|)h z6LO$>-BZ>3SW@Dv^)6eW3G}j*^s?&ngI+crPBuS97lYjV9qH0iHIqlTWu=wtkQ?M3cumO$xvm7&Jw-;#6EfOBMw8s8S%xV|YMaMUo!pwN zf9DGd0Y}}uO{2MGJLZJ_W^`N1UfH3;$@Z$*>vX*#tusPmLI_Q>QBF05 zgiw5fSE*Pc#Hq#zG?HqZ=Y@o*nl&JR%m|inc2-gyq9jB)p1mLz znVAwk4Pz1))W^`t$5HWdlv7<*V`IF?Uiebac7T0mHGe#6&c#lA&-| zrHgE^1y)UfwqV$v%c{MM$3$K=NyS0tIDdr&8LA}_<5)2qih*zt%w7(~R9YL7n3cqx zpl91t%fDakxsnhriZh`w-!mBsUkpi;iFkKt0=BP56v90cKRwd}SWhGt5+_-X4<|VO z06Q}oeRgJ!1p*7(CMDu5?q?4@_uT&89|1+(!A7od>_uLP^D(x1EGtSN{eaL5`kCPq z{{=s43O(vXZtpU)#AHKxS6k}zJ*RsynNDWU<()fH16z!9aWp-exs=_PXF5_RG5%Wm zwam{~EURafw!@15aGvQ)o!syRvZIP`&#GJTb*;`R-h-*3O@}*m>YmG=iM-2zT&sPt zGu@dP${e}n`8e=VV9okM>e!aAG25OU|J2vDKyTXpSw^uxv*D=AwkVE04@iRP*whiD zf$1oikkg+@DvsR`%p~(nfkYPXBRg`qG@X8y9aH=rd8Tun*{d*nZyo)3@T0*z)4O1R zaZFvVp;NJUg792pP;msY^e&KI!a#f%mgj$DMIPUB{Zf5)$7;XQ*qe7BTzB^=?!Mn1 z{oUYi2J`My3#KjL=vwN^j;@;Wo}LBsJzw4O$)%IoLn~L5rf2iMo(0Q>!<{*uPArfc zp2qBH#q-R98IRhtv?p7)(xEg4^Pa8+^9F5O5HhFIvO+fjEA7nmDs%w)eNAagVL$Tk zExdqij*L0oojSJZY{(u`oV!=9-VQ1qeTuX1PUP;Ka^w}o`ATZw9%Xua;M%~#NakSr zjqJ$#uikid<=BS<9}KMT?p1d8=35T_g*pV6CY+>tTwzZqU>7MgLZzAnUYZr+;pz~= z3j{5AX$8NwjAS%b4qt#$Dpf%XwGBrlSTSzHLFx`QY1)RP5?+{PN;XCeNm|0XX+@^- z8e{_mFa)@G!xUT*PnL?FFvGW@mshdPc%#0CZRn;lop{Rpwds}lYtt+9lTVRfn_iJ$ zn_iJ`mQAK~kl#Yt~#5cx~R4G)jBgV5h zm>$erx;~sbvElSB&ZpV75hT+M3ya9j2@adYU%;?O*Kj7)^ye$nH!H3I}1mVfe}gBB|oB{s$ZbNbrax zMD|Gn;&}r{B}vYhq^hDQyn{~^&ICtI%M)Z2sgC8VokwX5XTC+>EFHHhskP#7e0U8q z8KJ78Z2F0*0o5dyj{et?7$Luvyjq!y=3c&uJelGw?-{GK{K1;#4J|tBZ)i;uj^2N; zDK=uQrbw>CJ;PHNT8hO4Nc|mv?>oJ6@a#Sh|0~(}NBjmV{@tGkB=Av=Pj-e_g z#avu);4#>=&<|OYDB!C;Xi=?N=%6Km!ck0SKsmOQ8$$L=YcT~Eoz|vNt!14d?-qD* zHYSNGd6Az}Efq?QXrn=+)(Sb^fJraICw>Ck3LyuLJl@pcrqh)g*!1~R!<$YwKrT<} z^dlE)Y<-`(!K{QaDWhC&T3l<8w(6Ug-&}fgW$-qAr~8wJ z+^+L$W9<6agfcdfo4LGJe`TwE&xcJPG_6kE>AX9f>wcWveq{^7n!csJ73=E2?aADZ z7ymf>r{wRGx!1#MUT(8-=L5gJ!TPA}|F6i(xf?CtA7w%}TE9Qa=$3gfN*wl>QiFw; z3wzvM(Di9xU)r%iE{tx`P1(~cQKk90+reD_S>?bhx%0osjYpL86S=66 z6R#@KxdP-�dj5Ck^Ke4={G#@CYHxh+(t7ZN;zDcPyAQHRaw;mp+f5=aO5ihuQsDnnJ}xE%?YIi~ ztf;o)v12FTf#oBl%~Xu&>Hg^iHyh)R2>4+QXQAQ;P!&jmAU;F(zar*us2Q?+1O0a3 zTA*O)Ck_|dkjt|ur{zLB#&#mFe|g{1zR!07_K-vddy(sxL>uAUu(=ld)BOb!;7xyE zd3&TBSp3I78+O>GVU literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_925215.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_925215.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ee00bb27571109031416341a6a148ce7165df98 GIT binary patch literal 6364 zcmeGgZEO_Bb$0hY_P)P;_W8S{-~${7Cm|+Kk~(0%AsCX#uG7_VyZ2$vcgOA;Fxiz3 zH4?rimCtG=A629{zd{O9sg$ZjiWDkUDN_G8$A??dFHH(+_F4aMqZWwKm9Pt7*aZIvTfwg!Za54k;jaHuP>hC6#4Lol-}2 zm2yDxH}q~|0?<2vcd$AGmgx~AK!yiKz`U|wYTT2rIFA(inln{)Np5-buLdNwNPgCY zckKc=Ah}DTK-adaPC?i98XS}yRaowBnA2ZTSwkgtOI?%T;MNzZd$M(X8$y$zFSk`S zS=|Ey&7iRPD31^CU2K{0>KK(urek43HL~JNmW%)F@4sFA*T>g)hn*@Zrc`T|&&2u2 z6q96CLn;$tRbvb)Q8ixR_>7>MHSEbj)fC|)={Tnv^~yXIIiH-Go>D39d{#9IVut5{ zVTQ+6r&_rbH^rqzMw|)vsPr&%f;*Gq(yH-Dl4XvmF0En?XX1O$XHQ@z$EHQqA&7jE z<(MhxlR5Puns^Wu9znPOWc9tZxuhOJxX5RRa@skma#2VU zOq#*%%#OjQ2e&;96t$a)onx8P9G~V=OygKa5P|z3|1?O4{j=~-IEguQeK*WS9fmzs?$9e&&Lai#Wtm7^RjbGZePjYUQ8sq|CZ|( z#kV7Oxa9QY_I>JUDNO&oqu>B0R{LCEzOQh&u=|?#*P-`9x2;d*URZUw=Em}4g)_yi z%Z^RCkyU?Nv9~z!iGTAfU9x*toWWwh;@q@G5sto+v!&RnI5(~tNynzL3AutgYNi0C zktC&5$j%>C92;^YAgIN@ zV)Yf0inSZoVd*JZy@fHw+6C*dbOYJuDeO?J9cyOF(h2J@+F>0=^QHY4_s_mkh%F@N zlZ)ne9q%~q^b9FIL(45Y{%m;m9&A72v5VueN8?-YuMK{xML+*wJyGzqwSv8*7Jd8@ z&ddUm&}5BN2PVKAWut78%`z<+B~y%)Eg~MJ70H4pfZGtkeM+@5O;NP5H=ii)O=Wmc`*2q%VOjj#6ZQ~uon?Rv= zwc7^k|1S2EVMC+OOoJah7vZylYQ>=D8iX(fcSRzSB)HCFsC z_PGo&QmxqUYG+8$y|b<2rQ&dV2srMq`$hgK%)-~qN9_U10r-TE;g>^YH`=!ScedZy z6hnV*@9a=ve6geCYc2LIc{k(^tr|A0(7rW_G@7uvcvfh~3LRXbZ2-*K^Y+5=H(Xx< z4Bnx;6uPT;W|{7V@&5(0%~|u-!nQYHnx>$)eI?Ml;`Ocg{Qn3CK&q|CXydoTQljwH z3WO}WQT+`}0dginhEQ52Ab=)WgG{kT+0YQER07}HXaHCfn3EQ`N+{MCI^+FbS6K5k z`fA1G8o1(vvaV3kkLS5 zqeXDqy&N{EHjxuVCK?eqW?LgN*oZvYh&;s~#1b6Bzgj@3v!hibB#)3_nIXSqlW9mm z;Rksq&NtaK@5E!UQlTC47Je{J#L zb^1pCt=6RtqqoPIJL7R>Jie4Ydpme;wRhvY?eDZ-nY__=^T<;FgXG?GtG?ESo%1^v ztylJ6Pb_sm_lM~}%Dz!g(b*Q-+L)*h@k+LLMjAFg8lwN62!5ECqWO14?kythqq_)Lyor_I?P1 z?Y`VmEs_j7`4_Q8ne_Ntb?NovDR>8(hCE-Ts}PSZMx`c`B0QC3(t-%yg4ZK_b2`Jo z`w~2MoaE9%hR3f_qGr52wj9-_8zB?Vq}cb6Hm`1R^scE4JDuWo^Y|qTzi0{X!>>#c z1o1ht{{=ZdLmd!|lk}y~#ZZ~tP3$apA-8u<%1h;LjP)R2U}5Y0)-QSiyGJ3?uJ;-y z5$VbWE_Gh){7HA2>?Ld~Huv1F{H`(ua48U4n3$g^8vzFPF8^FApDLR%MkA|ht|Q-3 pwqVSP+`cbt0Abc1jnBQ41R@+?*vARNxAr`RT$`4ieHZn({6CjdQ#b$s literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_959027.cpython-312.pyc b/src/temp/gen/__pycache__/flash_decode2_phi.py_gen_triton_code_959027.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7f5a8a35696d1f3e969eb1450c8275f6d45800 GIT binary patch literal 6227 zcmeGgOKcm*b$0n%{wb3BQU9okO4N4if96NtUCEZR zY-k`r$R>b9%f@{R^rbVh}?_(M+W!W(Wl* zW11N)fmkAAfstc4`>q`!U}x?vkSwGEwQ zjR6y5d`CNDs?E;y4w^BOh&9*juq+T>^8;+gY1#PlE2J77F9O|t&Ai5*%bzFtSfk=l6*V~Av4y%yr2i1QBT0maIX%27P z2lzAvRmnx|a5x_D60(jF7G~Khy?oeflSv^en`XI0kn_dpFeB>%(S%Qsb&D*Q;AMk? zJ-=7h0u`8?FCJuNy|2<5V!mtP*j!B3u-9f~9WNv}78E9AEgnTSu~9b0#sykf@IEC| zFVpAPD^WHs>rRCk`iyK>8hSDjJbrEV9M*DXPLOT9z=au>j)6>qzO)4o$+Y_6SaR=J8Iat&ZyP1oXlioP-I*Wy z+`S`BZNh+^+0Oi)e6-ZEJFVa7=qVUiPD}jzo^8+fm0G-M{if5M+m+pwpI*_IoFi$&eRo^# zmFz3|(ZW@!>)DcfByChl`wJ(euH7Z~P}&GiTAs)x(wcS3oaO-hiA41P=IqQ*NY1BM zjFNLt+5kSu7|VmGb!Yh*30s1x-8GD!(| zG^L1?q7}g#ONttih&0+&m8nN6+wg1g$u@kAdT(#TuSGE;`GoPC^Xua`=hw&AJVAeR zeqDcaeqFy-G(a3xef5>sNnd%LI2@BbP~Iu*doNkPSG^i|NJ73@R=EgdedR*-8k=n! zrI)O$Uc&eSrsH$5%Jr-b!ZX4fL3{14Hi# zMhizjp13)&;=4JSK9D-T?rO`8W=Hc=*|F5*hNWfsK+)WtI<~1DEF6=x-gU~cPFd4$ z7vQj2mn<2}^3;3weCtQuH@XYHpY^Rg{n^kbL$|No8d-he*JHmNyG#7)@H*ALM)gWm zZ~jV&8YmoI8yuAeM^^_*gJY|SwL`B-hh8hbe!g_*0vz8ZQ^vHs=RNB(pIgi>76v{@ z7W^MaZbojm{5-bq>{@Rf_=la+>A3?i4L^Lf0wGN0Q2iaYG-NSE62j9Y0Rc40XpqhsveyNt<3`Q>QM`5|8qF5XsjO7|3+g# z^NsX13y!gcny;fz)H22obq!X2=gr#wfMn_GdsD9@Q{VX%yx8pj!Z-3Mfs@kj_9@<& z|3ja`oritOm_F1dbzulWL|vqo_SX&pQvazy`o9~Yi?9+&1B4-VRYIf~vl56^u16^>psi~;BUVJ#B%=aAQNnmY*i1AI$!FJb5iuJD_zJi@ZpKgams z0mdP>tY_Hyu%Gqnl-E`^D0v7ZWi3?zXEc!AKwhDQZL+bFzCdo{)yigp zphGIcbNB`K8fB9b3@FJJcNnv+IPt3HT##c@N|eBaW6ChHsb(|ex*W^TMFn2gTxJ(! zW1Z6^i0Dwb%_506U{D`?{9Ra9h$JZFbfr#i*d36pxLZ@FHta0`Ih?79#}3rl`;qmA zwcuOnS?!d(FW;HDbL!sokJqLdX__g<_f=(U-k zG!rb&&fjalx;gO7$6YtORwAoA@0=Lu+`FGdffm2H>5CiqvzXWjPFM8x5pUOGY_YUgHCq`rI>=v86TX=}oFDKT!xv-Ot_j7RSy=dw*Cw`$q9%Ksp;NhPfhtO$skSnkf>; zNoZcFWUC@^mVAtm@f^9)-d|{y+IOT4%bK5B%0|>RTy`LfD|K3lC%sng3_el%N@6eD zEA6vU_$M?6sk=<65Ql%L%9==6fWIV(I4{7R!Byk=*Vs6p;A9P)X64Ls_`t~K%GslX z@FU|BqzqSy)~T^rf|-l5hd37GAR$EjhftL@1VMa(EPqATzoG839+A|$p6i}6`8@Gr zxfeN{OJYVW_hGCbxmt5W*`Y570DGiCq@^MiCESF5-R4?~W};;cz#C@A(pYA!tOFP{ zwsqw$W-pcv7^9HMzSN!RE*mjsLJrp>GeD0HBEtHmWriSJ507Y&eMiZ*^SW8F{@+!M BQkMV# literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_124574.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_124574.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a057bf0ecb4a01bc7467974c02634919950994 GIT binary patch literal 10690 zcmb_CTW}M{mNR;eUX~?4gdboV+h764*gOozn3#tl0UXRb##yb<$Pd|;?3pnJrLLl6-UTT4Zk(GMl=H;HsFdRK@JxD%kui`>N7Ju4~mSTSe97hy8LceC+PW?w+2} zSh6EyZg$(HneNl)oIc%s`kd2e{7)v6o`90De%=2sr3CQ@+>nbYU3v9e1wl*^G(nR? zM7Q)NyGgOF7*=+xNJ2rlNsq`$CYIhJ+_)TuTeJfHl(dRg|3*b?Xw8iFfpSd>B~ATF z6iQl`o1#zIN6`AaByH#;XN>a5FoV1tlA7EqPks(~-YRhOD)9VO;FjF*UOjF7RP`V| zu>9`U=?HoQUEn4>g*jv@Tm@dV3cPp~xXo>#*X8c(HPR)Ys=JMEkOHWCRdgv}nM7&J za&w}8hE%;2T@HNn=n8uMr)0Mo^vk@oHu={{d*w^KES)a{@O=3Weo*kRs|4erlkXos z;W+zl_ff~GWv$L-trLRr*qQgwpH5@GRAF+3p!bdqJ6O)y=Vk>Ru8%vJ5%y&z9)TbM zMJbS{1hP{gPY7z4V}xS>qILPay#sx$pd55x7fg(sb#o5R=a_IaK1Ktij21tzgp3Xz zg7)m`lkc8$oN%p!mRxWTJ{+%l0db1pU^J@hKvOXROEm#F{z{Ppv;cJ5m9_-0eS7no;`b1BhF_F+LD%iDzQz+^PiY)km&*zTz$)|h`9y#ocEa3!%jQ+Za z;`2fk@$+o6zvbsyW&eJpLM_X88s*DckM=J=mB=FZ_*+qnWP?}G?w9QGayn6K4j>z? z96SO(dr#JGlxFrzuAx>`=UFeeGViR@@rK<-2WIrL6(fo3**FX{~gP2=w!e| z5_VEhIj;?{FL6Wx4K>YOA944|gF;-S8g{aSg4)ORyBP+wn_#f3dWn6Fon*@JD_>A> zLxO&U@%1szVaEV1s9f~GRY5sC;1x7Z#_8>I3u+t_v4XaD$misCHVJA8tb!&D^gcl? z_M{OS_XtYb#|e5T=NleyIh-^NoLDTJpk_HAa6>T5U*6Evga(wh~-GC?yjg4>K0 z3ybAsu%ltXs}X{Z<(L85?YK6asSXPIG<0-4yY5|rv#!3ZdIml*l7UZ*c$ZkZGTdE& z-N*RGMjY$_tN<9Fl`$cMRYXvZz_L;BaIhneK`8>oE#QQMQC#&3`Eqr@>kyd=l*Aio z4ZQ;l%ZZ4?f-y5Ba8OH(#20LcueX;4sPvZGgMtQcG=__BfQ#=zLFsh41jXPEA#S_uIt#eHD$eq0#>xU?UvxG@-@?gO2;CTOMo0UeMwh!qqpaGn?uG;IGs zF9!qzUT&9TT+jhR8l)9LfTO0}qtK-l2ZsetSG)q{8-rw%YRU>R2f=N=wufL4Pj;INY$9vj*|a@_$S2k02* z^EzBU(6q*;UAvoiL62aPkQ{?<#_JwpXpj*{{_GKm&;rEkCcZipa3F9pWwb?i#>)A!dfvD_a5QDgPZU)1rs}wnH|`1? zeV$hk8RGLces$*F2h$(KM`yPsYY*~y2cIJC|%c&zsH%I-Z*fBF<1Z@cK#Rh29vh3)MwV#vy2%w-f}jmd2XqQweXhupgIMk?7nM>9f{d)x5a_~9^SfVuK0=Sd*>7Fql%|3 z{J!(N^?Xp98*z#G(=NWHGjXPe-vO%(q3WCJkS9#Z=LYyty`R zn{B*bHcS6a`)nUy*T$RM0>__g^CFf{K6*y!!={ia(jQaa(@*PXn!ahd-;%7{Gk59t zRL2YGq(W3=U3BN2mfI~~UAgC(_9V+1=cwOP?XSR?5CRE?c68i5(AUp_r*t`o+=3pv zUdMa}cVVWGQ1M8)D zkNAW{EYUrJ-t|RWldpaa|GElN;w?#YlFHVaG_P7Qhq9Pu}746(I()^ ztAIOuO7OA<(XmauC(Ut*OJ!`=G|21ovahIU^aaJgOCCj##=rhJ71{$DHBO5KG>1Q~X` zgYrcz>w1wxM?(t07gH-TKGsu!DKU>;1xTvu%nB*-5r&9eBZVJKC2lGO)v2TB+B0!5 ziw}8aAi%ibEl{mP&fzPxvu$E)Iy!+cYs~58AXs&Xkqz6ly>-Y3c7$z%iiGR|Qnnc$ z{}CXbQF&oQ$Pj6daM4cQwlzs@o4NYUNB2Jhk54~o4?vvC_B$cI*F#d38|WMJjX_8Wp}wnMP(!fAU1#i=%!?m5c!ID| z4Cd_yrWeETt%d2wj{*F^mQD)#6({yWv^&O)jB$dh590r<)i40oQlR+GJJPC@thyEq z0{lONAA1rM;Xeb!qJ}VKx*V;Fc14T78oTG8 z_Rn=azVztQys0D5@mirYY63@IQ-s$l|-BGv_Sw;l`PwqqD*+)NY^cMwCGOx?eduOFDp`%b%843 z0kJAcl}9}b8yffx4YOo&LsJSg1bZ}Rl3Ewtbf^AyeVj~|ZUKtd7M)tNSV$PHfm5PR zxUx?HwymN^x(iRy8v$Tt@xvMMV5}7dQo@SmXU&iO*$OP&6g`2`uL65Y;=~fuF4mw2 z`~uu{0SZ#V6V9Gr%|Wm(70;4|mm!e)$>9$?SjAQpe1`EK*lr~rqc@Yrh2*tgx$2xq zq+0S3cs}wMye$6k-Ay&4mMw-~16WIm;SAaeEG^O^4VJ`D(HcY{Dwjo-Js>Tnukrap z6r|x`etsRdUeB9 zcxLu4M`}&x;x`~eE{y`M$n26aOTJu*AtkLBwMLbtFFD49y(#~xoHK?8X-WMiu0rmo z^|FLVh!0_Q{ybztMwAB$Ed4K(QgSPuZqK8oG&7uw30Pad>GKju*XdT7m89R9x#^Jp zEiv^V0bnCVuL^vTsq^LkKoL#4Xxz`)6$}gF3)(@qdju?eIA+u&Fq<)<8l2hACZ-zdk_V}Un%yjU4y?+8-p35|LD~UakmxqV ztP$i+oHXQPAhm(Vz%LjCauHtNFmnkr`vB9Dqk>|zSI|JVV9*T%OUVm-?3R3sRFV9P zX_<>Ssnv_;~2}g0X@(RxB8+d1G~~Jx<0Ow6gb3`P9H7~6-UTOC4f`<;jMaJMNqw;H^7&YDY-3 zs3feV&vk~dEo6&WBA&^LSXEM26F8L01v6M8O%W#Ak}%XJRJBqtGYR|WTCy+Dz}^A- zrB_d3Q&;m|qVG|`9YsY;{zbzr`7%#7H(-Pd(jYarJeQA=Jpp(P3M>zqu|N1IeZCp24VlyHs8Pl<>5yVg z96i#mf&;!&w)J4Ia@~KH;(O7H>(;;l|CMVbyyCJi6>kPBY?NfTK89JRLs`*eiN=Z< z$Iu%55Mx$sahN*%2*5+ojd|J8F}Hid4FN1xpOvxebqvd-uE}xcmzYi?Ug{eK<-|xM zi}y0i#7O4FFy_UENQ~c^+ko+F__6;Cj{x!9XbqnXolHlRls?!M(m?zu9YqM@N54|v zGfkUjd%qigI6RLJUFH;(Ri#f&*^b%6d};GSNgH3%mMnQENCsO(JJcOjMB9_54Up&47ls=`4bhEJXHs7oIGi%&g)f9IL=HvuF~HaqFZgEt{q?h> zN&7zDv~TXh!huu#fm2VN$pfc(SHWaq&{3`*PCIlTh_Y0prqff!S%{Cmy-i zP=SAEBK^uYU$8*kNw{QZe`C*a=(4Wmsx|mQ(D^6vQ)f~MI26o$&F=v@{2IU3PeEvf zv&3S7FJKycfpf!0 z0+*A8#I#*4`BeNRhW1LAE9{EIJC5BbU6Sy@l9W)SC!GR$MRcZ)D}wflLrO3iuZaKg zBBhuF(jl1A8SwO_OAG%PGd9GrEKYi4scWUAuZ5COKKQ!e;^bQ=^IO2Z2R{~{15jK` zVc@+NX+7=x;u zk3d~8OLN{n1fRRc;5&gp$u$Ol&J!r-80X{0ylz2rSp2UcLFpOb1RdOVS1j<%mM;bHA=6cYVEA_3G8D_g=m7KO2p?1f<);AKuO{Cx}1eijoRah3EgJB8Yng zOR(e!(Ie+%4=L5DQB{wcBq+j34oM>^$Z{^V8`ihs8cVS%R?TWy?Jv}Du__CKS8E!d>pv`Ks6ChbMp+Xr&l;&;`LQZp)fPY!E#5<|rq zgrthBWs~-jP1?6@(q8J+vt`-)1`O=>chx-xwvsK+4#8GP*v9O=uhPnr1L{&(+klqc z@vf#P?_1{FY=t|s&zGVxD;I(5yI{6_b~kH%m+UEkGgy%3y=j`X?Rkdz-T5pC-)95~ zpHsL*!^fr;4W}<$YCCVg@=epW?(Qvl=()TZ?AfMqAZ==Lqw|~^`^)W7| zSIoiXNroHqKHG(7P~(pVf3UdPR3dp%Bs)d&JmBdZ9{0eY%PXpeozt8S`Z-JxmxDi8 z7m2znmoL0=)qdWw4QfilFnH$YK|r}j_+T}18cLlArH*r!d`vw`K?JsK0jtV@EzN+% zi`bBO4PT_3z(Cm+uv)%24ZEGU@Y`pABY|>$8(%&HGzTj93f>~M?eJkXNQp1wcgz46 zcuN|0C%*%5cFt%CVg_IWyQC4I25op8E5C~`!sBGngSoXKvE&-<(m0j8l`noNy!1LG zyxn{$Z+;C4|7-k02|I5Mdu?Pse;xK-UcCu+K{o6ZU8Z$o)MW4} zi)H4E_~HS|m)_fR5wyxW)X`2&w6YH#WG%ycC2?#PEi?Tg-KF9{-r# zI|u;SbJ*ZOaV3~g)fn(j10Va0sC5C&_R&=C77MVyhPwvcc4;h;mWKos{lFmS^+~;> zqG4?&K$k{KWw05ZfdMZbRLYh5Fl-X$n^)9I-1Un3HNJYq+%#jMgK;?gqy7=tAxsP!=|JPK)NIWK(aKW!+Qc4xaRGu+dwtryX|ow0B70ltM_t_dY^N2tR8Ch62+s85BN|wHf@I{I|!K1sL&QX z|5cGCG7;0GlKt_%WYMwUg%v|d^iZrEZ9jkvb-|7mV}Ve#8yR=U4am?K?D*PP80kYs z>n~3~cw_#JIFl$(?rA~Bmf)F{SMK?J;b*0vmM%3cjU-!pk+C;;29G}Q^RAdWdi|~^ zo`Wn6i8Bk0zcVc~eSB$g5H(&xmP^QZS)Od-=S-|P>bm=bxD}H>jV!M%99i7+2Xb-m zlaos~Q1dlpxrU6_Vc5d*m?f_MkcsQ&cP7puYs+Fi+SMgo?fJ5|Pk5_e=

W9m0Sg zUG<|~epEOSJilTpi%!JY4_jldxgR9<8tUa*pH)16pn&R#85@c;a<_2VJNW8JATT=glS<<0X zmr>(oS*-KSKV_oD5f{qa5i`MPvB|jaS6%T*w6|%&iuN`m<8csvU0%fW!#iKmxwFQw zF)|PxPH6u~H~j_bbeN73pG zeSXgE05>k3uuCRg$;-eOc#nQB=N`cmL+S3dLLcnhMVWS2$t_ zT*i%wRVnV~CjIJkKe+{d_Wm_|?H2gi``7TPE%3AVui>kF#%DE;u)73_BzWx= z!lvzfNXgEamUT`HI-HJl@CJbJeq#9N_)_~1)K;OLFv&xB2k)EW_Gervuur-ng+XRm zDSV*7V>cVxR8qi>0{YaQ8D13joxo+)%0`&8;;KqiU+lQry5yD3jz`gXgKtTrX15|GleEvO1d`0KY>cjd- zYs44r1T&SSYadU1bm!rng@y%ou?4ksC+n{u`pODz2;B(hf1^>AYB>ksN?tgJy>_ujnwW|aB+-W583HZPnP8Bfxs(Wd(?b1iXea(nG2uz8ZAWRmo@ zX!ZRAa|hyNvTPsBk-b^}q9|9RU9}JfWAJrpGen)NLA_ivtk;n#j9AF=20cFw-_6uy z0IN3)wITV^>M)z~Q`F874xnyG5c4Ayu@9(t)lgdXtpVaUoyTegiq+s_5l{ncHYJ!1 zRs&H{HNNGu+8G!>pykziIQUXo=Sv%mZ0{8iZIdu%==4fVYt$(vctO%BG;YD8EO7;8 zVJRGcVQCyG%-F3gF$>t?EVuH_^j zD$Y=MirWE%b6BJ}ytz`?K@L?IxiQ>e!=G{dse+HFX2EpWz6$vRfT)5)^@-sx3F1`?JIMPTW0_q^--e4birEWqc?} z9}Tvy7>u(W;f@G+uc3~lp&GnngAK4YX$H5{86J)L;0h;Qn?h$I*TR?J;+WRU=DRmy z)u^~8N!ym`dPLVJ$i%@UePo$Fis+-it@>pD6EIt6f~|0Wne7etMosri=1O9wxw07Z zz%}oRGxI}T1~nAPbiJBa%Wh0|r&#VA zu7QoG&JjNZrze}%1O1-W;a=|`Nsf0GyMnDlR(-e5y$X;#waMJnOJGwO8aJEkuE{!v z{N9tjPjcKK;tZB0L)T`x003`t*S19|!RttBw*p#U+QzYI_qCn4`7dC_bUggaf!9ef zSsVqI^)LSS1iU=)-%42u3+UrjSo6IaG^@h4M0wpBO26MI^aCw$UCGKdg)ys-~R3QKe_qj zrf{Wa8NxPKn5X&u!Pljz)=%Dq5Y~>J5Tjz|+vbMi?Wp2=i8H@xSYSSJJ#j5EPllGP zORb;pMlGGfh!hEVUYFgl5PCxxkL)H^M-ObgA^ zsBz}s2;%T*>Kp01N)T;S2L;(#>JkM#XQ}VwE}TK*HR`%tdlT1ss9w3&OF=;7a38Kz z^ieD2)(@<)t_RoWugAOQdlIHZ!^cIa>TqHPRkbc!Kim81-X-g&`<58ta@Xe$bh=x( z`W8C_#;g74)isFdPV z>NtO_N=I4LTZ{*W97XqvT_XABo`Y2_HpK?{mV-hlBA#+});;PPO#3#JW>Y zO$|pRe@wdA+6?l=7+)jg_aW&xcjM|+JR0^5?fs&z-!28zL_@zF2h{9i9_ZjOS4Cqg z)bH{8#{53X_4JSUyk7jMuN;1y%L<@aCeOohPKQqkEO1ePorRxw8a_edYf~`KCXC5pA5EsOHcKoMC;04Za&ku+AG}T7EAc&lH zY9$TDEgTXqcMGf&Sq22JFLl&0R9gnr%Z=q-IQaWT1HUxdu8vGtx~5*^3ZBAQDnYXd$)GgoT}D0 zd^GZKBwamEUL?P${aop{DBvX$xP{nP`%(GB^1n*e=PH6SJl~FIDVMt2DKlw|nIDwT zm#-2~T+n=Ccw&%>OU^GwJ|B^a0z0_Ojv;nTuKA~=99H@^*@lh9+vGWH190{nkMd2O zk=zSr@c7*ERRW5MD^j+2c**=l`RC=r^}gln4s_iiIER*<xjGFJ4&jD4q9QS|3Kjh`|V#j9}KD{89g}0bgVPt%HWD1SI02uyxbs1Uo&2b`s NcQS9!J9q%LdtCjGqEO@B@3m42;2s888L|12YVu$ZGk8E!iy@k4g$Ds2F&ZDi zE=phgA};0?qsoga98+Lkd|0eW7M8wZIrU;0%Fzm1Nvmizt@*i%*3#O$#2sh@g$*Ti zwC;OKD89D@YT^6!N&zGJz7lKW(9`PI7&&S~k(nzpS7JarfYH=1aoR9|-!;lJ2rFTf z<}x>?O8PP6*VoGpF7?o#9+jA#sJE_f3qt0zvnHHz4RxdeY5|1y;hS1qF)G zS#&A|{G@>Q2>1y>?Qx9-87R?u{J#F7fqYWk{j5pv7x`KY!w3qQST99H;4;dYb zkU9{wXHR#(bIx_bvjuW;Ks>04XY(4QSb{J-DGa05*998CtGJWl2yNMfSDC}RHHQ}s zd|hDGto2T&y^v!QUJYx@@)om&Z1J=j!$KwO7Pe#>>IiLPx3Ps{+4dls1~IS>cKbBU z0$Z4sEoHZZOzAX?dm4B`JH#3w2W6<9GIj@RMfK$L2hG;HK;vtD7wg%{ma(?ilFN=m zlq+Yqvi8@Kv&&-;YjU&- zjhxlXJZ$+oX;P+ys&Zx}kFjB`tgT-W%+A$t9>a52EsrBq9h40m7`B|6^X(NaYiEm4 zKV)h7+lj^S%RR7iHEfMElM-LZ$vW8zajtgd)SX8Gvz6?wbtBeZg&CnWceS!g5AEKh zjw-2+wOM7QI`Eu2YS;6?%)`#3#dQnSv9)a7G#p%7Hw~vPRKK2^wTL^kVUwOH*#_1q zuh>g#Y06T>F~8K`+%fOjq$L%!RFJK2-LA>(Dsz^ezrEb~X-y$7`sYGJE$`fJHF~yYH@EK2fCaLAHs3gM>=dbZ~0SJ0Mp8wpKOj4qO$~{@|dOVMrjXAg$xX89>{Z(OS;g5Z)3reF3pd+c!NXmYSiILA#c* zBDyXRWQJ(3%RRc5yqZn+{4@|wkzdrz_(pS!3g(X3=6@dAZzDl>pBjB*US6v{1K0g1XfV%|(7~Gz#URu!h zxdMYj{Xs$1=Nf7eR0C+0QK82@;zbpq)fEU>2Z;?xkOodjh$}AO9ziSi2WB{vfF6mE zEFf;9z}1F;YJ^<$@v??31A1=koYrw=^}jwluc@MKDNAuB%?g=N)0LK>yH^2EGB<7sj3+ zz(~VPU9^ttUDB6^k36D`Gbf`bx$z~cEZp^^z%r|iYdM@dkSeGNzw>BaK3!Cv7~+c> z7FD*i$+B2j#ha?aZ>KE|?h)f9l~7f8YXdxDY=2L{H7sL~CN* ziP}YL!*7{yrteKJ=?{mGJb`K&f7&-Y5Fg;&iIRCOU)Gi?d}GSQBttBq*WwvvRFw^b*bd24+{ z^|-+Pg`2a*2I9uVF20~PqIzVp%~r>&xvE4VfUM@dris8y@!Z z`+FA8T;}&(<{g(`s>6)xI1tgSYOul*uJe~=;@XV!_L^h|Z>xhL*&N&;Z>vg{^VYhE z_OZ2ytBxN`T;MHeUoDW~)jgPw3U_R{!H&I}G9c|6 zzOX8}l`m}KjZO0x7S4UwxNw2r+x1Y#@9pJ{y|BehgEHv^p2YBO~;m)+y7%TfR`-srb5K$sl$e9zi-xF0&Aa95oV*RkK=-(5&pTYUX z1pGrmJ3i?h8W;=;8c~ITG8`8T6w{6vG>kVm!T3DLAj$z<<;DA6HyF-Ob)V{IF%}z^ zsX5w~CuF%0KH_fcI` z->-lk<@Pzl52bnTUw4%eDr~wHnQ5(>2Y#314 z`at6{+S`Dx&eHKsjkOPhXQ)ZX%zbz9z4_gWw4S9AnjIF{eht zVj}$dS|Mm?@3kS1SHwb|tXBcp(aIvpXXtkLmx|4N4pE$iJ0V&@8Gq+sdj$U0g(&PE zTD%3z{AdqCnG)5}uYjJadU8Mwf--;&HIngUcA~6OP@U{L*TH0wSO9I>9pNx@8NLE> zV8lJzN4pPB@6Om(+B-hs_5}ehxkS4hXsSOj;s;a_I0y-enFi578$|yZ#vT#I88S-7 zI$}Yt2h2^1sJ(mb*Vn(kzR?xZq+a@kY$P?a%urw z4!XwxjsnE=3<_!h1;Hr>VO?>)MI#S7K?IIYlKB%vaHFUXMIIDM8p1yj9GeJAKofTX zA_AzHGi-E*vRLM=x1`~3-Eh!M0Qk`q1z=NP{}sj_lhjOGwC&^0a7TDNO`2y8L=VUr zDyE$^#0{JqS;a=apyGFA#iFWWk*r8-4dEX>G8fJk$BRMZnYr`Hj$gn1_1g;_-yXYv zY{}di(L7fuDNVTZIf1Dyu?k*QlqPhMw%Adm>OtS@B8?F@d_*m3{)~+E&b}9ak8}UU z2Wi4IV~iSO<0)b**LL&Z+`(j7s<;;9L=Eds5nH&bn{{(_Nj&A)^?i#i&#gLj;i?Uz zEa8*lf(iJfCwG^k-H$fYv%iCssC#qI1`trVBYm(C(#^TjnR29?LyC}+RkFC~H^6gr zEH(`LeSJWB4|qY2#k2Yn)sCRxx_}-8?yBU|Bb~jF3V4w#ab3vEt7cWK8r{cf^)&3D zkY-I6=ko)xY7%2Lh*Q?B!&z_8T5~i`Yveu+ql39F1fe}C4ClG=%Qc6zoAd>AndnLQ z|LPCH5*w|TJlYV!YB%DM)-=nhWtAaa&It23^kFCHXf4_a;GbKYvpg<-Sn)hJ_q9io zkUogCkr==W3bo2Rs}sBYGgx<$B_yu{OR^-i37oPYa+4Lz)0^xHQY@A6W#}ZICas@_ z6O}C`Go=(Nore9Lr5G|4<4Y7%hEnhn#hjs7UZP|?sxYqg?N~E1rk5zVL=k(Xm*+-a zjRP=0pZ`LIZN7fes)KhBBxLFT<60VEkLA%cv=v#+knx3;HomaZemDk|sei}1Y9Py+*aBMelEMnpR#7WL3WboFH4m3d z>D)M{b%Q)3@(PLT%Yg?_pYJrkdi}Cv(&TN`F4KoAb9*DDWoyD3et5h6O2AEp8H9(l z^EtpwN_q&qAXmNKF-Rcr0ipga1}#bk1xjcHkyenQSwX9~f=u8F(kv^;P^};xC3)kJ zN1j2^3VJP&pbjpgnw9^7;Is;A(F?xGHzu*>iQ#AJQ}tP}7Hl*5(`V+ibMquS4OaFaC{XD&o9e0(%~ zIHE|ChL}Ejh;#prEL~KUE|R5)Qxnxhm{=3%;H~9}iz#AHxC248z0tj~-h?W$H$_&0 zKY^@)8|enc26L=4`T^I;Q#-<4k0|rZiRcN=#;FtPIYWx7TBd4wsy5j>uSikN%Ty~* zwf=VEo9p+kr>G<0&L@OHgulI>in(! zo8o)L4|@OUy?=af3AuZKCDs`8a=pJdl!~62FUn%Qv2o5E55(S$)x^nJW8C=jiJRBw zuBQriN7RogGjMdo%A?1*ojg^NUYCEW#0}flGd)!+4BNoVVX(X)`&6rhV&bJ@oeGNe zxy5DCQ%^~?p=6a>Uuw`msWGSYsYz=nT2*2?(~a>Mj3r(XnY>|IRb$2-&mEX`%k1v> zZjMQGB`@+7Z3|_5Mf(GsFYj2~e)v&^^Y)2bCl>H;iTea{%^b$VCm!nzGmfYu))Czn zJ}gaEA6NDh7tECzd6LXrBiB3E%-Lgu@$JjzoxFMHS9@=_-)c{p8zP!VM(fN-bcAcY z**@2v82{P9`Lb_n?$s^%Pm2c=z_@TbJjXQkAXC zmF;|Gd#duyh3m_QPVt9MEuDUM`Sb_;=?_weE-jLm7FCxdG)_!r4j~gq2m1?%Jh^U& zEIj`aoT^v3MnbBPnpM*{x?F;jM8hh@Cpj&+LqKwU06?sz6*4@^r-tD_H~{Cf97tQP z3s6TE0Eci?W4{8RVP5x)bOFxvP6s^s`=pbBV+K%RtX2kSA;pGLf83EEXHTBq1+%W6 zeE((2n!7xbXr83UH+ggbRI^u?ycZCV05JT@7NM7U{>cMR*qI)L$YAzTLc@NbRWb#@ za6jx3ij@H$CG}2%f*}KI0Du6Oz~PKLC3o3h-6we#0CI^>{4CY4Kp+EZB*f>@KU%%u z$b)*N+CLuz2t($3KFtCR#jB{|^$LF_2b?K@8@e75MTj_n`54+~5b$6=L=l3j63~Ea zBL*%G1PPcwg-FnNKOA!d2GGJkVvv(W1P08fi0(tt0E$H8kKh6GV~Cta=C4sYs}n?v z&-^)*35uRc<|im2{!>SMrwRyq(VR=w&mp3A($zN=^adon(2a`lz5v=WSt~9kBvVdK z_>jRDeMQV|U`oI*@F{EpaG4nFGgqQlGS*v1M%>ZD@Zq%C8k^?LW#Qvsd}fYEkH;8p ze_}jEIhUzDJhf-uGJj^lk)k@6scxR?et2fF_iTzf2QMrZTWoky6kOBloSD}-pwcxX znj%ZS>`jmfM}p+8@MJAA{IZ30BzjY1^)gw{ll4hvzG;3sMIK%zkMrd5hfNQe6nO?r zud#4uG&=h8uA3+3PNa-A%f>q1SeNXZ-?@O_t4OqIMq%4oc?F<_EeD+CGHq_ za_^#QuV^ehxynTgD}?gToueu>W6T+$CHwO7wu}qMs+Z! zQ1rJDi5T70$DBiHN_vuYjrjq%Gvlbpka^{lJeT4rpBM=S0_e(+d)B0rW|O++XBe+1 zDEnrad%*km@C!^q1V_$n3m;#@q&PgrwE}3|c5mAfaVXsJ3QAS;e8+=bOI0U%Re87z z9FX0KvgZ_LuLl@wU(h}%c{rW2T@0TD@S-b>g;v!nlO}v@wE&~6;U005ITZ|=)gG8j z^mQs3H2W(UT!Wb3paQF;zlImrC_FtN2;=rh%o*3~SWWEJv+ox|$Sk}b z*X#8DRdubJ2!rN7ESPRf+UkfopBGUI#i|2S5T80ecC0G8aJ&}W$?ytX8?Rlpi`has zTO?+;$l0x87Vy5rTr6fwFk`{YrRb&q*oLytl$e5gR*c#x5y@v1*m?N>_*11Fw*o#Wjyxo?4|gnRScz*foI}(^*tQh|K?Mr1#jglZ#K*|tYVOyuUu|A!Z#feGwH(O hS$`a=2SU2J^%?p;Yjxn(rw1@odCFMvF)8l5{{^R56h8m} literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_219875.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_219875.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d69c4169ca31ed4f31702d0843d3e4c8c8481d9 GIT binary patch literal 11838 zcmcgSYit|GnY(ben^j*g%ycK-A)Zs)qyW-T`-9)2N884-nA&G4792r@)~KNEPZi1{d^m$u?kkns#rCv`H_m%vf2l_`^rr)^sF9m zBn-orFwhKC27}%b1_c-|HwU8W<>o-^cv%c8OV!rafoRgdUcHWORR+TJ7k%pd{GS?*hIlJ>f%i19s7(OG$V9AyLvnGQ;J zl}MlHe)VLRxVNGQUt5Dnu~V-A17Ipp?>dQ`sQ z&j~oU8t9@K|QE1)1#Gpg0(yey-LL6b=W!?hOb_St(SX(bvab#QQEwj zw|EqR>>iEeQQ!|$@;ZX`IiuxM4C?F&-WiZh3WB%sJ9w+C^iL5{^6g1DFT40%Iiqcb zgZ3PFYnJ3}fv^1*f0d)aHQv6CJ8INA4Y!ijJkZK?1sixfzZ0M2-M9oN8_o!V!h8h}y?|GiWX^_RN9_`X5>(6tpv+Hl? z+xa#u?OM@Vf_wSgbG|Lhuxx>E;P=XRVwwSrdN_MlUb?fWUK=HU5Ks~S?-Eh*ts7vF zcjsFtiIRJQ`@V%G-JA7NC42UQjl79fjP3$GS(>wlWmyJg3;aI5o^O0$kZ&1D+{~4^ zdF>fzwZ5}XNTNU+#RQOyQ!PPjQ_TMWSuCs6OIF{gi2RQm$g+#E+i>r;ee zC%NtDV-ghsuV|Ryd_$ab%rVT0Di=F^T~v+@kBb^7=Nun$i|T1N=ktqNkJsl6w6}?R zuo(OS_q!9EsFom9(8Iz|K2be9jzE;GFCeP@0UzfU4bFgXY}n;+vaEf-NOp?k36VU> zA&JEmpbVL#qTV0ihFQ16Ikr|D%@zl9isRW<*BY#At#xoR3or<#P-fi0X|4oKmdCE) z(2*9E6R>(UWDb}M*-xkl6UfDJ(JU8-#~nkQZ)!r+;SXXPTr* zzaLmq{b!{=U24AJ<3{~3wY&Km%R}Soh9+8oN9HVY--0kK1fHu%Zgg*^dXNdTwnqwb*h0Z@&S5KGJ#D|5F=A^1H zZL%hd>jhJN=v>-b78?+(_E6W8g5v1S_+g=>X@O1^90_%&>5^D`yj zET|EfnglJ-dqbyJ7;AJuU~He9xOe{U`Gj+!JXL#CV2*~)q&MtYF&2C{I5#vq6m!N) zQpI+`Xb+uzVzfl}#ad?%exkUiy{k=-i36$XcEQ@d(3>(IKob>y$j#lHy&1Q}yHh2t z0@I4-Fh+|%B4c}F+S#i3LBX;|VD==gFR;JpTp0TFdrPH4^96yqAOY6LI%ju(G#Rgr z-4HB06L!JUE->v2p2aKw)Vg@}bKgU~u)kMedP8Tv0+ybOoQrzn-N~ZXf3*Ix?B`{V z=~KX6OTk?EYMl%@@n5>x<5>=m`h@Z1anLhqaF`^*-GBt7FO6-w{r}Bb2_AuG|s#B(D6dHF5Q>Z;=%04+@os zmTccBN$Y7ctXZ*?%$3fT{;2KFEAy|UEVa;5R2(aveLMd8XRQe~ap3Md3&TR)F`?*K z*pS|^OD`WjL!xlu@?-PCuv%`>%Pp!EV_{7FkuzQx8=5x?#`**tuEu4$RiInH*zsk< z=M5?P2rw5mx-jx)tVp0Mp{8R#w9XaJ7Q-=b{b9MFs|c%CbjIkOc&T7*5OfXiYgTs? z#@5w6gt6eIkLT786>a_7)+if25P3D`6zHn>nZ(|oS%5pQE;xnS0|NC*=u}#3jM~1> zKhYUxbP-*&I93$D@>^ZQS5RjnOwHI!E(ya)u1M=f5JN) z03og6+<_@>+yy2|HX(h2UD5Z!eYvf=ZMtoqBY0vY3$Rug0Jxw!s0nI$GG9MSB9|nn z+n^UaEbp6m9j}$lS)~5`<_iR?;&mSrtXgjWnE3l_Eav+pjD|G#x&*dSCXFD$wl3k3 z6&Hx_lY#77NAfrbyqecUOdchyH_sfHI$~=S%d~g`Iy9oyyFbiM(_g*4Dh~NHX{GkMtb(V7KydI6K=0{n;^K?OR!s{>8J5 zlfa)Bk0h@t|43PS(23w_p5cuVh`c|b?`NnZ30{Ahuruoo_G#u4HqEO2*PIjX>S;^$ z*pxp|?Vg-+daJvdo7;}q6&!LPL=Ee{KJ0S4GEp;pB9M?k?0Ys2+21KQi@8NWBuL^+ zK0G;IN*+-WApsr>I2j&Y(Fnw1U@6L;H5(fI3yoy5an(Qx1Tj??PhaVj%z6j3h}wWJ z;Pg5!`jLdI4uwM8Tc}yr;dPD;vd*K^yEC>Oikboxu z)!z=;e}#xAI^zr#p`vGFSK>`VMQcje_Qe}t_J7{LWLY}(uu?enW@_JCg6^%f&h)if zS*qds0Z(#({s+?2|Avf1d;~sY#{h{JJB`d|*bs^JYVcoN*El=$h-$D}12?%rTn0w~ z?9f_NR>ZRL<;5YMa7`$~?hp2SbX|^&nNGSy z)ev|=qS86_E{9BT?f}YOK^fX_&I6er9W7*<{}sxlbo{u)*U}BS$xQ(ki_m`uC+SZi zVpT_&%rirgp<8D|ouSFFEe+KhksHy;$TZl$I{l0`qK%RfL#Q)N6+|b&R=oa*sz|CX z#VZo;CPtF4oPTJ0*!!E>hsDXNODWZ*Bvp~7=$QkN10TBQMrTJKQMRPYmZWS?bPV3A zUSym^&&PZ3UA=pCq4&!-KY#PF?r5lU#b}D!BJR-HKPr^8CUoi>13{T*c1L!@q-XAQ z&v$>^lX(57{h#*Ve_JR&B$OOZ79aWN*MI%?uij4f4=ne)gnrj!0~vrGkyW97ZfrZ{jb^D{5IcAoL4L&y87CTXA!@alezV(xSsZ%{i&(g8i9-FIC z=T|CAv{q1+ASP%cnlKk_{eS^tY*LmV@5haBtfH1^XL?)d9re8WV{N?kBRWiC zi^lo`I$_(m1I9#lkM#o@4KEk}jBzZ{gRx<-e!<-QV+pO+tdJKkIJlRqvqB5j00AK zpGGs>!So|@Pzz=dNCSwY!N@_~T9iEB;~%e+*2x<9>_aXXjT(yJbr>I+Rk+0Kk+H*4 zD1VXi%B!-JEG^68pn*4Rv}PoyVcnS(AxM@0=uCZ$KWKN5nZaWImk~jVN26{6#}WWZ z^At;?2qRBRUzVAMlQa#16)=gIyp$=VdDbY~WVexAT$kkChB3F8FdMr~|ImW@?6(n884WrUE$m}#m=SENi zxFQ5MiZZ05IHWg4trKE;W6pO)^$iFGLR1gJc|M5Xxj?y2NIFjIpPX{Lr`>iVcOBt% zp$zQ>RzA`TBc|Igsytq2Kvay4!|z@03W9kZWp6+Rktm1P2X5P>;~EGdt{2+;$flRI zmGqS~7W16i=iURb2!j6vPFRRYQ`VWok;5MxiFIP#b2wRbIHfv_RgHD#NaV<^Q%HG% z-U&}fj>dW)QI$zmcYH^@H-0j?^VpJMkxy24r&QfZsxob{Mn`6=LKo1Wha!hSOO4Hr zrKs9v$}Uj$M8m?!k}XA@3Y}VEEHeXJ5O}Ea3PJ|U{#uVKVxf)XmD6I5^ zw#;-zx?<|MW!?}!nJ7xp%_taT3ZIPDMY>=RgE3kZ*%K=i3>9&uV5nKvHwgNMlzz{$ zu3gZzFLW+&DcvE^t{E*e4WSKFeS%Es8bBv9#_)Heq!hWDH6V?aWGc}F#_*)9rem#< zVEDVJ{Y)C%dRJsLc0Hk-=NCw!62d`kD9TeX(;w-FP>6HBIBrQXHOoxBz|?;-b?-ZO zzmsCxLuXcCAt+k5BSlv%({%z}7avT~c9@Db&vZw+V^#>GEz`9ET^oNbMel-@Q7AA~ z3oK2lDzUn8sbI_GgvwaK{2Pdqttl9;BEL$>@d>l4K%RymuzJCMF<|HMhp|9c1=YL? z#19C@poZ7*_~{Az{CO}WqnFp_jI78vavnGksQB3>k21QSAsIvuL$Ao~P6g5vt?60A4d56Xbd2Lt~{(#Q%UpITypXW9NdbL8R&N${LEf%%K) zP=Km}Cu3eyM&L&(_%{|VqaCJrD&PFMe8{I>bI#cdR1XMztt9ZV#9foZ@S1`|H?QkBw|g-(DB2XAA4^m3$WSzjgSq16^T@m0om zcr9j)o{R5ztgBnrp?GGZHQ`L@T0nYcB%C$Yn_{XY=@y?%G^Ch)%giBxIkeceIGJKP zku0>%^hNq&idbigsa$621g0+TOq8XVz01r&fjPLiW6_ynj)UYgTcW39JCL(7%fyQl zii9OzCYW}ETSOHgtpoD3GeuP{Q#OIJeNuO?@or;^YDubEBx#7GpPaxflBD4`P=J3D zl0oH*op1PC1nGgB9gs&LMslTVruD&^ajV*7Jh+oDHJ*CS;04oFx%PL50iXlLOM|+g zKA@M$15;ZACbJHtM8<><8X)><7$(teggc|@;gIr&m&Ga830BS^RIq;y*9MV4?3jR zE+51cMA9KLnFl7{RA6E%AZj_cf65#1Z*W-(CBl3h=XM3;umm>`GyMX7{+p16h!sm= z=z@Gq!NrA_js0H~eOd8&#be#EQ0EJJ>cIwBYIt0KQBc){PKOUfyW_TRXribI9*v3? zwM*p>r&EP}p^K|Dp)ZaGSJf($CUkbyOwa|P9(>Y(gPGS{#9_M7j2sKzG z3B*EKG)YTu^1_?$6uicXI=RGEqJEuo3Su==<8D!N68|GiRE`V>xH{;Q;u@MO?r}d@ zj7s1o-VRY6n40jqrTsy&1hYn@B&Irsd|vib0@EqU9$m+nkDcols#?(%g*E6&AkHneJ@lt{?)q9umNcaIA6ROIXBqt& zat*@2crw|0`Gp|dVtew^Ykw(-fAPd-qilheEsV1Hnm2-2ttntMf7nInw|%p_SfN;L zA{4sYWw**!6{kqDWz|R+&B(Z2HQ}n6C@7lSGrMQif~!`!T7avCaH*@C+-DTuuRDI#fy*n5c@^QUY6uN=d*IfMd1_uzc22xcK`qY literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_243114.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_243114.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42dc858fd1956e79bd9d63201dc44fdf436a9741 GIT binary patch literal 13580 zcmdTqX>41^neRQ`r$kaBby>1(owjUQvTVhUVq3QzSxOY2v6V73A0_Lg{7Csw9BxTLTOKQ{og@^v~`$^GJ%K zDa&aJEOsQ$o9p{#=9`)Cn)w%^vT3yvxK6qQ7pyM6Vw^; zOP`@ho|%-Kkuis z;NFcHfXO5&FL<(?mBXKcRkErNq^z1%E^2PU{BU?C2<88;FlfT$@PI+X?7h&-AY1uB!Gj# zI9PgC^}R4jc?mBaW{y$y%)leuWxYgT+)nsS2-49h_jC0!L3(*~ce5ayw7JH2iJvAx za(Q$Yj+$ID0NRYhUtVs!A-K@x0AGZgZb^PVKObIHufnozqvWP-Jy;qi*Mv)$U{+*5;PoO-(IOBUr4g$~`*kuo9@5DL_BY zVfLVapX(4rUTRsB7pV5PEvZY9y7YsN>%H^6adnlqGoj57?C~EB4kAtYl6nhLZ~0a8 zjn;+MxOzKOYV-k#?>X;Fj{yfIC_0Y5)Y)h4PbCKInH&tylB}T`;%k&Uy>?^i6~kD> zz|!fJQg6_bNw?JSYF;@D{bWGo*aIxhD^sU<7P|9jc@3|6COFj^I32GAj!pfDVX0xJ z7%=fNjD^>+l3AD&uEVur9qCE`1cZD8h`c8t6dOR~^IBdq4*h!!ypETWl#w@NjntSm zI#>;lDT8kEKO-7JnP9eyg?(+`rzr-P}984xCiX&V`!QnL$ z)YF`EgtJXrM_EB`<7`tS4naEMw6WxbbNJK=ieV^l?{2Q&BuJ*+oS=5uZ4(Zc6{l6R z4$dhi>3)G85R@+9B)~&1q*JbRYL?3fe4GIzl}(Sb)=BX*B{1A1{H6re1*>y-*yV7$ z1bwoXR9+|cl*-ES907s&o#MqVRnqkfVTjyt$1#{~Br)M9Sv`Uxz7m-<#s>GH1j`}>2Hac!Bm`+-)UkxQub z0gJ!lgXZh|=J$mL!-{z6E~MJ!?O4CUV0y1T;12Z8l?E>(Lv6S@Vz}8F>H4fP(u=mW zAw!$@NWxt7z9HBZw9Ji%4kL3zgo)VhD53?QU5HvxV>>dpdj}GQmiOskOHeU)HuM}S z+#cp4gEuclrJvcN3bgYODm>)vU)JRXD*Y$DJ&ANeV|cGJa6C}!KNvh2>iVcM)H`oQ zng(yzLzv|$|Ea*S;J!I)xFvGozHXnl=YdjvRpD0z)WIL1{OysdsQtdO!`q%JRt5V} zetkq5?Yyr%iqhuTku%*O zax!xRYjS;pw)d5nPP}XtTRy@vsGbqjQ!|rRmm6fOOHkqTB^x*GB1h}dA>2g5RlYreX@JN+GodmIu-EI zCh{);oa#HsKXV4B#Ec>L=cJ!CBHDTQCzl$tm5CM*wudzL~9^U*RlRy|1Ii}^CbVNGc)oA@$-*)$9Mjqr9e zUjSIU0BhL>V^IL_PtwR!&KL3Jq_>JGDX;L9!)Pi1Z$2)|z=PK%D}aWOoP0GhhD5&9no-qn6(8hky!i?2=ZfZB}Kx%Z`-FXR}OnWAAGQYSr&Pe2V|U$e{B z?%3vxUUH48nQzJ*6PyNPlJ|gW<=M(t@y0dW2A5Xx)#*}^i{R2~z9vU$O@MwkCZ!0U z$I|2GU&*I;|CM~Yg=fHO+P#+Vrm3mpTh?;URE=j(HViMr{2;qMFUO_$RCxAgP`W7V z0N^#+!`J58|NNd@UYfD8Co!Z>5x;j)m)b=%b%px>K~a#|q76;#V!Ug=CymRWq!6ZY zHg)Rtq7}wr-t!0+E2(DEf{b-vnRam4ycg7ti=%e8b;9OygL;S$3TMIPd8B+4yf4$v ziBbC%4zDa%h$Ac$9$}GGPZNOpd}6k7Aay7~YG+3;3X;juDXtXfWyC__ut6iph$+Y6 z6(&X;hczZa!aCi8%;k1+VCCU~!I&l-li(k+x~~Xoo7*`#YPZ^0w!T%6*@%HBJ12hZ z_`_9Te4BBU)^52@j9Ha5`)re|$?ROA#p57m(zQrkd4C-FAEAxiV5c>L5UORM@P;JupGDm`XgC|&FVToI_wq{ zWMDwiWCDgrRt$z_2Tz}C+|>%?S_0K0eB)9l{0 zR`5fE>6&HE5;O#l#trq}7awHhmldYKg|8Gv2~&YjnJ`=Ynq^Z_@F+6X;cR)R44HTM zG!IPXz-*`wSsEi{ank|s(S+6#+!Ja;rHx3t%iF!I&<6}ZeB*(!;9BimZLl(UDsHUy zDH4FHb?%^FxokAy9CWd)GX{o`uJTtMH;ybE3EKc))d8eCfDPgGYYIPLuB+x%p~}!u zyr>!Jn!P<3(aHDQgYIDOoHcX>74Aa1U6H-fs^8Ml+B*mDzKV7aAl-n7cliCmkTiH| z&KWi!^S1CPDr`l%)@aS$g5O@aYx%VC-VwC_1k#-Vw8pZK`6D{q5>_l!MA}j1p6DjD zxd$10@72TxPkeRibZqeSmyNNZQS{;{GLCxt5{8oC#b4UO1)-7ozl&6&iroO3JR+&ew*vM|wVMin7txJ1^ZGiM>4d#W-pijJ-N^zvMMv--F`P>y~*- zs3cPJnIYO0wcKgE+k4OUMIUN@DR%12{o=E}7av-RuU(qE6ucPj`E^r-jkMl;DLQhG z{&XBQ9ElCSa^G^ycQ{d4nx_8+R9F{oM+G~4-3d!YsA}%-mdx9L)QioxTW_|;&AXS( zdy#psINP|n!`GWAt_%&%ZwgnTqMb{YW@KspxcQUTJFRic^GlWk$a3H}+~>2O&c-b- z;-1Qav#78>d=?cv2R*>PEJ5b&;Wtp>eqYZ+*mX5?HL%6IL%ZX~T3_b_bJ4Ytxsl*- z=*p+&JMDJ|QEe|O@4Yu1JMrpQXNO`ZhQ4&hUbmqETZ|o#O*s+kL<7#)rP&pV>StbM zzNIM38OGO(_cO+93DtzR#*I6Doy$snpdnO@jN6cMn@^sAY#Vq;OjThOGBzP)6D~6a zwP6|+)rVh0mi2tNW-U4^&V8b2O0K6Rj`^}U%hhdUN7o8 zj-ES?498(q#^U!YgU5rl$hajm2pxr6BQ2j9BCTlq^HCexeh?WB`Q+Gnr1UET=Ah}i zdEOk7-_R^*mNxH2oA<_xTcg8&RCX_~Kb-Yzo_V0LQ$sQ9>*&<$F*}P+vBLM+1q-2K{W=&*}#} zKB-UU(+^8n>03&V#;<23#Ldy-#u!P+<(~2D{ZhYdSdwwo1qN7lfHCB7ZGfqq#O={e zRc+7@?uFqM8)M?-EJKVyId=Ydbndjfo!5|y5xB4e1)Jd|c;vhsFoH3V9awo@C~hT_tb(|(47~Q~y-dMXOU@R5n|{(**6l};rWccQ9IN!q zU7zlOSMn+zF7Ew$Ubm>al{5iqitI>ikb!x0jDp!>&wo~6v6~_3PJ~NP%>cI#T5$Vu zgtcjVvW|e^eToYo|5zFsXcS}(5pIiJq^~v=w!W0L?Q#5!4B#t+#tKJ5EFy* zVN_TP6y&1n&iQdh;+mPP&*NM;osX}K&%vegl_!PBXFZ6&~{PRC)CBq4=b&&zm8)Z4vB;=UYESwPXOi}__}bj(;w*HTL# zYU#Vz9&Z^))YjdYS(y3oQl#>ie2^A_KhQ)CpBH^vbhq*!OE+#%RDNAIuM64YMRnnt z+dFRVh|=+G`xDhossEb9=GsI>^`A?1vf7n0%1{(+dQaXj+{tP0;LWQ;bK4=d-3A`lWO|G0`*5`UWNM-t51H=sw39#W(6z$knx&D&MP)(p#V_~t%%st z4m|oF2*>Lgku@vE9gb;W>LX$YI}arj%Lhf_&=2KA=!#W)>!{aqKgG>{1`*uUb0gr>z@5o2TZgCd!>7wu+!6lZoYj}am`L|*RsAKFyo&J$&ju(EJ3>MOWJ0nZ3c_$TwJ?vNplEk4#hPcOPX$^ z>HeML_hVm-#WjNfr^&mz&%ZD5$~!N3yB_F_SI_#-1`WhYqJL2wYQNF5&=YR@@JPh4 zFz}n|&$oZN{hsX)=f6CEUv~=TkY6x&EN~%kY)*AeJE#5N#PzfDXXE+xF!4N7pgqt9 zbzT11s~*26S(jJuJ)F>)gLNU>{0^k6!+s1C>9aW0G+!3F0{!rj3#j(FIDQO}0GWs% zLTrg^Hvs`)tZn|bAQRV=db<+fN!SytifPJY(sCj^$SQI-U>&$HME2xyhotz@j0lf+ zfqkqeJjin$sRvj#Jj%mYC-SG;Cvfd3|Ep{uCe4blZ3ChpJET~(hjORw6keKQJ6+Uk z%P=}+Mx899E=Je~g>`E2=@&c)gMFEk&7}qGSXuzAjwOW?*jy$yDY@YKzBu8@n@DAz zJXV88aB8s6l)N_iTn$7f*leqLydFAuD7R4?upl;=DOM!1LTaX9Yl3|&UM&qX_;m%Z z|Jz^_g}sRL-LIt7yZ;H-O$x9K!;pWead@!%@D(~mtJnU`v&Lot&0PR=7Q=Ytn zt8e(6R#*Sb7+)W!=zT47ofxpw1?H@klqoNAvKI^gG~_ zWjwKw@e+7MndCJ?u)jTWkHVw$s4^Z+0x!ji3>V$v;{{P zUbHu9xH^GT1ve;)bIA)9?Kvyf6yRzNDg@dp=tSe)>YQ;;&%nhA7^X85Zr3`hLV-9e zC&xM9>W=6Yu4e9^pvMRBbK%nkub&MC;6hjGu4?=myn6yYp}pdcxvxCvZGTEl&F)D1 z-EH@4`jK?2x7*hW4*1G%wN&9wxP&N-D(;run~fKo@%FE1DU~_kS&>Qe8Yw8U)Z(F#OBOH5?AT*U5^M-&JJ?w@giNn}EAL0ZTY=f86T-g?f!`!zwCN{KH* zC1ay*^7vCw!HYGQ+i`iCBaa*frESLTbk9tIX4XmGvJvZ_EQ%mM?wE1`m2N{9;#y&D zLI@>jL=sy^oD=L#N)+~3E(*%FNhdoq;W)tk0kXh!D3=7b*@}dw=|550f2VZ+L6xn@ zDMtBr$(tqLFcP}pv4mo@D>wxt5zAMY656<2U3+6_VQ57LX@IDy!?|x1kV`_~^8FNJ z@S5MQc(Y?lYgV9R7Kw!R1C9H zkS4?+P48OKQk${I-4n97<+wjh-)Bv~_6R(`#0zHP-qvPD$srt1y! z4c`>wETQD%LfoYp=f646Y^EE(?x9U|Mc~A>p}CX>ozKHYg}LTriN^i{{>pThTZ@G literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_291697.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_291697.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cccee613bcabb0a6f47c0f0ff37a44cc60d6f936 GIT binary patch literal 12042 zcmcgyYfKwimag)%{KD7<^9TVG2w)%(2!V7$lO`ca6B4=!c@kQd<1%0#Hk3_BJmqa? zrhD7&ZV5Bfqd2{@iqmQ-PIsiHM_?R+@XJ#2e+$3m% zCZ~xJ;Y*H?I4_xzj>t$tLO97uEJ%W!=qYU$+>)n!=N9ypTa*hu$=k-5maYOl4D>GA@&P$wg!OF9TE}#qwAMU93JgAZ3*Ju% z5?&URqR@a)qmDD@C0@r z!Ybj_?)e#;*JmGhdU+K}FW8w`?~@>E3QejfX#aSECXZFhlc#xdkS7O#PwDWu$6VuH zUOMHx&LGVg6?&oViqW7K+Mpbz@Dd+Hgrq%dRHB{FhkAma$;|8(^XN;H_a$FvC<3M1 z&`NV?%W`N@A2&U$oHdE7B~ZQ%t%5De((Yi*><+)2AOaO^DO=$OodY}Bovay`Rr;V$ z!o!xcm3}Y-YtHgkv6aA6DwwNu!de^dLY%xylW_NGfda>9eETxVm*j?bShz}(+F(yIljirz*whon| z@=U1;l_DO@SI?3H#ReZOLv~=*n7d{Zx8?=zh8!*P_(s;m7L7@KS<6l4@ss!EVm|^_ zpI#)T$sEb~W{#4r75T((9}ywn3&RLDvUO}DTPtcBXkwc-^+Gv?M!K2Z+&cv>-;>jJ zo}HSlXPaNLwH$H!VU5s=`%2Nz2U@m~)Buw9U@kI;tMw&)jd*76-A0;}-OCn;y?&{0 z`fQEZX0E1f+sKxI?82;sO?J!77(=#RY$3N-?QHv`h_}RHhHuX7$m4^uyH5LNn?Ree zW@$Xf1N%1j_J!IGww2xIhaHhtAzKIbv+Zn0PQH`fzl~;j=B=n1+xhyucxP%uRtt2o z)|?SA=NKrHWLwxSai;rqz+r`XSj!5l8RhHQ2DUbL?P98i2>I!!GU7G&v;(iXr@g^S z*aO0@liiK9+1c^Nrq!quYXS#f;==2S-OPd4DIUv1YzZx$Y5*6P>uYj7ooK&->jJH|f-<~p)4PeafaPTAyzhj;OEfr`kWWR2nwDZ)g-7)2)ZC7~NxQo8VD}+Je zwV46IwJ@j?t}!1^`N1eSg+?5B6zJoh<;g3&5-S3h&m=%mJj|dBZq?!KhT*jQ1<(J! z+q&Rkro5oBvvmR#v->7I?mYli-M&^Y<7o9cXJ%U=*9uy7*k z_Xu*Q!>6NjE86PdvB!nRo603+l#F(#3LAp`k2dAgB{eY@SJIl46{Yp2WJx`zuMfVR zHkL<6Iioc=@W^D2T#p^%N?Mk+Dbt(5Q)z8Uv?JEc?P%e&t-)hyogrCN%js(46`Xc| z@YuS7C@@7vIbHRaNAC9D>5to&D^j&RoUSK$JgqZsQM0Bii=N6vwf|fHiuy?K=wr~k?XxR4$Ct*V_LwLu9(JB03lAyyfQ# zts!~ZV2=1cABrq2RmQ40LrqALHk3qVpWCC#rK(sPXQ&M+9>dU_2%m@?kD240zpP$9 zyIgfx#nKk%+-2SVgyQ}NBNrLJgaT(NAvr%%kSn0iAL^t-S)a`yAK$c6CV z#L78+Q%LsMV7ghdR1z&um+g$Hxw6KmQnI{ht*k00`@$Ym-l~cpH;q4cIO1YnHm0@2jEM!X5o`8)xZA zXt=7QAj8r}f)2Zz;^*(YlW2Rcl$Q2{wCf{nXs$_3SySB*=u{o}= z0nAfW9_{$T5baueCuZl0>f>#ksX0Wg8OoyPIK!@(LolMTGSA=9eJ~W1=2K?zYWQlz9yNTza@3BH zbd4&A)J5OmjI|sEwogO0Gdi?%C}!o1O&rzqv10u&QP8-4ge3Guk!h}A*OzCI`1ss% zb87ctuHbO+1Sqgo4a|uP!ZhTHcf&j^ZsT-q%O?_jzpGB1;5zyqT<1F8<#g}DJS-^w zsXba08RrTrV+Lp^wh;IIawxvQ)psSTx%xw#t~+>qO<54B{xSQAQZG_rDq@bBVx51a z8XrSm7uH3_qK>%Yk5ub3Sos96TyVO^CwyK8ZCAX4ar)*Nw*xs`@j?C;pIb}eeUo5` z$?VBhBq0d#{7#by8F(cpi64_dnJXtlJPKt2Nk9sCg~WJLlEZ<@P`wOz0i$hxSQ(t1 zq#~XO$XWSh_T0HaOH1#|M68HgQv{T(l2st=N6Yl(7J!f!5SC`5}%-1vHSC=H#`k zMmQs|TH#A;{4me_Fs=a|SYXG?+GTsT^Ly&5UargSb)bMPptYiA!Z|&l$7LDqGG&FV zKBHgZOZpjD9ZRu>G3hX2)yGE2+i^H}-^&vrCh5cqJp$1MkN z-g4DGJ@52dW}S=$BT5V7S!lr`&fsO}wKU8LGMX&{PHDDaoMN>s0K9c#Qj2AeB|~-4 za(Q5(*)sT^Anw0@^Xb5X6*3y%54~Ewj8i#nFNUkEtbm&wqZo(#mpG@X321a zK=l%tpt60;2kMI*Yd9b%xX~w}JqAs74C66;2jniLT{9Wn-ttx^v3~d-$t!5*RhL7= z^lOJ16r~)*P`<@l$RJXL&+M~KUUt>xT;P><0Pi#QYrGtgxMPBsW83fwpVRI2FpLW^ zsOH_?xp}A459l7c=K*Z)n6}T%GO$?@j1|iS4D6-mhRL28s1TLTg$)6h8FHo6+7s zXm%!LJCmB7X^jyfyc>PN-r!tXqYe4PJ<*{*XeyJk%A}@p&0vg7E>#6jruF7UHq1s( z#?{NcDSam(In$0+Qw?XTiIYiFP0G}G<5ciiXg`8t=fdwrX->N5-92^Z)UqjY?zf}gjjq(5e0VISJ%f;*F?1Z51_7F>f%~Ry$rc+*71wj>`grX( zE%#ayhE(GrPJJlh;MBdVs<%1S+bPw_RceT%hLUH7ljkm^sEe!AWsbVMGHPEP9p^^J zQ`AJTca2glD#D7;T%`IFZLl}3)rIUKI${XX;S;Fgg3w&Z7pV^U!WY0%f=_mYr=nNm z(pxN`x=L`o?ZM*+Ar&q9!~Rd$RecqwuS)5wSM_zAzAiQw-<8tu0b3N7Mom$Bl#Usq zbhL~stjSx}pDwD1dY1;{s^y~qzxFOF6 zdm!+#gbIMtjC|s`X8Yc^oHKa$0UFY@uOMCZgQTNAWpvN6Z=PE+uy0=5Li{G<%LTn@MaDB~z;g>|4f$}4<-lVT;Ky_N zN6T0t*7VnyA>gy~AtE$0L>BGzCBrNm@ zt#4L>_I5xT8vkAFoco-6#Qop0U>Wl;7VzUv#^pe`!bRJCE{|J46APPwfruuc_(j*B zffp~83wQ#zDjX2J77P+Bv{O`p&Y&HDAP7VVv}ofHygqW`*oWVfV2{ioTsJ7d0)To1 zgZvyb0WYhDL5o1}%rhvp8(zE|f*ro=8@x;ns+vMo3g@+1JGgEov%QAQSv*?0v-Fh#jY6>2Cq}DGsg_|OC(dx)pZ0PRAI~SMF{QBaz7gy9h!K3Ly zHIv?oxNCxZ#c_ZvR3bF-CeLeUZ6H-{;kl z6F=Rtd?rCAsuR+G7)mI<9eF|K>oPuvo)<1QhnpjFF~i-`JEbe?#;@Aqbi9jG?}Kh= z^^vyl+u+;HWzpWH@!&gYePN^+I#n9}5d>8Zb5-q}e&4bej`*uO6g~RQ{BHx_1yZ`x zaGo}n-fUWGf>4YhWvmGffRzi17B7b{N6TYFaWbWAT-D*jcw=JnVRcG(YE^fZ)16J8 zJDrMm=L zYRpkH1f^oTIZb_1R*#)OPhR7d84SNMZyi~82(t&phSBEr>?+Knk5nJ2w>yYuj^O0? z9oDxyrhlvgctLJ?h*x~9Mq8rb1U5iLj+@B?9Kx!w8oTj)Ena!Q4Hd17v@8-rXz#@uwl_KBhqK+X7T5e z?f%;khR6_2q9GH#cFsLq25UfN_6Y;(6OgM&kNqJaPW&Gi38*F;%(=Kgp&GaeQV20{ zrW$3XyzKO`v%MR)Enc*mEWx0#V^iI58oWZ=d;CqA5c!5@!`ryc+uqU*cMkSxZx18{ z;@^V^Zx_7&J4igD3KlhC4Qz+L=pa|pl%ksNU;XAs_kNUUOVAH`xSlhq)?tnsPE*>@ zrLg{~Tw12Ui$CX7Oz}FG7yJuE560&`^IpLbVjuXo&`OT;y8>CbyTDgP;4nW!wLeF% z7rmZvZ2cG$d**$!^Ki5shwG&rGeDea!;8K3mSB#uEsen32Ks-5pVt60;D5lMsR^xa z@o@O?il!ndt4L}p;DouVv2Ys8ie^_*wkxUG1x{+QE8KNMICH|G6fPAq?oXqO7$Wv1 z^T*vuSxKlH;^TPxt&R>w&2jP@>K>IKf2;gXxuSb3c>K9Us#OI0o~sGDF;dIPO45`n z)D=070u6oT#71-cHQo|)f2By?Pvx* z91*prsM2WN?LD{lz@A^;xMi_!Q5)7qhCb0piv#wGkPGn5+V}T?%4;8HW{6({fhrA|ckN5w;Bkl_01xPuub40k^Rj zlWBZTxDs^Oy-r)3RnDM^jxSqjcjof7trG_|@T9Y9g;|R{p@8I&3}zU;M$u~oUcB;( zO^Cf{uh?*OW!8h@E2NFrWv;|B>i~hp`Dvfmi*}1#7$>ZNB0)6)>WMo5^A*tk34Y$| z@Cp)ZhN9rf$CM5){Glibw=tc+D*pA(Z+EUxhl9P}Q&YEZx%WZiO5JHrRs%tlF1RVG zey$~oTi}wbIH7z{@z9?t8VQ~Tw5T#i0_$>_UJ*QjBRauBJjuX@!oSiOWc4QmWEd0} zwH9GJGEqqLdZEuP({Qgn4;P0#C8ih@ALS|gywBsCcRP8-QT%^4UOMUW@haGY;Li2h zEQ7BIdBs^Fl!mr-EJt#h;rARq6BAQi)`}osdu;mER~|mkg2Q z-gP5UU2|7^N4svq*&;DpjI(AjTY|HtVzvxt%ZZB0+ugUi|GWcbpGgUc_SrtvMhD72 zbxSN{eRMyfSSKL4Jd8gN_CGBCw~BwMNM0CSz2M+3IFioERp%V%oJ;o4aYvbFi1z7m zvXX3wJDw5nSwBeZl6{Im%yjtDPRkyDbkV)4!-+lki`|DCjuhliJ)^%}pUF&kl z!GQqsyF0nvo$E@Jr--JObtC{%Z}hAl=;sdfKSTLvC&_Zs^t7AM*Q5$+Z)or=|37@+ Bj#&Tz literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_298484.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_298484.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdc89a8a5d4af30c5824b6e5941c08771b014c55 GIT binary patch literal 15028 zcmcIrX>1!umhL8-XY(c{>as2Bv<_RAWm~r6*pV$uzGL}HEIZ-SY>K9CH7UtfniG1E zLC7Q+QJL9s+sSO`Y!;y>#*8w%z$$|Q;>n*4uz%R0U=h8VT>)!g;U5dg>tL^6`>L9k zWJ=0`ZGhc%y*ggKs`u)9#edZ6wFEpTZ~f-_f7TPkzoJOyF{T60exe|V+XO=}6q50lz{Fq+R-;uji;rZ}tdX;!i3@c|4#XL$c!RS6F8T|u8X8vHl zP0yFq5E_E1V~ifE-|~DeV_GBKyhgfZjdW|_^gbXSIg8hMINHRxImAJ zA-$wPx_yoEC2N!~_2`+h!Zdw4ru;2MuVGDz*MMwQd?;h+?t-g#x+X zC$&#r3f1>183)Xpm8mLB%ODvk&=OOg6LRo#&7D9(A5x5j(RYc2d-JpL_Y@*>uS# zXfJM%iQ3`uA*Vm!@_PKD28AbG?5O`qC2D~}FBvK)7s=Bic}65pc0pL>9vSZQdHte% zz;jbHjQc%K_c-HndRf=#byfugEb1JqftRQ{fA-Yz3(k`)QW5U_^@hrzUOfA)66u%# zj7)m^QJbqChI}CVEL|d4vO!sSUfI&TvS>i79!|koKTEe4EZd;0lCx#YmUDKld0;3jSpXUA!q1Beay;mWwpQ!seWo~7NwZ3dbxQ!wr+C=;y2H9!o~P(2Q= zlCz?E^7@0gwd!HWT;Fj$Rh)yfeJ8r?I51r`SIQNAC%U4{7%*K8M{reO3lKk(F+`sX ztsG8`>fvfoDoRhMYEUYo!F078DV504F%%j(r^VFOtElM})OC4WDPVVSR?gNZ3uJlK zugeCrfd`7rID++ojI9GUPF~H0WJSdlamA=F8S2cplZcY99fz6Q%59Z6DV5`D@OT@z z2F`+Kwh=t?re4O4ID(wsX63G{N26 zZi$X-L;FN<4_BWz)&jN^*Uat7NZynlDC$94(vrzZD-EUAM#=Y1Q3O{XZ2y7#$=)Bh zpX}rIatbWv`?-Cb5{F*kWZd4B{NQxi8Q8ayyPQ6g6Fk7}Uq$u8%Cd;>!56t=Mn13= zlurJBZ_6(DY=Aq!wR0~%&}LQ=N!-leWeeDeg(w>6Gb}7(qJb1C*A1Wl3DQg~OgM3$ z6ctS1=BVf2GELr>iPVtGKOibb0@pn(i?*95=yjiDx4@@CC&~hYqIQ%W0R=kb^f4mk zW_%N(e8@K}DqXB=*y|A$QyzB2FRJtj4l*HA7vkPTkT3l3*f z-MO;vVbPebL2AV5U7e6CfrsJr?hp-W5-9;H_Oc`6qkbR6S6UFP9d#6K^rCzen4^G~ zb4oH+P!KwT>=qOp7EPI;Zy3xGs8Cen58^7_=VSc=TyRL#Ny(@^r~%iR{s7NM`uhBk zDt$BY0hkNtXn^&L%4^b>cDdc-L*s+MaaJT@f(M+VFh>e6w01*OqDglSh@?wYqY6C% z=QUAv&FQ{AK0F}mFsqzW25=S4j&s5_C@QB$@mJ-)?&}M{kd(gR3Fp+Pm`;FZCY`?F zfJjX``$k35Evj&LVM5Y@VNr#fMk!JNDskg3V}bJ@1YTDDsONvrwoHz&1O8E$+tYI0 zL#_a@Y5V9+CnPyR*a9QNPWK3i+|GS_ zckSH;Iig-F8 z=Lj8rY_i-|&8T>iKaezS4IN)x6;Ii#V?M#&lAvrUqa{&XCm8EO$5NIu{;FVU2%TKC z7DsQ!UJ&ftXLU*Ii=k5~otbOqO?Ft=T(p#ce1-;{w z!=E0%cRcQztw`1!6!Zr}U8x_Cv!pjf>+YKQPToG#5<4zfw#_PM1Apk5o&4PRutC^y zQm~v9^rukehUlKV4*mjPKXWMN60G&{eS&4z?8F@NRmYt7^EbaP723}UmUDvse5mWO z#ddq^%vSzXyg5eF@8xfwS^Tab1}b9Fx7+=OAs>8l!ldv{oRr7XjiN_PW{FeSKZq(C!g0H z2$M^;(sU2cKo4!Ght}tMXgwEJFGBa~XX<$efBjB#{L=3`W&^X`_ni+r<}F=e)sns9 zZcmKjUzzzy+$h-F=jtBXz8ZUI|DxsVmxcXT1pAe+ZqZhJyL+aa?}}X&Dq821e;9jM zGdKB#@oWFQ?R;4C*lfLRpRw~rv7(R1;x)0!dz)ul1V@KZ+VSwx*JFRWk~nwin^!L< z&RrHddlRn>2%Q5$=|FCZdt%=Ja8ow!+_lA{A))KyTX5U>`%*J~KOI_S7SQ??I zxrFZ)%+)bJFdO)6D&jSQ$r0;9QEEwT{Uv*QVrBv+$@Ysk1$AARPN@yi=9pcuY{fG# zsCT`iT;4_)_Al=sG`3L3lEE4s6bx0Locr|hz02{j+2&;JA;EAc)SX%vS3L~I3*MpE$ffX^NQEpDe;Kr^Sg+g% z1f%3s9}vIJVjs{uMW`hJHl_6R=Z`#{7NH}$J{i=j&Xmf1A+^8>DbJQy43}__{g{)+pXA~Wnr?2 zDSbz!-vil9CF2{aoGOt>JY(ao!Afa`RwicpbuDfoCa<3-<2pM(y z=!K4)ne9g#RAnf{x=}&(fkD^MHO6&tswr*rFo?M@9bk)Md&$3Z`+>m`FbVw!As`hV z08;-^c>Q;XSX3LP=?ER|hz9sGg1srJZhkQFyEpE?G1ofBJUl2IJeO=aFR0His_Ag& z^s&gX=;auh)K$eglIpsr3VA7Pq)^in#D)fqsQ3B78q0_S93;>LiVCo5-Pc70SaN}z z><|>i)69;d5Hio;Oo7i04SKc@rFqc{?cMBk^zy+AOlT|8wF{;iw^@j(B7D zi&{z8oFHs@977vRo{9R>OA_x&*W?)cCQ!KG=SR}!uGy4$adQ32Ikj^R6Vk z>57!jiVVkFM?;6f+EHu6JrQN7Bc;|%t0Jl>8PP%n2-_p=(bmX*-u%v?1a+Ptig*03 z`+oOq$NdxW1Bv#R6K8wA_9rUNC#mz{Ln)O$^piz%@$K@Na?ojJ{t}cvcK_I1$Cuq- zbkCbRBg!RX(KHv~Vq~n9pZLaDn^1Qq>N}J0Hr9SClk1eB&TrL(r6}41Hf--5W751e zbRt!3=UdCWHDk96oH} z%p69a@+b=6d4Qi9GfN(6rl6Ixa2Pe(kj}J%&c<0+VHDH=(s4zcZ6#KbMn#G-S_9(> z+MlCTV(nP6I-QYYd{l+Dvgc~WSQH*XKKfC>Rsg)V6qV!bfTvR#3?x_v_%p_ho~s$- zvAE{)Jf0M)nFKUks#%d&Gs4C(dJ8q<5v<}z76A%w+Ms?Jx2ZtSH#5rM7T}fqd9)3% zbl??W>!@Yu@v4@kTnC`@l^BatVH~Lnkat!K(OTK#Ea!A&5 zyv3*~9#aix&-uC$BGsUg@FlE}CH=e^n9FWLFN7Vj4)m&mSA&ABM&TOtszA&A9a3yFC;X6>M)d(3%i zI*n3%05$v!e*PpZWkm z;l%H@pz) zUZQoO&Xf@yL8tDv@(kZD7^-88U}%Jssg%s``q-SGyri{6y@GaA%qnPVLWfgYbMzHK zTM?5Bn(ENuC4(utGveh-0H_IdEmF$2kKH;JekD2~9m8kdjF$@boq~Sn?4`NRuN-sT zpF1Dk6n38#^k)<2UQIaNk0=JKH1VPDzkyxEU*-Kr;ZnhMG|S$_e| zgioLHAjQc!67OAP_9Q0D(8AdeZjLG`C+C#dDL{ps^ke(iy*_prZ!Hz89KoMy8p;X-#!&`cox!Kknn zr}Ape_lTz(h674jejlVTXo$3=1nB@$1o(mZP1{GE#A_b|$90@GV~ZgFE2oogH1s1^ z!>T7ENkM(CKCp@O{Ybg2dUAFcll4MEi(vBy4V<1ca0VC~(3BNCT;+|oYE{J$&5R&u z~3iS)*AbZGMw*@t)`G2fEe^u748URE&&HDU% zKxbE_IVU+QB=Ik>Bx)o{MBGKk_n;+L8sto|LXwU`z5(p-JV^>#88T=KG-P@Rw;gh_ z4l;^&$)aE}XXA?dn=++x-;8{NOaUFEl|1uUc3>?BZSPgXOxxhBfMhpc3c5iGu0fu5 zy}cXef1RD3Gq0E1CkI$TC)s#Yl_#@c3Ap`ye&4l0Pvy0nmD2ebImFOzZ~-ps$l;7E zss=osQFK^Ft5oti%A}JyKBJ?(Q8E$BA=FSTk^xbhJKeJrkSbCWKF=hJrd~SEPohKx z_{)aDzvc$SVHA9lBO(Q8#9r(&6crF`+bO{Sji@wuB3*%iXIP|=4^32zqq1O^jg5Oe z5Z6k6y+PlAN7Q)0E9YTd0got$QVn(qFFVIM;IzlERwUqDI#S%8DKJR*Nrv7jyy)FnxE8Gr^x6EdJ~9Sa=| zx29-Qbe}+1#K>=GM}j(~A)lb%0%jc4p^PZQ6X1AqB-J$%?65nsI|^=( z@a`mC3+@woE0{!pMiffcgQXX_!XFlNo54Y)Glh3YcSl|fbuQAT=@%j|L<32>Vu7v_ z=qi8@t|jS)Q0HTVb$TE&@L}yoO?R4Nu6XS|-)xOgwZ3*gm1GdDbkH?wsxTsCf~gVw8A(%fSg{Cx5Ho-DqZ4;d#Kt~29dCVP-4;Fqevhb| zujXBMY7?f)1yikHs{O?C>A<~#q-kfkBc;Tc~?UXU>}K3p$6ObHwTZC($)8=(d9gFsa)CJf@A)2Oovz=L>>=d)q;MLpx?wZQZMH9dsgtp_I?j9k6xP5Wq8pQ>v&|X z10GoHyqrHDGvB$GuvDYCHGE{TtU7k&ueQ%9qhz!#WwP?}k5qS5A82An?ik}ev*!C3 z1$!H~=w{^0a-zIpsj7DMkZHnJFAdp3=&TF$CV}3Rgq{~II9dfqYtqp+dud^Jm$17l zx%*gxK9-=4VUt%hNGr@K;l0we0e8Nknl*tR!IHoxFszT&58j|DFi$B4n?GPy7IdNn zRp5lluRUE6I2cen_$lb#P;v@zeW*DVT=>fva+&}bp3<{IC3V78Rf)jpbb1;xZnZ>Bqf+OFt&ke z=>pZF0cU`Jd@$fgp47rxQBsPu z)Hrrgun8!<2tSOrhlphzVcxP}s)1_)w~mLpV6z5b+Kkb3effgEO3+uqR;#m2pNgD< z`x-{u^c#^k_zr%I-xDvIoxsQex|gu$=dYnl2%Gt_j~#pfUpfF}Jv#QQ5~$Kv3h=t) zjz_Q|@4B@q+?lFv`1F-~uf#9@_3Pnd0G^v?F7iDymp>+>SHIC!Jyj5vGT8nM<JuU?Iksj_n32T;31C~JzUQd?VQ$@|)2lZ4jGi^p_NE#S3R&9 zlHkeBzbji{6Mi7iZmrY?^^qLh_x&FADRLVRoQ&jv8>f4D75WiW230{docptw1xLZQ zseqO$G_TnBp&5s}xhDJ)&(k_=PbH78nOaQ6$=wn#cf zefow|+V`bOuP;A>67m@RvX1y%0Ksg$!Ie=m3KELI@)cp8>w2Z-bmgsLm4t zP95w&qXawVEjt2X(J1j@`ykw49EZDMqB;}8khQ2hg8xS#8sK<>{!fs;Cnl<0N9uSHs#C8u{}khxFbflKqFJ}@-Gj1PJavi}Z+VM+-9D2#nsPLkxm z61smP^nWHQmX(CU66$*U^sUpewmJFN=Mt`Kg2kN}9ZLj%GEcpkqRi>s!|Ud*Pun(U z`4Kg*PrH8Z`XhRtcG;@6$U=tY5<+GU6~DdZ)|Qm6OjO%b){<{kluWj)C1mQi%Wjn| zml2ybryNyJ%ezUkdD%i(i*7g1G%s6m)Ru`B;b?ItYRAzM0-)rpk*mw4I9is@{SRE# zGdUsCJ)==uI+XqNoUEL*Aj@~B;MloLKxFRXm(DMoIGDI}W#N)nxa5^$V-x=cav@#) zLLZ#EOb{>2&Sk^>vJn}S>X*HNE{LJeGgRsFft)1lob?Nq7rrJF7cb3IujFK4?Q~YK z?0-0xxNvcvx|Cl&KYdxFu*yP4&<;7?mZvnjusf`Mp1-m>(nCVoG+&K4YRa(6jH3Ch zLh+|($$IiQ8NGP>>de(;0)^vl={tAyVeLOO|EW1~_Ts|XtHRlg7FgCR9e316a zGn!TGWLB8jJ;lte6whi-<7^^@8Lc$zuOp+ClD#`=T{s%7zMO=#^C$nvW>z}wpSxGx zZsQOWX6E*Ex<+(Yy?XWP)vK!aRsL7KUQ58!I{DknXG#g;&$uEz#&qH3AsIp3A!vdo zT|}SwCHqLBE*q2gDM*4LM+7=jSMf{B7{v(nDh$_X`5JJ=8gS(ra8+*jpq5sDsJNdV zxKB+N(i(;s$WAS(K+M{QF za`v>n=XhTSPrY;UCx<(Qvi;O6wXRocoxHl|^vUDB_D-iYYZ395lf(bZ**v^KeWcVj z0vmSqLzDN@?E&jLn6eD!hAd3QYKjEGDY=6C>9GRE>)@)m!VGQ+SIm`6DhVR6k+X6e zCxJ$wlq=1!BXkt)?j~S;CW>SHn7@mga2~vV~p#BabSNd(s z;xC6%;|jRKL76Wjpj-;>fn1tFpvotyQBa>*lymJSHCM!0FfUSX>Dx|3WS7pss&0Qx z63S6n?Q0}Uk%TzgEP1GLTcA2?)Ep|1tK_O_#RKJie3IY^*tiPLEbO5j>-3a!t0ifa zB+2bqwPFUTC3CiRu0vh%8g-LI<2CB}QcH$<%{tV#y|y<5X{9B(dR{6l$ycOQ%hiq) zh!V?wXYGr^Y`9p?a0k6<)gb;ihHT%|F%omeJZjc^ur zarJ*EN)1_Z&2>78XyFw8pfl2#Mbq-S)UsN0>!Z27X2|W9bYGx}YorNW7u!UQ7}&#A zWzBfQ>$2sVxIGfDlX}>RRiF>jT8I<%DTY>Z+cHo=hYGlGh9sy^uZYM#I;$W!3oRRM z2OD$qM_QP@KhnbN<7C`kNj7PlKuMZ`+lO1Sw4+{X3G9EZSGHDN-A~X3f#xi^mgJD@ zey)mZexQ|BnIwL+z5V|t4G1GvKU@xJ*mPYiC{^M)_R4G|ucUp~$C)9it4+o$#vI;J zUg_~&X4vQ0iay7O?K!rb&v8#1$?EXO$jf{#UOUcuhFHg#eVFDIPI~wXFCQCr^C}1H za1SxOa*|;^US2)u@;H3E8+oM&RzQ`8{vKXA>=xixVKj!9(;gqMb@)7E!%n+{rg^2; z=V2LJ3$Jw8$9*iX91uTF;iL6B9WKTzl=FmJq zB_I|CU~L4*D6A2dmX}uvGWGIuhttWE0$kVvq=Pq#sP+s4Dp3$Dwii4(z{>`9!G8m< zrI`u8!|k&}eWl#M8wczxZfZwm}r_M@M_ovy#?VMo2=4kuH8+2I^@_%3_gwT>ai?W^~)&Uzm+HeL_4df1Rn z_A!TV%p^rih-O_Qdm$rD67{ zh=`in0kbKC`K^|kst#92PR42y1@*sUznc7F@~QTn;Nfq9M8mHJ?hH*2MIEuySv4x( zm$d9()HjF7lxf2!rf7TAicICPHe}kiV0t@VHCy=Q#H{sR{ahQ`@pedsm*!~rXyi!D z68Ao(c8AENlCszdR8ki*J~I{Eu})i~C9$D6{bgHx2vzN!+k&c&q2gmt4lZ2mPrNsf zxHy11orys|>hz;xf6{aXXr>o5FLpHE{$+XmD5`9l{TZt4KpQ%q7#H5Xkho}1ylY3t zewuJjpkotg!$dNl4Jp4};pMkvyeY2!bjxf5DsNgSKQ(8aFZ_C9-in%!KWRhFrxIu1 zeM((`Ia2zm_Hg^J%y%|SZ@AqB+#8IM%JAiAdHB7MEM?4(w1uxk3&Wn63F)_nWJ^WX zJ7=fQ{<`sQ%S=nMs48wjMfJ1WP|?1SHuZ+KZ%az=R?k$&P69(P8CRm>@V_D>7Hdn~}K*X_`VRI4W;|Ax-ry)$-efp=J3HNob4_(+4+F zW^1$t4($W;!{U3zvkkMZr0p;=AI5B~*`q022+}_CSr*e#G1a*RKEo630CDAuU)ZcyozOfe%9^8u7~ts zH7jJGIC$S7IASD|gOX4nnok$S7{n_(L3}_0W;UHPMnAv;vVfeEbEFWp$Z*t>?tx*r zT>*Wu1z+`m@2lWQ$=?hpIpqlEZq-ARwETfW^7v^Xn-Wj~F2tUzpPaaAPR*%sJVGlc z!TSnmIF%0PV{kbSa#cc`SRcXO*{Vm9#RxVF$q|=|5~qb#c{6XYVmV^32b%lobqr9xjnYUQDN>Z? z(@tF43$oR7dVm9dz`*IG7zal)oIy-Iis~wRbn3e_A@})596L{6TpSVWYZO zygh`Fq8EY(IXw*i+bF|~Lt#i{vhEOzw;gMM*F71F(-*voX08l785ZwG9H-U&a$OftL{+)&@}kl*9?+Dz;@jBp+!C|J)`FT^e$mzURKk6TbaUS4xe zDk-noaZH0hm^ zhbmJ#V`K-?mBz>gU3o%Lp3s%2OhwTWWU2_BSTg1dHja*-h-;I^-QZDDg;OWPC!^hQ zMd)Obs*O7kwF^AQHNfC$oef`#(umrMkxZezk&f_zVEa;GY1EGjE8|L3P`e1G(%h6Z z?_D&vB6I6p_q;4=KD=l?g3L$$;D3DM(T$|JCvomPGM^82rZ!Z>95cp6>rQ0d`Jm%r z=e^FPbw9Yh;MGlC3}1{EMknJfbEbO-|FG@x&PO|;-{QIV(7E@X>g~ZJ&w!zpm}&ap zH~I>2guw494>Mr6^_xC%fcZQe9Y*@^>G@~?8ER)0$j~$g zE>X)1IjL`5mJ@l#TN59bN4i5-ZyA@BgrRcTKo|<9F2Rf&llskzdK=Q)lKP$CwNm+l z*v#O}FH+kPwLMAMKq5NR)LY@VBED!}d!Z?gzk7kahyF1E6KLw zWtI=vCssBD?9*}Au+Ki~#1W$O;HePGc(? zKaR|0h-u+APgJ;q309wg^sl7i-ekUSX~CSza?!-b{gy7E_GOIsNVbI2k`8ObI=EiI zak}t~~|e+8|EKUkj!dH2&-nQ(EV0RyH@8!*8TVe2+o(>^@LxV^(3ciGsm7h-(p<+1_B zca>q>Wo=~+H(e&_XS(bi7B2M&XSVl=Hy6I4co?G2(5kYjt@b?dClI1rhlsCbSz~{~08nX$(`k zurAUDO0Wx5dQwyK;L5`r_ioHJ%+d1)(1A0_`W~d|NokBPl=5Pgpb(!Qg3iy&;AK;; znJ}~E3C;m)2Q-$37q5hrjqf^(7gR9#Vz$vPWMX)gn3uCrg0^L^V5B4Xh0U*E`2_*6 z+o8&ncAj+diXq4rt_aABuXzL;{ZsrtbV!s@@tTrhe-0f4sbUrRw;-{sAs{wiWvz9F zseR#nx0|Entpf=~ap(ZH8ZF_L+pW=t1zl-EQJT<|rgZsJ&Ee+TZ$--&bQ=?jjS1aG zZ26&Qq#@iAEx21eQygoUDUFj5ATH=^3B`rP&I?I+>uhPn_uGSQ!3j*gHQXvzzSZc3 z)#;9skt3gN`@Htk+Bx!{HNV$9)wBiMQfh+`uxj$A)M52Ud4E;NDOIrjJ1wC!M=FuR zn$qNjc1K7Y11oX4A>@G16prCcw{?;3JMT`v8+H8hVoIZ%(uTE>hTHmx_s-<>WUTy` zfi*Ar&Kjuh}og{esQZDX|PZs|;E%<+#~QU#kJJrNt2WmEaq zzgj7!YPkgD8$2fLA!kN+!y@2ADA>*K;0XPnTg;qxPzX8|IA9d*-5M5Aym*)6L!K3s zruzg`Z?nyP-3O_SeGOEaJbzH(SHPEN~K0$K=1W#0oU+s-K^+XaEvKS6k<1+^wz z6LH$~C7KRGNip=IAiPrIYuj8_;LY=sMu9L5Y&WM#k3orJV5xF>t6;EoIQ0q)2WJqn z$#6giKjMJ>ouh+s>cmr5~+sH*CC{HpgF%xcl*mIYl)LeY}YwJeFv z!S-h+%bktW8>79krU(9qfqQ}Zp2vNU`ktCjL-=JWU9{98OI>_&ZX#i+OIq53KS>pp zEEd(EqMCT$Tz8_VCRuba*o9NtZ+)ZR0x?YyPGGl13#TtYc2){-Qd?~CcC=;J;^ro_ zx#@Rjzq<6rrR3(r^A5E6XgVwHNL+NF&H=P(Ai;PSeE!7M$%JnbF*ner8zJq-4H0K# zFVd8xbfELWsPx`Gkg9H+b=)(~HKFYX7po4Vs>6Ticzo*7sbtle#MyJG>Rh64cwuBT zF*cs?`p}p!G2%mgezeWMOc2+}qq4sc#4*`%8PqyuUHH>2J0p96YiDrnm{_|`_F!OD zk1TWo5|&vq&-^f3v3qzh5E; zWpws$hO$jWnlAiiC_PF*u?{_wgI_%T0D8)-Ul=I4to&a;@DWf2)B)`JXEK0TSwogj z!^t!GDJ`w!w8LaT$LZimtGb>Cfon)r_-ThtCgwPBL;^?qJZ{kW7`PI2I=qa%0URkB z!|;VG4Yz>81#H7!A)6rhDc~4w0|deU5&c0-Lv)so3O(V-+XwLVoS5*S2JHBd+Q&W6 zzi!|=U?Z7yKX;RCGr}n#Nv31myZnl_&nkI6Equ}c1(LAy6~`&5P}4e z(&ry{F)ORWV)Z4+gx87Ev=4b)^gTjgP;%18Jhb1%9AG~KWE^{V`{2DSCrR?p1oiKP z{y&J)WfdXQ+$?#&WLb8aByG!P0xq(sX==GZs1_0hMR#hZYnF@B)%u2qu6wRbwUazW z{;B4VQoDsPStAR}k*^iTIM6Sy4rir=}j^Lrk*K59*L zpIhv{gt{*!?DV326xl};hey%du9pJ#ax1a3=7n-AsZXh?o80@{vK*I7w%o0msS(S& z_IyQsL5by#6Oa2J^$X>t{DNiNZ>e&-@P!$xbT|9o?|-QV=!+9%A!(8Lf?|Bn!`6GP v%LEkXm5-@Mlu%6cT@t^S#}+U9&}H9CjPSChge-W`N*F7WhRXMKg6#hr>qrHX literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_312025.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_312025.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c74bdaa23554fea24809ef612263a6ee5b43282 GIT binary patch literal 13580 zcmdTqX>41^neXwwC6W@U%aUd5v}EhBWyP*yTelrqN)(^5l`=FRCF>&PN6Lrth%!aG zA)8=LrH83BfT^Z{=~!6PMSw8b0&%)s)WKp=PH0rzw_ZTl8d&(B5~skTe|Eo_M^Y3` zSx!@6u_JNbT;De{-^_g1%)e;0Dhh((?>;<#rHZ2dgfpq2OC}zh#1wUnVkm~5qRt3k z`V39-qG|CN2~FAQ34$jDDtwu(jOb1JRrjK^sqzp z>2VpZ(DmCN;eiaAcl8X?34>I(Aye><)L6i?ubm?b9<&kZS_G zdo0sd_w?-4u9+(qD6+WOaktZ9u{jyLWp_(!b88#a@Ma5Rzc6ccxX0)076LWv0Q8eA zW)E`sxeh_(rB>8=fogxpvZ@rRO5g9g-oMZvS5DRSi(7)&<1A=e#dI1{@SG?>hQo_kd+El^C*RaxgqgvW6;%uTk#w+D)Wa3}X=j zL#J0twMk1+x}};^af&(UCj%nK9$;urkvhe5(49xasX65{!710lX*msWZ0biMh8h(W z113(2v2a>OJO^{ab+}fjBR%P!fRJqhk@p0Id=rR#PQ!^Op?{B_({d7$GI08=ks7i_ z2dm*RX3$NRT!@ShoEROR1!rXFB}t0IdkQvzqPe^zlr#gXa1$sIm!E`^Wk8t-ZDDK- zp@PTEnIQ#Z@)Wt@6bpeXNbQIbOod`k30FL!PSvh`-Rb0nCY?+tQaCu%9PDDQ5L4Y# z%9Z?$FqLtoTp8gSiXj#sfJgW|JW@b(Q^ZR!r{JYd_jx<}U&4@94<-{7#S$(fXYm^G zsu|Wf##*N>;|wpevR22KotI2GtqeKgEIxI-d=v`YyIbqG@ZuRa%d1>A>y+JP!D;23 zoplOHdXT4wc!di%3Gk2$>69y-nq%_;AFIbmr8DDz5Qdt?EBOnmI4%UQ`$}n%TGd%6$={a7RY-?#=OSbSL;G3dx_+EtX#TM5V z03_^lqlndF)K7B4NjvMXPqEKIE#`)6VP9xCERUBqBW1I!YohuPX&$z_sv_vdmm!nA_kG17Z|92M7_iQpf-UnJm_&YYuybDPJ@PY$cv+`5V2SRt(Wi`LE0QUa;6tV zPG)XsU2cHa^uPS#iI*%w%STuS)r)wQV|Lo&a)WGj@k*S&WMyYu|xgxHd^j6`JaB@#MjHUwc=Hs#qJa}EQ0%!>7`I8ZB<;WORV?QWDLti!3ed0qE9|{GuF$3A>No?TN%>?61vTu?Ge&`BTvp(=KSBF)t`6_p zNuV*dOqUe5lDej+eiK~P+0=q@uHjNHsZwM1?AQcT%~LVSGB7o4f~o4MnCJ{ljoB?> zpJ9!R>?Y4nu949Y+sv3nLda&UB4*-QA;q00o(}mGUCpx#dRJq6?*DKMP24W7fpD}P z_=I?xxu$jbmW-9iln+U2aP?d>SDW4ewHd8*?@JX|$ch+|gNAiTo%Ad|0X2Yq-7Z_d zV_P$N$u*`{t|fC!a2gCr-UF(YXFFHL8P;(dTw29dr%MGcf=jEpnjEDy0s5Vogd%(% zOOKm>C7<5?SMupSoCvI@-Rt>onwnZ}&w9?8s`0dC!*Ei}53<{HGF*yJg=cRDrHir- z0A8aut~S^H=h||4X(q~^#E?2g+}$o~j@#UpDXYs3>LETTtQnW*k@9iyzRWnM z#%)(vys~T|j<8I4ghf(4O#tfiiP^@2)S-BZjTyhli>JpOY$?u5iG{>sgNBz9Q;x+e zOpG`dYfQYDak_b_%k5;r%EJSLF-_U0!9QeiU*T0&w{v>jX0b9%eLF9;5(7_qPWagH zhpoW)w&EzQ-Ll;nvodM+S*O>Mlj-ESj3gL&$+ot&wl)W^O_tg(TBj_N>$8qr`ZC*x z$I*|Y8XVQ)aS*d|64Z0cxhbb@(ghacTH3)Y##ra<3^sPvqvNd0Z4s9^W!JUz+NR$fg; zW+5$km5{+Ld4*uTTF&v(Nf_c~UXDpI<-EizaN_*<*m)k70~bJlBrCUCT<6C}-MpL( z3@DmRz!1rb!O(2r>2poG+JQV9nYf#j3l^|ea2Q&WcJCX2uMvK(S`Zsx_clMx?p<#M zKhztqndi+xeeh^pU+;b4K~{c6ZVX)bQeKoW7Wfnili9CcF%|`nB4ZuSmWRraX{S&9 zz-S80g$9thDN+_U9`GJbXw1R3PzNe)LYii8?}}U((EsrD2Zn-cwez*X%HXNEq1q=; z0IK%+gMP(|!H9Fv#fsJt7(v>~Uv=F$vUntH1$Qq3Nz=U}X#WYMJppJ9Wg*jtba+o#zE~0IM3rsPEof^W zGW6f8i4C9l^3>_r@aZp_Vk6_|g>hsU_YNfVCBciow1x{pV+%isRHBOA5f3WwLHgc% zvU~0?Pn^7W>5IJBnK5*D4C%+bFFXWGy3Su0I2vk<8Jj+0Kc2fYcVE@z?S2R|Y6*-5 zxn*smkN!8cK41+tB2!JcHg0NpkZ-sqpO*(^p_AdBUss0v7cG%XclCE(N7Y^Tw#O`! z_w%QGoexUNuGcQqh8iM$pSDDqX#1TP?~cV@8vc9|?HP`}GIGDdgJ(v%VJ|nv1Vpn!gmh81DOZON5EE-+VDTc8~sK z5;Y!)4ZnQfe9U(^QCOO${{>W77w$v_JAJ(gb493X{s+sZ9YE@(*4ynj+vBF)%ci}^ zv{#sI+|=dkPZU>%h8MPkt5DIdWpgVsw|><6ar>S2xcPg_<^#xl;5Y1NbDzw`%`f1d z%7SyKus(bi6+8z$z`iU&riSqAsBpip?;-5En)w>o;=Q5WaYL=I`+=$G+SvS9a5Qw~ z6Vsi}yThoqAC>pt8;zZK<;$}pu@fU-IAgC_(U3L9OvW5e#5mEAGj?fim7)ekuZX^( zDDxSSuOIJcjM*Hj32%=ZcKNzj6uLlTs2CY`AjJ-!ECJaL@Q@g*!b)UlL5db!W(;b= zG%BhOzlzNJkz&73olqD8T<8Ka*M$d>X&+MT`>|@Zp3?1HZKM?72>KzHFqmPn4K+)K z+r>AFBQ24sczriAbi=XPxJGXZ7KcuKI2@6LPoe5IRNRL2ZBZqxX3v)|AG_C&dXA&# zjwAhX7?q*;y~^P6U@bCi3k^d@;r7U$PxX;@)bPEi6*U}0`a?b$b{;AGihwC-ylz@B zg=9C>i|XaAd(qat@#6OA=pPlmD;p1IvK^;*ovpi>OeGv8+DTE8|h z8nlIFe^fLPUz9EFya4WrMLlurh#-DT@|N^19loHNfa2`nk%J3D0Y4=p@~9Z`k99tY zPwLZ+iW$k93Xj^aW5mSGvB!-ul90L)9!XoO)f^@!VVN{5hunY=VX8pjDhUH$kNJd?)tO? zi=#Jz&t1O;ubD`@3D!n&E16{E#D%5jG*9ov0aqYwlaUUaz1#DLSUC)8`oLfw_jH7L|Rk>T2E9 zLG(`70Z0gq-^GZUA^LAG^}gwP=|U-78A!t4M|B9`U@eSK*DD1iQm`3ppnk#)vik}#F<2i) zg|$FlCaCVLA7{j_+3ET`)`ipg_{#VkTq;k`uy&UVZf0bHg3P{zfh%#e4M*L$wcKVA ztflqNab7Y8H%dHh6Lft-U3~no()!wKIL6_SKm}GyfALZYUEb9( z|CqNwVKjmAC`5&j(&c&k|11${wt0J26<8EvEmEVuI^Z7wsS6rHFmGx0w&$kj?(pwV ze}4MDX2?%JOxEjzouCelMXVp6zjOY+=Ae&W(dq&Q|H;4w|5z0oc-WZr1FeqFnu4O!zwb>W)Z zJ8$lc((xVp6V)xL|C+?s+C)Y5pG&pU+SM{jUleS4SL17mNp^)QZ&WWnA2dRi z>^Ppj%&XQ+DO=_d!$%j3j>n@AZz$s608S(Y475!L@Ws-pBk9575loKkkYc~^HbXp- zDPofi8JZC-iPKo%-NgiW`m#e-D=FM%NU%2!{8J#5JTgwArZ^c+Gg9nd$aSBK{kiFW zeR?ZhiAS6=$aV4i9{xs!}#F({jLXR9hq(rgVWa!?QsgHTR{t{^7f!dXo9`aEzNowtI!NOE!9 zeu?-~OU2VG%_Zy>JO>Fv!?0;cl`|Eoa;fh6e#PlB%x@Z$Iztd7w31J?lRk)DtU-{zY-9^G4rdU%2IiBN6@L z&~K_gYxtz$p7jsszc_zidkW@|Uod|xa3OGPUU^M3uX+E(^|K3SBsP?%yehiQRnTQ`k zY>R8Q00CgE9sZ7>D6TH`_9VcQ&=#zUsmo)Maw0s)DzZ0V9oR5Lw&Zb#r0~*=2#d%EMPDvZvc8aP27jt85@qniXN&21G%2NU>@U(s*2FL(|H`!Xk+OAFYEv;bHeOA05jxlC*la>4aoal(@~fyx|t ztOk$ZRA8SeI8E}o8i-7=+1B!SJ+$ypZnHLEL2NQptVm>q)J(zF1p8RHS{fDM*A<-Z zZ-Y$`_5#j#zLZe!{3l#D$-y#^4Q233!awo8B}^1l=LMW5eDG~ZB_C3?>=oFb7`PNW zj9XxBhUnCwH)CG5b_xhjBziJ09XfBF8WkRtD1=v-_@&c0Pfze->=Tv{kFYZ3;l9jX z#;C49gL#ElIY0NtVQ;5*Hlff4XulGk^5hj< zeck^$)>wZvd&9HniJtuI)lXi%pEuw=yaL?djQ`BrBk&Li`wjp8w2ywEHw9*YUgqme z7>mINI~=Tg6fHSDD2E1?uyOxcu z$k_T(+sDt}c|LCJ^z~qExpP4ZzP6$U-;sp5BzPQD@>fEQXiF1z1-C`@(Ie3^w7U

qaCR2FftzoH(q7kjgrNZrQX|vHwWXD`+bK4 z!$BH%5jc%B<%#N@H_k4e{iyKcEqAuWs}HVKhv)@mVB|~9wy*LiL;0$h(wSBbl%~M@ zqTuDpRAaFeB>W4Uc)-8Fha-NvB?>2&9))XJH?$W*E>8}qv&mjOl1=c@96Tib7WkwY zPi&-|7#>k3dCdsyZ;#9)_b5Edj7O8eOY!OB#J7^FIe4g~oNAo*s5!N8PWHh>;Ld86 zv4TEAjK>$C1TMCS+*iBpa2;c_y6l#gdMP zvd@5)nC(1Tg$QVgi42#E*nRSd0-?bEGcGU+Oz2HW^LfJjcTK@t&sliC#w$`O;YFx; zV%$w0fAUIrvF38yFVC>#kt46L&bpoMSqErl-Q+DBvHr=T@Ur7}hYP558@dqI3Ud=e zC|)g)*fQpvVs27`u*Y(dS9DA}nb|4(0rn4&1*Su}#IVg)#WYR-iPHQ#rTq`8Y*j{y z6mONhQS!A&Ocy*BQzFeOP60`T@>Nj@ZCI(Uy)m*lvMPl%K-ARX+}Co*B_VM6eoCbG zn%=5-qhdu-=#yOq*KEiY?T!r}i`hq!X)NZR1u}Mty0Noe*&|AP%T{c3ckHE?V~ib{ zMq@5F;OZ205nSEtaCJRHD>mACkHrpWlRd_|0M#K;CqdOET2;%mvafZNvgjMVQY2a} zNE2e1rkhuFROR*?nnlfO9?9mXvU-v=q_RemEl6bxN!CP_m0#~z=vXzAY!Owu<$B{n zs_0dMUB5>l`$o$AE1^>89V#^$GGTma|vzys)Neg9@ka9p(dN;e*ttahRXl| literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_357204.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_357204.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88fe791303f6d7113cfaaf52e13dc8f51e356b55 GIT binary patch literal 10726 zcmbt4ZERClmhbuV*?zWTJBf*%1c%TNQy`E4DI^3ENK*oYh8Biq=(vub6YL~*^6U_B zpJ&XdBW$SF&g|^QnXML6jZ{+@HH<1Xvs!8CXry+vtKGaB4c1#$LK^M#K*j=r( zd+vRHw#gXM>AWcSoqO*&=bn4+z2}^Jp8rX&*Ann_n?5~%y`CU`k1EomF91(}AtQ(d zf*}}kk{A)+wm-J&!60!Kk?ez0)sK&_PPtGV9Dqn7pmK1v5hrNN*6K@=`-E4Gis7_^ZlWMVQgQpxy9-cH;}$ zt17pTX&KW8%3FnjrT2&iWNKv0Ucz5pK_<%!+N)pCUh{(XTCa|=R_+_4nGGK(N9dPG z0n}p(rVg<5fMu)9hxraNjZuspxEh#^O#KJsh!J*|v8b7@&kGbADI@0et@my`_x!N&-9yoRK^qU3D#|lI)6SRTZDOV`$9`}X>4T9&~ z>~!d{1C2nCfFc*j69PFTkjDU9?Fk0PeB&WOKHqqXNmoYM>C*Z@qc^=-aMi zo()iw9@K*l!PDD-y+DM4ulV>2n^dZ!H$&c*-6~*2H>`szFT<@X!$q@O^KeSeA{~mT zZ5><{S6#%lbG4j3q9llDJ-30YkATe42Cjjt#chpYqyhYJHm)%Orh=<2(r)A$fo5X_ zR22c7r~`9=8nhukO`LUjY7iY~{Z`a9 zBAzrK%wsdRnNv?7>0jYCaY_ueaLt?wLtDybQo(F;7Or|s7B0@uU%`alR!HMTTg$jr zFk_$>H7ChfFtbO9leYgxnf)#v#|VE;49+bYR*M7MJ!w`NR6jdl>H z#qu<-;ir2}o*nDRvx8eRKDXt0epnu|OrAE-ks0scot%u@$x|a%($RKNicmRsP%g=;l5xLjCJ1w?N?SBS6vVr@!@-tb%Rdd_d z8KJf~7I8?AV(&!Ng3!|+m*@sOEf>Zp_$*+?3*4owPOj1vD{~w)f ze3D9g8t<1b&dSIqj)2cB_gUp0v!EGnKiA52-PTHWfFv%JdkR}ofkgmO_#2AVB2c!w zUZ(#wV>xw0EHaHwQc$=r`a+M9Uxtd3316D_j!WDjGEhvpLlc5B7(VZ1S=c88i;Vwc zBgSWu<#ZYaS$I;=PP4&r);;C&F@nOw_$~Rh8s4_qw6y!`WT&#p^94NMjJmY0>sXa1XYL83@ ziY$XA?m@sFXM?lTu8KwyJ{HUtKQDhc!jT$rOk ziHw0zH|AqQVT?E>(53N!1Lj$H!9oPb#zN34zNPwvph61_LyK?d$Mt|9cY8bnIUyK4 zvn=Zkgk6|^LQsqNK!8^vgjQp8657Sp2;pVH4dd)IY=(=1T3jFK0RZ=i3k3}Tpaxaw zyl*TFtbBp+ZirW*xzM5p5H$*_u>vYyYuHFwjHy7V8+7RSGoJr>vu!TOPK2i29&g)u zw`al~J|7Hhhrl@yZVR!Vwy<|3u zedliI5%eOGYr@M0yp!yEAR&sBL;nJiOpJK8m(UxNClV*pGxzAG*x?7Xp;YUCRApXJ z&#Tj9x;t0Z8aw)6O+9aI&iHt1+lrz(Z#1vew(!Q5*c*AXEq#tRJ7dQlSZY(3GF`lN z#}b{h9Ecsy)7ErHrk=NN=jk1>{=B|w#nQ~{o3k`e?~e69)SFZ1czx3^`aU^&<7n2s zRG-`2!|Quuujkk8$qexNS8tDg=DX=zCYQUv-okec=XRXp^`~O5Kh#@N?!*XUvle)y zrIW3R*3|LLwiQd;=j<1eJCS?ZzS!YMpqI0WvnhAlls=yApS9PzRvd1sq617 zzLS~BZdQ85~ zR^GO22^tPOk(2hGCj@Ce7*|PJR%On7+M5k$2X45Q=9a_18D5^_I|ukz2EGp6t9mQ0 zd}!W~{z3N0ubQ$0e8cW#l5g0zJi^xx#y7REk{>|Pu(-U6<0m9G-XD3OIuu>hhaJa zwzu)-wxyQk>c5*=w%%#GJIJ>W@#dkp`azNBg)H-{-t72|A1!dSKp_nc-_%v%P%v_P8c*sY}1ZTUxTsym>pChIOHJzV)N_>-!h?=gcix>u0u` zw$J;%7`!u>a~|Z)2k-9vYTuXp{`u6uj{M_D?(kXObQWen={{tW7ZMjz-i7n?=Rabv zM;0Tw+O65wKRbH!=;xzf`0x00+j?`xzWc_0-q`=!(Xaep`g6u%o*KR?U+o|ayI1!T z221M9&wJBhEaD};W+$)Txl~feJAB6wuOEW#XsG__a9WZ2ULu%LWx}5hXXb7gc|%)l zAg?y0Oh11A0i{jq6Z+Ix+LKlPj@tPM>QsVC)ut`iZHqQICxbTzb9HS?juT+>0Y8SH+9qhh_`SvKH7E>wDu!^YQVR9>O3Xch>L z@E5Mg5)yh#83Tr#XazeW8zg>Aa^$zhlznSV`M1U-DqZoO#iYh_2!o{n67N|F> zx~hQ-yNZ#a(5h@U3Ndgh%#HAu>#4v*(n6FdUCsR!u73$Cp&65bCM#W2i5FZ<3zy8O zI*j55{KTm^@L`l6Q22sKz*S8KmlZYQ4|5>=D6UxZFmfsAd_lje*iWuQU%9_TFI$Jc za({_lz7Bom{t~@<9s0`sC3?lJ!exmhIQ3hEQ$6$;Z5lDiM9@tci$Z5X#dt6HJYG-v z!9dXzB7Yiy;;&@u<(?8G2l*$ojM8nQ>=s#Jt0rLkFgPzD;iK)!$j4+LsA8xrn6m;; z2$NM}B)~#`0;SdzC;H#+Ek&Rqw7VR!80&(!Ky^>Lr$!lfPvq4?pul)%X5E1>go!xV zafRA4?-Y@JMpAF7J@C8w}=e(eV za3_2TYo7HZviG2X2n|QiPAxl%+B~RWP~k-dvT$OQEj#uaM3}ST>Dh4E^wB|-LbOAN zL@m_4aPI6JpuY$Ip);`L{~aS%4Fs)E?n&&q)|ql&?Ost>b(dLQV` z$?qq=pKeMIr)x7apUmBuTblXe;+>24^oL@vKT?>U$>bXC6S-2UTBQk%Hr|=)Pw&rM z02c9GaW~?Ge2VFsE;YPxcK&SI{g>a*Q^uqrVMxv7sJe9L^`6C^Y*Ws@El=susHx#= z#&pf~`o;Q;`)@YpDf4}*jt9i1995t8-*0T=8{3x1Tw{BlGGPDZ&QTlEE!Vd%ZqJfA z+g6}>W;cL1jRb9uoxnT7Q@-<&Aw}Et>Hk8B9T8O4K2c;=5!RXq@^!N3&sP+=vNc$! zik*NwssKYuqC5=7x0MnWBS$tSssesRH&Q~bh^r14U0tOQ zI*4MG4tEw1*Usk1p+>3Zud* z*-8>hP$?@DrNhOH3!{;=cZ7HsENN8F(VULcGdf1|6FHQ+V$N94h*?uOClR>eN9hmi z6!5u`qbmhP;{==>9YflPk4a7;N|I2Ag4ieamc3$+92y2hg<~sO6=coK1om>a{+IZ_&^yh^m1P; zc|(Dm5#%$jF+q*lyJnEHMd=OnokQ}1;#>5D`sbWKdkxU8z<=mJ;b?*L(-LNUp^Ob8 zx1ma?;%sVvdUU=gM{SDrKIh6Fs;T(?AOA3KGN(r78#7J3sX43VtJ;@F_^N|(tCFWlL9%?oN;ot4whBEsd6=8et^x!TpK; z)M#dNygx^8fg2;e6%q=So6%TD6Gv0t^y$poi|29}6*3B^6Yr)Op57Gee_*Ozu+Q7m z-t?J_cku@~)3(^)BXD-Qgzjhk3&-Y6Qq#>d-xzko5h&GEMVQ(7CmJX>ggnGnAtb zgNjV{g^lwYGq!B+jfSPId{g(b|7-KF13A-)*x*B*8N~)Gx>r^duVA+ z4aof=$NRJ$OxugxGj^^DaAyiebfBtm(OLV;xd9{ea4Jp-MzRhIQx?M$f zVEcQ->j{GW^M*o>Q843XwUTniwAN>C!)cY`DdpW-kpCpM>hkNmX6ORwx%gw;{Xbs9$i>c zyHf6J2UZjt#PGd%wqfq*#|qd!JOcoJ`9wzLFTdi#$s{4WQU1xkf`zy(FP?M7I{e@Y z@F($8ma;e~FfaW$;2g*iRYlcN3IbJ>IL8lsxKkm^!O3qG0_LceQF2-z8P#z**x{-n z7AXKeKFu<22rK+h!Pi(PDW>3bXn(sr?vU5D%c&HD2J{(<39wrbE^@s1?&+k(yDQ4? zi2D==1EgCZM{&^L8Wq%|E-}+bkGe{sjzGEuec_{kcncQPaMorg!=VuRKv}k#9pY}T z7KIAJHx^G=y8W|%3%I|9{}4KFG2)@Ay678I!)Y@0dS)-UwT3$l_o#!h-h$Vx)YHgp3y|j4(#&Pck927eMwJoMdhi zSWHoyRQFVnnVt0Zu%7}lx|)Q33FWGsB+1_s^uH1Me<$i8qLP_nwO2M?-uRHJiL20c zITKnwyfXY&eoHH)@~vC*dpqy$dL}0|uao5{r`AEq%3j1}Rj1ag{!m3|HaxRvWwO;; zLPlM&UAC>(6Ag_|?ENHpaMes$Y8JN5Z(FtCYPD3Y!PQ!+YQ@zJQne0OZG@pJc`kA8 z517?cIU%E;TG2ooaQAyM!fJb>Ya-2gHGPG<%&p1+hHpQJW>pC=1UX=Xx)a^2YE+{L zh|cJFdR2pJTEbMlil||msB~A(T|W0z549&lWE0tvI=ygi{@f}7U^es=y`H>5R*`jS y<#l?IUL^p`p8o9IO+W$g?(A3GmmCIH+{{zd^t9bZTAp+h#^#)1^JN`g&i@BKzxXNu literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_365790.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_365790.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..366d08df2709eb9a1aea555e6a2e569227dc11d5 GIT binary patch literal 10087 zcmc&)eQXogmY=a_?6EygVkdEM2!xQ34+wp)Av>}Y^9p25~iTn@KlLZTB+InXW8x_>sv=J)+{R_y(g{sFFg2TcmLTv z_m0Og!6AL!KlVzVxgY18d+ygczjMd`Vm2EoNWZ`K&82%~6!j-uP?I^6dHx#>MNLxz zCD23E1?5d&ph;OXti8a{l!kKC14NR_th`A*|8gBJ2^v8QKSp2$-50E&7r6U|d)ifa zj4Si>Yx*g{^c5}e{q%jaDhYHd$kTU$n;9s|Q(w43ebEZ_mKEx)`Rn_PLh;AUy-dgI z`+`BJ5=z|EKuI2rO5LVDMzDR%Uf>Osuu&)jy7D}9YoNX&PrZEw`tlX%p+8}5ex5#F zr~o|*LECjekxyIPTfX}BE7YsK<>7(3s)XK(glb{K$MgjY>?2Ed#;oIXL(Nk>+^(kw z;rTR#9~^$r@MvIV{;hXU9O`hKK7Qedqw}Rw*Gr`inLl>w#MzT7rT+BEcaNWObUcL( zPd#O@FpX^Vjtx8f0cXG4FB@?Fic=i%Kiz;^wKCl)(_J#%A+s*WNI--dz02qA^Yr^= z?Vx*HHj8e*JKzZT920KQC+dMq#Hxx0{J=*>G{M8Q1|Dj{P53Z8F9-5zDu72t(g4sG}hgHc6DU5zHj@YmGqlY3!h> z#9=woCR~f_Gqna>i+KoNtwbvnOMHSB&w*ql+||ptO)qiR<;+hWlPncWC4HJecGd>+ znCJV9I*wp{pim_h=p4!Uc8y-jj$h?JOhxI7eXx2BQiC#w%g0Vey!n-RVh8bfH%of_ zo`4-5+>+CK9`h~LN?Qc>zD_m%!NyhS>z2>4(wDS+c6F)-0-Ym$>nq$=vP-wE(yLb5 zCKagb^-5obYD@N{HmxF$0eM9#*K&E8U1L$(vr@M&*B&V0qzQUV8EnoYXQ_3Ev{h=J z02?K6WLRX{45A zR)|-3gtT{=EitPt!8cyvQqJUZj_mmewn}A!X0RSyTCRV|^=z31l-`i`NUckdkEX_R zt$O*=I;**S91X@U5sU*R!rr7r*tk>!Ewe&ke8l}9Dj)c7X4vT;lv!Wkl3NrprhSSn zS_3Vv#jgTc6Bv?>BciWgbPhW_g3P!C&t+LV?D5Jvr|9(dyJZ&pZojPW8}c~=J9o&e zQq4e7khOv@07873=H@*z>ks%uw`_C!zl<2`()OM@qx_t;+Kdqk8Je2oI`Fu zsW*@=NSg)PM0`%fVoYW52SktHb~uNZvV+-dZ%)>mt#vKYy1a5>Mna~${^b>3xgbNK zDAB*H?3IoEqHk=(;r9T+N)SfYj==1(@Nj@m7%1T^D57S)a*>+#cpXHd%qfC_(bVS= z{ecV*ztkT{ATl#AudMg=_4)B^D4AXuwF375@X0#TI0!3;r_wL$$aDi1bGlqI4fTT# zOo7|UwxFFd-7B;Gu<%!9oiagUHJ*0Xf63DqkhuvMHpzjg%g!Oc%=9`YM)0$5M5bLC zEnvbGbzr)61PS`d{_^hwSsQ+N`tOrXSA622f5ho>H(heN2AzRRKJQilquxN1UvxDE z+`}VHP-^mc1G`{Mf#I>CZ6o6jsB*x*3HZDYmk)N&&ON(Z-q;Q;vRNT=47x?Hdq}(p z3gRX8KLW@ZqF(HxOy(Tj=iR( zs3LX+71f50r_I*rd1PMywe^nuwms3D@T98tBJ>m&8icjJx8;->G!hZ7Gb=8bQ`JVKAnzw4dupX!e}7|<5Z!C$!iL%>GGwo>O?%BGznjh)8I@J2kqi)pNg{n>@&%Hm-pNDR(Yi?O$ zhht?^P4NrJ+7xCVTZ*U4rpjXFad*OpYTrcmH|O+^m>->w^yuKp?GBMeI+>Tf+YSw!=VW>z=a+2vLt9B!qH}tI<*uR4&;gGBnn>?T%H6*)nWj5wk)aZ7c@!S|ACvAjFL&>A#!I3*b6TVD4Zj1nYrw7Uy)frh^|XA^h<_F;p+z~$K zvZlk4gBzJEMD`UTn;;vR!T3(5M`a!0i9xp@p2ED<@bF_zUxF!O4;E=29dklo6K)5_ z7vg28>%`>!U^PNi+Qd)pj_m%lCF;Dt0p|`-xn9NaQ5KMDw=GKv__q=?gTS$+8R3iYE9Z~i5`x85IcnU zN(>CF;j_`Rk&6p_72*M$R4?!ih;K+#B~GRI?F;-)#P9s>$it2Y9Vz}G*iC~md^Tc_ zo*21oHDOlFxMk! zwc_94g`ls>`q3+HPyeL=Sv6TVO2E3T6WxI^(d)_@9-zhnGcocXfqDMK@QHCXL%*7d zzhMaKBzi>$pub@R6RQ_!96aW@3^uegkTw4Uxf+6kI7JK*W1j|`L>{vr$X+q!`?7Ff zlXV}g*CMg`OsJq$%>+$_4Pod>U_R!8+|ren#1UgffjeTsg=w)qFcK=CSrk}R+6wKu zY&*RQfByC*es&f9{OwEpnpODow=eN)SK-g!zQoVm!;vc_nvz&Bp875lp5N8TI>CL} z<8q65+C?1ayK>(O0Y{`UK zg8L;C>hSN_);8pWz~8?gG79k!i14?-<3B^xW3FJ*6fs2s3&y%2+(>ba_b-2Y?e4YN z=2>BGKiYpP)pQzhr_&rCJ|8K3#%k?4Vj`CALK+OR+2e<}PxTU_0J6nCc*rcoi-B=5 zh$|XMMW1JA$U(wD2!P2|LxV}Yg4?g+2XPVOI7fVp&Nmhq84D;^6Pk`IGSd&Y6FJj? zCoStp4jxpd#lGTXh!LR2F3fL*o$y~FYSBo64>=GyFmI|zG8IWv1@RUHf;)kM3ebT{@T8C!T?T5k^h>aoQ(y>@CX*k()Yy+r%BZ{L0U?~nY@`F-cS z<(-J`2}~~^;iISLIr|HZme++2zu+j=8m&Q08F;2}OY|s?3LrEvgqy=o?5r@-wtQ-e zc2B=Q^?uCx+Yi!Q;bcLiAPU~g9&7n>|IPi0s?^#>;3Fv9nc~*OYQNlib8CW5m2duQ ztF14JjjV3bM)B5Ar(*p%xW-UCHg?votYw4(J~LTcznYy7V!)s0WDpq71X+nuYXFZWLr;0Yd5+!I1)gO_WJk}+>UU=SdjP<@e*J8$_kVqtz^{#Tfr ztX^1`%x5EFo5-aszH%pXkxviMF3-9O47QGfLA7_PREbqTJrRhVFOXSU+t&m#@IeaZ zSGhI^xNTB^GIse58pXyFLqo7IR!_mu7;6&v441H^-H92Dq8Ha^O%=hVA~uZ*(MFte z;|IpZVhubRc<~0#XJIU{QQ{}K4r{GgDY&@909rPvca2wNt=}~y(-Y!3AoAmN%33VQ zC??A^=EX-&!M#x+`4y%;0oxUDZ(-R&;YL)rF;!Rt*IRtaWLKmsW0E*yxI3Z)Ol!7c z;25ilb;oR9kKMU;``YZ@9}awfVBXvT$j`PuUOeRu9R(Y+E9?yK`rHz3xqc#DRvB-e z3MRD3wq<$Sa~X@G_FL`I3S?X#I+QjRMh&-` zqdYRM1ITJLP1Z(gqt<9QxzLM^e!crn+wHa)dZzgSGyDDzj_)1w)t!%opU#|3p8Md@ zFsklBl|9IKG3gYLL0I4h5I2w<99kF@(V&?02hiYH^2&IMoA~LPv`A4O((M{}9n_Fl zrX)e-Gm0|2r2!3i>*Vps@!z?pho^>9Jih1Q>l25NsR>ZIsW8+*ESw9QFg4-pNMh1n zKnDJJLI<_4ckB>N(*W*rziMW;uPT(H?l0XqFssO;CHOM?SIEn`kpLj6UHX3719}8? zL4A-58Uk>Dl_W73E=wHTfn?xu&?vBy(L-krk*-U`1AuQ4V}etZAPYx@_+L}N4Y%Yj zr{C>pZeTM$9X>}0UgFM4ghqTSoDG6eFrHEl1LhY`;KxaLkl@nME9-k5N^r=-y*7l0 zju9U|EVM&5XYP*`k1y-tpA2I|0sji;r9@H8CyH)YfSekDI^y3!PrrqqA0vnm^~6#V zdgsranIIh4k#LK-=eyE}>mIC|$065iT55OBw9jpxukA!kb?69q>38E*FL4x*D`%5xk#L@D#tm z4aH%2XSUdox+GWMULVlPQx;%49qGeM7=sO2KzMq~o?> zpD;G$-Y?#TB3L8Je-oaITAHT+MDhPdng5-tSkzG(OUU*~<@L%ZTxnQ`QAxUD{fu_* zRPy|VWdEgST1}bpiLET&{DRdOYJuKh`x#^gy;=8{BFeBPU0fb*cu{K7XcpH{8t#+w z>*b4@J#^z@2~}J=-8j{_SgMqon!g>oJEWF6=uY}y8~-`eYEx;;Fs+T+-1u$9-HN~9 z(sM1P;h)>F{BkHh6X`;_E_!ylXR2qBf^5S79KRMfQ`L3P^cA!+i?Ymk1;q868` z>+c-6eSqYj@RkLB9pcw5ve5FRt`Qqgl7~@SN)|EqB1h>=3#<*Xw&w;YKikjKMX^m^ zHr{Mpq#!%PKI9+pBs+KZpL+fQlz`f#qgQzkiwnc!Xn6cNu6@3_oGyOWMio}43Tm#K H$kP2U;W2*N literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_41463.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_41463.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baedae31197c86b32cda6b0b7521d431f850b948 GIT binary patch literal 17037 zcmch8d2Ab3nr9VRJc~CeQkP{>mn~a|cL<6_(YmD~CD}@`L(gUq zG6_ahXSQ8-dJsC(n5NqXqBM55<;DO}2ZL=7X0Xi)wUM?)3kc`{i~O;GGBz;BAA|Yc zD_)W*%T6#;h9BScz2ke|d++;x@6+Gu^;!bHU;LL}-KcFKh(Dr4<};=vPrp_W#2tbm z7;=Egh zWY!PX+w^=*4WS{J?Tpb!^;w=TWlZbjo7c&=tdnmooZq8mZ0{-_raR93c4?UgrpQP1 z*$a%w4*A6e@*V5cFJ7m9iBHdz7N+UZF=g*6x(w?|yar^e{O2-;eko~3lKAqqHL%(* zUmMBBbG7|cmvLRmFfyB%3cqY)$zax7<8@@BhhjFrOJ&wD^V?-&DuK&Y1y;gQkSIO;>m%zZG=Q@ywB4$frS_a8pp>b?o zCtn(0K^oW*^)O32)5L6pSvp`H9a&x$-y-XqzeTL<{5A~Xw*(2nk5EhvK|Lkvj-7q& z_!;-bQ(f)ubE{IFt5Rn~UB`vjF2A0x`7Ju*>t&*LXl&3O2zva!fT%(72@g9G__h-D zK%qc}>d8g&oJe+xY!0P6_x1j3c}2!uMhRP=3~f5vWMv!q2dkd%Jb@$1agHo)CAS%9Dkp(QlTat*#4SJy@=!ZfoRhPncJju9IJ@R!$lN?}JJnnj zXL~8S?3^)O4OhYyy%b$`{+OG4jK+6v1r`?fahR$^P$oo*dxza|$fw2e|{B62}g4GH(BB zesH?%3LIF?T~41V2_5DRuAw@#x-Q~-=m=+LPeCE-0Tf}i*7L%7MyrMiV7xpYsB~eGF?89iPWGc&@U>6 zgExFEi;kObQSg9dH^Wc8PLu@)MC}MW443VoyO$9uFVj0N$_IOgM5TxI4EcScV$#PB z2SioRz_2H{Yp1A?vMES1qMR8Hii$vRnDvQTPjGm!*X#B$O#M-jbcy6KkvuLc176R7 zFTete2-b`OBv`B!0#OqPvb~JY?HSBP`?JyRyy#Fi*PE;B9TJV{7Nkzx{Td}I&g#FbyWW`Z*v58QLEdCQ52YS~eK?|hu%-9ydREO*n8(a9oS%mTdMr9n&5BiTHJ)yk%DT+tIn2*@?%-CxLm}#fat`v-OT+ z+QApai#{Dq)Wj$5SIjgEPM1*Pn!EC3^mo^i7p{Eu=GEkdt3q2>^2hx`Tfb1!pJZ=- zb!&13ZX41igP%$G`3}UAJcd%IW#?qiKu|kn5#^$_gm+CJ;Deu@=O+ZqmV{HVY>%i= z_8Wq=CLxDp;+A0E7f~)-i}{x6{r5a^8}An^bqTXz*#=e3#e9cgu8s$Q*}!L05w8(U z&iDzGq?XmzKWFcZPmjYQ**E{W(~zn?Dj1H2I~HG1QpA(depDqB0XcKrGxKe{Lohqzrv!7O zpl_T}%~HSh%&LSPKbqSw>^Ljv&w@zOmY=r7m{_Yom+}5M^D9@}FH|%MbW^x(S!IYh z|AbpoYp2vvbqqu&Zu?qY_YI`=QGKij=EZ!i-u@JJJt2}eMb+qpuh)Mg$nHcL%2DYu z)aDLuQOWv(W9*O@-NBg;UFO&r{d>5p-_yQF-yt}nFKuvUDxu$uX&z9W1|uMo`V*2P zH>@qot4$O#a<4<*?`uO!XbWi!Yd%yOQoT=eDn^b~M>VJ9)JTyD=wU1rr~a5=6rA#7 z;xDqG42(+=ZVrN?DSdtUb57?)wNbi92D4RXYUO?*w}9F7NSRTILW&_Xr|1I}a_z^c zGUhr-M6?m;fll9AxUFZIBag5REL70*$6W^tzF$ zpNsX#fHjKFJpB=?j-@~Zpqf(BHP~86f^O}(Da!r6AV6Q#x%Nx0oK74-M@D5h%rdB= z`p|%9u$%E5ncSY%N*Lehm5&9B$7epG$~1?W6wdr{B_0jRlS!I80`WERNIOo-YuZeLn#M9L3VLLQ$Gu98z7 zfAk}XVI}LbiG2qMy5WDI4Fv38!o(u2n>rdjI!~7;solwn-6;s^a=14Z=t_aEoTsaj z)PZEpffR&v6{_7DZN1$Vc7;b5bylSG-EI#b1I@&2jd`Z+;nPqhrhWftl5+8biPMQg zp#FGfFDA=fDasW&x~S5Jf3jq@-zl3egA2yYUrD$=JNe+`tm})8#~t(Lwy1I$)=C%E z#m;|dd?u6Yl;O5#YC>U&)d-YhQLTxz#!jGnR*B**5fA*}`^)?R9Xo&L&FMFJ&qvo5 z)y649)DRm@sZ03QPmbI@l5nQV8h{SpZJw06n6LY!@or;+OqFhX)ojZ%t5%_0u@gE= z_#76DdKm^~C3qI?RBYLK^AAk4gYZRYfpDW83>RQ}|A)*k-n&c=y|r*=4&yy}6a^48 zKrD@!eTxvEkd?D=7^m5g&a{fo##z^(0#rY`N{cw#YIGotL)bAs0`m$vo}*M^WUhE^ zK3vGqBBw&<(Q~b4QO;04-caBo286T(V_ObDyQ%CITMCFTMrxjG86%*WH~DCEp_WO& zpMi&*1709=Etf;f7%zpE@eJ1RBa69&Hs$fWQ2R1KwWap+M_vJZk{7_XlvsV)jC+OF zR`<$CrIi>9M`)P}qc%>!%PTPwCZ!SXU4^moa@<;V-UtgRGV--+mcxyn<15A}@qDT| zNA7hS;Xm~n3C+PrEWRw+3Ya*%2?YqLVVx+bf}mc(R->hsILLVbOPDd z2yBsl!GK3;RnG`6z~`z2yc`CKESNHygIGp~L~RyK@#n>J`I95L{K=6VB-4{?rDrI2 zCLk=LdjPP4A@swZLE}A*f(vK}5;QS{(Ezv*SpXMc5(q{R)eylK8{iGzbQVO?SuizR zO9c_Q5E=FwYV#^;Q#N`7{!fBpF^ceu0Ag0|qF6sBT_EPX3pR}a+=B-G6#fVP7wlj_ z04fQ+5!?LEl&(D7o+^{TFxF?c3C=db&<1n{)098z2U}`Fu-7Fn&6?*82g4mOR50l#@3rs@-zpfY z!4fxY0RvCU=68Q;&d*-fT4H`dyD4rJG&SL4i&}H+4MAHTmkXNe@UdlsDYi4}=ZgXF z2%lJ@l<%FqeKPV!Y+SNtr{77G2#%eCe&@`U*|sm8vmKwi=WYpmUKjMQCojC2bbG(1 z7<`3^B!6#%`L=%BrMh+H^I`278owaaeoC;Bbr9Z9N`oiTodf5i_h0AmFstxE1muKl8rk~KjI#GcNuI!W!QKSX0vL(>mXLY7knm>zP5GS2^hY`~DpI7z&2UH6* zK~PxLFQEPB>GYsrq!oJwJC5QLD7c9Nbmp<>zN)9$k5T;RDEJE$WYv2X-y+I0Zoppq zMXG1S-Oc_4)yPeWhGh1{&FJXHHM3$rg<8Lc|AAQu{slC1lNq6?w@-ynL|jPYHAl5! z*Ai`-Iv70&*lG)$U_CSK^Yj7GkaLQ%J;n%BnWQM6jn_z8a__5Y$+iV`nV>G?!TPC) z>lP|^36;BMT&c>vi)x)@M}1)6sZZ#;blml^;hA0q+OUbxnZjrAQL^g5Ksp0HFwVk8 zd}Z)wT{tJC45?lK=KvrA08LZQhAjlp04Sn)MSv!EG%12va0Y-Dba1@_(8Ey=)7oB0 zn>WS+{%4@O!r}Hvo;eap(CR{v16!QxoGd%%0ekXhWg#4bMOz;p$K+s7)~neuNUfFc zoDd`76J>*gAm|rS#Wm87WSsp43Xn-7i8hB~78KlvAh1fvD=RlJL|fRu0;0b`gdae# z+73OKtT>pW4q`iWmA$D)sCL@kRB6L+FBz;;{n38Dh4=7)+J092pm^5&h2ybf-r$PJ zm*`R`@yEHD^S^oX^EYQlznFMD@kEoNFNE8cC=I?S&&Mdf`hQaui?lJ^wn!VorzJxl zP1uW0hRl~0V{{8`B*uN&n=V71!#6`RPFCpt!}p~;M1dfYShXiH0*TPje0&p?!ni_! z>=goK;2tDFxH9)A>Z{PL2+&Q={X&9@03DU}p+#KtA(-d6x++e^DUp!?Mg`D>05PSJ z+X;$mavP(skU(bHu;{g*7>Sk_v2DP8&P9Hj<*L=aO zqe@?ahwDr$uxYcKzmRJsS_P!@fQbcb#Zh^!<~wSo`-X#~bj>gz3Fc8+OA67!O__37 z+q6E@M*MgbxUJ)~86zCI`EWYP(XJ1@YOnborqdyPt~~$)^nK1X6(N}eAOImw^9y?1 z3q2Y*J!jwyFgK`IN`#CI88UH3C@q6hGiL%f44~l5ebvvmn;8Sh4a{Q__EO07;U;j9 zFmW_)6Y)#eBRd47AMT@w4+Bs}zvujYqx z?7ttj!(az<)*d;iJ30vnL~6X(H^J7ST5#q{76N$x437Dq=AYDNWDBN|3I;zE*8|qDRpVswP-XUFy{8jaC@X>kv7Hl3v_v$ z{EBuasrF=bdkR9@iQG1N1^s4#grE&&R2dluhY)8~w^lf<6ro*1|zME0cUS^!h% ztw@(EoUI4cD0+=QCg>`_aYSc|?1}A(9tMR_va4dj6kWbRR|#}gyd~b9qU*zL-WMqBLspWTee7fs-p0X`LS zz8-uz7EBF-sUc-*iYS)AIl;`ge{%Zn>Gt z8ae(CMI|wE(NW5Krf&#h{Z$Vcl=<4F5@Rp)$SkSc! zy4I9#Kd_NDP92IK;>iz=F3>dsT@ybK?<(}R1$vJ_!{U$-biDw?wezyI= z_Idr@yxF5k+JAN!=4G+R0{rpV&FQ8EOO0TuNr5|tRj@Qb(dH^Vc6n^ySGw}=6okzQ zbAoZJ34<-&IJ_vO^wkUcCPCkn((hQ%?-BHSW_o5lb1jd1<~%9=$pt;Q-&{yuyp+6r zC8d7@#*keqXZ-yA8ouJu76`AC#E56o6*EJt?2S6MyB*t3wq?~kdtuikDW`<&8KeWiE%|Rq{)gVQatYg$AV3Y1Z1g&~QeOz@QDAsb5}*eLLqY2t z+vWMK<1XNu$pi2TxFylP@4+6PGnT&LS_mlhxewOp=1Wr4H@QZzE^0$ z0wcEIS0C(=4R3g?pdYP-H{?(TjI1GDmPL0W``;me z*Ky=c%c8N0X4W_2^90eGzKsIMvrcd5ik81$~vEuR^=oGIch37F;Kd zwyC$HZ^O18<@Y9vX2$VDL*6Y>ac?wU&5sMFI=C9l_SpE(HUTbOpTV7nPxR^^dw1iduxCmYMoFbF!!w&oel>}b-~ zoU*yX=avn2ev@FROTbyY4`xW4r}jnneQ+R26-gjrT3ZG_k&i1?HvB%D-GRA{vdSF2X4|{5BxP)3p{+ zMBOHhXP9NdlPTkV!~PNKmVvkjZb2kWESqiN)8DA|7%e`}li-TF_rDeWX4B`J=G8~T zt{0Tl0n$9TZNBcDKvjcM*cv+e!R@HA$3iO#%BT!? zAo$rD?!=M>?+eHR!0#UrM{MncBMJS&a|42;z9m4}XTL!iB@*0^9wbDg#D}H<@b(`A z-)T{uiQxxlQF$EyeV%9l`vCo&UfPdZRC~sP!@;p3Q0x0ZW7a@%AP5Hta0oR^pFk(v zCErj0{L(e2hSE`d{NY6tjhP11Zvyz?K~zd4_+bFUrQl&SHZtIoRvR5MnCl`ft(V(B zJit64B>a*=9jgxw4l`o|z9Z~YNW$U}fhpj^iku|LKN7lsAoTy2C|^;6b4~cfd*^PS zi|?A1Ke>?fbPE=5a%40a`pG=?&LU+_mmYgz>5FN1&02my%P*#Vap{X^^gQi~Rcn!j z4J*Zj%pA7AxB2$wMO~?=c7VqHOhw6LD_TOPey{X)=}IZFsbbMt^*uqA<-pu%^3vsb>Pmk7{QMP-!YT_J zR}6W)EnC#+BHoDhdH%|3NgoMy(|k4Jq$$HHGfL*O3Kw^hCejoeyTeU$D+ELnzRw08 z4B+UK+CMb>ZNpQP`?N+w8o#qQl2x(Gcdk!gUm+lx@Zz7@_PN^MH~g+4`TFIB*RKn& zUr)Z}U3jZcc&jgYtWP-9{}k0-d5g4>_PFAgy8F5n0-`g%-wb>X{Xq1|o*Q_FmO{CxcL;No*pPAt=}CYj9_8a+@|q{{dbjlhw140pc7BihE-n(aFVt$vl5E+EWjl^-$+EmJitX5zogpYbN}_f>QeK$H3^Pp| zG6_ahJFTc`+R$l$P-CG{f+2)KVV7^|bG-{OoIOsT{ZtDK@{DTZR`5vo`C zrF&_TmyAk#Wi%zB-1HEUBr*%Xq+GYL4CfdLBV}ZaoKd_dW0Z_?L3Llc4TqXhKPQLe zxjEFJ*Yjy5d2SBei|1H&(;e?T*a z0eT@%93>cSLyl23x@AMg9Ob4h%FSDpTec|A&t2ZHVG7=q-A^?v{`RWXl$v517^|Ba z%Fm%yA(UHlloxG5U%UlBUNelBIX%BDcWHa$W31@;$44Q4OwkZk;gkv@yNuU$p1pAF zjN{U&-s6sQ8*=A2jnlWC?1cAm5VwE_ z<1yYdAL9jQm-6&Eo<7ghXLz~GG2vsOMClqI>-P+JdFinGF0W_ZUboNT8+Z8K>^Q3g zE*7iFs&Rz(ZeDrm;@MM|9cNrckP`!@!R4UAcDKK|;j7__xq@fnD-ioY_-Z+dtA^8>`cnx*^-GFDKTd_^;99O0*WkKTjasY;)N!;> zo8e<*xaXWkq~U5dacjkzKz)`~a(D)uh0E`k_|mo-%Hbs}$i;pH8hm*oDMM#T&NWw* zTme^z^$@wm-!3XbzY0wTc5picOA2(GZ9uGM=j@!BjB_I=LtE66*s^Sv|>dZgO?$wbZ-~Wqq2mWgE(!o3-K;!!V|dVnHdM?txv~$Y|O^ z2Aw5icaFIOYlk5-FVN~Ur^j={`~w?hsP=FzT&o{8C`RpvT_v!W+s*CC%C~WQhYCbF zncr`oRidT1wr|f%c8Xni90MxR>F`#0`8Vol({W*$zCXZL1$Va}5|b&xA$q{BPl49q@*%0vrfez^nOKJMUx zMqDK{bvJwG&*8D;;;?Jav!D=Vr+Jz4j>r2LpO!34d1{{K<&5v{g!@0mHs>Y0Y}Dx; z=H=tQK{v~4q2e*#2j^+l20wNqFY%4=nhAD%fOU>KJPa>$F`iqzbksA(E1axzY{1RS z@pN=cUZJkHy{KHSQ7jsCx{&g}c1m>l4<90Yl zGs)p}vM(z+mM(Q=XkBBxCMBT{ERso&@mjImKQWx00geH7d}_ks^?(U-I6N#q3VG=S z*i8kI!#m*sHL#`cnyMpJS4j`XPcrW89NG?MkTuMi%q|6Dx@C2zt02c;b0q z1Orul`|!UPcikRmhrJU{mwVTs(>3h$4UUgBgF`#!+vR0lyL|4^iCvJ}Chg_p06;x%$JNm@Ik~71YJmsGt?;TNk^R+CDcewSRJ9*@N0HApHe_vR|N_ z8IPNgwK48N1?@=RzSOW>`zw060UbR3@H#s93evw4>|QhH&+eGn5j`7kN|d($jQ!02 zz`tTX4gG}{TYu0uJ1{d4b;imTmB`kfwC;bTKM zd^fV}UOchX^|@{71ZwMgNTaq>4{xBo*O28}NVTRjOrHp!h;+YmAw+M?=<}wphOb7h z#@gd;pPJ(BsAlh?6VHGUPj?IO}&gvn+o_;Y78KQe#}6)_W-C3ZXR`?M#18#S~q z*-*oaNPh^Zl!l1yd)%5zGp!1%BG#xWX8lsN<9CqPhxL(uutfits^%xKlqsIR!z(9m zyFCMgK3?Hn6ysCAi7BXL-M%Sy%!Q9nF<=jj+*kh$;({OJUbKj#mdRuBq7B!XY9()sy__x&o1<1mO{ieW3@xcQ=8}=WS zThPcc&D$DI!6~tKLbOz5wJX_LGVsXo9fpw;Kb73shz!YT3i>Dg)||Q9B$|db+qRIR zW0XX<&05KWEk3Y^5QefxG*uQ>hc*2Y7|R@YY$(lf$IRfK0eMIr0&YhArYfNFfrl&v zPDO6R6x?*>;BiWE8(=+ZD1F-$DZ!dj(Hq{PUY@R}x8cuSpW&Bp!=Jl8!!Ox}KX-kG zU%Cx{?)nVB>^}DUAkY*izeL$J=l%H|Rijg0Usa#G3ScRpXJBf4%IhylE%O|8#7icQ zF^ipXf5}<5lktoVR52cKO}(D+v8qvz7rX)2;A89ofU~S%+_yX~H}PRz*()B;CalpD zd^7w@$#y7KnC4(dz%xUY9U5_t_A$=G{+%fY2z8o?3X#o19Dk-LH9%p@-xfa^h4_q9-)_B@Xs70gh)W#`%N z0{un!^Zo;j{Qm^0bt9$IPag>%S<#jyWNnFxwj{*bvQ@45k+uqHt5&qt3E9y^?a?H} z+UixUZn{0({zg}@BRILLHcpp^%OjIWT^#K|>hfUcsw$5hRZgTTfd?yF=}(TGRBQ@S!LjJ`$ab9h`f8kw%rRi$kd5=oi|fxFzk;=cRFep{-2FDihku zjhwA>Q*qPW$QRnigsd^4ZCq9AgRifdth1#vrErp(qOZg|K0Y~ra;f996Aw`At(Hh!`bwXHqtYDZRT3k`4=~Ex3O0- znN$@=+uuJtcQ|fKmNtPiB3L7-DvH*>-#phGr;{a(&s%JMYEsJ;>-m(<96Tq?Hi-a8 z<4vAS4swOGOIUoavxN}3#CJ(Z6W^bl_J?4NZ?~@Dluzvn0TXB92yU0PvU6aGfZgVa zYkCtF0gVUDS=>4H1tx=qTLAAQxL_7X?)Hq??PYLTf=T9cW==K?Ed>g=d;!@ASOEjF z5+tKAOLmU)6Hs6?rx1}6;8&*8fPgH>L|h|K1at<1d}bpexlY|8hKAQgMVd1kna^<^ zbDQZ)xl*nK`zYWOzY$yRas&z_+TUfY?GGWMZF3z zM4*~Thdw|X_;jr%E21W=jS`VN^UKww@V^RS^>Q`BD+lZZc-MdssO4(mb%gZ54`qO= zR&sTIuwtO{-Irw4mO8<)wuMgFFCV#ds!!=up3wr9ZO<|A<+% zKuT72H z$Zdpi*(g^6nx@~IU_VR0N(FXqBWovD##ys~*D!y?_;pfG)s|A{y(T7!zDO zc(M^*mwB~-hhmIX?w#rbfK~1TaCJ<`0pf^pMkp12#|W-1xQtgwzwbCOmJNLcg8>f) z19;lc)04bH=oxsL80iEs)HOchb@X|BJnaE&GO1vhO!&t0*5+mQhEnuQvJPv7x^cfO@$C|&{=EMMiwj`Oddg&( z)z9do&ZMy_cxqLLK|tL`5b$%`@?}8o8crkK>ELmUft?GVi`tU9GC&%27>(0aeYoeN z1NRQZdjOxTd1;CHykltq)pR3WHwGuok)H5%0UZ<)T2rJud?eU~Q5j`e8IecjF*Pc( zBcnZjeDM@&dI=d{TJB0*ytJaa3_dS3(u;JJLgU@x?nrC2HO8#ycIGOcb*ZKa*xsyp%U9dxhglx05GqnJV`Dgq|b9L+rGB*OwWo}ut z008#LgyFNFnLjIfP?R)vE_W=mN!4jk30ziLSoZGA)7Qe+-Z6kQKm#Ee;bJ!zdr;lJ zgcSptcOKkX(VqmQ6Kb!AuSX`Mws#!SNu;j=a{yRL�q4;qL-M43^M@HIY^f_~KxA5KKqUsbd` zyzpCbdd?s3_>ptb8u!efU9u$$4lSFOdwyPk3QmMHYlX!GOB{+%F4e9S9(h#Qi3&Ua zanCRI|9pS4@N8JOu7t+cRTNnDc>FY~-G@y3mg<&UA2dExqXRw2)Dt|tCR4n1@{N;Yy@CX8nz9$EqQX z^ctEM3zqC?_c3HR7Ca#mUX8e8^D#LIWVv`+`@ZxJ8V{Gfoy(gq5Iw3h_yI9&3?7gmTv_ z(tfpM8~)t&8Gh+D{JHBh{IYHMbJu71>H8_f5d0ltm3}FpKFagL%dSKKMhHYQ0fxy9 zPGPX+{{x&NfRyxn);$|vtn5&1E6X1RSwPOoV9SDQT|nW(JR}TdOG}jj zI4_574Y-7m>!AQPqa0ubd~BGK98hMp0pP>tz-3^f?56@6&!^2nfOy4i46dsCp%3b; z2Zr<(2zB832Q;ufs5yD1=c@QF>H^~X0qs@Hu#%aYiI2GpammC7lcN=BY zy9eQtng&a_^p*tebwPW&th#Jn+K4W2u@f}G5{WK2pGy}#e2yc zi57Gsd3-rn0=1~)KG3lnE|$un^Z}Bw;{8+-HcN%DS+ai*MZ#u4Zc;UnwX3ssfy}bU zy8wKvg5hV!hzDHjCR`vVAG)xMn9s0z6X=9eHh?31UXks9SEQaAwSx0c*m-(@r^nd8 z!9^B0_a?bH~)j1VB6eE>wCW9kOjHvos$QTlxBs=skO*a;q_N)ze{ zD~MN_+7x=C)(@vX3fv1UUH$Ak55BXaKa1bw3m=)Q5x8volIBL?DSqGVp_xN5dgjQN z`WkoyH(4W-e^nIf0G|{)w9(1A`p8}EJx<13?@cYz_lA*a&!P(%4}?0_s_J6Szf*i#qO3fojd#Q+<6Wq#6_vFvzOvM@G`ZCE$xT$+5rS(%#2IOO`|Q`s zRFl?q4FJB6v}H(Jmef`xWEFx_m|po{J^ZAD^G}{a1pjZe=JDTyMSyPc{B%~#6E{SdDh3oat;HkDJ^e;9sxx_8BoE_fN#bmz^9_{vXxt@XL#dn1?A>vpu-<>YK6Uv)9? zi4mE-t^3I7Mczd0I^j_zeewx6^lKoGy^iH$0K(BnJ{cmnYh54tLfkPi4i!A@;Pt65 zaZ-=&7?_wE@p--2gQAzH#5wn44D|BVXI|(jh<~)Zo8(eeOFGtneD9?ZzM9 zU50l{Kor!c#>Au)OiV_J#hwPQ5GaJf!}|!a8?O~C?id&!Vdg2}xb4UEszamW%+!ec zFq?!d#;Co+uxQq$G)@1O()}By|97fvT|r4qLF-$UZ&ZGzDhMg?yG@K;I+p0UlyLVW z>p;RcwN6o;l4BBBYw9k^YMRRXOjOS*tMeIJWS5?W%KJ=IMAqd*{{XTMCZ=vdACF18 zNFR@9_3_zS6lDz}t0!^mHpuFd94E56CF>TAUh#V?r7nVxnIiV5r3Q&)y@Hac-YR*c zWL>h4Zh;3bOTlc@Ow)RSklWSz@yPs$m^(wCqkrA>%T%pZq%Fj>R;sb-q?wgL4Nu|_`=|t8e4?7aqdskHOjwnt&7`&|Vb)`(N02MU)kSnBls#ikajA`5Lq#GfZ zvYkoD3AXz^&I$G~VIDJNzrIOZ=<3Lo*_$&r*C|NGy-)DZdK;MOYqf>01W<}tMGEP! xbf%O|E64p1#^11;pA>@T4$Umv3KYLbT9H?(9i{vX`&K1Tom literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_434177.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_434177.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5a0d40d9eb2aa07cc4b38a309771cb7aa0aec25 GIT binary patch literal 13580 zcmdTqX>41^neRQ`r$kaBby>1(owjUQvTVhUVq3QzSxOY2v6V73A0_Lg{7Csw9BxTLTOKQ{og@^v~`$^GJ%K zDa&aJEOsQ$o9p{#=9`)Cn)w%5QF9CChr=^LXrBo}m(hoo)jm@ny7Z`?4MNAR z12G~Sp;-Md>BYQTj5xD(D4(C3!mMGPa^pJXrgh2-)+sN{T|TU3&8sV$USSPud5b!u zWh>bt2Q^kA&L}fxiyeB{6(w)U&g4O>(sf#uty8{fopOv-$(94we0B?4@fLl?06WBx z9+&Y7UBC4a9>{<_xr4~4a2o#}2O+vf@f1(b(iAlYzwF>nt`F2ov+F48RuaM^0UQj* z!P2v;?}bUqOL*xpbBwBI1|H!q>m>r?cEWE$kd98dpR1P%(#xZ}n+4gV%{9JD{4@!Y z%cHw+)Z~%@&}JO|@^a%PCpYezw%Hwx=WX_JoBOHV{*C?a*cqu z#X4zoPtHv2oW5d(BCDGlbvvi5b|>qww(M?cYHfuY!D3|{7iMfz?$KF?l|ao*0s3(c zvj+wIT!$d?Qp=jWK()VZNnMK6r5|)$@15_BtE;@732lC0kN;?J5NXPn)LW2x%deVm zv@W#9)!U&`qYp@Y&v{>Z3^*u3(Q)*p&OU2@DluTs3A)0Z0bi0OARx{ zfQgr3EWD1D%)*>-9j+DYNKg7FAmkfBWCq|EF!JAlmQJUiLo`MaaXg)6qCCh*++yIK<^OI2W3@9_9EsTvN zRPb1M3#4F7o+3A#VlnUqsU0zdsZi`G;fu#KsoK@AJDnWUrjzMJ3I}hVgiFC$z-Il$oq@CcuWM@oopGJ+Iy3PI*{pLcNoB@SujU@}2b9N{tw4zH1* zp5~k*oNdxN$_jEDXPX*v2+|3sjU^|X!>3MA3`2o?cXRzFK{D;;1hva za85Bv_Y3rZpmYHz0UmN8opPm9vs^yl;|v(7YTtcM} zSo{?qG+*C0zb`ZxR>VtpA=NH#$NCip(|hd!cc6E!GqYQxPD!_C%6*JqWHUbL+Z z8QQ!@66T`!4Z*IUWo|rl7?~R)OvH9a5iR)aLezp9+mX55JCG=}yiW&Pf{MAbq32NH z_AnP2ym=`q{md3spq+m$-==Y8cNZ~L;r6tK;igH3ZyK?>!pHLls3nXoaqLU zlbIV>lj{?-y|27<;$^GY@)4Fn^^BmNnwhk^+#p+Bf(oZE*|=#JIa-en;U*HU5XM~@u0_6hRtlidT>K0Dsiseq3* zk$(Z;RNq1VnKL*gW)xY0U6p^94O&W4JPo)d<~Ahc(vo-?pP!O+p5hJK%6UV&Z3)N* zX1osQA!BJzDg4T&)V!HDVR;DLvlLR8kCwTy>Olfu%$E@kYf59;#Fqigrdi-`gtwdd z0>Ih@Sj#pTivoCml183#zKAa;y;V#}d4;DOMpFTJ^Kn@Q9=tAD0W^g4{K*J5^A_Il zOmJyh5V%UdV#A({sSy*1Eg)q#XUS@!)8JZu%UY=|YApnAD^H8HtJ(-F#Z>3jc-#cK zG7DOhs_|5T4oh%90j*E9;j^LhQZ=4xUiv+tYceR5sV^qigwN*JN>Ayc0`zPD3VW}PFLbH+I^IZVQa)ZnK@Iowj8R}2mlgP}Ptd-cufzLx z9B7O!(?#X2q^{|y-vC#2Hnm`!tGHB)>eQG$+cv;d^Hfao3{2ZMz*O~AOmqgOhU}KG z&#*>Dc9UlZ-@s~#ZDz_MA!IXF5i{|ukm62bPltSpuIAYZy=$;N_kTEsMtop7{m z_=I?N@r`TpO&KeZDIb#5;OqHad~JFM)Mm8Ky)V^#A;+-H6bF{)I)qwI14V%BjuyueVKMn zjM}eocxAal9ATO82#ch8ngG=26SIv2sY3};J3D$&kW7wFaiusfBNh^e4H`j4OgRp( zFfrmdtT71^*69{xF1M2dD-RD0#x&uW1pkoLeML~)+|J2SyVb_B^{s-;MhraJIq_r1 zAFcx9+l-^McFT2Q%&Mf>XPaD2j;E97GLm5ACEMCp+uElDU9!}1(KcZnUz?rEr7v?m zcpSYrs=-k$9tSZi$3Z=}o||ym$6a6{uBN91H*J|f#gO$NkabvYl zkpNV!a|iv(Wupn_po?XlF))O5mA~q^ab)30*arBj4j|nDYzVJkQ}_XMT{W)?RfdM* zMa@Xp?CrsbPQKqBbO(Fqtf4EYa2L|;itLS6{g#f_-Z^;pRkV8m=>|l+!|xAO5uO>El;;U1qV}qx^Y>W+!q8CSzan#$FFq8x@{?ZmM2#w7DU8E9K?2dR)c^5Kt z-;>{Se|6&Iy-Q!_#my;L(J6pG55*ro!R^94sYi}m{C(; zB*-u68hrG>X$%2dumPEC!nJX8(}R5DHN~7FC=Z-g~w$`cU&ru~TR67oYXL_|Q^(?b6((;KgvyubU!lr1j=Y(UE)f zr{k#MNNn(x`<7$A!->MuH2p82!n$xfD%j!cPFN~JRdav0WZnj(UTnVIdb2fd-o0er zi_Ck)*~ZNszTQM}WoU4IQ@9Eh?Od`nBTMtg&7ZX1X^mT+U$PuPmIJ@xKA-(`Hg0(l z_f!^~MTPa@v#8)X=mGX+2{LaFzkv$(`+6S2uB(}=fi2!0+8sC6`Z^z&i>{5#jRc27 zS3Wi0X}>#&YI{+6@4eyJiC4cmI}|%H^rbWQx(yB3V(fTq%86Jf8gRxg&8|>XKl3W{ zEk#+*Fuq>ApD|`js3yEMZrthXTvqA>4WVLW+=i6feDVZj+rUF&stT)+u?Z=gaG5En z4b!NoKKvT8>_^J|K21Vt4Dg`~$Wj;XN9KJg4OK$>Xl>ndQsPL z^xSb|I1ZyS7QbH^JRYn?#x0>i=qTJ8Y5B|$X+_(gkJ`}ogUE2mC&$hsrC%8^2Tj+_ z^X8EJhGs#tw0SSuyf`;jCZt%ma;`8j4w8N2gwo*;#aoMF#e}3|;5f z1%`w6u>6n8M&gUIr=1tTJ+Y`KZXE{Vx212(-qzy_sxc_e4ju)#Ae8V^v5ZH}N`9>O zNqsV(eptdv-%@%semyH8ZjKf=#z;ag_l#ffm-=PHl8mb^Fu<|{j3I|>15D*4ZjW}V zYJ+}oFAT5P7!xmN8Da#=vGd2HbEn0B3y!M2Dp9Dg4>TH ztWDdKbp#CWQ(X8MpO2qOHgbddX&x?U@eWEkcrwfH%Eil4ZUErHl`QS=KHi?85M#a6g*ErhUawcwos?}8w#jp>?ZE8zq!yKZuj*>v z*+%qE*8xa~jo-tFc0u&tUg`t$_0suLxH6E2e}L-R;>9mKD6YsZN)&JYTE-M_A$;a0 z!@UTwf}F+ng@arZaYvv;bb-rPr?l#WUe&4VX|NVXC+k%r5*gSGc2Gaz2HAaum>8T7 zqrzICAQx44&W|$^*UV&n9_Pa8e0*hm4lb1^r#Xkq1vfKtQ9T6+AZF$6(tshSc}veuJ-x+KxjXp# z(_fsvuO0By50mwVU^}QoBN5vt=kJ`quRZ9amv#Dp(SI^JUC0(MstecL z-f?qBl#XxPpQvt1{nsQm*Cr~e|6Hn*)vlCLhN588ds<&pOu92vd82xvI@}(r{;&bE zWXB2gWkJ1aO4&1y7(TjKbUYr0cwG?(2XHJYV4!U}fG?I#9Z3%sk6?0ShZOsTw;7VL zOc9%G$kHscC`n_1cNb&e>B|mTtz>YQA;sP}@K1qI^2m9qhT`Qo&C0NUA=iB__UESi zQQ-F-#Gm&apmMR})ckUZ9(UTem+ReCioJT2yfQa+S6jH#EP5=5j}<77Jr-GEjo{5y zuT0@*RrpRJW3;(^SUF!5fsjb~$5p(N_|sHBl3>=p0X=f`kQ$8#-h05SGm8EGnDE;X z4fu{>@6N*B^Qd@LE=uzZD<=vK!z%Gtq(^$MsCfAp_TFTLjMpr3Ua>(71&Cs3MZ}hN z;L-m;I9|_)tXVPca7+VJ9}zp)c_^7!J}3%@ekdnGR}>R(<1D6neI7WC&fCCUB)vH5 zxFpEPg8)Hl1J@IdZI=a^9c0!txSgCcQ>^H5`Y{FvVo-$H3sUS|5o90{b~h8*At+oI zW*m-LN4=K&DQ@;Nh~TE48v&mN?o58!Iy{XZK3%rrj_?QfBsxx{#ASkCNVRZ{1du#W za$IzPmxm;sEbGGFej*-{QPPvK=1CK`2#x*$elDyXfLzEoV5=5xq?M|xa=+Zi1@_04 z<=*y$PLHp=-+2}6cde0#gO0G&*B96B@OI}Y0n_<}|Fs~CK$3KWC&AYP`GMduq}{xv zsYROFFdc4=Yj%3Omh}aJ8UIvBhIG|o3DRv}(l#S)Ggw^b;@W*nnnOr)D6Z*P(sUzD z_wO9PANyh~t{DV4P2Sag{(XU0-g&{>^+0F5de(n7XdqS+{fpvI`;DH3o^aELMGQR{DQe-feV3SbE<3FIqe50uAiMh8_%zYiRYOD?SU?+ z>+;WD_4qx>y1aVt;e^f{tP9!ZcOYFI_G6expT(i3`LfUz=!cJ7K()`s@ne7l$VB`Q zVoO}R2?zjVZS%JUnYgCZ+m!%M!k%DNOj90{mJ{JYR*|~_>%fH}vL}x_B*m9zM0mUl z>|;IQL7wYKJ;1WzQ69cJkw4u&fon(kUu6R^X;y@78xRHAA;qdalsj#w@X{39>7rg+ zhS4cA>SP&pF~UA5tW%3mzu-9-?8}^NE-hfk(gI+0EGeA8<}$HK$pzQ<#R*T|L@M*- zu^K#rQ-gh`_wdKekG;e{ZF`VQh;S3AIRX9gn#0FOPDCA&ht1;_~1K`NI$dyu;GgzU>2{j@j^0brt1e*R1H-BE_gL#Elc|ZT=VQ;&4CZW^?Xuk@c^5hj< zeZ&6-)>wZvbHlUXiJtuYwNGEWpV#L-ybRpnjQ`9#L+}s?`wjp8w2yvZFb8ISUgm30 zn2Ny%I~c5*YYSB&Q_T{%tBl(g_uPKr<_mG-K3^xe?m&6I5bT;80%uTR4S3-Sb}pHk zk*WFPJ)gXA=Y_bb-PeV+<@R|Q_}YrL`;H_mCBfsMlD`^iK${w|D|k=T5IqttL%VyB zxd%LYMP)&Es6BKgbQD!Ip`xb9iD*-FG}?yt97dMI;Kr-0yHT=GveJUAz3JiUv-STxFWh`HjPxTAX$mPiabvD_HN4fz%numv^-vOU2 z=dE_W4Z8L7Cdu9qWvrh7sjadI=Q3UyM$CL}GbQ`)5*9vnJ zLMTBalGr-poM3NKqOiwuQBbx`I@y^C#{upSkOihgxg@a7RwOh{|B2H6JEi*%s%%A0 zG0L|~-YofskBP=x4WVt{ouR>y8<9 zRJ*Mv&a^Y!v>jojYH9|ghEXNh(MZeAKik#*vAH!W*4?axG)Sxb7X~Wr%wPMRbN#q+ z&7<8pkaN%D`_A(_-}%1Z@qZYNdIG|oKmGnJT}u#uL799ouf*}}koEVh8 z?5 zDoBi4%8WzzgmV28$_-B_r^}baIE?SA26azZ9aH(yGv@N;Ln_81o0Xn%G1l_^GPa8)5nTX0CEUpw&n zt^wb%eS^Dw2iN4*)(wdC-lMPo%0#?~pQjk`{-=l~c32B7P*NuYE zmIj8ro~w^qpFyc8P}*das?rt-f>Uzz^;}<3`!i@&oUuUb;LMz37DgR)a#qec3vEPQ zoQpH#GIvCq?~Ai@?pZiqoVmbT!MTB_Vpc;Cvp^QD#4SJ$%1}EVu9BlsJF@Yhl`07g zSsXiVr;77%#^>TI%medPb2iTOTzsbd957!EM{rf3NQhr37-C3K3FG55s2#2bm7?-o zsTP$Y9?VzEkrIWBkHUP_C49OyeCuTEQAS~LG-n)AL<*}^PI=u>=4TMCid><_+JQw^ZKvy>(QMfh2GZoEm*+ zLGGhXT%(lZN>@o$plO!Pqm0VrYPt1XVK2BQMs-7d9i2Z2(UxcMt^;1YOK2us^BUe( zu1>~n!ws&9xE8K8zh<*W=&kO<->QfoDd*lFDd!tF1?Mf1SAAZ>xD6$;XUNw?H-3*x z+MQC4;y!HRER15Tk097e*}6E^ED6qV8@W1e(+y327f518e(GK2IY!Y@D+w~W_mosY zHm?s$`|@fJu^?c?3ev3CNKws1W~PGwlW*XfLR3xo!(*a46geAYuaT??s=Ovq5g8Zt zQ*3CM^-uUl7*Q2qM$U`MiIGWB<7fSo!$DC!8)QRaQ9Cpq@<+CA5!F($3bKr-WI_>9 z?~jBgMgl%R!-(o|B*X@x-B6CBW1F{4RQr8X5mr>6kv;+ZArFe=A(0#qb>Rp*!UTQ( ziDGiBkUS$xP8Lc7McTk5i)5cg2cCtqO|YPX2^MK?Q8@*ZQbXjM6?K!-6F!s#)_ihO zwB(Z`lfGd#G(9C!_=8r%Ff_u3BRL*gDh`j#2B8JqX6^%P6&e}}L#g!5=f_|T(M-dl z22VUJ8vKF4^u+WyEEp@2Vlq)5}Nb{La=w8+d8_oY=;`rC=vO_g6w2)ob87pq5TMp zAfGYfv4=1i=h_l&JekpZV!M`U)7-(tLB2jiSI2fgu-GnZ=e0b^_hc-sv3<)W`K-Gk zH6pm%msF0d*|y|r7R=4D{;bWzpB8N1*ub*gm7Ga+3+~N}bjJQt>_C=w^IKEZLRFhU zZ;tKG8ZArq2Eo{nrUkk)w)?)(mOL#O>py?(i+xx3rTvT5nZ}m|>sn2 z&(EKSp<6%V1garnSalMnovRgu*^wL~j0MKHY*+nY7npHu^D zaY9;l!Px{OadrsCj>WyV`~F;id#}*hcV|ZEJS;7|$@wck???^{rrMMh`bk|#NB(dm zeL-mHx?L}{bPL9w*q-}ZQ_}kLe_E#WbH;=*ImC}GYQCkq9zdQ-P)QeGb)|Np_J-n1 z{WX22a_jBpZ>io#AmW5bUKDlbF8JUyjRhG|6At*tgJBkFe^J8*Bh&0;KyH@Up#q!D zzXY}Yp7uTJGQkm}sHKu{8qUu9+Ng>n%hfU@xtH>RL~Nb2&JKD z{y(8!Q@G^Fr}HyPPW^O>C)5{uV6^a$Vib(#hVFXKp&ccs zm=tIIQ(%X{of$YQs^LP4%;2qHj{+~gVBn}>6T+)w_oK1_6dgbjI(uv*L?SuE9!Htt zGN3`^F3VS6l@_D&iOayg4b(5eU-*v@{VhhUY6#jm*OlnH)EDcGoy$_zIdj6ylLF=8 zj|fz4svZtAvHn);&DOhAcdYk2l}fFDOcQEbvQbdE zvy?8rBk4zYg=75y*e{=)Kgs()ItBgQqbdZdg12U`|rwHtlv2e zY7Hn?+8FD{>o4ky-oL0X`u73(c}0?sB<0a#Fr}U)(#2C9H9lRKnSxU)nQF2zXJ7X@oMQgx?a(d1R#=Wwr*G7>zlmwYEUdB<) z>#3+5oG62|akc`b9GXBTcu9~^IyeVsM}PyPo&~)F@Pyq!06>A#F3$Ok(Ux%<3uAC| zuKx#j1%M2AT)DBx9iws%VN*Gb|U zz;c996*-D$3!N%i?`2OS@as$Xi{p7denSa=0m^;O{Ll6Y3IJ=gF@Jt@C$dSle;!~d z2>?b`aZNWIdF!48jqssC1@P4j;dM$jFt*lMbg{9>K3G9`>#)?${a4t-sQh&lp(}C) z$@mH~8&(j*3c83^kV&_K)cy+62`fl7ufWO>E7cIOh{S7Uk@?M{)5ph}y@dFSh=?&7u~N3=BDG83gA5 z5~Y`1Go038fD0!n1O~C83tKVzv& z?G!9+v3*&C10aN%`I(e8wLjx-7YyyOzWWAqvNiEatS@VJ@pgV@VVz)ZO?NJOg^k^U zxjWXstf%K%5-rKrlp}rYuD;`*zEjY5es%QL;LX8|{?*v72PVthP+}-~?1Qn`-eq|v zYjh+p2*%n}k3hcwKoV__Z%+mVLlyL9b0d)Evn^M4F6_+Mnqvc5tMjsZ-u;>Fif6%- zu{Or`-!}tBk%*?$f~g5w1_%rshh!IjZ2qOs*MHG^wKZed2rZdx$zzGp&yHR>zHmI% zm#+WRmktXxTLk0Q*q&vhdG18wMDjRPpIkVZF*XBI;jG~Ir9DFR*2Q;(s@;NfckEEs zR?WX7*xJxAJp9{&bz^M*1C{2z{!9JwQ~bues^;ZfZjkrfRW;t%T9Vr{+N!MG!yilS zSnvt=V>>!?X- z1V{4?+n1HsDi=EzM>F1D!O?r?q+mZ7?^^~4t7X20Z%rLfj|fd&f~V^?^L4{r+o5~5 zBZBS7-J@^ZJ9=6;dOBl!E50jRQN<507}HLG@Z$Z;cIV};`7VBYDzNbKjV)hpzqUPN z?}Rp;ZtM=trjIQAY^k>WEB4mKn-}joUybisc6#^@el*>?5W45woFua@1YZR|J9A}Z zVI)naJ3b8wuC}D|zO#~Vn13^6pZ6_0T$lIE@8SDW_SCtn?uCQtT|#yHVtvNZ`A|tZ zUwrVCdbPs&;$sZ~<0;d^`i$enhgzj`+hgi!6gm}9=$}SmPy+@1Gzy~zC`?bIFl(LL zR+WU+ojmuEZdFaV+QA{LSO;^Sxw=kpzW`9B|K`A*9m3WFg8M+cKkI2o4GEr2i)}Si)(y>2)sjQ0xey!~J_~OFMU*^{Gk4%%69l?5urv8K zlop__f{BS~aCAMFWXQ2r&kH5tr+fx~`T8Qi>goK9mLqRa z*K>dcfOdG~LrOrzbcfggj2l0QNtVeG$fBY`K%$8OJG6HQuuG5$c=`rr^3XMepy`n% zfMq$%o(GyR;>bZ}7)P3|T62}HMaixq4%9xNWcO)DmUfi`z*uzWc@;$~!E1mAwUG(- zO(@3}4?Bobr%)6^5eB@l0i?$;m~VUpu26J+vZo;e`0Ljw{TmcrKv4-0b{v&`f+9Sp zG60xy&8*np0xfDS+yF=QZ!scku+HsF?EIjI@A$@0y`-vMGE`?P*IlVzs1{VNSbzNN zvel6sOjTs88)Eyj277W;Fw~`bzcDl|shXAyO<8LtKPXsR;7+nQFB|3!S)Dbxoj*7K z(h^m_Y_iXdCC0L5=Ug-qeXLbkHL-o58lm`9;uNCXxkS|>s)@uzk;-Dvg}_fG#-M*KLq@40`VV-3no;Jx z2%tFuG&(>ow45%lNRde&)q|4Oa`1kDa|Tc}#T<=kfr?XN#f`53R?PuIkAG|+?@yUm zY0%d?XxGT;8MXAYnOTr)@bF<35wB+-K5*<3}T;gLIg{-KP zE=*)%E5SWr-$yj>qUaZ>nj8_8;0&-y$b^wxBg1ISe{l>f@YpLXq%)`;dhTCsWg2m>pN!6$PSDOTP zD_CN#_M|dfRhx1ybf+Egj+>;iE$vB7GR*Hv)i3N#a^F}QA8H9#?L#Bs@Pb8RS%;8r zTkN287Ln1B56I`lM#y7Gz;7u@xSR`Tt(=Jg@=L-?m;A>_s(fR`FZ>uO*Gp9OjP}4N z`9bZeix)em<`7z)JJF|M98pbF8>OPUNCCNqQ;(#Z;*{5OMo|=<3jGKfHE;%yY|WvS z7z`Zz=V`_dR}a2W08)l~58oQ1Arb@^8NBPDuLGWC1`#p(#RoHq9dCRg;HMBTExn{P z0tJ3ml5U+|lzk0zfg9?BA1$2mN$x9s#)qBIsSs3%q)#;Fem@YJj!aEQL@hX-)8mnF z7+uM7DbIP1(!fG28w^D9Z&U0d&}QH-JOdHjwN?jQENfm|wteNiRdciEF7-;R_bD~a zTNiupY`EKeP*63*cE`Jt2U7LO*KR z!32}xNbuqmn|s`PBRCm`iN6gk;LBcAq9s&Erl-b((vEzNs?cJH25A|6!=Z78T|mYA zP_!4-Q9TnOW_mpMGKG!xx+^Q0# z>zclxuhLR_TlcN;o8wY?;P8L;{Cf{h-&gD3JALW&BQ?;kZX*&ZsG>B^>sO{)YX7kA;$ou)@9_pqO|kT(9sD?JN6s{|w$ zowurQR^#NI;J=Ok>o`s>{WKtbC&Tw9XNAexM~L>(77uBE*h82bGN#5$2E69~2cYn> A)Bpeg literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_48845.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_48845.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a14c2dc9fd02010dcdbfc2164a11ceb9de224cd GIT binary patch literal 10510 zcmb_CYiv_jn)mv7{Y-2pF$oDIkU#=?G>{P5LP7`-AP@*JghE=^@pXdpa&oUjz#LD_ zs1s*x}Ssi7LFL8{bj|13MZ()y0kB)ZE=NYg+3GXv5-{_K9|T-$Nd z8&W!Zq?~&m-}jyGob#RU<@}4uq^F>q8~O5L*LsTj3+`k}e!B7EiISpjP%OpLBh-26 zOP{AnUoomYuc9dh<)()Sr3|w4CF92BW4OmESS71s)vSiq{!Yy@EOTFX4_3kD6*%gj zfun?#UVk0+YvSmC#(wm$#_RHiaaQ4NP&-JmhA-&*#(N6sB;iiXQY*J+O>Wgt=5BK) zSo135d8>@)uQG0tA0N=O)(=(p(ko`Z=XE-Y-N@S9)KEbVGwCK^DEX|ZHJ}u>!o>=Jw-VA zGFUB-EoV1;NT1J#7nYwrf6EoRrs5es@9t+b)X(iT+kBfA-NOu9C*5w@=@C@>za@c)^)4@E46%1#<6;>%6=FGJYs!YK^Sc$yYr^O1( zhqpt@Q2Ie+`uNMz_tNu$b!$K>vXI3&ka$BY3Q{9m=J5qe)_~QZf-LO%HDI-#}Dnjc?YXe$`HuzzefsJS*Dk5X$J~#!bpb}K>htoqvSt=E%9B?Z98jA7* zP+$|G06iGPbT*?+$cE|UaE1?jMPccs9TPg0XfrDKNqE`a6TB)^j0%4eUiKsi-WF7e z)WewLt*8oVNUIuAs1nQ*^^t<51{CbQ7tMX+^5eM?=**={GBd8<2qD_wB%&#b1-F02&zF_SDbtO@-y?pn`71YH5p?c zXh3z#F@D;K2i{nN6B@mNcZ$gROZSPYW%!|U_-S0p53-W?umkNtW?3o7%Sb_BC&1*( z!KC+O$wD@*!MzggB;4;{X^g|t!_`o0bNFmtgMtb*$xz7M>8_PK1uo(Jv*e`&TF?q< zUpBiNHKGz@mIE%=3gz0r^zoopv>oleug|dRz`C%ht4sYNun|$BF8G>Ej|FO+Qb*+J65yZg3VL!dY>J zB~#S#KF-6s9nR6E=5V&zpVJ)6j=Gj$U1Op?O`#tSVyQhQ8ZzVl@!{MKU<`8J$#Dno zfeo+b*@((C+Uz3doHgAI$c6GLDk_r zM?_V>V*sWJKX4K01~5Vxa)}yp?>yfIw_EFy-Zj`EjSq-Eh(A8g}|FddKR( zc^&gL@SLl`=N=tzfL;S^sL3(v^o>rA)Q?|rz>ovP%I6((xV#`-O}lqDHN%8xl3*Oe zZf?vy!o3eX;-|$whgTM)UcN;cOw)DYx+tC0mjpYW7|qki!^fi)Nn>g7;1BuM8`>Fd zl#cF6=2r)gJXxc^Xeo=H^Q?)YEkfRw;NeAUNwiO})&#qr z*oq=oVs8uf`dMSrwl{ce(P)qEjFk%O>jYzc@Zh2;KVjP}m^R0ag0U%h@Ttif=@U#9 zzu$l7$n7I>=WJ=RvQ;p(1`jQ+K4Z2;Mg;Sw-=Dhk&h2;N6SK9+sy4yg7VKO!*&@#H zd5oNHVEO}}_rz4u-WhLPCloc#cFrC6zGALZXgV-|MQA!H6rB`IJ<@tppHIXpqf>%? zYrIjg*Ukd`n{$WeoBn8-Z+Up)kw@6mD-`t#rgOnVKj@9q)#2*Mu~==w*6;`J8~+3U zWBq-JgSZ0B9a z+@*Q;kL~k=4?hsJ|{%looi(uV`7cGufenCgK!ood*ZA-jDux^K8t3BExSU1I7f~6+T&ko-G z<-AR(KOp2Egrm(bj4EdI(;tLCh)&#^x;d51+Zs|mWelHk)0e`RBF<>RCm#sRhLCcR zF-NvVuL#y^fvNsjld7T2^{HAaui(cb_dPw|G_Mt^4h!bPAbWCBIj{QOnFC_7ebyfF zMLNTd=%pC@WqWM!<_ChQUJ|}uShcWPFjdD}X5SIEr(c2A9I5yfdcx?ZnJ^P6irQkk z{>*Iq0s5w}DKZds#WjCs8eYK5r$qX)sGXQ{dj>E1L=Eq9jzDCrnINW+Yry~-&h4Az z#$4b`WeawSNT_=bZ1i>Jy6y&rsG)4Yn}wq?aXzm&O8trk7;Mv5ls*{ympSOZ^vehFW> z2EP3K629tQ+F7S5q&W>~pKjv)-Bn1Qmxgs;^0?ftTrtC#y8>i3Lz;broz(2#&VfXd_N z`Cmcm31gl%gbfkU647qKzCFp*-oNzK)w@^c8t2&gR-yG&vf;GAoL*#%p?AW0&(+Fe z4d(=0;$YV>qRGR9p8zWkPI@Q7LEbaf0W<}1TWj=O zKOS@8ip7;1S43~n4pHfy^ogoLaBy>WkKba(FWR?HI)vg?x6WMx@OR;l|4*p?J4mI> zl+iTZ7H)fNC{3tJ6Nb`7gK@eg+;Z(eust{ds(9KMHb#2FdBOH4`n>7w;q8%$Xhmco z)_$k+cIRySH-{e_eyrac?0BhC8a2TKFBwX0jZ_LM`y!(YwM06wHrL>GW5`KNjWaw1 zU7|}ZpBN%NH_pwRi#mUEVUfw3HiylTi6m1TZMoHYvo&6kTwlA$7^d}MeWdXdQ-r_a zpYg{keiOiaVdd*gGV7vSZq?nai_^)HZL1DjU)uC)O{$18<^_+E+Y+^smgZf#f{Gsi z(Zq46SJT42EXI&b0w<%N0t(Q;GP)J(=YTSxLMlX)5S^q0cjNMN-4<+2^~<_e&BZ* zjOxcQticcGoq?2-CZt3v)`+>y-jNPzSAzs;&4C<(-xpVwj3j`>F#iu!7^HjEDhO7A zgs=*XhDl0p_$#{0M<;n-dB3|H(h(ls(?8-a@4r&M6d6DJUns;3o&m$6WIYga4!hms z&`9Sy>i2Otr^4ZcD0dN8_}xe;6Ejzd+vB*Zz?By&QFY1Vp5iv(rgGGI84?}d$uS7_ zU0__sMfD_Do!wkHo>oJ1=MJLxMg5YA0JT2_(`h|FMKT#Ary;7Qa7G1c=Vi$r5HcEw zV29l-pR4-I%O%xsT4wCH>p=G^{P7P!YJwD2v7k$T+k9j1%-$qZ5o}+s?G46|KinGa z`IBKoLe-Hd??^&z*nriwCaei@k++h}1_`M-+#G?JF4UYfRK=Wvp*nZ~(~=uDTO((} z!_hMkV}hG)wuXKYIU9a&!MI5^JQ(%gvHl&z;`e zz44yg=Vu#lJH9>mUDv~|#L1q8ljnq!=N_Bh0X9_n(8OmJu<{2Y6Yysb`=cE~-bO)H zzWOADK;|kzRfU5Eo3vuI9o9RPG;Ll!R&w)1(gaawi(qO=n%;y-vu%1PJQTG60-iM2 zE|^;cb4$|vCPd^0%Yva)Fq9?@8)BNcI_^stb|zFiCAR|KxhvPuWQ|k|E){;tr@@#| zLjbk15L2TLfI3Gi>}i2JzNCV4DW(zONol}smN8mDUnUd)r$0>vlzY|`p{$F3fO@x& zC@KoFxnRy9>kJ$R^hg&lAU&)kYm!jbc*&U?q|YdD;yWt-MjEhl#2^DpA=O7pN#PDC z_Q2?;??J|yk)tku)~p))u&jn~mE8d-CvCT+L5MEF$Io5C6}HckR$+)5k+jML9<}3& zIER~X3qd97+?U6lW30o;*BB*j!Q#1cToM0Jwcq38MMc*{PBxU89FtUqNcW2>?EQ*b z&O1eNbNC0VjNgdAlKL=67{pWNA^`Gh_~ZW-Do`Enp4cRCu>_5Y7YywhaVLrYrBQ2?nKYog`NvS&xJ(adx`g*$)5g&o)MvE zBr!UX?BRqeF2qFak+NuIr0feiVX62Nvni#e%*82&G8Rnl4eyQGZFuk~GA)xjzBQJgEE1Y*egTgC2N% zuSHLt^MkM&DVKh1>;pamnt(RI1a!Wv4+-9;23t&|yqAvr1MuG`NbjNHDH=dbHQgMZ zfK2`*>#R{o@`oi6lKw=%3_dsHnq5xb?P#n~6LBRueHPNmB*)*lir2O+LG7OQw^qzdaHUs^b)oiBalPZpdH9#0u5T~Q>EQmgVb!Om1ZWwZsm ziM+u%;sdTxa#+7Uqrk!7euoFF($gP?wwNdFxqbwGLYaijw8&&y9Co-w#yRQp`Xy^7up@I1{`NR0NlfC(pRYcrPa^@nPbS55mvM@e#LlLHJ#hlMxNlnK}l& zBkWyDGB@ckdeqBKj<{R7H~@pvruaLcu~JH!rvE}2|D7`Z2UVKVPzp=1==#QM8=o?T zAq|emVimKe9@!G_y`SKCp~#n_sQro#1$<5@PAcGW>3S54KO5Hi8gyR=_I3Du9lEcE zP318r&0npQZr#g5gF=y7M=6-=CD%$)ibFL0Myi0a72c?wsZAAmoP*Muxi(U;0Q;h{gtk53L*;Ewnk%mvh#>zT`PlY5 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_490790.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_490790.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ff14674c810652d38ecc0672550755cfb0b971 GIT binary patch literal 10687 zcmb_CTTmOvmNR;eUI>YY%*%kmHrU1(jE%v7gCE$5ZDRZmCZ&}%;vqnSnZd-Qm6B7r z70FJO#LeA{oV`^R+0?ZhQyaXMs^HyQWp94g`}ojA)I~Mx-E!6SkN1mQ`(t-McK7s* z1_>gJZ+4r~O!w(?PM_{Rea`7K{FBL~C!iEtzaIRzVuJVsZpg)ys=WNIf*@uHnxM%M zqE~v8y`dlA2s9cWxGV&N^`OI`G_e;Fj$0em!mdM0Gzk zu>9`T=?Ho=o#!Il`B`MjUk6^W4!m$3xXoptH)ZeZH_}C)sC$jCkpiguRdg|6nM7$z zvU8$;hE)9&T?%}2=yJO36SCI~`ej~Oo80SUZN&?`EZr~s@O<$Weo*kJRSCw!r`|ns z(thsk-edODt6JTwS|XKz{LtWT%U9>W9*9xJOV)i zic%m?3uL!Io)pwh`xwUnMC>*DO3*FNQ9yo?4&87+Qb2^k$c z1ns#qr`|qqKk3{AExF(xd^leA1L6$9!Dv!(!?9djXmU<*KZW7jv;nR%1Ggvx7teW3 zL28tDKQ*4OcmrGw%3r}PL4~MfN=*>HQnU$`P5}>J87f1CB5X5!+bL2!Hc=3&}3S6%m=A4N8c7yiSplwdSx~u zwML|_LPe+mQzP3O@hFh+p5mvl!0eNdtZRYGzwK$(0Xh%x+orpkN zz(EkdP|!rbq6n`#h1W@3!5z2<_oaGta*uCUhGbccF4~2*uHm|KHLoeqeOf)Ik&U{q z4v|P4FD(hJ#oB}P@~Eh4H4kKvdwjb$kY_g%$4~P{+g9@uG$ika>YLF-9VPc zH(XwfKC}qPF?=p?b7$^%_rsJ*9PcUwr$stX7(~88xBi4D17Fzq=qkN4X}u zA9Ub#y`p8?i+5pz^(vJ4rr#^;)q-j=RxHy;q>xeF&*zR> z@mao#_<6S3*ZT9UvhN^Lq1M$qjdEqJM+aA*N@S6Hd~K*zvcYR;4@&lUHJzw6?Z`$e zhmV5K-k-4>#p(T$Yp4y?xXa{L`kiq)UbEY%eNHc1F_O5R>Ekx1dC9f^-?3^9-3*vW zqLvg?jt_>|7dWDThMMNCkGTfqK_M;c9*F=l6nU?8hn*d>4@(g!ZUHA8jN+aWllO@;T(2wgI-S^Tn{{4J}hV<13iLCLb4CL7>{d&p+QC*`LkURq4|kdjfBAzsteXddKUHN{v%J# zmXI&#i&ie0tNkaQ=UQj9)7l6bX-nkR_}_lIrk%7^MTdA>eO#5FG+X0^+j#Rf|B0lv zIMT;kYyBsm<`ssoM-TG0-SfsoUc3KP(rAn9jh6Byb-Z!6|5(zL8_%oaO;s@?Z*1}( zdzOwjSbh4n1z*O-KD*$ql+q1rZ}}s*X5*d~+{v=?!S)rPV|)e%%@4 zV#jCg^OFnQ_dN@feDiUB>+vV-V(z(s`k8f8r2FfJ7#(ZoOZU!?^Q8wC-r-BScx%^_ z!uXlI!I|TJmOW)3%5$I{EeZe06M+-`>2i?Q#D1-p_($9v8$TFwPD%NhbiZ{;m@0d*3_+4rR-dNgXg-4TNk_Vd>L3x$tW-#Z>_AC^C9 z+B@IBaQ^#-g-Z{- zeATfh=N3&pOQ!R@>HOk_OG_8*`~`c$^k)IZ(+w$m7ESL46wgx2A0v57Z9s$N>-w@G zN=KV-zsp;9##r9c2;hQ}$k8t=Bgd!h(Q&?@Cf3R4*9CORyn;y4^r5H|^jDnQ%&zHO zpYM(~-)*05|I0(XwQ>F`Z#l4(+s5a%{qF3e-iN)3+!L@Zl;JZbbS-!-?3x*z9{de+ zXX^G;qHss-Si;=6WZuV{_kDZyJNE;3!raMIoi~-s4TSmNauZ?B&#LkM{et*}KB7Xl_-h390Vy-PxIc@l&C#s5E;>}xQ zw)uv8CG+&(bj}a(H66UU!++wLHYaTP_`|1^K4c1-!h=!eUHz3CTiz5_HcOa2r<4OipEjTi02f;I3FV{GS9?T>S8W0= zy$ZM{YVlYw1Hz5Nn>B@2%VGOE{hF12as&G8{b_o|2K3qc)AY&>=(G2y>9rfsXYWtb ztL~>Rxg>$K=ZIQu_Y15tV!(S(!Qj|m(9o`HLr#}7^F-i1#*{B{$p7bxjgogk5RggC2Hp#re^e`r z!2AHMcmrYZOloxy?hiqv$__dp!q-Dul^YtE@J>Kb3BkT|P*6j-#a(A=F_{NHaQFm) zqZrQD8kl|z!}k_u5I=_S16w*N>{pyT0P*ewH#WfussXqGWUPh(uvS9Fx89OgrD)x? zU=ZNj;m4i=MfguY0lwO33hfK-yVV?a+-!@hY=O3C`lR$LLI@5grOud{)3?+uBwO|Dv}0cs5#htt6c_&cAPtf4ExE0M}bPo(gxiMzg8 z-$Ku$%MULvn!5a5uM|q7#((q`MX0UeYF=dn%0P4YSfn+24Obfi4$KQX7RxO|xM$|_ z^yP@-^D9Zp9LfpigvS$9QKb1!2LvFMiIQDO%7n)a_uMi^3htENE{!_=qC81im#88h z5GxZ@X~ey>xt`x#KTjq$Hzq+tut#$ws7;Y=cj|7}#mGeQ4xo5t(Wy1d`GnExKP~En zGxHQ++bVjbC_F`P1c0^02WP~MvDOqw3M*EhH6Qk8Yp`%r^aM(u3hXI~Q%g*zSc4w$ z3pim-is?iWpFN+NgK%9c?iC9!Lm>H+#UHq_imfU54C6nr-AX)0e>#y1>1&^I-8qp+ zwc;c2eB^I<8T{ePn`%xiTMVBDuvSvTX|y$1TBJo9EQybzHHbn~Hj65IKw3;-!jVeQ5vWy9PQ~Fam=M48#lKMJCG%M$ z6<5J$K*GgC!)H41R%6U|OsED|wzrX~g1Y2CwnDAeEcp$r$b#O>3=Kfq0UVI%G{meC zKasRp@h#E+4mQb}peDeCTO(JeI}+3uf9GoFprrz>AAbY{;oj-Z z(MsM@71Q&%jq|;H?xBGCsnHTT5j?SEEa#2oOU5eRSQYJzk+FuDW4>U)@}MZb{b0h_ z27bk8fk(z@VDBqoEbwbx!LBe|Z~|Qk<90|BWPw=1$AYJ!tFqdEEU-0H8?3#xGot48 zWxQ&0Qg0471vmQ-Kh@=gN`fU}SNI}KmHuk6Vxm# z32X5)ogrik+QOEwd%8SYnb1}H4=1z143=mRYg)9GYy;PTrn-sz>WdCq?b=% zN7r*)qT5ly^+ZKWu0_MGxG{GoH(-Pd(jYar;vTa(A223a+z)&MmW_#)TX97+nPE)e zF!(6=qDRq6F>j8)pr~*zJezxwod7rt3XBe!Ha++peYPD84Vl&9s8QM}>5yVW96i#m zg9C0;HuGSrvR!>vLVMAH>o&jv*OhG~yyCJO6)y#A43uQF@X<|0v7)^aEfq6~q1E^y zhO5}zFg5t$hlik>@UY_(F4vR`Vpps_Ys1#-7?#dklatIZF`Y)d&o>Clsj&tY?`4L4 zk!*`$!h@}l7`ijJ0pr*3WB(Z*e&U(Y8afp`l?o;)eV`|(fzVMphY-Jyf2F=_nl;V$ ze>eJIbP=Dp^bsnlOdXk$J@ZHS;^w8I4!)=(QS?@T3~UY5ge`%ZUv>s|-prBC5**Fm zU~fbb=}ee5Ly}LQAF2=5N47*934MkCNYa!Ox){6|J{-|U0b^S%@0+rFW%J{S+7{l_ zvT$*!{WRZx`iUdaeug)mSu$MY4Hx5=F2{SXBn*9VRi6|qt{fMbofdxLk$VLd_+KW{ zr+oe03e=s1`*r3A_AG}k<3_I9fFH!0e-b}+I*Wj#!1Sm5K9Iww@o9Y&L{&INEEf3u zrNKuIDer^2;^NcOYNQ_`al8WRsp)3$0>HLhS7XB+{VuWK^obSj`*9u8r^Ch9Y@M*!x$rn!Mw}2aiAB)ccD6S>n|IYK2 zp7un@a94Eyw*}voJt$kG4*5G@)3a^weCOjGi`!1~sw(hL&5%^Cd}SmGc8fkx`?&PU zR3g9Ee|p(S=nBKWWwpwz@gHB#C5(CgZc)o%_3*ORDj6JVzaZdHGIue+A|3QGs0(Ik z&bvq8Th|19BoHXM#^7&x0_B+CyxfGxC1{R_|1u;f-9wz9gNrW9xjqOE)VNv z1SL+?As`fE1MwJ(OClJg#k3E2N9cQmJW{@d08Y1|=|^=jX^j@# zEuAf0CZM{Yeq?-T6syUtJMQ+)_AM&`oIG^&5qgNk`m-H%%b4offp(1h64!t1C?t!+ z7iRjV`<4l)#@LtmwcJQlZTnHFBi|<@>O02Uz!=wK7ryDc2Z*@-c;fHT-yx}U2 YBEPIICi8x5Bg|Eaoa!5})3AyE3uPAyC;$Ke literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_511041.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_511041.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab5871b441d8be54557b0dca944e6d49bca7bea GIT binary patch literal 11303 zcmc&aYfM{Ln)hD6eXqd}Fa$^lBoIgm1o9vdniTRNK%jx9&4ecPTzn0fhYi=4(74w# z?o3y;+wR2f%q-5#NO4D%>P)f-)ZQ&UeoFUdR7nG8r*Q;_Pq6e_nxMe?S!_S<;2a2^EGdU<}6K zQ`oSa(E}ZP#jEweM%?*vCO0S3fRyan0&yMr$cwhvK@5%!-&L%ix`Z0{5evRKW z-5_M{!#)K%O80=LDI;fwV$USIPFO=jQ`4&O&Fp-&Kfwxf?8KJ^+= z_sW^mJ!c(#BGq;JBu+ zg^!X@10wXeSBE4Jht?#eiS%hz>L_RMo5YPYM4GxZIn?W7LQX zyitMScW%HYm7c(^91i8tX}pau8YQ?4-zM_tx0{;GzSMDr3Mhl;P|kPK{^Ud`RE}d2 z{2~v#Sf4{7k38mU|AZ^LNmw`NBb}mO1G}F=W7kIB@iGlOhsFyV`4?!|h!xnw?~ySB zjZm^mr?BZ6W2^WkzChUtzJV{ylpCK|!`^4mQ3D-w2Dc%-`<6^gBS?!ZN1FLP00DP} zkxo~jg-00K`ab>zzQqqBi!t~?!UXp7&HTQc`T>4Fe*oRb37fLl+5G7zMUhzi-si_f zs|pvDLKM>F!fpMM1Y;Pc}uXV`CoS=PY`P2Y50zOSgYmfxkcK5wM$!4K6K zUl(Y7P6>ACIVD&dPw*6f=t=pLIYB1A?MVru^aR>B;_w=D2RFj=TGT3?r^9>+qnfM( z6_~4w@>nr^JKxG5zG+nMGLBu&mFKy-GkpN_VbqfYEDVj+&amDwmY#OF8By(I+?Pbv zwA&+UX_oekxkQcM#d>|BZgk2^bNgCEjohq;Dg*dlQN?&UQRCyhtV=Y~oOjyobkGb_ z-yv#f#|+1cnh`m3N||=VC*`^k{|qXyR+tKlE+30dzG(1qtebH;=;>^6GE*GMDS9%k z&Md6cBO22bMqugLvPUG9cK^&|ZUr#LSnu47!{>&zZf4zxsu^IA20jkIY%-xD42-NN zRD|g$#Yxeu6x|+&L|oL%V*m|u)Xn-h31?cQvO@zgiF`VPE%1(x`k+4^!h|`GkZnR|{*xNTn&7_nW(V1{uqNiXg>4n4W zCq=E#Nl&?aY&F^i-H5|C?jGetJ?xTPQUkI5a2#e(2K~~35OJrdlcoiGkS@T)rLFc! zr_9#@^tS&m(tn+4n)kAkz8Tu-Y8t1VlQcK(_3Q!D)x$OUSZ5RGnx1KbT9ezu?RQMm z95A7A=CT8t9B`I6ugBr^!ihT2+`M;xGxUfi8OSl|Vm+=Y_9Dy>T_NA!!EFj+-yFn9 z)0I8pJ+Z+RV`cEz1FGOkf4Dz3yFyh3PdqSN7j#iw43BlB%yq$@r_>X>g!1N;xh2^1 zuw+NvEtE7R)kSNC)}*~wD69>hTC+q*AsQgQwT1lGwg@ zg;2gnpc;cG)=cK4ZHHjmk)Q->fAGXZlQnWdFjalr^-0f1JqdcTBDM2~U^)^!xwd(a ztu%Hz(JEB52)2EnjofkHb}!*e9m%2d!m;zIgTsPtICvTne(e{XF)r2}b;K_V#d`(Q z-o=h3_21GERHVKE*Jf7cDdw> zrsUa+Lf1vXbTN4HD{Ik0U9=9S(2y)|{hYn;zvo}E_6GaDf+3rKHnK1l9gETNibb7J z)ta&&S~azWaK!rVaChWnydlATI+&Ojs$1_?3Dt*}j;&Cqpg+AC+W6F`&5O)pt5AFR zE-lm^UB*|;y&=s*vpvH7VlXlvEevVa%oV>FjH_ek1al44?By{!>JAwn7MH~)gyOo8 z@qxv*P!cVP+2gwtexbTe*wS|Q#8R)&F(7OiNWSVwj!mpsCPP|fR2{&*B~-TyrR{fJ zOSAVTmivSk&k3dHl5dPBXJ%K-EX<&2ORO{66SqWDfwRHPMb$>kh%X1Gc&Mqu!LYX|eWB{q6eCyYBbi z>rJ(E3)b$C?xEiHGj`!p^iqtz#$07SZuzABqxRIc#zoy7<89;TE%)2+wWs!;OqFyC z`tA_>ubM)x+!dmjKT?Gey@MwG2CWM#NQNkv-2+Z z*f_@yp$6@&q|DehslvLrIo9I@TQQT6-j=-9-+_{SSO2bI0pqcWOn0`x;F8WS)w9jI?iO8g)LlV_|8(Dt4_pyi3YqiN7+ z?-_VaurceUPUP_~!}_pclz{c*^Q^GEzLZw6b^@A^J_Oi|?j3zVm-VZ7o#Yu};Pq+1 zW>cXtcrA>fL@;v}%&0OJcm94&rXPO>e*XR}zUCSD`TMi@#53^o_h<1{&%n>$pT$?- zNP9Cl#%o@|>h%LGaycLUHzdh^luSm+c)U)CTE=zB?R2qdE1a1?4OGz8K!hG6N9gCo zv%)xGIP46Ls3?v@j$4!GV20r^$rnL^M@Jsoqzf=gz{J%9IR+UzlVH?pWu?wmqN++% z_n$c1nYAK)=>BdG2H6Xci@Gr|CLN<5|L(NPXI!&$w1)#r&mox$zLv(0DKA()z9Wrk zSY$x^+TrsbLF|FqzCcFF7-+=!fKa(FW#0eUr8{rkeru_DiCI1(oE%QIzAl(wPhO;x zBa_0#sbt}FQa{bU0qFGxQSD*eV9xo*=^2-3a{Is#fw9NtymLNL19qBoT-1Q|!d+%Z z0YTaT$r=G;OtOOOjjRK8d>?&~2xDpVX@d_~g>&4@94A@LERsSwJU};8F}YrPNoK{C zO*z1_0FI0x-#*y={|sVlBz5Ir_+Vu4y~DBAYaLfR63t(hHzd^!NwNX;wT4?GEn(F6 z;?)-u7NPuw1TB;|eM#<3s`noxxd9Hdm-H6&VZ{ z20Pburr-}ASnLbsQBcQmOYF5o=dDxMPu=aj-+ixp#d18P{WmqP@!;QRF^x5{OHh}s z=?$UQ$VsHLKyevD%^?~xNgY_;CnJLk--~`PM*sBnHGScgf^b1(Hl^PJGd*(kNTMoL z-hhTjs)A1GOJlXy_FUbQz*A+rziYGg8=Fz1UAJSDHP|n0w}{U>a}PPXJd!T4L%c}R zBe_=#I#E*Ox2#VxQXJuJND5>CX{0<6C?w?rL_o!>cwCBq5Krh1WZ!N`NZtZ0;PH4y zP9y5b1=^4poKf9WE7vrj1}GH#-HZ$` z>Xrq7S0}A9dOq36Aum!&`Cb+jy%5krR8pGJ|LZLDyndoYnQt~n`=P8cVBmH6VHjNo z#>l0mJn%DL`~c`N(*8c3JFu519IzqX)Za*-vH;0#RXQ1orgB=2V;4`szEZqFjWGn-p>x8T$ByE9&gcQ_~<3Cn*4jM#_((bF?*Eh6GUz-rKkheY)yw`-pDqE=~;ov2F#T@Xjm`DW)_ zF274ukJD45^%mBLz$9fvc%35Qc8cnGX`p#&ASs|As*%ktk#Nk51whF$??}&!5lyHw zTQxYnQ;w0#9K1^56CyF?6^U`Ts0G)5(#41>(0D#%7%HkQlgk>;^n_|>KLp4(;ph8D z*brH%9S9$IzcoU?*O640ggTJg(1msH8$b`43a*?BpL_dcuq#BYk*0_-d^ARXNmeG+ zl}WM^K^nq_NOfc+)+^X+6ULN&AGkzmJxRxlLI+c1ZSc6FIdfV}mdNq&>#^hD(jZ6P z5;_q%5$<0lw+rO<6j{AW)(K=?!jkApkxi>)i$Jz~cHmC??e-LT1pJ+X!pP3>c&tV+ zRjyLC0#%!$>Q<>nhzfuQR-K{_tWt*r>d@!K_si~;rKsZ&8z?SY7>*9diFjA4cvo;> zt*CUNC)yJmiC3qJYJ&Y=nTsWprVAkGk90TZZuzhKm#ThS_eI@`xeo#w$RCGwR{g{g z!q-gY5Dl1x*_*0>F>)>5e045%N-)(xi?t|*M|Z`BqT9fYtgN|q`Re7H#kb0?m!&HA zhYTN9MFwM4(bppNU+OE?4Oqd}bqXVkt~7=lW7gOZ5IUDK)~p(L3&!0y58OI@{cy_I z8a%c}7DO})rl=`4k}_AVnrj7fEd*N@2 z!Ikj<8U#{+dIYC+<% zzJUEHDrat$B;nZ~Lz}3F*uEs5*>_M?Gmo4oHiE$a5;fo>Y#+*15MGx5V<3gMeLnQZ zl#H~N3bM#1=RD{VOTHNUO91*A{CqvKz_yhy9El!z`!vX8{hGlN=?qr{yZ%%-e??iY z^o9FkM6By$*C!JnP24TLUwN-`h3Y}_w)o1$@Wo#pxYmBPJ!Pt0HN7C1UPyE;5{vY0 zUCMMQcoNCGmY5~l`ja<6-ci;Mn|d3JI@fGPF-vUlYH_q9zE7~$ zhK{XO*Mv?;(#;Zm6~*)-4PR1~Uu!UH*}4WRr~ol%EsFRBOI5IU4d&My?v1dq_V{dy zs$Zq{3e?_3>*CN|B(Ybi(*kvRc_=yfN{TuQQzi?qw1wLuL$R{>V2a$eN;V2)V}f05 zS@fsKu2r&EAbXcvme~|J1PVu9{QLc}bu@|qhBD~lIk`|S~}DB z1h|_c8NY!F{5uQ}sGjbqgVl^f>?Zf+JXg|6L@^BHS$dR!`Wf`#Hhw31nrvhW;zT)~ z$OV{3KpW5n^Z^5x@mtW%=XJauafcLIBhGtz*gf}tC zkRieL;2cic=W;aHYh=p^J$W&n^b=U~LCLU?eCB#e_PNo+mHdQ-Xt1ZzCk7wc>KTz< zQ_$mzNR2oogJ;GI_eR7WqAC4AA^Y-@!8SL=`F!Y)3Ay)0zAZ)aS>5uNOF8r|y!~TRSS}M#G9`Dj)*d)ea3_gUN zQNQ>ZT&e4Z^=Aw?ro&i#K*6L9R=-L+^>iDmF*-)i7#Jf%{DXlpF{az(EtmtP=Z-Kl z=2sk{d*u;kpzGDB(mgf8vb2tVV?V}Ne~B|xKYp9mRtH=vP}iB8GX<<+U{!95T4oBL z)4u9C?Y8H%+w-@3ElklzhFh75weK0T88c&ygDGaQf#N(=l|XxGp7zq`z?VG-zMQ3) z)z2AIk-yzrz*K%@JVP@z3UzQ|mnR0WdJRk!Ko!C&YM580c1`}75nkqbR9Tp6m~R!c zj#>K=e#Qo8*p}s{{R-}^eT2yDd9(-eM;H#71C>a~oCcZNd-&+SZr8CxXS!TRo~rdc zRqK|igKr-_@m2=%kpT_Y$rk_Eu!|Sm{VXq=QTdXa8{r>?(F{5SqL=X_GTtNO-Llc+ z8WA{XF?j-hudkn%^@HpcndVrY6^i- z66!|x_HhpknZyKGnVJUB#LE(c-`3sAK!nOyfYs-~uFio)YhIQ_qgZq+Ghe7;1z18X z&carTrDA2!h+&~Bv0SVQ0*=rcagA82v{eg;4N4L##Ofe$K`hOVT`N|@n6*Iy!-4=4 zawr&}25ksWjpz`I5S|=*5J#6K24A9G!Lv@R5sP1HTy`DGxLR?wSn^WiO0+d75`c^&+nJE{$UhXyu|sgAr?&jkan%q52%I zpkyz~230YnZLfriS#obRld#1gR#(V>mjzWcEle*P2KxdyR8 z<)jKH)+u~)ih9v0+LaX+XzO@pg@syAs8KXz;a6#`%i%K|vuXXIjXBik(Ffvs(YcJL zra|C2L)6F2;L3PM}NGDbhR zA9Tp!XQT{;xVL=QMLK&I+H<&?t8+3tF6zW~MMiCT2D1^}iDxhqs1==_+ZOGv;MkRC zrQfU2Rgpb@*}`9oE~mAJ1A;LpE*spJef%S&IXSozsKv+#S4P-=ZBj4-48v}IP&NjH z3oOT3q2m#fiB6nDCX%y2)(JzhWrPd#bM9f6kC6=?#&=QH5BvNw;pW`_epWUj-OtM= z?@+)kY}qUu)n)@!85kds^-Ms>bhmAlEp8z&?DM$X3?mzPA;7WD*JPs`*%z|04}Cav zO=Y5wSK3T{!4ahu2{z7#GCS%w^8)8%SeJWvsW_M|_T?1)*;daItj8}~G8p=R150JU zY}ML>BZIjW7}3uK#ztJc50;Fy{4!&QD?^0oM}SX8$Xq}Jva3)LCf166d6ic5`CSS| znN$(OK&#is@j_+{wKN}KP;h2G!2W>O%R{UB)#`(=Er?q@@K60(-5$@_@YoQros)6s zA9SISvThI=alsK8cguJmY=y@?#L5P=nX;+R#b5AwVb_A00_+UKj>@=4HYxLZhWXcE zv5p^;{`;-gO95_>A8~uw)(dXWpj)^Q@Hc}Y>K9sh&eJNe!y~OwYxVhscGs|50P;AeN8+OkR86?+!7BTtX~Go8!rf_@Ea+6Mf{v(cZRzk6qUxVBwm-wT4tzZ(cgp*r>L^{mPD0Q*(^~l;jR?DYQCse zqHCupiE0mbJ*4fia}r(i#ojLu-8?kqo~cUKbx3qa_(1AM^dM{}Z*<0m_`wNR;)+zV zNuoE+Y@Mz9Ek3*9?(VtM(v}{H?op9@Z}cV%@pmQyQ#Przampu^Y?J71v-NYuza5<` zyVrXEkhJrJM4t#BcwjG{Y?x?>AD(KOuiW+x_igZAaKZj&xckpAU)wMGCi^G)6z9^iqt!CPoaYqSZHQ;>Y6~CUz%!rK0+&8p*yfVuI#)-7oQHn~US7iCu{; zlD&S>)+pH;Z*RV{b9!gewsp20R=4}UN7~uDxZ{Mh!TGuN3fHh;PE=FVhw%Z%}B@-F#}{o9Ir70FHelVx2J*%h&- z$bwi?qD-%E3tx|=_<@zJYEnjfN6pg!6dmZ z-1*Q{5VL*wla#&uhAqB7UN+I1ID|$LGs16rXD;0>oO4Q>x+Qye#Poo)jMGs%#>V?6 z0*TQX;#Z@ybqYw0v}xac;=b_xiIev){jPBS%vtG;vr_B1Kal4iz9eEYO2$fH`n8Fg zsja`NnK?eQVS4v$ue7maZtMNJ@89aZzu~*x^CwP9`%X%Y@BD#0^%yQRCgYc7)95AE z*MC8f2}Ndss2f#O5cfI)AOd5;$QX2TtT4v;JxKLw30X{v-+3HVR^sM9-ol)O7g|aZtFGMBxH<`5x~M2J|aUV0W2~#i6+HI!@&PCf~}!PbBP4ZmGyx% zNU)ZTSvxX+f7Txt@x0?#>R-Zd%sYOi{w4gn6~^c9UmCApVSN7nrSXPa8J_~jMB_2c zN%REEeSSE9j^Qz0aP+Yb2$%@I{;|Lq&mpsKS|<|>d(r1%xk{*dat|Lm5(wR6a)iU&Y52(8uo}kPz=0s^QnWe0xn7^_;nq)hvGunshygI#mMJpx85SU;W%vJor`;R?qf_(qf=O9AAF zVoEwK+&R>S$l=~anG0n|s&Mb2OtCu+{a|+HtRKZ8G%IM&9+eHNpSyT20N_XA$Nvvx z{}sm4t1yZl-yPk(V6B=rY@e^*o`l?5m9kRf+oIb(*dOi;kAjpRr=nD>H(D6(OquBL z`wwiTla&*dpn`4jlT)2{-k5&lo33w<+&i*hdow~jG3Y5Gy#EP_8SSw;$xxOe&5>=f z1Mvfi1E|;*al>aEiD26`Ypi$j)WoT{`{!p5wj@hmZ`egb?t-2G|H1L_7$@&)^=2_-EKZW&_)bxJOhiOGtbI2EZob zn%Esc=H{~W6g}2uolr$j2Db_AE~+;Pi0m*D|XCPld0vkYjVOd%VHiRs~8m*HtYh2rpod&WBypLhFk~MYGanZFQ!Ao41qP!Nem^bVOO|kL7+cm`E&;abu2%6 zw=6&@A)38<&@y!JHx?u^_XYeHSx(>`-jTMgt=^F>n;oiU=MY#wAK)ATuR{^vj)v|g zM~~A18YJL2)N{Pu!LcLg?ZGd&1z*75qM&eX>}YlsGji+%cen=9f3l>x*;ulxN0aVk(p_oVhbp0 zsYi*zI9JKIAX}F77&idGvf-kSy(AlzNDH?PwdjZ4mm%U381qBS!2|kqL^koGV=N1+ z9Ypx+5R_42g{t=IQ7gzj%t@*0PSI~maTVib7mDtnkNdKs1|h;T;_{Bj`pd3S6(6B; zSjK%k3bAYI4}DbyNI#|Yxt{{i2>keO;b?)*qcF2&Tp!h6GcJFk(b6PYnx=YZbV<8}Cfg)r+)4q8&+k-6GvA(alM^6(*rmq#WuzNvb*Al_oH2 z!FWToA=VaiUu%KD1#KTc8$BDZiT5UIuAhuo-tM~7J>5Nr|1fe|rji zac~Ux-N55*@%FeIex`}TlSe0xe$n-1_s#BP@urzAQt{T=&N=*EmsGS@BKJj!v>q!c zeF!w4ik`Z5W|3Ye(d&|Q9nfjD#kR&@o3}dV4Gz`VLdVLJt7Eg;dmYHUm9#gw4Y_tN z;6h2$PvAzqpcgbau4s6eXL(XIydea7k-?Q9JMm(N3m#htkk8q5=>Yq<$Z2P(*zuq}-}}+qPpwqz{q$#Nr|R!COgG$WoMFH2zuW&$zSxP#$?(AxT|9n1dj5k0 z;4bhf(i(9`P2tWQc@Z!DV(iP%&Cu*S-@beA-39t^_<$11Q)ESA zEFKr5V!Sgx3Xz$EcY3CK77Aa7=vZ;ZjkY+G7?7$r%~x)cirXUQRFM)yO1vpmwal+> zk&0R)WXe_^Z=a}6^iJ7so{&~;Tr6yr3R{0|{F=T?Cks0OQRUi1!}Z+|eXQ6PA+CLQn}7N%(sKCn#6KE6G={o2lXLy77lWUd*o0mHIF<_T2bpG9~`|6-RH4n7XCklZ(6 zgM#U)EGK$mSOFg{!khLh;4|j@8^9>)ZyPiN@*L0+B0{DR2|^+3JRpH2nm`QbZ)GAy zAq!&^Ej}DBw-v4p(Ziwh1E*|^adRFh!tq82Q?{a8;PSY6*45@TDhFJ7iedb!r~6X? zQoK*l`-=CDyjk@KmH12VeFx;IQ-_HNm}v!d*e6+cenn!5LqvoHO*Ds433RxsRmxAfR|#8O?RrMtNx-?AL}t4^|%bpQm%Brk7H#O zDM#H8bQNw-nW#lmxnwF&>!AiuGUKkOD{X`_m~gaxky;~BYtklEC!s!mG*Vm>oD5lESES5W^a%Wl(wTER@i)9T`Swq?gHK1sXBaH??Y2&8F#@D6B*B_($ W7Ng?tzt*A7p>DK_Eo(+U43Pb@O8l?yjr+q*;p%4qVE@0rjAO9)O0ouNw_RL&f zl2$1%=}&TccN@kaX-y0j zVCXU!dS3qw7z!BWc`;PaiJ|%#W9XiHj2wo6H$KD6bs$@L=8s_va03i)`W1Vp=(a+V zH>_KcRLcMIX18isUjMufzIcsx%Np(0HQGxGw|n$_>D#K?nGxmuqK+?fGsC5MjAV+u zZH@NwHQLv$(O%&;@Rf!8JVw6iZPi61U&~h)hTzvzY*XRh=V|pQD_~|G4Zq=S^~Itm zjQG6slQ=9WLy-AK7|qOY;%nY!FBZcIEY3>aa)Y(kJtWdT^)LkAhYSm!1Ej>k$F35M z2m4R=op78ze(|v5=Pn;doc1I_`rU(}KEHK#D{ ze^^6CP?3+Ce8^_CD@FFC$et3}6M(06`Nlj$1Ab9C;=Unhp;2DpCoB_lUI;fVklJ-oijfJp}B**wYxv@g;kD=H;? z$k=jzh>5VT#bISPqs@|}B{)<|<TS{kE|ftmMKF$*p)Yn zbu!&tj#^O{+K#Mbz3&*a+_t_yqhe5F9&Tk0M=qge)GF`qDL5m+M%a&;{g0&cPFgpi zI=MCfj8lv{R4?~s&NIQNM-ACl=~R=}h6wxq-vL*pymI6a+!cTRGj(0}&(y6ukpdaf zPH7EgZ71s-WJ2AKpJ2HsxN9{IQWI&fhDRFGN_Rthp4_WtY3FjZ3pJwl+-kAR4d(xa zEDG)401!W~e@1tpY_4^cy+&L0qlc(=njnk$ZEi&yfBW7$?$ z4%Rg$nlcom;Y4GmN=GbYwFyLhiOO*hi5fnRDN!>3EI39p`IuNtj6~8kH0Br(e3RoM zM>8TtgJ($a2Wao8Xq1}$LsM=*r_?eTNZRM|_z6*(%k>dh0U`#!sG-90i-w$-{GvW9 zC+KjxT$7`dUXZjPvas$WF#dH>LpFh9L}Z;r-lPRG!Ff&8NQ8tsD%L9w-aRqt91Do- z5UkuJfI>DnV`V`Pu1@r(3W2gc1L^ z)8%d*bh<{IfkEHcRbq|&nmIXr zGB&YftPLGrE;2`2aM6ai1s83KBW!999VysWUKjV^@|KjUblGf8*&4CAF?4L%S{b{9 zt@hA~`z5x>)wltdZ%JHAm+TJpFB{8a?eS_{wG|t;!8oSkR7o8+)g_GB*bzGXz+{bF z!lv4f4}Nm|*71ZhS)H!$!KNM<{b%-gK0Fgw#V$mB2_3ezC66q0{F7y&`?J%FL%8EK zww=bN7bW6XKXk@RV*}AQ6Saih8`!pUp=+`JpV`IDU+lm8CwR}x*!D6uT?id{V6Bd8 z66(7>Wir*>kA+Ym_geI zS7ryI12Jd3I;q9A-D%s-d!}7scDdLR>Bq%2@orq)oOl@*w};is7F#6n;hD&_=!SSL zw$z0+WQ~qaAB`M|+Y{_%e759bJzqgdkmPgBDRf*1I+x)JDz9q{6Ae=$cC(4qR`HF?QrQ#ldEw#-a ziyn)46aA_6yS}!5Q~6cpQfYr!cON*Lh_(D+*&Oa&E-i~$qDR7qm&|wNL1=*kTp&3%@>-R3$FphE813i7lNA zy2Yb+4}JgA*}F&ap>ufWIczxxTospp$i_NiTCzjhaB+R2JX!kp6UlOH-?h+-?fbB~ zC#<^9>1RySridptl2rY`bw2n%!xcBpMQpI;I)31`KLQoR2rt4%)LwB=^(&kuO++V4 z_nk&?2PTCv7g)d9L@q)aF<6HC{~hif-7JHc;cR!Va0!fsce$V%v1|1P)$i&88B3mF zhDms1Rhs6ebsh!KFW{v^YXL9j4x6>MHDN9c*u3IxE~pI{WYS1W{b>e#WewP#N-_p` zy38bdol)Zs9`c>I}_9^&<`*Zl3r{EXv&*3Ycg0FmX|Lu$)%rZ!Gma(f(%`1c| z5QnJY-B*WPZdWGmLnTC3gIGn6#^Cc;il@bmrdb4DFOj56?{ zYrZjjWw_gWH;_7a1)p}M5Ae8xPk9FLx?PSzo^E>T6y!MwQ=hgLRy4jhoCT(Ravg$V$^1V8^j z!skCj%(B5a(>>k&)}c^uXkyu5o!K|NPpaJKOf!aQL*z`1g=??&lg+m_Cu=`%`mAY* z+Y{=&Zx{?6y05ZC>anVPnbU>4BfUhot4X;n?1XHFC?CsvKwQrZ%TNi7%*)rD>pZJ^BpUj+E1J{kw6gox)Eo!g1 zhXw`%V4dd7U041dBYT%@h(`+0H&l1!?`-HGDUih|q?P0|A5+x z*%T6{1TD{U*plN;F2TE#R_1Z|$?_`@JLyEQKt)Wfk zJC}n#FW~py(cRA6LP1@Cs6U!`^^-@FZybnyg+StZ#7WT%q?e+!Pmxdo?y?dxAU&^` z0xq7!!3gYD!JZpc!ua2gJ;w6tpb;5(Es35Wqm)Aw4^Gb%JTM@gJWEO59>reZ^Zj2y zEJ;^r_akLy<&=_cm$2kr6x1QjvtS{X#8I!w3vF(7c0VYTz9s*BFG=xzkG?VfO`P7pSv35kQ;HFw$s3kS! zs1rP2aL&iTjd#J7GA^no;jY;!*;;mjMs;^u|(1W^Uv@^<=d^Apt}-w8CAt_2gd`Q&R)cpIR+@bjl&x#1?z zTV|T4n?E=|`&#t1w7xFXw`?@e9G^Z8zM|eVvt@cqCspKoK5QRg^xsDo<1Eq3~t%XtJAN>>f-9TEeRX0Xib~8-!pY# zQ&)0gVN=?)@1E%ZHXZo7?3=2us?w&Dp(FRrr8BQjzaHzlxp!`Ff}Pu+IPz)#eE))F z{`A5GZaMH<^>>DE4XI16dzXgsrQs#>Naz^#{vw`PZ`2!ai~8adKbY*`_nC?RW{L3A zXy(o7H$R%Z`TE@JY4et_YI#R{viGyP1>2X^UsNwTzu5SF+p+Lf0vv0bD~Wif-?(S4 z!RDHe%RZ^PRkdW^de;d^>uckO=3Y(k$#e4~3r@UsKdwFy=6-33^hPG4ha%LtfX^(YudaMJa|*d%;o9JvBUA&xnps6qW9KdvKcq- z1!;irj20RyXo$0OP4UBt+FQqxdR)IdZQKLmH*C_UH*aL|L~hS8EYlS~~nK6pl2#7U4u z0^lX(>M&WNJ3;$X=}1~1BKXgN1Fyh9C39b#Rx$5GAPlsBf%OL&NUH?@b|IFKG^VWm zL3?wjgh^=6+1GjMnQCjQqqS6p3Lw~wqOK#Qdh+=nYBy(P69p_Gz6B*wf6XyDmak%U za-8+L$PNpy5E!M`cTH3cf(5#k(ltmlsLmGN1DrSE=NBZEYb+th8IG>mMwC!Yh<7He zw|1q?TXTxwU|jQw@s=_160X^vHnruHz^0h{=E&SgV$+;2O(J>_8=9?(R>fYqd3o+~ zq8G2UM5Le9ELv8U^^Z$#gSJ8^l_4;DLwpdj4X7IVfrqT3*p zWG#!SqRlj>6l;qfi>rVj1fFW^<5wXdR*~@H&3ka|o(1<}@8XTce%x~g*PaQ#00FJq zSnph2)#t|el#%F z<`p$l)VNa27;X0q>#<>d%pLEY8%(GZ&Up@^ZN{{rGo|X3lw$t2%PPg6U{&BRUTjeL zWLFg)>RA$E2mrj43ONjSl(LB!R6T_rD8lEWrw$OFX@Ucu=Q>;hwH4F^wLuO{*{m}H zfevzbIi$P|2b-jz9?V<)5DSXH00&BQN+1FeRjx_iX;(<+gIL1wRzockL%|evr ze*-=t=7FUYg2Ruwvt9pC_WAnH)-Q4UL%mPvX#`7Pv3aTSBv#dh4ug-|AFut%$dqk^ z{am)7U97%4l`g#)I=NzGbhb!vMXfSxLPyCyED4>WyA7mB#A{ay#Q%E8K-5Kukp`<| zK@wY2G)r^d=7rakLoGT;kLj=4omAN|*ls2m;&h&uQSkU!wQJ}zWp z2i5@AjdM0H?t-0PN>hv*(E8PNc?D#+9^;PdkH3}sUem5loMNw2#{_!xje zMKS)rg$rm!$+GP47~{V(rhj9qAF#zC+dCWI+W3Gg3u}mv9``TwrOup9xjoo6kP1w$ zFw8+kp8}4Q?vP?R3neVHr(S$1#k;Z1lkx`us#kH4LiOcAJsU0m!og>t^{gX48?7C8 zg_MiOHg`(!1M>$Iy_ETbdCUt7YXYrUbSBLoN*Ud{pURC2#Y!ck;NGcxt8ztggk{@c z%Rt<|CEBuL$y8h0KK0Ifv(*#qN%p&zZ{>cg3|>OPtxR*vr`7Y-f1uSzN=9LPR7FOq zrrj?pida*;?33zS)hi4X7t~)GzcA9`UH5n1Z@sjbat+>d`LN3;)dJUO4l@n12NnNL zvg3*$rR))*@?(vKt&4RMOMHcaV)87_7CY~jeOLW$b?U;Udly{zf-B`7zUQ96?upd# z3EU?8j(na-e8R-7~Z|cU18vp_`dV_&N<)tzTdh27p=C0fQMW7?bSav5X4_lBKv4Fk=5Ty2x5ky37Q-w z`o%ZdPvSH+Ch3=w1VwnrVO&#|S$yMs&DuPaqA6NJOKBM`|Fx7>(257jkD(3pY{(%2 zp6|v%eP<3yL9bM_EZ?gk=@3Ded`QyjA@YGH+gBK~PV9AINNYXPVMAe#7^ig`#x0WHbePLir6WkWIQ1VM@#`*T!vBmyr4oTW zBapoUd0LRUof82Daujag_@H;lFGxl_Hw7)@@p}T!fX_MQVSJ1NxERzgMuh^TKm^74 zv#*}&bDnmWK`Q$o8l)gr!@x2_1juZ`Flu<+Lz5q8%0p$FP*OR{@&ZaU)O8OlV@p5I zv=yq@gjUX)bF`JLm93nT5k$z&ma+CJs3TOxRMV(sRz>RI6}<r zftlONZWSdd(uEw~kc0zCfK`m3d2acJ^kgWXZJB~~9@@sXu4kfmc zb+9?9v)ka%lV^{&(DqGun^;HTeBd6puBoAIgK?Sg`fcB&ehJ&o>auhCjWOzTC2QN< zu}Q5`sMV0;T_?HBc#OI7wRLwC9}jTBuA~)+E3|WcPuG<0V%ylAQ=sl>C7z$%YzMom zAiam(&F;Z^UkX3^y4?t5vubvi|xtu3_g2kD)-bAO1Q<@fl0)<#n!-Ee{y zo??aPXJ*RTeQYPY|A8vI!%5<1;n`7Koq&oc1X43evG*SFygj}n@-4b!b%DF z#+_~-td~8zc64^^g%Uw45;;da%(!QiaYFx)h3g*#b0JK8y@OC|?`)55=g3q^Mflhg zjqc8w=ozjqrLl*PKQ)+U6w?Zhh zvhG{b?vIdZgE96hZ>WuT@`jd(ENv>|%&= zoQRNUF2yoWeDdya;uK%Chd1w;?|Iy?sOgQ6%a%&6`@=f!#I!RWgR3!Nw+sS8CcW@BUF~-I&h5e z_4|3-{`sK=`iriGA^vc0^7Zrl;q%E01N@-@-Zqf*`jXeD7WHpOq|4?CZuf^qu5+ zd~0g^3Eq0*3*+PN$2Q(_F4=d5w_M@1SCX#5q&LXBf{Qn%{)ZrLl1HiM;%kKd3+#3kyon7WM0T;vs$dW-UkDg4rTDOj|6bByKn4Id5M z_uli)k#h%AO-FeBkw4Te^er^=x>p{b<8{5szKhArSNXoHiz8!8BLRLSkPKdf3U89f zQ4dDR9#roe@Pf>f)T^imz0_-B>NQl>N1jEg^HiUh!Y|z$RJc2>(8Y{Defx>BC~ zjSa$@r@vHgdkSe~R2j2!#<=-QWfL?52L^dVP)uI;c!#bA1i9bs8uj>@MkrtoqEb2I z2?UvOH`tlE#|AO%Y0twEe_MH5HAAq(aL(V((Ggk5!B4$H{FG$L7oeeDfJX8HG(<74 z|1$LVU1dndQpHMX5_z5Pssg#O4kN>5-4j(tRf80aw}=-V$a#9jfX)Y0)}{){BgzQW zLQCFJh7^G;hzhE#XDiFgtzcSw$TAH=uHz>~`93b~^=PzEK=a+B7 zU%WiePi?|qygbh@*@VA%d7d9Z!X)^t7EyX6!OG7Q4n;46%wMoV<+SIT*X?1DZR;*v z#%TJHbG3@xsb5l?fl5)51^gRLVb<{}aITOwyzYs@GNLimk)=nD1muaBAqAb3_7+%r z$i4%3O)lCKOg&0U1nHUMeO*k>JoTfUSsM;BKZds;9Re$LsySn7fL#=HjR(MfbYe5b zzq9S&s1LTh|8N^3bOVbYZTtTX6Hg4*8TGW9>*50OUcP2m%CP&BYoEUT=rTd z)U~|2c2QlQl-4KJ^=Y+nNnOLMYZldYNoieDU6)oHmekd}x_VJvo0Qfj)wRoN&7IC@ z=dJE=S9mh5)ZNiUHL-J1eYoqX)^z7m^b%Lco#U(t@}Y88`8(O?+RwC$+MaOFU!|0+ z?rS9>GsU*>Qd?T7igd<$pn@2jo2p1h#04*`){Q??$Ii|CX!=K->*tr!NBPYQKF$ z(Rr*^$Ld(@<%jeIJURUgZD$@pKGoUn7c#JVzcs zSR>+QO|%4XK{P^mLZt;Y6gf56QqWGaTp!E@)I~ALHPjZ?3@9Pc(ef#9phH&Hq6T~z z$8FFv8At(C4^}KXJhDQ(y9k~M=S~?OgSD8}Ex?rw;^+#)?1_{e3t5O96l+V8Y0LJ5 zw1rePeqc^Dv?42ijHp(FY|7a}?4#&f0IrOp#f8eRnV^@C|*HnF~p~eR%J&Gdo45mP+^@FqFf6@S}{Zd z%ALz9c6IomBs&M8>Og*W6xnssOGgMR1KYOmiT!gKJJB}*02HQq8w20ut-kJgo zgsslE@Ln^Zi+Z*kOH_f;*Hrc8EN;cWV(RTfO!OCjpOtuTBamm_s9CTG(u0 zDuI?ky>*a+)OEw_N2eA%ZAb92mq{(Z%+x?8gZhpEE{C3}Me#ZmY(c?R2msP@-C!C} zaytqfD8MIRGb)wAX${bp0*Jh0LBL!90K*aKU~qI)kX;A(%;TsN6o5pCs1=Ps9DuY+ zM1vm3RTwD~qZy3J&*Z=j}#&z1|vwXfvJIq;e;3-29C&ZF9j*;AeaPz^xZ&~0fKrYC@vWTjEuY{ zkduNO5SS4UAbVgg;3hycek9E75x?PEJo(j#Ka08i3ixiq&tD3Q9xxa4I?Sa+lSg2V zE*ng-{%O4kwy366a9&?X85#isQJZ7e;iwOHFRS%22d}OScRvB}22T4}CwD%+|E?2! z;zWQi-8H{u!M&({CEWc~XSg#M9gOwG$wgg#_{5WS>9iKrsjZH0;WbSNk1udRt&?w8#wdr<$dWws?%w`Wnh zH(P#DEZ=g!`CfCTydIT5HCkhnTua=wWZW8&0onstiFeu?C*vJI_wl;Mh$O8q<90=V z5_j;rZ4n9T?of0nHWZh{z4!FIwSCUPTVI-Y^VY7%lH~a}7j+i^#{gvII0&j6p)1FD zee;rT8?Zk(`taoJ$&_vnAR$oCaCG?BrVlFaR>ZsFOsc#k(UsD*MaZYd(iz*d4WwMk zxjtg<-?(>U5dkhwOg4`EhpmyLPb@ZU?6}j_TM}fVepWVDKR5ZPiQl@PuRZ{ot1`;k z6zy4&5T%tXC4{+(3-YB~5=P$C5>X)E82Q2o_A0;oGH+~2T;Pqn=LYA`J@S63Jo3Db zFgZlCMlL$X?u=rD^+FKH+{OmXS;)Oj-j?51Y~}>hzKH{>h{|aS zsW-Tq%pGFI+);20ix?0OBsrByZQfFp{6imnO)2s$_Y~Cyk$Od$es~@Q-7^A6wM+ zg?j*4#=aXODdB5(Cd+sIZqMh3K0A~=cVX$=kNI;yUer1<*^;)3*H&?ZDQ#o;_`Ul}>At~TzfTtj}=lHK&c^$56B zfUhZ8z~7KU87}loGn`r6GE@Sa^2I3CtcsOjnP~D~fv_6LgF0Dx&Z^deXM$B~NC&p3 z277K&Y_W>D`hpSZ3;KN?`fUJTQwe@8_E2A-_r`o(c^?{Nunp=mfq!eFW!UybUmD2q zmjEnu@3na>QieWhNwTEnKX!@}H0*?A3swF27w za5)KPonO>+NMQ&h+%*mfq+cMX&^<4*NF|A5peC*hf{?w#*z$<@llKf}*|7*L}yv23cGX`gPtbsE%2BVa+Xj_7_c z3iHZ>vp;y_;g#7d3;UOjo#&68$I1lsO>FnyUjnlT%&}`T>@*win*Pa`#^y*@+JFqS zx>#UZvt+2_4Rt9)1Io3;T%0jRPpg*~V5>DsnPiZs&q`E8+I32Ci?)+9Ic(SSz_xUs{`DR2mE=SG+87^j=l68t$#y zm(06)^X`;+Z;X6WUd>JZrYzQ#uC&LE@pJbqcMm3Z@fB^cW9f>Tczt~GUen#SLlUv|>U!Q%4`eX&) z?8#V!1!a198Tt{DhZG@YNEJXR7=DnRWYJENd<#+G==keXG)E|T{$XaMm2 zLy&ecgAj#pB+%*?)Bz8;dTy8B&1QIj@_ybJmgb!9Ya}SYdGdI~H&?+e4OU>YD!0$&EVc{+jH;7s6 znvY>T?m!lBX8r)Q{|P@oI{sm&7|r2Vo+`C>)KNe$IWpE0-}{N>)2c^Ri^`V)8{N31 zY1drWLd#;)8D82DJ`SMhtMR(8HH4)t;hM9|D;DgJr&8wr@R=12p|Zw8D>A8G9zL;R zAT*`nUVIjU43Sj9UO`fSMgXG4piS&36|EF6#09-L=C)Dz+9U|~{em(ZV`@>k(iIH& z0>N>QAU}pbV-+OBUid@-KyH7)b7O+ZfY^PWaX)+jAwj}K7&EqU0Qe40jC!zXf^`TM zkDwL>>m2fp(o7r?zl8#f=qnG7`RL%N=P&~YAOYHj@SlRk1^7Ej{)N!|JE8p#!Vc~Z zWei(y*W9XER$3x*g!#sq`CZAguP12_Zyikf18{ZNMIEJJPpOVk=^T}R{^$!(G4neg zGf6kiTRlm}4^ZdG37DGss(ql7S#l9i@(8D!6M^$+YsXU~0hNr^z$qpIiTwDYa&4ZHU%N zw6s37&)UC2sZ|L&NUGV&I^!KAUnzq-h hAi{J@`@`1R)>RZ=Jy=1OK0iq48&bM0x72v4{Xdr-ts4LU literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_635842.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_635842.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..699f1944f23ae6afb3576358c1937be7ea73b4f1 GIT binary patch literal 10504 zcmc&)eN0rmow@@DgoQ(2F5|Kf~%_f@3}-Id+)5ZB%-Pc@!$6%)B|1OEY50l$9J;uugs9 zI`!st>Mi-}hYVcNyPB_3-6`(@16RcryND61GLFofD{-+y8qWGIHDJ^edV;IwN`bQ? z!=dE3O;Ep8DP`VV**g5?>+nO5+~)jkhFGowMp6jl-2xo>#*2E(SHE?gdZo8KZD6&w zL2qVmJ6HKGIbea4WJ#}C(KKnRd4i_d_vAHrJ_(`+jUFUCY?QB!g#a-0t%`hFv~MkMdU?{DkjG zC8DY&vR5MeB=V$0IqegE9%^(>uV=_T?32`^u4&20yL>Lc-S4%}xOgwGgH}8mDz8Tm zbaimWCXSdTLdev+v-Qt$4Mv`w$8xp1#7zPHeoBY z*){`no&lafC6)jss6%q9gi4_Z$;lcITB)4Gk(se$In_dyQ2a-=OV0zh+b)y{rFh&m zLYYwWM@gtrW`iZ{5D1|flq~U{iX($WjZh2qtA0ey~3U3MJiR>jv^PAg@qql`AiG z9?S|oEp?aB0yMZaoDR_fyYsZnP}_uNVfPHEB96xEyGPh1v}Kj|3VVdT_@1_6Y3)LN z)~NHCQK4z1L^&^+ch+gkqwJNj3+*r8Ryl7o0C@uSa1Lq&`97g?txzY_Ja>aoZC!-? z*=Y?S)W3YMfqgHjlU{vEowQ$239sg86jnI{!u}kcU{ol913zdb->TdeT6EX&N$3zt zIn`)An5S%8k!__?+bbLp_6Z#s&_@!}*&4b$a|)DPz7{q^9}is4wI<4*%|55+zq_pvXyT&g%yu-W0iYuSEI$Ufv}c9DeV( z+i7=job8}QIgqWFs6q7Ni=l7H;Bz|0Ts~Z{$6a8W8E8DZCwvj&();|pn{(M6gADi-k~8MS`9ff2(y-< z8-O=Si;1JKb7(1jk`^yF;3|jHDUqWRJ!*G2{Zo!HNsEpJj1)NW0U;fdZqV-bp#5cK zDR$pw_mE%GD`R|H(#lK7*PxA|WHLPi3xgS2^`xaddP%g{f+rspStHZ217d0O0-Gp07(A6^OQXAE72@V5 zk!=okCyj+mMcYN=_BboDZNcspQ&D70G*y3o_KS-*FUBVq8WTGXi>AZDp5*$J!iwls zv9LCHJZZEw9eNnt zEGmvn&m9cWNpo5BSgh?!bFBTwY0_zoez4x>-zrex$k%$T|4D!3s#vfs)+!d%#7&}U?^8-;Y+qAr%!UxP$`Y2XF;X;F z!LSQ!LsYVObF}NTs%Vc`ToaoVi|gZ^Vo_6wUa^!#FQC~uMY9biTv!s_Bop=$1{EVNy621~~%yVNUmEceUrmM7{v z5(NiE`d~=CLYpGCs9UtuiF6%^hT^L+a_&&E;i#tGfwP)9PMA&(3PYT97>x zo=MQ%p|xc1ARwfm3BF`N$D1jtq;YtuZqk@YvP3 z|Cjyot72XIVwG5TKs0v1Db<-GRX-6P(S}(%Oh>FybIkf%y8bbgjbUSC2oBQFZ|SBr zxV?l#UgZzK+lefvl8^z#&f^t$&JzMLl6D^z7J^WG=>Vc5?wq91ITAqtpXvnh6B0P{ zfc{I+cR&?T12yg3WrT4maAnkKIF5Qles4~KXprVMr9ilV2Dp$0R1r0+8I6s1rK~7R;G?I2_G7KncrSv|6EOt)vr-AMWvh_Jhfv*7_xO=#NL5Ez2 z|DRE8F#Z?7?*J{(DQKyKH~F4gaamKf>g-xhKLh7x2KIcNeOISf^hM`SFK9=QU!0RT zwX&)&JiS^$_h+=$2|9=56OZn| z0Q`I(!@>oLB(0xi!|aD9P_zs?yC=No!}f^dgN`LlX{ZC$w1?ZTcLh6xlZX>EbBK@Z zk4`4&D!IYl@ZN|cRvg-!U}}S1s6kGR(H!XtzZLBg*-8W|%%Sc`cer<%sT7&Y1XI1t z)Qe1g+#ElWU|N=$T_UsV-k$ph?jA@mhl5=!3>%sXzZ&t2%;sf7wP>hL7;2Uc4Wgj| zFvMWO(7LSOBkK3ubKM`kJDSiRS=M)p`tE;l{bux6P&^wvvSKRy$Q{xEY|K9iw-PXdPGZ4!g4%#GRc@Dd!jpI{v~GTl4d7Ah&GLUThdKlb-9Nx`}qr~ zLOY2a2T9Aj{8PNgiR_7zkoG{Dm{`>m(4I9~)l;t2_bc@s0CEJRndPLUc13l<= zGGFKu+aBEq2Dry0a&-kQwpO|m1gBE;J=iTOxZ96BN}NWr=aZ@n>%)d26Sj7-#Mb7dZ^)MVj4#$WUete_}PK{&?sbB({zAT3wsbH`wTG`>_t58ubX?nZQ zck=0|gAd)-%3x4d&~yjbi9ybBXl7^19|Wg1TW8ySyP7-3yb#Io9fFL^90U=*c6j_( zka$F!W|=SpYR(_+6H9j{=*HVu?p(WdZLxKcdvHiRbT-j)PNdHzX*P5*T=0}qmuaz{ z$pm+7ddcYafrFa1ac+>!d*LBb;Mn@7`Rk~n0apyU$Hwf~Uj=_1$4%GO;B{;xplIQ)DKa1#DINURnT zta0{m`0zufVo6i6#8f01i)>;ZGSy3(>LsT7d0WO7A*Fu3J9q>fKgJj#!Ho5Pa9~N( zvb3osp=k*nSkW1SfBDF4o!>mS8PL5s`bNC-&g-{c|EBxf-miNfnooqa;M5@789Dcm zE_K@moo{ zVAd2iMZlz%McY3;bmLIGDzUi{+F{Sak)St4Yd>we(G({WOBM(@mHJlQqgbAKJ0v z&o_p&iz47}K>*iD1+dFNJ)jp<`P>`D6mk?YkNlg!g%m_i;`C|5A@d#08fV%8$}|$l ztd-2;7Tq=|dMscN;7D+2L(WVY0RE+!u_(E$ z5#@p$Q*yq_c2-Wu>e!8JSzg`SZXt z2R~mA+zPU$`%&|7b&~KdcRQLSSt6Ex_tkfMBhkPknLi z=C#Fr-yZ(@@I&KC2$Wd2#){@V!DFDx+Cq*{+ehY5`}NZ(zSTMxh^s~G&Rp8#!jgG* zj*a%m-u~j9o9{d~}=*2Hdh=Q6LAa-deoVA5*KkuIJo$E~) zs$*|{F>rI>9(CV%*O=IGNHiRJupJzV;1Mtq!Q)tgI+2D{5(=(hjrj~R@K+u(pnfr) zfM5g(P&4~WZ(8#yLJ1-Zz-l&-M+)$z`cIHYWqs|Z1a;<{@+BBYKpW5n=z!j@SH>b} zK=bJY8p1^>upTgQlwfd^0Y+fprfK_lGy(XgGsQXhAY|b_VYAQpUGTNR>F~Mitv2w9 z5f}Qb!+9`@z@8{J*#7?t)5Pdq3onRJX|JdpE!j+LoyJSp7 z*JYR?>EOQ|Q)7Og52^9&Ya-v0VtH6zo_9I@_^trH;{OhK@4(MD4UZtPVlEDzcuX5H ztm==F5Z&E-ujKxgyIUU8hvBBIPpRF#(D|U@VQsId*&gf$2yiM^^^7G-n&Xazl11Hv zitlC;#RI`!h|21%k-#dYDbNOc5P~TR_Td8t>xWhpbQgN{+0;Dxrv8M0C;vGj6yvqz ze+qf2K%RT^7<@&Wf^QKLt>pM^hz_6T{Ff!#G3EFAr#vo6dldf-MpBQs;j;_CJ)ht8 z_5`28a_3zhpO=@^NEe|%fy5j`-{Vr%`s1Ed7H=DmR197?iIJDwKI|RiZV@uH{4r`z zcZ_?vsWI0f{+Cb$O+@$}!*f+llH~6R_IHHw--wFe5zGo%7_`2>?fSMAx+J7U$S1}x z?p```c8PN>@jlV&UnPhms-r4+omQPu!FAL3tCBCvylW%Lt2&KQ`v(J|-;^vWkJz3S z&??ocg;3G&mtQYmRrQc$^J+0sR5IT<*SK1eDz>!V8M`%>E}kTN$$xG9M}=-x*h?|H zm1tQI@6^6csG=~X={83=RnDuZ$+ zLdz^uR*|x<6+rpv>tqROithNd@kZk+0oeuWK6{tN*#{SXbLm$w7^q#c56bUxetCRa b9G_l8wQC*aWYNYpdd!pULigtU96ws@GfPs>zRkws7pKeSB=~ zx!r0ZV*z&dR;lmnoO91T_ug~PJy(CCX#)o7KZd^?U97;czo3ee3NwYr|Et2Vd5pmr zd;+^D=lDfjsuPo{i)tJrFc&^1QBsiQTxzGDZbLPOU{s8n(Juk!ETn?%>p|PYh5U zQ8SeQWd^7h<+UpRxQHsV5l;rD8b&K%Y8l)6_{Bmvg@su;TCU@byB;DY2Ol1R??ViS zPcNcF$*+S%7LkPp=eoe&M)z@(k?JKQc# z)T8oM2Rp?*45JYU^g+7eA=+4tDiJ>=;)5bSAZncUDL)G>I;YP&;&F4LYTR{Qq*)i| z^4tABd(g%DSRG)oh)q_HK1fH2y3?&ll7S<&O zzeC*4K!nP+fK}zdZp(p1vtO5Z4R5`j881}61+11Y%EE5vOZe?U4Tgm(_%gmC2sA=F z_#J$S)K=+7VvrJF&Q}IO3VcZxw~DU>oT{J}!-4=4s+K4~4cZVL8(+;^5uF@}sfG3mT2gxzPcFx@uzBxug}XLk&nNzT>Fs`W|rrUXLJ#TX8+6@ z(%EzH(lZzuv`3!7kamu=y8HQ3Mm635cC8{uKILocs+4$we@Qy)9sK?~2E~iRvFo`y zJ&z}{FJfhH2~+H4fxEh?-%P zvBd~cJIqOKy5ZoI)M|lN7M&FqIgp~B^RphtWp_-j700v1;hds3+v;3{b$UfZhQcri zVXf>HjY@lPYCN|B7;e@Vn6h&on0g*ZhD9b$R84`rH1M$l56F8$MHpBqdc^{z=<(Vm ziXtf!21MhChvobk9BOSmpdiuAWUvgrkr57BK~W>=k?aBK6ad&ayVVBqYme|!$0ctHiIkb^|x}Yv(@jKoN9$y ztH5``=)Yl4PA?o1S`Xi=7g~=CJC1)l zCRAKVT^IAeeSPQsLG;5#0l8d3@$`N3tm4zVYrt;lgF1 zt@~cT&^91c4WwSbC{$idz2y;hcm%6w#atX?KIx9Rqwgkcg1IK~wqR-pGDZ6$L zL>-CXIo)64$&R(|?j@=Vx@q(5`RMuB`NaNY?Qig8{msJ*Lqg5|MduQIG(y}r70kM$ z?$~JJT(b8!wq*ZJ`@*!a>%cvH$#gWLS}7=my-;9FoEFRtsA`SzLP1S}6U@7zYAK$t zo2&cO@>%(f^0cKk!F@4vbLP(Q-LYF^>H3!i%gc*z2!)523wnfto_pPY?Eig#y5M9) zvqD;b$nP}-a(Tf*HBOhrE*imRpN(+j4?r`*v?b5q^ey5--OB=f5cZ_b6tn%5zfT%w z$tW2sj+f3=C2YyMKa*|WLmM5XW3IS+&X<@@&it9&{}_%mCgRsb-SkzL$35y7wVcy2 z0XJCfv}CB*pCACOFW{dFKqu?+2UxEYX#yo7KSsK7cER@p@&o-m#$#jId#3^h{(Yv% z=}tflPO1GA2WVu1*QLwQ_8}S4@ZO1bXjIR1H*K@_&LeB>2)=@J544bd1OKy>QC9N!YPf~T&PT^qIrRaK(; zRNt9yHtP^`NNZJx!|WBvMe@*uV{({rbOrZhJRk7W0uGRV&}EmrCa%5Z(1Z^J#&tnK z2EGKSlH>MYVeCF>nl(m^vF?~ZJ}8v#Ns~=?X70Xy>+Qw1Mdn_Y(Dg>T^|U~qULmQ- zn^E&4jcS{gwIkL@k@j%lfGJB2-acADQ3KAWb5zuT^X9+KjsmD8dC3_AMJ;*qjRy9w zP=^D3hSA4~K1gk|4EjiJzuFC6f6mmA5>Q;tBS&NjY};_&EC=92@Z&y)MgQ+Gwqm4a z4@VC#87ori1F6abY4{o|P-|zj^S$11cX%2M;Vc!UVnb1LxEp$6x_Nqzjyr^c>O`AR zQ2U*+Hl=e3ZbAL)!8iyuoIL&dg;12W0Z7Jg)m4b5MeyAXH$;?fFf zmTdfVn%oxe{H*InSJIZ=-h?EA?5`tDmc{EnYrfH(#M9-wU$oisq)@NXt`=dGC45TS zKhE3}gDzNTV?XYIVB1!vcnm?VONyxpAv2yQV22EIhzO}5D2Gb{E?hEo#0g9?<=GJJ z8d+p802c6gTro6bD2}}@JvFj%T_;rGF-Ubst!MyQe}vS)W!6ls;pBzYfbmiz6U}IX zKp&h7wGrdBsGLucEBuTpPl#*eL3=S;g_DpD45LJa*C99H|3!=BwfRCfOx9>@=2~sK zo?2eF1!WXLSeLRg$kp~4m~7u>yGX~AV;k&=to%rSsWNQ7U=-Rvj1KJ|2tjAVW@GsR zMA%rK@J~*FQ}3@*Mq+e|o6?J2fxCCe!0UNZ4zBYC$aOzc$wBrpn}V~JZ_bB88G^qc zH%6Zc!GoNFppHZOk6(rmo&HC!q#jl>pG>v+&a1);>{4*qHamB`ZbSu^)F6k%Smc_W z0V}FvJd>hs+~t~rf?F0r%MrF3N|GmXi;%67WS`(DVU2|%4$-SMd&YX9RaDP-TvtVn z6!>CmQHyHQaSZ}szJM1(KTa^IQ=*QW4!B%Fmk0qdFT`=sAaw{<12%YHyKDr(p~u+W z@M$#4#!s?#$p4V7o#vVyh1^i)xF#D#iH7PLL}0WKBpG)xBJSZ3vx*gjTsDLBQ&yk- z0AR-9$NdRb5-cAmd{tESk!G1J6UeeOSuR<(9J^OVLYlTYM;CyKyeQ{Z8NA zfm;Lj@IR8jCzt5s;bQ<1Yx~;^5axhebI&Juye-}lcfe0KcXHWUD_CoPefEn>H!r2F zZ3{huwR3UkUg6hg1xt@W_C&R-D$G>!0GPiJz3|b+WqPMT?@ZIR;bRYg;S( zj!#pyNDPIs{qu+B4*jAl@lwjzkWx2DmQ&hV&Riphj+JCSt3lmtWj^8lvWZ%R8{;Ou zU_x=p#KJi1lVoDmAuv}AjvU>PcJmt_yk}`aHD>4W0yjTGG03^MJ-D4`I0x1EHjGa) zDo~BmZ8`|j3F&kA$@32i9*h{gqvpSBYHnxv1chthRY~xv zpK77O$V#t4Sx&%($Y+{2mxD61kcl_(v~*vWUo>=z>K`Ev!Ge4o0u`;>Yz0A8Y6m3^ z$X*2mSs_aL50}(bqn^En+L7{GQ)FyCs!F<2Qe(SNRnleV93LL3E7^xfT1`|9*ssc} zO)H0K*uOyl=+yz<73H5X^c3cx^Q>SfsdE`I`Bp4W@z#JH*bCY))IXZ;-A^WZnngx<85<+=sOYheRI() zALZlS;MhYf@9yBO!6ox65%o$zQQROD?1a!;LCMdjL-VQ` z6vwIs9O&5t(E}eHOsR`yKN=l;C1n-C6KH|Ijo~5H=66&S9D%?|?jw2jKq^Er?2amX z;M7~t!+Ch*e*t=$oWDnC_#999GVl@7hIAnk4sO=(MoNPR?~GU7&V-Xf21dghJUARj zqrAD%*@g`kU>vL)itx~dEUjqtyWp|K>EK-Uwnnhus13biG2YC(R@;8bZc6vYMoPYS zqI;p?Fbuuj1Qay69r~)GCGPOJbo_ zjoPdY_pcUUlr=mkU-t7R~b9TPEP~FaYm4 zBB_*E8|o(=0l&{5@VZ28kM#c`Q8niAi+XUrIKS)K6e|Y{jK|P(;~AHigXb*$aj#NT zfrzD5618%tylzMvNFo!Bvas!L-vo0Dla1yTY2ZVXJ|-~X>SDiyDl!+`pFx$cs&E|t z3r77LM*llj@c=IfmwZt5UeyD#IHEKF5gc;`2ap&jlsXI+Gf43ninezzoc>hv=2&*q1ZM=NnO>H7b;NQZrY9 zRxi;f(K<$~S`D=JhZ0O*wqh-hH9pyHB8b&Wj37TKf3JL%cp2}6FoCsrzG<#$wOFpT zw%wh$HKEi7@KgA|HvMy^w?u(0MX)8->wZ|%qpdMEt`qfZt<8 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_731602.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_731602.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f91e51e1e9bdb044f9ed948c62ec4e079e1494b GIT binary patch literal 9375 zcmcIKYitwemNT};9^2zroChQU9P%KBkPt!%j|4~{36PMMLJ6dyUdLkxCyt$r9Z0g{ zyQbT!HdGs@cdugZ)e3WOCBg-XAXRGGmA2_>rFI|Ede+>HcP=YMy8PkKDqVV|-TkwB zz8Q~gT&L;n?tYTzd!2K>*ZI!t^WW=q8Vu6MfAs5XKP$zs-=T_>bm_v&H5|hhFdCzA z4>l;~_#iISW#jTe1unJ`gM9m1y|}?TQWBi4EG7h1*>kTJ@gd zVS2bhQjO6?H0i`fi=-K3b6UMYyJmxS?FQ|XQ$yZ)i$wksepsJ&bY14c7pk9VK^_iV9%;46N7ics3d8)JT^^?7}3ugwqZ0BC7z1r8y zQ$26K{?;23r0T*OeP=G(dL3J`n~IT8mc#GmK_z0L05Fo6MrBedo74Cs*~2tMa7zxX zA`2sxSOaT*n5Gfjx(Te2EzZD}v!!hLEP-Lc3U&)yF$*+;m24$jDzt42s4^pLWw*`3 zp;uV9jXE92~S;d5Ole>qKL+ z1~htbCu?Dge-L(N(G#L=GM);8o}0sb0aubM{6MQ?G6gYi5gGvOuP>kmz}&kX8_^4HzU) zI+t=G?lQ3@Y#Ex1ge~Q#uqb{NI67Fz)`@#4!m%~NN_VllXk6GA)7TJ(F)|$U{#(u} zjM5CUIg4d@zJ3!(JOgRXj9w&g-tJ8>Tbr|{ zVriso&epSy(z(upvJ2J%f2Yr#6*(BSBJP!11+Eof@hJYGQh~8$w0!Ioob)AD_MN+- zmiCCv=W~tJ6Ku+vBPihf#5S?Dh4zb;jpoYPG}>r+K0a`8TB1rT#jzP|qq%N8Pl3B- zIdiT2?nprO|6#6HvY$c1)Qa;8`?TBt0<8xs1Ra=}aK1pd_XQH1FVO9_8hBa2!)qoO z-w0zLx4CIv;h^2udHJ~8%PZ}S-8VO6J8X8EX0$+?*El`Sai=$63(W9@Kj34W)_pwQ!Q-7gev()E9d?h?&mczQ)&2nE zrkytXc(yo}DGqzH#cZo13+wRmhBO6{s%<2{!OQE@5W->r;MKa9$ItP2KacnFni0k~ zHDU9+0TL-=2JIOGR|I1~mtYK1dAxiA_Db!Y8n;2g22+uWUfw7b-CmnOiYG;ifT4A{ z8Gj&+LuCg73WP{V$h=~0j92lALHeAhl9t*cw!9BY?{aIJTYuz>^FGD z2s%VfdK<3WJ-k|6Fe)mBZSDg+KFljm+7KcCK3Iz74^02{M$-)+Gv=SLJDg3|?2a*e z;F{07*FFMA-Q;H+O#$clL=)7S+}^+e+qgY2KILhgn6W{V4UReB^V%FfIPBJz=7an8 zLl3VLfox+=#_RMjr(srzMg7gN{UPkza!jk6+Z)*%JO5O(BXshav1maxuZrQZwxqE> zbSA%^DyxpWxw58&qBvzJN|e@dhPu${R8e_sh%2&&dY_p~qcicNTv_81l{B}9`chO` ztR-H_RqW-c#!y#EXH1x@IbHQ4#Zd=BUC$}~+_}iP*yK~HD%7>6H%C33e&@$;-oJA1 z%HrfwL$dZbr$4^Z%;`>rx>K8U>x!a7oUZEQ&iiNXomsRmRVHhWak^un?ln>!zVuN? zG!X5H*kUuBu72_0QqAZ1(yj-`m#=Vpx;b4pNA`p|Qa=>%xrUmnkJLx|;tdIN(7iBK;V_;unxcK2(GqXvjJp>vamJP~ zkp%-S#jw?&wNj;D}HM67*}(ct315iz48v%c7dzB zkSx9!Rs+$ZvRHFm`OC@0n)r=-hNUXb(!y07;7qSAn>l0Ks^J7@IPo_{UzLAZo;37@ z6)Bx5dL~ByvLiOa71wdPy2af~7e8-a8sK)fb5wg6Un8|2GIQ4>*Q55B={C!e6=6Br z{-Ib8oI3E1u8AX?U@yMYb_-v(V)`TN4H#*Nn%@266s-5s&pTp)SkJsIKEsvl<4pUO z4ldXHHNL#-$?=sdTuVP^>JKZQk(xPOL>G0%9CuxJT=Dbw2ks55*0yrBt;y|&enTE! zqjb^IpX0FuG1dHz_+ie}uy}om{=8#p6H@^P5@u3D zicoMoJ_g)yaQg2Q^1pT)RK_!cC*lP@`SQV?Jab#+Z)x$kjoyBOGOc1sx<=BGj05Kt_?s$V(AJ#ORX2N(!07SV04r zc$gkPNYKil8h#r1X=yx2(W;;>YzPBIntYE8>azwis~3z|Wcm=G*5xpwDQ_xJrRCDx zH|SSo`tgGJoAqbu69w@%>(A1cZGvC8KZ`Hl1ix^91|OXpt7H|dn$@scmSS~})DP1J z91c#@1bokvquVvirhgQcmOFihSF(+*m45L|V zuVsA7AFvEN0rtAv?;iFzEyFXGLl&2hv7G6@c))_--QJN0GSQMz($4E{hx5e=7?44% zSA^h!3rDUU9DLVh6JaDog`>{D{ddp%BvpIT#IKea-GlYhl_s zIc4_-!2GcZ8r;9XvCZQHYt4VG5dpr9@Y~??A0g~HRW#Qd=>=VUAg<>snv&GMPY!;1 z^vTiX$(5RwpKzU*k}U%qHISl=(XEl!*9o~%vyREt8s-cTuo5}*e$~LTs=;QwW}k2} zZ=x9*yv}Jbn|wE%3|KBUutC9|bX?;Jur&fR!YMJY1E63#fNd$5lU5UR7Pa-Fj|F{@ zu_*{arVo|&!iSgl+9r4sK47>y=?y2uTtpbz4Mh8zXE~iXA#Uf^4L6bjCl@;nKYs^s z(0_)obrPd=bFGoqo2Nn@p~;l4cdkJb_)}Vav?@9|Uw5lLq3B7J_aqfP;r5i4 z3QtB}i+0>Pl2CLe$~u#Z&hU|EhLTl7EoZ2W+mnX+(CIamF7%UUrqYFq`3lhSrr4#$ z=EtoMT9=!@IP}?}r>0Y3<(jz!OrM-jF;w-2P|xob3PSTOg%L&38ctD$v|wwrJJucV zf}*HhA8n zvv(0smhaxMSvN;TsOb6IhFHnn$~%>D`!9B&k(jDi=|k+8m!v*$ak0@{?U{x{ zx-lpV%0d3)$c?9QLBHvd_G9J0fE&x=@KyW;AyEQdh>e1ypPvLBBIr*R&*&pGsx}`I z0%Dm6a-o>9q)~zjR#8wUgO+=wk|c{Lmx2VqB8irl=t8hckkbSTD$wLC%no!vp~YAg zDvy@r5rC?+^3k&I_R%=4mSzIlT_I9oN#sobuV|@R)hH5{d5I=6yQ5;sO&~~Cz0pb0 z8f58YX08q(#U!Lc0S0uoA_fHGvS(K)K!A;ImndcPN7@WWXaZGIJFSt{atgb80M0=R znp2I!Jnzd`E##W_Ce!{?bA_xdgMv(?tZwn5fYx772v^~HZfx(HfU4%P;*8gn`)pfDmH zM&k(vJ>CfB3Db{C=#;G{9)}=_!v~smWXd+J@pfD6>yzH6kycwoUR%(MaNPw;#0V6lFj@ zX-pPB74At=*2OB0Y5-5Za5Hk@ossca0K7fHV-H`7x+1n!s+yy!lT__0wTGkjEOsn1 zNoqe(0XE z5rdscYFnraJXoz>@Wh+J;cT9F2~&QiE1G*d@^-8$c0OKp=TfZVQP<<%2fZu!SLB!E zQ{9<@l z@@Pc4F30qxFeX@Lha!h=A6eB}IIShA-5Kw`fBN3(U!7eAH@oI=vgYXWTVD)(Ht?5M zR@={W?dOy27ZTbF3B?6btLFM%8Lj$PaMYV=RTKdJzS|kZAsG8T_dbaHwOpMJJcg2> zT@|342=FXHI|r44jE*j(jX{4XLBeL%QOG}pkqUT*@a&Yf3_zPHS-J3JKxU;noq1RQ zS$hFyqDIy@D~7L8*j|Cc_CZa~{E$G%OW@mpkG#PQh==KU3vwKuk3=~x&P9~tLLeW# zDhm?+l^)m^RwebgWNk3=hi`$TR)OqOihfg8FbTrSiVG6NN-{J63qnbhM1qVFBfrfD zvr>jZ<^hPBUIcAH9|S0x)(xmMh(3aStrDXUeqJpKQHYK*NNzwh(jD;ga)@7UAaJx( zUL=zQFNB!}lw0ug|1*3*+~`Z@JQ2@rZ>T2)cMopT?W^L&I~zHByAR@Yub9%LHf)cm zLLJZb=DD%RSgbi_kJUXQ9_t?HmPfwue&&6uKL?Z`_%%3h2%QlkWWv=mmc=^a_?<4! zxHInH3|0tg0SOoxk%*fXTOS{LaO|o6NLZe#*b%S&HNM!gw42-0mZ)#zcD8dB?JK2! zhkxCYcyoyB9ZH-Y;ySK!w_U*-Y;iFA4-&oH=R6MAr2{~&rqHXs9`5P9}JaR}xSR0dT+GN=w@ zOkQ+JAy`DR@`vfm4B-}n)wpqBe67H)FG4VDfi0~vkpaQ9gcl7ln57LkA)fBA`<=FC za3_)W6y7IjFXKmeNJ|M`7clD{R7N~0>P10A9!J$K_y~&BHq5JrZ9?#zr-p4Pat=-> zXjmS%@w#;IS`44@D#q!b@&x=FFmJIqF(1P?9RbOaX8sgL`W5{AGw=yvYo_AR+2^Eg zP8$J_Duze9;|D(}`E=WpZBNPLz^xni)U_;itn7YTcaBq3hq}N)>x)-`&sNe1G3AnF z)k@{pv&rJY(7AOAQx9Cf4E3xVvy5iN`(YE2ZEcmmut4?#gXdP}ebitSFTdz| zi_32Gz+2E1MEQAADlzB@iC2KPdySWmx&yo#tQ&v8IX%Io{qTzriSaRT$GrjJ5X9p^ zS6SEzUMudQZN%rHA7BFii?#;YHtwURJkDdx8~`J2;V*~px*W&x-(l2mG2OpnmA}Qb zpyy=d2jw@**JW+^K3JZ)WT9ccVcjIwnwlSb9(bf$FMbaH$A-U6_ZCU8W&~S|>5X$k zk)eMtf!a$sCZj}(oieNw@5a9&YVo3!iu!=P$*#*$d3*I;&mE6gK5+Qcz9)S`c};7a zYm2n4Blxx2dNJSM`f0X+>{30HGYhh@9XdOf4MgL17|K@-ZH;dlyItInX qsmJUCRw%A?ebx76pHNJ^V}FTyU+yo*&EK?PhU%og=B8HIyZ-@=kir@O literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_76683.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_76683.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03b8bb06166720fae17136cadafaa0982e67177e GIT binary patch literal 11272 zcmcgydrVtbn!nf2YhOR`gEvVGA%GJCg%D^HnuG*G35A5brp;v5V|)!UFT=ei1owLC zj@qr=js~+cyEvU$)g8@loe5INXtbr$NKHp0wYw|r`i{|n?q(&VskF-eMJAE<@z3se z&h<48W1!7SdnB&UOTW0@@N69_F$zY>s2DY)VYEM0GZaHD>h3}> zC~eE5XY@Hd^w;81Ji}x7T0E-f;L*d3Hs|y8`cXW?^UQpPHNymB{29rZhRH>TTQ9+w<2C>6yZ}Rd+KT%fHuj%zmcGON`j_jHuX4 z52=`vx7F7SOf^#qRAqUn%Ag)}acg4Ax1q1th8}uj_T=XoGBA5#6lNIfKA^}qR@7U* zddD{Pa&LKfV5KUdw*uw>Q}s4^%>o*3$*z}mmUPxULX+)(bPB$Y2ogRHl%voGN`zA_ z8cx4@{!FjyQqQ$6SKl+G{%1O6{c!ul8Ja^*&M|ISNrfF+J|$0`6fi zC+bjs#?4N0kHV-c1^S><=v0bipGfwLWUr|9xF!QE)M!2ai6P%GCn`t1v!appa^8R| z;CBVRte-_gWp(I-ZUL4?A9U_LCGe3GB0}Py{F4SDzoJLU3TM|PqD(+?|LM6M< zDsyN{b7;||wYb@D)Lj5!%D=;frwH-hfu_ zh%e*!24NZaqAc$|elPIs3&OmEKo)Xf2~dJMB&U)uYI9a4kGbJ(ynRR!$gb5$9@Dd^mgf!cn<^F7FG(QgoRJaj0h_dHiQ zU&9|_G>cl<5QOS?p|8!JoWszM73)~TE<8%Ufj7x3{U>{1mRqut8+VbX0(k{8*H(F% zJ-5j1S*b7Z2cZp427|RJ)Rc#Fv$dH&%r^z$-eGij9gpxY@Xa~p7XAp|g3n?j>M3-T zuge*=HLEDu3|~J|m@Q;8IT}@zt;}ZRvGYftpI6qYpbX}o5 zYHMTU4?0u?U-$gIhFYKVwrP9L+vY`H!MAO>H4L&2@GovT_eQxT^wOVjNoQy)hbDWX zj`78eaH^VrOi)uHr?V@@R{jf;2MLp+nk9j#jdc!)YK z!1@@k%RRoC9nEG3bFvfJTF)k}XF@b*BxJf9-dZssnlcp9AcnV;Cq(@)>z|%l zX-WVMAeyN44-Ii>ZKTW~Y#Oc^h0Q~&$%z`g;+#mkJ)Y_D=`q+9RwPG7?Wl_z^Lf08 z5UU937PW&e-vruWojj;pqDERawgzp6n!Di}3W!t?76a!%I&idzszFz9674SIof1h; zW;n19(zsw3G#K2$cu;394|g2oIi5fL_xZ*dKRe1zx;@^;8*b02J8;84aTvhnM4*vl zJ&ghH_+%rL8qxS*h5?v+!{n?Bs$8H^0sn-{;|INJYi)h;2sDUBiN-bRWhcC2tQ*FM z4i)!Pz>#6%X$wIc=MF~?#|KvQ`@&})nk{poXed#+Vy+Iq@>PL#K|8OFlkxU+L0!1# z;g<4R@&1HQC~i!t>}zIgs;EXV*MwhLvzEoL3sz^i_o1yQHk&vu6gMmx(zchwudW%2 z<3|z|Liu6A&=BrgGZv(5`vv3vq(LyWguB)YD&jLjK~1=4&18#>38t!_U;Oy>k6uqs zE!C$FoD@ta!`(Zy7z^V%!C3k8(;xSI)RS~Cm8PqY3&!K&v+G7{>{}lUBvkQl355re zI-#&hFg7inU2ge}Wx4JC`3F9s<-A}#A3nPVltU8bygz9XiVh}yLSdU=Y+J5ZNFwe3w|10v7QU}eg$%xemuA^JU<+F zCrZ;rhgOaC5pvCHi}~gcCu)-RUrZ&7KWY@LO%d&ywe*9^_~m%5V693F2-ccp+tR?D zk+V`mdZNp8h(BtoJ-Yv*g@H9u+uR(p9oDA=75-9w9Qp?JPHUY77J zH3~HyLTSf??k_ukck$Af-9qQ()VGF&&LN?6C~X-AsnSNS84A}-#qkEgRQFg(8lCIr zqFCUAf!K^--j`?+%+*PgU}}A=Rv6o!C{-4HM7?1ktfld$`SJuQ6jVmksQ=ENRmQso zdv#(;u-7F!1l!>VwGLEQ&~!Y4#Tn786%@uxfFW^7FxN#?YnI}8_57YhjZjb*QLj_B zAF~TL=WoW{A2PR@pErMe?4x7py$wt1PpJFUudSbzeOi`&p)+0FB~V=vdYv-GoN=FE ztre(R5Dmqb66E}e#Os2!L7*Dm(`+;mro$U8gw8%^ikd)na6Xu(s=^&R=+5@~mS1)c+!%79_ zFrFhAsQU@QsCe=x#9wFQ7+6l$mD}u=QG?^6ltBU>7jk#EBoxnDq+~D`Qo)#zE4(Fb zp;CiOszM+YKrB#)G`xx?cn!*rAZ)ND$#WmcIxm|95K+=+d)I=ChNbXYiEPT!vCy_yMh~lY{{OhUAj+6=@e1X$Q9)Qg&7#h2X%ANa!)y@O(77ceZdJ)nE z%H@6-S{`91@tu>f>Uv(sQ@kD^DE*ESxwDKOxfCONWRA^Wpu*5pmhk+mgWa?(ViLojI!^8b*|V*GZ`dhKAE5 z?~7rZeL>E(o~L;OW8(EvDB=HYSLh#{0_Td}W)_)s%4#sp>vp=wc+HNj^iHi1UQt~l zoaTNO9YXA71QorGV5MVxnhQ7vy$0sC|1`%AAp*dg?n$qx zy6N-IfVYF368O4rfeYfFo?swiIOZOoWDj9l0JfeRq8cMu3|Cnn)U$AF6QYv4F(zuc zscEk_=ye)d4{AXM13MYm-4InXSpE!_k5R9vLLPyrn3)1#4;a(SxMoDsCmJ#Z>@SG= zpbH|DXrK&-F41Q|E&E-l5sg#RZkSoni_tRsJ*c||e_SsdgfOv2>E;Yk!}}&MytHBN zX!Pj&Z87(I?I~4pq#f0?Mcdx$40nX5pb!a0Ps9g)NAFFkt|R~jCC;R3U-`1^OGC56ql`&Ma(Ks(ZO)hb;l&~-^m@^qSR40o=>NT#DN#R39dzN)Vh^i^qn^{W1$pg)*w zN)D#=O{==2g6`PVWQq3+i)f#oQBqhai+CIdG>i zZEsw)Hw*UWUmgAA`2FK)`$>U18KpLqgsEslLl~`debK&nW!g}&YN!#gtNw!ioK9W7vU>TtaQXU*?K=R= zErkoE^QH0TcrZC|`^_bCsVQwa8a|gZRyzR8Gf7M0tWbI==@CksR!drilGdfE<)-`7 z>5`XMOHK(Tr+!=bMfvCD>5@Ky>WdoWvDzK0wrasvowx{7bKfaS+g=F2x=!0x>AeEI zH%&Vd2Ue?^g{tOsRm;-YYHOd++LvxUpBjK|zXJNe4kBeEZ;9He8Lw~nMu5G7Dl}7) zqH0)gV49uqpd%A$QKB#nugT&n6|VtYrRKF;aFsfIBmk6B8p)k+rY@?B!YvQB1-bKE zQXq@%^1U-IYO=^oJz?S1`LGV7m2D|WM0648iJ{)6Leypun5S?Un*gMz2E3?5J>(aF-bY1e1l$zAyKw{P;Rcj3?9zR9oHg+G7$CcknQ{`~Em{HnVdbVm}r8Y~&z zA1v#cn4AtcZe>l_ri}^iAIMzGj6*svd5EyV3L36VNO!)Y*HITNc4i4*JlH$a;OO=T z9KACygI&WW2N^dOp*c@3ltOBVhW7-;c>hwd+pAO*+zf_9Y9gEdoOm)En5`)a>dfHR z5t$wz7H~e0j3EUqhpIm(e1#q?z*cLdUs5*F{&NR z?I7bm5j>O$mO!{C*Rr{|=7#n$Kg9L86OfUJ-vJR^8+`sVOgyAab99uBb;JVkexdkK znyO#CdGC!oZ!9-0GY?J(CoZNNFA3D8HOdfqJ!*ceR+eh;t%kt}1{hAAX!LOqyvag^ zQ8d~q_=su<`UGaBO{m8eL%y*w7Y^n?Bnw9poizI+BZ4Jweju^IPC2&A?vt-z_&U28kAFm35r_7M(2!%8|+Y-39 z)j2jJk41rdJI@m$JS59ri8he+zQML;E)DP&E1`D^d}~UweE7LZg?+_59=U8~c`eYx z)$xw99hTz<`&C5F`13vaA-j`reA(wqUWuRebYQJ?=;1_uYz2==F_O7)p{s1lZ^s!8 zx`@CwX4ZU*Eab=Kc@WFair%ReKfLL8q(D1q;0RPe%GBg1eXJ?g5_88|en*vHLy+Ln(A1zixHrc9C9 z4=N+PKvjSrLK)}uQGKj+p?$s`LXjsEr_)qzQuoUyh-*_v!yW5-`>MW7(3hq474fML zXKv5@WLD5atn^;@o$k9m;7^24!_^n=!8d^i9d$XO5CA^|h@Ocg_t*fhKHwm`CjD>-MA9W1GcPNW z85OnQ7f+7`I1XWrtl5N&WH;^7u>36R^#rgw1Abut251-Ik3)7dOsre%;d5Vs;vm>C z5GNtP-}xyEb4YkUw33qQ45C&VgA1fHFA6ikT0{`?ol`LBMM+8)|e~TJycrDpK zhrDQ(=H4&{Z>H1m4kc1@j)iNS5GnU`z#o{NfETwj_}`mE<%lmJ>Hx@a0q?CzHiJ?w zdnY(pk{W6}A?OD59NV!MX@W6T{w0!`n=9k-ABndvd_ z3HFyzM8<_fXlg@AlH?xE%uUq(f5gQnN<7Al)J2q-LHLH9V$rmlW__3E&2byzA* z-TV(w6S70`So*=WcqcEy9sw)lM-+YqP>m?wKvd}G32J$K%t)Hz2R^L7UB5v<79N`n x_YFAv;K~=*KZj94ZOSz${f@J%=RV`q`r)_J#Ht=`_rcCw`i=B{|mvEbh7{e literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_769812.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_769812.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9af62e01802802e09367714f38e26d660ba7491 GIT binary patch literal 11689 zcmb7Kdu$s=df(-9xqOMDUevTQvP#g5PM<9HO8lBfqMElF8c zmN~je8Zu5#RIfczxB#ZF2bii|SShZEg8mVwxFP`xv|7`x*sTu;?);DcQ|bUM4h7n8 zX8DprMac-7{buHynQxxI@B4=Ttk=^71cUK!#_v=U#2-*16^u#ZrA|f=%LGfX

pL z`jS^koR&??uPR8wK~CUyQdOldD|09&bldY#ik0WiA62l5cNGtk6|Sl^1Y6819mGUl zjy6=g;Mh&d?UFB7p)#c_Hm9hGF6<4*aA$K`z+$BFRcmAk~ zH9>nin2(a(`egHS*WI-*^Dg`aFg}_sgz@RweQeRY+zG*A#w2ir4wkeOjB5TRiK+EC(<$wZhBrlmNn}ChAGqB`?WVyd$ z4^CN{vvLnkd786|uktfl+#u1;z><8~gQR5sYM$ULc-4ZEAoxmNjlK&|foFIsO;_`p z6phOK`?K0t?Q9+D?oUfolf_Bp;4B*>z3J>+N9>^@E$IQTKE;~Vsv(t^`SRtN@ueg= zOWark$!qyqz6^~etA#yg2#@$6e~>SqggNsc;z^h<8nprQrA^HXlBj12>{f+zqz7flC;KRRD5WE6rB|jw_jKiF!8O+s`)pveo z%~BNK#2=U_Or8V0Wa`c|GJTEArAFqg zlsTCzb1qlrw9T`f`I3cUY*U=W#@=G=ZkN~Qba`;MNm+4)2FB*P#mv}P*6A8!Mx0&` zmc{M!&iK5!y1D7*COtDYy8{)1s30(Obg9YZ_BuR_&*NYj=*P)Aa95KK&gGb5IQPw) zPLBgO&7n)~1-fltAj(3DO4fU8#sP!^i8O==+q~0b(TOtelt|BT?lI0bZFRDu%ErO8 zIz+{k+s2B@1qbK$i0V;Dd5^S-%4DIPbjwp$L1%314- zmlKtf(q|5TDAYNo^yx9;BrQz_rnR$H53DWgu#QZNS&8xNgf~0Ein__UbS2le)cCfP zcUy`T4OpN|ciYqBJIcK~(k{`Q8Ahs+@m!hHxiVd%9)_KkdwP3ve0$QnJ;{n11k2;F z^E@JjJ^<0v90Fqw-Rq)e#>rYo(9sboL?RYZi72nAagUCA99|EQDnweLWY!84`oKOC zyHZ}D*U7O!J^rUg_LA4_vfACSqmH+~a-y{jN<_WHWR*@2*9%f<;NQ~$2PQy#bAT|Jm-S2fh%KJS z1kOJ*nil;be`GjrtP1o!(-{^ALxYj(xUM|Vv!x7KL0PV&Bza!N4O1yzZ>!iXh0C6uJJa+kM(b1sx zIZScUQc*-1VSm*X8Cx;L3Wbv9wc)3x;~=W3IMN`PYJoDC!o8uXkbG+&VK}~3Nf^o^ zS2pT8go+M9-?8r8xNt)_eM8XS*!YQ4&^rUY&lIZn`)>CIFNZHKUtYTW(V4p!RxZQ~ z4?R`XeWN0@`o%+`L*avWS_0h(LrLU(OfQrlTc;kM5>C7>l)s)Ruf;hadqsQV#OcQs z!igc2cnuB2_&OOl_dMCbkHYrl(WTLk-ne^hAefb4PW|pq&TMnQmC;K{UvO@_cia+?`xL{o|s6LXML~3(f&a50sVp2uMTR0bWl4g zXUTUdKjlR&;J{N@bs<1qlYv5|0GKBx0OWTDHmzjuQl3_=cPUpqNUHO@lq=Ka?($6k@;&;`U7qPbcX_7&+~t}66?^obyFAl>?($6k%01f8U7l$_clloJ zr^gqfdG%-+uL@~d)km_QktA`8u&9UZ$?btR9@?H}_z9=^ZCaBv1K6g&CA70#kVg6x&pJo>frG2wAF#L&lINc0YOom zpfthOpegOIIScpI?hs-> zB5!S0Gzk?=FY^rzEbcEwK*;lXVgX{Bl_k0_{26cvACh6G+vk(}SQ<wf;-|qmgDG%0Kpnr9 z*XA-TA0*q$*|wZjJ=COU0kx2Grc(V+6z>oYl3A-(^6k=8O~2Ng#@C6`Z}R~L%aIw) z(hIOpvf3=q!Ane3emnP}_CZp1zs_5oDrf1G1mNTWwc4-eb-b3>vpSZ3PmZ+!Ju5k6 zp!P&fs$S;%-Lj-wUnqdKbsx4V2pREHLoV0e7{SYaMY6Ib@^4i?BmCf3PH~<)IcQW7x@VGhVC-Xl6``GJd&JSHW%G`8% z#~BZR0`R$52HCSsGj1qv1e4r8iA-|Le4=U2Hsy18n&8H7f}A65&A69099|y>u7%l~ z4&+*};-W^W#YSdf#)|t+x)-u}zqD(0Myo1zlTe=ss-E|*A;`doJ8h_OwHa)4&3F_kp0oE_Cs+5wkmz>VPcxmZzw zoD!lEPZNg(5|w!D+$hi-x_UU|60;b&Yp9UMqX*<)aEKJ1PAkWX0U|A(4>rZX+}h59yiRf4W+Q&%JCYNA~+GIk_7AJ-iQGf!_= zyd1h5J`L|mrD z{!oA9Y_LDBtA=!5VQ?;@480TR0~1!K5Bfr{L@0q~z+9sBia3nMX>2j`L5zF};Lj6)QYJux-}`Vx7S zVa>8($*_^v6r~`F?SGQj6gZ#I7K9yOT54-I6t$8{L4g+wE(r(Ay@Pg_Y+=SO2O<^H z?;rmHxw{!{SiO+@6}zYxB%ks0f0ufpgccO63;~H>#mW&l1H>Z5gUW)R9Lfb4fT~eh z)>WTF*#Ox9tak=@gGi3nugSbY@M*>WEDcO4kT0vyp?%1qY0DOi~|JFL(e_TWK#)5M!{O0G%u#!|eh6x2JOM)MaYKdeED1t^F)-UdLi zkb`0YXGGfhtQ;xVEInBV&>&-Jm(rV7%#fm(R6s!8{8o+Q2vMLHv>ZDEi}x;?_lpFp zmA)3;!U{&&vyP_JtuHAQ=t5vXyXbwLsT&1#W6r@jN1b*X0=}dgki0a~;0$OF$ew|9 z4(k*ihI7E90T`A(yBB==N%h5Jn|3f5-58sPnYH>{(iJ2rIans{#~>@mqUZ*S94JD^ zjkBR>3?gvxOI&CNHe2)(WUIk1i)~sbxbrQCP7& zv@~>k5b&5w7>hTJRf4f9DvO?p8xH{*=n58xLPNKE0Ft=`MF)pNDgbIz0U){F5Uvkd z1HI2_<6=Xo0erjF;n83B-0#2FzfOH-{M7iA?hAB-Q!%eFJiIs`lqU+xBNvwjgPLc? zg2lH(Z%0}q+)8_Nc%|c^f|x(s}``@V0nqvtVdO7JHA-*7JBI-gaKl zpAX6s=ECKZOD7|axcNZzji_zITo;rFyYEn2#YA4oRw-_dygHiS3Wg zt{zPp|bZ0eGW7WtpOHf_?{|Ov+~ZG=}A>v(6e|ebSq*C@q)Hv zL(#G3k-icx>$hp+B8{|qE7s~owP$9^>9tPU;WA2xG*JI990v^4$%{SB0Yu80WJz?r zcw`rdKSdsc^zApdZGIe@u}DN$7Uj5q1H6 zRaz0zdDDIbHtnGV(1%24IN(Qj)_u(HDstqLuKF zH)SO87Ob3H3+os{`oNfNnsUs+%kfRlhMv?Vb9X*<$->j~U?bDN&9Fu0 zW?g?v%Ly=GEzF1mz4|$r6ZOa%0(i?pILY9KwxB}BZbj8_@7Q!jJ>@QzUP+{45F%L# z`)MV>LuXRLKa6@9|1c&Z=))j51dwM%iOVQLaL2uiB6LbH)QM`B8_a32QZcAu%)=HZ z0uS!bP=swy1GYW6C6s~D?k^DabBIK8P*n9>-SCqG40R&e&tcpfK!sH5iV~?5f@7@n zAT<*WJrg7p+n0QiWF0Ae0|T1X)5iT0s>I;mW0f$i6zsCaW1(Yrj)!d@bZ#h$gPn;y z6MCb&eI8(JPC_rODSScDRz|CzXlphUH5=L*u$44Ln^c(q4%se&s*IWhYCpD0!>SM9 zmMsc)B&fnHhCtUdeg5LL&^6>bj}G7e(Y+t74S#aw<10_~rvts;&;$TuMX2J>s}g2r z(_AZv^V4Qu3D&&L{} z=H$9v*xwcG3%?n@66p$Gd7`V75YAk-d}`@b)D-QCn-6Z98wGP?+}s?b5Y{eiRvi(l zj;wXXtJ;Nv_VwD&T0U+0vO0cpKrjy=mcj@nIjQfTyLT>L(kK)(234Wipyt1g1icR| zc3pm8@V`wN`i~yuEB(jPZhcRy{GVEtJss-YCf_%}hP}%2tRKbXzzR6wB>Zkrmi0rP z4(`h&j{+H%a2v`W zBoWk4vr3+JlE@1S&o!zcdAl&@Ua4o6KQ;+7JjATwc2eShqcwB#Agrx zZiICue@D{Vj$J2sTf>9OFHi#!6s5IBk+g~A2=)e9M@03A72hYKZp4~}^rThPCm;8v z3qwTet#8Wf+5XT12TeL*1(INH4)kv9$UBM-*DCb<6#hN8APNv)A={w<}#s0y4zd(a#h!g`H`eDNv>0TYR?HH-WO1D_WJ;3W4DW#mbdsgg!5 z8l`XZ6#QVv2S0`osT9R&fnAhOIN`@0*u-$m&v0BX%2Zv#UQ$sGmo|P}lGXyL&3HlZ ze27|UCaq)cDVBQ;m7POT4{CwxoOZLmDaUEl|3K*elhFT;DBq$9 z8TEeY?b0n-JJ|q_R#nw^r&gx6j8dxe^kwc`*Fu^7Wy literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_790411.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_790411.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c3f16a61352b5805ea4a0a09a88f1db5cf2bb8 GIT binary patch literal 12519 zcmd^FTWlNGnVumz|Nou;@ZXF^0|C#d@7H60t|N#)qJr{RHZqU1YJ#{; za0EwA5?AFnd6mRv)s*_GhLr25tCWiHkmIbPApgeoCHb-&qFrjd_N-R@#we zo@!@QKT2@MUy@wOD0$DMtQxF;sl0xLDOc*zj4S)8fWVozsJCoUZ{4EaR=j@1z?HqN z`FNv6<$cwlC-eljo3nd}aeF~`HnLnf)K?U#uh@cr#}@pR9wS#(TxNvjs^6xrmOwxD za65sgrbv4rKceAkfUgwu)o#&FUGX-!`Yrg8UYgqldd*xT*YGxZ)dF9eB{x6T1=6v1 z9d&4My$8~Df`p_FZiEvO7dvf5ysdLUtFJ3+EJpWW_@Tt-{iS2*s!pj#GPTi%8 z120~7o?F)-Ju1mCJwN62`CX$PpQK0mn=XFVw;n=GsZfF}f9S6XD` z{Ag(S8AmN|dN}f)>f;TLV8u4P>H^*!1-xj`n;wA@tRHW*7p&ZdS1Z_Zyj6l-s0vU7 z5v&#}gz5n32<{Yi3U*vqpagYDPrXnl zSdpHB{-D{~^l)Up??{h8A&tRZLcL)7PSUy2U}-IuZV+||W#37nC9y_6hssnUwj0 z4#6Sp#&fk_*uQB+x(T%IO$BWhG3CM@p^4Mp(;A)!QBIiW=`v0`P)IYdxS=%_Ni|M z(Hf)V`=>R8us3-8`&pA<*Y~p{!RG~)&?SHEIZMFiR_6IlJKv~Kg5BScqR>|%Y?jcX zTD&qRgmO+ju@9U}p_MQ6F&mRDbPLCY6ZZ_tmqQW@g|@Xaza~nqn1q8FQe95;croJi=R&cG!pObi;vJTx*3|-i8viMJ2t@ z&wDwK(>0aPPUNz~1=;Cbtvk=_o|X(78irx?`TVrRDD{EaiNXR%jPf({vreBEhPxGX zkkqp%jd+q zfOJU)5OcXFV5M(Kx?$%u&d9o9Ziby`O=V`(0B}mwDCSWOLz}QzNRY;>hKQiTH|8Dj zOL|l#XFlg8wq`d0@qI7Df(w+I~8Gdo(s0cO|Nq zbYgvH+J0o!cr;9AO3l#$v9vDHDVFXJQ<>7TxHeW2rq;@<d6k4*!Q-4Qlqs&*yXe{BY3RZS_nDXfb&X7WslInll^ z%sez(ZC2}RhTB0w{;~6Q6(%iL4`e zX6c&P+#|9*NT?TtY=$rmLbe?rn&W-(@>qAGSG4X)nnlb0ur55ee7MwdjJ6Xl|1U-GDEX&32s&whFXqz#L7gpS(d=7jC{bmP}hHb#um5g0i4d%Ec{eA0wOf@Pe$>G6(^`6aE-?V5yB zsdf(gK)wqxXnBu+o}YFjpQ}8`tKnlL0-^8F@91w60x_ODOe$;u)Hkv|-C2mIDWxBg z0=X@>s(_m)c09cfZNEzgwSuY`l~beR(7Sqp3bptFGRn_*;V3rhB6LJQqJsVxIqpGi zkt5DMjRIjPAwtOp?FG1n|$Yk+O;F^K|1;$PsW5kCUb7-4|*!6?{a;1Qzd z$57%$$v8@6TQoE8pPl#1s97@#pt)cQ(1)sEsh%ErCU$H&ejY^yP(wZ&EdGCmh-@jr z8W&GQPOLE1DNTE-racWQQ=MTft4y89)U7bPQktIB?w&NH%&s+tUF?i>zSS4%4b6dj zSY#t?bSPp9^{(lRp&vXn+izFJs=(HoY|jlCLo{p4DPHZ7J!N}_XVdPltT z=iPU@llAGUX4F2i?XEOk5pVqYfjbA1WV&+yvsPQ4SPhgmYbRJs=saFeiM;79+{EbP zz?T0Aklo^%f9Dj``>Pb?5Znc=%C71oUjGQz+#t|$U`2qN zIAk>Ba4N?H;5#Uwu>~|Eq6J__7SKxy=%ym{(gM1<2;EXZw-%w>3g~nZy1IZ~T7*vK z>8R)1?93uFo!gm=zzF)>oNcTUvfuCD)er@{w)nbk?1aD4>ayFet}fTpZ?0~hR(y5y z^x~_Vrx#z{JiYkp=IO;(H%~9Vx_SCvZ*`4=QDAd(wrO<*K)K_{rEhu^EHAJlY5+K8 za!@Z&e=Y~mAP(pCd50%Gb;ApF`B)J~YnK84s5vi0Q6@Z|S;&ma>RWpFI>_O0phG1? z2?Q$EVgC~A)hI#!Y8@GXb!3~?am)?J1yQV@M}gyY1h9ByhaDz~!jT&uAwNlT!|S=p z??zeml*C>e82jyOvAH=%$r z48|@v_9tm>qPQcZu3G?*(9uiMpo5pB!Eq!l3P^b%=IEZ8l+-s|lM;idI8FuCngCS! z(7y~6&?CQP;A4CW@9}ql?Opi!{ss&aU>%y!8y0mD-FtdCvpxl67&aV;bjODtFf}R7 zsZ`ymG^9)o0wdapHar(S8~I_vknBy9yuP@MCar4u^>pVGEE}Wl(;+e)pLu>W@##dGIT`Bv zdV`;g_uZirry$kFIf zoP6&D>XuU6A9E&LY1GxPjh4l0(5L$NP@?|MEAgs(r$0G&_uPH*U+I6QSBx))&OC(p z5-5nb{rCrA^%`6HfjN3AULWg^R>C1Xp%U3$nPQcX)TFVYDElKyWo*c538VR&svhao zK%38 zM%3;KYu}xVHpI=bec?AB&{f%TqO|5oC1I_&-4ttz&n3>@`5}bbhCbD%YC6A||MJbx z-%MS)yn5-HcS1kE5k2Mo{C6 zK^;U`ko=}c&~Yk3t00_WSd3Ex2FfApVyGJ0(B=>w3DK9W+M)j9cJkN6v_FU0i$!Dv z!k>%Y6hr0^vDr9>6h-6|r@EILlkD3#K=R7fIbg&bPKU8?1PcHk?!oxb0JD!yGz=g_ z`8=b;tHI{bAG2`<&*Cdh8eOF25z>eXW23!rai|Zf{#VCIZZl5f_Z;WO9}R?J<1{K* zp(b%IygOUN~h8O%EAOAC?UybgbAZ~*0o9Ovz84PUc z?-`i`M|b#-K>iE(`ThYm8hCr~3)aPf$iQ2tAy&a>XhV1?q6NpVH-yI{wV{(4*1UK& zayH7x+v2V?Ter&MIKC=*c3G8XkFT<)ME2DE-urWD_C<8zVqIK_EW|C~GF53)!>Vb& zXxg7Vlng8%N}IY@O=m>Ynfr5J?Mj=@hhBg<$f~(sG}k8@Kcjxbe9EktyH?GEqIvME zxzsCHQs%+5`PB!nT}@qoBlY?lD1Ge>(d-PJ#UA9xx5Da7i5cQ5V~{c1qH}N|(Hk?u z4M^?YUz5qhNv&Ahy3`@o9$7vm)|?0%K4^=&qt88{E3-PHv;tfyYY+8-u2oa5XsS(U z?|D9%xI3|8I<{&W5KRML4W%x>oH7lhO|Lw-@@nep_0+ZND7|uBG`$YJM{ev}I?{llJnxu_GPtLeLPGfzaF(epevL<7|Rw1BZFxO+WUVVxD)t0L1def^C+J7 zaZqHsQkpL85Zr~96k%9Ye*6Ry_&+KcRDXM55&jHBLL|EIuQ!Dr1+<{&Ue2}$HQVR| z+x|@YsDeLWQGz;ebFV``g4&=iNP{KK*>U7Y1f4)b90z$jc!CCw5)57vEGYwDuy&9~ zC><{8=Q$TY1zEVrN48R8;4ASv-7cTU+2)`WWDXxK+$CcK-}W4ayg10`VC5hox@wm% z-H{G{2q_|mB?>-Z=PYuPgSxeHjRJN?UU;R}eZMzye58u4r9t~zGjY%8o zN3l(7Xb^8QI4Gf69fY|=Z-<&kv%gM2f&V8|VZ-yu&q7`@$zyMsguh12!=R% z(c%9V{xOMm&HHEk^V1$ldkX(UQBse4;g%c%z&<}1H<$~=%$z}gskrQ!_QB=7{>5n} ztA>fjDH~%kCM1I-#XYj;X zHA#|xB-lR?#(yWO|3EN+wpHek{hivkYS-wpuolHi6a4b2)Ww%loF~QmM7uvr5GPfq zRPY*9T~NVCr@yGmMFk(P{(UDZ{Hx{9nCO7Jg$MQ6mBr67;a zY?f7ar^xQi>QEWZX_mzckqcQpDj5i~Ez6)>387_HDZ5D7ADdA5ED8D_QLW?wQ2vha zmXK9Ley!B{7}aGbF10=}Na{vGU literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_811684.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_811684.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3400ec71c5d1826df3d7a61d26acc9a365201dd4 GIT binary patch literal 13580 zcmdTqYiv_jn)mwsPHZQ3UL*waB!Q3+2+$PB8yb=}|$XfdNob-N>#&dzG6R;%%;(ZqX~m7*P_75~!Gkyi6(_dDm>j$<4` zrz5TQBtG}NzVDp#opZj|Isc;7swfCw`#w54~OT1&^#A}Hlq&>qj|1AwCPbjAB2`! z2Vzt@N-?@$(u;YwMZ(P1p?rRB3N!k3$_?w38`mi>Sf{)&cln5hF|Dp_dWF@D`6_ir z!&EXwc51vtm{DfT6x(&MD@v|P&*VX?(sf#uty8{fopOv-!IT5md}a$%ag{!!haI9% zkIQg{uHX6?4`k4m+(BeiIE{Z#f)HJ&IEtfZX^I+$Uv_XO*9U4P*>x0kI|<}`b_-fUs)7iO#u_t>o6LZD_GfPR9- z>_HAc*CB|!)UrA+Q0;GDQk5cA=?9%R`sVxMsw!_+LX#iZ?LQhELhAA*)fS}M@~f7c zZ3}I2RRdJ2bpf$&ulMCAfP>=Yokw5p>bDG}5`(r(4u)q*)=&lUb;_MyyYcjjVJt#m z==4gdHfSkIw^VZ~PB9DpWI*KD0}Rb6Qm1$py7Oo_HK%+oIOQ5REvEsFP5ns3P$QyZ zz{E*07Ea5EXJJmb4%Z5Gq$k}|5V8#*@}7c_Zvc_cX*lr&^zYGgT24Yz22P(fQbX40 zU^P6(47$mZ3z6}G6Qjek;EW8tC`oa6Pr(LIG?$lzl4d{^ZU80X@{>@q3@8(!EsTvJ zRPdNNGo)Zlo+3A#Vj*w^sU0zlsZi`G;flxAsoK@AJDnWYq?74H3I}JJgSi)uGEM6mC zHO)FlS?iQ#jNxTg*6J9w^O8xYl_4jb#ix#!k3fNYS4;gSUOeq)d6mm%owU0wIIWzu zvrZvN5AgILuW$h;0UmN8opPm9vur-#WAzxRbb5@jObH(cFJhDl5Zt1O&p@!J05q8RkuPhNoRTJiF87(f3GrdJW%UD7(5y3{-`q4H*Z1e zMsN2cnB^(|slc(|zBx;{HFDvBcAvNRp+a>{?w1Et!5^UfhDcS^_CV3;?MM|Xg99kP zJ|c;BJy0C-b}Z|S0qdM8*gU6!N#qv?JLa_BBM-Gzp_6fK<6FJnPM;{D(p;;7 zWab9fWa5c*P>Le2isKy@*#iW~MAIH^^2Ouf*v~R(9G&j@IKtxQUppgb%v~ zBA)IC2)WJa7#SON@!|>l74`*O*^Z+ZadZGTlOKEK=#k@=eqPpdvS-lJZ^L^!74Xp} z@-G0K>O06kQwFERj3O&AtMbphK}&IprvbOb+=hf)S`sJa@>7z|Q@lZ28K+OTEdkje zny3SMNEzBw3coUkiZgLWEDwQumO={i(K0tyJxJh+xiZ3GjcF{KxH5p*Gz;UHw+e@ZlY7cxG!=k1AD3m|!RwM0Kto8+pNwELXXf=_%b*fPVd7Vei#(g)Svm#~BDs$|p)Fs9}GeF$yf>vI4jDDcYBFb$H)S z0FAL_x~RCF)HOZz8{n$WrWTBI6_;{Rl^U~W+Xk3wo{348fvI5wOjXasL}y@X%x($$ z3~OX$H+i;mjf{raX2vWMLN;R+F%!=UDeg4>Y{+NmYMvd?yBgbb|A%8};&yNigrjZ4 zC&aUpYg&_U&RB^|`H-XrSI_O_YSTNQHlua!eW~IKSrH?0(6A1vlb*#Vpa!t7*=1{Y zY)eKjxyIDOHD`_qPJ2d3?wwU+Ousj1~!*K*EOjc0c@3@63>AiF&$!=?CCc=lvax+v=a z;5FLK)#lp&-0oann(?xyF{Dlrw`WnC+C?;Vh5G+NQIOe!4NdG~ylcNFjmw{=5TY*C(&}=9dWa7SYsTeyq;;AtQTZ;2iVj;2Epy8#&lw>2tW27H)~VIxL^^pcBMC-cvaM~ktj&zRdRG zarEJ+21m7c9K@`g0QKB*ZqjL+aDj!mns)GtQPw##jg4LP$QbK#TLddqBNV%^murfb z10)%PVj^jc@-l2nVr!Fp9m($mn9P_gE?!2c4xrEOzS`mHt_Pfzf)l~C1h|*ULjbomUFyx0)}{*mt#^)Ixq1GoH#!=dY*^nzy;7B$;z!3*ZHv#H!mjx z1BxaSFhsIqFfChi91f(5J<9EO&p-TM~cYlNSx7Q_bFy*r;}_pY^q zAL)(P&2#3UK6o^)ulK(6Fe|?-HwG?zB`-=C3w(-%$?R7z8;gQRk+BYE%R^F4ISEo+LhE9Lk6dN8xFO4C?n72QnF9}}!r8Qg-8lC^UNF}P+74e|*Zlv$I zFT3yl>cq+Wm%hx4of$=kN0ENi`_dz@r0e{3fuo_un6c?&_LJGWvkz3A-mXV5qvpV9 zkXzC=`sjaC>jTzcBQn*5YvZQohxvx<@;P}>7CIU3{&i)zZ^06|bWeZxO;p`^e{0M# z@gRTF*YU8V>_+W;ZKxsA`&o08iMHK+`QB*km7y;tQ0q|awc!UPulxER7MI>I&znOf zk($r+(e9}EZqvQK`_?b|QOnD*Q)eC&pY^@;$XtB=(%hxs#c=PhnZ0OYo=3~CYiNew}{V$-xx^M?7*zW5|m@7h6bAPvF+6JUvY`N2Rt1WKYwPf0Z zOnZdc#!a2RzC>|lXlQ;@xC#~RSTeUDbIZpqpS0a=i<@6qG9N(Z1HWNEpZ#<;Zhi^( zR2G~?h4taHs9-Pj0Q<58nHs`xqQd>Y-bb+OYUXNSi}!?f#SOK-u7{?g>!Wj{!I99F zPfd3_?hT>ZK2+X!e5RQ$MT6EDGZAw*5#vOI&e)~d6^a@Vy(apW zqReMRzCOI4F=lh9CcHIn*x~D1R_Fqap<-m%h7{X;vIJz?!9!xK3M-MJ87Z1^nK7sd z)2OIE{5mr4M~eMEbwXhXaG?vxTo)cdrhQ1U@5ic@dP=u_rIAvABj|@*!eEBQHq(cF57r{Xme3G%6mE;Oex{GKp@tWtR@871=@0p2*m&wQ7mYyH~5 zNYECR{ZY|Gd{MTv^8&ah7WKrfBZByC$=lMmb@+m69E!7pM-DCs1^kqZ$fIJ!Ki2sq zKB-SPB4#946&|%;$B2oWqt%Tul90L)9!XoO)f^@!VVN{5hunY=VX8pjDhUH$kNJd?)tO? zi=#Jy&t1O?uNhCf306mOJDFtU#D%5jG|%qE0aqYwnx9Uaz1#DcUEkQ|B1#f!T(n7L|Rk>T221 zPV`RK0Z0gq-@}M@LiFEW>I2h_()m)jGLVFSfa==g#VdUpWcL+fVz54p z3TuJ9OitP7{}@s;s8xKy5+X6-H)+|0-X1(|&X16Sf`3y!*QYq`xL zSW9c25ErS*+BaE!wtfzZY4l9&Anz`O%L*DJ84{^F$)y1Z+n z{!wpV!e|2JQHTm5rOWg7{aGT?Z1HxlD6lBRTBJsQt>51dQWrFYVBX@)9nUS#y`kTq z{^Ill&7hxtl&seWJ3t*8jaWZ9fA{ZtznqECzNRsLh?#tn+fZ)oSWA#1#-E?jeG z`>pL!I=*dxqPjWtUz6Beo2aP%bE#HZyHZB!i-OJXX?)Ew$&OIv&FY2fa7U>6!$!!G z9mms`dDW^ZWy?Hb_~>HM@ni(zbwwN;!11Jjfwt)YzF0bSBt2L>g2|B`QtTJrW{Af# zMQpMmLo=d9aT*J}yBG&gUv|i9C55{T3HHW;e+q<>N5)Ci6eq)JMvDClx$bkZKR4Zv z9KY`%{=Dx1m5Uvx=9f$KxYNG9T<@+@?A4>-6uGIp+QOY?(c?LMtU!6}vB(N*1aGc- zMG8No#CHlAqs`^R%K4%Qgha|euH+QNpQikg7_;^b=#is`)Mz~L-UD8pQOx(pgx`*+ z!FLRMcNX@ZN69I3QJNDmGNRCk7zO?c^hoa&B_|ul-kYqD@tQ@($v0>r2T=^Ih}hB$ zKK>sF$LmD`YnD&g?bE>2M}-b{97-mZ4+_Fz0LqEb6~x5bIE$%Xp9fB(^Hy*dNiL4r zFY!|HAb^)x!S#e=>t$YQ1DQ1qZYSrAgAqJVKgQrd42m#&UV^X1saCd$0FviP z_KSA#@{pvHVO-eTPsBqqN_sNZJZWMVq0v9U&xN%EkPG>GY}LYzv_g4J=9l@{!2YhP8KJFkKLt}zgC&>5Ea`s14I-kuyKU^<`hzaC@|NRl4#B=~wEKNviQG@F;y zwMbnXro%09^$u_MvaTR7<9CFlNLw8iBW=TyrUhwQz~VX=*X&zTA42LwadqdCx(BIy zerNyv_!r}G^$@_R^RDgl?+d*8&Wqmehg!q6v;MO|J+YGLUlfNrZuTzphMPY;645UV z{-*l#hEE&rTmNwW%kvMkr(h2G1#`y&7XruTl-D(Lnh#FgI6Hqfo?j0W&oc%(0^Ly8 z?Vr8o@q3bWdG+4I39Ttu7qZT8N7_2<$1svUi$l%xWuYt34;Q(BYWK$RV}Jz6MEnq9 zOI))F2moVk_qPW{adoM;I{}`A-NCAux;!Q+C&GiQB6}0ofek}sOCEPf3NOuw@OT&4 z$9lqpJlBzUfMvs@JbZN`d$xT7*N(Ek$_5gpSrN8vKon$$6sz`7?zEl4Nm6X5n|gg2 zMyJTAlV;S#2>YP0PAxqBg6CkcFLScFw16E?3xL(Jq;LY8%fu!j7hK;LCp>u*sLYYa zYVZh71@@VO(94rIbU`3nLc%qzUc`MI|adpo={357O5`<3vNC$Hez zoBlVk#`>$7o1O(v^yKHSfBO1^yngTDW#9&9{Ab=7hKE4dZ}|78ee^@UDKPW%GG9l+ zSPVYcpt>VE5cGID-mnzzbKfW69Wp zj4dDU{^Z5GFUE}>zHY27cg#z{*H+ZvJCZP$1doGC{#vLJZEC`<;N4Mu^hmS}?dnCQ zUhwD@l?C0Qj?k6RQB={4ikc%QqRr8ku{ zS@7~?skHEF88`=vYmnR3**<>#s$p-jn4jz(z2Yk|u zCpJ<}438+2yk;2ow@2oYdlVjJ#-mB#rTFx5;@e5p96VG~PBlh*)SOy4C;MO`aA!5c zSV12l#^Xy+0vB6E?yKE)xQ?+|U3N=zy_6gpa*4$_SS*`~kR~P^XbY9NpbJM^aD?Fn zdy|H%6F60HgCakdykOCsvtUgDuGXM}r!BlzFy1ZB8Ta%IT%3SmIy32Zt)nUw2*Yx+ ztR1fIh+g4pVgCtwdd~M)>ul!a+74Cpbh{CA+UfKQGc)=O(z>0=angX5`sU%P4?OVyGG)C_r z*=ImY%yu5HKm@eJM25>n>^^x+fly%o85bA@CUh0je4a4>9h30Za|Yh8@rqPRco8Zd zA9It(pS%)YthwCw%hN1*JFHF-_BdqBQ?bY5#*NTai&B z#oHxsm3$)-(*;k<{CJQ$P}-d_`138>C}WEc#Zj6p2;} z(u5eI>76S&s&ea1&4Oknk7V;xSv|=bQduL(7NoL;Bx|C|%5SvKx38Et01XfzAAE#wQW%fp=Y%9qoIKkkq|Z}R zvBGY~GJH2lf_3KMp<+Ld}d16S`P2!jn9}xPHnu;+^ILQI(`7)HFWh<1EE0oJuC|9gduFPKEOVO&U zNID-C-!v;{^p((CB*t9Z8w_{x-`zWFk4UB7?` ze}q?_IM&&A($>LKdpnQr>#&_@>u$9jelB<9xg0y!;zP< zmPmpkIkR+MiR7&pm+A(R{43>AOygfB>-~8dkYpx|m}X_HLD~WS{8eB{)|i4VV9jj7 z1jv4WA)Ch*P5>7EBDRP%3uVP#z%n^PK3hBi*t6zT+Y+`I+LTNvFl+*V{G~zVU{t6oRSF(DV z7~GFxOie~_+17@T)+Pfk+sepUy)>#o(@GX&*TA`o{S`rDy#(3=tOC`r7F3GLlch>j zie@IvZXJt@5E<=MLc6ug+Nqyww>|^EEax|CV2!>5@!h*r0X zUCWX}s+yIv_<(eiNcDta($0d~Y$q^|=8iaJg?fQDs9xYED^2=0X3Ql!tv9AH-NaV1 z7J+g#UbBQ0TbaSpm8Mt1uubI5W(maw63DcaJPvSr4`Dr zW7k7okWE7Udf-u$kg5seAKoA_c7y*-mS8o@<)9?DSnYC-lWP1sf3%OX@Rz??km9ft zn4*litVz#7O0hfHH`!+3$K>4~$1Z1{Ru%~^9cOECn>xY(t78l3w6@~Boc3NGaV}UO z*9Si#SP7mSvU>)3x!Zfu$t)mvE6(IVw17nC0+MG7=y<$s!*3r zLHA$cWkdbLyu!}dhx?qoe8S1NJ-o8l<+gk4xAF2}-;mSA%RLwEBTy>>kx)rPZ*E@R zKP*6v!Prh-M!UTL>UBd}WB0m;`W-eqP2b+d<1IYCm&f<<_+cJD!s8vh+T&&VX{XIT zluiz&l06y8;Z&(34eLmk_Kcto!$5nd%kh-TW}e#YZXwxPsgPyd7ydXS492@z1#-P`MdQZbhD zgS^t?wKHC*NG67*#2_q!ZN$rXctwvGYwZq)Z^-8YQZl>(Ei>v*h=+MidTgE=74fto zOn3^RFVuQ$gS<+>W6(Cj^}+xwiu-dYdisP#IBMrt8f3+&fm8?t+WNLxlO+n~q>Y~^(i+Z=9SuBN&- zw$;A@HM~v)vJEhQOhOsW@_;Gg$*2kInhv%h8=&!IoJnZr&B> zOi%@p`e-RvvYDf{1X>e1L(E*s=_+R^j%okR^E02GndzEU#@9D-rlz2BUY8Tz@UbznKT67ggS=_^Rr5)ol6fnfS(CoMqSCXuP!LUh^Z}u^|4K%9%VE zIv73>&Hnvr zlqYgaBP^G@X%^?qbwO1kuPjP&c{PyDt()!Qa<>Q7^X5XRC|Db{as``$)Dxq5x?rjx zVg$}8vd$d(hpstt_6%3E>z6KckD#>4@YA^>*G56xT3z8d+fpZ z<=8v##>U@$ieV#o3$ch{t;7KowGr(E6t+SHxue7}l-n;xBls~?c#`Nw(c6TL_ydO7 zi5@Z6gCI{54ivd@2k{KUOf(TxB#Z^qMN>sLx7}*G(G)kX3o4$Zc0earSUt0YD{P1x z-U!Mc>y49Rp)ugPX0D)W1}0e_*Ea;o`JCLaC*%qe34LC~79JI?!a=o@Wo~-%3Ii8L>EEd~J%PL~qzm^(dS@Dbr`oy%f(BbKpwCHX9Xm7{2! zf-^5E=1#fO%i@zSfUSP8sD=@;vBO->3Wv z!6cQu5=nI^&M--37AMnmrb#MW{qxsMB@5%wGGvHBW*3j49ivI9KxU#6T< z-opWBq~=@}yhy<%gn=wi9-(KV9eWQ?o364s;)FCIZHVfX0o}5U0i{`t5E&*WRZtVu z2B{#~OVY}#D!7$rQ@F^$uNWDEcUrD`=8z)^#g`5_8c2pdQl<}!Lr z8CF-Qmq}MTpyrF)XRl8g)dAZVx6fXm9-kPv+AaPjIjBI3ytfQo`7{=jI`{(M3 zU}Z0blWcXILL`_)kkSGIj3IDaPIV@s(m$2U ziWGv#1v?TFbyh7?1`&g7b|S&$V2^NIWXhp{mxC(adznF(K?dClk?F|p^@1@$AzFCA zD%3w@)iFq{GK1*jLLVf+88l%*Va2ODY~oTg?NG+!4xa1-`_M89d63ErSr&edK9;xM zdP^iy(Tb#DegQ4+!_RXS*yDc!*rE#48z%cgeOC?!S^}fN@`TAeT{Ts8r313#iI>(y z&uh(~1Z|nR^gvs-h!Ju{LQ^$^&ve~ZM8{_?&0d`Cj;Y>Sl0iib02COdn`{gbj4#D>imRX27XC6@(&XI5R(@#XS9nvGmTX5Us}Sc%YuL2khJq!v$h!x%SAyS< zkOE6NpaAc1_`b+%W`(0vEs3yIsD+f%%R-7;DmeO9pcYL@)XJDIdaf=D;CHF}3?2CS z%PJsyCr83@0##$T$ zR9Y=VK9ASZHWzdyL zI4_fq0P}ko!D>JX4g*43urMs30}|vo63-xm1vzd(Z7-ou34D0+QonPISK7e~V8}ks zRG>0BnE%1#EO)?hG9t=(1&DNyn-O~AWgg#!A$fUaKN8ZA2~LD|N4g$p*Tl$k(aNYj zdLmZcalie3bL`C7nB`oYJQvfhnFqN?hR6?3Oo|2C*d33@)lHAoJ2~~vU$=hO z@oh(3eJHSZUPDb*g{nY$ACDO0nv$EoTkH+?i^KOSeq;UK`fzP0r|Ep8?&8#4vEwJ= z>XVPuXE^nlhi5N5I_u!hI^t>?`T*qGL+x<(RK%$=KufEi+yS_BMToez1mUI&`oeT@ zPmC-UM6Dww>43T^DLb;7loccuEnJF#+woQQ(~D#%x`zJg5|~&)0+ZU>3`-lw5k2c4$0HYz_nPnmUpb`ADl!As&2{Itt(?uo8Adt=3r3B#4$|QLx zK>f+|S>YLwdh7_kwb>smX&5+xR>l?19pDudR$0X?(mazA;zb`_@0ly`u@YO7*M20p zm=W|r)Li&qkb2UU(}GGb6@|%P*U0 zLG=!5ElR5~j|V-DCD|m3CarH6T}FlLHS-Zv{}Fzke}@kUrsuAUDnqy~L?NkSMpEUv zZ`wcQkL%Wh)fo0E*s`ymn!Ye~Ax>3-lSf(>eV_x;q|pZ3LNzzn->SJ$12@OI=%ueu zee?F+x9=N&Tlhoa!{Xzd<~X<@aO!ig>($!H;!tsT?+0tdOhKq1Yz*&>XyWQhQLwBJ zt-mgx)=lZ+nl;h9&)0%igDGC#z-b!h6pxw?aZQKro8nEKoVqi!hn5e@9;u5sbrImP z`NrnBx*BS=x*19>HqrTt>aijBE8B=0-h4c6<8?_tC%4Z#;`8!K z(GknzK{8F^W)zDYohu>+=Qlj7rDr4a_W)ZA$F1iwBm&sHF*nej@gxBcneAWaeN%R~ z?4fE;pydTM8|r6U?rnUy;V?%6qX!$qoze1V6qdIIT$%Fbl=n*SPsDS(1BVwWOr0P0 zFUm>1B5-igfKlea5$UaoZqPtTL}74fKrg)o48#R<3l-#wV!#7wUN264iwis!ec&F! ztE3cz6eqlFpx?`@!T-VIb&iiPg0mN|*eCcy@G9^O6CEQ|l2Jo^w<=Bw7aU#&1Qyna zmwSC9E~mI?ND2t_;gR}6`dA&0)%#d|wkts_R+=dy zpjzCB5yn9NwY68)MhZWB_aWYx_=~O}TE?bDQdr<&0?W;NDqn-^5=!bCdxc$;p>$!{ zE$Rj(rnl4!3?r29JM^38@0*44d7WVqwF5nbLVNA}mGetV$UWVQoAHwHsp<1m=NBQyKU&T(I|@+D=!v^k`|2@Eq3`rArjttw*odm;eQHwZJ^!$)o#V;daqzuK#82LvOqhMrrm3O3K&XN=qP^eNAYDWDW z6cVWAMJcEkq@cb>9jX_uXhR#(JdT76m!cx1f%(cWVCm6>G6Z{ zHauNYMW_g7Eu(i+-KM8=8N(9k#wF5COQf6gr+29t%UiNLsfyFzPBpWRDR2|rg?U<3 z2~Q|^&0KQox6#(0T2(M2&U-=flEnErMemB(S zT>(?i)V@V_7Q*-yW@Ko+Le@4sL?Uc|7zFr3B61BQNhe>^O0-v(<5Q zL9Bg2?69C|JO1*iW2ur4(cWGw5!8J{mmEI7)8qCDDij`ZvID+{6{rRZc@k7kDv(D7 zvRxn#?}xC$)!*0U_4ov7kNb+C8}hjwt|7+h@UYH-i>v|&SkyRH1)iWdaqP&!la9kI zI&xgO>lc+lt$6(Ra&%_-p=YA&Ms+T740%^_CsiU)v`krPPTAs|vZzCg94F(8d$`s|4MG_$vuRbV<<4;gpCESB+9pdMZ_gQV|WNTg8!Li7XvMp`LSUOkK5znqEL% zlOvTpdJJdgEL{?RMpoULOfVC;D@*qyu-c#2oj_;G;hb--D7XTy5Va*uo&L5E5%Sf8 zFmh|SHKHWNa$Gg;Z!K5LnefQ2g^;2Ljq$piy7HXP+$wI}LMbrvyNdLQ9;oALMfs-X z{(SNa)Gt#j$t*x7OSquV01-^;D}2 z>H4`UQ929d$r&^GSkgJY&8@cyQ^VR^NX_7R9@xM&aT`WqqhnN~usd<<5zoLz7##|& z?uBwn(QM)xxJ_A_z-DfqbYS z5%S$56v3?yG=I8@&NPE`6p_9AcTQUni6X14o&D^fL>hww?i7UCgY#zNcA4M(Q&cY%l zYDs}|UiSJPqLYb*0Vi&gf{gKB8F2qcy2{%Ufx6`M^$4;qV47(@@|vtO8C@j z1c|>_P!F*EaDZNNco~6mG2UT8ddb@-$epaS&*K(kqi(j}Cn&ml`}b2)q6J{CkoutwyeiN%`16I4Dw z>t);y=cR10ClkDo6YR^Ry0T?meS$v4L9E2#S)7n9fxF@GtQT}C5-|Z7d)WS=0iPG* z3pEJViW-VGdO!jXaJ7!{Qj6oiH#y#)pP1VcLL?E`fL7z#@KK~kl=ysXcU3tkd5 zVlt`^IAETs56HZ~tIG$e;x`@dfw6E5_*tJIzaV~Tr^_{TX{Z+@&I%+9aF1gE#z^LY z+Aa%nH0X{VfpiK=#L(?`To4o&9IlH)eLaE(i^?HpfK<`wIEJ0Qf_!uUe-*xq-Y!3M zN$%|%c8m@PsRXEI#NqAp3)G0CYd|1ff&w=e1|${e6BM{=lp+Rzi3>LwPdMLhkY&XW za{qIzaipK^@eMd#?#7EwSC7+wvA?e#w8%bxqmOkp`rVfX8X?vQE#2g}8V0U7 zAjtuf&EMbWaP`BK+uF3HY1;{W9+)+iM9=f4+Th^_=EBI8*bd&>Fri7BcLk5kYOK+Xu`<52p4T)4TW7U~gt>~> zR>d{EW>c_rPHT#s=e4$pIo@^8TUk; zv9bvTZ);8#Zk^F?3z4%1W8?^LsE9T5hPCn2ykTQVHft=5ZsQHrA=w;+jN`>2ITC+c zxGl0jRv4##?u;vLt(=rjYj%dnIZJV>h3(Kn3u>YHsTP`#hm;SXd8@})M{UuIH|pZ2 ze_@~SPqf{3OxdSR`$LL3YuU|?7!^G;{zhESTbn0qrYv6#PFe3Y-hY|jevY@E3uzu$ z3a__~w?+5I&huqWlk#5;PE}8i+|%FpO2!kbpdZM5N&T9^Ok0?DgUCVHhOK z`)~5fnh-s!)J5uIR^GG*k36s3^tODyp3rTdUr(qkLHnG}9O>nCl|MWF>DgOn@#u-qGY%S6-ftRo>|2S$s27jBaQXEwthk}N&U(>so>YYG_{Rie?**; zwCTO32oq`M>C&ht#{AqK^YG;jJlzm%nN#Q@wjXg1lCHnmT22sHh4>_;R=E?ZiXQ<}(Uot}rf-&P0YzzeAbiyX=g zDBh+y1tZ1!p^}qxN~Ep$w9pocQ+`M=GEV*>@i!S11==MGHu^!E6kT_2pHpdJb(rpw zKyQ`lQrR!0=CO_L%F}vJK-OpEWZj@AE_RF}ZD*52NF9P280A~afGTSzaVl)-5n#p1 zK&z9Yc0iAaZlqT%a#?OmeYT2776>FcP+SbbhZkON9R2;6@ zRWT~BtMK&uD?kO6Gwxxp%gvSnf2_K?a+e6wxKQPfKLPKrCCh;^VUmL-geF1S49PKA z@+^QbOi~8J1L^W;8Du9nOpOd&0!&j5T}PzLXG2{Q5RI%oM@hu`VF?fc8YUN&3w9Mm zL4kH&7Nj1xA22KGXzNLPR`KDBa%Zpe(gnu3dvsk&3t`-YLrxGlXopyX z@oi|>+1n3Fnr}A*#KOHm>T3n>zk|dBrEZK4(-C{bA8qHY>ypa4yThNoe*5*wrb%XM zH^2LMvhf73Jn=wDhg!xCgbzf{#>k|mGGi&|$BZB&{!=0Ko0_6cyAV*FJJjtqxFTErRtawQ=v2Or@6MnuHm|%Yi60@{s zY)^R4G+maUHYLh8CBdi5W@*z5UBS~8({yEm+McN1o&=w+M5UX<%~xB3_Tb>G#*CD_ ztF6JkpqUuW5$AYe@DLP`#LQeSL^UL@S9qG9K%J?f$=lkbomHIJVSQi;gDvP7dAMd`gJ8nys)&U(p z+nh;dQMBgc`Wy9eGFh_rMUzcW%xal@zL3zEf=BVhsFh$eRm4@%O2y`zD|f?0D+nKi zCh(VA!EgYk)_=hG?6u2eVOJApr<3+48IjEPKN^-IC2WKA3entLp-<#xu8OxgBW1Y zVpNW^0!&S1j@S}_buk?C6lV-`V$P*GGReo81dv(eEX(1HAZHAU0%zQVMe@i1D}fcu z@Rx8a^0d4hU~n9rEhItE~pbbOhPb2vHHgV6daOqH6$W;C)p@S9SqK1~D1T;wXcCf;xkk zcyhwo^wEKA`shFwcj?M<>FUew2?%?*&h%A5! z&i!&2gTAM z@c=RA9Iz>i{{}F43x2-e!U_h9zk<-}v5jv_YRZDGX^UTseLoA_{2rXok+!fSI>>7) zkU?#ZwA{4irqIU7{_q|!Uet#7jFFazHQX3I%d2aG`{wio?>VEEh$sAJ%*N}hVpn+G z`iU)*)nAd5t3Tf}b(Y`Q&g&%; zd15*6+?CsfpWft&97A(Ged3V|{WhZ)@RoEkLI;jCsNyu%$-$!kYNW zN#nF`N3ac=3MSp?&88?5ZRT}VV2SHigMlX|M>l<9%uSwCn<5@wy&`7jRn@`0vub1H z46iPWNqJROaPOSX5ZMs+M2i6L2=0GC$*&!_dLVQrGAvrN<8Q``dFuvVyJ6z=WXo5! z$+piOQ&;%S$9V0r#PL@W4%fF7gO4zQ#PAnJx&+My2Qu#fF9Y>zfjx%~H++oUglTY1 zbxn=!IY?aWuvI00Tg}NC5^W@JtNe%#yKX8S%6`MDuo|o-DM!KWV5($55mJR9m63vl zqs+OpQ)Vn238z4mVCfLufb%bQpqN5Cky|2PmdPiVq0gV6rI##2pFclKuULjYe}0Dk zPU@ZvtP=&O;PQ5H6Gd9^0$T!|V%U~I8t4CqErBfpX?u1v1ZAlo>c&Put(0X^IHNNy z(EM2+LHY23%KdtKjBNX#BfR|s@&!he4==7jqq!fD@JA}f+$QwZ(+Vj{e;MLOX zO%(nRc^@G!qu#Ul6p*)GLy<%@csn6l(n*{GK zpqU$t2sOQWFt|TtM;fm&tPa{A(1x)c;T?dbHo*?oHPJdvZwC!It0)&n7@jH>73Cwb zYEet}zL=J5nNgPV%F-xUKjksaOvOgNV&jB8S+QkSsS)j{cXUzeWBLXivwx_2qE&!4 zY#=m-;9=2BLLzY?xuzXb9PqAz*6?x&P7AI;ZY@H+nD%=)_%_`awcUHmO_cT(|yHp>r8DQ&lHQ9+1i z%{Rc;Qu2Ztw?K_LPRr>y9rO*#l@I|vLk0|-9&$?|*T@;*rWzAPDg;$0u)(a*w9pSjjdoU(h>)2(2nX(@Ii?R$}f zIU-HZ!Dn(;a;pDB?z!t|&7vkklv6!>R>2K)YX1tDv*W#B{xqTalarS@iD4{ZP6RBR zinE~6K)TBEqY1OXfZoPr#;*#-ZzGJK`G-h(iCqp}5Tmt~ot=5>kTP_R6teqL{xzH8 z32%F{yTFUA*VH-*6u^qmEg{!aNUk|~sXEj15hTACjS6rm%$CYx+>l=j`Btt_92<}V z0v4=jEk5r!0f->J;ehsz!Y&sme0Lp;5&}1}W7h&>cLq`m^2+sjSBM$b-w#`1u!1>r zmlX6tjc8tv0`Q!Mtl<%Z*XO;^>#n$Pr2=2$;Ol*~o1TQHhf>B1H`G0D_W%S$>l^8g zw zx&mXbj$Ut%yS7M#SdlIwkY@zhVF&9R!G^ye7kdkL^ym!-_y9!9M9YUlc8{P+-E}$q zf*k2}Zg|Au>hDE&`{*M@4+tn^Rtr+}D1aOmkzAxw3sSa!1fB`_270}IK?-jW*mg+v zp{0;f0v82Dqg%TG(`Nq#2>uCC4#7+Vbr5h)X;}J>Y(`ncD~poKlAwK7Z$PH()dRuS zP}3}Jh-~BOvKaXdZA(zCiK^Bl__PgOdV6{8N-(H_hde9~4a0T1Evc*q4glzG4sVXY zWo~G5l3oR-D7_}wvQ)AbZ0_*6=w4n^4%hA)Luhkkb9h&#I_s%6cj?Z+Q&J&~9EuHocr@Pht+_t5Zx&L`(GxM_ zjZ^S&Ak;EzGKcp4zMwcl&RR>N&hg7J`=_n9T7TLWcilS57j77r&P$2HlDW!Nq5YAK zQF44sbYo=4H=31?Xu{GkFCz>^p#SO%#x900zSBLUs{*4wsaq9no6{IG7|D#LhS$`@ z1_6oI)Xiv`c};Utvkk;Z>&JG6cSgx~_RP@LJY5~@092Y@J40{gX&94ckfGKzqg}yk zS45e+t3F$Id)>5lOHS`mCksE=1^qG=MtsqIk;~%^Gp1_ZRGoy&Y%_1FgRGU6xb4!& z)^9Xrk7R_!27Q8dD+!$?#T@YXq_%2C+rVoZlG^n%+ReOn^F-IAbE@fH*OW7^CeQ>_E*0u|Fs%i8g)k3W(ldj#4lv0EgUK6}R20jlRYg)y)*G=Zn@Si#AN0;ET3G zQ*8I#ycJCtaZt+V7IsWUSXQErt(e#3c8pAFD`&Lmo;;~-goOa(@^bj)s6FaTYAR+l z>v+w&q^2I8j?kt^Yjg{&ngnGNaS@%k0XW3lZ}EwiD;K~(yCOQg9{&-PvS-2}6c`?q z2#0{dkoT%Opn!|kT<*Xaxm@@o?}f6QlL369S*`K11N<T-bw?JL+`hDWR2U2I9PJF`# zH;K+Z2F^1r2KX{jxZv~+_4<9Njzy;vM8yM&7VE-SuIw}v{4@M~kHAA{Mfr@O8aCfW zc*U%?bVggrYb(*BGmRYyAA$E^ddt}B;n!jH4Mw-b3nqp!nu^{%ZTa!5=#f)70Kf(t z?B2%0$lwP>0DoAD;kM9r!_8aJ%N_GN*t08ZK3#ij?N94M2O=Ht?CDgrWBe@Il4XRc z1h!;dDeM3ybHqDd4wsSzrRZw#{mz+!CcdC)qISxdC}>I+>M&8mm zp_&{_SQ?WSd+_L-t}wcS*VV*fL);4Or;TG`KJCC56L$f4fCRe$ zsG!oEn-era7CBgdBtVKf^rO<0$PW;)b+C0NFBFoIpAR zZR*vccy}%+QnT5IOp)B_EKX#L*q(ltg~u4_7X<7dp{y3Ra34bag2bHB59+hs!7o>QzG7OrCuo00P7Q!%Q){Pdj`CC$XlBikBQe_(4N=elPZ0|y6;oyR zN0XM$;8B>Ms=`QMUPkHV!8U}TnS zp-9|e5T+{VML9I|!mH#Vc<(7F(;<9kCdl{U{{{&!m1*LPy-|rvl1KqkC z)K(Q_`~0vuu;Q}?dOv#CeA3IXuMApltePGCM3#hC09%4ONbTa zv$o1drEMfxH*X@$1=s7w>*mckYDq^6aI`QTwc=`rOQE4}aPt#s-J~3+9FRrW1W| z>O4WbEIFPDcT4&uP^w$;oQ_h-GgdvWmos^8W9rtUFHeN(tN z`?4-^?9|M$^Zc>%iPu~+uXXdUbtm?A^E-PUqq6g_ktVV*Ci_%#OEXVEaKin|-p`>5 x2;SfPyPdz;iGv9S1_Y84y@NBom-*hyk5TOLmJ-tZXeXfuQV{|gax*z5oR literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_838410.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_838410.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bba2a94fc106b7959445812b24c62064b3602db GIT binary patch literal 10256 zcmcgSTWlLwb~AjFGkjBT*|JPYHf>q5Y{~CO6w7uT%TDA+9LsW;5r>kf2Pq9n+16er zOpzdDc2QI5wxZMBhTbj^Iu;OS5g^({fy#>l>$V@|gu+UgriFm@kN%kTq96Tf&z<2= zv_jiS`na&}+{ZcR-h1x3=SBas(P$u`{QKLVj=Z;tApU?maxoSfk5`lgF-NcjOO6x$ z(wFQf#lB)f*{>oAFF7jGkx`|up3oEQ7FOw1jnYrf;2x`b%DDO|?8Ys-%h8Cc!B z>idO-`?U(ZOG}x3=vf0xu?m*{nTj>Cih0w0I4>?w#h?JgT)?ofmhXU}XZ7D3gZiEr z)DKw0{Jqy$$FQ=t^^*^4`}HEf91`zJYp;HoVC}ykS;u^Z%-g3)4BPmBEiXelwsKuA zit<6RTiGfvF9-EU^=%Hy*2PZ{6zlyPZ) zW%mxRVDr1Gehcel8{Q@Rt#C)y;+bu?NLSMuKIHi|9n=E-V(%*_&blw2=|AZ{_eAgf z6TP#7_TsC(XD+$V3iRmyHo#ZLy3hu?$Wer?Lsh75MokcrdQ^?-XMjdz6WWBT#IemG zIFVFPE!sQ-x(`(qaU0NPz-gG#5X1}sMYf0(pa)}^P9xfa?3hl;elV{#6qa1yyGUm% zYDA9jgjYO=i04E#sNy@}Rmf)$@tU5yzao6388suDc)AuupsipVi1!pMF{Hrqj&yi+ z)PzUzcwtnJNAXJHYTFPgL9FAkDoha>L~Q2J({vE^?M zksz;M0r_f0tv}?xZ<3Mb$S$<2c#r?j9EFv&WB7fvm2jYpy++$m+mkoVYUZ`FU5T`# zU6R~Hc0Y#>iFQl!Sd>#y(y%ms>*kBnCaql}qh!h29nV?O^~{|spM$;Qnb@il?6TT~ zRmqYP*(34@G!iSm%*bBUf%ZPhiy1x4>SsVBM>tc#)BvaAyUnMbfLsquUbXPx@5^+8NpR12y9QAE*%zp#4ZK%GyD60BOY5A*4Y2 zpX33e<@b&bJSlO=D9=O=qk|i;4m}Ag^6&_%WR+vv34$xNSY>=AbQtYJN9GN(Jd?z& zQae$(&Skk(4#P#A=in(3E>ci2H+{i1Y*{&MjA2C0hHg!Ihh==1LQqXG!7)J{2#t8T zHLP^jus(8;91a0ow4ex$3x-K9Fw8L%ZXYYCJgje8P)_*#f`;K3|FBn3&v>~&P|yyI z2bj>_J%U;qRza5)lx!d*sDq&Z=M@Z0C@|skxEYpp9TmuKfqYRQPY9ZzhZ*+vb~|>&>xZbFif8FOGWFVLg7qT%aJ4cMoqI@C&8_gm?e|c(M<4 zhq*v_GU$VO%o+}!0KviAAvo*=aQLkW%1Pjg8Y=gUp!bI-+_(w*kehzNA~$`0w@6UX zi66{PYRJb0Lj@dK8pi7Z1(9aq1I!K#4FzFT`pSK{xiKJfQ!vOB-D7Ys?#U1r6f}d< zmntwo95hI0726c!@q{PB<8UpUKmx7FKoH|&k$^AB2xw~N2CU9*dy*nl&j!T;M|t~N%_xL+iwK8vEU@* z@wSgJo-rmg67cV0hQ0n!dyw8o;k9j%2cbvNfTf|o#bivg}i9gj7l=1ei_^xE%lA$4b;%nM;`&|57 zvN1>3M^AojvCU~`wMjC0Bxl(kJ+raDQsqqf_^S4-%CTa$Wh+~Fb4&E46rCT^t;^HpsbI%hu`?Oma(l6zD2eBCadZi}8=F

ACvSA7X`b$ko?JEB5(B)k z@s}@ta^~)tG?S^%H67!P$D*fJe!vVq%bEC@1e2`K(Jio=wKD1AtA)} zZ`hM@^9?^fPbW^LDkapr4GYR8`bdmiwNwCz zA*Pl=x|8aUn3N+qe8-fo;j8!Xjy;*z7J7c$xN!P&_u?(S>s8+IYWB+YuikKH8E^It zFMnky>kll^lQHs<)jn4>Ta~O!@5?m(hRkfccYM*gWIY8NC`hp(U71n+ipgm2H7qKZ zEIl#xBa`K}FYW{R?)Y&aro?vwTTU{w=9qfL?pWUqZ*NIEK}=%WRdZ#sBK~%&g*R=F zDOYv&pK^24v(vzT`(M1p>zZTKs?L;XO;z!>?V_yky3P+Y`F6r|D8HM~JEGkyCVOI> zH*Njpg-@>Dy_%lNwC0+R^QPm`)6bfzNRsg}3?er`GOHvJ=BE8vhJ^tux%##dQfdf6DQD zUCSfr8{@{r5S)wsy>91Y&=|y86juZo-nAE?-L0M1`G`| zIpV1TD7oNpOpM82f<|df4pAEk9*`2^AvMwcn zzd#lHCoaYLZyygENAMu926a_&kXt0U&A4j7)fQY~Q@`ObToWE^##IZhw&7|!u3Wg< zfvcUkYQ@zqT(#k`MLQAP!+Bu-D*OdcK?Qn$ z3!%2H==8CU7z4LnQE0B{tcioARaLTKMOT?L!PrVgW8t5r>OX3XY7KW8uSwn%H!Ns4 zZz#<9J@`uHf(1gnn_9T-IRX)**lXKRJkSS`q;21g;R zW`+V>9nw%Zvw+)rq#K3r#sZQ|B9-h(mp?75lp__OsQ+J}q7PxGSuCR_xwZZ+nJpLF zD&Qpcb!fT87HsDMyQe8h-`t3oQWF6+G#_%U+9lweYK&KoMFp`Dd@_akC>1NWEI}9*eYbiU=P~&lu`z) zFlz=lX%;eO8)sR=4CsW21#93LcvsKaPZ^%d`x$kX*=u14(K2u3dpZI$KlwRC=kznR z5RG&g6Or#JB?{8Xo{0nR?0e(r!43i5kn$oS$F>}P@oWDD>#jI3^uoCcT1|^NI})3p~WN%uPY#y*bJFA*h3} z$TK3SJ>$&8Bt*4=FbwL${@_&D>xD2%j91lQH}>w~#_*~*11=cAAcuUz;XpVjsHT11 z8-fZz4hY%;n#;5!vWS699=^Vdr^zAo1nr`n_vJ!ZYlN>^w2Ie zjSvKld0BZz=?RPrfuKckY zTrr&@M2wC8kNB|2U4u0&r-73<1T%@8h7+oL_0e3{`@qczbJzm|NFr#C}rLR(^8 ziLUrj@W^QE?cR8AvNBbZ?#|I|%k(~;-j|tL*qWn{0wiU=eK>wNF`c}XZp=}w%Ty;% zb%KAd$We!*J&&xFb9J+IKfC<#wL90+*7W7i-1poIeYsu7dDrp9rr)(aXv_9p$h~-x zcU{bFx|DtGO7`ltoOK}j@~WkF*|L?lY)y5f7%|#=`9dnxGt9PWk ze{XIB-9KCy)?&9gLFT7ulem zbHM)DoryDvogYmlV0Vqnv=erkqnkz6&r%6jg8Ha0VdANVWvY><;0~N{X;gKxIpxSw z%~@5m(O3sTG47Oc(**oPK@cErDVAg!bsBf52H zmMnn=-xm8Zpc}<3tC)XPiPi&K?V?X1uLO2RUQ;8lAe{#?tS~dA5H)zKiulP%(BkB1 z@fU`53P{PS=QTyk5B~pmfKoto0Cozga1LDNp=?RTEHV7gfTzp?x54MrjACuHp~$BJ z{jNPPsp(Mw5p}hs&haHdKrw z%S=O!Nl;u2V{eg3NP)yw+{FQwpbPsYnG{GLsQ;0$ACh6Cf_gk~!^;KnjhE5Vq#QjE_wPkdUGk%^8Y$AyjuMB+13j3Z&=@nOvL% z7O0RHlEWSbg8B}ZT9kP)yUzN#4**&;*kZuyq9s4ViAP02{!84#AuRV2R04KW+=GI4 z&@DL>^q?EN6z<6YObDb~Fcy-cfpBOt3@Ki4i^AifU=Z)BMB2AV;_i?(7T`FqM|@wM zZU~+K24H^;f5BT&MTu3b19EaY<83NVfu$oMv9<4uiZ3^PzG+E!Jlg$?nU=ko?!_HT zE$4WZGkP+%Khc|N{1Z)7v_Yn`VnMrD|8OSf=#QQQS=Lu3B6+pStcjk^TL{`7JueCz zFcrU3QJa9~z-{;p;Ez)LtP$WSav$RXyTn+H!~wx9ecQ(2$Bi(gWdxnv;;`}(l%u|o zpa(B67y?I;6Vn`m=7gAr;oKNP$Am+HP}uJk??%vE^7@02VOL^St3%<*aj$f9SWk#d z6DWz@?%}{Vdyf!LATe5ZWFo+Z$Gyk6c>o4=LIlr)MCX+xN&bPL|CKQQ8&SVXYTvJY zr#4?r>}*|W-~CO^DUxgjZZtYSq3_aptJt;4UAx$I$lVIDTS?S!{`kn9Bl#+^TV0s_ zqe$wpl2Fijg*bCT0r`;VsrS#lb1t=aLHY1PmKo%2p6ujQHuBbz>g^SkwJ^IU6MXn$ z_DzPj4Q2zASrl1Py|q63Y}~yIif7OXW%gfy~hV?N&7cP2(vS1YI=tf1!;dNaJq8A>8`-=u6EjCEMDWLs;=wj9fsY$r;T*plUPDGnu52Pw@^R;&zB zdYd++H(gO}T2a{mrWYwp#|9#|3s?abjkYLI@AgMIqLC0&7cj8?#EId@^DC~b?8%jAAf4E^`U$ol~pdYBdc@;b)ULwdZH z0%={Dxl2DrFs5IS_sn+{Qv3hUDURxx{LRx!pBlv!xQOwR%o$5%rf{3~qHWrXw`nhN z8G%zuKTr+jFczll6;LXs`~&q+?kk`ww!yC4rX7u~W2%5^9%E&yKOl!pa89Q5>dn_k zTg@}H;PWDV?Ccw__u1b*J=AMI_d@N(zCKa=_M2x<57_%e{pq*uz3=v(xA!?ScTflP z5Qje>)}R~hBhdgTji)$3s!2IR-c#I7LHJ8{fK{d;OEVx zNW+%%WqkR(njriYdR^x@|AoUZmaTX<*@)Rm#^aUQTsedr8I?VXf5HZ=V9&h z0PMG5I#7c)M9s=uB&>{Cp&i?j7;5fbn+o zHGD2wg};{H#pmHt9lx75;Zi+M@U$5b15 zQi9mTV^nBgcrC`R+k&lo0o%qa`K=IUj7??5reDB*jaO}hZOn?D^8)ssjD63dh4_5F zU{v8tpJFzRbWbhsu)oolD}yp*2IXwLY~l;~A~Y8nTmBv=E+^r9n)u@PNRB37-h0yo z;<)Av`dRc7-@rFBntNJVEBRY?fVTlWnE}6d2l&_0x3+Z$$USMuHhymwJ^^lIB#zKS zNk0#A!@rl`JDxA^<>vR*Cn1l+x8jp;LpRvJk8jMFUlv`)xA6Pqy`MJ%MkDN9>TGz4 z&rxdwZL50n7yKTra6@Cshfx4(E>W!~lue8-D-PVVvV--^TMB7WKm z&*!05d=>}zB1SpUKoD%Ed@kBN6}iOk=NtJ0_Y5-6k;Ju3xt6I-QYhIlY}9!ch9+#J zsB*mL=AI!v@(c-6Xi+o1Ytyc=6gpv3h^i?EHzBIMzAG;F8MggqEF2nPGl~k| zq-dCCy<@Cn%I;=Fm6LH_6_rzNkEn654$qiNRL{FuFDGh8C%q2ezBW-UHLIY?0KQjL zGG3pk=6qh(B^n$)@08nVcQB0Yps04(r+uub9+5sy{4sD&$E1tHWj&@X6*(X4W?Xj1 zl&F`B6Y1heM$waQb#B5sJz{POEk%E9YlBD3NkK@Xj%}%XMDm1w&+a85XUSG=Ak$C%6GYvEXDi_bMR-2$~1x5w9RpK|!7W+$7cui2r= z4lLvIdhAXwFwVh)``cPOphq-HK=uh2>v2u8Z^G=*iE%%HvkVa5RuGhNp()f9>3?LX z4!r)D&RIAYIv24f=!!t^<2>_{c2OH4BVCEShQR5q^<=Rn+9?z_##IH$TywmrPROkb zoJyL@B7=gtF3|TlzbJeyYK=~;oE7Sh3N=R)`8|QNNxC?)FIpj#Hwkoepf_pEi|6kW zjJsmAK(_~apBT;IVZmtq?AYzow@$|#D;0^|hXv!|z=`A!=y{S;5S|H*2TmqmDxiVJ zK5;|~B4a{PU940nYFoLmI`9vzs~3ch-iLajqhBcM7mRO7BhP&@6Wtw|6N>9&t+3`5 zy-?J-dUEafZ>?)5g|4#?uL)fjg`$gsaWHV=iMb+L94q*ICRQw1TUQ1J>*2K)!O|Ze zxD+4$@%o4}KI{|*nE2RLVc@D@xhk0F0)0tSS!6Ei`@BCoCsenrbPCm7YdJ#I`FQ`u z_~7Mr`;X&;KNk8OafTK8S)qy*OkCj1Qv&$Z!l!5?0sx z@yOEH;#kBHtys|t)&q&6gI^mvgCvZ!ZfegQ%qmpx3u@5mpA4N0pNJO4xJUH9Ao)1g zv@jQ%3tx?%Sg}6JZ4as*oAQ^67mFk12~#a#q&UQo$lDcEKglbI)c?X79f;NoWv|8h zg|g-qt5CX6$ZKC!3b}`ZswAC&H(wW6G3A#F3C(sSC4uReav(ix?u?ts2SHq5o z>8HG)D-SB4=yJl1(PF{e5bG4o?Sig7s7XRw7mSc^gAwwZS8Rg$kf1vh(rlCxIX~K{ zAaV=BlR{3-XK&peyfqk`S!qnv9uabmARfryrU&uDV6^yiQ|vfCGNG_lFt)CoUei8Q z{nqhND;#}OXn#{Mz6rcQ=l?}d*cU!2&=rvxfwn|1#0LJhHFiO$?N}v++D?J`QQ&w| zn-jMFgnz6vEa*bIa8V>b+Wxh!;VIOOA!B$n;*6=j)-^qct4KUUM%lBYkg>?Hw5hVm z*=I=6u~=a#L<170oDbvzC&}132RqF@vjE7m0Aw~L+XNrB8D%Y~K_P0LUeBm|3zcC|8s$_01<7B7LZ-h4Ha-cqgA(j!HLu~dybkOL15fcZZ{%~p|HQM&c5v&bmuc2%Y${JDj;$EPNbOSR0supus1+w zf2jY^utf00xZIoZ^&#fMX&fN8CUlTo^^Tax8rYIBrs53o6Ot!)7+1Fi6{;E7iXE`G zpY=|dub#E?2amfG=BvjgvIRPa;goLZ{c6}t!^^vcdi|Ok242mO$b@GU*fG}NUDTd| zaHt&ifzwcAqrfY(wK^5=DG&{o3IS~%JL{vkg;S%GB?*-K5TD_c&G8NCLxxdBb`FBp z92s}>dru|127YZ&9|UYh{ejM}^OSgcOru)eq@U!=C z;%j%n&)&a@uNX)7bxV@ZjL}XUhcht2$z^eNH&^K)8O=5@nP%QEHWK*l<+SQ`2)ZV&wV+HH3f zEDG<78pd_i?Q}U)(L%ga=o%s7&ym0XSBgySl`YyBB6-kEh;d{;KEvZr3qj~kr4TSJ zgQHtAJdPK@WkNJY6|fl9`3zlxrUSZGBiZ6Cnwd>0s?PNe^lWM&4&4z zYJ_ndp5L>LcH!EZyC%KhC~}9PAOT+o5?m)_{}CWQ)7++RQL&zwZ{@tu@_NE@{2#CW z)BC@EKmPVW{N17WCAaYIc;e)QV3~+d&BWRFg(+UB_DgyMty!`TCYD6R~?B{N0Mov@x}3V$`V&u;*=#xnHIW2U7vQ`>|X9(ry61>zBqg5?5}#i zI``$d^~Te2)#*5OI!Wa%910!!wDxAxa??6hAM5zy(49lSs{N|z%ck``C*!J^+K4RP>%D2`*hy%#9xEibjMZQ0?y-9_fws5feKN`W(c2%T5s3is$a1_Xt>|-NOvgE^VIAQ z^gUHsqk9BZb5f@dwgw$we-=U!o1&(l0_M`i#fuTgUtanijTJRF&n%z$`Po?OdPQSU zgUwjSBVEagb>-a$m+xO*YfW_YCd(~1JC{3uekf-8P3Y#-7ljphlMJEEGWEd-h7my>gCBg0^M`TO7r@x!xAiv&cl{*>W2u3HLF9hei zVD}A-(obc;=;RqfKT!75;7;kJ&@8xC5|5xS@Fh>9CzDaEYIvG4;J_?!^E`y6{nSSm z2qaUyUdEMHGpc|=iH{CLpJs>MtC8ou`IYB}IpWp#=2h5QUZ4FI04JnQ7rB$8N>HB= z!X1DK%$KIj0H4dld0ETUEa$U~xGZf=$8m^vz@}i4EiP&&T&`(QG%R+S{|Ay}OZ*XZ zO75{41G4DQSq5c@(O47+7PYKP(qSscl;b@J8l%{*mYbP%x#l6b3vo^lBkE^8=@utw zsA+Z-5y0U;H7LS;ZIbRo+{d~gvl&@0DkHyARCt}D!tE4Qb9jU~G=dsO5JeS=Hj0RZ zeNI$*XMG~+hH&((&+Znb*d_^ zs)|!pPXIw5(uZrpBazdK-srW2t`(4g{Ql7XFht3M`x8_h1U6o3d15q$kB2TrjtO)X zg!AaUU`Mzk)D57t2?rBv5_BbWP`OCIMidFEJaGIe^!0`MBBsc|a!K@^+xA=bmA4YL z2ZfS@3A*ze`iMXuStHll67&f~)3k6Zbn4@GmWCFG60{}I`_z!TuqU(!=2{aam+S8p zeo=m>e6|0pi(g)RWH^zEenGGGXLiuKd8VNHv8gCL^VcQ89v~KWebN(|iQ1O=Sbyx= zolC2&ckJAvDCfTouI6N<4>p@WG;l+NLzxcimIc&IJGCP+9TOse+m4X@m)wlahz8h=Xy z<)X)WFP`}!HhuA!`xt82EQQt^*vv{p9Z}&@md)oCGcuN;E`DsWm?T;Q*nH{*N>dyY>GAwu>-yWlxWJ$$=p7eRhE~; zw4l`@_Z9mT3j79X;0@WRWb8hYLtBzeO>(V}IF)1~Vh1)Y7_r1;S!lD-l88hyQ4*0? zP-zrp*sRB*?LDZ9de?i?4i95@aFRepj+i9$C`st0nOF*|y^oMI@aP3nj^%M>&OYMy zaqOR<3b+vNQJ-j-v(I|qkz>NeaL8QB5POo)qoi&eU1R?oAXebV{WoMUiM(PEb|y25 zNnbw$qD_<3=fX`b3mkh{#yy2AdkG|!Ir7$Lv$yAO&97R&YWT9@k@5Aw2^24H``I80 zj+?{%i-i&EA{A68OG-bnM*5@H<#!{tWjdx7N}5*kg<=Rp-vEIS^no}~6FJ2jIwTI> z2)z+G22t&AXq!OWVs#0+RgwyaLDbeHsLF4sdI6jUO@e9x@o3Cn7z_=DXP4ezd_Q5_ z1CaC(jsvgS?XQpCCZ#h_6dW-1 ze6>f5V!9Aa$^1`+Oz8tuP`am-%^1IG2YPUCz85`p#{bZ$c_qk*l++rAdH6MctzQR5 z72I?wA=!f?&&s?f!p zUlJY;D2ut$nwux#oqQG^CPbZFVv+M9D#zVEQ4enroDW0_i=+Ia=5_og!44t_onzML z_04)*si?vLwgg2bk|fB_mre)?1H3sRC2gyH%sa{4AtaTGqx8D2DK9fS={n5*4FH3- zAUOE9H)2CclH~6Q`d4_SCq6uP{amzdg?-o-AG#!% zhvV+?WJTq-N>W|@L}mIv>fg7jcz7&+X;?5{j*m~gOh0@7Mm|L;0$O-4Q7Ml#o3nY)tFE2lRc|7+f{W)qGqnkgJ8bT0}H9f6;ZP>kq}Kinjx0 z2R2l=dw)4;j27OmxK*)1Kyjt^LF4^KTzshh7y36eF2*m9d~?|?Tz2DH($@G~39!HK zsvsM~?<@^34sQ@pj5(j9&+iZCkqzOFrLM)U4FZZW^%wLVpaaF#-mlJnc@`Jr?_7M2 Xnx40nk@>&xB64>na&}**Fw6We@pX2{ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_891149.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_891149.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2ed7e935efb75fce7b9bc99d543bddb2590e36c GIT binary patch literal 10306 zcmcgSZBQH6cDvF_T4`5+00|IZ!eAR=Yy!q#v7cjX<1Y+0iJiP)@)V(!umnhAB@8C3 ziJVEEr8rYWUs~ieoyeJVEXVE`civ1$C)3AEf7E^bQ8sd8b?eS}=Dk1uN8-G;^ZvBw z-j$Hx3FCA+y$pBn$2sSoz4xBa)j#R=Is(e3lFu*xrG+4Vi#pk&&orL?TtN_X1V?b> zC@~;E$pI4g6=TW)iX;?-lk{LpS!DT%$Jy0msK+VbPsvf7>SKyibLu;q+sX|wlpOt| zFqE9OFvhTQnBa7ul6RQf@Mfq0_eF9mzt$35Eyp@3Pf7l$+~)M_j2EplUcAnDN#Svq zjx)Tgx}8}#`y9|>o=TmBr@R19W$TQWuQP63XWZmuVEu}BRRb(n$E~+oB|uSJ6+r1Z z^E#A%}KrO+po?L@gOR zv{~3JXhx30sVf8x^qQcD-e-%jNl@X|R-sW)<5sgk2%A7@67MNE!lkG}RA|J5(1b?O zcxF_KM$syGO{+kzh2SW(Yl04AH?E<{tio;)6vA4FQjED`)H}U?}dWvil#}9xTp+a%3LmLU&vv7=;SNOBOr(JW9OrdpMsR zdH4mCyI>WzuG!TstIuc}Bs1S^h1xct*dnZDiHMXPU}8NHD&V$c}YD!DL(@a`8n;!Pn8cnc*GHp z1e`ZOC20de-px5}Q)9X2h@{Jmj!)*dy({AZ$8cUdH$LspjZgc%DOO!T=6w@> zTfhzP3Ies(Q{d5cmy~|kAQe=$5!wAfO(26j5!CcbCE2FiYr_;JS|$vLjLXdjf*5d2 zVsq;O1xz#ZfPMG5TmcxBpV>Yf;)rd`J|(Hqi2?`m6TomE!0=wfImw$14SBbrA#Zm; zQrR7ji7|;Bkx09w9G~<{>VU&O3f`1*+8>Z~qdw@_(DFc~6IO$(gUGqP!A?m94VPci z`D{a@Bk0@!2roV64eW#MvHWk#AI`Q-`uLH6-|ld>U9>w!?7@pZ?{@I=yur2r?`R7; z$NX*3YeRehSHZD~(H;KFHW;$ORSEjMHir)`%-((N?VbBzM$*ebwh<@qb&m2EfIW2g z17Ct;7bbovn71=GYo`#K#zjHZVVu+|7>&s&H=l~+TwjrarKI%_U@z-X3Nv&vVpFg(H`L$tT z&%HB?Zn5)>XgVY6&jNB$dEAtsKD8$_*XxsAV%7F}vsk%*!S!|jeQNQwFMUgA2gHK| zV&#BXbRm2kgMUgUIx+Yuv1<21i&%Md@zQVXOa15X5C8h@rPm$eafevx5R14U5C1zj zz7x?CvC+ioC1d+nrN1(NX@0;S4R=3+-E04FXl{6RIBrkW%xlDkJt@=PZ}s~kBs%ls z(c`gW3Df6fvNNfH78y~QJ09I5oGaMa`4JVEzJO9Rz+;ij1 zyz6iK7paBU?)k)~?)&`@iq1roj|^pV6|)ubio~VQ?epcy;Tvx+I35@dN7U=iC}G1Z zYUAUxoekMAW<8&k-ZbAZr%bKMfN1KNcZtRW z5f$KTi#{x!GtHXduI@@18brE59xjfxCaT2J7Ljg=sMB@=)VzBy?Y%XW(N&ZuU& zktlj`c?$tkhIigd8_n@oKkZ5c6UVRHl9$EWy`pjN!tTYU-;j%2zkKQb>*D@bMdPc$ zVrk_^hWOEV#cW${Kg80V5zQl7H=~d0W1ht3WT5K7#kvW``V zj_4vxgmNi4)w^^^8wB@IE`k=V4?!3SA(2v`(JH_{;aLeQvXJ2RDpwrU!t?5td2$2% z!t*(N#Rm9==X3bV4e$%k=kVzb@C(o9@HHIzqodRv?d=MA2T39dSv~A@lUfLJUm>iz z>A%+@D}LHO;dNZJoa{T-X*u5A+xw!$GG#er;hf_Wc5e`Tq~`HSs}6%styujREzP~H zmL99cZfIVXDDBn`>$hDT#ob!^~;dJDSJG#~=I(q7b>d%xL zWqYEe02v>2L9=mAUOZ7i0HYut9L^zR;gJwVl`aKrwz4nhX@aA!mOUB14yZ$>jOy(^ z*OliP+hI6>f@n)P9Oeh0AyLENCQon8xH`GR!S0$Z(6QqH_U8fzcObxb5P?pp{wqv8 zDlyG5vrN1z9!&I!^*d7~oxi;F#as8@T5MnB?jID74W#y55KAsBy}>ONJC|rDKM06C zx;Q-Ifk)zMrE>hsIW^=YTWJUf+~CUwF53MNJ3t`dxG1T>g9~29=fFD=0pyW^^NAgO zE6cxu#(st>8>-$!6?$ViAW-z|<@Z7lJmQIi$=*bj<LwtbS%!bl0^#G5h-mmZ*xzfizS3kf{-wniNx)X!wR{T%sD6 zn8q|yGP5tb@7n%2j8!jD)k{ounla8Cj2=ue)$#Ff7|Rl6Sz;_{hMn0H-E-w=xGOxK zW~&~u^&(q;r8|5$Lc&~35$m5-zrTNpGDY@3Hk8Jkb0f1OiH2FPXlM$b0?%dULi9qs zA>I$3*7(iI8iF8%GAq7aDhjndYUE4gFs_RJ8McCZiQ*>o~%O~5fZ%>jb za|>XkjTMQZc|KiI@p~hqQZH9#e2m95n`R8*Uim7a<8#+r1ap93*A&P*DzkFtSutHXpsZ03A>*PHNSx;cGbOh|ltA&V!4*SS$gCAIYMvF- zm0Jl(y+FAXS#Ds(!UhVQid8Tu{#{czl2d}=%HGvGpK1QkOF8r-Jys1?cqW#AxZiYR7uZAdFAsZH>8Hw6hZfFe&9?jYfzvI z1~+q8LrkzSJI~Qs-XPZr7AV9DjGz@*PRG&jVHXVAa*i0Vr(vf<%m+paKH&*Wp^sq0 zgrNA8sWqZ#-s}HBStfl-2M&q5bCDOsyprCtr z!Z#7%N70PJ3uzx0I66VM&ubrr><*-}s7r3=Bwvr_l_=Swg0zqql0J@6`56Bnqt&tgxN-LNgjqCh zOKDrfhtpbSrZQR?Ylyvkt^T35TGUp5-2F+<^`4ZrIeZv+(y=oY(TZ5K?(pJZspo%NL(ZXtJ))eWUITby%#5Tk! zR1JLt8K&;E!4fmi)z8*18MY;wph_yfF>DK;N;CS%M9c&J0(cGBG)nZiO%@g&st^?TBv~s65THNl>HS#8HM@|?1Qp@-4qGi_ zc7G-ul6B}1ghoiAy8JskGWr!`_3;5<5%TxwkyCnD=V3)?)ndiRqdR~VovfnXM%_)Q za-s?;E3D3}G#^3joZpSqCjTB(k`ml){^zJgPk7}!89>2)Rw2pyh>_=HZS(<5FT#J| zE2ywG(i>;`qJ8n5@xx1OePXl7ZUt%38fI#vwJ|=WHG|-Rly8e}iw(_PoW1x9PojO+ zd#C$u@2%d&P5)s1we^ASG-PO3gtOcf-<4>;z7M6d%EYqP2%WC9e8OC3K%jeX^rk91 zBFad2c-nqZ4%@k^AZ3{4`_lvLf--O|l>+%CeZFRjzhG)J3byRNl` zkEEHR@F^^dt0WM8IY&3;=?_rB4{gYh^4ZZJN@BR+Cp1BTmy2gbTCyYHA1GJ+18M_$ zkP`V(^i(;22JU&D7dZ%ggw!ET2&Gyn_2X=i1TK?9<}er_6i4J5cc5^4oS0Y zB@qq&ELwpa!e-VsBx#0ha(a&)vgN!4iL^=jOzK~b2_+3!+lkR&U?mz{eM`$_)_gqg zbl}8cCq&5nby)0Q;6H#2beMQ-C=Z|d9fU$m6vC!B3CZEzzchYPd$0BZ{ZhE=88cgU z&UY=gJlN7JQjOv6$R0?zHvGU6#vPE%H!f%vYwk~{$_K)|%PgTa#X`#}s#qO9zFb1E zW#K-2*I*BkhYSz2$5thdv7Zp|S&T>K-D;498F@WYvHa>gM&bM11biov=xmFxLG!eI zBIpZFc%70OyfpX?jHL9qgOV0ppg_<$<>xc;_c`phNlNf&@R7)egJc-*jKs)$Y8&>A za<>RMavQ)492oO)6Qj<9{9i#A>9@d#P%kS+JZT?P4R@AIfwJ z#j=S|&{xe@%*#zgZQXZ`eI(fd!G4)>u64F`xf*wCvfWzTt;=@nao3XVHsJ0i0^Etg z=-~I5)Kev)V4qf?g(~s1UqS3>`%c?JmL`lhYi`sm6VOa=@3=d7Yj9Z!?Z=y2ZVuiU zTvkCF)Y7&W?zY})UDlu;O*C(Nicr6MlWZg#V&~@uX9t%FXeI+s(PQ~E$&$tK&7ZVh rZ(Sy!nVh&Q+!Ao}zVV-H{;}pM8hzSgCdhBME@P>rKF4X5RFoSrlMOvlk2ecyObvn7od zs9uyt^GDF=V3Zdd$Mi?6#{?XB(en3p9CUE;Ea7ShMsEKpFuf10oSp8l$8VSzLRRxLhstR+e-UME=3A}a___m;# ztE=2L%5e4X>n<}ei|v&$IE9+6-1;I=jvUq2!7@knoD0UZe8ZTVt#C)>@d|US$VK4% zP8iL~HF6E_lb3C9n{7q@+i#MdT~83}hn|p-KY0yh36!}}MnmS&$%YHBpE-HacUoqS zo;lZl+IQjP=W0XG)c|ww{JA$?FJL|a*+Y0J*)%;j<%>l9W5I}QMCI##ekSsS zLaiE^JS&q!GI?6o1$;A69w3H5czQH67Lhd*!J9J62P46#FBmlv z5>PyvZ(ti_YC-@clAez1k@$*H;Qn-W~VLO7vcL>{r9WfX+zEjvK)L>Xc)SymHs23Vy zz&W9&Nb43FfW{pIF2w*X-iUjE8o*FLyM#u;iTWv<72@rN#F3@d;(nThT|(8jqARW( z(=`iqLiM+zt5(;5>0T0=K&25spg3ZbLYGaTQ7^(v2#VkZs1ZRiMT^o_qTn=W#exab zHE*CZ7wK9z(g_YsXQ@bMEz<4Ygw9rpt}qvSk=j$n{|c&Fa0*qURJ6D&;}z8QeXYu| zcw5w}qQcI(RQNF>O}?`acD-F_S9U=;Y4JV6o{gMDJQNHQ=#1b{k} zLWj^jUR|oIC81-yMlF}Vs{B&-JjTKpZ9fi*&p&?X$bZ&J@9N!%=#dG#d@s-;R49`g_nLlZpGJv|p6pbwACFcy7iQ3BhtVvjAkx4Dd(?nJ=c;8K?*YBa;Ia zVLEDYLbj>J(6kSCF4M{w(1&?6#7Cl-W=dvCLqlh{`@#oyAv`)70jTnY1>pJw7(2lk z#9F8#)}jyMD*9rsqJqXB2*|n#*w-5}H8COU(W&rF$fRE;M`UUQ%pceQ*tM&n(Wp%0 zj}HKv>ClKw#!$apa8}l1t0__-4@6#pEpq>{>Azm@ydLH!A~XI#u=A=vFyW704Nva{ z%WXQ^8Q}w+(csifC)7Ga)6oOIDSvcoZn9(MrVl_qI2Fm|(W{NB`Fk<6Rw6a9}E%lz5YS+OzCxDo@8ZI0W9 z1%pV6hw`@FiIa~uuV-7O?cTg?f8yj?ZBr&B)pq8zRjXDNz!Ys)#CB(^EbNX&u6z#8jFUy<0x zzd8Dglb@W-`WIdK=3a^IO&nj{v}MgvEAEjTE%zOt)!nUI^e#^3J^hlSKXH1Eas0%a z=F&Z>b7DwhTC?qo+&}a!j!A93%XL!QDTz6Srm*cJyLe2jg<&(p65EpXNKDt_=+ebs zd6zDI9)4hyx`!oZ7$#?_`>0X8D7H$LU72yovTyMeOl`U5f%Z542Zk>jau?o_`reV) zcVM>G`j5OKC-z9zrc6w-dKcf8EUzsee{kTp!{;CLe0eT+=}qa_o4G4vQvaA_8A}X4 zHZk+NQ@hh=GJA5)&R_Ch#6FKbG#!Oqve%?%MbC=8IjK{bkQtQN_9V6Hs7k+=Iq<1H z(<9ZkOOE#JtYqJp)IWxe9ZU_T2gHGevswM8vy08y>vye7jneM#EO$L*29qS(`!bY8 z|3l_rl6=fs=P#u$rN1kVF8H$DOU{Sv;Ux9QVw)dJjitvj2C1esJG#{O(9)aKtl6FE z3CZ4An1HQ4>z8bui;a@43ufx96PtfdimeMrfT_-wtX*=pCFwP*Ej^Hmi5Db`JE?)) zPoGHnG8&0#N|I~#s`Ttat$1|70`#=`NBsP?)U~u;xKwp;>9ph=l<2_^4eNUe$4l$`2#YiQ-H-ajs5rRb%iNTz z_et!&#qOo%Uy)0#pC4I%OX?bu*df?)OVv+?MQ!?0Dx4uDOJimVHRfK7}hw$m9*#Fnc{1 z8oL^m^^t&o5*#}HEY?8$L4=^^gV8yDIsleSF(E4i`(h)Y8b6>vFy1BvV!Y^)74Znw zJ&jPOi0_jE`7LmWO8(px80SMet`n$AtsIGz;fKa(agNAG+?Y~y$~a2FJS%uR0Igty z-=|dDA+Ar-NucI5@6&NZ6zO%G1OxU=2!RyzU;$~+7~m%o<3(#+#RTh6t6FNCv=@C= z;L|tLS8gxS>$jk<++L!mwxF-vUZU4*L0`GOM6bPvtP)5hA?Pna++qluUw0|R>d|w- zYoS0eP=4Uhh9jb<$YuIVYO4T2iWDGziZ)L*o6AgcWbbWADL6PBc0*NsWWoYWgZdby zV3OLQGBXC9F)$1DiUQ>uP*o#q&kkJdEihrKo7w$Lc-4fXM$L_5DYtiBz5GW*_G3}a%R`6*|B1FOJ?^& z^RAqBSI)d^)y&NIql9 z;>71C9-0m(`k!evj6QMf8BORM>1Ih=yGk3AJ?VZV;z67nlis8sK3Kxqe`rn*-+pu9 zP0|0;w?Wj;TT+(vY@V(Ydp_>H)0=I~Z{Gu&X5N%CrM*97(~;Y;g;=KXr*SkNB$EAk zdYjnt@!mUovt+)$ZBwx0nbV}xuh$TaHE{~>oNQ2JcO+1LXwhZB^7$CVu{M*<(IUlo zg=m*5tKt;M+zQeZN#wXDt`)Qbi9KZ!`D78pDYz!8)CY-AjU&^3L*g_7sa_+H>AG=b zJ#I+oIRbq9!W@;W8rK7EQTCPyD&iUhgP=!VH)n`J|8ZK-n^6o2%9W&(I_GiqnWfY{ z2nvo?XA(C?5od9#6qX0+k4IHTQn|QQua)#-+5S|B5vV=KO@gV?{%|-z9XFTFwnAWj z6Iv`71v96`aSso;Tw?_D@qai(k@6+CZ4HYk zau$SDJU@&;Z-xmssSI;`hjPlwbM=(MX*T`|dNB&jH;p2|C>%4Y1?6Uq{jCW}YE>b* z6|v(6Va63EHI0q=Pf0ObRaK@bP& zy&)XdQ9?mhz7@4jpsWdHDTRFw&!g_5PdK@<{MHCqFpj>PaJ*3I_BR_ftt6^TT@n%l<4|ZqdDn^1@xuZM;7P& zL~3Hyx^2bUAXyup8FdzY;`lQQVXK}OQi9kg@)>XTt(CoprM-tAT6+^G9$Bj@w+ya1 zwu4W2Lp+T<%`VB&kvP5TbbU0OF^J$^B9OCV(JMI)Bu=e1w1Lan(33d+p*ML|Y)nn% z%-bccYmIfH*@||tFJr%RJkK_v7)J?Kf9G7D-Lt~>NNi7@eFZf-<}ar%Lwv@YG2JH1C63?e$);jcjIQX$~x5^;zr{A3W6 zk*vc(dJn6JYV7!+YoG{i>hx?`jGgq!`XOJK4~@z6+;l-QYey$RgT9VP-heEErdL7o zniT>??k)h%-yvk9B6n$mtwIsjTLX!sNfPep{GrsLTZ+R2SN%%r3djSW7?W5};`n2b zHt(h06TPB8(|9MC?f&f0-9yWUUt7MkJhYxmoLDnE5V=U@&D#^l9z!_P`{!>$uU03B zguXkv)SqRVB+E+>RdY3beB#cD%y5?ctDz+Q6Faz_$I`R#tNkrqx2_{BHS08C+cqCh z#l_(and$pPm&ak-g7|TW!+Ex0h4o6Tcd>1`F~@rI?6HyrqQt&DyK{xbp|JX8dyd_g zXZt}2F;(+JsUe6ai*pjwoYhO_j-0jw%b)<-ay1pj&oh9)AEwB-=9`^-c-JH$##8?4 zT0xM2NGOaF*KW}tT$vZ@PgnM`uM;$-_uVTnk2r|II1TL;HcJI%jDmuZ~n(>XmbBGTccrr7B|KFk5c>IqM3abP}d=L$o zfq;Mk{s=NR8I45H>vs9MR}79SWqM(r4+ikHJK&4#=YIyczky%mCS(Z$0=kJ)k7*W% zK88gyeLU0s%jz$7e7@r$eI(KM4J|EQi+#&&4_nSk+NQ(+c=u;Ajn5dOx&vMWs+SDQ zt_QJv)#b$5b%rq3q~q&4tyP~GT(=R7GcknE8}Jb^*rVZ*i~NKDTao{31aK&#WE9G> zRatY#B)sm;!Mlr0t0f+p_%iLEi-x0f(?MC^kN+hjYsN!S*$Ac}JUZQ&;R}(Li&)6X z8W2BtKOitOH!~SjHV6p?%o*9N@YFXJp5*QlifS$dn5M$q++?tqzYEAnJw^T%g!{UN zB+1_s%)b%rzY{KqsZjPr%?Ay)8rJCQq#ng5GL4JpA2@UGelN#Iq?+hDK^&#}DfpbD zUZ>!48HcIWGTaLfc71Ey^A8;VPq;5OU!Ol0iAptdu&93O7+Txw)G%J#`RCU5r=0J) z-N@?yP(>KGtvc(|o@aF|MXkFCivFPfR{c74f+Y8?R}s$Y+j|!FtXJb|jase6)op6E z4p-~d>ULap5%8h{(*F-TQ1z*XpqQsS&=@;$^Hqwdt^dmGBCV?i<^$oDu&#kJJdGim zbsdz~^#s(;q|U4xP!05qQZu5OiLh6#BWk!Edh-WYZe4lGLhY;GCh~yz^2dkn99k!! zxOf47mS28Q{V%S6a^)^vS-BLDE(LPI@s;4L6r9aPuIEl(m-=r!Md)9hA)RC=yv4Zg zxYh|MX6HT=?h3fLJn-u?U!K9m+`E3|J2|^Dc|)26V!(ZRu%2{&b%?MwOHmBhjD`VxXf>pArPZg}1RX^4|l5ZoS zWwqZ^LiW8Sw9u~a*DBdo5@X6Siq-v;eyo2a6MF_Dlp}tnQA72^=VebH)v@^>C>|v$mVU2kS)+rRD9F+t6QmbqNiW!fziJS~8G&oZ7VD>W3+bI(q~lUG z>@FzP#5S;XAJEqdU=9nCE57g+ZLz$>%{l$@FwmD24Wt@>X@FQ10yA*(%#mU1g%j6~ zT2F0?o!%517MNq_&s;j2;Cu;Mld{MJ-OR$Y)#bL0Ib4Dkf8VrmbFP&_hxj|FZJOibn|!PiY!2ZhJ=Hm3)Du%RFV<@T&O2ByagPylTEE$y>pj z`SL|2MR_XuQoeE#+U%*~t9Ua>t9IjFlE9bo)r&9)yg4ab!&ie$%_8)F5qLa1NC^;w zG+a(Czk@Hp<)pO-4`@bU<*9y>a(42yeBm3(CHqX|c5Pcvs^4`y#qR`*O8uFPrAB2X z&?a7k>*VWjDo#(NYH=zqpVU#$(;`QjTnX+iuOrgCGNkpJq#JlSzl+b6(mWPk0dY}2 zKbhLd*FhQ{$6}nexs{&XL}F8Jse>A7cNtl~-f5+3YO&JZzKJVf8@VnFc+{t%Nn55RFeqZ0WWNBTz$#sTr=Z$<3 zUN2;abn_Ozo6kwkd=K9U@w{Zbcbm5MBun3JCVP{dubD}{>wPk(@3-Kc?%B(?rPZIs zN8vm8z0wRVg7@laW2ONzRmlHCo6;B zpRr$6c)gGn-p9*$CR0k98Qzd7Ddl($q>Wh?|BXMuxAA?CwUYl$Q@1FK?lh-|1;xRx zrZ{K-wM+{N+j~ye%Um&{WZk#s9RCIp5|(9xV%p}K6qK{>aR;|da}|(oVFa0bO3=-5 zvtyiX+UjHlg`IWY5aiR&89`;^Y%^mHK`}LJV+G}+gPU~;>QRWf_jCwKF;nKA73A!! zTTr^(vz)_nP*B>eb8b#hj)*^Y@}qOvZBq^x`PPz>@F#~G%VFP|!<$&py4;+Tby#iF zsn5ye=SbS;OfuD;;I4en$&%qcH9|AtC&@}8$;7=%6`kWN+wAs*>4hm9Ote5x3bYMRvJ=*VV#K=8i|Oc` zpcuowP(lN&F7ankphpsogvkx~l)T!XM7I-kjwoU2`_OqjlV7 zpR~EhXJ?whshe@Px;T5Q+c7=Y3b9t`OQ&_(=7z4c%-ym=k`@8OIc7gJxm}$Mt96$a$%F8^svBgRzp*h@$A3j=9_HZw{VY z*Hw9sJj*lQQ!lB5bg(a)*Wf*|T|8#44LOmyHKHhp%yUR+y;ifcIpKF$edA%28P}Vp_aMV}`tlsTLV(!wh0Ny+@ztm;zHMXUEUZ ze|GiZ)$sgEQ*_rMlyk^?EVgCNGehp(tNyEj`CwDjQ12ajZZHPMeqsw11jkTOUAP1l zb*#L-I{2%a)ni{+*KVP%v#97SGMw`c#enmM$hq`>q_Gc~`_{Cm=){v_k&9QJUVA5U z@f~#Xdc-=8PL89Zab$2p>BjQl&2SZ}>P5z0?{Lgm7BWYg4x*ZaYb~hyEHa++zV+Nx z5-D#%rk0hR$k_d#=+(|uHOfB>ohT@aRJNglwv|C->RsLWD|)pa83(+>|E$h?mT$Ue zUNQ&EqWQagDk)YP&EM%$#tKV=lc>-Vz5ud5Z7i=OIE3|9#1@{`38iFUoO_8#m zUvghBezmx6d|zS56sx>hf&GyRaC zm~_=AxlmPpQLrB6*Mu&hyat~lmREom#HX?cy8F6lUae2@&$-1hQ$esKG=)k!kZF%^ zFji0%+6hClqZbwQ`i{m5Dj`~lcJ!cv9^cVth30#gmo5h{eexsUP?%s{18$nwr{}Uc*YJK|Wnw&sW zs00}sk*3k7ib1q5Qrm@0U8|5dfHVVsRlJ_csfjmIxdq>QNyQk`U-Spufn$g%56&Z| zI&?XF@jtYOFQZ*ut2ElR7wPxG27)oZnl`4WupJrNS5B;{pD6y(_C$>ipGBRp($0E^ zo+(rx9>0Cu_jcf&ko&iardO~VDS>`ZP|x3VILF4_+&P?}nkRZoP;n0T0ykp^zcU%k z@I}sHr-6Q``A~b0;;D(`DJ;oHNheU2o@@XUW#Z}Wxm>chs6VBlJZzLR0xO;x=&X5i ze4=0#*ek%bxzulg=Pb2=l$hR!M9ES<;Z{kpt?IuHZMviJD4`$OTBpV>xT8&-;3l%1 zz5E)#c2oxa%i>MDu{T1(L>9k3g&=xVK8+7bVYMG6;TTGJL3d zH3j6aWcSj+f-KA051Wyt@zRzK9%=4&Hhx&#;zpJD@s=E=p>y)nDU1iVQ1lwO9Z>eN zTu_`kda*y{(YmmTsrGugAHy$rMc~*jT12nI$~xv3Y|vrw>WDLf{fw5rsaf!mT?bn* z>o^FxI)MClFZJorgW>zbVLIHol7n_1ik2V#hwfkR|7QP_`N;XVBbOcM{Al#Z7%Cr& zlsO{>PHqImEL!l&o!~OM#%*&B!Qga(b0ZD0pac)jJ}xN1BXZy3@Hi&?`u1+(;(!ZC z9C3?|vqL7wVuJN5=fJ$WlvgVH?W)-Y_uK;b^Oz;+)4Q;WkbJtshsA*?d(9AVJP5o4 zf36Xj@c#k+l@9!?yIuaSk9q>OJAJ`pcp(1fRQTM}@|K9AC8BSM>6yDd{+>XGe}Ay& z!GZe+!uhCdci4u?TA%9MB8uaYj^k0F`ZmmY$bV>EUmj6(Mk+g_K=tL%Mfux<-hS_V zOp|-p;5P(pNK@?X2g85&9sfJQn&7!$QJDT*v#j~0@@vCahIPY`cj#Xg&KHV;z(u4e zjcK&Lp1=?`onSw;zILAte(q}gnn3ZJ)e@i_h`r1m>ff5@_tf`v=1{ z(Xu9x6KjoXii6!^AEITuV->XzuHC=(H`l}cPb->zDv@VhQ@qliF7c*p8e`zCP|ebr z@63!+6)&Y2qxTe9Dcl4MxJ4_sp`f|EN!q%b(A8|bc_}1uTXZ!)4KXca z6`pdwoG-=7!Gkf;3WCp6@|6j6v?^ss`4Ya0Afe^Sc2#4<6!1|{W&|ayAtgWz(m*Kz zGFro%aM@{n#iN@MSUS~LTn-?GP{teMsC26LL^_X1@8pZ$2u-C^<3akc3oLhSCT4>B zK+4Xn3j9YsXQVL!f;xZQ6v^X5+61emaf6 zHEqY*qeiN5|ZB6KNlO9ZZG1-gBK1`DSp6C?p z$7vWNOS=lijsk}vJ`uuPYQE!Yu5)?CLwuDb0bJri?kdca3WNb|bA|q`B zQPYthj6|*h#xn@e;TCY@HrXON`SwrxgYMw5C2QywDr`fBwv}!X3;g2H+EujYG%}nP z!O78|oC_&}mzQSYd{ERFcA~-_WawF~TPyh0{F?cz)+Z;>zDvk(2>>R5aU1*%fs>)8 zh^h4#+?R`AEUxPYyhmPOAn-3o?u{*t0e)7#qDD16(W2fB!(Kqv@`?iPpPUQaT*?jA z0G_v_W92g1(TDQ-00%UhgEdPQpZa-WRj3OFx3+5)Aim?M@c5H%IKe%=ay4@P>bHj? z?@poNDO5P+)5QuaB~C#9j-kS1Plh5FEjBmTlyf)A-xW5Yye8;uUO`}C ziP^}lM!D5LGk#Y3ur!+66u$BKt>s(4G=5$BRcW+!5akZ~6wfvKkGZ=y{5Jx&d+ZYX zX~%!RVn~|p3Cl{efGcq=W_ab8_(sU;9$xuBqS;FTqENnr>{YbO_jw;?v zY`J#Xdop5r^?s;>Cr4hyE_v*#z}@m%3~euVKm2ml(Y0aw$U7JVQ+BrWC)ZvqLqNl46&|G6RO z{}uiRm>>b@M`h4nMH&uBKoD*Qn!?Qi*^K&W`lf8tBdqP<;_fXByQ?FG`$^1CaXHx2 z!q{KYgR=W&p>ttH_(!OrFIsZoA5_2Af1`iW|HK`+ID*dDqlZ~k!bV0XB9m@3x`2vr zhdPXE=14_z6sW!|rq93I=kNQp>%so}``7gc!tG((%HHLf zNaKO0`U4Tg0Yqp-f8d42a98gKs4hsu-Cp}=y$^eT(e>qlFAl70`T+>k-|!ATQ`CfZ zB1OFj_VovD1aEvlkS_%Y+U|`njR)sHnM}v=f^`p??>C3(Xldh?$q8sN|GRuGV0Z=i z?)f;GY=ORM$MY=(tR{&G`6j0v4o!nkz&axx>ZaPt9gAQS!bD1Zwuwv3R$1xkij;&VN`J_Bx# zidX94h5~<2z^OJfNIjguhnb83Z;ey3gkn{9){QR8WA(#b+8Vh zMdVinzYIp&Mu`IF8CHC42J3vW0Oa%^R|C4C=e3fo*qc!qR7!n(5c0GWSywR3L}E?C zOx21lPp8tFg7xs4+)+8XH6sW+0R}EwWSk4i5!90o#~j?@h%kT?l8BXDmXRwlQFX%- zpe#m_UrAnO5j8KUIfr=hqp(fe-h-E6_&t{r9yZOu(;Pb(r8#)}Gr!<)z%4np27-2B zMv7zbKv2NFRIm5~NC}^NI>h=g#2_ffZBwHb1NRpYA$CEK&DsT-(=I4(lG<*P+Q=&+ zL4oh#1zd*prXYuVB7t@aDtM$b>0kwV!i6_u(r(}n@gQv2bV0^F2CnDu=lTq+g_pwi zAbBekQJ8&wSQo4P>W{SE{+J=>?q&bw+e2Uy$j;g3@c`GG;W5E?*CC4O*|?>ibe>OZxi zuSWXnsD8(Wz5(eQ!ujEWsJ;~nVDiand`Fb20Cpzt?n(d2pfSpndk0@Y-mrf-m>;}& zzc_U1GwVa^%K7N7UR2y0W%h0`hY)jUjb7`BGD9283B;WEue*NJ_-$j9x#T_iLYI4Y zw|_S@e@BSEU;nu9^Rnf#)pK88`RdBLZU|l;kq0Xy!J0p}f_=<4;)7mMVE(U*ef=@U z7;wOmc|LUF{-3~`puNkU)p^u`JH~> zJQ&y;z7uzH;#w-FBF<1u?%ni!)QWJ0DF(^m`+Uy`p8( z+_Z|5&Ha0rKkn~<*wal2R`}TDy!r#oEHPCO$qdB6DBzTmIEQ#tZY)DWHiHDpU*%40 zPEfL-0x|=rqbz3$xMY&lk=l*-772V(<$|iY04hTRibgbaasuFL(^`{7p{M0$`S&95 zEXm_>QYN%ndPFWM7dC17vibfgB_oxRP+XTlanadeV=!hg(l}%@FMLw1dhm%Zsrp3; zOy5{u2G=^|0*99ofIW)|l3EWcP?b0X{1sG@hJIqTXJJI&HRZO}VdhoYIQ$wYf#dx(7Wi8rg4Xfg9I#n!E-nOdxZ(bLoR?gjqNDo|v#J0s1--@n zltgZl2cFzAj=-%tyqnx~zbz>){lW}h@8n*e`;Wl<2l#XS6A(~a`6c%%mnv_+1)50} z)8+7LmbJpZ$W&pZ!3t5?3Zeq~)ZoB}Wk_HxIcUnToEaBis(fFwPZC};c?-HWdK*@nIz>Fc9O~_0He8`3a zc>25lbTyISdMuZ`0*_*w@?iOXukw^Bm;sKcX&&cwXopATQF}DtqhNJK0#-4+n%D61 zM*xiA03M2x*EwmAp4W>D8t-;^P`$v~$SKxBbJ!0N^lk?{@U`1u!M0nJqW#6Mso5Fv zD29z4IUm8oC*V3hc8RZ~vAl>_U&kd8vkGr(0j|c+s|9Anip{NcZWh+9KwAYv;$f@! zK7+Ug3sY{F3u~!#t1BAa0i$LH zmla=BtZNQ=`(Kk&w`Zk)t#Q5X6jIcBkHV9QlcAdL7^<)Zj%9_b>b3GGi_wB>-cw*+ zv_%0=T&c)ad5^{OD8}SHO_mspARbeAiGauaC6x!G@eNKWBBRL7!ncqs_PJ#W9CYMWg!ZH44|L~j|_(1IyO7SE>om8;-G8# zrf1oODaS$XHz1Bxv#S?&PFzmY^uJKdKT(GNMwQ1^lq}y{^kLQQs^^+Q@i9nXFE~3lYaCDjJKp7l3O(c7$*Zrg6P#Sp#oW zR>zj@itAMd)$emC?JhXP%4Beem1#aKy_e7)aV5lHz#*?{xhk&4F%8w;z0uZ>+WKEgF-3kF-}yJUvOJPh1Hx z7(B?EeK;Fes1b`072d(Q_R29Djx5zw0#9`vRBlU0=G6OWYT>zrsJi>(tSW z!Cenp@3+P&_*~ibWz!c;#1+1$Y;bl@(Sm^sw<^Uzw4uNYoj?kZ|lj<`dvU=i+|T~*&j%mZW7w}zM0C1L;K?tg7R3IBn0#Z@qjXL-~aa6_;@8vB^=VKj0eh_&F5{ zUvg0s{x8tZX_RvKF%Cb@;YWZ_=^h{R`Ud=*bl7u+QbIn3dO)d=hr|h|Jayve;nU6| z?h=TJ4$&ZS@uUli+`t03m@$;uoMIe(N&JIMez0T>N=cToG>;MuYK5cav{@7cLHioC z3fhvTEu(F8*_0f^g5`7xT|Nc<4OY+;0;Cuv9$Q@#-jLZFtp%$ zdcD9&fiBqi75>PF(KuJ~rxEdmR-(KqSmVJBEBQm`QP3;-qsW%qm^Xp~hZkK>H`3X; zpf|w5s1T3qVAC398|cRTS>k)V@tGQ$SLwBhU)7svDZPm<66f+OqtJ^v!4_Hulm;=r zk}}f^qnICjX^oz@(9ia>bq)HLR?u%QJTB-5>{(KY`Vnj^(92xz7P^^kn*w1-68!vb zr8m=C^5Wa*t@JitO0=TdhRouc%zg7#bO9A0EAQIrb$O$+<;sg8y6KzOPPdCIehO5| z;JPS&Jt)KIb>Fo!hGjt4lkD=u>IK~;opl{5{*T@k<#ZXY=^-=x+PdLhHbL~M5e4Uxk7^fWf zdi`khLZBa}gijfUsX_ema|)iHeoo_ZyC+8{M}YMdhYxeauoF>w0-OTPztffZ_H($4 z!}~dv*qO_mQkVokM`UKf&q?}0uY(xmIsJn^Z-7&wJRzv?jRm$i`#D1R0+q}=0dd@& z=_bs8&|l!cf}d#UC*KRVe;07C`gd`^J<)P$oEr8|xZIwWL6>{j6&M^JYX*gNEYRYo z+${mm=tK*|TKMjbx&pwJO%qp~kmQ8b6c`_Ky2oKnZQtD9x_v9;a5{m=IqadvJR?*; zbP??)|9^m@8p57#!Zf<+=14Pxr_}b){)gJ4>0^;&j3cEj4;^@9Fx^njC>fmDnKG;o z9e%hXo-VG9`Pkx?q|B1mo07H$R^Jdhls4I!^Q@^cbmXDg7QGVN!4_{yXjA50p`&SS zG1C?+XUm#d?WWLyw9b$;*Rs0WxR%v!2_0A}GDk<)qMG+l-hSoQEAb17rd0iIwrF># zC%t-(&J;b*>KyOyyM6f9;kYYNo~r9&bzPx@OOP}0t}A9?23T7|yp*+VPJDa5`{$1N zo)4W1SJ{0eL9m~UU$@=wNvc6P^+dp>r@8P;}& z)twC;d}OlRSUv*8s z9>&vVOO$5KwXp%#+#Dv-rqXvE%xPvFYpRL$v8IOjwnW|E;E8o}yXVic8xP#W*^P%; zV{cfIHkG_fg*3=Zgz?xo4oS8rU^Vaz@>{dj%RiJzXlh+gJi5_Im##L-Z zD{E<;S1h;}wY^~+O{|bN76T5jCN` z&smSpU8!f@?Tg8TF{oH;E312ksaM!7upm#sf-JKBk&CfJ2iT&@m=ShF>{2}NSAFqI z?7H@O2fJmz zWZ{KN9^b$qNKL=nHRAD8wU9yWK)DLa6PTpN+@P&xT``1sYjz0^;;Y0})eVfshO#zo zmJSmQc)zk2S+$vr@8LAQCbcAwTCyg!G>=-kCUq8ef}UW7Eb#L%u6Kx_oR$>ICFRHh zd`A_?@?;1J&J`{~L{wf0)L+1&4P>pZLJ(x0t)d7MVW@>vzD)#`0TJE=3yQZvF|aQc zpyx?ZTVUa1Ls=Uy*BO!!b>>y_<=K3Ejq-)_bLACllrNm0D=%52eBu0DdFdME3+Ly` z%K*NU-%)>H6}X3EwBi)js5w4mKRh-u8K}NoJv!+RRQG$T0R;;91}4WR{ZsbL-YJw< zDr~1hStY!3Ss;f(whSn+3et1Y=k`!Yk#^_rEwt8<-Sh-mNMM=V&{l!w4Zc!O>NBRAf)HQqzhO&<#F=*i+}T`og?EQi~L;>5QtGxe>-^p z6T%)6MbnyyCb~BoV2;D7oFba;T)g|n+#B<)^W;Jo+jTP4a*8ERr3r2Lm5BbaTw1E2 z+)$P`9eAsuQNrVOykgr(a5|qK6eMw0L4hB1O@IjidZl}ilY?>N7M=T>=9T~x++7ci$H!Fe*XQi z8vZwgEh{jsZhCiQ_oAjeDeFvDcBa7Bl&3YO2byYDQ@yCENy_#l>-MC;*VLpn#s``z zR#Ua8aU^A3$=a?I_!>uAV|bvcWHptGn(Cx%XR>Bz3VcoVl14k-9%+B8JG3`+Aq@=I zMzqnsh(5IUk-(>!2|YSYqZd*{Dbs$lYql%yNR>6A zrqL1Cv^0hp0VC6pDmBMepHnvOlhkvZ7}|(SSwd7+cB*v zbc|mQoH_>_a|AM8sPRi+6QGcLe8kVGb8xXcD}P}YH1I=<`bo*xLv(e|-+g2b(JdmS zK|NhW^HzJ%kXJ5ydO-UbM;m|+!n9o6wLv4TrS&L{l<_Zj#=|VVf zbyR*eSL~I@RXl&hh1!y4+A90Kw3XGQ}FV+D2H3{6Qf;*`yR0G^7 zRnuCuXZ`r~qP*2oKu(d80UT~VvRxf2(Ca%nndHdz0ou>#U(o(2v|2+~(KUSgwIFw^ z`LXbEn#W2hIoyxcB5i4`{`q}^EBY7jQC(hLb=hHL1GFPAzag8P4FI;ULYO?<)bO`Y z;>s7lLC5uRz_=@cqbuhle-!p?h1meywm8IEaP0bh8QcIRtxR4CtwJuKT z8}m}gfTR$EC{*15{J?G6L!oWVgSCwo3T;}d3Or5;psa6na+E?d2B_QQ$OtFrL0Jm% zfvQIy8Zjro1Zb}Z3b-z#0s!Z_E>jMay%Blqkcad)PVNS!a)OhCvI)Shy5ASL+zD;8N(wqYM-^6&-p1<@nb3&c}{@V*1qnv3;?q zJJcfrRVepBy#BqTpGZDcf2>|C+8^q91UCC5Eq$!HFHAht8>cTtE=4cL$VBa;ep^`fFoP>@dEa!~e#;*3 zNt{kOI_G;*rtYxvIhhgVvIaA^JuSkZN^is)J#%eb=yf)-q_adXvAW7w9jk2s;8kl5 zcSJp`rYzLGWH3V?I@#i8*03pIWewW^=#?q19(wCg_#D#+=xpre+h=c`z0-epXl^J~ zzkN~GDdxYzl(M>-SoiIoTRnI7-R+(0P1Ut8$~v-KdGNb)1)-XUB8Q?R^YY9{>@;g@ zd|+*6tcbwgPF8Q4SwsnBD z4usW@N-6*)v&X;ldCAW3p|ruql&}UzcyC&7h<09k^I>V(&EnbOm?LJ6x$aPRFV9_G zEbWLY9@?v7MW5SSqWjYJn%K$N=IH)({Y#0~Th0W<)^DFbd%x}g*h%}}ZTQ9JU$uYM zo;>w(^32)A*ZUv5KFGd4xOl`D-T$zn>gH>+uf@0hbn{2MKHRlnUhw~{^xnQ?-^tGo zrrJ)WDo)QFTvlRbFD(;T31GtUO>BL~{9(5409$+@DnkpLm?0PobMapc4PTnE(p`_u zFzbSSqY#iiH=L1KhqMpgcJV^$Rs5>rs&Xv^kHm*TaPl}_J~FpDdZv6wk;BI#ctjzd z5br=Zg7Y`ZQoc?}FZMk=_l~3xUKUYC6r>Dkhw$PdYqJ!>&H&>o^bR6-u1nT#33|7( zr2-G<)@(_d*HR%HK`Lq$L0K3NE5b^zlvDuDCl4TWhj%oOG=en(7#}S|T!2LGsg6^KiBm>lJ!XNg&D}Xj zW=_JRT|Ca>&L1CItUTm}tehW8Z(tD#EGbNwCa4@gJ_n!^4_NbuFn=8Lm{TZ&IZ>0S z@v>)5mmI`|fSG$h;71k=z)MS7OImAT)U38f zt$rOHpdwJDduAw90to&Ze*Owj@qQb^0PLCWjC5Yx$+X@_@K9^AvNZ)hz(b`~2o_;w zw$P#Q;6tM&dM;LyGHwX(1HJw z$>i#%T1;NbSXfyl!if}&A_iEIA-aB5m}gB;C~dCzjZO)0j{)GG=r(36U|LC8wICK% z7cMykwZ$z1wuuh$87;Oa{|*x1#bhGY@ci2hL2^=N*m)T%LTn6uKgdKf@|Bl`u!m6} zR=7N5NCqtyqR2@BDYBwLl7KI`LN^rf`bZA)G&}x6k{9}dV9FH^-j!wA4=VC*C9>EL z*h#;D_VaHn*66V)pP=MlKuH2DMI}uLuj&e+`Xb+gDv&WRMyaL=THuR@)<6ix6x8PR zUPEii>^TAb0J9{hqjjVliJ4XAB+K#P^RQp>>t!ITm13C}=zUSHeP}zgKA>%|R1o$@ zD!>+lIpkkZDft&w2Cz8!mu>Kd5RJeH7B!FZ1x>U`us^{TcF;z+8B)?FuVg1sU2l|O z*M0E23%_52m51&%2tgv}8EA;9yp_h=V7xU3GU_iOm6P!gLvr$MJtF1h8?wGQnGfC{ zG@7XoQF2a-@scfT8VS-aib!UV_haP!F?h&20&4~VuACBV89pF|m>?C8BDz0AbOhz$ zO%js&Q;7P}k$`043MW_+n#nP=!TI|H>SLfxqaqaW7)eLV^sdORYr7fOeN9#D;O*X9 zz3&}O?7d&Nb-w=7#*Z8ScH=$c{m$N`tT(CYP3udh-;BJ;xKjFRc*FpL?~TYCOzX8b zV`^5vA#?~X_jx#cee6)OwE3qmfArdiuie}J;K-}&kyjUW=iym^-V#0ghu;lLA6o6q zwwpUCIsOxj97H??RS@y~3k-Nd{R2uc^IW7xAk687F>e}yM}Cv=fQlo; z5RW-=gljS|9+({SaEkr>V}DLMc8-pdlOvukDhW|^2KYx{11w8%9RD?@ z{Z~x)Z&>*fZV1_~R=riVL|DTLbRQO@=G&4dP9{kYYx5@kfn^NaE7>Q3{iND2NoT42 z^ZQEYqB+DjDLR_`KWnY%;!|}#t3ud<7XqstSwq~L&t#?P}MzYZ(_%ZxfO`nPR zHj%uTC%0i6n(mg*m4C@cpGYx@_K6v_V&SuUCGhxU6n`wO#7$|X_A31rz08LzYj2Ls zjtJo`+dn$`;ZY%c`1r5NKPwl)ZZH2`lB=(tf9w1cIn=Sd2mT&rT*2-r9Xpk@H6~mO zt&7BgEPo;SWYV@i(YN4OB)Y*^&?&%5SA$ZWB&wfk6wt|{EGs(jb+C4g<{R1>?XpgY zZEU{VG}rVciU|YLqcR4FJ`Uj3xPz&`*)-dKS8m_-M9@mF$csQklnKDgld$C00n@WJHkZs9xTe=s)1j!iv5v`@P1xcTu;OkbNSs(VYrug3oa D6DI&k literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_93329.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_93329.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5766b73c212db350dcfc179f885baebfb2081386 GIT binary patch literal 13579 zcmdTqX>41^neRQ`r$kaBby>1(owjUQ@}byKZ0oioOOE0*wo-=Xqhy_wA1NQoBbF&r zglvK}l^&+j0H&G(rek4E7XiX#3&iPmQ3s1fIi^u@-+BRIYhdAjN}K|V{@MLz9!XI& zWjRfO#g4>zbA8{;d^7W1GykH~sVNBe{U4vdQbkdJ!kJXiClgQ15{kM;u@p;>Q)k34 zeTF7^WdJcl;BB$ieGjk%e+nBN*11xuhB~4kdCJM%0R$MbC(ZES=rmtG*%^p zdpBkPCX=MR;OTN!4u1+($*Mk-vT9bjpt%L}!{NCgw9f^h%jiSPYM-kQU3yf{2ccuv zfEbnyQ>^}%^g`Y(Mx5Cil+VviVb-ulxp9qh(;DRkYm^t}E+5je=9QIAuds%-yiJ|a zvXyL+gBmRnXOtPU#ST5}ijucwXY!y`=^Cxd)+pbwMmff+WXl0-KD&vnc$+?BfE{8; zkIQ(4uHXC^52W9o+(BeiIE{Z#f)HJ!c#5ZIXo?zzUv_XO*9U5)*>x0kD+%F|01gJ@ zVCfmv_rfIQCA@TqIY!kp{f}{%^%8+`JK;AjNJl2!FVxEf>E)4~&4O&g<{H}}ewqZy z<&hmYYI4Z{Xd@1PIof#1$&Ik=4zOxSf+$yOVWT_wL-;yca42ia;3okz=Sc-3>ksaC=1cr7q&>PHMq4Kc-l ziI-t4ypEO3z?^U$t`+M@Px_}IJq4jy2O^)>@{%#=-(%o)yp*Jjydi6(#;no7 zYIsZ;bdx6+BI5%GMvrH~n^<~5nqu&tf_0#1J}(I+%YZ6e2a4hIlTh*uC^MlgjEyB! z@K|^Yq+m>*A~zgjG4KVc9WjWhQ0yt;i$^u7+Lf<6ogCGslj%eX2XCH%UCb9^s(VWL zlD`qAGQN~ABTPft&*1~`7@vp7N{DVUf)q0fLFROycX0nD4r%#7GC@%sVKNF1uaTgh z;+(^rZNfUj3UV7~n;dos(s8GaB`2K2r%q4|L4kW`bNvQEGUet3waacBcet!Lt(tLg zPBBUM33R`pbO9p)7IGn-a-~xc|wYJ{~;h@VM;;U?fWDX1=3okK$|hubCS zlf9(!IjG#)kweDC+HVF){O=bkX7vX!c$+ZCh z3HMwp;&d4G)0}Y3!A&~Gxfh@ob3@lpUYp;# z)U)cix-@7*>WWYkQdI%bo>8}?(fK&v9&aC)2aEIAi=&MYaajMXmkL@@+8jH4rVB() zX0Cr#u2<0Zyn6J+D^{`PV=RN}89_ZcJz;gZLA1IA6;5BWaZ@gGv>qSCO(a|;e7H>z z33M+&DD2M3p^;&iAQ^L9;anZ#VM= zfVC5_maQ`u1@QhPjXdRi5noPvtC*DX3QswVrULNhExh5m z;L@}paFu+;x;+_FBPI}=K+0~+lGQ|~!L|IR)lys3S_s@`o)&9Yv=LT{sm`nMxCwM+ z7PKZ+HaH$s5sWE%Dtb?iMnV94mn6|Eisp^@S=nPB^*)3t8 zVU3LJCeJp$fz=Yz%#=k!$Y!h}X5v{P#hpf<4fzaR&9fbP*I;w*|8NYA{C0jT;b>d% z3GwXU8&~Cn&1ju_U#j^+j$xTe8rC6o(zEyk)ByHXyKME2 zZO-T=*O;35rpz(HX)q>v52#k2&3qMaT*Ym0X%%0cE)}^5F0JNka+KBt==Wk$itu?X zJ#PM$e0t|!$)|Vm3|LJ&SM%L8HFf;1)toa`jhoJch+es+qJPW8GJ#92_?91-0Yih}~@+x4GP)9^!+-S#WtCDIWpn%an6` z#D0atE6Ww)2+M@WSR~cc1fV{jm~9+L9ZHbe*^!HaWMX8JE5&&kv5+`y&g;4LeVOaV z+_=*|<^l_GB|Rx9hdJl;6gGA>LnEBaZ56Fhtyt{BUakp2 z0gz-2%CV$1D#)=ZiLFiYolJhmz+}c`aS3ulC6}OCF^yefd1$s=NZGwi3 z%t~4cYB7Ub3QEy>wVo4XV=%7up;qwA42QGmANLFF9y3UUbxdjCo z7*I5sfFY62W{Y35WGV_CL8dyKEf19;^ERL6 zp~)PW3H2gNW27u@+V4G*&{~4KL#?Q^5ovdLyOtFCfZ>O4J~S3wtDUV4Rt8VSjnzIy z0#LQg9`Gxdj3%6eE|zr0z#!69{;K`P;rYX18{n(jk97O7A-r}?;fKt1)to9+85)ci zH6vZKw;Llm`9WLI9qgI4hOVH(9Z0t$vL{;gTRK{M=fK_9(9V9O>lg73eJ~J`22ag8 z!voxXWkmQbk}g_O;p`}Z*$B# z_8@=U*Y>ca?0W57ZD?zx`}3wK8*RCB^zLx%m4PqE(5`{lYl9C;UibArEH1rnnX`mS zA~l~IqMcF8oyNO8_iSJGqUNKqQ)eC&pY^@`$Wnam((I++#c=npn<8wa<>t}o@ICso zG1PE4Ht^~L%Q4@fL}6)~{ufYTUAPSuZ1Z&`EES=u*}q#fZvj#-G~aHy*%CMJTr}@N z<~`zUPolUoG%&XzT!o6ZFIt+BrTLTQPh0M^#4Rr_TJ|H${@-w4%zQQzx4evd zDhtk_!us%8RPX}y0Q<58nYV`DM1}i&-H%|`)y&qw7Viq}j2mlx9S_Y#*M?_@gF~S! zpPBEp-5o%+J*d3r-cansYxmC%#!d`=<&3>yL;bcGI~JRCBG!reov}+Z%M{heyvBS> zQI<1|uLtjEjM);Z32%-YxBEJll=?tJs2CZyAmtXHJOSBO@Q|3Q!YX8JLdqswW(sP< zG%BhOzm6>XkaC|-lTaE1eCPtQ)P?(yc`s7#{jqwvp3-kyZlILl2>KzPFj`=-jWr9# z+r>AFBTbRO78K zIF1a*VN}NA4=RJlgSE)GDKr2bgK(AF2DHnjBsG92{DvGYjjR|d>M({=Nl zIV8WKnb$0C+=Djmi5Is-hyJMST3UNJYu7yYKx3x{W7ap&sW)PF7M)^|f&DH+*ZFmU zp`bl1|D&>z_@eA-=LK+2Eb56{hk^JV={vG_^!R>i6pFKhM*%JfCHzz@<59DcAM1Tm zpUkHplCaXZl^%^>&q|1!W0xCaBq5i3#;^BF{jwoR#!VL(VA%o2ki)eBrg9RuM>|=y zPCvL8hF7eQiI=ksF#_e-`Qy>K)9!X&LoP<(z77;@hL_-x^K!rl#z1yp^ zx(8m#t9ZD#_v?Axg6dY%1f(glBe6jS=Fu?gK2 zD7bRv;$11%5AfhhmiBcWZ_7}Kv0m!Jn)^nt*DLBy%GPn)#5vZse`ae^i^{%Nbv18q zC3>f8KP1G)?_)$eAo_1F^`ZHC>0BvX8A!uFKy|J0;+Gy4S7a9@iZ^~OV~RHsK68`d zUIth}&SLw*L9U6o!%!l+z~w7bTJb@z=+yNzSPLT)^(qmG3~UBFsGo3w?7l)w49f>4#~F!hdZIp$bK!J8zA`=scghn}oWtdUn;E&NAak!^;7S~A!chlqt*~1~ zYiYG}T#ydKjgmmyMO~jz7au>Yw7&5Mj&V3B61sRz@^U``n0MjldIh%BU%XU8pLccG zKkV&En9QI&icuk?^m*Q%KT8?yCU57m5{p8tMQROKd;Ps2bwNW2<}FO$_T2Q`9r*p} zFHb+v_WS8a$$CSu4b-9Gi0#w!cg{c19`MmiI(@+CKN+~-KkMy&BsB!8kkk^SW2IZe z1Bs&YkStMD9Bc~O;B#F$rwkc>p$QG#I6Z$lZ2RbJqNr}Ms1X%4MrhD6W4k&RclDxO zz4zMUyZRHgbvLHxr$4$Bsr)4$q($HlHBrMCMV}Sjt^CK*^&1qGU)Rm)LbiBOUAX4< zwwv3cbbQOcM0Hc@zb3Jq}VS!&5(>{ zir8dBmS&j+Ng4}0yBGyeUv|i9C4;*RDfY&Je+q<>N6t$%6fehVR)+lxx$bkZKR4Zv z0zdB{{=Dx1m5Uvx=9f$KxYNG9T+gmj?A4>>mAR?A(!!l)(W5zhtU!6}vB(Om1aGc- zWePv5!gmT8qs`^R%K4xOgha|euHu!%pQiee1he*a=#is`)Mz~L-UA+;QSA4}gx`*6 zz;_ILcLw&JN5!jhQJQC1IZs zkN*e4@p?vN&5ALHV+xr1u-L)&gUQ6w0Z};gK{*k+qL_FGXED|5^T26z-UjX>>BSMp zB|%0W1PD?axSnuqyDZ4;AhV{x?c|)EWJQnDk1;q9gCfjckYewOAOnf8vzf>aLE*YE z?QqOE>b2ZYakHO61UL2EF!(fZXY$L|p(*_C>9Q4fgg>|^(QzUrE))DfY8Tf?0Lk+t z$3+Ktc}UX9vM%iHC*mO)B|RCdo-}a_(C8oF=fc_n$c20Zwrb%aa}Z&!{IFr82MUk|bfBuN)|5`5i|?++eB+Kr2v zTBNBB)8Xd0X1lj@Nna3{_D_anNLL+}Al=qQZ8Op~gT-|&uHCz+Ifyg|e`24W@AzbFp1-RPe04mW*tIAWOZ z|4sE5TR+=+&-RD&U!8xTI|XydFPJ?RxDYrttGcG0)qZ&5`q{a&@%(z2c%CWH7U+b! zPXEkRkKdE5%d7VuO6bhNx{z&d8`9NbKZc3)SsZGbD+^tLe)z}*RQp04KL$vEOvDc% zHpR6YfB-PoR)1@diEB!|oeA(H><(7NH03dAIT0RY6}cO*4qO-_d-AwLQhaGfgvWcp zKGqT*LUc~v{eJS0xSc0e+I84{1fk6!bCxJp2KOv2j7KM@*!2z9+CY?fJ<=$ zxCPc`h)xZ9Gv;M0r-1N8q9+To{`0o+A@M`HLbS%qzUg`}wyHdE30x38gMT`&ICiC$He@ zoBlVk#`>%28=iSj^yC+>fA;!=yk762CEx~U{Ab=Bgoi-bZ}|78ee^?vIWYb6GGAN5 zR17}Yfne2aYp4pDY8Js=W!$o``}RvWUy2*|`Z~aM2g>t>VCU>0ID-mnzzbKfebLm6 zOwFI{{`93gFU3u5zD}$yx6R4G*H*OEcQ|1w2_6TP{IyU6+R%ty!Mme|=;3G?+S!fF z-QdwHDhs+pZJ{fnBdDSY6*WapM4O@`(N?tk5V9NsH(q7kjgtA2g|6FuH~Zq1`+SE2 z13?;h5jc&s<%#NTH_pzV{iN{I4RFv3yxV>CMYVN?YJP zDtdV`)mSVA3I75o9`G;l;YglsiNcAchu~V)4eiB{%aa4@Y_b=RbRB#&4-ZMd13p>C z6B`*Xfk%`{UNZ>$+avcVJW7u$`3ykOCuvtmsFuGXMJpsj*VG~TVwY4_ALT%3SmIz8@ot)VIuh{JMn zoCB`zh+g4p=KcwKdf)1h{CAiZrQz=c)=NO-?ElcnFF3>nKV!C?OD#Jv?gyq z*=ImY%yu3xLj<(MM25>n>^^x+fl%Q785fvDCiFI>g*bsjZ4+FHwNbimt~L!h?+W_`$hq|Bm^$sM==Jk z`JIZlDwdRmKKWH}&4ygjj@ZDlm}3Z;hhy$(AY(hzft~HDPA2gkTXE5yu~%M=u?}P& zin-i?tBq+VxH?whYJZMaT(tQfhaJvlM~rg;s)I}$LDkMIYvellH+o7{^sPa~Fv|sL zLJZLKj%7Vnx%q~6Ub~z}viYg3fn<%Ttchd`QrSY1HB)8f*IVaWmn|e)M3rv1-Z0nj zO)<_ANe4zj|>mhYseTm z<1dg={cIW4FX~703;OxBW1~6kJLpXX+E!(C&AgKgnvI%cGwVaY)A}efG+xtJWcrQ$ z1Xu7S`Ka)rN?v8KmjXGL`{at;nxUdxik#<6Z=e^yfo^^Sy(B-q&&XLm)jZ5}th_H9 z341?NuDv4zJb2(4fL`%(97L4w?2PeAH&%`)m~;`6dUq$aI!R;zxL-)`_yGH z$vz#o5qd0u9xL)obAQYE8M#X6t&rQq**_&O7Xe-rWwmd*LDp42LyGTyb{L+|2ofHZ zC`X|Ol!!X5#I&D!zoW}>w)65)$LUu}-LI6oBIxa{GlWmika5jYZ>#vI%rY!IxNw$TrBl9Ny_ zYz%-J1Z!5dLf8m06#*SV1b``6iCcgY&`>*ep;9oTc5=ppw6!X6RD0wHj z>^!jCCZSBQ;Bi+A)ob=?QRabVHw%PN1*ahK8x=?NsnG5bDAX3J6*i+9RF|nSC^f+v zfs|`kDzPK&vSAJHNv*nR!NH_1F)w!-6Or?cLaB=(NZn6v>i*4Ia2jP{gC1*DXK)X&gS!- z9ohX{|JrPe@jhq})UVCAFlz&if@XDX#)3>)o+hDDXb=i;nFakd3QbTh%9fkfm~CTr zwz)dmxrS6zRw`FVyM$f8LPxt+>j>rwh{55#7HrO=iIv*jLcP!&06gXlXyt-?gk8e! zoN|k>XN{5N*+0ta2`#@|7USYhv>L(pge^H*wPp<`n-O*h?5vrD7zJZ30WwM#)I-ma_ZqNaMgaABAcUwR zC5`j?ppQpq9uH=RC|b^cW8B@ZGypDwX4L5$mb4!KfSc!GsfcF?b?ZpphF%4d%0D6* z$9Ydb?;Ldua+1cy4PKMfql06T&dEE+`rVQ?;O0F(N#8f(ar$@dl(aHg14RzxJ(6~C z4AIq`$1iDpeh=@Kj84C2bkOB+a$Ma3N$Yfs`*}%wRerhf%jk1CN8CP~H{dQ%mPhU` zZ$S@aj*`LW=Lb2r!#TQ=9nNO2=48jRRM!fxYfLg`TDS^&Sjmq`v_cPz59by@qM!Fn zj5~aTFe-FzFe_WReq023@#5NZ;|(BsKc;SNOwL-hc)a2g}Yo6qY?=#@0#FYggRsf z**7$c7SG|7sBx!j*v&bvN(NP4!Q0LQC<#g$$_@G$YFs~^|7Z%T3YsA8u(1~<$d8}6~t_yWNHCv-M;`_zY z`dKDrJ`g&UW=dnb;x=*p4w0!39Zj=E^X5$=yD7nlOmpby5?dVY6*mB6~1&EdAyh^SW58Xs(6E*|IO%?{(hoj5`yylx>g5?g<@3ORi-qQ4^;_u5JmoG3!LPwuMe@%C;&h$_B$DDE7 ztX{OYq^$2PvU|g1y2uvu##_aThS^%NVy{@V_fPHLcRuc%b3U=9S}%x27sA@K$r|;4 z-V>di-Wa!wrcGg8+Eg0TeC~|trz_%3qNzHpdulSzlunn%tnt2C-5efrA7&fF!HpZ<9%{R`3{(i;mKGA+aEIEM48{##Il5f0;()$gwJz~|K z7`W)iF?JOEug28@>uhirdZ4L`*EjO zQk!TJ&Cu79sVsJ0G*!i2Vo@D%6>o^I15-?#6-|wx$^!G$KxE(!KNFY^qzbCy7ry@R z{)dnFZvzhksm=Suf_+K1$acU=QS6=K8SAt)ri+s)Q2Wk!_kio;>n{J$YPgJ~DShY(68hXTWF~^Y7ZC{^)U$ z0n-wh%J{{^xj!@|E{dCXCrNShUXk7xIImsDndoAF-!d1pQ)`sL79!P z(Y}~FQT=DC;RWCVA(7W5y?4?**gpV>ET3w8HB|7es8+|j{S*9{3;eBYLPs0Eg@@oM zzh$^(oFN2aD0>&r@)1fz!lyb(d_oH3JMa+sZt}g*&u5097Fs~g!>WX8gSyX*f=wc-f{^l^2<*JbO;HRSWxugL4ykk4Pg zBClFQK7aj+ym}4!{PipHnunQNBS{GQvqYV~n>T?X1GY2xSg({#xzab_9Cufa4*K9c zbPYUENjlDbZP4ZB*8{UF*I1D`peA3StLHzc)(R^{NC84iXla$JMb0H8k9kJGtT-Ix zQ23F%4m3GzsDdeIx^p}=bZ-LJS0^9uyd4$QlIHZ$b8Rc`r4Q}D%20@R!ds$RN1UTq zIp@K^w#@0zxxEw4F+cc#4(vJjcGkC!c)&689fXX`905^Z3q1ZiL_DPmrsxPA1#F3R zi>2FA)b>Z$zWw;Y$H}H7H+N7x_(7`StVo?rQ%v~7Na1s>x=hErKo)!R*o&{DBzDjT z?g(_-KjE1GKMt>fqy?YaH6UrhmGR%;hma6j1MHdr;$RQCj^;T;bE8KedSoF;_M}x^ zli)8-_{S&wlBOS=>Kt95SYg%+JA7EyN!c4U!A}A|3b6P}!3_Q8d4~nQDo&+&mI$3wZ&Bk<)bZXs8{I4w{XcqRwgS zXRY&^mie-ll%^%z3U0yFmB^KtJ=PPmCdhB82UL>$UjJCXz#a}Ad#O<~y3moAMnYR0 ztrj(JIG;Vcc!Ryv6_23?(RsCsq(FFB(qaYgo*atE{s|3+3won&fiz0sp6MrqgJAx2|id+3eMyWvSg34w}#h@f-$d7;TTg|4+0MkM3BG1vkB5rRs%&TGXi|dLO}+B zk>3zX0=o&oS_N5w6{w-oRfC5jr&d-&Fuc9KPYuXvl)*m-g-E!`4~}Mn2=#|EndL*7 z%+eu#BXom<33UvQRLqaVL(<{^fuy-M=$@4HP6#24I*SF2K)m z9$XB9uS3;(pLfFT4!G+I`0I#>{3Z-qI1nIz9>LdBpQPyOR+(V^*z(8(p4@AjcZxN_uB$+~fWtufj^+tMQ&Twyvt-N^*;7^ri3!_z$i!lv&q;MNzOjBKvu9!A%x@(Mg zBrGYW0bTnF!yVDhkuGS$SP-@R&ilDNb}kwKzg;v|#m{|X|E6bFlh_JQtZ3YwREx&_ zi-va5(4I0JS){r|s_RMH6Fx zX@-rsqGZ?=Ifgne2z$f+s6FhDTmJ#*;^)5Zz27@)|3}9>-8`>pmXVO0bZJ1nN_m^C2YE zt2tek0|oO|C3Mlqq97@#{$Ki02}VJMgWuR1Rvcd$`&K|rC3{!c&jP3=z_r5e7J@Pi zSQzhQ{Mw9Y)#!oQ1<0Wr-?$r4O3%Bo{|55EMI5hsWs*PN{!m7ze_%%t? z#8?Ku;Ec<~FB!aOxf=VX{2k!<7Jj~p40K^n>yCS7czXCp+CHzb&(rpFVbRoNWb)>T z(9v)c+}5LH!~mQ9EhuKB0Rw8MY!MqYH99?-GE|4!K>#H8M9)U*7Y#PiVEf|yy^rpG zlrn4!wLdK^nff^LajYrk{A%Ls;Qe5-=le^KFD(=vgAiAVHFP3!7Qp1}&C_X1Y3M|_ zCwvZI=G@KhCDx2xz4CYucwOw)MHXM*?X#EXOewZwk?j&$1Yq72+kMs^5Mefsd~wTvl^Gb>=~*oa?P-jrL#%{H*r0IUl;U0`k8yIN1y1 z2|+>WL?ZRRkC_&+z6`5w@;2OpgHMgZObj^3l_hHTcj4bsVfgeX@~gZiru zIgrd;b;zM)4_vP#(jl>#Xl&++f;P{@h~MW!x4GQoTeklad0-x%cf0(`4ULZj@4v#& zcLN?FV#!nzI{6dD;)_{NjEo+O@A;$UpErEBVSzdnYI{pf&92$DxvdK|r$x;s*e@;7 zQ*rxChOpE_gw>MN&)J>?QYDu|r)^CE&(5F+Sp!mjs@<%x3+fv&3fdM@6c94m4tC^ zO}0O2Li*~?udlap?X6&Se?3eRhIKDH$SRd;`562MlDbuXvwT@~k|Z0KO9+T#Z=c@2 zY{5mVQY^*AbxN@e7t596dR(*-1x3g|{p$u)e4!>(%!}i=mlGsr=u;7;<Z}r~neZfNM zc{^E5?vCxg*Lt^gnSkuP5p<)ry}D^a*9 MRZxAC#yjHw0P)K~sQ>@~ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_977481.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_977481.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b7ea9825e6090edcfb963b5441868739565fd1 GIT binary patch literal 10397 zcmc&)Yitu)maej0cG*?7W5;pG14u$hVhBkG0^t!JfrMu#5TKierq}DZoZvj{PT?>BP!S$-SfZqKmcLRt`(hc+aVky?KMhB5+MiLJj-zksZ>m@5 zV|ax=Dj3C3e4HWm5Wz8Dk{mlk-ZCnq!#qqy_%(Sn!zD$^&C8UREK_b-rrcV*e9*v^ zzOT8N9a#CE({r_)%|i^^if~lsVFopv{e9{jt0(jXSI3n@%gQ1xS3!Aok@AXV+E*^q z9!BJEg!&Pikt168hKB^6Lbds&X9-q^E&9oV^(=Rj_^g}8Hn7gAW{OZ$5-gte`xV_ z4dPizVHoqZByi--Y<+mu3N2N+mK6mp@idnNff7n@X2%LwuFzI1*m7;F1iMfbq6i{f zEvyo%LqH?ET39XErLr|aJbft$mBN}3%wDkP`mPn$K%ccCm~#l)gli-Ukb^Qzr&g#D zN->=Re(*{y2^^W{U7}Mb)C#sgt6y#&QonUVnNTkAT`yD!^?w$HdSx~eh4lg<)PZm% zK2UMQpbDfiA<&o(t`*kf8eEsH(JM9K27#1o^ZhtAUIW3Pv=P=V>CGrL;l_gXDI&;) zQo%N;3g%XAxQI}{rBUV)ZVH-}rW{#7x!9@E3Ay=W%u%lN+&Z1$r8wSEBSSm!8>fp9;I$Rgt zQlw|Twq0luwuC?gakR9)TZPR+dqIAiuvOS5-Oo0OT8Ge7z_nF55&5sMX}Bz(Q&I(| z%&wH>zY6a{$E){M&RGZ`PPhrqLA@Z~Bb;v#8io27ZVoPOjFX=n(-1<_tB)Gq{))0_ z$1BRBoq|f(u|$%v${7%LE|CRBr6#=V2fgH5wX}y8-!b7kn6$tuyk7#fO{bOFY)5UR)-6G|}I$op( z@Q25kO*91Du2D}wD%VRxNOcyd-LG} zEQRcFLrV*MqAA-#W@2bb-X|J{c>lz>GvI{=%fS#)Jr1i!LF5FUFxbIKU_?p!#1bXx z^*JSqA}te!j?AE!4+OJ)*!*}vL86)c_(Yw5a4>*ZLrx69tYydspiR_Dl_Riocqs#- zR$6YrQ!ck#Bu7Mg#OZPeCtRbV79R`X6k19Lgmj6z0jJl8_m`EaI0F~FgF#WR@c6E% zm6woThc||jtLYFdjMNYYgZ}NJX22O5$2-k=E{dc(%P8!L%rWeURswv$p2UPZup8*r z{59c!PP9(?`H{f5%k62s;Bt?+f*1V0O#nQ7!PWrpZVh_I##WdBgFF-f-z2Qifu}e zX+ve?z(dwF)f?+g)TY_$NYCRE>x^z%mmm|J>5`_%k%vq2neuf>FDh@H*Vr;<>%6@I znHwU9GuFz)d1Q4&jy){3$1f*$q4JhlHeI?q@_L3XPi#q6qpD4aZHe?`j3x7>>yU9> zibZUDr00>TG(L(GgY&X>a6EX8D?u>ck`}X^0%j7_ITM$hhV! z>(>?6D^hJKZ@P9nGH#C?d;~RTKJQOz5@%4^`jj4(wIO5M?7_MAUt8un?i|18MeWCt z@p$AQrZFhDoc5#+4rhIL(;pdiAcdGn)D=Ke`vKd=V@+h*kMD-6Xr8DKz<%!CqceWKZ zbfJo_dx!3K|Ka4R`-f2X>G?B*sCy7q45lqZQEhg6Gi+JLRGw%-rlzNA(&%_(w#S2? z_s1uZd2O-{nd?&~WZL$WQW-m*sWlculv-p7>zX8qEVaOPNkf#%*s2oUU)ClLAzOX& zBC<85x=`t+DE-JO`byLrl=-kDNoc-uTD0gk|v-*GoSEN zmtvRVt{HBc``P9%c3#_=wl}1x+w=|k8|&T5JC*6C&a`}}Gv&hnp^*Zu3- zxkISE``%^LegYX!z?m|Yedm?52Wr^>g%!`` zIDH7tO9*yvnSEC0RzY%9gX!b3ZT#eaAxQZ2G$2tm=LHNf;_-|_Vo08F%RU*y zu7<3HPFoz{_JItB7p`%SsJZ0zOo|#8#-k9s-UamOpYQ?Fa*w*k#(B7n&=vrv`+`Wx z@P(Ry?{?7u&ZMAsXu>}M&^2&z!s7{f91K5zN5VotB%Kn>6g86)@kxoegm6R+MuMVh z@}j5(NHO98e1=P$awdSUQvqDzotT{Q;)DwZ40$9J#N!kHH}DI*2jV0`WEgg8Yi#Ss z9dXx3o%5RVXeTb|h;>};j&wyXW@tSYu2DX|GjTCZ*Me|`>TR)YaaYn7-Iit=BHcww zjFxzJ?Crz>#ID7tz!L3=_r!V^m>R^?q?x(}rU@}kDNAa9nrU5NHX~;9H(T%Sy0a_I z?2U9kV%X?JY)3qZn5qRs9WvCV4fP9#4al$oK)^uS(6*r8iu7B*@!TD`Gm_TtU(oj; zea~+^za9A%WKTx+KQfiXEpc1CD{hNjNYo|z6Q`48;`Ftqr0d!yP=gD`7G!Km@w07d z&1zCRuKVXKs9`56-L+sngsg|s)+3Q)8O9Rdn%I~O&NCb5H5+*s z-XQW_QFn3D;~lyXNi(dZ6FxG4 zL?)Nj7`l%bxQ(RbGN9om(UP9g9fel{MOJ{up(0y?Iju5W$ZwYKm?6dtsz6_0thyuw zbB?ECuN1fPT8Z*m1gfZB&?r)gB%+2WjKopz(_wwy(Ix05NlAc~qrmW1V}1l0_`z1) zlE7&dYRlB?a`og2{fpP<`%^3QFJ7PTuUesh@%ns!^$Puq*XR3dZem*)0!auI$XB{g zGWbJnT{T{x&%@V%MD_RrUcawqY$6b>0o%oWp=QD-%c~ms(AV(($-oVjsO3DDylxMV zcZ#os$Xy6{@S(zkJ;NTTzgMl~LyrLr;z&$NQFjI2t|IU#28WZA{JBbdeB_}_Qn7<7 zn4G4sz{tj;7IcSJHq!VSoK=gO-k#H4e9lu4z&Euf5|ITl-3gjufOG8$ZOqzmV2BpV zW@lh?OXsK`ybpmrkdPby7)S&YdH!#Ml`tKmHScZ66hAz0(o-q_v;O!d5`dY-Ay7)?`WVrQ-% z1eF!Yu%%dtU+szP2Lp{ZMEhe}thNA-f3$sGb9lb&a9VRXy8V&P82P)07W+)qbQOSc zOXAH`*X_eM4*#;}yWVemA6SmYv|xiwu`xD&>H%HxT%%^Sk?!X-L0RMVNK+1~G1?J7 zhz%34PV~{Xs0(XL4E8M_Gx7eJx2E4pxPJO}hBi-`Vx~B#&x%CH7kjSlN!6ySnxUU$ z2)NSpszk#Vo33q2k?G0}e`&Gxd8vWYF4_s!8tIkw2@TE_2^>iN5O+?x!V+RkFKLf| zf-A6+<_KdepvYB`ED22*P<>V}XkRY-L2FR(Jdxw0m=1uEiac!GQspcX!0h1mg#!TE zNCRk>`T(Or>8Yg?CIz)x$=R*6!RT>Gn*uJ2JdgmF#d*kC9D?*<6Q)6j(8}0YzyMZO zEh#Y+j8){tTF5`}F~OV$5-3r=fE$2!>*7u8b}@obhvlyrPW4gS_VoCvpz+gLmPRv*3gWz52rH7{L}OGJ3&z zc!{ke7^No*=>Z^Y5=%u&(}lbF!NJ5Ey=reIECdM{|S!r9%1k^Vua71f?`n} zaF2>)2#7j8?~c2C9AI_QE9yed0m)2I4UPw}1)Gk$vORy2mADGwd{*=uf`;^?TRfdr;Cb zw-c50fRD~l60i8YD_)HZYa$0Sv@z=a+!7r{bXBD5vBkb%sl`sDhP1^IISO3G+Y;pT zwx7HMg1NB?TwGl#*DQI%o$5)gx_)K$;xD&;*ZFPdee$38F6`<>yL#`tet+S27t*`V zppG-B@vZsuPSoIx8a{1{yMT^6(U$1Hwk1hkYmAS?_xypbTT~OKszohfu1HjT*_Eh9 zrUq~|8cL=dF~?^;GriNjX+vG|3^Fv&s*!$6WPgS+Mvh2Q(~ZTV5^`elIh4SkYsj$r z<>&z%3M7Ee!Z+BQxKoG{L;;l5tU!+x;7a`k^r(V&cuG*`Uwh93A7O1+7shZn=d2Ml zAntSkfYmp%usv+xD8b+*!;HYdCDZotcmnY1GQqj{2qfVuNlKmxdf+|4?Fx9DZ4N36 z9N^)F^GTRNvM(i-4M;%(Lip~>w?w)$CooxZKr)J*1K7DQo3<=?=Ot6tkM9tVTN$$+ z@GPv*!T&BMMuUOCiwL5oM!qyQnK3`ld)z_k>HwbNzXC$H;TO0JQG|G8u|Q9j1gWdK`+1=H5RzINk!jWCRreVu$*S!r78hVjtU8Dyz|Le)x=PXgy0!O%PPIs^R zekg4_7wHAxsoowBFH#z_HgX7KmC{I`bi`l<@tT6*!k<2en#b;{Cj`X&Ki~?Rw3<9d ze4<&Nd&?-ibWOl}fk-PUehsdtT@yinaKh&iwFjg>r-unPo1;_zA276kb1 z=S4MMRIEBg<{*A8J1rOMkNC1lY0E^dT+U;eBQ2fC$cyeA@{e*i2pLRDZUwq?%+F1X zdiL=HNYU9mIC?8ODW6CE~{SxJqszsf~sQr_H(67pr zR>mFA&9q9jXeCtihm}_=7gdKzvSraml$OmjPd6`?WwWhqw?}V`=Ca4gUh?Zbs_@yBe{p7j-yKLw@Rb?D(P{=M03! zw#eX=iO@0&lpRs_XJ*JhJxp3jQ)2xW&DWY22}sUTciB6vl)U%GZ_obO)&Kwi literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_991002.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_991002.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcbef44a43649d32a456eb597f751f61d480643d GIT binary patch literal 10974 zcmcIKTTmNUmffvaOWloLJPZaK*?=v;U;(xVKX9;(iC@GRCw8!B6rmO%9ul;KiE*n& zo=VD6l3B#gY*|iiRpe}{B0shwSEVXsf0m@SD&F0XRgb8F?e$c-YW(9rGWN&reC+PI z-D)8dVdCuUt4}iSdu00>+h*H#8Cs* ztKz7B0*(e&d9`&6KVd%x;J~Zy2d(q!IP30bh#bOL>QkJhhw%H1asqIcMFq~R$fF)H z73epv({Ea*-@H!0rEveCfwjJ;zL#BCDf{($jICyiJ=jQb0h&tI=`UTUzigfUO&*FZ zFFZC#vlZ`Y`e~qIb72Tp#%2nSzDh5P530*yn}b@m64oqw$(rl!u&VHSSsDv+5jejE zRx`3TcI$h1zX@KmDaU{FRouSqDPn!kQylWAuc0iAvTBr(kl8h&uJ7#W6X%_$MEc0- zH;$fi_MPbOa-Mmi)$>9NAiK}K@#fhqJBiVWuoc!JKL-#O#q{G1-JI3%t(3S~$@ zafn{`Cdib8I*>?u7S{_9U z4Pdo=aSnF#2CzE5gfE-XU|0^Lk}u^qqn=O|zlpD!fmK6W_$_>?)VDPV9FP-V!Ec=b z&he!=+-iO+;8f3OF>D5aLN;jzXh9#EXB%(ht!SP+st|A2Bo<%Ytu)Veej8u>qwsP& zm+)%x)+wNp`0c!A6zJl&!uu*uU7UE+8I=g1278sN~Ql87Q=sj4$j=p+hyL!_bgP;*RS1~#EXVJsw}z1 za}~bs2b_}6u7TH~^O%7*$7=3tbLZH&!F)UTEGCQRVeLZHA$rzu&m!I5CSJui@kI(J zf52m-G7@U$)c|Qy+G`=Rr)5^gL%TPiwmCC=UVbg8w~pVT^uF*)(AbU`{?VIi>?g{l z_Ma%1_Vews_1#6}bb}tLysLTQ8DupoY zR1XIa9^;V2=8!tF}|))rACkE)@` ziYnG0gt`12lH#H!05y+ja0UJ2UboZ5vi1&9<3bjRs2M;7hn|3_9SBH$x`CMqsTT=T z4#_{xipKSUAm?R0PS^Nqbu?EU$gBEtz3x?5w@)-=XBYr(tk!)ZrS#8CjOI50V~F!l zO*jKy*zS7fk*JygR%sw}&d5d>s=~s^9z#_iN2!jACZ+22Ii<-(QeFe*pa#8MASmIC zi}dQ!Fqt%cwt&O%4-N*PS1w@#QhO9ELC*x76yjb0_$ikvm)k9BM&Uf)6^YSNQHySa zb5z7#B0eDMB-Eh^Q75xMAd*@B2Sj233`J;VQG9Te!`{K5NJ@nh22?)pfJi_^JviZ1 zCIhC+vjWTI#eh=D0u4wPEzkkSS^a-Xe?Hqh?dL`V6E3%>dD!J1bp?n0zFlAk`-05@ z&fOgJj88N}tJ&)dwmHXL!SSiFrirUg=yJj}3i^Fcw;wK3d%I)z?iLsk85zhq>fwBz zG436p5t+e(60kwT*z*pIVrF+mcg69Hp(1?r5nVKUCVD1r%g~kKt|unTP2HR>j>ivX zOgqCT9(z5T-Beu3!^AAe$?XLm+-#!e?1 z)7IuMxvyp(%q$v?0RPRUvB|i7$y^iBDBDYP3rs_VST> zh+jxt6RKJSYfEZy{`|L&`3n#H56!~flMhFQJ#P!vwY^tuAAE$(MlT?%(hP*2{y4$Ff{-_ zP0cBrU}^=XSj*#`bNdtcTt|d_Y&6BXqBHS6p{P2dS}rojE=K(ctH9I%yG+*Dz#JJr z68&XFy-ZRcakE#VS7NSs@h4@s$`a3g{wl_S%EnV^HS!&w#ClJc0>=y-%1?2YfE(Batc&q_}H?pC}#T= z|A;iql2I~N8n-6eeoyXv0&OPBz*})A)xRfqJp<{0iTJysZgSe=9U2ab+JM_N2A;Tf zQqqoGI|9&hp5PScbA#cROQlegoaj$Md*2{$=x<^?Hj;Ctb8wgvR|pU%v0vdl{v%+p zLXP-lSmy&Wq~VD|qb!aT^#^(w&8IR;hC)M#;8iajSFJk4BL!U9C>hlc60pxgCLK0X z!2P;UDCSW}8zCcr&8prbL%JZcTBOA55ar-H;=C3NI~7_30tDDd&JoK3LcyvPLwTL? zT(AQ8{Br!l|H6VdgdwD7KFI!t*eEg-#d+|N(cvjxOP&>vOxiUm3Kxf!q5=v$vdmf4 zAOYm5d-9Ak^sd0>)yi6)+lrbhQGKTCd?%N)C<6$pIvnPlP>STiG1vG2>*|5TFKHA^XoT_J}N+rJ_`;pr{@K8#<2< z=q)Oi*lVxJ>?mKC3EUL`_rotx1!wWWm;XCrYe^y%Mw*BP}Pgn zwzPU%n%cHZ(X;!a`>r1gcZMfHRm{>+I@TLChC7!T>k_j?V76qKY9#3MTaOIJ*@kEX zD0y3KFwyyW_wDXf=U1;kczw}uD17uep`x|nW6w!UV~N!W>M~$#WMAwk(nO$+^btqI z1%)KP%^y;+-kTTaF2-HIxwK3gXN#gmvB?Zs9^dy#$E^;qc{VqK-kUW<4Kc@uOe}D7 zW^N{7`%MVZiUhtZLvD)KeX{G;F3>v__3L(9o?8tX?Mf*|8^b516Bc!{W(c_RuQc-0 z(EWRM7xMKae=vtJQXmMbNHUPg#0YwNYZ8+8Ayr5XK^j~NRN%-d&1w;11i3P6k^~7@ zAW79ZsjJ~NBghC|lhAjnys(!RppY}UCar&W7 z01;#XS*+-!k*u6rodf2_S*;=@HdsX;MCTx-Lf>#elSmZHbSOOEaLs(Nn4LZEGz)}# zl$k?@JbI9guqN?_)t!Y5ykVVlTn$L(&sQLD1ZS>hbts&dM_-{}1u;?}AaS3(m%W)G zDu~3alvtgzn`79|+u#<`JOy-10T-TzQujUxZ6x?vJue4d;0y}cH7}=0S!D+J8KOSg z3xO+Y#Fix@uTX9ZPrh_i|FTgCNUK5381jNG=Rw7&STA_GqaM!$RI+iErq?*cDk&Dj z4Wk-*8Fn)V&ky4!QC5R8Bt%8^6|ZMnRJ%}o1#%adl7Y!7k7owrD-d_^K~%vF%4R~; zOo0~qnMe(IgVWvsgg|^Ow;j!(f&SJOSsB!#c_$#VGu$+)BhfG7(~|U;;wU09?G)A1 z(sI+VoUA6KiS&xTV?`}E)uSF(#JvGTyP_SWQdSU*0xT&W+?_EwmQ3JL?VHEv zTQa7OB~zzh>innL?;5^o$e4PNgI}^_0!Pc1uqXQfI(%Z;RDQE)t|&g4uw_g&;S*07 z%j~7-rMNlXnK0ive!upMhC2=Ou74c0N5ge(kc5qqnggW7n92aYC z(w21HsTHidm%ziaw0-&JSC<}K%2+xDvNNh(QDH@;k16vKRVh%F8EQ*HyHxEEsvVi? z*3_Fzd%A=@U70=I4~x^(nY8+hY-*rGcIPV+q%`2+VbA^>4oy-*&{&%!HKGQ$LCxY) z$SB7zaFVF?5YCUya}1cr41$KbvyCrEUF2O@uMS~cB6WH6^{R@DWFnidcP zdYzq)fp^@Mb=+k$)H zH4022)yAt>9f;Pf`e(raufRe$sQJItTm&d$Li!pZohb-5y=>k~LJMqs$+Scg`sx2b zk_4D6)DTT}oV$t&y(Ge9iH0uMDi(JQHIXzDHBxw3G;^K_k1L3zg@+sSxUP8Ylq{lH z1S+9R!iqRGnHTUSr+isaq@8vSj=6%|8wmU&WC6sCJjo=>B}zt+rP6N!;xh#LXUIS* z(N+|&yWSN(g5I1V(i_!=J0BT~XRm<(al~B-+bvJB`-`4CJqvZ;HGb2$XnX^raZK^- z<>=-3whU91|7MLPaa-JVt5z^pCyok5J0PgCt@d;KZTsIfL`cbK96$!9EnXAvjo18+ z+`6KHN4uiMOq*sy(NG*c_s-jzj1=)>tdgX##CtQ$)+I)YSky1r(u^a+99uZI;7T*E zgGi&zvnQh`V_e*kp+T%Q3v_eJuz;uO<_vvgp=F^nO&^D-AYB}r66nf=R-o$A>N@H9 zyU|NmQqKH*4n6P(COo8ix$lMCYjC~uzk?Qd`|yghtAx}W%m=smmFCl|1{T01%QHPM z105l4NEaeQ`e4qeh3}18xZ@k$g4+P-dF3_J=_J})aiunIgOPxK`6OoUbj&Ba3)jSA3n1}WBSrqXhowoYQxrRk?q{dQJRRKM128B-Ebs>qa4RYt$Nz}Y|A8@oz$(GNCCuT{8(Xh$eN2`_ zwCHIkY^ie(t?6I9ljZ_KX>bL@ju1x)DBd8>67Z(zs1b(A=U_I8%&mG*Y-g}wbH>0L4F8SS5nG5sb8o5$?W%Nc@Lsl*8KM#c4t z6`}{nn^uZ3Yst;VxyF?esadKt%cSNerCBaDE0pGDsac6Z{1{~UpQ=#vnF=H5XH{s8 zEz&S}8)X&WQ&qTeSx4XCuk$M^sKb{HgtMZ7I!JG5osOPf(V-TpwDhQDz|6%f2pe99 zmb!8I`sHT~w7!20x8OZ^e9tEbZyj90P(9TrmkWCymi$}gKUbzNTwc217B0Bco{=Tb zq~MuM2d2{}riG*LK9dl>ucz=G@sdv}Z&j{fP)$yK!QbJf>O$9dr@uKZRnxz4$>rGO e(%8Gg7$5`evz-;V_4|XEaeJny<~k)^|NjM^|2(Sz literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_995030.cpython-312.pyc b/src/temp/gen/__pycache__/int4_matmul.py_gen_triton_code_995030.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4f8863a1ffdc7e2096dac9f2fbb4c0c3e04c479 GIT binary patch literal 11463 zcmb_iYit|Wm7d{q_!dctdds4%7bRJ;B-^qp*^VN~j&0eAVmD3Xv;@T=W$HmmL&}Pk zVM7-wLdHQvwcUE$F4lB5fayj+l%hbGVu7du7WF=W)s#X-%xbX)nm_(0SN^fPzxLca z9FkTj%S}6gIQMbRx#ymH@44SOqkl6Pv;;i={MN6oChP?9CsdIhYo@SmQxL>0f+lEk zk{A)+whaAU}BxdoA+x1Vkp&Bit<+OrU(kfd0_ez?gsb$S4@;!L8 zwDw!@$o^s;*|!=;@h!&rYjxAn%IC0~46H57+FJiyJZLuBERXKF)<@#e;Wd4VXG}Xr z(1wr5W#cC@ag|}e7{yZlFKzND#*6Y>#4=s9OS^fOcFQj9*23*xEnWPs;*-q4(tAWh z*U}{(V!XJ3MmCR*F7@aO_x>fVUU@l)_p0c!cadE( z5`})CMmUs${>7n}`_DTsogEo;UU;f@@u}K*L4W4em%sapM5(^?%Fx-%&hwAJ`w$M9 zp!Ls9IRio0m?t1;Px3 zYvu6**Xgn{I|oP5B$3i&p5#KiOwHCyeA0V>h>_p_IjmPR+sxJ@9YZbR-127Lg6*N6 zKr`HFFC}VSfHM(l-DwjxpKqrVaoPAt#L=57KNlL zKx$EzcZX#&JJFo&e_D2DVQT;voLU<1{!ka&&UP(;=cF}gWyec$7LZEM-k(_?d3fiX zpBm=~+sPjJn~ZamJPv`N3@k!8!`rgq4lw-z=HI#w&88N~NrQj(s~H8#?~&yldz?tc*RL+4GXe z(`R?ecXs$JgVYo1`34t0_v}uN+^OzmOKJH;6I{9ZcVCrcf!v$zVUMxB%UbDtlf?D> z8>aveD#YQ?U1Z=m6An^PxZd&w7a>kaWt>! z&$haAv~Itk%}5vpyX4A#K_|5@%uM80fMblAo||z7e4y2CP)m@{fc2E{;{+WLh(kpf zSStF3BB|)}JF!H85(NXJ&g)|W!3>W+Hy%j9GBYn&lxeRw0IlL1w8r%bfRCOTL6v(? zx&lF1L1e;!pu&a*u;_BT1roNygcI0MH{Ko4jo+k8P|dj96CPOBF(3Vwpc0ovkdH3R zz)D{Ad4qz|2MaqYPzzuX+@L}Y&`*vEq+3v9eZbn_hhoDC=h(yi_599YtGHy^2!RbIJfQvBk`ji**Sc*8*Wbb60&ODWgNTN;)vpO)V*U+Gwx zOgZ{_OMm$MLqj=t28Icr-ZT`&>N)$8gJXGfOJX=>Xb+!$Xt2ay`$ZoY4KQ)o^m2?>h6q!FOHa4VORzpq%)~9={xKRgg-UH8%DyXzt-v(o1#szp?K?>wf#5D=L?@LtZQEk_kRt7I({~KYiwzZ zbH!~bThoTY5h2q>=Gf^aZA6*2RCDx4ecTw#CsD#%+E&cGrE}%lZ(Xa!$+6Ggd0^(d z&++E-5fxDW!WA$624#}|s_$vauVHP^M9;)d$8Cwgy1pwyZknyPN|#DGTimk~@ed~l_iFY1L+%}JU!^+yz&CiCKabUqg3CKG*Y)om$LJ2YnG zRK)4T2;b1nmv<){R*S!yT`m2h{lPi@JFoHOucd415`9nPq@_KgdZHpg&RBFT=8M-Q zD*63~dE4QnO0@VFrUxDTv7xnBuJGMgc-xh9r9EK=MpJ7{m$SJ0s;`H71w zj*aognz$JbSA0GZ{Pl2Ro^R|<+WE#F*t>89)W(?oN9-o0U8JH^%*L7Hwm(u$Uqjsx zHN?EINxgrh4sL_%AO!Wi$2WF07+`u)O~rVEbBy1Ou5RhcM&vZOnOg))jAvtViG(1e zqy$YOj1iEXBYs2zWxkvY@te>Vl7-}ed`OIOWqBN^4fQL47jWawhgCqNCzp^(NXaV4 zvk`TUmXNrwAPx#eh=nk*v`*utc>`pM zkfhW|7XzOk0O^ z!F`g}Fh8xN!2t1nh(bboR>x{tJ>YEZk8v7MjIf|Z%nv{&Y5`DPNcTY*L>ZcKyCf_3 z&RX<9u1zXs`}VMQ$gqb!6s%md$X{&_2Z$Nuv~MaCF|@sqDQum(FB7r6 ziuPRdxuuZe;VGsX@uP49M-vVUa}~-69J*#aOd~3*T@XD?x!w|#0R7!p1?4=%9v(pz z^!NkQ3>M3bBSOubKQKGz@hm`$0l~#IL`el^fDTN z!mS`)&}SkL5f01z#{%dy?1Zh%DB?2A&bjBhP6K%_2hFvkX>SGWOQUmG6W)SDO2 zN6&N0xcRnrv#k1#{0R>}G}KT zSJl5a{?52=cp1S?fO|35hgT4WFhX%Sc9|n%mzTP^OG_u>u7v%rhqtuyRBKcPV;IYk z;-!nBXz1oU8>aodX@AO8w_$4LP0fjp#AwRY5q@dYT6U`kaE3O~w_?73dTn3VKOOn9 z=Zl^P^Z)U~fBE6M^#|dhO#sY&vDptNBC<49%r)IPc>5sOd-%RFRneW4C6&p(q>`^d z6R#LZSMQ74Z@>0fPFB=ECP+;~R1d0{DnOOu+6`+RZ>@`ute96^pV?B@W8tBPI&ft+Vv%Mzi1bFP({hbmkao>|dVl72xw;T)Zuc25Y@ zgwyXE9i8+D@@uY1L5UqO3VHWU-rBhWk7;I z*pqniAANjz`^wQ|Jw%gDd{;kT)*tDQk>Cy$@1K2aCM*@m_u00J3B4(N4!abAd`nQz zit$En0o|~PvJc7(V@Gohj`(|;_q2P(8`%T@jF@DvXyknjIHZCuM~oic*8(zE(IndL zxmYEb)q6baD6;b)qME4ID+4!I$h{DS9147$%({+9Py2a8vugmv@2gcP{QqOfad z?5ZVt>}DYfM62(WfooM<%uDd1z!Wm*Dluiq)S{|fP+S6#3jgn~%?8xRrbhM)fl6PuKAQ5V$#API68`O=mY z)w+D`)3@)xo$N@`t0(!Bucq2B@zkX>rH{N3H9c0!%T?G1CDi*9^i$v#5M<})A?-XDoS6%PZ-*>=o)vV!l=bZ| zydYYmd{-MV{{Zv}_zRd}1N?88*wPYu!{VvvsdZh|nxbk=R|ODxLs!e|YS(r9*A)BL zbo-I3?~Zog5J9uvic+vQ2E#7`TGZ)dve@vF>iwQIg)PzpdBY6U9Oi5Z@>A+Ql_bAZ zf1zGCybwP9L?+j(!UInzLTQQB@rqKoP$J#2(+G1aQMn`HLU#;G@tAMwV#BvyUwWN$ z{q#zjGA$aT#@K9%D(AZIoVRr76o3s}?}BjnG@d7et(xaav#{f>-b2;ok%P0(r_D5(HO@5=WUM`6lQQg2YjX3Tbk|IUqw7Mm+LQ zGVVVF|E0xheI#Hs9UM(?9w-8X4^(rsi}@K8;UgA;D1t8NflnuHSHR=!a47LU$KR-E zKqwg4bR5Mo8Qydby|`=9$0YHy5{5aKQ1wmp^9B;B9(9UgqJGqgL&BM96cmz9!I1fY zkU0lx#uJ#E3;Rgalb1{Gt%7D-P zFa$iO<41o}@_F@V)$7!$aNjd}8oE~cR-4xwF7S#2;X&BLLvi~PJyFt@aIKUi)vHwx z7E;9{;R{=OLSu`Cwv-B!Dtua7_j<7Y0534m*e zM;V-(ne>S3f&ibOJB>asiX?*OtUpu4i-s*G=tNsP$EGLg`-JH9M)CN)Q`7X^q~|2_ zIaJY&68I7}+m@Up$v+YL|0E3mMO1-@k(tA`_x9b`_mC=ysL+v(Gs&)Wqa)ezzzhLh z=@6NI{f#w`7qaM@t7MKUknDdhsu!Pw>Ul?eK3ScAvk{r(^T`6L=49UkJF?dWGX2KL zns*G@%a>;_ENd64t;Q0A>d&o&rtC?3nM}6TNyw=8DsEJ4$%aU>ZOcN~58Tt=)o)pG zwOFc_;Hpik!cqZonN%&u)e54j=1%YJ-al8O>b9Ja>9@Pk7>7~)@u;ktY~+q0iY)?) zE0^$V^++0y9`eYZTa!zZ@xgmTcZb%fLomSsvU3|zK7N5Tlct1wn}FBWA)@2(V@)M# ziI;5HJ9&HO76FxXW9z+l?n3GV$`2b_HsJG4L-&>vYLFRdXu0RS>)cYK8b!3UZ6S6j i9Y48o>@0um>^7=zzeHA$*2les=|IX@cSDCa`2PcGYsB9G literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_143388.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_143388.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c54a27ab63834c4ad63f58aa3c412d6b777120 GIT binary patch literal 4463 zcmb_fZ)_9E6`xtJ?Om_`kOY$^5HR7I)~ZNEKu3;pIRXjLYsi&9hq#E-%6K=96Wbeh z*N~)ZX>*5!T|RgtT@^=Fb*WUTTtRA)sNiEiBJ~SCMV2~`-G<#=o*t1I2ju7|Z;g62|QXGj^w?)}bcpUJos|?+ujwzG22^eo| z624KH&$ta_xDCfFte^EEltBi5VT=vtg)vj7ncb)tL0}Enjf%w#XK)5P38M^Ox61YM zD)51u1-8LU%uUA)x({d*eED79>$O3nW2)`S^F8+PTec_E*=ul343FWV5ti^xwiX?) za38NVlc8}1DhhkOl!wg}Q5p9w<6@6S*!9|&zU;7}8Rd4`> zdBW%k;+ayFm>Nlnaf!GhY9tkt37b$N63i;ZJqDz{x0g7ybds>zMO7zGO;=Po93c1{ zVFo{e=N$2Bx*C^caavTB%LI=QhdiaJ^fnO(<*-pDq1k2$o+gYmLmZbPDq&Sk2Sd`d z768%Ei|s$Z>bb0_liE}yD))>>qLUGQTuJSY#N?FTqp8szT~1E*z^EtDE2b1RDPFiN zbx+NRz!Je7T}g>i1)SQmZ|~l|KA0gbVx6csDXS?tq3!_F^n2E}z?W&FXPc4W%kR$Z zF5r@<-8{0wI!pUXN4)9U%f6=_qI^SAo zDQJsrrKVj)VV61hh!?U_uI-Cc3uoug77i~pEgrqzv>g1|b?vXwyGL%ti#>;m!GXUI z7K6v{?I{L_i~KO?XI)p1&mGU6vfjykQec)g|C8-nhwvb^jdty{+X7v*f)zj_MDfDL zz!Y4?tNa2o(1bnlVwZ6MzkXe18nCOOyo9OE+II^G`5miqBWoL0QycW+Z)ndESz;#l z!{#=DTs^*5%$waB{-r{SI)gy;nHU{z9L4fpSmmv_(m%Vs9%8x=CX{F-q4fh% zp5F)38r|KuCMq;7vT6uu8^6O|K-TfdluX>JTqay)R3)CMlG5X`w360{6Tmk*PCS5p z`iz)}Ps+qAC*&l6qNv4ZW#UN2Q^ZxqKN3jE)4DjNT$WW)854Vn_f#5GCFLPiRa6`L z!k!HuHt|f0wtW#2LkD$_o;dNJ7&s>${rL3Ipm_QhqeEqIXJY{S>Q7-_qtU7X_@Q!G zy!F9Q`#|CyD!Lg$4EtYcvN6m2aE0UZe2%v|t;j-bK32FeKUv~B0c!yd+jH&K;R5p| zx5zCuFM3P-j{$Vu+|}v1X^Z*Xur7Y>{?UrpmmkgzUmMBd6*qr%W^Tr6{ybB_U%T5@ ze1ZI@xlgZ&S>_4vf55jD`L@E&5>KI9@aFgD_RkHP1Lnn*V0$6*Woj{1+TL9ZcAFy) z{Y_T)AG;oO_7yw(?lzS=e_rw*G>?Jl?EcSw{m|cjzvIJu`|pMCpFVSc>&GSk+56nt z$HF#q@FCk;V2f-AY-)w`*grncVw|E4Nbxl6vlP6T7yavDxnHDgX^k}AFhqjYq$yDI#WXr zhXDvBIS+!b-(wHT@we#xn`wKs_uOlE&%D*xznQjId(Xdy_nx=v{hMifwf6$`-avrC z-jP*$byqp)pr}ks5kO>@Bwvb0W#W|d8Guhgmo;6CMl`7WR_Uc)rS)_@ySIL@SKYN? zdS94TyD85J<*<}eX==5>E6m#1oKEOk9cEI9)Tmb!{mQtv$#$Wx)>Mzd{0S=94cKR* zNB&Lbv1eYy`}0Chum-G)OIvQH|C0GCv&_A34ggf;x8}A&8Hty=ZaeSzZuyq^eyGj( zpcN@J+wk?MV~2Sxd!op;*XvPh^0c6Qu7_EA_jd43$E}WK{)5-i^S1o?A$rz0q#C)nQnAME#j1poj5 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_167554.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_167554.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4158463d5a77e6c68f6ba8560493e1c46bbfc1c9 GIT binary patch literal 4386 zcmb_fUu+Y}8J}6Ny}RDEO#&g80EYutu9ixo3pwB@mzEIVu0ZYx*Ao|UUHQD5#EI<< zv+K*Bu3LvgidR(fMy(X{bg9$_XpmY|D(?(UwY);3aS2!b8=U_M3HVoSct? zUPj*8`M&x7&3v=J`TX}lz>lE)W%`f9r$vMwQ=-y@^}*T$7|bCBDL8=!%5OY?U7AUH z23U+#JmT7^sFvRfqp~B?MjoXG8O3vrU09zO-~^=j6;?$f+rBBQc)x}3Qw4>qKR3uL z{59`@zZ$V{4GoA4QpKuIX%a`o73!|f;d;(}aPq9&7vpN&dtvN>XKjR`i#Kp(j?$Z( zf*023Gn|P`&h$)y{!{>=3^JJwG-KioG_JACZnleH%T3?Hy2lJ}@+Lb9R7R-L@_kba zV&JEMZ3+r=+p|E|0d1s`&@$=vMs-{-fz?&{zr!EiVt-Z6WVGlo^NLeE(Q6L5+mg2NaI$3I52Arei> zamDdPNi;R2I&4CVDln@#+#n2%L!FLCPbVE#zd{Vhs~Z|oBOwQ0a+uyv;JM^Tx22Fx2E&2T&b}PC2EKr2N^q(mpmR151W*3@s(c zGzjXg!$;nJy9;KV?Q)H%JgSnEnjm|@HT??o5bRDCJ==lAK>nb0uz*Yc*6hh8)|=a& zGmFBmBD*VlVu|(T_S;g?zq`op&h~5~pMdH+);sov0#_2l+1@2FWVcyIv%Qaaf1a~A z8{56Je8ISRXp!%D$afa`&byuWj@>y{;=8ja9tmQ;*=n{MZDqD)e#gy=i^5wEg|4E| zb+`ZCz@34T@Y8J16Tg%{U>&d<>M}MvJ`47if!56WkJj- zR?F}D=RThOxX?4-aO3oEe=&dM=E*^O@F!MY9!5&zHeL!I!8bkko7!{mRhZEuu*7X3R}1Cq@)hg{NX9Q| z2fxq`HmZPH4?N8(*ksuHyIJ0yjQ3k+J#_CfK7boE0$0iAR$0#!y%0Eqet&K=IisO6 z9%-zMp1)0R1q7>k;Z3i}kM_deiAGaJQmM*ag-d7C7uj-zZ*W>b+)HHA_3kFN5@o$BWiizD`TDGk7+3*K9ttdy5j{% zj14<}KqO;QPQ*u5M^Y1N5}-=f<5Q~RNybx-uZ&7gD5Xvq@|ZTR5?LFRI~}P%4W^Rn zDMB>j0#d}k`K$9!$gY1G;Cc^T8Qp3JpLc@iPh65u|NP>qUisqt1E$IwR z4yxpJsv*^4AUOoAHhA<_7yMTJsB!yT^K5hWXG^@07c9ZvXMZ*~JUd(dkV~#{0)A-@rG0qegH_w@iPk zFER!EfNNO_gz}$SpI(=9%o8E_P-rO%Erqs{K=Bm3Qk-Ng5nQbk*zyC>)LMvsnYxiG zz207IYR~pP3O3m7e>m`P-_hc}qhB?Y_WigNJf1xRapaC%{pF)z>-}99O2Pj7eE+f( z$aMh*v&{vz$nF8NOT6Fx33=Y)?S|{3%Y#1jEW~?#tBpts=8sv&?2(e#cAsq{Vc4S> zJk>^T3>WYi=;u%zYPFhMn_j}3nt~N5Tx%+(I`!IZc43{Hsjck3zct3FLK=K@O!d?C zm;lL&_c+A*9nt#~ZVRJ-HES=9o_~qxnJwM>SF`rw=!KVv-oK^MznZldM=#RoO$4AT z^{vpOyTU^U?e35OfwL_}%wA6cUhnYWox9Q%VJ9`#|pS+Bnt%~RS*jeR~v_0GB zqGALOJsp0W`|H1-bVmYicF2!lJDpJZvTIMkUy3vo?G9fV5%-LEM&gDez#Xa^>cm+2 z)^mIxsVQ9}j@SL);ilG2kR!tF3)z*2w1jdCxsJ;T<&P&dC7n>a$yx9UDMb2J=vF-# zw4?=;f#6?_JKJUhHVtd2x z`tsw}lEWb-D^Bv}3G;M0sSjL()I*}8l`8cQ=nKYCv9?r-3NLvJm#b7&mD=yDV|U3J z4pchIPG-LI{WahD&gVP+Qy}0+(Ejzwug7ydLXXL$*0`;~<^veqKoXL03XN9(<5BF) zskCR5#z@8!j-Hxm^}j^P^n~y%jm*X<$#acf*;*K7I3)QcT1FF1-{zIP&!PKdPGTCb zjjP#Zi8pm>aQ~p)==ByeE#t8JBr4j6Lve{*5A=*Kuua z0_NMBgjcrKbBuuu#_-I6{9FK`95Se9=9Gb-nPZn`db?f(Avb(0TNZPyi`VyUUa-v0 z0oh;;dJ<$L>b7Tv>;c+D?J%oTr`PM_Mu%755$Cen;4o{}3{IkI)^cu{VBJi8U~mRU zB2>2l$M3SzAPo$^;U#;kNz6gA=J=ZJoZymf6dC+~iA=y{^+k~hxFX}0X%v~jw#WpH zfWebk89|AzMS<+zp;cF$SkA!(%!OPYwba91A3!cvGbCgPx4pH~!e|yneIqJKvqX(I zO4)QercK&jMIV>db&}svtdeUH_1l!5vi(!4l2Bu5F)7);m>SC@WSdSYF$q=`n;Ci`lLVFim7mX6Pt3>b04_0X3 z;`c0}n|`SV_|D}i7o2Ov70luv{m-dL$<%f z_TN?S&E1(RvqSko!pNGe)mFqyJBtU4@iN;DRKHNzZ|=96t@BIm*AIX*yzrS?m^P;^ zt$1;1wjAmz@m=}h3Lmn%%mJ{+@r4ev!)mdlrOxHHTlkBeU$w4sZ$0D&O5DI7TK*XR zeYngG=Z6U6PIITVuh_iwgW}M#r_AlEQg@g;EDG%TuD=1>6%?6fWv4=#I_<6m$$ z-z~NO=$2G!@B8X}slC6%^@9}cTR1&`dhvp_&-|cBEw}xd-oFU}MQWIIdq~#-o$XU) zJ*#HoTgV0^gb2DhX#Gn_;v2|76At;=F6RNdTWV_=Ft#DTggb$wj=spi&ymKBtZkDf zHpu1K*8W>LBx>>qY_SRW+;wP+zgD=Qp5T|%gI}oUIeWeg@ypRUuU@YgjcS8Y$Kjk$ zB2k4(x4qRoo+BC=Jycsa7`?eRo@j(o-WtZ&BhMOMgPk0P=(TD9Vhi{~MZ0WReJBG z^+gDNL?lJlGH>b;Eju-(sCr~fjzPZ45hWF&??Gpzghkn(XlTo2^*hPcYt&1Ug+%L*}+TV zPkwRn6R>b+A<QlqI@_^Y{5dwuX z=9%jwi@3sY3$ydHR?DZkB7VShRsx~IhvtXZ#YJk37YawrBlE-g!Tgm+T<{^+S>ig2 zU1bg+cdb2KjD4QDnJK@~Q)=(Yk30%CTRp$p|FCv#S*2Y8#W^A3%#G-E;)-FTZ`*`ns$FHE?L+na=6 zNgJit(FtxCRllPi18_j{9s~8RH%KFUK8bma-oKi*7kkgXg!j}dMeDm_|7zM^>^=7q z-uqvx_phey#oqJ8djkQ!3M1M(Yr&@csc!_t1q-~%TZ7h>TV>zL4?Y3KhyWV?&JAqa1cz);!*KWm%Eme@(vp0Wg@ob3>yY;Abe)hOq54GHXxBcFpJ9}2S zcV9-&>+KI7jOokj+Dph0_%iN=o_?IfknT${tHA zQZ^+YSI>Y|I6|adg>J)xF@B7Mzo5Wh(e5?OEQII78&n5=b0dg?O@+(m<*!39+w>qx zsL}^e%Z{f^1b0^$VZoR;Hasw1V*>@j6gIpt-tZxRXoGOT{(bzy<@w8-9LzQ^V1xr- U3;nqD=?q4pz2#tJo_Bow4+H~xFaQ7m literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_220059.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_220059.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93b2a788243ed4f9a751994180d8604bc109090 GIT binary patch literal 3963 zcmb_fZ)_9E6`x(N?OpGB?ffUz`YH$nNP=_)B{BchlHTbM;-XZmt#|X+@jA1c z!%5fD1|gBveXx-#;g9}!97sJRBwDFb`@K?iR-zJ3Qz_E%CEwi3^-DkX&Dt>wmm@@F zaeF!Jp_4>4FyCT{gjM{9SG= z((qQei^fD2>0-m4w8$hPTh!br2@n73*CW4FMpJ(GgCGjPxBYvTF5F`1SXyR)f4ykT zqq;mFu_Jb1b{rwhEFv2L7g_>cOE4ub+6?gF1~1tx@O*=h`mAuT+d#1Ow(zKKJulm` z&CLRrkNIYWemu(-ZJAoPV;cK7@QAJn%8XEJ{k7F#c00(%3gdg&lisntCOdA&=^o{q zeU^6#DZ5!?9&ksAwa5O#|lq`T8tT`O2?^{lnoO?9f>r(-v};31vi zx2^?uG*d6Aw5F@Rk`BAw?pT6&%OCLR-X;9)Xni*i*s;fvowX!b z!p*7XED0Lcl#UY&;ndu;S(jHYCKwEEpoBnJD@`IA6b*7I5%!N4;#O6G_Ea63Mjiw5BUL1sk_X zw9aXZud<#L(glXMZ)D2bDjjgP%+`KX?ZiJ zRPf5u@!aWJ(+4ASW+y1``Kj{XPkL_~c-G9SDZO_}P0gy-l#w~1PU;z}*Tkt_OP|a3 z0@s`FQ!)mgQ+|6}JDFWnpi6-TT1G}m8L-U0zEh{pod*eNQG7#{SsiEeG){mMVff4e zKn4f>(1pZk;e>ml#H>cToRRf#WJSG`Tj%9XDOwB^)0OyuHWKPW{j&BrRy zW6tH**cYJWynDX*X(_xa_BhFPF;+b64m!yfe54R|!^L2UyDyc`er2xlXP)!@6~6x| zezx>vX_X&#hBgGT(BXCz+e=69cb9K`)3PRI|E23sU8_RU8Qv77m3jA^ zliY|%h2!q=Vryw=E%JW3sS-KjjBJ1dgYIB)r1Zi4Ypde>PLlf1xmTS_o042?b+0*> zHYB-l*}YusD>3)`%gmZ|q@4Qdvv1Bk{Je5}pdt_`YkhQr66a;O~4$daQuxQbv^vX-9zr$D!5Zw(h?U$euvZ#k% zu>%tf4Nyo8j2*HAV1osJV|B5w(byeNY-SI6Xb*W1fWMyH7-WY&LkVtl>ELBug`L!U zwCOoLV@?~H9?LL#(khR0T2X8GL#Rx;-Sr!ce+D#> z{C5nWg|1U{ToEb76#mb<)ky@%x}o8&*h^P`DppmRry2b%6gbz8MyPi4f(*g zo&Brwz_B;kws(VV#e7KOo>ILS5ipzS#67kdcDm z3dO@;N`5ll@t-y^7^-$5Np`N`kHD!Ee2qf0GF{s?3-tpujq-cGCp35-`SS5DfeM+f z-kWE9Ii&tRXxIiqWA>$ueHrW!qU=sS%?BV@krw`K(_1vCx9IE0-`(eD&TP>KxWDIYg2Sh2Ul1NQS_V@D_JAY7kucWH zg|ub*pvZ$A>J^2Z8s@}37ks%LCqaG-KJ%oP)Hmh0bL9sK335Sn#o|zLzTEzJ;Sc$T z`8EC*&Jci1;h=jEk~LF4{xtY3`Xst0TmY~U+KOta*#n-=&ABZYH>zAR7^y@Bf%$O4@s<+O-`oa?YRK&FT(Fo z*b9un%(O*>l#wwlJ(tDxEOG-<70oi<1JyfE?^K9bJ5puRNNW#~XRUUM{Ks>Kwvg5@ z;Ok(OMv3V_RShr<^Bt1@iK72P?^Q#HNRAiA-SO|+fZGlrR;u-X z!gMn7dRV$=FWc1s(3^a;Ah}XC2sG%3#HwLR@hBwT8($vZ7J%DMGKh(OFZDAm|IRTe Nc4$@ZSr$DH{|DDGAd~S`e+*+R8~>?S0H-9ol#(F!7@(cl^9&+&vi z6G+mrG$176?t>=}Rg6}Z+^0f8Dl4tJJnnn@fR(5l-KG)}yyPw2Y^AC%?YZNh9kMAu zTCUBUzfd26e|YGAlj{D7-vW^oH5jJM0I zJMOd{CZ)Q$E?B#)07G7%+s2oS9{=I!x!qRWYGjzK-D;o~ks)&1Di#-W1?K|nfn35f zPP7iEE?BYZ+$5~XgJ5k&F`7ro&h+(Rf2XHcYF~gkE1ku`dEdTkq`?^l1j2R`Rb&o5lDM&}p z9#+P*tl4dlRJW;3&vb*QQTvw0G%*)UuLpr9Ol~Xz-)$@yE$KC}4ZInHp z@{~rhT81110PGwX4R9uH^j!-QrNZI-;UZoMx7erFLg7W_=G+<|eI~^Um-3gKtHnbr zQipwdO^iGJ`SClSf*1A53NZtr?V)PcO5KC2f^)m^%);7gz_OLPz)|b zOBWtCuJS*(`_~1ruqVIAX)1OtwU*3>y{p1;d*E3(Qs~NeIdSLwQlfZ%HQZWiFNZtr zQ|kcnc>Z|dWd3Bav1r_BUJ*O(!Sz_Ya4mn$xl%m3lr9}zjUBYltO??xns0V4mKyK+ z@22hv<@TO(qW7U%PQ3T%d^yqgb-bJyC<_Djz&aPaedflQ#ZR4dUM^lKnSbNn*@B5d zWIwA~S(Sh)RpkT_5n|`i<-Nbc?A;cYk%cB*!}C?n5AEK5RWop!U&YO!)y{toS@;z_ z@y=siW~|8N0wy~uZkgPP=K^(}2wDLvI5i0A#Z95sNA+2O&yni4f}f+`aW&tY*aR-9 zl%1|4>wESIS=?`gMInt|1>SmiZ zm(nz~Emgnn06vo&;kU&`{4HKad7Q=&5RcqLjNr}X8!}J*8xfJ`gJQUhppkhQXojS%aQ^R-ncTO(Gbxp zLa@RJG?vxoOnFAXt`S*J%RMywsgjx1h6vFK4FE}|#%UOc#hjNj6F?D+9%!&ah{^D_ zqd<=p#q!TL2jRBvMcR)1QG}H<1gG#CqZ-V4FcpPG`PoaXG)*@12I6n+j z!V3i40~l+3*!=?m{zL|h&+|_FGZ4cr#`e!5Ua?`%zY<~~ zxcL?W5@UFWR)iLYRI)#G*DD3N)(bnJe)w(qpk783DCd0)u}!kAs91%UpXZUV!-EMP z71h=p_FU0FKAO*Do6jw9h<$bH@>@W2YTyJ|UNRoMINu=p*NgUI^!zKJ_r1}%f4yig zMlZYq`tTb?|9a70j9z5uEd*Q{8QEZ4cY{5eDVLg66=2Sws$H8%X*8gk^FVE)sTro6 zQVdN#y1_Q}2K&Fq+uY+-dN&v`dt2U3a)ePSyb+bGnx&yMJRZ&79+!`joXefF8Po8P zUMIL43>hP|3Y9guFgy<{at7?bVva{(L>oPcHrT`8MUW6JhUHbtHgOzcz#TrO&&!W?Gcw5lajkEzo;6HuZiZCvs|#>bv-`|lQFlGvRiPqWhp?pbn4VF|Pa%Hg>O8TE{qBW&^5W-+Fq+Vo?LpQ6a$k@OGr_7fbs-EyO4)3+ZV-rS9%4Ta13%iktI+wvn{ zq|*KYZoyG!X!-2Y*-ZpuN&7?Q_Ze4wl-Ob|Tc0N@jH6=RrV`TTB~l;KP5aV^w&#vL&eUuE z1Xs@Wx#ymr@7#NR&b|4w+wG*F4NZPF!@DTzdos|QwM17^CW_zm1g(ut zbLI(#rUW{y+0j*t-@Hj+!eF(sO__M}4d#A{Jz?V+feP1FN>LPVc?RDqIHs(;?FKX9 zETgjDpe9%oC9q|?u0`7K*d%7hux)VUmuG&#jRlP`Eih$gQIz z64Et^i4R($_672u3i&#vHtZL-Dw7Ae1N43R@QtxNe$n3)Smv)CfaGXnC{mX+Mo16w2?y zCOLwgF(gi-z#JFiu{D4K(P;rQ5i!8StcYz>K+3&6*eu28h*^0CTaa`a$(WHBV*+*x z5n)b<%ACB2Es`uE!QY1I53p(IxA6A?rbjV7hRs4u!Y)ZhAzt7XI3&(vcWJ~g{{K*$ zMdYMoE65R8Oi+wS8fI)=;9@e8unp8oM2SN$U=zPc1g2sH@09?;43cDClP&UiQO1^_ zC`?U(>;ks(3!E62r8aQpduIFhue#?&G%LjdL7{sl5S)cI5Ti!|(-3vHgo53&Fc<5F zQFo+=i;8HDo1Ev5#uhnXaS*XAM!BE}argF~IC;DeX0Vqth;p+6iV6{Q1W=G4zcdP| zNu@U2HK~u1A7zi_-5u(=g3X<-yMDD`_iV84)XC(@^u??#&-&D%0$Y>bn>?WoJ#~B1 zP1(b_x~`nNYqcTgJ_2f<+H_xbH0Sl_JpR?8oaZg|d^O&N-I=l_ZE0)v(yDjOencHu zcd)7EWOKSXE3fvhIl9$Bu%Eh=ymY%Wb7=XH25Q~uN_8c>z-!m?o|VgMPXFpa&Ur*V zvkn;hll`f8lJ8`_S?Nw|p4|rkJ+&zzDcl~+Tv)!4_v}&6>js*$eYv{5tMr=lkb0)z zs!8`|_pO-k_;cR`E64xB9NdOgrBEm7_K~g?x?m}n0!dLMtC8{cXV7OT zg$fgHMN{Z%+}VS$W3~&B{q$0GovyB*FazJ{wPat{<4vk)LNpnpySa=IL0jTa7 za#md~N|08wWYZ-QcG$aW50-G_E{9rC-TV(|Lct`zTPTAz>a{iy0o*M1=*Ji;43B;>0Pg2RnZg zh{T1n2#E+=pe_bypjOVsuDZpRl`f?5;JogE)rz`7sB+I``c>c=zCH&nJgJ-}|_;H+?K~V);aN;P%fR)_vRZ zb<201-?XbEh5cRkhVKrqPJS}_`NhZk`_hA%;pO3M;P%MF!EcAZ9{z6fn^AT2<)k*4 z=4@S-$uS)TyQ5%tY9Aor!xrof1$!g;)HeUCk+oR2+9)Wbq}cAxJACROM4=9&qp%x6 z_$xPZHy~TQZGt{Sg{Xw_xk8S4MdN86-e!g?c$Fs&T-i(91Z}u_zP>fnPO*TWx~W{C z%nB4T-tsnJeStrC6TU(8uV(G}=H5<0mHP|#n&Xr@ zxZ!9_2eS3WWY__h>06^Yrp;&vmF4ZHjCM>beP1?y)$v8gn&b4#*m+ZUUdPUsopO2A zF)b_l-P)F*epo|XQ>?#7Bk*vI!azB_ z)%H`{70ed5l$#bK{HK)Wt+7&-6`olxm%H5MQvb5#D3*Vslr3yEk}WAfYC}J^W7JU_#4r*;#e(Ioekm?B zyNW4w39JkQ7MlaJZt_?H>47ddfwD$r<2nn??RRICZ{B<}{A(lbFznl55gwQvP*(+j=*!>$2_mPftoJ14VFP^|2&n*Nd zc#I4@=jqv-TK(#r!Otl#_c3mY(*t+-hc&@ONaqbS*YN#udhiGO1w+z92g{~JUA)6j zgzMf)chH2)Aw#a)QgcZ=W0#mZGgMkm#dBx#?7wde)XNBKy13~B=>0!-!R&`W| zp9d}*_01B#Kg*FFm05S9I`=5>kgW;I9BlOdV%Iwk+Q|9}8xFCjykUEdPJ`3HVw7$2 zSw0}8oF*MV;`but@GCORY4*)g>kF(CPO}qa(W$Sl#c6a}JZr7KnFm<|zt}jxO+C1$ zkJ)C@s;-*I>0vpN;xB%Yc!D4#UrI!H|#ngY`9 z@1p^0X@LfbwMc9lv}}_Yv11f}O1ZJmpngj6M--3KfRVPSV%cO?H?)jK%v&^4?$JORYzH%1+9p(han)~?=M!eq^8BSZLpgn!@)^xqq=C=lg!06)VWw@6h4-qT zJ{Xd*j)B(hA1MF%W$!JM%vZ0S>t6yG|rW z@@I2ri+DZU>5guO!fWx{nN3mMRw9K!Az5xXRaQ>9mp3EP!s%k9-1JU4@{W6DANw4X zyq|l&@NqG;E_b_Qn{u>pGB@OoJr~3IP%cym7WsQh>HL@0hS>j194LzePsr1i$1CgN zh&#L`$@%tNd!em(=3ZCn##hZ7(uHTzP+1!KQ`4V2|Jb=MjkzP+va+_AyWoy(g_V3y zuBXsa9Nq}OU1}_cPr0L8;J{FBs4!YQeec@3{I)x`t*C{T+%@;|mZIjbW6DL3wrKX z#gWpDzwoeyLKvIXVpZB1{JHh!eFf7|_2fgwF&u?S9j+6(XqVobd1e;ShNc(7kqpVdMy(pcp^Ca{P)Bhe#UT`s{*{vS!Uxw{4o; z$vBxdJRkvlV4;eMaL{MsX@d&Yw5QROk+HS3dCMT0Ii>Z{@XdH~$+%32N#OXIBr!w7 zaC+=zEjbIvfE@uURL_4bv~NR63rI;TYnV0C1C7PPT|2krU7)GrKU$<8y3Vn21!!DU z2_q$vAIuG|j=96`;+ER-Og&Rp&y+r1SKoE7Zid9Q!P~#tQjdMzF|e)#zX80dgu*$e>82c_FjAq@3|vI>#*1#P1~!zmtMpB z@Ei61Xxd)wz0ACK5P**|zQa;?hy8<~d}2wD!&61jjn8Kj1`X==GGL8t83&LAQPZ5T|9&p5X8X|_Zl(a1$6nU`2ykf9Z z#heDu1>coUMnQfPYV9uHRvX-_-zi8^^Kwou3>OwlZI70IpM8+s5P#+l1IXl$=Z?c2 zjY~aGf=?rlBOB6102`^b5HB`)z_Zml{=;zhq6t~BtZ^=a4Rt_|tZ z>*#q?{`^0BDk4&v-Ej{UV-asl$XSRXjaF}yb4hq=F2S>lihC6C5+g7-Yf~vg28j-0m1gIP$GBfSdo7!6!lFXz4F|-_ndRj`R>nr_xNw&a0o%^{L}9zQ5d1GSf^5ia%1!F(0G6}q~R1AE4}d; z_WHiGe~iOO$5Wo4%FxnV^Xc4_{B0TQjr%nJUG7m?F&5M~9ZfZU*G>!k4>_-kT5xOK zxS$Dlxv@|ct#}uWNj{`Y)p+t8j)m4)xKSD$y7bex-cd&rweSO=^24*)>|+zRFtcT4 z3i|a$!;i}PT&Tv9tG7e2e75+g>>(%FlFdy+FDKi`4%z%XM4g90w#4dO*v23Y)=3rH z2U4L<+E5dh|8876Y|EPOvHua96O>%z4)#i&eN&Bntm^27u%8jxF z&r+W@NBy_rgmT2RG#apUl0H^?^?`EMYz3DcTbF&OIk7fL+AW6rP1v#v=<{?MWt)Gxu|qI7q*hUf_u!8Vi`7XD~y>dTVxUw2LxOAu(>MY?J$PX-@$)71i z3g*XqSENo4RJg%y&R=v!)?{VzeEz)KSHSmAEa6r8@Y1!S99zEn$Nq=!73F~^hl=tk zXXLr4xY2yO7b_6_=XN6-$~!RcXG z>ISQ+;z#q_flrXe{n}r#{}5Y|igv2G(t7o_Us0Z*fyX4yi8eQm;aXwen^)`*TPwIQ zur08)fa?HFF;QL7_)1jRwe9u7_WE%#iEQyx{40!+Eo`|b)fwglTeym1!O`{qKuIEm z(fpqCdK^xk-k~Madd5r|nU48zN7C%D45K3zClk6yP#&Ge!ZegHGFEaTYh+E}z{&WG zP6L^2T2E2#R#LxBWj&>XA4@fp^E&mXlNrjFSc5iX^jS-tF>dQbHOAFG8v1!WmDNWG zF$j>CK@yWR1mt7Qsi`Ee2xAb+m#8x))qgOGL6Q^6OgyD}C{*uHV1zj6zckF%DV~-CryX6{AOs!cj05T3k{T+E`0m|nx#WfMQ=2LFg1d;f6Z+kG+aA5} z8t8qyTCJVX@6Oxy=*8DSANpR=@6Oxy=p}~UM!?7N=sG*P>nwGL95AcJAzkG){Z=vo z?5bIFK(dmho0gi0o4R^zogM0R_J6Ea+*jM_UuUU8t>k!&BgdI60Ew-Zp}}#;Y_0x2 zkKSS&4HDhVrYy6DXBlUj3>hQH5;Hb>Fly30atZVqX4q3=N2SrZ@KQ#ivMA*xchJ4L z)bcp{hup*5s&K{`{2Ju@^7|m|!%JPC2fhe@7G4$4IzwNJQ8!-L9aGWd_`V^dz)RM-i~_v>Fo-YN?WQ<7%A-`@}iY? zQ8=A8v}{U0OWuX3?4+2pP;K}z#$O@%pD6q$F4pYT!Jowg0%5&Q;_ywbS%e*XOz zQRr>@kxwqkNAZua;*LDHbpO%@g65L`yVP$}Uh_%Rd;R&xe?9$|)0<4WNif3U7xD=l R`DPZQhJ!0g#~sN-_Ft~Cd!hgU literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_404776.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_404776.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5122b3667adfb8a426ef3396f9d677c3fe3c604 GIT binary patch literal 4709 zcmb_f-ER|D7M~f9$1|R>oj90~gp`Dq53>yc2a$q+DwGyl+U!CrXd+gY@l1jfe%xsh~yTU#PS%7^5Y+o0TFpFM0De-ItYCyXTHQ6K^&o zpv_2g=iGD7ob$UM^E-F`?D4oMxb9E?@!G#Y`jK?3%d|jeg3K%K zA-62tWCukXvHK=9C^3{G**QH^bWmI-Df_YO^ryf4;C#40D%3%6g6x2Mg@N`CrPG#e zgt4U6r%XSn*)u|&C%fulNivQ22ek+pH*zCx6nYsEQAj|JX~;SauE=vTL*+ z#DLbhM(eH1%URi0pGZQY%sg~FAY+1Sq<*Yb6{XjnU|nZ1?GgGm$>iCxObvCJ*4akL zXJJCzv@-hNYu@$Ojoxcdge{CT+M+EWKXQ?EvLa=gHp$7 zlQ9=jBgtU}v+-0!25SnthCtMNLfE0D6C|d74LenBT-7nFPmL+qt;CgtlGMZc6n1KQ zN>zf}G5sZG&J4i)C6+W@jmb)QGOVU1u&36Or~bQJwbK3Ap^RxN`3kUWEGC;yTxw`Y z3nxv?n41ib>8gfZ5TPd2H1s59t6r-Rrf?n z9o5DnQKkD@BsvO%OeMP_!!WFFO^tTzN@A=Vdfo9*IGIut;j0t!!Lg|@sKPKhJ(Ubc zQ!vIOhmQ329t8_-3fH-YM-?@x#MK}mARoW>3D8lN`q@qS8gmIFQRpjs-^!j{@c42A z#z20&u&?ado&9KDY|Qr=$FpY^J>GnC;Y6u%U&*ttxVPjv1V(Q|zPIq}QqzHw_dxMd z$=j1XyPj@|cjsJ&E6)|aC^o&|yI%5#O8lWG{->>vTg&{3Y~O+)=9-OWNT3&cUI^VU zg-}TdJqbNM{`h!VIF&uUBzkg3jidR_*<<-}<7D>Cf?LY%HTLE^^YZQeh4I_F%I?m& zzLI-S_JalR?KOIHCykSZrh<02r7XUkJ+lNUwj0~?O#Z@cZuYJ0>3Kn%k&Tx8#X`^A z+eOE>!BX41r9jsswiGz>XuK47uOz%@dgTH}U`EYN7!xmjoh4sqS=bGLSnk%j8|P*& z=dTv@zp(GDz?4#ICuw$*W*=!Rlx%hDf{(XaHj3~;+@TOPViyj}1qh%f0^x&K z#tB|ET3+=5nzDBCk_GkW*%?RT`o*ylxj1#J50bH~;|tSL_6*mE{#HyJ}y5HdpNq z1h|3PYqLO8i20n5`|`ixS|($k_J5>A^t7sU$T6T@EtX1lOz-SaCZmcXYaNNmWGsfy1}grbI1L*Wp1Uy8)j%7?0&0&0!QvFj$G1%+3BdYJwUi-0I>q5%dsd7T}GdZ3XJ z+GMQmhaQ&dIa57G9N^aWz_7Da#Y+h$lg$eG3x+q_H@~ChK7W^=^WT$-^z$9T89K)q z+{}2s^J}R%4))fa_k(wXbEofh6%RjeeRsxh(i~9g*X@sXKka(l^`~y&7+%OphBQk8 z0!y_U?X$Z;yWzW4W~#X^`FFhZ?=SiH7aPm|Uu4hEds~2rUHnY%SHD^C?s(SLRg4rr ze|G58A1?p!@w1joW$)!@{N*LdpF3}yFVKaa7Z7#4SwjB`7gl1F%j5qjv z;Df%~ywlR8`YYx68?|$hcHL^d z(OO}e={<|^8*s(A@{<$f6$iygrnVAlqrLMk=@zv6(`6Qj%)bck literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_414029.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_414029.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d98fa0ffab1e5c77ef63dfecb3302080eefc737 GIT binary patch literal 4266 zcmb_fUrZax8J}6N*SlWZU>w3hLPJE3$kzgile?5tLIfcpx+7erG3 zmjt-Aw22~}tvq0^mGHfB>cf#H(lt`5H}lPJzM1{aw}19}-3ZF{)NjXvAibbUxwuNg)>9A`k&I*49>$i= zOge^Hj1)X>`6&-AzGX&X$33sMQEh~g9e3FWCBv{w7GzdIfS zeo(R}-MU+MOk6{V;92`9cKa9vLDPi?C2vVl7j*RFz&C26+!pwP8#3J>g)7cy^i|d1x0NY6m`XsrP{(oJr#S- z!Rbl{bOvI}Z;S0d+gx=OXca-J6fqPE|_~C0K9bPfVuoDmzeSO@}h2njTdnQCX2@B$BvgdP|x-`>JYL z2|b{ugZ5)1iI{4^Gr1XQN+YU?XG|9ivt&k|H61s@#AJ!80YaJDiv84jaMDt({QGB} zw-RJRoeD>k&arT00v3^ow}(fSxYntXNT;SuPIZFT84F4A1eugRy(M={%}UTD!6LLo zT#6)MDJPGgJQnH#iy4q?U8M z-5FVZS=0YJ@akpwv0bBv{0WNT5%s*I-GO28t0w^wo~a- z^Jmg$vVpAnWn*4!1z>`hlGBa9{&?}q!j)|AlK;yK%l?)6A1`aalkb1_^m4AV=MO!( z&hxqY3xD+I>OaT{9~iw=fwA1@7jIuo{W9}uR{JY^WDDYj$d72%LaPR-A|<{85g{); z7Uy=r(OyJ48n>EXt&$FCw~s2BflVDUZUn8&3Q45n*X)U_&vlLRF`IOPm3~FrO8GKT zYr$T=5^a6Tb_>VO6wuzyr~*_KRpE1wYthl^#I$O< zZL$*>FEVDjfiJXKDHfejOs+@}W=&j~(WI%wErm#l5h-N4uY_aM%6UQ(#9{`MD^g&P z-*()|I{<>v$dLWCYNZtgrt(mg35~oBis`ZQ4mk-;hiS)uKxLqU$2;Gj?l;aCcwwGP zb19N}H(R&RmFJs{oN3B`h8C{og+oSfLG;jX{dS+xW87E+sRgzV>H)`#Rf@^=H`}+YIPZA;5M#6V3)GR6fmT{P)Dn$BS1N zu4Zri>e^E1>#nc5mWO_Gx=`EjPZ#UtHa*DWqgXTV-s^tGHj<;T;s`vILh~OQv@Hhe zMHEFz`)y6H#cf5#G9;F^?4U|E2IsBRn#8vI_=kIAv5tf$v03Q8g(NRu$>^20hdp^ry<+9I`7XHs}wHz2#((B^|WQ33;O& zC-kDRSSiSwE~2Q@F-^6xT`QPXii(n2k2H$vlPmlgqX(F1{y_Qwq*8dP?XmNT_mOwS%5Q{vpqDy}oz|KU8U3kC zIkw5}hgACW-F83B(tD5VpEN&eUJ=d$v+e3&6?sx)e(aXe;RUMlP^oeIwOwpsAdXq*{PZ`GgZE`E-P_0K}@Ih8(?j% zwLzv>++Arj5tHvDYvSUD@~0;g@^nn;CYNAVdSFzVv28jq#xIcPZ^-+1ba0bH4F7r4 z?WRp;KYqLL(a>M?Kk45(K;Cj7#$&a<484V$GUpcu76vvEh)c@Xv9Ds5__ThDwrurc YguUN--ot^HGZ@tz%KKVwix!su0@oH*4gdfE literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_419949.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_419949.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21440639ecd5b03909d27f70bf4720c913c163b GIT binary patch literal 4266 zcmb_fUrZax8J}6N*SlWZU>w3hLPJE3$kzgile?5t{s=-sbR}H5rs)OcwDNfuFc_Q6 zE(vgJX%j^{yLqr_E8%H}lPJzM1{aw}19}-3Uti&wn!>1nC7;D#cY6ww{8ph-4(=7#busxvLLexnyCMFPucl@+BroS~2i+Ug?U3o}8~pJ6kEEf99jFsb4tTbH!O$gYnD(V*0`%=g z#t+K&q+56Ej>+o?5jY~n0f|m5?2z+q!fM6a5=_0L@UK;}L z+rD(K?va^?jt3MolnJ5qw%1x!skJ5575H=v^SidqU$@Krb$d2{y=|r%M@0AOK8j7o z57|oebaq~Z82N20^i>yI6(;;uu`1|(-AUtHU44UIuQyQtyP&AAfuf-_vRqr3sHalT zc{p9kfX+Z{`E9Y?XPc{`3au*W`}BbBdFZU{Ek<)F$PE#CbjB*Zf{e*dh1E&Znb5`+ z62vAOo{g$OOaK82x=lulneJ(l7$e~+DJq*>n1thFipjY)2V6+q68aF{ISXOUxVjk6JIv*km*pSRnwztBq}S?tV9yGOmA6}=U!DU zE1?I}bkKfmBoR|BcqTV1O>0Co@vP~BVV2CubEe~Fn3yb4H9#m+TdAMg2u|8+m4E-Z z`&NQXs?*_!(mfuIOu`}(@y_s=64$y_66w~Isp)Rex?>?Jo*+}wr?=#;={X6SBv^!& zh)a85N` zPFp^)B6P0`p}Y`!9D4HpqxTEKS)*^=?a6ecJF@=l@Iqs5c*T8a>2TiNZk&G(*iNNS zWzM9}6LtU?;m>e z-52tW7yszbH-3;8J}~-f0%N()FWtVB`dRkVoc34t$QHy4k?+x}jaE%iMap~yB0^qx zEY9tMqrHf9G+{NrS|uIOZXZ=P1DiTz+zeWU6_QBDuh|pVp6fd0V>amoEB%VLmC9wJ z)`PurCEEH_>=url$KxHxW39^OwU}G~AA4~cjiFU0YA~Sgtn!>ocj~Ume%J<4tMa1E zRP1-K7N~gA*-4WNxUI;J%6N$U&US7Wy9?O8d>><^bG7zSug>-p z6m$-4{0GXsf2d8FjVOw&ww14Sizk5x>6r0_U*P>%Hi5b;& z+hiv&US!;K17B!!QY<>Dm|Te<%(}QTt4Y&|TMCg9qf*FpUk%4*lnaC;h{X&hSE9fm zzwNk_cK`&TksC;De$dEZ?Vh>4cVXOLJQXl!XcxtD0=9(e!Ji3HEyhdaxQ&tMLbCT45kO~ z4W)2V^sb5r^5TINv9;)}%Y2glNftDY}=mWO|Rx>(=zPZ#UtHa*DWqgb==-aGq@Z6-%y#SwU_h2}prXj=@_ ziztec_S>3Xi`$BfWk@V-*+GqL49;7*HHmHa@pt#eXjN{2?}w?flS2n4lbxpl>n`&_ zI=hGH-^|<2==s+`&+KXJ-^|<2=!Mrn@7`1NZ{}@h^dd#CBj7&I&;~ua8H(Y|tMXd&|inOS)*A6Y@qm zPUuBru~LvVT|`l5Vw!4WyLK?E6cr`45~0;w5Nyd5=?D8I+Of;Zq9BPd2EOqi!Iu%! zVzxJXV`=}xncpVwCs+70MlUc?=0N%Yq*8dPx?)iqM9kdWv6P&>~!VMx}wC@1Tme~Z-BM6 z)&`kkX?LZuL`=Sqtcgn-%AcM}$TKnJEV&G`(gUN?jBV3_F@AwOe?#8Cql23qV))Nn zZntbQ`|;bw4~PGv|4IMa0rHjuF&?Y+W#}#3lD)7vxG=bhKwMJ3ihUWg#HWp0v}LOg YBkcX!^BxYooW-c_P{G%BTePtJ7wzg)<^TWy literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_433589.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_433589.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a198a06311602f867f7faa211235bc7c199013b9 GIT binary patch literal 4625 zcmb_fU2NOd6(%W4B1MU^BP*_*7**`XjXESwEO$*|PRPn&GU0D==~Te2i7 zmsDiC6rcnFVxbR`fdW!qHt%Wa5?~9EA$iHl|=F}T7^kqAj5^2>{ z{?=T8_nv$1;rY4u@;irr_4zymw3E?4O#KF^r+8vDcZ1k^3`Cib2#HJ(7NN>QMSp96KDrKTPL6|6B=o)#VL#0 z9wqW#V?D< z*S&+wH`#Sr*0x98t8)@{-*FGe1TEeiYn?@CjS*H`9GM-V|0R}uZ_H9#lcg4yMd$wi zDC}t3Ut+_r`|(aoF)6#~$!8+=i>XLOI=inT39ePKrebTA;F55Aihp`8n z@3QU~Nz4*q_ZY(0ebSN&QwE(B)oH`2Xj3u@lLjr$#MH+6=0{-?K`2DnYfxIk@T8G4 ziNvH3lMI)L#MGp0&`gz;bj{ zNvNj3hHFMhYe+TR@J@qhaL5^hl4cDOEL|6oK_gWI_oW%bCCvy*MpN-thFZ37FAQWA zG_9t^s5~$wMyDY{C3R4ogpdYQ6dlmyWO@K*1Br-`QczO(byhl*o)v&4KtP(35~2!( zb#!p>$H#|Y#Rv#ZqQbO{QgQ-?!2=8$>SchjJhATa77rE+#CJAe8K=PM2shDsMJ z?iv@$kFIk5Qg2}>Kf1>IOM{iY)xiELzyEGWl|Pt2v&y#tucsR5uk!tM-q}{Xb(Sl3 z7CKA2E85+`Wp-eNja1pl!^oo(4^GtBQ~BX14S~+*2#l;Vo}#PZn%9ciLiS5ko?Do! z`Fm?jAG{0x_ZRw0?WKzg!OF#D&)bXps-C`jD8~xNiYE&vD}jo7r>n;GK^!bMFBQ7J zyi~rjaHTS`*na2C-S*|+kM62}iatL5AXXjtNi{h9m(gnQ+}B5{!Si2Vs|GJr*$erR zRiD53S>dx%q&ZEQk3X^y~ zTId(Il7G|NR$9#w;A3meEq#9t&CiMDS2%mpIj3glvsSrFciI@{(w(|%deq!ApA~}S zu-0#qn&0BKDm(OT>~?DbOH!vNNW4do#Yx=_7Vu66t<~*OqAfeHPcL3~zhvEY-+8YQ zGfC*~c;~hii_-0=E)!wb*j&%$DKRapA!Q;o6B1KWNSX~LGpZH>;EYaf{1@6>&ly?N zGDr?dF~ChVrldk+Ax%+22@y@oCX@q0> zqhwUWRfi(OpOR-ZA+2O(Bq$R?#PD9qfKamWt`qP%It7@WOlz}3A_m}ufyr>y(b%A` z$K-67+3_72ynXyk*P#?TzI*Iy$QW@whd_@PaN^dg?q7%T{iv_=hE>Enw zncLYL*-~%$z`}vQxx1h6Z!L8lukk}m%+OkJclqkV)%@5h>nU~>x=NAKwQ_bLTR{t- z*Vw)l_E42QRAUe4hgSi>x(nTyEd^?q6XKSJ}fgHj*D%QV8cL6QIJh39gxhrZyp@u-ut>^9h+4)bZba%5JWXm&9!ENiI7T|8a1-nr;t}6Q##F`7xQS%12Q;Bop8Fsq>_-nB22^e17$u;DpDpPpF)>F zHNG$^E(A9nBuPFceE%f&{)6zYk?z~wH@Y{e02y8qqbp*vDki@nV7BETD6h$X_Ffm+ xSNgboe&PHk0pz0m`^4`OCiytHg*jWJBtiPV@xD(6p3RVi|Lqzdy1|*g{0|RgmjnO+ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_459560.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_459560.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90896d50ab2754b412d09daa76d84a5a8bd481f9 GIT binary patch literal 5092 zcmb_gTWk~A8J@An_Ke4NF3z0*CWI6tAQ0FD+66Wtz=kGi>9WmkwA+#KOcFbe$DA3H z3uDPpB(V0uj)Y|M)S!JJyC8K}T6J0N!|r=uFpi3Kno3BUmq>j`H|Hs9w{6Ap^{fiz0RS|m3A1Vow=DUl9QAXu#|6DHAogSlU1kK05>qJp)RQWPaxp24?D+=NxM-C)M; zWmN19YMeJw5?{9Kn5VtmIx#znZ3Dx2j=@^d{APwFOYXt$k%DLL9U={tkrU(+Jcv7sNUeM z$_CKtwR$ohk$%9GW=CTdc);m4cyN`Sm5E?onJg2UTeFb)D|>Iy>RtnT;|s7i{uFzx zNox`$9OC^kzG1v?ZK1w=`h6OcTMoyC?o|Nl+B z7m=NgtsqBWF@8Cu=$Nr}PKc;T!8TAU5=9O@hfU%<5txV&yjOe(Ge}WIUADjzWffce zvNSOPvUAuf&IxiwPemawI)nTyNr74|^9 zf0g5t&57nzb4Fc0vch%82f%*vV&dZMj`Y68eLARByCd0^=mM`@i#wMtuh_lI{aO2g z_~})^*q7)_o=BX?xHHO~)*QbZ0J>_EQbM{tkUqb7KIhsQKUXr)oaxQh?OLW+?EB)U z^NyO-k<9KT^Br&2-I;alS!S}1{qZw-t~PZvu_Lo*sq4W@5wp4;%8SG>&8@(3om-Ix_^dNvc2~y>#bSJQOv4YG^;)uUWNEsC?Zqp zzfP;p5;s^SFKAoQ-v* z0IGY2oV6quCP=GUvXvxaEUaBMiJ$taE@$*rU1G}@r`cZ8S9Lj~uj&$8zU*pqY2Gwu zf+qVFPE53B(-_z%CXLlNa(;Ti-iKUD3QYmJD!W)K&qU0o+1{hP+~~s2_ojRi$ukpB zAUFA^JeMSOR+7S=Igc+adc^s4k{Z|l3uVE3R`RJ)BzeRDWGp2hhdrYnRhB&=ADWar z6EgA?50>lXaHuEM_5Xk-Buw(%Lc+8t#2Yq)16>$r@ zWE7YTgnc1F$AZY*q83sSjA$RJNJ2x~NYxINLOwx5dPDo6jUC1ugx&#GCp^jxJ^AE$ z2MOS*OX;b_smJ_|M@;`CPk#<-e#g_sru5k2SbQ{ZV{b;UMSnl5XC{X;c{Xu2)swoM zo?V>Hpv9{>$L?En9%A%ZVl4HWO#L5P?zAj*+-ZjxearN;;NjX}ygz<<)z$FC)tPm5 zE?vmEeg-MZ#@_0?es$H=@~Cw%=Q{IH-i*=)of<#J^?T#Ghi4VXi;zQ^l ztcD++av>K3vc#JvsMAz{iW#pfWSdtMkrv@@W;lXZxzfOqy~vHxhNI`}TQhAX3;5Jc z;26ERLW1r-R~LNj|}+`cs#enjS>|82c=?!k&otL*PqpOKgMlKE2fSmE~*am(CUL zWPD(aYfSkv^@U)#9S+mCMzTzs(GDuhvnP#qOiR6AHGaMQ%k3-Nsh6?yrtrLuoeh?9 zxZ|UGRP;LaB|&{Khq$JYe-DP>;u?iJ8D>iq(l>@V7*H_>*;P@cxd_seDz;vP{7>?- ze!JDzQ(qO#7nW3*ltbcYlCKt zD`jGDwq0x6FtyU&M}FyvU(WjFA1P>Unkkb*=WiVAqFt#o>EXrU4GPF5>0an=NGBgQ eZW7Ms?`Vp4{^;nX-T#=QshXWRm**O=YH^a4*wnwhY+*}_)il(B0^6oQ)|IZV&iW>EFlFcOd5?=-^?iE@@&R8 z$}vb~CS5-@)9PDcRc=yxSw^`rR`Gqq-Ps%%4GKsR6i!8xO+VFD{I8G)RADTj1i$f* zhU$^=-=I;EMXFf$EzL3U&>D3+LV`mVetYI)c_dj|7!~{B@xik}Sh`4qp+edwf#2E` z=FaALK4>B{X!_>Cem;y)9+}KbJY%vi@vO%)x79C#tu+I8HfPN9CU0_6z~u#xRp6&q z5CcCCVpCAqyS_Vg9ndC)>f75Jb*9ll)mIhznfaI|qq zP@OO(Q=;clUrpGv>IPS9vbAnnC(N+vr#o8r5%J8~I6~1AVbqM85gJ3j$z#_j*65a4 zo6Y8`K+j^I4^5dQJwB z4DlOUR!(ZLivEG){qMdDBcxUKsLE3+&Z=pA80M!aqYnaJSm=2h62pZf_K_k}4s}{* zDx81un6s}Wc9pm;Yp}uvJnYaGb`?zDv)^-mR}7ZLm^EAxqt1SNz#3i=#KLZSx6|w> zw>p;FzPtQLIQm%VFA4qkFFhE&H(C~cZ4K6xTDDOdT8%ara(1rB6%)&?<>+DS{7OhF zblcs|p5pmZ^PzI+uytky;y7WSD4eoS6CKCYFXlM@x`_xUOrZCdfWP_ zBDEK}lC;k{yAo+CjM-z(wdFmf_INq+OY2-kYIPKQqYxk&{|alS(QI3+=^XR}+tOYn6os8S0ZVHFJ#Tv2 z)UV|;nr2?p53{Ww#;7NWMhU@u0f5|gyOa07V$?_XE+0^+HKfO_t+pO|UiqLAt&N*O zqq#;;w$!K>Z_{4`=K8ttrr+eJhGFkSqrIl7)b+k%40>JMW(LOGGljItm^89RC*rHk z0Db{Z?`)n1gG{c``I~k9%P1ZgS^F=vHCiONaGOGiLGvG6nMmoe3w<$)D6!80c4iY; zi7A?@gUE>KGt<)=HezFH!kEEoOi5*8r%%O>_kTRp6PEx@Clb>t31YQM1gc0%LP;%a zq{e5o86B`s!^sH}0{AlK9hNVDG zbe&$WIsoHCR8{gKB_06LJMif3E<~<}P)qC5?pwR9-&BN9p~LQQ`kd=av$tl8_|}(Y z;UK^y;9aNP=?oRwZ~5E&a?5S0EF89mo&@#CrI9syb=P?4pn#lF5>VC#6!nW#QYC)Hm;=}p4OKVfWI|WvqiggNS);!K zUU8q7Ifmm~hWq0j?x8Y26v9dt)4u>0z2cmU$uns~_h2Q3Mx91Q(XR@8O>PLDG7+DH z@jp<-9vA$cM4GH~&m|;83ZgAKgU3G;Z9oz@oNut#R*OP`{mKmt(316cy#UkRHOc_Lg zyr~;%Zn~Q3NZ^W^)iq4~?thJYwC({B5i%~Et~{=#mG6-|arHp?lNn8!Nvo&vC*YN4 zdwmwVbsxhpPm%Ob6#f_5w;n(&f1`7ubDiyC-d>L&sD?kaKm9QZ+=dUaQdRyTgIaby z3&xlOm7sLPTrk&tKtJKb1<97y{Xm1GP;`T8Y+Pax6aG;;&a^(uF(~?0ITBkC-N60_ D0?>z( literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_493519.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_493519.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd019cfc55abb4856319311ddcf5abaa708c6a58 GIT binary patch literal 4522 zcmb_fZ)_9U5r1#Jws-BdO%e7@J8&RfUE zaRUcBUb1gz-k+IyZ)Sfp{=3WNK+y2jKa4yD>M3PP&9+Xg{TYZkBq0e$(O~h72eHXB zG0Px}k&GjzoYL0fTViB3B5WL^+z=yKuCce*CkE{jE2GGk|6eCrUovi!`5~KRzs3$a z%2sjL(4dn+va>8#n8X2og^KMb_G5iNKmLi>AF@}du>x&@XD!Ll18TT57XiMyN%;2q ze9EpP-L6}vVE>d0p%l{b1|I9o2A-*q%vQ@qaIL!S_PWLtr*k?x3S5e>*i~{x#b&U9 zn*v_vCFYLhHa!Ql2w!|x_F9?Q%t_@_Ilf09e#`ot75X*_65XLY=nYHw4qLh%ot<*R z%?OgUVz0Y=yfUE+x|QBd>1^gMm+sVERAR;6&Gb57rnk6zLvLC*1)-mES4iC?lp!LD z-g4{w9b3s;7)_#pyfC(E+hSl}?)uREd(7+AwgwdkJ zF{UJjm0(N^OT-pbg7IOQu+cFs-{|klZ25biREfgA*`Zmph!xp9$4z% z*uMKk>qJ5sRmXxMxpgEM8VzbAiTK{&upHM~RVCD_$+59k;98^YVmzV5#LE*>%h;p{ zLn4@?CE{Wz0Y)7-wEu&)Loh??#EPZjsI0{0sIm)G)2~|H0pDjDJ>QC)uI%2--W)DC zeCgxMtaYYtM$hvNdA1?lz0BHXnhhcE*pX*ojx&tU7J%pv1U&R%f((>=>hkI|Uv zO!qwF99etDZeXKlmdk0k_Aha*kGS?c*M6`4{^7fa3tU&a`!Vm#ZqIBtY7J?&eqrma zfhE4}5q}`hAGmk+{@~rg0{>C^*fWQa-ILj4)EH-Gy}2_>j$I3lc}FuC>U4iuJGXsy zd#*00&es=en)A-)bk7ssIU{B2zq~v*G&__#yU;j){#N7Cwx8U(lHa!PPd#~W=Ys=z z@6o^P&U-(8Sd;gj%JZi{3v0W6a{A=VSz~wRVvbqZ`cHPx8e9@m>S^br-8Se#>$m|V zLKMPHJbVrTIfrx+6!!! zn9&Y6+7?(>iS+C1OR2%M;`QXX5RJ90oXH^5+->Dh`HbSGUs+F{SfH2Z_mL_%UkQpVA)38o&YBS%IPXy-xG`+p< zUrY-U_$!?gBc0tW!u0Y92T(I=Ln@wNFuDkKVDu|r(sB_^V8G&)|3X`#cNd1T2u7NE{j=4gf-JQjCU2Wg^H?IR*$Rs^KY_SYqKgu@%9Pc;a$W z6UP!0vLYsi#C9T_9S2b{xmQsV0LGzlDM)ORd^H@Bi8UOT)8-lgk=N+4Da!q9Gb)J_wDmupNI zOP^Zi_$;5{jV2>FH#|F>yF5Eu;FGE=+WyY6;L7%8`fl{k;AK01eR6uzsQn_9!(ZF$mtCIhXPM7#h%?MH-u;NL&-3-U z#sW_v-6>={G9A-B>F)HEWv?$6ycwU57v62jdt1`|kKHv!%kTC)YU<24bv~#mH2t*T zK9W8Gs%JVr|K(%1Z*k|ThaC?Civ#Bt8$K<#&o6T4AM^e1B)x=OxmBqFd8qh1WD6brDuVnTs%~jI}0<01+1wu_N#TA;HO`r>~CZ>-^HMIg( zDT=DpA_|m6jND?HP*Jcd$6>yYE^IMz@`-y(`owbq0lGLdPNUnnvas#W_#abWrIxr4 z)7^k;*@jF5)S`G{&pqpX*Im~V-vy-|?=^zCS`+J@?A(<;F>^Z4`YPop<@K~GeIbWg zIC#%{f9KtuOZ?H-k@L3f`5|&vIV9Aj`xQUjTfk+yMA-{8@f5RSOB8Hnu zcvS9EsOM;1phlrvwP1{&BH>@i^>4Ie)rJ`Ex^LRI%GBc@t-6tWOZHOc(l;L9)+~q- zisPO2xXI|9>znOcML;gdzmNVlYLXAUYjk9-2dWX*H^M<&_n#z2o?QjEf7)qU^FRNw Bb~^w7 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_570539.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_570539.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..245d048fe1e8ae3444bd757a1fdb96720f977c9f GIT binary patch literal 4424 zcmb_fT}&I<6~1>o9?y7egCQ&>Kp?-=ssagV5^a=HbKdb9M1p-;~{s( z39w_yWFv)Xr1kKGQB^Jb)NDa&B1OtRwoiTN3zku1O;ah-@D`x-u{2m_q#yYKDy zSN-DeqkfS=idfYgn7|SLDplKO`%d@%^2{%!z8LQ^;|1CS|MkxqxF)LFV6)?^?l$xWOM&-hsm5h#lZ#`7j1Y}Y z#WRSM2vMtb?ITp(Qk57ESIH`2)EXR(ctuk>{07T4KeQ~f6%Et_>XQ+seY(!5DN#mU zX~iVG80mVK_PVMQRl*1w^@e}ZTRAI?CQyXyBlOw~R%S&Qn;ng6Beqx7hZIugubqlu z0)|nzr_Ky*DkVHqNxNLJ#5={;&Hk(kRGDub1Hvpu5sN42vsZpv_j67j`?nH^r5>1zFR#(U2>7ce_eTenq~4{GSeLGhj&;GPE72__RWd5wx+`~%O-R6!z-e7gN--5&KYaAS z;X@yS#IBcIQ>77wB$Wi&1NJG}X>IU=XV6L@l)YqL%Iz!!S~BN~zJP`AO&0mkO0Xt- z-Mnty&UF-myD}dY#jw?59?P69218aut|cGdn-A`N#N~sXAPm)7hjRyKJ~WfLJNeMw z%=uS2E4)AJGkq49yZWepncx4MKal4SJgs}S{mJ$Me>~IkLJ+f!W~0@Z(;ppL7P_7b z-Fcz=Y4@{ZPmUFYlbO>i{y=u0xz7q)muDIt?$7Pa`QSe2+q2&;xj~u~nj_)q*^@Q3{bLv(hvYTd&7kuE%;!Mf+yy*kG%8$r-G{ zIcHa=*H_laUW5G|k~xEe-uHWNW%fO+H&o^ro&ijw0{IUcUc&==&@jD38=CB?tcL8z zE%FQH$YEr7H#^78PoSdXV-ykkRwhZypn$m2)i{?sjW!EH?aa|fy?|LbTXrJN%~ zB@Ry6TnUaNz6~!Sp(sjdoXIZQQQtIr@yftTb^~n(-1tWh1FMbBYjJ?D=;I%xrc$5X zwZiv4Ee3+wi{?dm?RS3BG1oD_Ywl--z`ki*^a&5{Ox^ix+}e51C;}w*oBcT^clwJ9 za~I~X%=H$6k!j{7FF44CO@F4R=pMWfLeGV!ywC*jDzs!yzZ3)6Q|75RzlmJhxX({d*&|}LLGppK7RV>y-#0+ezMej;Bok| zXX()8r7PE$F843>-&kt8SqOcy#DB8FHs-=PHqW+#BSqft{D7dpEN}8w_@P+j#bRya zf5L*7Tl)Y3{vR`sS<%_SnZbg%V~O2C_Cl0m@UJ#na?ud4GtkeXI7+(aJ3lw%J%pjzYo8$;zF03?Na~ zCjc2mUD0$Y7S$B#AnBm9={KL7-R&M6B9RT^pGR1-pYps=*~&>mzY!b|i?9yUm%O{1`oJJPOoj`W#e@;KA^}_D{{t-aZitI;TSpLOAW3Qby}c zz&A(=N_LyCP=sbB+cO;3Z2?lCrYqxPmE?96a@&Gydpt|1`jl|LlDOgQA%4eOM%-iHwmeit}`>%SexsHUwD KYM&AvSN;zx!fmVo literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_597752.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_597752.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89da40d0f7260116caedebeb06ea65b0073199fe GIT binary patch literal 4266 zcmb_fUrZax8J}6N*SlWZU>w3hLPJE3$kzfn=W>)&LIfcpx+7ep$Ia zE+jfe+S%`$Z|0ldd^7u-Z~yG|x)GGV^l!(5AU&r_xwuNg_7f16kc?y;LnFmE9>JE* zOgTnaj1)X!`6&-AzGX&XCp<5=QEilw9e3FKCBukI7GzdI6Sd#(DLem9JEsVvob0;e z9C250@psUO$RI_mc=pWUpnH?L9Wq^ggCCyzKpKkJfjYtDfM@#)hAvUVv@b0apzka) zzF)E@-MU+MOkP8X;5qv!cIOxbLDPl%C2vVl7j*RFyXI=6+!pwP8#3J>g)7cy^i|d14V5W6m`XsrP{(oJr#S- z!|6%}bOvI}?}+UI+gx=OXca*}pa*o%17~?}F`7d`ZivvMGgj^uWK4D{tWKKFgf^~_ zAU4_XY*YUE--PE|_~C0K9bk4>ikDm)*X_>zf-OouY9njTdnQCX2@C6c&ldP|x-_p)kP z2|b{ugZ5)1iI{4^Gr3u5S|h57XH6Fjvt&k|GaWa=#AJ!80YaJDiv84jaMDt({QGB} zHxp!1oeoEo&hc<$5*Cq&w};1+xYntXNT;SuO?QIU84F4A1euaPy(xE0&q>fE!6LLo zT#6)MDW|%-x=w|_Vg@8zS7}lqaV19D003n{^*DU^2KtUiqIaP^-JZpHcaw3h;Bu$J zw`U8yZv)!9)7_cjtSc|J82tsYCUZ3XQ=@;~>&pZ(dd}aP^R_O_IqzG>g=f5b!IgGp zxa{SXz$)Lq#-GUXCmz*3ZhF|1=T95G>%Q8|$!v2j(3bPHtsKbt+Kqu{g1FF-Zpbub zwUv{rLg$(g$_b%Ip~vq%d@nEb7=0UV&q72u`*wi88M$pQvkVHCu#h$qOT-PWcvq>jd=~uL)lrIyt z7VPCK(blJIw{Yw}9`8CHYZW%H#@zb<*o)I>46QOzg#mSEh38zlQ+G}F!#0Rog%@R} zY`=}QK-rVdj#_brdkO2D&ca9rR!qOa9YuDO$3x_Iw{v^gUBK?8`xqmgtG17NWwxKA zpmS*RKTziVLoLc|L{Vh5rF5-ZJPABV&rDDRDjN?^1MS_2DnM0H6+ZX479E>O%&4Z@ zCOd)gBIBkT_(GeLV$n&(xUtkawWzDDC(Us0>u_cozoJ1IGCRFD!6r zE=4k@vUQ8ydA`}`EtD9cF7u0QXz^-ZIAZh_L=XMeZ}l6!#`Sej&Zf_Ezpd%DxTDBehQ!jA9aPE2;JlSulh}42|8Q@NR^bNtewYe7Idot$*?AhU?lB*v zv-^nt^}Owlo__`O%)Z9{^}OwlUU&uc?tMl7dfs+NFH-b60`Bt+ZPKH=NjU_vsu?*9 z*&rt?H=+^6bjsQs@SUhBswPFksv^C;Ne}fV{i(7yhwRG9P5MJ)Z+R!kk`CJDguGFX z6ME5DtQ2HT7g5xin5Np;t`*EGMMX)iNNBYd1Y0si`oVsYcI>dSC`clVf$uy>@GXdG zG1Hs5zI^b(%s0t<$yNT0(F;tpa43BUQYpOL_Q?6z`_Q{;+TVia|W~V*GjP_y5}H z33{#FAtxZ5W=-*;>WD#xKLfe1$(JQ!oe{@GR5Jy*>{Lyeoi5*5mzB7hAg0s$4Y0P> z+8|Rb?yfYJh{^YmHF0r6`O{Mgc_yaxkV`NtJuoWG*tQ%Po2D`uBWm&#vL@ z?n!Hegjad+@xI8BmconQ+&F#dOZ&~<`Pfi{ zal=bTH#6Ti-|Y8&GxPo3A0iPEK?~qt4bKAfB}FzHUjjD%1i($CA&p6*{^HE^Gak;4 z2l_b%>CC9dXYgQN`Xv_WQdyon&BVktDz1-&JKz3M*ALacgkSv-kOI&*{sQAoWHEFvZKG^) z_KzBKt3=O;wrB^&1`xu`G_nz3p_w2;Gq61)*$m*}GA`RJ;Cva6_@r=t{zX`0Y~fZ( zdPcDon;Qcx6ZL&XwSKDElC4nbc2r|;2X4_bK^ql{vwyV8$W{+oK4HZ+^5mB+uhOot zE9f0%s(g~SFe$srj{fiGq7pRbHn(}h+sOJBlKfI;Bsf>KPnJ@Cgfd`P+d=w*2 zEA1NZMrwUuZgnGYcXZ$X-@R;}nQBk(sCeFz9n{&icHInkYo^|J{w;!s9Vyvjy&b(B z+_ZfLO`}+-4^vkf+Pvi$OStj4IYxqpHLT+pL%4W)#EdZ*G(a(duvU_Y6WAES@o{xT zBcV8sr-pRGC5^ZSRD*;E0kn>{kbpTgPJ-B+#1;vfmVx!yF2YART6jd$)wGI@D`A zmsZV561W`4gu|uirOM(Rj> zNKaYKCQdY4`uIdMV9m)EHD%y&^}Q?F(TQmlmQ*;HWu(-E0mo}Q-rj!V1W-t|>f5T0 z={Th)@pnOoFm%&}FzTSM8;}&4KjI$AF$J-~=~@npbMdR`WnNj8BiTSUnXh;yFTdiP zS&l@r2Xc{o)vNi)t4_~T=wo1V!ab2ap9>eHCZ~H@ie~q_9ZvUSUYrlR;cPI+-H;br ze`zl9#~<-+dA{u)zCUw!roeYPr&k1NzTU0R*5wY~Xk55(r+P_v{gKd-7dk$xdf4!w zp&)cSovV^OH|f6abgzi={2}*HwkF4Y82%vqG5d-5kyzMuIA3}A&i*B_6&PQ>d-*ro z!=VR5zkBbC(a%QT`Qn#hue|u${o0Iwjh1S>0XtUY>j)?IkO@w z^F3})wk5~hXv=kf*!w~6$DNYU0DRaa~HCcwyl#Jsc9ikJxuNV(uW;i~flTgv=NHnFVEp@`UqGQz< zR9lF6F`k^#&tPm|aDD?PhKUHiYfY=k5pY84fFx9mKC$pqZ;oq#ifidWrQt&`nDov| zQ3JmQFp>R^jay;qDBbr97#y@NBBAJWj4L{)mwCZ+K-QJO(Uh7;%E7#HaN&GGdClov z4)b&ESAV{u>{{I2R#4g(`S!EsMz=abusDKNu~$S-yayfvai3lW?qU*iF_GJSE$i$PQ?8iR@fsKQghS6$$4p#K!QdDJT_R+S2;_iGsm%Dm)C+0)s{g}U2Q zzs`J`S>nI%oQ5DXzsKDJYaiL zkno!j_6#F1I${waVWdn;Pfy@dO1S_ji>4f}H1^)pdle!TpHv+(lG>-plU961{^Yo! zO(peH_#8;3R$@9ZtOpo|`4Y*0Ly^Cuz3U;w^4A+?8`jwa%;EJKq*TlgxC39+0=5xA ztXy1wn`vN_<*56N;GOUGM-6d$M%Go zF(lcshoJmz@-Guq4#q+Fvc zMlw#=dTLXvUx|^~g#URN<;EDvb&b8f-ZAQxc!`x!qUon2CHD*D9+@BWNZxDiQJ>=% zcMXjS43Y&$vwsFhd@EG#F!2ta{rKcZ;&80CEGl+GbwORT7`j9qQz0!0;2VpCZ?E@f zy#_M8hHDn)&jt|6B7=F3XAJxtkLxtEjd~Gmtl_!6ZZXRloWV{4m*wkbd44(zeBfq5 zZ156u$90>o14@Fge(PJ^VKzG^XH~ve*uz_FPpGrE$u%*2hL7&BgzvDkPcYuCR&SpdlCm>WvXJ0*O{fxkE_qntpqf?fEP`ifpLP+BRpn zL%_|co?^)0?|5pbh|vs+c!yPbDaLDKBMf0vQEigA6@5ZhBbcz!>A1E+Pd5>q=B5 ztg7h};lPe4&4{#PTBGkg+))1e``)XHI;l-XV{-3AG&Tu7De3*uaXGE`YHF-kms3-{ z!1X5k#I&NO#7kGDo~aoTL?U>mD`_#NfPaVj-+k}B!_Y%o#k!&5q^zdpq`C)e)2F0u zgQuKBKW;-pps?TEU&Li!XYOQ$b?002Mv321Vt3>QDy%2J$MTnaJ4@`&++YK{3{&4X z-?z>ey=5Vs8>$E)YqxnMH?++83SQG|VQXlfE9$onE^)n&xV{qCcdzgM(Yr^>+=<-4 zGA|U`&33EBlIA-Wx81t1#2!uC zF_VX3X-zP$J{YWzuSE|4b^Mxo@Jsbz-GSuI;MdpyCiZ2!lXbtK*G2a(>yhYe2u6*i zww_rUQoMSo)^B)q2P_jUHR;CJZ~)CY7tU}S+~g4KouIeXMs5J#7ZAqAFodo4TA#sF z-yl$7gL=2wskmye3AzCa0Rq?01S5mp>=qc#j0+SZSN;cOg$8CDEaM1aG~55lL|h9e zqcgG^rf3nC6j@8})5BV3YD!V{@R%IcGpZbx;;Hb7QRcs=m zv^=eg0IsqsDq~_F@sDIcRZ2dks)}mEQpDHr?1*n#wC#(4<-6$uw?jklITAcKa8dl= z4;M}ii5Gr1da4S;>^K0lM$@njY*oAI41YZms&MtuTTr#_HUw6hTNc{q+jAdQIKIG} zywz<*7sltui`);xP-IFBn{{`$<^jMegGwury?c2okP!sq7aH^e-%%m*Lw9VNb_xVy|#Oa-eH zD;ZM&Q|k;ieZaJJ7NcLKZ>Gy{^_1Fra>L8PW~=AVdmnWlDRm!t&|L2Rbvbw}cN+Z2 zAO7Nx%fZfvJIyl*%I3YYAc-2{_zFQDQbqb9dHtsRdE}%FheRYAI#_5J4VF~iwx^!ie#19gi zb!oOfMf~~N7~MfOczcZVf?Wn6v*bPozP`eH$ga20`#00}a__m<@SfSyy?--pFZZ5* z4exzh>iwH(d%5=l_1-{$y#C=8dUaPgXdwH`NKr^nJ(7GS9+Qb%(q{lb1zpy3F&5Qi z@xThb)GIXis~7jxGY8du4c+|_R{a%~xgm#@(kjjPZ2*g~HacgLx>kpp6eKn36-B@* z_BGiq)MbL|2hdNe)??$}<6u+n^pAeT2MdBJSOeDO#b4aXe3$(;yTl#O4Lk+;4s!=2 zd3bT}J@@^<-M|umA~*PyZ?mGs78}DJckRlZ&Yvr>o%MQDXFN^!HuNxyhwioA@4DNy z#Q)}X^lX(sKS$3hhy1O%VH*`Ac>LMKll))*^Yla{V4sfqE^H?WRYP4*68>#uAaN&L zjZ$g;LtKftPI!o*nl4XIRbw9Ud?KecMI~w3hLPJW9$kzgile?5t{s=-sbR}H5rs)OcWcj=c7>rG3 zm+&Ur?n{L2=#V zN8D9Z;ypAXb4Zn|c0KbrC~Y#cL)6tb`2NKohKHhdm`+eR;M@K=$2O>AHkOqMXzwfv zepuG0+=kn5Ox{F@;CcHje&-wnE;FQuWos$fkPUtkS}Bi#z=p5@2o_+HA+tK=wIQfD z-)7SA8Xkpv(}kEe%-#UUvFEfMt~ST!^f~G z_z_?6p2083;3L0nhra4&tHOl8Dpm!Wv0we-9M(HBdB^W|nIU6SGtr zxd73X3K$&tR@~v+1Gcpqs?e%}ali-|o=48g*=nTu&bOaK82x+$l}se76v$4F!<98;(eA(6zGO8Iy)qJTC@U85lCr$W@B%}gE+_(KjNJy_wE*XAT$KE#}3d%e@%6 zJ69BaFQC6C-IE>8xe9Wd*bPb3twgV3NG3q?>>F!SeN`>$$!a|7Vw1{cDZiTh)K7Jox0xYx(Zp-}mOb zFXbCA|G}Scd@nD(XZFpxN;#n4KC7$-Hgzbt8Cn%qNFf8i;z(S3t(%OG`IHm1>=x}PmB&P@ z2R#dh9i^hTaO^%G?>ZmrRW`51+{S;|OVF7Qtuj%A0S#xB=Uj%z>eKV>8Jajk;~J z6BsW#PTjy4`g}MZn^dV#A_!WSQ0MgUbaGZD;pAvIMBUdT@fr0JAxUB}0~JaX7!-Ge zJ9!I05atn8KEKjlU!)&MnO7a_7!ChyX$xQ&0h1?1~e|DFRaOjnVrG(;QgUR zT$H`*@}azZXiaV{dh0SDr9ZkKUgTa#zICZNFE!_m7No<#CcXpfzGHdcv6bP1@3=Xz z;Sv|m-2LRa?~P~8-!J$&pNX9>_@}qeLJZ z6cqJ#EUHqcqR#{0$-1iP;b=rt!zVXcP;at7RrcV)j{h_h3?#f%tVqNP497%FrxLtyny$`GS01WsYC=mA>a=oy>pZPfp>pX+!(+*~@&H+@ zl}=GSJC#&s;_7*F6;@?YqOmk>%YiZe3VHsDynjQ7w*w3hLPJE3$kzgia|y>OA%c()T?tpNX?j69t$f}E48|t2 zO9I?l+C-6JD-W2q623Q1eK^uYY9pme9{bp;4~=D1tZ6Dm3NJiVBF9y#`qF;0{?iTT zLZV}&o&CP~X1@8&H?zO__Rn6gn}LGw{$@M~(o3q8i>oATJq2NjQ5Xfqn33WejUY>B zryL_3VpKF?`6&-Az7GCe?GSPG4Ssy?$MR6b4%7)I2RvK9VCfQdMElY*0s8i$ zpa&&;(rvg6$K*8#9GbI_;BG9+3jy*32= zcYPUN!=tbd9Stid=#64 z9&+XA8Qi=CG4k71=&LNYB24%zVns0ghLgs(vidrs)~KWYcR^8G1w~zPWU00=QBTF5 z^KiP70fU9u3fp44&o)G zIEV-ro{efjgaH8)bQ4yO5%)AsjN$N<990M(#^LywO1M}etbjE^T%#cBCql%b%}h~m z`Z#f7?K;+pQ_~Yz4b~I%DPjAs!t*IXmk1gn4s}{19!6o6Ap-F~C=!v)- zNx)LNJ~-8N{3KXNK(=+2CsiC*W4H|fP!`mV!iR4%-wKT6U1(3YXHnkWWS%Ry+^O*G z*@EC(hxYDtcV;;2%1bR~e?h9r98P~=_P_A@GJ%Yd^S9=_t;x!_)KrCk|5 zdwC_WDzvW&$8y54$8}Gd9yR5K(`N4rUv1_@wmBDQ%lX4NT4+c&WE!&i z%86C6b4?88#L(l=lMf$#m=}A@zIC@}p)K8(@n?n?>$AhF?nBFmbM992+;hOzmF`+N zlRlFTWVJ6F^HM7S6Qz`rZv53JOIH@JWc!x=UtU=8uh#!?MgN_0|BI)WbDh0^=*@MW z&(&Y}qd!;wQBM5G?5hflh!=x@K&uv7H9!?9@fC;+vgQBx9gA?1`$+b&c{dmvn-aens0#`7+UJ z!Ct-+ZGFmi3&+mm(T?MhUSacU%x(OSy?C9*&?*yE7|?K5c+O=w4cBBpY=fj%cu`@? z_S;Aglsy^Ts1;YZmyp359E@aP#q=B8Ruo5hJVbtHJHLzF1?*nBj}c?=)%MY-%=S|z z=p5Sk50rWTP>VVnQB_52DP8LpPXZ6pGZU16%ErUfKzlc$Do|BagU>y#N5^IoGa7N* zWG66QWSqExFZ4M%7M)ZHUnB@p6IW+-c{*`Z#d2a)4iWd2aBN0Bk8uK9%s}`e1qOv} z#~r^1ATW&#-cPGmT2Ww{09BdL@VlT8kDYh$`_OclcKioaCR6Zu7Y5P;=J|pkF7RnS zg)=9!b&K73q1o&$lo+8d^RsMd@oHW?WcC##5B=6}^_#us^%tO=O`lzr4pKjZ>A`zL zDO8ZWYtn(7bYN9#E_iDeK1+XgPfoGxqHj%X%!!TJ!+G%_u!(Qqn(s)?cVu}u?>lM^ zyl@GruG?Qc_r3G1@ke=I`!k_^oomSYvs{jA26U+qU^|`(X9E(LNSTiT@^*rMm@ljZD1fEKv`40`+77O(f z6J?V2+nQdB+lqn|NGxsHL6vL_&ReN9iEQ`r_xHx?6>fm75=Cz?z2f5jsq*m+dZ;((PnEqnWLHjX&>tFm%lknNchEK`S4Gm!fdp)6tRj5sEuIuYTr({y!qx_oC{R^wU%6Q}hXU~R3n zK}0I7FiOib;;mtad(OG{p8Go|e-8$E1TCNX^W-0adP#;U(~47)s)_KndP zDR|1&Q(s#Bl_`au5?_`PZk&>R*XcX8fiZ?dGB48#nriv+T-pB$c|hT0ruo=7E3?=A zW4z}vcO8uh6jB7wrZ|UV!a6ZG>M+Bne{=kEX*9t!nD+zigKzT>6xqCi>tj>EZ*4NZ zQyb4QCNde*HxKLQg9zo2Nxj5VCVq*>4Vvk#dJ!DC8Msrkm}5=Wq^E()aSf{iKVAhs zu=5}`IhnfayF>N?ZK{5uRo3kdk8XB>-mdsJ*u(GHp3q>gNdTF=$&(Y7@m<WOu-BiZ_H5bx5`y}*Ux!)9Jz3V?l#W_kAOI- z-dWh>?)vL@hS3~~1x7Uz0u%MQG0LISaednHtHz|F#jr!iXOnsiYb0u8yh9l&ho8~Z z2`!$MlCl$sYw^s4;?OBIF2ktmFylZP`vx4Jo=p>N#-!ud^vjyz_;o|olvs;{FF4f5 zm+)P1MBUJmvLek&ntIj2qmECR(KQlRj*m#_gqqS_>kgiED0$BDU5RTBt?7pBuwX)x z=Oi*PtCKe#-O~R1+y1MnHm%RZ6H5PNJTVRKsF^+S2_<9n>sq4UP|`E~!1bpFq>QSi zrAt@kzL_}*L=t#rs2M4tf`11F_7Cnm2qR9r)G$?=RVH*;H zg+11uA};e?`QsJZztp~DmbmT`-JL&Hp#w`F+G2^{QKEO`hqth6u=JpH&^}va%0e_h zQW3)TZtFmPWR2wujK$d49$99K#_fHpZ2x0+pu`T`A9!%+-k~ykG=FT36AB$xhuvn& z%bhFRZl7P}_CDqYOWff7a}UPujg`4y<%idJvCwPv+O784OVM(?+?F%dPvb189BA?urOOb)U zoGnF$O56~rrvrjtR^yf>A*C1z&W5`h?c-&@k#5c%F z*d*RHyA@yE3g`wTyc@VvyAzB|x>*#M-iq@m#*D825867(!Y)=O5yEJG@QcZ$9!e-nYRWqXFN*p3viOR`z^yrc3(BS98eE_2*WMd*A;beS9aTrah z5`iiTI($OS7|DsOn$;aYfMsIR;Q^qGIVqK#Rvb}DDQQ41Nl(rzjxU|eIDsliIpK^l zYe+NdRYjB3acRI2&t*YXS~;OavH6)Hm!ccBSo>Smr(gU}V_E*@an~`r5QFKVp7X z5rc)(*6ACgOSr;ti*pNecH7svB7VelR)XQeSJqcIq$O&N3q9sKOI&AhcbOyb3RVeh zQkDRwY7Dk`U?N?`_&1qbneqpHrAS|XbS>0s_x-8&anFHL&w;dRIL-riFvNQ^}^?zQOx|%)CeM-%Z==y=ULTd+NQO{kv&%A9<_a*}D6-U=e=&rNSK_Qrx<4~UlWaUaSp*Viom;(qE3`IAjL|j*- zz3U{X*GU=JC?05J_G!Iah6iJ`_6d>sp^}v|8cA?2gvDqVpR*}LZ@^6glREK=U|<#f zT3ifC!S9}UHa>Po*j(^-P_Un;W(XeHQDv0;D3%R)bI{l zr!+T4d{aro;UI_VhB7--&3aDY3ningn&Wr>F>$Zfy&y+`>O|q9tzlf!+upKDCrFn;dYPr!m68 V@5Le9{(Kgr@UC(wx*)hd{vT_nZTbKJ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_843690.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_843690.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160936953e497388a3e502cf7c6217eb208cfb5a GIT binary patch literal 4395 zcmb_fUu+Y}8Q)#6_m6kiu@gc>Acs?N0jq@&1EGgFxRCHi6w=-uM~{e;WxN~W#P&L~ z8%T0%$simAFFv9*4zqJN)p8eURqSYCmO+eVCuqXxGvTOtHtt@b*n|n#k0){ zTieJeTi^;?q0LxSf6XKDM64B0m6l`{`BsnD<~oqNU^Q8-mi&;btvXFjQBi&nkqvmg z)+)+iHj^;MuwOSvG!&&Vo0v=)Rr~tMD2-rfDjL9ynZ|(}(yyaLCZ19;mq3Z^bq%v= zJ)weC$NVse=Ajt&850@e)*QjCF^){^H%uLA(e0T295Vw~;QJiY7co7EeOk`IvSFf> zs>LVcNWY1dYE7N`Vb!%FvV}1R?OY7{@URijx)3op8PA!>z&v!U3RMj~i5YbY5NU}l z_T5Mz%p$`C(CVbI9sETarvLn=`=*Y@j9enAb&n*HV=x~*yFYPV%bMK=N_LxCCf5zM z?sP1k)lnvX?WTGlHx&m}942Jy*?3ZiY5nr>q23P;f(5t4Jx$|d8p>*Ev=W+V`^$jSdhThD(|W-XFBDEaXSuuQZk?O?t$nRz{+->k0U@POgk+s0YlAFVB|s1< zN`cQMd=^fGm+-BIlwUj8c^}xxnb?$o+SUsY#(eE z{ye2JwcG9XkMzGvt`|1%H_85T)2L9BUX@mzwDOz@WK&6v?rO_lj{EO=mG-YvD>uS&H(4n*Mfc5Vwh4+L1f z=4pzdEN-;H#ccV0M}11jaqYj5raLDiiL4q?ry`jN!;Fk1ZfFr;$B2>0KsJ)dY18{J zjiii7Iu;=ZKcecI0g@RpCUQ9)nUP^FVNM_|qNXyD6CXw}eZZw%U{kW=fKAzS${@T* zoMHAxO1p{uz>&!j8~|1^r{d|+hKBK z?_DgMUJ=AuuE@QRrK#az(TKSt&Mmr7Z3)oLZ6O*`vjy_OY7wav)~>lxM$(xa&RxO7ca~l zz5V6$VEaMbAuH!#@@fu8S=3E*et2)HSwyaFmztt1 z=8rJ-)F0 z;lv;F-{hBt4-0*(px;^CX@6Xz7rGw%pD2%%W%uDDw%UnOvrASlI(8S%&RlfZcCQ~) zZ!eti`e7D&AGbc~c+{~ho`eU~mVVw9o}c1pT>yeC47#isRoo>(JrG0ORJo)Mq~Uj9 z0$#^hs7c7(7{1YziN&OzHB4u(d;6~AMT?HP|Bd50tf9H615Y+S+uWlOS~kO1P5EtF?UWeuK5;>54q8IuQN z$xtNX^1%~zD?6gqhTRWj8>9vamHiV|>X*1iC7MzR3BKf8y4gyrereCOJ&s9D2wJX` zJLlYU?|IyFeSY`muXei?L6e++8v33?=p|wFW-bt`e*q$e1SC)a)Sv%S{Zx)O1daVP zg+yvNr$=9!{|W|?9=5$cMz{fkV7yJw7XOB7pe^Q5911-GT+FE;&|#W1Av6h7Vce|Ku?mF(nMLP8 z#^`jcQZayx)p?Lvbvjm6;=i6c-3?-+LU9U)hsJplF|=Ve|1It1B2zA%#rUh=VyW`I zS*j_qROTQl)k-yqzvx-5QmfFDaFAnlC9yb)VHJtt$|6}LlsbhY*(hqOe~a~+b=J2C zw625rpwtuJlr5x>H7GR&z9GMA5`HkT#u5uvMMsK+Ql&JO&C}XREt_dQg5jE?sv@Dw z6z4;eeikS+f!yXkOb+UxKGSW`=pZkRXr{0{B;vyO%ExXBLqN#QX$Co`K3kaxR*$cVno8^N1$0i zG59^MnZpAEk~fr#OEZmoBQn;^pgaG~^VGOz5GII}gkrqTV~xg=ENC2rFMyX39-^Xs zVQ~N^j!R8o!S&|$pI>;!!gxfA@IKKq#QR2IYr>%q_(3rwdnD}h$YL_lCka z=)E>3v_&SoFyw`ek;5T~2DayC2S40@_#jBMYH!I@?}&&)VgT<2H!Sccorg0MN8fEh z&dOvk5lnYw9oyo^mn;so<(9IS%g**1jWtY>BH>yAmJ8dtTV=swKTQ zQ_-5Swa(Qo?9I44Gq%on?~>iA?oOY{RJk*D_uR3CUuX7qXYAdu^i6_Sj5TRVSX5KG zXU_PP+53#~WEjt5!xQTxYnJ(Wyz4p3CF>LQYF%2G+xe7jd&cg|u=^hGc(Uiwo-ErD z?|x--B#$MIsr%H?=>vDqfD=}y+LCa`k3Q$@$wP@l$&N%vx;ib*G-SD*@t#*zHL2F= z)^t^xpQ-y@Tm00L&8hBBS5F_#*qUc+GTe{mxD2;H-m}DVQ$nIaJ)iEHzMNs3!4z%2 zbK>TSsf+5l)P?B_e`x>w(9EIP1NS=OZ6A@=yJg@(gSXE!YuML zUDpI-dFz6)0cIAF3V}D}SC{p2-D=|DI=h68y}|ZXC_BGpVzf-oX|4?uWM0=4TCZF& zlR_(|T$GhE)Y*K4+DqRmmZX~BI&S2K+|DwYqHWD_Hc{K)!Kq;XL5VH)vsWEg?#4?;J@N5o|Ue9QlZHd%E{ zNPrJs6Zp=_ZKd_Q?MNW(;{#GB5P5kASxJ5(-^Wq9a)!nq0G)&EWnvj%nVcMeWm7ID z%?J@_^mV^DrkV4Ap*ch1xa^HU2-q7Q@V0BVbJ36tf%ak?#+ua^25vAKj!K#d(AhVn zSpkLQ32(p;a6+I&Gv|@l&8{nXnl+65gWjBJTt_@3-rzlC4Lt}zo{yYdI(1O|B)Zj}DCtu_( z-0X>W#YdkzYMwcEWgNR^&t@IF2|hAYhi=_??r2 zx%h%%leUEIt|Q(B5Rz<8G~eA0*ugrUu?-ovA>EQ?x5m3)a$Dm)&*}O!ouQkT7|`1C`08u4@7%z?-h*%3bCu^pKt}jK-*vYwr)&W>^CbD6@tlh#3_C9Km!1 ze0_^OcuT%V@83<^#@;iV@ZRuV_x|0qZR|a}3Gc1%)%$nTwz2md@m@ht!EJrZu5#Jg_0)#{uJbw$iz0NrE&DaQ4v4> zoei;$B$wdSE_HOa?qT%Lv9DrJnT~iD)U0GvqDeiQre<3oo1WMo*`KnVq?XmHe7Y)E z$6hpVj~|~plc5_+_2A-srn6L!VfNtT+9%DAnxC>CZ${7i^5=)>Sz(Z^I^LJ7iEeu? zCHOGxq2|nEXj=fDxKVhfX^c)`0?L|k*e`1=AgCmZ;}Kjy%kyxD0Aw}g-od$i=CY!3 z`ILGG!vO)e!}>g$7j6W>un-N1otPw#V7nBAZpBDZ)JtUh2eSVYZCx=V19PYGX5)&X zh3bHs?x;v!OGT8FtC#WWByV{#NJ$-rw0XZu^2;2|k$ggWx$;j$S3Q_iN VZ3n69e~nYfxjpM}-Q;ra{0{_P5w!pS literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_885795.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_bwd.py_gen_triton_code_885795.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe95dcae1665005c2338324c3cefd86b6bdf5de GIT binary patch literal 4384 zcmb_f-A^0Y6~A{p{uqyKAPGw#Uxr;JvZ{a}Jx2c2_UU(x*9;&KRd+yjCIbJp& z+FWVI=bU@y+~58Bo#VfGy>0~UTlvou{{ZP3Q>x8b5w=!9SU@t8aTJY~fAJ_bbuMNb zg$0OSf zk6~x?o)KY@!5e$HsUK67h#f3@`hx{B^l^2ChJ=p(fG} zYe-9gb5Y*2hVAyr;3B&$S4)VzXNi_joq-|TcT{&CqiGayju3W`#;cwIj`FdfHc9QO zKB15RrhITJtX29qPX{o8X`q0Ia(a}y6GR;+!I%`5sUt{&@o|OnQ8g$-uS%U`AnHeY zs7*`8SXh06^4bleQ@f_CM4@(BpH3*$twfcW64xbtI?zb*70L}=h35+OXgUeYiZmq= z^(Mt5)TSgf!VWIAF$*73qne3_;wj3>)717wkWi-<3Pu%alJZ2;0g61OF|P2}wg3FO z>!wO3wL~zabWH?9lYmc+9}bQKlrD{gx^yL$=mM=P+9SnP5|ggqlsgmC5||_armJx& zqyoGXCypLJc@%o+PN~LKnp8+!iIVqWY<3;B0MJ1iz1W3BZ}xEJa2^-kZRw9oynXKd zoTuR4Tj2Mm2THu7MjfnE*WuHXnUlGTd1q1dr-w>nL#{n@B0aP&h}o7*ORgy=&$lk^ zx_5b1=zb#f7KGl1mmZBi7%d7Pr3ao?oOZmzX>i@`$sWucgh&_XoAVb}-3I`puPJvq ze`Nk@!FM2i?kRvco;jZF%k))XtO??roN4|2rG?MtKg$m;H7%aK*RrH*!v5~!?w*xhf9LzR zV8syG%U=7~s})|M3TJ_c5aTVTZ}q|fUO)zlFl${)*}z`=Riz&o)FxvV(_#wXLu& z%+H|%JC>W_&>=^aTVxxKH0;%RT5IPt?1po42<9y67K6%s)$c2;@3bOr!&nBYdZt`@ zOV!ih$IKmqyak!$N7mfJ0tg_7E?3qYL$2FyHw{~Lq{}Ga7}@wAv<;S}!w6wCbL!KH zu;w4`In1uR|2lBO%^;EevZ`qDL%Lr}CK4*q{bNc{PZGs1hhzT!5BwwK@ddIQoT%@@z!mB2 zZ!Vu1k}m&Z^o&WQq#b%Rb~-F}BpuAzQ%hy!7?_w@`!BrGsMNG$p=G`${mYWz&bDS+ zb3M74rH=XE6@`PqV8AkMnYP?u-o4bb=qn0`(t~SG*X^m!p9!4~?L+OF^jZ$-4K6p327%%Sc zEHrnfN1g&Hb^h_-la3RGjuR_Q#g3mBeSPV3C8uld_*Y*(^|d|T^V_2D%464+b;0*U zXe|h>`Szl)H$Aw{x8$4he1YFratU)ia|C{}u8b?!bVmfS2ohUNbqIFjF=n9-B^0>e)2D=Z&ItS57Ps{DA_%kJy3_?rGpRckGv1O zt3p38i_n}4=69GZv$hXvrMU|QzO6P6vHV^AwQ;zmV-K4j?R&6qRXF`7ao!c4pAu)& zg*-dcBPJ&X@c8qIXLH~Fd%Zv4HFt+}!gA7va$@U@!k>dA)YjBhm56ylY>}`|1xSvX zu1qC}na-%=QzfpkbZ7p9Ft^s+ASyyyh253L)u{XpG6ybisH-og%E_qGPcFcyETe0z zCfc-NjGrOTKauxeXzv7&-6`olx$zAS})Q=fOk{wfOtbjo*)w1B)u4C7-6uXVBHj*3ILV^XuT}hNk zt}?rfe!2wcz(BZg0a8U#(Wj_DfjX6e)_@E9)*f5*P!kktVXFcL?2B%4Nm+xK9wjASI^1RAgW;&E)# zOwu*ZVx-_{TTe~2@+&h6J1suVqtpZ=yRNdU)roPh?2}mqO*egerR;tN-=pxdxAEKr zCv#WbNwDmqubbvs2z#iUZdxFDWgYYqU!zVMfePu+) z>Og%%^?@N8BE66LZu~OqkZ2HehTm}0^Q%cL3UKBDLofnVk3+VRHQ{M%6`7hws>$%# z(E-mT6LjcmsXJUBK&({mF=+5N-L?H5kkXavbHwDSZp$?%~)<#pHj#sjmHQk zG>juYi_sI7Z6?niW}^C(nmQPrR8o4MMq+)s zlAP@WsxQ$mrBsrXKADqyXBQ-3N#L2Trlgn({vA8=-9!C{V8&{d980Acg`|`O=>psI z{b*r0uPl192MPZCLGxe{mwg@CGwZB7*P1g*d}oR6%${CnJ-P1`#FB4siQStW+(B-D z=rQwH;e(>LEQGT|>q4;5Wgg8AZE(K4*Yp-}VQ7gf>NgIraeeo>{u0-JyZ_GdTgS`X z$?WM3UdXqZZH1PCywtw3=f>C?f9O6xP~rz}kKP%-HD2a_kR99zHs|Nf`6644uC$he z2eRijd}6-G>?yonJXdOdv+O$nP7C7CNq)|pD`>^BrG;{^t0Z)RJG_vS&Gui8E`PZ6 zVR3Nf_3IbPp|@{bEQJpJVeIbVTbE0rQ-AI%h0c`tGhm7Jd~xo|x!h>sP4jY*S=sXs zwr3j-3Xu-lb<(a4I?F>8Jxx-vD#ic{A&N6LcBUXimyv;{ZT8b$#sz$*RCO8&ls}M&Q{yUBF!aUJy@GJyn1tOJlzP7+%-h3-vej34Q^%#!a>kO zHCdUht@|AE*93-qMT5m|22Xv1422E4H%jbifb7+53c3LSSF1OHk-;{)bHi0zKZYXi zk1U{$umwYf_Hc-~;#h_7dM- z>?-pVPQez1NyZewK$XCz4~S4lG5TfddaC?JZzacq-AxWUDE!iL6pB%gtb7)aDVAH-7XU>C zUD0$Y7S$B#&?bE%o3!k6viqIHVM2q}$r*^SuII;jUHy@ z$nDUbeYf_l@uyx!&u;ni|LED`kl30Xu~9LCCm&Be%>D7-k4{GX_U({fI8G~ANp!sl z_#Yz;HM_;tM#PR0*K}OBc*vicuFTKYGTuccrK!Yn+y98zS8HF8B|rj(+m$BOg!~n< z4OcFdJCRi7bV4~vM!+g98MHKXTP}?8LnQta`TvFXKEU2DI<9nVF(LfcRsaQ>@*kNW zeH{d9+l3g>=0EOj$K8eD@L-x;K4*dO{zrM3gC%1<%!v3$t TBe?bPJVwF&wI6ct36Ge&Q zNqeV=QjfqXF<=pX;0XvI`Y8%fpiOO{4v-=a5Fmd*Uqn-=iLD9Urpw#%?&dzPCD90d;nQ-NlmKMJ%t8o+Z!!$Y?V^z-;?#AxKD6fhtr=f}JM>$pRW9U9j9P_FC z74N9OY?p8ajY=%iq_Sjro(cK4sMuk`pB{Q=;8)6Uq_izc_QKiVhHl5hG-eqzu(3h~1<@KCg5eTXI8RM?m#UfmqmaEKvXI2ki` z=eOSoF&NB1A(60VocO1)9>w7)C8iQz7>5&4jc{>2tir5L_%Z03$GeEfn4Kb=F@sIw zHB23Ap*q4`B<$dC;kig;!^AOFQ<4hS=Lj=QJleE@X*3ZJrEn2FZV=z)umL=Tcfaa< zsgraKB21F7K=gbV#)QL$snTWI7!Z?29f)dqwEgGZ&N&@V8q?v3);S)IOoq*IJ#j1? z)e>f>fg_!!HZ|P|qt1AjlF;##a(Pben4VXFq=3n$o=_q>7~a>@)qA1`W=Ne{!kzDYIeQwWiQ>t=W)4lfKhTzZeY2NW=Us?{V4&L&t3n%UhJvpK0 z8}?iO*Z#aPVE1o`Ql=r@;51}UEFW4`Z?&(B-FL;loY?nG+qbWN{c2u3W1qh7morDx zN1YmHWVt>&vhF`@4-}-}C$%dL%MIDOtZ}_D54K6|_TYUmy*u6Qv^g`&N3WgAURWLY zQ~PIc<(fNkQs=E#a#FWFSP-QpHQngEpY311m=l{p0q48==Hi=6zj4g;@3QZ$s(hVuUxe|4{*3IzsfWMNBL?a zi+RjyW2UHV z5TESZ=%dH1xB!qLJV3sepU;cvsi0UV^dnfi0O$IQ`cMKVI0+3#&p~= z)YWnsuxBoe)s`Y)bhwZVNJ}F*Bm-uSho?2dW3AZn74eezBYMJ&MQ8O{gLnZ}Bjdym zh-Jedn3mWlE#S+9PfUL<(F|DtmC*B8zVDi_`wJyS$(E((}#Kys=;Uff36 zC_Qv-c!??S;??=Zd8hW{RF=8JHx>d_ znGez*TvL|VO-au5q`S8U9QjoH?`cn}oT zH)X@0Cax#)hdOfg9ro}>u*T{5!_m8Ky}7pD+co*NU*v-)0jPOlspq5LZv>mx4!oNW zj;skIn{r_31VA;{kmYh*GiNFb4_w-QlYxxUsZzFw_S_mVCMDo>&`@8 zYF*=6@hfmt5qQdtCKeheJ1o3cAm~z0Eh1H@ek%&Heac5DcXp{MrlPcf8@sou%yUU& z%`)=A9c0Tl*<%4(sNRzx{weyP=6-?JKcBQGTQ59=_3U$<*3-`Z`J_GBdhr>o_rFl< zpHJG8t(U0v76Js8hqq|xZVB*$5^z=xLrLpXwGU$vjd)da9*|TrHN#XQw4VI=77gkx zS~ORZyDA6A@sIaRp9pdIHA?eB&8sFbO>-`ig*X?cXXB<(0iP5l4QdswL5m<*?V3RgdzM5^Y)_PM3=Ij*Tv4leTxrz+*JtKDDLf7Sd&^Sb!Pv&eZ- z_WVC`wgn{D*~2a>hM4Gs@q0^u{rCM-p@8ch{4+RCQdP`(9dY<~G7Ggm5lTJm#)xMk zW)cxnsbOl#>0zH`mzc8+AT?Y{fk}!BziRy9nbL1*6I#f70rK_`X?G&a? z1Q#NM)>+#ghGFg@`5!3oPjqmT;jcC=Hf^&_%xl|06s*o%N?-b}3Pw8~#L6!HL2onD w=A2m>S{~X)&|TI382>Epc5m13P|D5_gP6c~ayL`=AjzPr!}(y#qU750KZ99xn*aa+ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_212491.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_212491.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8900d44e604ca6492651399e967f3f4f1cfaeb17 GIT binary patch literal 3287 zcmd5;O>7&-6`t8$E|=tzqAaVmegaFiBZF~bMY3QwsH{S^Tqmfl!nKe#LW{+gvyw=W zTzYnCk?IOq2L>X-2T><62#V&=gHs!b45aln=%FakiwPRFFjWBq$;CG{N-sV2&5}#f zDS^hf4lpxsX5PM^nfboizeb~B1SM_$e&r8AgudVdzsODD@Q)zuAc6?apv6Xy7qK@M zvf`qIk%7~m9skotPXt3st1r&+(UL&KJJS8;!eWp}21>{OTaN^c5RtpDEy*N!M_LTG zqEhal#fX56NXxET!imsxZg!ppFJAr0r5|haNk1t*&1Dhl;a@OM(1x~A(tlt1#M79N7`p10aPQ8i-3jt`Biv!HGrp!lC+sLh zcp%>Au|i3Q8hvMLw+7t~q7`|h!&8@oaUxEHM@aAksg;3h_xQ#g)sc9_EqhV;BI@N7 zbMOl}^>ldZa3gLQ3)mFdCE6`*;h;mOv>uk zDhrtQia`^YNqRA5B`}3dQ6j(r6tR-EiLE8!a|^^S+)+=<)W6isl{-SiMMphBx>1&;S!Te-h#z9b>ti6 z-wWewCSA4ideRtQ(UYsXy<+Cx(3c^yaf>F$Z6li>htYUuQp=e%t3iaenVUG1FKIB{ z9BN4u!k(O%oO+9|umR0iUs@yDL|(TnZPlPTBSXIjQ2=is>vL#a^tFPbJzLkx*Q%p+ zb=bYUA3w4a+74Amt3Rv9Uv=NviyW=oEMIWvp2hpCYAt>erlPURK>3%|_iE8G_ec9m zcq>>AR^;l!y}l>P8&8!tYs#CS9eFhPaImgSyR%=0)vdAeSf#geeY?MU{Ym(^ducxs zuk=-{?SXn^$er6$LYs4C#hrcjqGd_`;O*OQZ@yc(zO%5sP#yj4*oR|J2T#`qPuKfj z-<|rK^!-D~0nuUJd>0xfQ^PLMT(XH}zKRs69%ncnXE-R`VhwNIYkUgYKz0+efZ8rh z;k1z?fTb?j{{?3baYKUDE~sz$3D^_o} zjR=nP5h9`^e}sM``MU38M_j@iLA%e_bAkgP0uC7wpJ@*=DvFzDe`# z{zmD(AASTkIL=cDVYDvHUs&%QF^WmUAl68e=P!6102t@{pHNviY3A(Ia=|QEEC9Go zt^j1SdAp=#QmY1Fc|`{lt3HaY)Vjgs1~^zx&M4Yi-dr=NW-e*K0@+lKh2PaP1>-WM zCiPIj{%CTy1T^Vv`_af@n`SSVObMCBrbFc5{&dl~qz3+s3 zbw76WY3yVzcJg;O?)C4!UQ5h;j_a>ptj8|7SN4L+=K1%3y&pSPoviBB8;?gOc2DgN zJRbN#JvRMVnch>Qn^W%GzVu3!)T9#}Toar6)=GJ$a`WzLb)uTQcm7|Z5ReaKq(-;S zmCsd1cTR1ey8FY&(jYwnF(;w6O6$q}lLwy^0*JH#)cx5d@VC(Mymf&QJXX!w4Qz8Z z@`}@bumyfCNnjf0|A#mErM!t9-bDXu0$2gAB={}7VS-;Gw?K!tK+Bs9zZHlm-@=>l zQn>{?yaij{6h{GXA@0pV0K)M6bH4S@`7LHrvOsirZe?QJN+khR#Csbfwqe;?Qnw6k z;yLG1)-f~LIXK%nm`X_W6xR;Gb4+qH3bKcWM2H2cVHGmA#f0Jn6G~8vlh8vgo`oK2 zaSHk=i<730u8{`3;$Ao%FB*LZl+Ho5{tS(ap2gzsm9JG4?%5hD4^^&L@yBDI1s+8o zMxTUc+>2j^`YXSxhX&k>&t5t1UfH}>lLr5Pv=0=d_PO&O@)A*R1L`Vs! zZ3^&5z!YDi-rf>#c03uqc0h#2E9kotEN{RbKbt`L+D25%1@2yy(@UgGqp!=DU)C4zh? zB0=?L55`nHRJpu!b^Gc80`cR5*X{N{8h$u@$Y%~;Lpb`CdItCXlVH?yydE359q~f> E52Bd3F#rGn literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_254823.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_254823.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8018d85b1098f095b9e4ac1811b37f21cd2a7fcc GIT binary patch literal 3485 zcmd5;U2Gd!6~1>oGamoOX%dpO2`z-}#_YCHerN!dbd%<%yW0ezE#2L))iCRsBu;FP zduN=#9ZQp{QseDI90?S!giw)?U}+_agv7q^01v#tLoHXVbqP|X_9bs&vq*>+ICt!E zWIrr!M+;e@tIe&~qf(Y91Z-24yeF>qbY~nV)&K&$0m>r}d6{pZ_jpAAC z%(=90md8lLi;f;wv_@4<;}^r9U1PI3PW9d9@7FhG{VK1a#n^wIQw1$B$E)IPem2;! z?!S#@LmbjV4f*gIj!Sz?e1`hRu73OS>+(!O^r#bn_d!2+h}lI<7;9~jbGbvR;`{ae zyl5dyw77MUUxx#EzlA?Dmpq!>U>z!ga9P6rx`lke3Ru1+nB{{WDX*7xNlTc=W<$#& zJj&7bfNN+Cd2}7~0G8hhFwd=^>U+T7XEA`54_okKx5|6vk9jC7)QCky6+E`0HBH{p zida0e(fv|Ps#e+oS~V9;~4cRi3Al_61=CzF3XOKjnsim`hFzZVV$QhlaWpLWmGyUhYYced?r*cAv zn+$wy=<5T6utPg#_xRk3D)(m>i&MpK{B-PRul?k;$H6}P@@}MQ>&?QO#kcRBD!0A-aOflBapW@C2*!#R z3Tz*w&KAxVTkgJOU#^CtKRB_|zTIAGDH$JhR+@Xtp&onuNmE;iFXZhjyS#Yk%Epz= z8>R1+BfUTG{iyZN{Kx@Z0#Uq5+JGj!na^^}!G}kX`F8?4$U=(^j9^l=<^oQh?R(|x z`*{x5kHvwdafHZe*k#c?(o`mOj5d|)gp-X-|1UP*0m=6Rd7l|@)gA}LRwFC(JXnjk zXAf!PZi4EwgmuuG7tJPj-xAGcmtQ>La);FNxmn-v*KheOzjJTVY;{lG!kd9Ptn#;| zykxe!tCqxW3Rug0W9*Tp3a)n0i7h|iLmf9U$Q&0_i*qI3VED1e!n^k`Xi0=HN_gLv z@h#8~7<~Q^@1O#n!ytucK12*)rYEkAQao)i6=!U68VHpxrj+!%sxq?vMQ`$)znIb! zki8?o)b_s$rw#Vz|FBW1`2(JO3HUmiIq(sW&|pH(n92E^o-+V;7L=?;`HY$bBw9{t zD^vg=N-WSI$#sqT(#Z@JYpG6~GTN#sXZ0100OZI6GCry50Wo4Usx`Xm)XD`iQ+ZV&uCzJ^euVvyKhX4*HDh12mB*(*fLn0 z8p!lYEg)fU7|1u_L>D{yCm7)7RVlC)E`;w!?a^vDvNcthvL~v3>CW=TauMHJXKRy% z$>Koq=FZCYN{MX0R|)rQ;%X?oHC!0p7`I35o4ZovW2vJob(ESb(kXlNNlWX_$o5F3 zrN_QnjkT4!%CQ&hS9hb$#rTh|f80G-?jC$NTIs$}i4NOSU~F^v)_c3rjz?0*(@>W^ zzRP!))H2@#u5I=+|D=LcY`z-;wkn0HvG%`*L_yq-A@J=&;X?6+yWe`mcR0_NfJ-$R z`yV>U@q73m?Vu#e|Ig||h2qEaeLn^-;7eU!&118H7}jB3!dZYBs&EmEd%mjdA9|#z z;&WN^d8{=u8W02WQuTihYgon4aW}zZt&!nSb3xUmN}t0T_gr`Ld#w3gbjV9kcOW&u ztXT-MJ~*?-Lchnp2$WCcR0S|VRJG+~LL;n<>L5_)t3yDC42JQI)T?D)%yGx_ z1t&9bXbigX2N-PhBpS1)J`JN_)7GiNsp5A^_}76rD)Mh`wz*b z*RG1jw+=A-|1MAJ2jr*U=c zNqMIQj5Iu7wd1JPDDh~*eCWwJKAQ0m?=4}wzA)`45)m{sANfy7;?n{%0ugWdrh^Ut z{#$5T_8?7eSO~4+XkeGCPci?+E8n^FsyY=HUFLWyUg-Nj#yk}hzO^>*ak$5n;O+W) zPPC9ETAnQEXF(w6x9}5V$z@4+Y{B6OwtU-l54nI9u)GT}$^~6o?kMYOmfs3+Z&r|a z?s~U*+|cF&HR}GFjln_N-^fCC`8(h^EZLHIh7I`;5nLMKhLZ+kg)D(b=xAhR$Vve_ z3b&IByR;8DR7JQHDt+YNAPm@!aRIICyO`cRJRNmZYjF*MVtM#aPx zbt#rKn0F~gnLrJbFp10aNuvp@_CDGEd9{C8rwhhnEUxv>#o`Mwb52j6jLm9kv)`cc zep5>=_QR+@IjE*}no=QLQ%?_^T3l6Ox;|9nI^-}k@Y2iQdKp$&r|O8$EE9ELF=iO* zf=1I?l0FYX_@NuWf{%^9lu$Ubel35kI8au)?eR*axiDUQr_>rPMWXgAmDaYRxYbqa zI9X~vX-_^n)Vdknh!#hR%jHAI?N=-EbH(QTd3*Gc(zO0&{>{SMTSrQ5-@JGFW8*XB z(tTyTq>TTb{vrF@Y+3o9J-!nRtsl!DD?C>`w$W1#zF=Rf$l)Iz+HBuwFSZno4?4@u zeI>ci9(~Z%Ruu9%`|^$;-oAX}^4hiH_e;u&4^MpD`X^y{A0kH#uh3SYad#O5j@62t zK^fl;Y$6NIHwvC>)tZAIIlk`JudnAkFn?fCFN`8YdtjAs;Zu!ptpltPPbWw=z~I05 zdk5t9i?2nCe8i}mNys1Q*iM6t4O=)e#R2TJSjRNqfgDm()ESb^9DAj4Tw3%1~ zv}lm_@>qC@&=>iKZvw-sw4c+!u2+4Z;aX__Vpb}xi58;q>f^9Svq2mJZNd%9Nri%xAfUpDv`EgS1EGT zeq|@zT!{YUjr-k0rS747BjxTh?jk7( zeGuQ;01r~iONHh;GO!g%u0-1ZCW}6CFM=SzGx;-xqj$djsnAh9!#HHtX#CIQhq?bd z_?b3Q0_FbaGlIZ7$8|*?fF*Z@?fPsEn+-Iv9zQ@l7T^i-odd_NOO^koPcHzJfQ;Gv(%YdGj~CN$~7IB*48{2ueRV zwad4Dm){mlh-XL)fIuYLQX;M~A2C;6&Lnzp^ek6aS@ZBQbI`#|Sua%!Gp9O~Pi;n_rg?Gz=PW$46 z?jCz$ZL%aB{{NBQlTfI|o~l+}6wki@*I#b^{gZpcQKh<Ep6zO`Kk}W>OoXDj>)yNP-pcdp+`Z#hTL`ZPw6C+ z)Q0Knu)yyRLxs7hd4^y~q=) u&VK&&cHCJQ+q|-IWe)+nmD%TH{{lu>`66@*w|q`8YU(M6dvD0q0R9aj5b7}i literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_336206.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_336206.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5dda660cc5a06f78bbf03402f3dbfafc49cccf7 GIT binary patch literal 3591 zcmd5;O>7&-6`olxcbET?71cKF)U+HZU|mbr4{RBg1<01;w2>J&R@+pU1;brYq)0A3 zyNoGy39M5C7U6@albBO-=pm^Lv;k7YJr(Gohvs62Ml4KKz<_npO}2a}dTHM*$3rJi6;9=|!I$UQXn|8aclgcr!n~;X6kbKyp8wQTysA$Ty00w=ig?F6 z?{{!XchG!*Lu$aW99+h6|1L9|pyH+3cQ3yqCsKmT4U;`k_n%|7MFZMa$#Na(pjYr_ zd%YsKH2DLKj-+7IyVJ?kcqpTK-h)rkYmOX%&-~iwrdk-UvrckA^SZ?bc7BPQ@Zc%oFhidC@v%jyQXwU zGtN;ypVXJAS2Gq>633KJmNR-B6F`sRLCP69>Ms&4O_F&zqfjA9l7+NN`J9$iU{#~y z0*J=gDD~*2JhNsjQZLbO5`%hmLnEMTEEiSkS95A!Ef}(~9FI`^KE=PFLERu3MU~4k z(QZ+kpdPiT6L#@bgc)#HrhH1v={6iHl;xs9bc)N=^I?)up6G@`@iG;b?FBsoU&dZ8 z|Ck%TrI97Qm`tg|i^I=6Nd8C-sb$Ll81vN(o00g5F{V&j1=!JwLy=(buzCPUy zf8V;g9SE&ms9vaDuZzt<%$nH>L~19i6V}X=P`GxuKH2Dfvk`jpvo{){A?vMe$-gF6 z#hOr``>gMgbmp;iwjrJU^3Ye&FQQFp(wg4!2iFFxgSFn;_4WSx^+)~_)@A15eD(a= z#p=a+UtRzBXftpM{7AmlnW|(>clE;yci+DK_Ug}S*Ei?cjE=0Eu+f zSMX6#a{`oEWk2V?l>@R;xTW&|vIq1!jPBCyaAk$r=|U|Bb}F9#qT%U;SA_?SfXjGW zbbNRf-t?}(Spb2CI(nvHu>aI9-QEDxvG42|2-OsqW?)wVqp$O=;_cx02pj#K*=}@T z9#E5H@(UO!m>sEIOdk`0sDZ4zo~wffFkn;MY*RLdFS^~L$#qcOV05$lPbj)Wz7j@pmr8L%<-5N#G8Vs4YQg5eD_A zw1SaIm$Z^jc|EhD(qM;P?2#v*DNvym$uv?>%Z6OkZmC4p7UWSHe6Ivw=hZ8OXoPx! zq*9B(Px+#;Eax(Si-mC9@A}#OW!VNG0}&pEMT)PHGpr9^n|@z@>!&xa%*Z!>Jb%UJ zE^>mc@M)knon;9JghWoVm7tqD$OKG9q3S!(SZJ&F(8l5Q!`4qAUe|n8Uu~e4+(@sd z>mRHyHKhS-dP|h+s3Kd=(p;ojQNuLmCwOf&{29`rT`e$)(4T35lt>iLg;xgCx^ zJa)Mmo_Q$E>`0-<(vgOAq~6<jGTgvrl2~-$x8kNKgtHmpvwQuOcY3|TrV~6*m4AT)7!HZ zY&b+<9^^Ef2}GoLC*cd%d+esOwk<_?C2t;=Hz#<2Du65$@jH0K3VwyXd0pNdLZ}Fk zbs$Ch4&Jy|+MDR|CUycH@;o39#mBsv2*Tc<*kwn*%bo*xze`FI@(e*yKg^_5>Q(G? zAYiDvA*YhMDv#{4L#M8pQTJfXJs6MkWRz)pArDasgt4j}LUEtX5~ZA>Q?5Khxn-#3 zQRtzT$DoH=9*2HhXSDAGWh*9oY&fnrog_eM3ab8RXe{(3++$sR5k&sT+Og^}_`&g~ zgI{{T3VjiJFV$Q zzB3gM*=I?HVTUx*V#}c%yBo*e=g3dV9P6F+V zun&#~fU)5&>%k5AUFY|O6D?jaCAE!o`z?;nGI2=23) Z{YivF&x1p_@0o&8Io&>t^#2rH10}s@=;-Vo)k(!shr58Ym7x-rF z-RvDkT;CY&%zX3B%=i6fX20Le?`1iHApPO~&!@5?LQh%4O`*DQunEE{Qjv-?XuMYA zaqP6YY+#(nNW*EzjytthtDMHCrBCZvYl2e)5BT+Z$9PB$tGtHN&Hp)44QgRkXxuj; zsG$eJ@rZ{@e1OKI9MYnmWoaJABYVv3C=Fd6`Nox-iP5Csa}xwP0Qulu%$8_ESxafo zl@4nauh;tvf^Q^$IKq`wYz5cr778H?S%SqafUgA@P!O$<6_|$JLfB`;kE(9Qf{|>7 z6;`>8z&gVTDIKoW{?hasjgG-X9r5{Va4c5DiZGNa-rzm-QJ?LY=f)G#EuQEJR@CAd zZqKg-mL+DnAo~!sdJ{as!v3AAuU@vo8$oyV7|o-&FiP0wPP+5r9Obi$F-?QIIi-;} zro56*8F5V5`NkuZGczV_0g;|NuS{w= zv&SIG9#hNC^gyd8)0fEUB%1(Ork?A)Ff*S3b-kHL>VWm)mj}Mm-w!>sJ>d?Ym{SwI zGm2p(rZtk&GNcnQu#h(%L1Cj^xoP>0;v41Os(i}6w%vMkHNFxr50&Stt!M3O&)c`QqQ}ZD#aHd&T`{s8Duzn(TC~#q#YZn~CZC9xKNhc6#H+u3^P}{K z>8dzp4{b-Jcz9#+#?n~nSouqpmhOt&{qD(%-2X_~RR6?ZIRMOvbg<$W6mS%pIZJX$e;5P5 z)q?#rXjhSi(q6bjPdyz5xynrsfVpNER}VdR469v^DKI+?D4@WB-epkVM4*O-<5k?w ztk&`VH>=-;{z9O@gN+6>;9&}yQP-s!u!0LP%amPBxR=fVR{&Lpw;w( zT;(nP$4CuY0+hl}g6^#Q*gWFdKk7<{wH2`V3AQ8ba6Xejd*U-HXl{hukOu+{OyGiO zo^(f9BHJsVRiM^}XDGV3U=2Vi8{zdjy<(Wt=u|AuodMoL-vrJW-TN1$J@$vB5W;97 zFxqu);ds1D%O^EWHM+7&zN>zvbe*F(4kQ#*&G{LP3aWN5mDC)Lit~hR2WevkTfD;* zKqf+sMrb6d=gidPtUhZ1(@ZHapTC#V=BT7)v@8UB!bmM>G>}c@s89=W+LY7sz$W^f zMu27#eKhhdB{QpCB}6AQ2yB%E@&bA==Seqnd6pF{yyIbCh{SY~noQ-COrqu%-*QmW zFmvp9b+D3CFp=_;KzauI=RB4or&)*OClmq!gN{Qs{tktWwnE~={9?Yu{lF^S{bT6F zmLxBa6i0qIx`cPyj;+4F^13|=5xy)Gg(XtzT^(2%DEF`%`=I`jUy_s%^LU$%#K zBzdXd9^U3pmDLLWB4A@5K)thar>wm7ovlc8OKRF`KK?g36cqMbQ8?m&|9(|G?Lf9~ zkS4 zeV1;XUp#CesKRr33;4VRyr=^`A7+RutkL|Z#)8}to^f&r} zDtg)IAiPb~9ASLz#8W&>Lqs!XGw_(;^1YOshn(+&8gl+3)R6Q2Q1>?q&}IjY@75I;QR7>@{{l-`|@tMt@LI!+-_gqIn`-j zU%FM{JN|!0?~6z}VvjmO7sr#&rk*bS@*lgG;~j(x#vk;3H2C4*Cy2=h0mMl)^|R9* wST0>%9a$OKMI;*j+#i)bk-{4g3m;e9( literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_392963.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_392963.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7dcc4cbc5a5cc8a346fdbb3a16337ef9f4bdf14 GIT binary patch literal 3577 zcmd57&-6`t8${&PuLvRdmWFYF8ZwN(T( zzICXXdGF2advAW;H&1^RML&Y_t6%VEK0_PYoo29%=wF(mq4{J+a!(;Gim{amObUA%%D>J-?Pv$oH z4dy45l&{r0Tis?3It1HmzX3nyM=9X2dBE`)0U2)!_6mXy{UO_wO{5Olq)iw>LtyK) zZHr)AG;N0r4Ew?F+mEorqfT3mO&DP#WcW5+)-GVQf}-viVVk^Q&5CkV$i}rab*cKI zLZX-o@nTYoV#3Zj>ZhEZq5eEk7f3uSC1vW4lXz}Hp+ZKD%P^`^&peR&=m_PtLY8UM z7paSAOGKwGO;?E$?V$J_if3q0(@9cRq@qOBWr|~zSMnNRheUZM5fW-fqwb=V*NJ92 zKs`n2y^NYjQ?3XG-j5S15KWhsZ@3G!Z6um%l31@(_ zkTPK)0#YXo80a7%ur9TX(;ARN`syau;caP^6g!4}CDpS}04(@?p?o%(qd_I3fO%cglB)`Jw=jZ+b4pQ{@?eul z>bx{U{nz7}g7O+6Du7g?Ajhd&R^Cr06cZJwy9ItxHyHr0T7iY%BO`>(1|qUEE|QYeF04wZ3jYYOV#Dd^6m5W;lBm?w}h#!7pI!g2KpZdg4I_muRa7S zC|0K{)3;-5mKnR;Ki0NJ-rBnQ&emHqTQjp;J?}O{ zb6eiIr@m0Nx6=E_*9*`D3eVm>`*+_<+umdA!|N|KygvkoT7T84%2e&@?Key={P=L+ zzeSJF|22>N$9FnWP%KY<-L*%5qp?T#Plu0B^8BB9*QDr40B;wzbL_Nt{|NcU9h5|+ z|CyP{kWe{Z=>u{Sc+cC?64q@1GWK#B&fp9P-Sc4B@fNe2zOTsc=d#8-tl2P95?G+h zp6_4{%lJ9g=5koG;iu$=j04Hu?_iC4uC;j_)@*=*+|CQue9W4Gu0i4vd&ai@8G}GF z3Qd$FW6)8Ttgk>A2XsZ#r9@m)q+xP`4LBqtPG;1}j70_VV^a-eF>;PDyf@=2>Z2aO zp+ZL2Ak`SA+{!RvPzl5c2u7iYS{#FZOk?kfoociqwZrtskr$BHVbTOt?Jv-j(UVX| z`N~&8v a?oSaGzYLzjU0=u;g^xBvksATing0M|Jo-`q literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_403404.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_403404.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0290e0a6b73b7af1cd358892bcc0227f28c126c0 GIT binary patch literal 3782 zcmd5;O>7&-6`mpYkN^6$E!mD~CzeaNu}n#I5X#Smuz4at|RRH9E5eGAO)w;T)W3}*c`jk z?l}%46(>zQ*3x#bxKu9be{_zG=3R>WI(Mh@VNO&y6(xKAS5M(p&pf9H*SR^bjV@kC zb3PYRefD~P5l5tlZ22q|&s_Y$^!Mf2xZp6yWH;2Ee`2-~9olwDx-4n8SMZ(A_mbey zw1Tm#2;ngJI~@xp(GU&y5{yccfegWLt?&q~fOg4a;77(@hgGMS9m|Ht5ZO9IQe2<7 z?=ZiiB;jx#FSd=9)Whhg{(AM@4L_dDzcZ3tUP_S&K?7=D9eKHD||U^T$3 zmVyrSK6`;J7~alqgWYy$;N*aOUJY& z%4hXOl|(QBFHyu#U3!Xob0oV!Vre;{P$5QQnFW<{sccMvuURV2gQ$;>QMZ;)Gb{Qc z<%#w#(J8O#SpvFxF{e_mno`qhMwj(sBt-F#DgFudYdT3Ns$7sscA4T?>Q-|aVTVfv zm|=iiJe$%?Zzx}ob2_1X(d=-6y063t<%p&$6c@D+EDJwc{&RVBIZKwbTr94RF2>?Z z;9NHId~5;y7}ZF8R9DlvQ5cP;#^g+vq-F3&&t{@0azz=YJ3~311t%v?j!&GNfG>1N zw$$gB6*-!VX_~yGl8l-nhrtOEs`eLX%IHBLRDHYhc0Jk*94xh*RUAr~Z^gU~$ zdAzG^Sy6DbeV{I754(}j`~adTfvE3;xj&-ouOG2&U!@GfY`b{M{93Dj6Q2s zH-0w+%Th_y*=03LH-?A(S!632LbBh|-R*5mkHQC>kRlk|he#0(=|l99{hzsp^+C%nV0AeaMH&KydCXd%Y!S5BC3vhTlI1gCVjvvP zoequTC%Y5kSpodNB1d3z^fg;_Y1XQ;e1Crye zB-CZag}K7x;hAxgqaz?libLQFts15HU&W|qn0S$$Y0I>Y`J#$n$ha(Uq7~XtRk-B z)#z%h)!VneXKha%*Qajr^>>@ShbuzM25;`IO?_|x^lQg!vGs+ug|*|3RTbeu zaNp+8$!0LRDMjyx2Dd`tMkw429VuV9Ck3{oeGO?}eP2^LQl5Gc9M}pz*9bnhaj6+R zRzBYvIIuMkX$(Y~1INoVtw3Kb-3UAj1S$q=gV)|~_3x|CH2Nc};(edLCVnuv$xYWM zZ{rQ^_UPSXcZW8IrkmXK>f{6NV7>71d$-=Z9cvsqx%tdkgBx#2JuS%tZ8cO0)p1iA zHiy1e&)$CrL|)kLLz2f4jZjbZO65u|R`0D(-F$yjdZy*`S5H+=)ef^L-F#(}8zx5} zlyRtbv6HYrLVVXz0+qfs5mO*fbR;8PQZV@0cVAoS_z7JB@tICub7#TPggZvj>j}rj*_EH(J?dT9E>{$6A_Mlmud5m zf+-oovR3nbMm$s`s+Ldb8g&(-)K!F97=s>aVH|p>g$d{YMzvLCV+yfN|_|9Rlkz!#p=o}u!Y`v(t| z&#%rjxZ(dF?QIG9`^&TD%Zmie0Fo21EgEVw?}-%Tvw6q}sboon{rssrnb4^Rm{il% zLXLDO__CVOfYo_3={2KoMvMB|2OuwGQ_8QAiPny;G?~sS`ILH^ya`s>pGW&CG}~^B z@t4T|cNF*sIso~fOS&<9b$Hv=gJ1dm*qb%amS5ZesgkrB#igh&aGcuySR-& f+{o{+Zs&&>;lNk^6S)6h3Pzzr&0zSd&&2;PRPP&n literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_466457.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_466457.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8019039a21291cd5a7bc346eb0acb8ba2b42d630 GIT binary patch literal 3577 zcmd57&-6`t8${&PuLvRdmWFLtyK) zZHr)AG;N0r4EwQeP3 zg+wtGVugej#e|)6)K57*P5n8dE|OSAO32h5BeCqFLWQ&%lVMb)o&_NF;UUUv`3%#h zFHslKmWfVXnywNh+D7p^6wlJ2rjvxMNCkH`{y59mrJHvrr~dPvHuBqM=;x|$t4lUtQQ++wA;3QmulJv;L9D2&i9$(p{f zB1?n0n5Ic7g=Cd9IRRdHplW}Hri8v}M**?YU+%BrhQGTsu>;&IM^qkg2XF7|E6b@6AVOOFEK>WT7LY4S;^t+H5NtS)Sfe5^eP4L=N>t%uHjLhi5K zTWf^IOP3${+Nz^~?#+?Dq zLdt}J2uPhUV4%Z*z`E2jPMwDB@w|>?&cd7<7`|cG?Le`{wKf=MyLzBl<3)71&3OhD zT#?u9@INfnjs1WVbkQ;)^Pn?qar?bxa}9Lsb`U2I3qRgkvkeL^gEw5MNrXtd-f8KS z1%v+x$u7eMt^2nw>#G}Bhqt9!lI$1`lw{Wd0kGinvF}SQjV~jN^d4*6;1)1b0Z=Qz zch6b!{a*H1UxMEN&s*Nj7T99;p25#r83r9!Q6Cya8eZWrngvTuS%IHLEAXPsJY|<> z?ez3egaulp87~o{EawwM(IP|rlaWL=m)9e@lGRj_K%ijl=p`ivH&2Pki42GnYBn-+ ze(m{4Iz|>16494p*+>TPDwEGdV(GNHqR5fc6h}>r0}NvqG#aLSTFC+yEyZ#Q6%!i7 zb5>6*=GDAL{SMj%U=?4YegFl1RZ1sP3gt72EDb7Y1c19C=att8Q30glc{xVivhsc+u9&Dm-A(X|x=9~+)eJ2BCSmu0;#JdCjcsoe z!$^e5gHEI&V{8>{pPy<#8|Zx$2v%M#zxn{E zpjeqMPv4qZ$4m&QaDB3Lv9!G76(4$!)xF1RM;hLq(#0pCBh{0&!QWrFbD`1mV!i#v zt!oeaC+htZf4SJ`ztjj_2H5a;*GF&sdM9*jyJz4o{#e@{dTaa2JKJx~ZqLqbcfH#P z&2M|>pZG$R?sE4-UpGJ#C>*;p_IKY4JKhr;gBvf@y*~wqntzq4@>KQ8tv5_A{OD-U zzeSJF|1FRFCwJRXP%KS-+p$l7r?E%()1xORdHyrrH7R=Hz}toG96Rm3ze2uo8zoTj ze`Y2!Bvg)9`jDIi-t(5Uh;nXK^+Yc`A&1s15X z=LcBBGJb}&xg6GP_$j&}<3O_a2Uz2tX>A^dH5*_cxATHEAG2nltB`oiJY`$|ltCaF zh9=695$LE()>k2n1G=K=Qaq+9(jYm-1{{(hCo}A1Mxp}wxv2)S7&%KA-kWh1^-&Mt zP(H0|kZKH4Zgr3_s03mN1jEonEsQ`vqOteHPBog5+GhIW$P38pFlhp+_Bk{q^f=U3 zy8KNL`NNf-at~wxc;nP3uKVIW@qzC`Y5a+=v-)1c*Hs#S{M_-<<@Gn~LihjA>|HMk zc9dqwPhcrg(cD4u63{f<1j93F$O-e1S5dD;5q5HvPbPHg1B}&lrH~^nBz)V1c7&-6`olxxx3_&`m^dtj%>wN+%kYH=qHxj+6rXLl@ePj1gAx;8dxyg6-A2V z%CjqrQkTFwFc1-GLF!~0C>kBzDu)zEQTJ4|DbhnPCYV;jR0RaM7aa;HkpTm_v~QMN zYOShheCv=q^WK}8_ul-xZ{DB7;UI!??9MOedPRhuFvcxHo!EZ_#C@b96{pcmt;I9g z;kk@&hQ~<5DMybxwbrVf#;4@xWsIBURNp;*quwzisFKQSDAoL*BUQg9sez;WW&^5l z&p#7v_$A&$Ga(LXp@wF88ApS=OzjjEE?@utm21jWJm4|o2igbU{?9R6q6u}aq&Sy4 zXjQyX@6QLUfW@uA?7ZNa!yk-wDHU7(jk>P9Xo;3@9=N>ZndSATKFvawumZFtI<=&7 zANe+zpHNa#t@YMsgE{IL8eRuIR*yIqD`*ATni}In9=(xjGcjYH`gLY^_U91Iu&MtGzSN$&sfYbE#@5c6XP~9soyYlqD3PV zzeDj&DjOzAs+y8hh`vPe6!mEf24P1*1(1OHT#e9YXQ>d^(@f$xLU9h%8@DM>3{$0e z*0gK$hBB{_td=Gx!3TEXjE7L!=xHlz zYAI$4neu2ga?&2(inN#IN~Fgg-;OjFlZ9j%mxrs7Gtd)iEln4`YmaS z7Ahwy&96Qjd*u6Eytpa8Q4!zxBloA^?}JtGnmxK5l#6EzXG<-m>9zI`dhILQp>T1y zFkHM;xKwT{8xK0Fp)>Z_mME=`6+{@%2i~8!JFz-lI$roudAJfjy*~PI?5}+PKKP1A z2dmnlq5-0rizFNO#wd#lCz6kW^ggms3IvBso&$M3j=X&krJRbH0XJ4t>@FQDF0ADL zi}qVAvVD2pbb%A408$SZ>mWA#D)T3Fcm#M+7p*Xj0W<85umWb2OCJT77VpNLyNN|g8|oBG zWF(-PA7-7X3q6h6@Eg1Z9N4IPgV|#PKX)}4G#qt)7U$Z%fzF_QWZ)$ZqnqFZMnCDy z`EH^m4EVA8FDOX_!j*4Lb&qt{Uxesd3ISkxB{>Ib}t&(R=6s=2JBlk*ytwX~Lj5LS%jibj2zWR?bMflHgR zTFz7!^d*fb`m8cQ+op8$&CEh=f~H<2L?@JwFRC#bP_;YBxaPzx4b*@rDtY1{m{T0D z6g@%lBw+_fUWSTFlcVn_-}}Lht7A2&;u8R31`Ck}{E#j-Q1&7P83GYIDB}rKHri@w zy?=b|xcwGHXfaR-tdi1q%I~g?RK?SqVsAz4t%`m2=vGK(L49}39<^_8i{VZ2rHc4c zxuq(eut#^A+siEW4=+5p@E56PTYP1`ry_oHOAgom z#fid1Y5Id}4%f6Lw>=FDQt+7%1z+83MRK@!zHq+uS~b-5nD2Vl2GYk;&sPR}1V81@ zTp96wI^rJ_ggEex!tNl$iv2%0xj_=;|7U)nLSo~2Gmppz;7zMb^Y{QKtb@#gvj7uR z|3xtDc`w-|JXF-c3t97dtTm!P&qKT+Rrng#u!>(`ZGMk6zZ;uQ)`V2?Ygpr6Xl;VW zn$WPuvMr=a%$kL6Kw>bp%eH=(VKV83D$0>_&``f>E<M-@%k@)k8NS>AwD9Eyp@lCu1ntmKq1o)X z<9W!*1ZchlpYa(~Hrk0a+v87V6l^Mkm zmv_2O+2gB|6~6QTXY`(kNxqfmNfuE{Z5MK9E7uRG*mkaC84L)Um(YD4FkoC z8C_jWYa?VDw6YA}fHwl|`7p*$ko-3k{yTbk2Mh0a-tFAun(@G%j37C>Rk-zK7`S~O z;$%nuMNd0ES-N`v`r7q91myan(>(J3PX4XD&m{X75e|PTpTlims2DY!sz$o+h8$P^ E4UNCj!vFvP literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_599125.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_599125.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dba887eb2280a40e005685ddf21676294d196b GIT binary patch literal 3443 zcmd5;QD_^<8J^i)?XIMiWXEw6E562Qf^#C5ICh-e861Oix#sTTb20TDtqWar+O;iN z(kip7ST-xzlt3&3wftlqZD}6w>cI8j%B4?*K9th-A*w=SrY;1M7r#026-r;)e^%0Z zuNCz2)`6M%Xa0Zx|DT!tzL`HpA|V9jPV5(pQ5m79Y|t)JOW1z^!Y0y?hO=m{+2c9v zj=3B^Ct#%GjO(ZUX|vZjUC1b3onxbUPUCM2TP?+$qy;rWN13kww$uW8P!l`r=0#1q z9heJwdC9lYT$n?8*z>Hc;6!Mbxt*rcm8tK%d0m}OiM}vQ=ArJtz-);Yw9S&?+R{O< z;jNawEcz^k1Bx_}RF-XIi#E3!Kxh?AZHcwzpwE{-x0W+IXv-RR zpWkA6LdgW1y}z}*L8oKzvJUwobvPC~WQSO88on=h*@u0;qh1(K$aH(6C)iexSXN7_Fd$I8E3w zEwn8Y92Ig&bBP8FYf&c&Ooe1AZ6+{*>`_9YoRy`a0x=dyGN-0BDke!Xzo1hgYa}($ z8dRDG(Yi1~d9#>f-mFC$Am$BX(ST_g1Z=I9f=)wvR?q2qOSM)KF^a!O@%O1>S|qLM zYDp!=GR4!B*9#_L2TvuKfq*3{q>QZTrbES&TCj*oaf$LblY|Pyv^0uWsJP-P%ouz* z@YV9q*`Z~FESZI5N*`KGrj{T-BY!rz0I3X_BsFB|xxx^PhO#4S-XJ*@QnHNv@VUZ@ z3e&Bjnld2Ii(_NMBNstK`_#7i{IaGF7m}u_F6kt%XUVr92}UL6AE0s2mokcVujeYc z+GHd8P3PK9II=!g8LQ6Jq(*qaxw;*WRZmpjaIQXyM5{+?Q zp>?SuRmIxu$32havybKTb@}|~hrWn^7H`Po&g8R@vOZWDtaevtHhOC_k3z?tH(7+& zE3dC#u3WD5)Xb0i8sU==M-HxCt;o(~$2fv`=j~f>uYFLR*__>&t)2M!;75ax<7eve zGmYNYemVLV;f;M*4I=%lISLIGiEb51KGlLf(;{HhCcXUBr577zJc4ocZqsI~G6KE- z>y}9s{}7eAG7nnz6aGg5s^LAbG7oLm_StD%8vtbiEIJT~mqS3~?Z7Hz0TdT*yXeH> z5VcK)jNrtjtFRG=EcVB_MH6iPCrAs}BDCVq0&VLbVaub8Oee4((1vUa^BCx{x9!so z_JIf_3PhSY(w;r&57ue(Ks3^QX{$x37;*4brUO;m+>gH{TCFM4bli+-7|=^8UDwQk53@#UX~0ac>I2zyE^Q6GN^t^+fuR$KA(S%mR(hdm6ivYSViFK8 z+)V4sRA8S_rHyuXBsMWcV|l$~sRd(MC#o^8j?mBt$!t-7ix7j*0FXs$5qKh3uvXM; z8sNmBM8zgaBti#2A4!0MupOcJJ@OjsgSli$b#rzzB*)pLumJR8vYYyV6vzptQT+HN zmtZOm)%+7Q4g#pGD;4E#)S2Aw>fV$#q+9Pe@2^d=d9fm{k?Qc>2W=VAO{)th8H|TnWVs?& zyYGg9NBHfoBY&5KfVih1<czE*he97Y2@6U=^@{fb_^k`kG%EkkSAYi3 z7vHVe;irQ8vo%}BmIp}t;QheaoDJX@hctY5UkC7O;!8#2ebEj8D%b#fP5K(qu!dh^ zZ2@0258!1HZg`~0Un3g#QfrfZ(WEw=z?%xlKnt>HHUn~KdY5heE~8$!D~nnZ?sQSp zZ>Cc^4QTGw8@6=QQd3D&SBH1mrc>X{h<|XwKNw92P)noGkD82pJ@7TL*X5?;yFbYk7)?Mme+P|&oj7dQV15k2CEe zFM$`HFFsxS)jyt1BqHu!lC!Wv8f)_5xh%Y;itrSna$6$qIr5pbMT5Ywrlpq(q{V}? zdfsF_={~sJgK#&FhMQ-hE*M$u6XXVNZnQj}Gqhq>pCIo;RCaxt!1ZX4#~43F%3o3B zZ>WDyM4Wslek;Dm#qij6;>@T0pY*@rLEh&Pr?|7vPxaz>^{vgRjj25Z;=Q8VJqY|p f`L(joX7i!;fQZ%7$l{Di? z2fd2#x7W);hbDKR;YljC`1@@O<$x8i*cBL+MN6=lbsnUkg{x;f% zNfzg1=KGC+<%pYm)WXl?-qR7$A>V*6SUpy>TfP=fANG}eA^ilOMJk4JBU#U zC4@QR9MY1fnP4cFRg4wN>*lgX5}0yIF=Zq$fgM2!i85w}hVn#TB1%?HsZ>yil3UU! zm(dj!R&^R!1kpS{Mp>hfb*!1olqbe2F)44DIssj?l-FoT%V=3GXUb+N5u^CW6u(I& z!z3wHlZ!IZ*C?K&td=*3bL5oIr&KvhgZkp4p_zu9b0eZbP?QZ*0qa6h&YQ%bxJcQX z3ZWb^OqJpi6-w@c;h<~iKgz%6N7i(*V&s*iHnOZFS0G$H_mZ*%k&PH6Ibv$r{0NLj zGGlU1Cs`RHH1*u*xqL~6>Gn`g>JavY^Ai)Vz5*+BK=#xZ*Hn2luNa2BqLG}IAtxaU zhgXb0L1Uw@L=@?%UaMTIpJ+-)?aRBdLtDYkVEsh>N6pxa_Uk+0!?o4Qls)?-)?b$z zu~RS=iPi=xKdrynhz#4`+Z99AKqXKU>ht&d9*HkK7GG|NFMo9C)dj z!vdGf3WPI&vZ zTd!@rQM@d{*#KRTeDH5>{p*`Cij!0+q42cah3k!n^1_&eMGhTkIk(2h2Xtt`!&n z5!lelqu-l7=vDt)@PaJ?5Zuld!#C+~9%y<}_xc zwO@xRmk|K*&i(`9LrFbnrj`nN!Js_gH@OVZ&F0OLoJp-{fazrg&@6dqHd5;v6CNe8-%0!GZuIcu=&45Z)UR&b>)$@xNKF48H($Kij9#*@ z>;%M(iMM{f8$D7Vt1I;z4~Is#Pj3%A9QaN%I{8qX+>s(1cdV{9xRVaJMmLn| za%H);dS|6RT2J1a_y^1I!k&PnNcBSHLhZ!X>CMx3zWb1ilT#3L5~^SPhx!jLd{PJ? z;sdDrW39qpM#sa|1xE11wP(xN^!|2`b@0qH&b3teNeJVAcoUw>8{6TH^{ys>72v7{ zzJ@of;^)YX@9@U^-URp=L8|yQyfM#}TcE>R!1pFvB6tfr-Yf(l49z`tw*IMeiz$~Z zs0uv3f~wt2B>`2c`$mRM%`oMpVrcT{Q-@1w$IMvg;C$y`Ji(D~Iodot&T5WCKz7lP z2+{!2j6%jVC{rAzObKdn40@==^Uy;rjzd3gIHc*JYp4aUm>W*V%SOHrN*ACSpF(4! zC()RFfdd{)UbMZ7|~7@mujdCWc?#Nu2rM==(>%WI^6%5hJ;?Ukpoluy%Rt>gLrw o1mfL-+kL|SPWr91@67Cf8{x>8(mCAsg^E$n@n&@BR@e>Y-}i3CUjP6A literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_650964.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_650964.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a28d165781a3604ccc4d331d8b9839104c24e86 GIT binary patch literal 3507 zcmd5;O>7&-6`on{a+ltgqR}vd(q;W(Zb(AtAse_wXcGtPK&&AZsW^{jT0Ne@ z&X_9%W_XM=oO9&3OIy9lX?#w8e2$G~IW=&b-)R4s5mia$HI(c8Jf|vXp;=xH-WF!U zo__H*nu%~oi+B>`1)K=&vE@lBPELR4#g~<2I_ML}=m6A%zhUNz32m$7IF~x?RlL#u zUk&;=`9ltuQn4j$v?WwUOSA&>Fse!xvVs=32phNvtksZ(9~*~#QvFu8#abatWa}(R zc12_IqALeENo76>fVQJIX@~=f32zdPMFSkMsyvvnAc*-95u-0VRYaVMmm4ITIY^ z3n^os3c5L`kpw2}q9~v{>&G2bGZ#cG4p`w&bLWKpV!xb8MD@7TRJ&ehH}a5LO%zqk{Lu3W5LtiRe2 zn~}IZwH@iGpQ&B2ryj`BdZ3=)>g?Z=`|T^+j}LdH@QPRy>q3KHmDfjqW_&7*+?U3- zq_L04$BQ2>Hl;~>f@wNiJ6rFnzq;D}&bgMd^R@FUmui<9#~a4olg-F;ph*fXPt_!Q zVuugjd1>jTMk+Ud9*pY*KDO^q z6c-DtOwJSk*zgSoOo1wIdN6VnWLSp*ET{pH6#cv?SJNVs7&O^sbcz);Jz&Tk73jPh zqMpy9#arV16nGUgyInadZ}C4yKjB>Nx3S3%#3@!_*4Z@xSmD3ym*B1g+1L{A1vlCd zi($=C=VN)AzKarKa_`?z_SgfEMF^wC3)kkdMtpiW&Y~u+>Y4$R88^zMl1|L{td=s% zM2oB0Lj3Zj_}J*nlLPUqy7`T`gF1;IIS!%-12Mu}Bn-No8|H*Qz>Z+dr2sNSqP2)% z3nyqetryK~rmU9@Dgd~o=Kxj;C38W^XXiC4yAWe!7d09vWQ#P|ifP(W)GDS@(r;-* z(Px!mD!*0+Q3dS^Av$p)Hz6HfbsDZHj%r2EQ2a{E!qZ?TjTKTAWwvZ+s=~B578@)W zT<{@r#wPnwf{elxyV=H{p|R0+*Ri#p)gJrX+fryHR14Mn>#4QOYNqk#>U>k`w1Y48343RK2G|i^6$R7iw;Z|xBo2E1kM*(x@@u(u8~6wU(J zQH60Z!}rem!K-bc8hk2k0iU*j8&XvWP0%KO0c}{tPvM*3)8++mH3(S~Ql&4Tje9EJ zM4vX%^G&iKZ&O1|n}q-h!pS|h^?U5T!{c05Q;<0XRqd^8S|beUXvEYEQ%R=`O&Qu_ zTkfA3_76t3)VJVptHHnae0_D7EP_8mWxdo_|Vd$Y&MxcjU8HIk- zU@-1RVZ0TFosK!a2b+8gn4X7f{0hddFeDD7c_r8SW<95=Cy9CZWntQnX%YW`(PDGu(Bm>|f z?PvkuKpyh8GUPr~aw)=MhX!(4lZF6S4O6R>T1ZHP*R-OcL$>84cg{gL8%HC6$Z!@) zM$fB1M~>pwMoZ%bT`lLe%j7kX%AN`%1o`Oq+--@x*3Zv IMI00U0~k`|Jpcdz literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_674736.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_674736.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e3b05d47a05faadab655f692e17c6311808db1a GIT binary patch literal 3425 zcmd5;UrZax8K1Fty=!}IAe0`E&^VVKighIvfEjj9&4TDhJD47S%d zyAFn1OG_FlSb4xmC&hh3eae+Q+)34)Dpl&AhxVZ_xQ-e$r4rKRg*SS%s??YEo3(e@ zqs;6%{TM?nVJ3lcK#TPMG%yKnm?Z(kP-Tn58M)J3x^+r@DLG1a1PD1dOU-j zaUdU@5irv5g5$?EZS^Fe2@A^eIzE~Wkl;OGt35Fz5kW%>ou6MvL@i80$Is1%h;&bw ziFjGb_s~ob9)x288Z?h?GLBfhA|T#6O3B79E-KNh^4qdwnpFN`NFba|pDSW!#hd3$jc z@T@qe74v0q!ZYv$3qQB-m=%63y8Fgx86`t0$`5nSwMYh-kWU+nOw`SJjV3V@(&em? z#FXD`GQt98jzx-8pQGu#nk6ihrs=|*#)O=nCNQfrX%}zZMoOg6}Y>Rl`h!rI68chLZ!sWfr`X zrc9uQNf=%>_$elyEC0GZyrk1bqnOTU!}ICPBIKYKhSGCd!5lVdX4us7#bFo?=SJ0n zPV*`xV(NvF^TlNq)a{{~(ILBw<749&F2W4!Q{CmWOGF(hrVT?~)M!D=(bJFv|L%+r zps~@jh_XIZ9jfDIq~E@_)7f3iH9B9lZ|p{6>u*%wsNJed&1k}&+=+J7&Q!l)Pd<|) z>rz#!;o6Oj*rw3v{OY5z4~-}Cx1P!q4SC{M^w%rDTxrU0+gEnM^XT;(Nn+m zBy!rm#@&ur$JZ}cFW0;4#!vg2(X-%O4zEpCWiSy!_upE5YwcF8ull3qnG44uO>ogkf?k%AC0QYd0R(4%QM1SOMnrf0U2F-K)SiFEn-BTsgR{6IoBRNr zV#!{=cDQ5pZx8A&geK z+I*idbv=ZPBTq-eETVSCUWLZZ+ zE1POjU(%?m&#I#=^1XDfq+O>}r>r}rn>X^sma#@0rGr%#YH?38eDwL#LCB8bo0MM? z9pKIDym1Igm;NHUbE)D^=w|+`h7-#MY;(Q}*>8 zNxolNE&XJv)^+ziu6?umW^J@qs=u?b+*Ho2;hm_$(Y!ioU$Niam19rk-iF*;?`p~? z?JIlTUwSyPG12TkYftaSyWSssdg5H;#JNXZ%@eOby4{SA+EY7HymqP~4Qz^kl3v@5 z_ip!ntr;KMmWLqn+W6h~cH?KZU!88oZ*I#s_mtS$m_4~G4Ae&n&8UBH!9N&F3N*=Ai|~LT1J=F3T^kc{2vC~P%_oJpa!1Fhu8o!6eMtkv2d+P5Bigc`> zte%8lG2R^fh4@MAqu7)1gnji{xTp5xX1LG3x;Jpzo?5%v5c>arRv*Yn>9$i2@{%#< z0O=5Hi*>Z<@O%zlV7&-6`t8${&PuLvRdmWFy~OVJj+^dgpSZA?|bKz#8{iBbf}sc)9t)m9PE z_|~Cj=Djzw@4fkX-#q?50 z=LC!toHF%TpSF6LQ-oCTzh#VD;AH-mu-P7%^T>jNQXSu&C%Y7%>^^*K!7Y1k33Gni zEAK5d7vPW*u=NI4an$#Wshy#o@#&vWyd}*ftVJ=g3n~wF?;d6g)S+#a6lYQUy^J^8 z<0ZFarm#Q5qGW8iHrpCX9s?O}gIfbzYcRp^u(sqiJO-bJIVGP%m(#blGQ(^5WNwS! zWPUa^9^gb_AEhHuMd?E*%tDC(XeY?BwQSy7G(*|?UbE>&Mr zNEA~cUQB9HOxQU`{gl%))SoBnB8g|Eq)gp$63;CvRLH1t8Aes=SpZTW8K%5e$TDsE z5_J)6ndsD|=_*m89TdMu@hlB$I!VfkRFsIiLh%gcmAppSAyJ-5goK*WsJkfTb)uOL zP)||%AfqPIlq-UP595RiMAK#JS#8VRtLAKtMMC62`S#tR6_up5d_18HEyWXQa81n( z#ut^GKBSSvkgjC&L%j{Hg@vHY+7maC+?A*;iv@FhYAIYx=^9 zEXDG1O_S0J$tf9f0=)1*)qW368GYS}0%CQrGFZn=e{XqW2e{WNuhp*BJ;!nz#9|yv<6P0u2$)}-?>SATFwy-h!iS{rw@+fq+5jy)hd9ZeW ztr;3GUwQ27sEsvzz2)&8f3SM0a;nx@Kk?At-?B4S8LM8XT&Q=|wYxpdzyMeab=QST z4n8!^E`R*P%KP=vhIqVuc}H-6G<9QY{d(;e8>9aaPVa&Hh~gso0qfmF(F-J(a0WOF zDH8@FAa%lkfer%#>r%@&eHymM^Cprx3v+H__=a7#1H~HGI$)gb>Vaa7m(bxh=NVLR zC0@6~|FBRu_5)7PMazWDgU+zU?e~_=HPEfwL7X}){CI24HYm6Z-f*QS5h9&>x201S z4E|S0b{Q^c-M@BOU){tyye+MgV#jcxq6~Tj9s3! z)6>Hd7HE-HyhMnyTu2f{iwqA=Mv}RFL67K4PE$z|fr7Q8mz6l&JS8G0vmj2Yxya0g zwHG6qI9XIkL|=;MB3ZzzY#|$oXEN%FB1cYB95pcxFpOQ$XqfUDB?nlv6wfPEOllC% zIX$^pPzxIMJ7^bxRbq+y0TlFADU(bql+Pw}G^k`0Ft1Bma!sM`7Dmu;PATeA9&8dx zU66*U|9U)AP~IR!1&~S<Vxj_dx4DFQF-;oj_=Pxw=wWxqZGK`&*!YTbSB@ajFSzp#O0oSbe?n`a__C zVs*MQeS2mdGa;bDjmh$*^74*XeB?dW@E)rlX?pw0m!5`>)K1o8zrA?(Vzck%M(4|0 z*B=c|GzKUBaH%&x2vdTg-0yfg-z~%Np;nX2VEHV1X)o zzK1m|40828qYaGq&~57zC0L zXrde$g^s#peHFqupevd#CE}VQ#mFf(;E)VEnGq*58WqS-O*N3k$T`CB-i)iLk9q)y z3K?C4R3k>Y)fiz=3B)i6MxcjU9EE;VWABNbYP2G?!}Q0I7mzn$(galPFVK|HlTb(b z%GW{U4_EsteUJg*jZ>ey9*FnFhrWyD@u$A-+6PTvPkH>w3&+b>*57Ujz5hS6cfBas zRh}U~f~7=7a|g*QK+|vw49{dBCoDi-MZFe9*vU~omDH&ZFjmu*VxF{-@Ldz)smn|p z&FwdLiw0UJAT6pH`BP+CYwfQ0d{&hU8Ra6m23pyF6j~aZT^?in1q%KXiT^_VyKcmJ zKkB{FyUTUs=xz`}{xn~i|4IaIk4IdvC4c287HeB3ZB7TU&7&-6`t8${&PuLvRdmWFauGwfHl`|IAinsfL@5H~)Hh4+YO4rn zeCtp%^WK};_ul-xZ=OCEML&Xaz5Iuz-+B@Hk}+0sw}`zz0&xe)NXBV2*X;2eHhC_? z&j}bQIBDv!K5h0grwGa5f6ExRz{&hAVY4+b=aB^kCELC~Pj)Fj*?suff?M|766XB2 zSKeD_F2ErrVCxO8;;8Q#Q#(UF&Mr zNEA~cR!C@3OxQU`{gl(w)Sn~jB8g?BgiPHr63Z?sR7k5a8Aes=SpZTW9-_RK&oFKJ z5_J)6ndsD|=_*m8Z4_Upc$Nk=og`#MDo8|Kp?HS!N=_r}kSNb2LR?L2)LoErI?+rA zsHY&kpH|~3$`!!C2QfkgqUkdAthVItRdcq+A|Z01eE06aib_&iE*4h?mSXV~xTa?N zV~a{wAJ9mAKvy!k0pJGGLsC{H843K;)$HJz+^PiP7AwV7aC+?Q*^$vv7@=K~HGN@4 zmIiY%O_Ndz$tr1b0=)1*)&2xc34PO!0%E1V++V{De|KqO2e?Z}Rn zEPQC1UH<6D<@ahMb@6!V(vIN%aO%d?`nBq>Hb(v-oZbWX5yb`aBi6f#qUT9A?hJ4i zQYH*UKNISR=M5xt7UtZ<@D00e2Z}YWwZS;s)dR&EFQUV3&NHas zio9-z|6!qS><65ni`7ZBTF-yx~esB1GEtPD`gO z82rabb{Q^c-M?{JU){tyye-X=WXEuzB)bj>fCZnAeP3#6`~|{D@3F=WZUHkD0JQ>q z_ktzg?`4nmCHM{Syye|&fh}h58T_o3VbF0E^`Swe;S~;}S+LZU75G`S0x!zUGj@5_ zPEQX-SfE9k@e(1*ay~&6Ei%+U8A)Vwc|D>lSxqGg1Pa!UUQ%Lk^OT63$bdMZW+OA_ z*ItaIV`Nbw5q&9^jbs3?GWkp-mQJfHiX1sjan!^(z%X_}qhZRYl`LS` zXZ6HlUd?ON@1R`(R`Dh32T;&grF0^tP(G8$(x8%7z`QPLi8Y0~n;1dES*4&$Ij~72 zbwL`U{%f&xUU{7m6+kMUmt)i|D<34{iirx;-2}g=oAiNK&A`HM5q1wKUNv3S*!DIt zj6|3`=tLSa#zvuPUqMqsJAu&pa%H8wa{F9u@QXn2wlKB*;#33LK=0!~u<}~@wTD0j z#maPf`u5B^W+LUZ zU3=6&QSYDl^TkI0rAFv7z=p@WK6>-lJE3FSJp=dfC))PVJKI;TZ@)9UJv+DE^I+r6%iWKB-2hFXaP02b-+eFbcu#B$ZoE|Y{uCT){#B;RQ`IZC-!!@K&w*jbTg-0yfg-z~%Np;nX2VEPV1X)o zet4083W>+eGq&~57zC1G zXrde$fsVRleHFqupevd##bcTx4U$uAz#$oOGQ&=0Br1@fn`$77k+X#1y%|?gAN2qZ z<c6p5Pmnis8B>oHa?z$1@ z{jmE+_b%6oqq{)_`O|!P{%aAqJsxqvru^(tELLB?GyT!@E&_5RZ+7=OAN1bu-D8ry ZUm`4i9Xx|OzLGHtA8&*rHv*7&-6@Yh^%jNQCsXt4yEUT89RubFFDkUX~VJC%t9Jz_i!mUxeY6D`;T~VY+ zEnnLmo6hkQHjaM_jZH>oWRzg$Dm9uz zR-a8e<~YPCC~oPgO|8^2t8j7OXUAx7o|PTMTElHzQw zbyCY{yZV2Yw@Gp*61A3$4Ci)LL)K+525+#dU~3fwvVv`{-IwZQ!x=?-O%W3p>dks% zhpp>&8!p4K2$r%Qo7M&!VF-qcS~onhV~5+OYlajTE497+@3!j^@EY)=Ys$K9mi=~3 zmA2uQ(auRU+&29UHay8riP@blG2AY92wS?WO^8S!-29ea6|=PUmU#@eBl!OL%PB>@%d#r9861^7G6|vN=e{8@WBOHdkTfgd?YsI zU(dZ>43@$MT=Ksbg3ggL`Ok77Exg zMPI_Hh@B?%L|e?6hS5#NVB&S^z(3V=!Wx-f90zsHAXh!LH5pphP?x288jvW-C@-TfkgW6NKWO~{C5|J@feyGbPN%w0PT3)IhBNLxEpg)aDzIu5A_7FMv&i7} zrrL}~mx967NKp5aH#mb|oQ62H>UDz1PJ{b7Bl89irRyzcZPpt|Z?Eg`sI^XN!(njq zG;VYYPbC;VcWMZ_J3@^oKzqkL0>;yCz)FTdBMNr_+yf1F+FkPt)&RWP@oZOb1q5>% z?VZ8aW`H%PFm3QTjU8e5x#;Y;)$p!Jr-S3moy;Pf|OPWcQ2+a6Q7b$ zk^qme7F$(_=T1xyCzMnaIOj;lQiQMIgEXa-43J7)RxlteJVHF*l@d$JRg6_koIt_| zfDzE6uUO9EuTUQ-Is$Inx$>(x7LBE(M7ZJtx^8()+|&un1IurG0VWbI3P9E9hgv5f z?xlbD?6AU@Vd#sHwPuK;$$a1v9%R-s1@^~A;qISYXZL+#el|Dzlh8T>9&~hW-rBfj zhV}(_p3m{?xDebN-xx2BY+No27tD!$uaAOcZQ7hL?}B_w>!yFhUvBO-zq#*kDohr; zo^}nDx`xVK!=*FBCI8pWnFD`IVc^Xdo^}UI-ND_ya`#x-|GYT^eyoqb@xufE+2UyN zZZYzt@5RTRkKOOKJn5P$`=_4>7Ww>nZ_Zn|QACATw))-~cza;C@6q7H!9C%!MLf~J zK4wlIa6Lu2#61U2uLt+rJ2r1`+%8H#{r0zk_gXGBGCmKc7!By!vd{e*;{>)1*3@MS2}u#>vBrHQ->%&MRQt zcJGedo$dW7pAFR`7LBgvmnjVfwni3A;0V-n~8L@Xp?VYgz zho7g?PIwr}DNG@6owR_PxUixvCEy9eW`cxWft(qE8ggb7YRH)}sK*)|-Db-j+gk&N zK=ZdDYkz;_uMa=lOMU;3wO%yE_3pOo(ty8 z`t=fb?*B*nNML+zX2?2p0nrK|9)V?%rV2$bCg7F11aCDW)LNMGjX2^lowxzfnyzHh z*a9cvZz?H`($#v6TG6+nMZA>_2uIa~{42(?Rso}MC8^3w3FT#c3$)T}MN2?&Wif91U^A2YOl>|j`5W%y{U4~d1Vo3k6UM+`K# nmaJ;G{n7b{=a1>g@lC`a@nhdb)b=+SF-;fB{@yjO<;}kUrO*p{ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_846578.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_846578.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a836967dfe78b439ee2e177ad49c618fbe45f7c GIT binary patch literal 3434 zcmd5;O>7&-6`t82?sB=L{%kpx{3F+mY@)3M$*LW=wj0^99U!)nv}qFy35X46C6OYz z^6t|5=@MAE228?0NdX1KASjvw1)5eqxIl{}r?@?|hh8MmXoRT>2(T~eNsd~e=%szL zsD%fZ}kR$6U^@A>~Smib+X1iWprwN8ZK`$BW>YQU>l)4`^Y?r4fl z!BHK7t&8t*A?hej)WeQBxW)Rl1w86(nrRA7n-g>7`=Y;o7%iZ9XxwDmGv#Z=IVxl| zYle!3J*}H@Oa*N|WyP_{PDos#oSmV{tZ7V{S~ig)G^CkYZc3*@#?T1#8dRDD(LOay zc`Kh~*6e92n$}IzrlMsVrXFvn_)Us$P}Q={6wwp&3DcOPc%1V3tYxz8r6Fix<|9cX zlQ1SHE#0;}7pVAVV%9b(o~HuX2b5cyNd?og3FVDEVBBE+7K@?2&!vC7J~(HXGuEt@ z)CZ@vJQiyn5bYgXfo zh5-p!8ohhF$ov(IPt#=RBBFEja4JA?}rIYTZjaI-y=86pru0 zOpw{bvTn#=r*u1S=8{3400OPE$peo~9F%VI3G_6QPfi5SoCIShK`UsZq4gFVhXPmN z3j*}BGx)hgIA<@)ki$8=BTVuIF{ie+R}i7yC-i_V-k`3{-a|xm!tbFU34Z+F#&(k>oX{jb2Atx;Z{nYqe{}q z*{P|#k+-M_T#%fm3dEefkjSKFbgJqZ9RTcvm0E<_xALHz)i0Z-VS1rNLv_Z8D}e)u z-=R25agHj6nVL$};qY~HKR7}0g6E9IF6BKoF#A}a8jM@>G&Hdx)~ArUXhV|k=9lu{ zpQ{|Z?QC>*t(;gsQId+dI8@XiOv+bES1QAmoA<64@kaN7YDaana-*bd?eAWBWBHAu zvLT1dYDv8ptH&wC+|rmk>fYRxqo2ypt;^3LD$;I<%d%bZ>;scQj5L1CcpYbI9$D1CF{a*z^H#^xg=LQ?nS&> zWTXAS-=dNj+G$5>XK|=}x^%j7b}e#bO*mqH3FelBznR#XV7s;ZTQ5>msPJF!egcn{ zz-`vz9KpTUXf0s78EekotAT?!h4Xz5fCG0{i|8kzXF}rxXuN;^fjNLXi1Y=}Fu~6- zH!*-FHqkWn!{pO8Q*40zs(q!3duat;6@J-{Z=Ze zQ<2yU5Q-67x9mhxv-HH!He2*y%W$xADp(nb3+7jtwg@*q$(fAWyr_wXsbuO_K4V*y zn;)Xw0{rvCP{Kcd3QG9rN1zI;Lw=4t?D|^E=duRLXY>o^cVT_mrEIaM-wuy4evH&Vq3EB{ zp)D-k?OE#C;Z*$Phr_==_nUK{ArN7&-6`t8${&PuLvRdmWFauGwfHl`|IAinsfL@5H~)Hh4+YO4rn zeCtp%^WK};_ul-xZ=OCEML&Xa`}Q9ef9pl)OU78m-6D4X2*h0^BN?aBY_rF+*yOnk zKPzCQ;H0U?`n1{0oFXKH|1D$OJSX$FgpJm~tVb3Ulx+L{JlUoAWcR^i^KRL5OPKZB zUU_e!*#L)>fUP&Uf}_6YOzkxFj7_~g{-!h?w-&|3E~q@z-TRm=P=~fzlAJ~D^)lXQ zjThaHnZn)-i;}V7+GuGgdJJT^4Q>@|t-=Ju!`h1+1RR=1gh4#D=?Z@`cFQ4Bb29&kKHK*pPby@H@af5S8p zAyG_)SRtWBF=6K%^;1qyQ-6-A3nZ425;AqiNG!XcP$8|xWEfScXC6p>c!=^^KEt%> zi_}H5C8AT8rmIAWwo&{Z#WOUh=_Da5Qb8i>GR4!BS8^I*heUZM5#nlEqwa!~(}`v} zKs^QNgR~k?QLX?6K8z765KWh4I{KmvZl{3 z%hF&jrfE`2Az39&j)4~*sM?>PDWPxLQ9!Kpm-}nD;qNYuZv%I(e6D)6=4k{XrHQ9v zsM=9A>WBL3V&A%47k^f|{3H;r9xIBn@FL`-7DeVit&jXeIJpb%BZ>>;N33@fMbDFL+!^33 zq)ZryfYb>C2091`tV=E9)JfPL&l^bQEX=uy;Tv|{_7!VfYlCsNtNV&IUPK4moMTYI z6?xqb|ARu^*b6v87cCPq4?4pZx7S-XS3$RK2XXSC@Z+sD+o0ewc*B*NK!~*Kot92n zF!+y=>@r-?x_{%czPgEZcw3q!$&O)PNp|fM01G}J`M%WB_zQ%Q-eZj$+&pF~0BQyJ z?j=jU*UKL3OYj@udDFYm0$a@9Gx%96!=U3D>O+G_!^<2-GhnGHEATUD8D5m>=j`&V zo}3zrut1A6<0V3rrF?=YT4bnyB9h4F@_Iy9vYJW~2o$UyzO2OH<|z?5kpXc+%|@m# ztiBvc$H;<0BKl%18_57(W%8LwES*-D6*+Q};;4ynfMM){M#Gd(D_Ovz#aK?EVnTy> z&gzMUyqedj-$Abx{W{numZyz)9BDu7fxFUP1`Rz6I`6%!Sxy9s_#H|Ybfnt_GiBJ3VeykfelvF&YQ z7>O`>(1|qUEE|QYeFaSkZ3jYYOO@sF@}2Xw!7l>6Tf*el%aaXg1HDfI!OCmp*B${C z6f0BZsXNnam;$!cTy7x%!P{Z3(y7Vk`sCv9M_`8euE;f2zt+&6r zdHr$!c)fr8&zBngmm8rg02?0f+UU(+Z-jr27g=g=b{k!j#ZSS%5!Sz?_-k*X)&A-ZId9r%-&YLC|esZ|y z-=fFo|CUGo<2&srD3&I_?bxHg)7Yc?*Tcsrc>Zs^YeMwIfwv3WIdYc{|@Zs!GSK4#58*C6qje$KZ3IfFnl z3{8|HBhXQotgk>A2XsZ#rFcwJq(O3m4LBr2PG;E2j6?;nLyf@=2>Z2aO zp?q4`Ak`S8+{z$fPzl5k2!^4DS{Q+TL}TxXooX~AwaxU$kr$9RVA42L?K5af=xL~} zbmf~M@`o!u?lo>pTJV0qPc_QG|)8M1jAEl$O-e1S5dD;5q5HvPbPHg1B}&lrH~^nBz)I|c)4URdbvCeSUkltQ76EyDfTA zawLaa1Vij*StwAcx^mrV;*shMR zLXPZ&oDlQYVhA*m-<%P=p5Ebl9a+csI+hf>Jd(N}*MY54y-(wuE=OQ?cC-lpFMaFa z$9x@mK3^BXTTa-0T~QBb-wswroTwALC$(=0Bj(!$6!%RM_6*{fin_g!GiKVnagGWZ z&6=a4Y0nx2HV9fFX~i*NPcyDi&Q8-nj+iq<%cx16ii-v@E$W+2o0@%WjEY3d&KOkE znxg<{9hgkYTlow-+n$9z)=gql(Xvfq#JefJPVo(@ST;%OhFVaGIZyE<<&B(0*bAp( zE~%>->IZ^q*&6T)Dm#uhr>CiqFw>T+9+e7e&L$QcZ)s_Z^0zcX1!CDcl@?m01(#&8 zcM^LF{&{U=-XwEYPD>agvsz*f)HSnT)n<&WJz|l>h;3wYBd{7tkEvOcWK_`EHnXEg zatkV;TT3-zf+45JPaHdX6ex5+ZEv5R*VWOSW?AZ-L9#}g41*p%Xx7&uca+fQG75#4 zu2innzFb%KmM%6Ty)|Vma-j6;dN8zfs&cCOdiB+JUi-z=+fxsMM@tvh!@H{AuHo9r zkHY)Qd?Vcb)@v&_mT%Nkyuo?eTd{{2_}aOii3>d}j3sUa&>sn)j~UX%A#gv}s|9Qa&8pv7=y zxVo!$?m^&7rB|3%Co3nHUaq`c>#12kAE*cS|2JBFv5ufrw7+&>c?i^(150v6uJ$nl zK9UbV3WI)I;7CMqfegTyXeBy^Lozfsk64xU~+@GeuL#5sGWITgPJet#;s(+AS}7j zplw*Tn$Rpm9o=9q_hQDpi(}r!@wh;au-)tnpl1pDHM(IC_fsD+tbE$GC|4Mz+yb=1 z7>v*g$6$n3I1b}+i+ylARk;~g5my~g0VS7U)7PL`A3|3`A9qDc6HuB0;iWy5J**tA z4*y#GZRl5_2mZ6A^Pl>At3Rmw2TJEZ9@N zX6{9}j-$cmU8pl=T7M6@f}4((&tyzJpEk~tYj7&7kJk60+vG9EpCIKgDD+pfcT+-~ z{KLW9gPUA0eyI^Z{L8)X?fs00`4*2j#l=248o`6riz}CxFK;54uI9Jcur-Eo=riRA S?s=qR6y8_wiro&n^8XFfTL>Hg literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_973282.cpython-312.pyc b/src/temp/gen/__pycache__/l2_norm_triton1.py_gen_triton_code_973282.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a39102e1cd476a3b47e2d617125a53590599da GIT binary patch literal 3741 zcmd6pUu+vm8Nhee>-E}eJO7%bX&Rbvki?;q-w*&H8uD3>DMCRXnj$1*a_JI`S@c*7pOy7O$oN% zI;mx}Q~N*1+a$RYiF!*$hI6N;A?Gp}gE!bUu(bvPIl(s9?n`a5;fx`@u82tpjb@{< z%Qkeo4VU3q0!ulMO>2{lFa*Ozts5TMvCHkyHA71HsAGfJ$kDZmmgC232NEj^nZt z$wsicLQt4El#GVyh7le{s1-4)CL`*?f~M#i@zFs|m#`km=$JU}O4#y_pe*6ArppAa zYII-vKau`%V{}EuOIk*XDx-^1bO}6C(<4$$N$aB;j*jX|Dl-bb(PSu+R&go>(17bnI-FNEM9(i5qRFRaLsU`Eok$dZE7N)iu%4=%{sQz%U417Azw z_5AClV8z#GzH-ohW-G86C{31DD(yq&YeziO*;V4V&zH}Rl+TQq*FNGI(N`GG4;N>P zFaCV$tykZC^@(THoZc6^N^SYCnlm8a77MPttLWPn%dKDB``n}G6JhG9aHT9<`Q7bz z6YnG{!gX^JG>hI1DgSwM`oQBY4CV)m9i`wCPoFsrD!l%}M1G?1QvRh$h^Tr>H3+oKI;vncQ*8Q)mtaE;V`%b z8aKLyrxJ{wJ2eE|U7^kspuJ-O0psa6U?oGK5rsPd?t!K|?XLR;YXDyDdUk5J0)jcs z_Re7IGr*ebo4}#);XffA((mYyerku5tZR;N|L|IOpkK*G6-CziQ&P6Sb_wRhH2@%h?q~x-41!ENxCy+1- zUCuKX&F$Kq)z8L7H}u2~)vH+91D!15bkfQf{Q0Z=vip*9GJ z`{^HEJFM^}82T(^trg;EGVi&B``Ptuk^O;Dyz?j5*?q6CFqfbEVR!=p54yUyZf@Q* z!~24}z~}i5TnuhaY)+Iyo0lrW1#@y=^iq(l&zO_u9guHp-|}zzE3N(J*Y^D_#i>%y z)1HxX&q$?bw0vf??EkVkd*E*?4*m2iPkV#q-r(LqrFXpIf8Lx0KQ<=b_}+p4Y-y}? zrxblW@WP|+NA9=V9`{UF{4iZSpGuBef ztC&V63b~Q!M`~g0w-Bx^CeGyk=krO1SD)?rZ-Og$n$$*fNN-@vIC*%n1{_S;c^Qn` zu73K>pGq?SOx7GWYmUb8;^0UEIzELpB%^0oo6}~kfu$UP3akmA!W#QbYjfGGxf*K| z;7Q2HZfeb703x37Ar1W@y`BgcU6v)_gLV2WNQk5-_UDUl!^qB6FTki9cz?~DgH z{5+L*!ox^TV+wieqy^l>g%xc%2~QX{8zk&1Q9zW!Eo_9HLjX({yPd+Ev89{jcQq3?n3iTh=9>H~LY@pi@CV@|!-cfp+9 zxK`%Q{r^ZG35>VH3|ogT;IjgVL$EB;Ql;p{B)l@0;jKo5dJ9v&5l14f6E`4Q)0J!n zTi_)84JEBnx>~PMEBaQnh*;f#NK8%2zho?HRWJ&dQ>wh2R9?n6K`Xshv?LTq4usJA zjQ1~$@2|{xpc*UOKevAFi0wu*=6{4pIlzJVCz`_Ow4b^J|6Ov{Cezkgk{y!jU#xeMw5 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_114093.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_114093.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b53fb908fd5af6c01c10e67283e9f95f79ed81c0 GIT binary patch literal 3354 zcmbsrTWlN0arf@{Jia8_)WcNZMy_nek!0C!ng*`wSc?3JtrSR;Hgt;1iMNzU@zuMN zYzar;T(}S!2atrI5Yrln`cv4IfZAw*_A~$ek)ot#w@x7-{n1AOIr2vln zaRDQO)0Uq4R(n=BBBXs!uQ5vDzh+(7e?93`95;oPnxeQ1^twTZJ1Hp6n?l?}yozh` z7((jL6n9)yJaHd~2ugSEwM<0fOLD-A?&T@oRdI#4*W*FZ;9=4$Tx}&pg6ZzPyBc1# zS~h$3N|06(Qhclax^C^6O9*5Su=QDp+6qLMjr4JQZiaR3_4(&^)Y&ZrZlZVu^hD~s zc-st)1-IDtQR;mCV*C~93Uv-k7hZXJn0n7$er@PgiLEB!o*RSR8awk&+PY)C<_maJ zJ9k485CzwLEqV-R(QCMhqTw~fSqDNzpAOI-hR5(3?)rI$(NBK?e81s0e3PJ6!|!0S1g6Bh)xU zi-DVD_p^u%I3L&&3j)JaIA(NMZsp)x@#KII`q5Nx8Iuipk*CpH7^k^z$&oh@4A`J|jKJmR%z@g4bqUK0O{Wh3onV|gxfQcTrBhf$KTkfsk z4R5D8RP~3=5esaFB9+Ku(7gE2-@0(B;#q#_ul~Lgu8O|W%&qTNTf+;}i^oet3`s6N zR~p)kwB0`W;L!8yhn}wvRuUC&HTrBd5`FArS$5G9(+T!aOy8H_{D`OTt@$sK3?J9b(9<>TtZ1+5pKGRg5GFwhR^}Nu!do9X?D6@ zLvhrb!PoGe4L<)3`}gNJ_Om#Z9V<1@Ff#bpQOq?;g~XI1-{q)7(PyB}oeG&wCCC>V~uQ#}F*b)U(|*)c^Pob9-#%7p?^`_4f7ku%yLAU5j-wyh1ozfE*I^5LNm z4lTcVC;1!cUV5YB<@NRI*kxID`zt_}Pzn{{ zseYS1ueiD13~X@qbSUx$pJsb@jrzxgffVix@D%?4)*~=I27>f+K>BTkfKtx_prTGR zT}r?+Ck+7OFk>*;X2wZ)71{k!M?5KSJXwdehF9Xl=6g*ngUB}b_EL-P-~7v|KJ<-VQR#dT1Ec!l+i%5?=Bx`o0JQwZkLMXYctDZo$EYp=|`y za(IVL+sBaLE6sk)cU1jB_Qr-7XfM6dRI2xyOZ9I}*Hra}=8u(+-FgmQXtAU8o%zA? zV5PgF-8t~7*taGO-{8pkX$46p)13hpcX!r1!wO`! zE{R4qePDeO+Xt;_(ppSglO|Qu2OdrO;s%?NTbhvQ3vX1~mp=8}nVsDgX|U^`LXp|L3CvZbB9)1wVTUrqj4iVXZkT6~#!T38 zTvvyxtj16HpIxJp%5A?s%B%cEezE8<>;Zo-*m_}IxX2Gzl};dDM8gt`G^u3kn`R=O zb?P}t#FNK|d*lHkcFV_lj&w)8#ekI1J8|&eOf&ROCd|&5U|p_ERc5iM&vZ{1)?^ zpWter^W$AjCQxxUsEn<{@|SKLtgx)ib!24Hegn@sV(v}@#&LhNm(IG zWBTrxJ_!+0F;z}DBuPYlbkqRRp-CbYcSxC-cxW`rFjGbdT$`R9f1heUuj8{uN{MRi z<4W|bVvg&{Hn0iu<>Z&GpbacLR z02D+#tKp;;$Mvuy&Ast0Oj-129SVfz2XX@o&CC9M*`8H@utItQx}n@q{@ghqTi^5W%V zv&ZH>$}7L{dp6<3C|U_B+8}J`{qzFRGpV~hCbj|5gvZsEsoM%Lt1?MOEd!F>p*)er zjItH`ZPUms_BOXj&kiHQRm9t-xdGe&g#RHP?|5VclNOjgEWw61XYm#{5q4M0ROMap z{fk~G>G>V}x7U~bAv?Fl#UiX3Wbr3a#4||v=%lI;E}EK*utZSJ>6Au9RZGXB8m1L% zBRs;EF&C*K0w6g$j+>~u5hmh~8M2aAVsYi;xJCqkhMFXvQNX>~aexRj8rBUP14MMt z9udm{?Knhqa78%4DdESUB^Y#j2c?*1k$BEyC6&@Jehs9T;a^14YM^?4BscP9&sY6l z^#9s&u+VbwQOmp6T9>+(1IH1&TtzXIfzR^@;4Q;s zDk>goi*=iChJ$J;R&gl~FfMW|q_T%VR=NKjho^BY1hpazM%eXTAmP_xB0f_yOgX9; z(Ct9c1Z(T>ZBe`lbwl5Z2+_En)FM3bV9l6}n+DWDY(}H4T9wB&Ma8dznL$%QAKZEF z8heLgTTL6XQ*>QmL(gQ4z?4O6!Cl$D4L_>eHw*7M*ZV{Y=Ubs~O3mdUY?JVDAYPZlw388TB4(1RsGc-T zZ7PMGn&)Ck)6N5?hbKNK2lAL6S8pME8|T8_LkV4-jBAJSVX&hgRbw2c4US=$$H@O1 z3jB`ttuda<^|SRGYz?z_BS=3ZCvqn?LR8vSl%9q`+T;-HcdXjNOelYJ0|B|Qdy|sQ ZPQ(PBdh40pe{?e_^vZIuW!7h(Hiqs!()@|r5t&rj$zKTOCwN*cQ?#vDhnAC}S zwR`S8&wK8<=bSr#;CUZ{@%PaW2h=KrJ|l=dT$RPf&nSfMA_+-U5=E_9 zA(*xa-E zSqV8Tha`R>VB1#rzlK1zme9u`I`a@kG6qjrvlH6Z?fIuxbV;;>w^6hn?9|)5Xu}v4 z39S+J9_+dB`qhprLbtu>i8=4V4Kn1U8>4W#yGW~e06tVsON`hvB*EhZ1TIH&Bq=%qt(7UvR;&da~ohC8Rs|l6RD^!M1Nu11B zg>}l3aOkn7o@vdY=h~)MqX%~E7TmtiIOIdF)^Sp2?o>jhhxOV8#=e7GokMz;%GPW2 zI#a$9$6)FG689C8|Ra52m#w_GOh!zaplDgoNFqBBuLg97u_yLM)Tk z68*!OVHIkP zO{N1G!|vW!qZgrm+?_(##aB8N;15QpsaB~JcA&r`R zb54WTmPc7lQL&e7PgaVDSf)g^9_E9)j(_*HzMWBq)T|hjTL;A0kf;r0(%%yM<+RqS zDzR2gPGwsGYpv{{LDL!X%y$`=R{j@rHC$5|xhKxl7W?1- z$?X18!%GjFmuf${wbbzY){lEiFGoua(Ng%jab-CeF2wKsaE2{jdp|SFmFmB}7=PIJ zQQO074>BKfrL(=I`rcBo&*)m^s(*f{&{sH`KVN)frt`x?Gnc1@5_i(*T=Dv*Jb6!n zn>q5>d(?Pkg>~QQn&_JRap6```xAR)144$BIs(Ym#hm?<)PdP$;}Ft2=nnO~PMb%0 zDo$&lW{tsJJLUpyza%-@QFXg?qdb*|e)2W!F%pww;qRKHCP5QkU2o2GEuk6(u+=B}O6)G(lMt%~UdzN+tPo&@hy5z>K};l49znB({$?zO9Ja ztgN)0OJ-tXQf-HY#X1Rs>IneCF#DUCgG=(JkfS?n5@m5gfk4eYl1Ys9lv38Vp%){7bF zaa;B<+Yg;sCGGC5x6D3g9nDiYkn$biG{ay06@Y&kXoU|?UC&=HHqSIaYK;80?sxkh z>{}Xn%=bW(@CBz%immd4B7#*vDka2mLfG=puf8OYP!ulr9 zOuWAA4J>;(fKz;)FNkIDp6zRH%9r;Q+TPOym|x$HXx*T3)J{WquI z{N;7y@?>|3ZThNWUgWPadd!QB>`_j^J_@NP9JG9CNy3N7Fub_fYp)dY>4=%Zgodl& z+eFpm(X3(>1(Qf?RuCXZ(qc*$gst*W+KBPf1iVN9i2|lH%qVgkd&OiD;-<=&5pSxP zfe#<^^q8}PDD-EN(j2nNw$d3=8EH5vw=2D1h;%viZvfUAilRP4{=XppuW0Wo<+;-| z(X>wQqrSNwMuE`Ojr@)E8iLhUu+QrN+h7pwx1`!@sbHaF9lkoO%BL0+>cS A8vpmO{@78JW6SDOS}Qv!WJTTJ`|4K@-Glo-^}WM}-lf1MdG z*vyKtL}d?s@S-YWp$~DRM1>luMygbcC&X)BSQ$mu4UsG*FMg>ECRNqG^xT=9-8D5L zRj+o>z2}~vd(OS*es{io=LI8 z97Y&<%!0y#pa8;lFNX3@8R9;2KHq2kyz-4T{=}*ue`GaSjh5@7CQf&SYIJ3f^sJ#_G>u~1kWM#zgPv&&6E39~ zlQEt!W|9Q6dS+BtQu4S;1VvZUqZ;9o8AT<$VPpPXndA+XJ69 zbt)I#9+o5ciqQmq+W))#>w~V#8GX{oDhaJ?Oi4^C=2#|uKpEB2W|yHSx=byV?EOSeBl+!N9KC=w%y{ zq6>r0V6kI0yxZ=37LL?#>r!LkQc3Fg{KTJ5{qfXYWqJ4N{&-1>+x=^=m8?rmg%L-l zD&@ht6e+aa!NtSHz|Fmj@lr!~N$OtiUpf4^Vddz9Gmp}x!{UD7YA<3cXP|? z!=B~Qd)G>P;!nk2+C9%WVeaJY$@vcpzbWQEQ||6s9KC&Qxn+fWsH_AY>@Dp*{uS4= z1zUj1b|a~Lr|F;d`+S606>o5@FR@r`$7{t0SSAbTPl}% zz7}+LxQGt#2>+8mtnxQl_;d5@JU7qJ3jor>burJF6iD4>dGLY&I?zPOpRM+4z=y`y z$i<2rzrlWIdBvXYmcu&S{ko+xdLr?VOnR4q52 z&^FpZmY~SsH~Rk3OSN?g$S2&J-U^3vu#aDFI8@eQe9aLC!ox!lj^N!h2u~f)0V;8W zedl|s5TWmaD0CW+K?7^Iq3J({5fP6imDIf%W81msJliVh?V#u)rA|vI##Ez-34{kYOpFm8U_wokzzBe( z*?olYGn$?;TsRi?!rRWc7A%` z_JO;2we77nF|-hLf`#tiNd@E9%*~m_j^F30fw0!x0l)91055DdqHxsymAiY}5P;rU z7Xj|T;1(K3dAR+T=!1>n9#u*X?oXvG@~p*9l(wDbMn z1se*T#m?1W2b|`f_D?^${m~yT+5_{0C9ds!XMghahq4kqFK zpMocni2g{YUm=8@7&l2>B9k^u?Mha6tKH4t4Vr|!+~v_sQoV<~YgE3Xsf;?6)Z#i# zjGnmhYap8}#`s$l{u)W&pxx^@FxNKQw#l^O6BvrAsrlABH{aQ8a@k#F_C+&gw^+o4 gJ;fJCT5xnLji_@% literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_205496.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_205496.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..374962a6151ada1380e71971fc670a6ca257dbb1 GIT binary patch literal 2747 zcmbsrTTB#J^v-j4XI~(d$}>x}=#&SRag({{oO>Sk+8UOdHNcdrvqE`jVE> z09^s>3HXw3L-;RNh@B!U9>Z<$Be09%F&qR=z*VT1 z3bRp38xc87>k%v@)cBw(MT21(^O7pX1{KUk;*yLxO^>UJSc<7$%$y$8u`ep=YWPY} zgIQLtyhM*c&dFdvSATUW4y7$yiaCPlPTU(7*s}jX>w9g$!TgA##*~O!1Dg_` zv_~+RXtlJ=Jh|fW&Gx2x=jxU`)n?a=)T+yuuFSfsu3x$lo)0fj3-2%0w*BzMgFSaH zXKN2uWZTtqGply{WzFy7a(TZ)RGb2>Z;A6_%eqK6!lRO#0o_ z=W`ty`DeCa6OMz_9b_tl3G=F=kE*edT}Fh9NQ1G8{4>a(X*5k;D=x%hNlP4|6iQJ8 zv~E)f4R2Sh)Gd0^PShk_l-}EvM2bp5A^!_|%nUUH8ZtBGOtQKw?})*YW?;|;W6;Cy z{OT5z*?i0S8$MUyGq2F!o?p->H?(LMGKCE03=;W1%!Wo~Nu)6+>*EOp^RhA)4k;nK zVOoBYhD!pqc>$(x=vPL9mZ^LbDZ1nBVn%IA4$$Ba-KUbUfjfMzG{8MMJSH-j?d>{y z#7Y+{aWoRefN84L0I{<@B1JFB(xHip%c_(}C~9MSBp!mgJp@RORRw}t0H!}obiMpW z)qGXvlZB6#D)!&2zu$1T;laSeOG_;$vb#>IwIF1*E|nZyq=vr4fS{FG3&5`R%a1Rhwb`b@;TaF#Tzw z74MGO^QrT5^_lv|72>^}_iOIfJRE!C?K8Vq1pjP9sv$i#H}<%^{+4n(dNca4`-#wF zcC5Kinw@K`Fw;7DYT4;qcDiAn^`^XOY00_kMch3rq=auuao}{^&Ob7TURta7?Lz4*eudklC6L65~BqeBoLSjMdC3y)&Gcs3%6}^6M4mFU&&pcSql&PP z`UXtcpYx?L0?+Fxyh+%Z$5aR5Gs9s$m%L0kreprB64TGi&pIjO-@WA9GwHTA{tIn2T224} literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_216901.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_216901.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705abc509c41e6ea96fdabaede1555d807833705 GIT binary patch literal 2331 zcmbtVU1$_n6ux)<_IGxpsj1njMA~i~t=UGbP#VNyjOkBOinQQH7^XXuWaIA6a%U3F zGFv7fHNl5$sgT$T_CcsNLM;^f>U&>YVj`0?=D z`~%nmvM~scCYp6OH$dumsbQ~%pX12|o)HH*VTZ54y;LBE6LGlax*Byb=wed`KcTs< z##%D4Dz}jn4Kl=AJf&@JpgJ)pzSKAkM$<^)hY3^H*;y2Ua(016%fuWbT2amGR4AC5 zPDRT$iJ{084^jSuNt-5$noaT-RSURoT->Hl)0jDF2gVj7L`n2@>5#DlA*M&R{Ha!%mtI2x5`@1$c$^b^O`+wmJVxUM#;`t zB$u&`VmSj^rctI%w31aeEu*JAt-`(vW!q*+&6&EP9_c;y=Igy+q0)JSl#Bv-5fZYy zvx>l6^t1!X$@!tVp~d~n@jdRT)p(*sTay#DU5oqbJG$#~_pMj&_S6rasLLnZffc!} z_Qv8{_4agKPT$&f_ocehSC{+Tf#0L?`R=*yTKY=HJ-I4Kmj`DCXFsl;`B~`PgoG%^ zSd$nAyqWwII`q7mw1e5OgWivFb~6^5Xr1=Pu8!|;ODqFWwRPs7fxfM%Tbce-e9RDA zCzY!SN9oa}$Fi-8?K=lf^dgYUOav~|j*#(iO^?+j9;VJhS% zbxq-@sN2(JgG#zFmCqTu;H7}l8JybTx0wP49Ru!@9lRo=97Myz3P;$HNgFT)Q=w?s zn#DqOkq+QA+*ikcZuY)o615(A=S` zQ`e{Ncckwa-xY5c?+rYNzUQ9&Eq1`|UlpRWM`zw&2`6eiuz7h-u4&8R-HYiPhp!#J zg_lzYSHiLR=v=hcdqu8U3zg3+U+@2-x*ASgm7YdKQQC;3c+!2>yM*-~D}_ zx8gzSOIR$RJdC}L&p~sl7)CQ_C5$B3d@8HF!%s9o0v`p`{7Z3=&h^2r&j0T?0*m7y z7!VN#WSz|hc?}qi*@k7SIn6TEUNCeqgR;$V1V)N=gNnI=Su)^34-vzfEZFd*TKS5h z7GY%7aYNHd8r&>)73^j88~TEmq473OS*kZS(EAuY4Nw6?M!k=(mxNkOstxraoCPjuUq6toPQ3 zz&3b{@go%f17fK6rTOHV|(rV@fT=GRFq)SCIm$wTGW>24@hE)phX1Q)ylheoH(}G-2jQ! zVk1$3+e2I~n6?r*CL~p&o+u}`$8xC~R7FE7MJktEDwnjE9;&`sdt*Zrg{mX(%zJO% zeDBShd2jqk5L^h_*V+4nj~od7N*KFU6p6K;fVhKXBvVl|P#CEJ%Hrt}W`L!TLJe6s z_R_*A(+WG}d0s|1nb}0nZX)M4mWP}&e}kPZ8W?bZJtts0L#*t$!46ccd+fY{23#~! zT&CifA-`+m)MW_PVq$7@lNb`ukUn6F9?c6XRv&4Bq|JiYrjz>DmwIq zHbmQ0X_Tqi;(XGzg(khO={8)3bDTrv*lY>Yrew-+&K7NsgD~ka+(yMqt*dx$onUwj zei)FFUPCY(7UeTM5Iyt>Eu#><3(yUQAR7Xa5_wTd6RBmrvTT36?PF;)kWoqW4DSoL zzUL9gAzpKg?O(F0V6n^(&ho?ZuFMa1-H+-rKb$4|Hf2yowdF^ZVOcIiwmd6g7Cj+n z4UIx$$j=5;67W@WrhW>uBa$}k=P<2Du`8j*!>TkQMr6!OsuT+=n2p9I8FQK*R~3H^ zrus2+dQ`{W5lL4g*F+6^S-JKCKZf1$QGI=~j2TmjFDkhTDnuQrB+QsgL}YQKFvKt) z4}~;f3nTEvMZ#NS3^Nh=8g>^QkR@Hh4v>goQL6#8`hT~-K4`rfSBJHP6jWLVrQoon z55{9}NMR+Ww`yvzRaZt5t-!SwQFYk7mWXRgOJYm}xd_|T<1sN9mlg5Qk@kZ}kHQS* zhZQxZMAf}umh49R8#)vHQHKQIWPhqZ)3E5-V|Fcjyk*>qP@O)X62DUKjolrauA1&y z+}oZL+Rg5z%_dfansg{75{)9U;;y>gkUp1gO1+c$IOlGj4lSH}*tT$fCZ6@?S^_zD z!0cGAuD#Rrc~hn#6TI6r&ELPguy6Y6jPF5Xw)XqW*&W}u=Jua_Tz$&yUGZ(ZEoExc z;nel1-8tXu(^nSchsPJfGuN}bbFFOY*xUj5FOYdQTh`~Dh400qm-2q~UcvAGXdZmVJgDMsl8!fHbsmW(-iwkS{1Wr?y; z!cqqM4`__cBw6@4-F`OEMo$9x$lA%hg(hWiwq}S<^15K}%;2qpL>r7j5BY3qX_Q&J zlx{N5tqaEum(5B&laBu3JT>}l3NpbSirA+oaF$a|*IEXpOAUTE|A*dgE+hNR&D{5TBd?AqlS&m$>@&@yT8sc}9 zN&`uCQeH3vNnJw6Y#6GI25-EbT? zB>aN*tWb`dbrW@YdKYzu%6rNC)0gVY`z*H7W~(f=8hI-xFQqQ!Yb>_5$UfUf*fj>x qo+8ZQZImzFkw-90)vS?W?M*}p&zyBs&F`HQ^1Zg`-8bR3eEkQn`HqtS literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_369711.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_369711.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4183d5202246a2c4c08dc1e7e6d99aeb531913f GIT binary patch literal 3264 zcmbtWO>7&-6`t8YQY4oYMahk13$_|p4&&B>YC%m6$4)BBQkqCXVMS?ZyPGX{DN!Q1 z#O_Kl)g`h*6G%i3VhR+Y+Z1MUssDgFK!LhB=@=vzF)|8assaL%i*AaQ0&P!yv*b!N zt-vWdZyWVe;<+N_Bc&?7VmNj?G64k zRz}@1tKEuN;e+ZMVxpA;4Wm4Ya^pJPzHM3|QB1guY|KVIgfY{^m(#T=UCu};g$S}P zXQx!cr8QY0ya|*+c*E3mmEehJ2f<^6eR<9#!HjI`sjHF!y`o-iqVvSB&6#`51|m;{ zi!a27q2Ps>G&=lzOs8JzwB!h%ODR&OHh?P`NlPXTA}IBdx(sZ=85QC+^n{c%bwl@1 zWfn;A6xJq)S20Y7Wpn`R*g@?-uk^j4>9a;oPN;p;a$;6Cr?u=4ZBs48S1HAUV?QA0ynj+l+YAaIz9N~ANBVGhX}K(o>kNOcj3jf+KgVvxQ)K< zLSkrTtTRN4S)DIJ*A1# ziQ@V4WF>a1r!sO)s`|fg$2NU|l@rAirKc+G_kF$g5Os32I9j^+&X_&C6>2L@mSZ3G zlty$oC%>vb8?T1qYgZDVE9%zj7OGbks-XpYbdwX7Mi)nye_ooa4E%BMgTcRXPwzsykWP0@KLlN(4mWfa zQ2=*NP0OvvA#_S)K0@q;g0JtdPA^_3< zV2@wM%b)?5inkz`Ev_S$;J_Sb@fJ7ZcUSjEg?EAG8~DN=pMQ-0!Tddanjwp^*gKxv zb%4Rh;xD48cYOO8I7GO_oFYdV!YgJzr^0zu=TiwaL6O95G*JoQd zeKY|s?lvU-piRUe@7xQne3AhN+3|hw0RM#Ai3ER@)O?gB+}QBM^G+!_MP< z4~)cnHZ5mAIc=PSNsaY12pT7$`LH?XE!^KHm!Wc;?DE>lOXRvtxl5?Wg1)*08lkOF->@YRV7P2NkCb%e~|EiHD=zy z3?kGJHR{POf zHfXmgwyj($UMlxide=ImfA088_eb5U^Y_JZ`@*I#v~sd|vNT_wUpv%$NBvXgW@h!m zecy}r(ANwOoVR0Na=zt(#g{g`!40n;`W3M#mgIHs6A#b*E54$y)c=kMQu|ZT32wB0 z2mb%;@bkjX6DSa}f9hl{0bcGkdQ6=I|2vqxhf=5j6>}MCqX5T#zgpl9X55-6uogQ* zb()&{-9>BYEbtWe|F*+XI~IZ((+?=Z&ZJDzlF+h-sb0{;qs?Al=tAkxp;XeLaFvT^`HZ&T^2Ecn$#J@gzS!`?#1q3AKhUhI8++i lK`>Mz59mN8!_NBdGYE?hyj?i*RScuhlk36bi+;z!{{qE!(EI=Z literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_412290.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_412290.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e49700bc9cd5b37a96ca5e4047b937e1e50de97 GIT binary patch literal 2521 zcmbtVTSy#N7(Qoac6Qv|*^MzuV@)wP$zY>VLTF0}ny9fCHxG$vqaoAD%&2>v?aZiA zW?=*&E_rbKP_k(Q34JnEXbOeC=Djaj42_&x5Rw<)Y@|^7)c?$0+=wMDJuv6L|K~sF zKi`?3MKOS2bj81$xWOazf(n)4D>K{Q0&^E>NW)1qQlfYSdoq(^M>vdh9QX26LQ7O* zbS@rxT}LI2-CIA#Y1}-wTn-rVL%aaK0_c77+(=!u0)8Hi1Q?_Ts=mP)EcrKSWR&oi zKOE_ohX~&%5BB%=sX{p+C3H_LeA^#jx+fdv(ugw^Zii}kxopo0Z}4QQo^_7k_-A3A znh0(uXNA+S<6kbXISa~cz^Tgy9pnTYhDKAJ6Li=LpACV9FJ(?ZTvn{f51rdRoRA}~ zu$2|EVFyDtu96LM`6_%iQj@9ntjZ~!h!gg5(Q|`=(F~Hfs7Z_8q~{}H!le`|A^8Yn zCrKb<8snytl4Ba-6;nx%>x4@hibi~vZJ4^WkKiG~UYfK?IHlNT?3Qf7tm(IQ$ulHq zOxjgrjWA= z8k#Pj?&|J-{}fmVpU};;o;2TvHEFM``!KoaWhlHjlp87>UJbRl{p+tu8{+=_)!dEy zHy_0A#g_2Wnbmi?i((L`9VQ;EXm_Z_Rbc!Nc3;)bz9D{Zdo`Y(&kDa>h8so> zdl^aoDB;veO_3Pl)9jgyPIyh9imAFDuHfg9z(mDrB$wUvwqDcAEMTMq&b!H3j z9?nhL5{cP?ox~F#fLWa&K7fFhCjK#ixZQPz_-1v}u)M2I_!6)s-%cPP(!)Gq0f~eg z2jE-uQU07lu~e$RT%dUp^o{UYcVL2=*2Vpcqq)&S+rze}2c;hye>(KTp(j(%#Hc&4 z9*8Wq=i2jAg{h}aZ7cfssYj_N1J42XWmS0V zbw0Ql$OZCUcSP{2vXREMNaG*%bv}OU018FikGykHp{DAfx7$O<*TCFGF_eX$<#qZ@ zZ3kAvWj_@WZCNnbv*6DAV0b<37j z#nR<2FtmbhvuBr~8`@4AL40b`Nb3?u{HAVACT$Da8=KYXW2?y%x}urw;AYXI*y-f~ zZ-w%!IAzJ+oiX17pDyZr9VQoTg!j1@w?e3?We!@oFz_-^pKmL)tp*Oe7d8&HKDc)8 z+Bc)_#f70F*ZgnY1Qcp;qh6<^ka--|A$+L}5-HJ+B>b8t;cX{Eg)=E02^){uq)s)` zmaX5)n5C=C#?rQT4qovjSULlF+(>HX5r`^br3|%{p-m?Bb7l`1C|s=xn6_Aq@e36C z6^Xy0mJRH`-8|R4#f0$jr^BB@z84$JLe5x1qA$EGRou&-_CNQwPzsxxo#vj?|JU!F z^Zz=(`g|S){+L#eMCCbbFt%sa_M}3m(sh9>I;>gUI1+lz72_`_cj@7PDd2z@}V*d1$o-@ z5>0kf(`$Jw_c-iRV!@XqgP$17y-?V59K>nK@>&(IWLL4x&Sy#BBlH!kqAq$ux5;Sl2{+a!MGApi5SGe#Gpp_m>yJ#V3<1A zh+qQKkgEyPOSp?8Ch^CECXS3M2F$89T1Jl%uRdaK3QLGFA{H`R>{QNmo$kcc7nr6M z5t0#Ai91VzhBl(vnhR6%=sJV+Lsz0FiTDIAN=xZuDK)U>x8{uLFcJQnr@nE%GPJK8}e-9 z68G5Cmg!jK#p~xL&P`s*zB_eg=FnaFwtS~~uKPFsoegja!Zmc-2NMyoW{%)QsL(2? z6w$V8H?hybb+6IH*pkB8r6obRF=5uUx1dvWw~}NR*<#tGxe~nH3lUn3Vui<)6J?J@(!%ULeFRdNEmSJ=WQ{2(+BgZjHzwD=(1$#oS6lSvKZ<3PO$ zpYaV$P+21KS5FU24a_ty`|2~@tG=q~E0A@KvyG4T%RklpeBg%z4^xkQeIW7#rkkgl zv#FWXqkWBw+Wq)^{9*TF&xK6KT6ImPXN68r-TrCslsDTw)3fZBSGQ4lfs@2c=Zi`t zcqiK?dRE;2EVtsWm~1Uf{;arCQ@c}`K~0hlsq&0m>0o0rhOx(SHsWcO3go@^ro-%}4^JaF*t z<=dC8 y_zfnoA_ejsUH8}7>zd%XJ}@z`ApmdV41*Zob9XIM{j8HgfkVswh6%5Iuzvu~D2WLG literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_469771.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_469771.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89d6e09e15c433f98538ee77ab8d68fd67a5f1b9 GIT binary patch literal 3181 zcmbtWUrZdw8K2pIcii3H0S9itMmCL$dvR-Ihmy2%?YJC<{DG)3B@MQ^+1~EK!QJg? zb|F5pr!HzDIOTy$r7Gt%l`nm%9b2j`RjSn|t4~$?k`q>irc^>oUU(xWmD)b_o4s8) z6C$VTz|G7z-=F#B`@Z?j{#BMk2wJlJqZt%L=o3cxP4p%^e}Tz+NJARVqN&D+r?5+N zIetpONXN6T9bdIEYMd_2Mjl;bl*S*to)$FWE#a1@FcsjCj%HiGsP9W?qTZrOdJE_U zn)E`uo(A7SQ(>?Y_BEy=MVx5aVkXCF;0G_gJoKV6=1-ET<}U!_?E{ng&rfx?MjC2kR4oveVR0c{gxqG*bd(E_FM#ZDPo81manAO@`#>nM{H>h zER~|RYzN#~%tp2h@q)j%xir@W_HCZoW;@xD#=3?)4(D@9t#))DFZQUFMTkeKt?6Xa z7f)ul+HJRa-^WsWlQh)g{${uDMlNpShSzU%yS@p&Zw;@T&LjGqQb*Gc&AIlrvODa! z-F8sLf|slopkY)*31OVDglw^#ComOqsyUYkP|nKIP=OfJM9nD~jY=v}^V2#NvWBWr z(E`e(qG=gKr+6aKN%0uvU!J#UG^bi5b5${+*YvCV=pqdp^VTl2i6~O(((_ZpP{LAD z86CcmB+N^~iinDZjHcuo1Gth?jC9(hlID%fWnfG0sL`NFQcA%hCJ8WQ9!T&M)+T6B zGcA{8b^`0zLG3?3>VLx^b7nzJ>HRZmYEHFgjQqFNX+3ZCgOz?u&lUQC>h})nhN57Y z`l&)ufprBAZW(zcWoWuGF!mo0rW~)N%LD-)Fy<@43aTJodb z+Q>Df9{!e-+z3ThPn1tozFF&d5bAS=n3JRB(aNQF#+>2JSbJsi{R>sAI`Xbkd$Zno zrXD+UXYyY1Prdg>?ke}+tbgZ~dh8WvbVHDqN0&xdeq5Qa4g7ZS!@<7^PwhY?5Mjw6 zhoDP&*uYj1Lt7WM(h#X-w2T3_afyQ;-&Fg}8m<~bWmLv#&hjWgxr301+q(xPp~>C> zce{+sfExdUJ#huEfCjt(!IEUP_>S0;i%Ehl+QMwuU)>uu(MN?Z;Y+)G@iF=b^LO=G zglx{{ZwGF9*nyEPUP6iB_|`FShzhBBO-*oA)U0AbhpVP9WKw#HfrQ^!pyKe`nSt(? z+~b~=GG{~jSO9#GYzq28NyH%Ug19fA=DHau~`EhV?8 zOjbX~9EbG6NP}myYVN0+dTz1n4WbqbIyrSVYoyeyc@8EG)>9y8o`mjS4*KBG^~2W= z*CuaF+)3P2*1EpG(RFm+%4a<8I!~Sic^5?q3rd1(a!~@TG)QWOih!6}o(9qYpw{U@ zDgumHMHddJ)BwvwVDBdEnP9la5S8+PoK%X=gwXfFqaaZE^oqy=L+jD>6fjsV9+z7>1PnJ(s7OD#$AL_fU z|1NhkcmLvp(2LH{=Nyilcaon9p_PH9m)3*P^n8l`W+dh z_NJf{U2pvw{Qt8vEK1u?ph(Pl!Hrr9yxeWqZjwiS+_AN)Ag^q?T9bT1`YU;{q$o)59 zZR@#xj5EMssi9&jYvlEWKm$ZK=d%_(V`gSiXU~_W%;>5{20+YYp<({RUpC8@k q;y6|r+D0(c;t$!tB-`Beoo5l29|pT|{IeuRu_xD}$CtvcgZ~99zqVum literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_493615.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_493615.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf91c6fa8f364c3b5fbf49fc03be5da6402028c GIT binary patch literal 2470 zcmbsqOKcNY@a@+d|3Z>tNJwZutKHiHI*WfOKzd*q3xmSyj|Ndh&HV{tDTuQZ{E!8 zo6r8>@i-B{!7slX{Xj(MIbBo%))uVgflMO>DNGCv+msn*tTh|whIt05Ow_VdIZ+`B zJBj8h{IIBSszc%76;u@6^s**hjfjdk$qzTw{5U4ju#-iqvu4AUW&F|#RStwCN-Dqu z<=|cW6VNo$nTi}m*($eA6=tropOZGwWNVrYhGd9iU=0V12E$=-Cqc%J3kZRPqjVgc zd9YCP2otKLeSI+hb!DygYGscaDBL5N#VCW4`) zq6YULeD6Tt`>;dAF%>7&7~TR7={0I!05Z|C+dDN>7%J_Ucek5^i!ZH~Bu`QN%3C@& zbN=?{3r)V#_HxtPv%4Pjly@H~Hyz1~3(Z?f-R0(8v%&*c`JF@M=0kbMvUk%|xDYPN zrLLRNdGC%qwal}dZb zs^`xji`UV0W}})epaL_(>Tt}4=xb_BVeYYW)G9!Z{pb2BFa`9V>~VP}4<{6EQp)jq zgY5cedTyHgkUVllj)YkGdXrKkWW~PvfniWDO8Z zOUN%00VYLglnBspN`g2>V1DR*2Z#U-sb{UmBBDJt{6c+M+l@*%XnewlVXA2KCNE!x zwzuPURHDUuK;8_mb_3AwCR+3~PhBirEOpIvJ#FMnE_ zGo{Q^U)MeLe*A9y$?<2-56xrCuG8kgGVjdqpZI7&@-9d&z*C-rrzp=$TVLi~Q_h03 z*muJNTI(#F@&4M>Ac(6zL<*VXI*Q~jqINMR4dTH*e;9iG}T5LA*7@Uts)y5Jr_#HFj69t_Ev8tn74_)UA9A8gP>5Fr#xCRF%yIIyavW4Z?Ejf|@>(nTdWs>%x94Qd*_3Od2sLbGu z7PM*x9{`=bRGGGpiI%*Z%oD3_2Azf*;%oh&0-<^u**uI1RrMB2_s|%aCxuMdWkG373@+oIBL@uZTDsr|Sg4ADuUeF|umV;JT+a{q)pKcn^~#&NxE zqU{CS!0cM%5bIv0+^VmMX)GRFML=d+EHbtv5!yqoo`4`#l`54ZM?^2NLZxV$N=WpQn@c$L)Hl0c#~>=GI+owO|Gb&^ z*^jbB2(O zJj-L0&cB$S5p>~-u+%U}2f#iEx2oaVGA)$SV{mc(L$~I%X8Clu5LdW#mgKpi&LD7F^FP6-{Gb zX<7lk0vEB01w||BhVtHl4-WPo0tuBa8l+(4NCzCq9?`r3z(G$tQ2X}lC%-wls8%~x z5^p&}Yl)6$=+*ev%9quHx8v3KYwde)7HaK%wRqp%@`KaMpP#)yQaf^P`HM{LNTwFg zIHy*l@ybwj&uv_NvzB=K=4dU^TZ{JI{qRBje*gXLwZo^E&nmUUN-e55$JgZ8!pQtc zCA}hdI|FNibnV3KiMf$V<~N~h1A;?ZYk9L)b(>)D6`BA_8D}_)IdNC^1qG(#JKPfU z2&>G!q{Ofdh5sM?@j71?tTxX(w2$Dmlg-;)HtcgZcU|y{{a@r_Q_jD{{>Ac5dw7|p z#&E&N=Fgy{G)x6;Qdg546?JR6WKc;rrY1BaiPu>Rl3fhA9VJl_+D99wA{3oopn(iD zm9_s66=w`mG~ME&lGjQ}akE260$flYik}K&&}SwaX>Qe%b|wj2B5J8*5V8lT9q>25 z0^p!EdF#UH{ODp&wP$%}^5^znx_;`qH}y~+hI$Fb7WU5XtxPRWEqCNM+nwOfVZce>RnQo6j~1uIUF5|9fB~ zo$CWy=l{1Kf%&l!Bv%|xa_d`wB0T_T*fLB@(Nxn=_Jg1scyvvj?0C3Qw=oaUA zJ%T!R%|a^Y2L24iD?N)nE1}n&qmR0}zdwKD{CA_y@wuUz(D|xrK@{2M47-_1Mo1T& zLnSW^8uRvjIrtDv!X%|ZA0uoOQa(Fj(KfAEFfHS9iFjqqPZTWo9o*n)*n0zIteDep zA@>;X!oj|LQJ>5iACX@0#Tv#O2TH7510^+Sc6#}78|V9d@DXD1VX08kFe!wE8LQO<>}QhTdamT*!&Un zS#Ue-wv;K9X@!`rp=D;>IHx6+Sg1iNop>*EP(Mbc1Nz~E`wy6MWCsy=mesQ4-@gMi z7}>^YLz&GM=SpD)@G|Ig-aci-mkyiR{Quv1+*SxyD(_iHIfOZ}o5$cjaw4G>$ZWkg&sRBITVNl?-AO&SEO& ziiSmFwo@cnQz>_h3LlI*G?_OXGIBw;fm-;&DtVkL#ZhNP84<<_yLQ4drvfRSr-E6` z(Wa4tUUG;{W#;B{;+T=MZ58Y`N5@|$yU!QNIlE+-xO>AD>TV?Hz%lGn(Z;(<<2tPCkg!uM=w{Kv`tIJI-kvu=L#1<=6mX7gfDN1Av!{S~ z=+OqGsI%IPR@q+}n~!h0^6~X;)$_IZcCUXio`@)_&R549iiKpVvbCD5txeaG=^L## zKCNl{YRP@xp@r1i%8BaRwe{&*iYY#+X?tp^J>KC(*rRu*x6(VWyyzWR6s3!YCl615 zSaE(4JC-32WJXcRx`B_~pm#)GIcmk+U;xUvA>IiDpxI|&#^P@AVU3pgNW5pEhI0Y? z{~?}W2~(&nI?Z7UuE;9K}g4_&;0!*dOj@uF(}Z29JP#_JdIh-cnMR+B|FMxPxQC&NGdggRBeLa1rP5XZR51rq2-X6QF4tfKN@zm_jnVprf>e!tP>09`l z{LTFBfxGedz5aVj2Sadt`pwCs3vzNnR)Ef`Git?{mp4AiE3@&Lc%|n{6|5e!;FEl` zra30n*P%qpJK|pjm32LDK$p;$-hu3vJRpSOh_H-J_*MqHLcct&0CNeApfWV@Gf*03 zE(&gFv3%H?%YrLpna`>b{#IOM@q0mM3I98e$l|yNpcWg{_@x$*NH-9wIM{Y{)3CAL z1B%T|@#-qY>l*LB5{;R;VgW;^ZX(zo%{ex_{m2B?^8iJ9qusI4bkU>_<4Y0y+K@5;fyW)zw@l|zoO=0k<-h(nvrV!?Ls zg%SyHE{qhM;2r=K1tX7j{W0SGlasCBOtvwyorRH&Tn?_x#*}mX6H_I)IQ?ufXOWkf zpBMsX-E&(fV3*#Q^({zDOi%nQv@NJ9Z(u=f@doZi?tDSr zIDIlwhu=o)a6xIC-aXql(^pMj+i`iv)voL2U8U=e*mXzgT3(Ad^`YF(wf=d4L#eIv L$!(L0f5^W9Aqq(( literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_580037.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_580037.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6620a136d253d839dde71a531f280eea0e77ede GIT binary patch literal 3061 zcmbsrU2GdgdiKY=YkU2ZI8JC3AQ2`#pL$7?{s1BhlqO9BiIG~0T2pCRyz9h?<2Ab* znuImlNJUMa>YzKtlm}7?i9jk*kPtjpyz*oPp=horS>kk3-dxg??(XTnS$orMCf34xq5#R-ai?%0+z7kLtfqNLCug)$`ZDkHkW{160M-+ z|6OwahvY(*XmPWksT8(C7Vkh23t1v$4!ut+C`%s#{T5Ah6H8Wza0su-VO&lrYEccY zM~-#54UHZ-pUWw=Sm6WKA`jbGf;^U5JvrolnGq{$wJOXlc8k;pMk!`FT0FiuRvviZ zXl?UoCDlDJtQHU7dlQ=Htk_KGfC!-GK4~Vr5AFkQt=-*fyXAAerqCivaw(mZc9-0q zDOW5>V|Gq9X0cB*r&T>k={&HK0;bJ84it55N|)!PoPs%7mkU!WX7ie?V4q=Xx{7@! z5Dhr1VLl5CbKqn$j;Tq^ygF~<@SJSwxdq99RZ$lX;EOn@&71p753-24=9A$@q_d-^ zMs$*Vo!nRKE9MkwuCWy`r)9GSpp6}H5(>ecy@2^#K`NTMfdyj9f&9dO$qBguDzYi# zfD?~HHsWBX`%&XR7y7Sg`m9luGiv{|oSBu)X|3=lc}gvq{f3_DH`TdfKVbdMz@6+A zHAC$yE=oX`AQ7fkkTRO0N++H@di2;cz`@+Csu$F}{xldSSI|g6iEMOVM4`y_$<@h9 z&$ifQkKPr-9;_B>EnlvNdOmsS)A5_*oAT|h?Zj|3G;E*Q`BhFW6f0*}C32(*)WVT+ z?~RtqNTp+~eQS$ag*8j(sWJd9QlnLN$EB9@~kumS4Yds$y2g)})Ph ztMS2V#C!1F>WOpJ$T|CLE!zIU+m%H5%35TjceCyDw>LX(_E)0=_W0ds>pH*2SNM(G zZR3aN^Y(bHrQ-%wIbPw{p4>ROnfQv@Jbd$oYRjPg@?Dm@c6Ryf%EdqH>kDfO8~#tl zo8ql=U!*@vZ+Abt-SOO=z>n-Jdyp(dicPK=W}wD;3@CaVSx6c6cCvJk2;nDhV+CEBp$Vr4{@^!TfQ`x$@Ajr)`3wD9ZGM%uz#n4KZLFS_@rcjDDh21Ars|&0(%=b` zdhmGQt#geVh}n#idnd_Yc5?K!Q%-K2EEyyr-L|bKV8_BxUY>hPk%yNKUD4%YQPul~ z@>)jD8^ZuJScl=D(Fe=FY;;!$Uk|PZ%V#Ph8&6cbj&2JBk09OiF^4$47etb@$2ZEk z{>h+DH0pmK%m1w915wOQL4PrVAkuiXZ@vk8BLu$bxIO^fCtw;Qu>9LbbrH3-+Y>vn zxIJ;7Wkk*%xgS8j;L7pki5($aW_E=3N_aISds*|631B^jHY4 zgfEq(cgfF?ehd~IG*!ctGP0pc15lAa0&Vxj1B^jnkPKp9Ca)FLB#U`nHRkiCfjJ|$ zq>`4aNYkpUIF&OD>)Hw0rU?VNk zl$KXMLyl$lEzN1ld|n;aM?oj)3&z*5)ESDRenR43QRr`|t48tHI+r`^0y<5NQup{c zH9^%m6pODPSvyj9(c>q-8TxYQX9wM55G~ee=U`u)itLCj*C$pd>I{LOOnhv8XgP4L zt)uP(S}l3R*+FY-oM?;ylsMdg1Qd$a{e&i<&d2u%v^RvP(9c3A6}vY=p~zF);ly&# G3G2UCpRD2l literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_608628.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_608628.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2853cd70159972d6fb098b89c79e84578c8d8a53 GIT binary patch literal 2747 zcmbsrTTB#J^v-j4XLoi%DwSs-T69VaLM=6wXqAT#1k-9CqRl$)46v}T&I~AIHW?Bc z*z|*K(ne@a+)pefYGb46m;e5_!8GKSCM5F1AC`W!{pz{1v&(8|?BgbL&pG!z?z#6o z?mQAj0l}zDeK+(2kI+*>b5{kt!5yc*iNpy-tMtLT&<_!2x98 z+x&{MJ*ARv1F+l}rt{bqROm&ApY*&$k|{_ErpI(oz+Ocf?3T?urqHH)(b*Hufu64$q02e?M=Wu08T zL}4~6>mw3}X(NJ#gccvvM$$n)feb-?2V5a1!fU3j`_1E`@4d@nC}dp=sMbog=44Bbo2ztD&`VlB^b?t7^bx- zd@+q&4CcczaO)~|5yJ#w#^i`z3aq*p$DbcIUx{lYdO{AV%|mi%L^g)vu{Y&GHD)yH zTBzAjqlspqnw>)!a8x}J*VU%Pco5cua3~`l3x?u~8f@Etu&wnVa4ucc0a{EC$j|S+OqDy5P_5Zp(^o zR`;^MBz-9Bug&bpjAZLuvi=t9_?qCG-J9B*u1^nq`FiHS6Jej#vBL7xCnitKoK3%z z`gE=%qx{4+Zo+Ypwu4M%FkxO(jZrNYa>|Hs5os`Xk$(dDGmWOHYsH0FENP1)ltL+L zfHoWoq2cX{mAXYQI*FR3i_&|Wl1Nc0DCB=(kC~xnKto}s+)38(ihOvY;)*6G5~|kJ9*KvbZVv;JV^x8mUI)|f7P?-3 zqiVh?^U=ZwOBMU?HQaB!+xTGM;pL^)li6LTv|13dTbD)-F7cRyo)j9w98|9o!>$47 zR7Oi1=AgQbal1+}pX*N&SDecj;63h-l|I&%UOP8RJ7cl){ zp%rn*?1j{YxrR)`;|l5C&igfYYaWh05&Nv}6(KO&m}*Rq&5b=SZ@8u2j^2zu?0zEj zSRHHLQ=E6g02JiYApFT1@k&x$EAEibuuJ&${5g_MwP`9cJ#EfjS8zmQOYE2kGKuSx=ymenp}hqY_8O{g1Vgf}23sWBPO|m)USPBWg9HL|p-4QYN-TD1 zsy-SqbO>*FLWKu|SAs*TtZ2<3rjx56PeLx1uh=^{4Sh@x+Bu^g0-3jo@I{!Qh5Jja z;~PFyUOfq+oatT@O4AK<4NF43)v;PrcjNs0`EM^+$7g!8Y~?={b0c4w)n{j|9LV|76oKb;6y7B4&SRR3@R{MTkxO1C95XO~PL1ht&AtKKZS2jRJ~$YUC^u2A zEUJm7yFD6LMkDHBtpiBXDfEwF+F&S(dWw8MBk>nhy-K;JD<>%ml-Xcm=k4m7)th8x^8lj6U)_~d*|Sax1$Hm__e^^2jsF5Hh*@#~ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_619005.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_619005.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54544af5a42640f32c5a2d9879360753a8059ba0 GIT binary patch literal 2686 zcmbtWUrbY17(eI!YiWC15XUCdDU;x;0%9D5#HgG6Q!K`?X#|&>o8DU}l(xC|f(X%Pjd%``o6>6b8Xl z!ehEkhrJR_7t{uQ$tWnq>;UhkossFLQYz0&8JedZ-CJ}%;qTd+$n;V>)3eF1_^QV- zsMiCoaIS9>w>TPszk)shx+m!` zOI=9Drra6tf$7lPCyyHDF3iLh+?l%0jJMNjU*d#o9pfDn-KnPJXX%!yk^A!8_UXPm zU(8o5aF68$#}B(R`#OH%POU*HP>_tg-PI|=e-5sK$~Hs-~kP6 z1sg21S6-fAio9*rwAe)d!W(N)aJ0u(0Bj=x7SoB4Ya;Nbgh%Y2Awmu^ z{Jge}x+w=jyPKj3VWFJ~*9VZI(?E7lYrCk0O;=1+JQ#kMm`N<0e?Gldid4 zd#iRDXDjzFI^C15q$}0%m6+0Rj$I#{s`@%XHwZ6FcwSibB9G5Ht?hac+NBx)xo&?|JXI4sq z(+&Y4ZbQ`#DJbiz)BuV-AX_=HNpaZE*q223U?dh*{Ty*0hV>IbR#}Yk zQ{?#-iNB$p%h++Pa=h|6vjd;TYb;_stM>9mV;T0X9YI)p;jF~v&s#9^?ag|t$K7^> FzW`^BRIdO4 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_620806.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_620806.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..258478b2d0606664a793a066ffd1f2edc183a366 GIT binary patch literal 3307 zcmcImO>7)R7Ov{QY0pg0_=i{+(d7(aT6xjkdAa5M}wsi4`P?* z68xZmk%32DJ7v|cACJUE2qd;$MI-d9V z8F>dut|D1umG2E%6@%Ma=jXjW$*vu&3uWs`tcw}{|6vDd(+&Z<#tuxuIja1UDu<~| zbZiI8XPAQ3Owg8X|L@ZCZ_-ojpzRxj+sK4$#g^Pz*haPj5l8QEIvV13gMFK0w%Ja0 zkZ~B#4{}drmpI&)fwZR)~>a~Tf7Nv_F)Qg<&Lj>u1tNE zy=`j4uCwcH*Fu$_?hcQ<^DMG_g@(~2ipT?mCCwVkODd~sREx(|bIV(>6X_6WyT0)KK)HkA;nq?6xX+%p78&rs=G@Xj3l_Ex@f#QD3pBuMmD4|&- zc3Cx{*Nw|t=t&w#jawCF6){Pe14Mw1is_iHCQ3t+22b^!JN<&{)n+vX6`L9wGC{O7 zCaJ`{XS6#4owf+$%kBtUQ+(NMfIasAvi;}f);CgQ%uH)hqjf}!j%n6NDtSa3Hj-AW zNusTmkw~`!)#@G5f{QTIDbqNbo>XC7g(F+3q#8}>hI+jHsb`-$4jd|t86;`M$wAnb z)q;5fV&I^MjYtX4^=JF@%}c>O&Y6{9sET@U=;$5tY36ojadi1m?|bdnp1b;7p>Db1 zz*ofSTlp>Ks#2f3R8*QjI`v8K-+FIp3wxFhb`+Hkr~Ce|a#odw+)!3!BHmi zaQniA#jY=#7kloi_uec%{bDivqI33v?4R>xeK{$A;bv#SD)d}eZ%r)dUvw@E-+rsO zw_{m;-sxNsq?xnRXJ=o@{W(8zQ@iz8Vfgx63yq7y7uurl?!Mx_liv!R8*qsz>Omlz zBUUCZIFG>9&Y&4QBhE-WfpivS@eo%9wJi6xG#%gJ=Djk?a8-r54X9;t7U1T;u*d8C zUqtw%S#egH#k0tL#^j7(vB$=3wormEo42{qKza4isEcLz{S|%%o^5^y{q6Y`eRf;( z9=u^>^Dm)@Z-5HXab1gWRMf4>v_U1^n21FU!u*diq$xM{vazz;xCGt%!j7Tgvs@Vu zt1i7dh{A*AlPmTF2DxH;9%_8|3JTcnYx6?0`v$6#+q$74yFxv?7P*0@*;WJKwh?M>bR6 zi~MehJOP3(@)DLZ0IVVaW^{y#03Ldh`i1~Tt+wM-oH9ttbU}hjCGd>+sxT1{jo~Wg z8HfqP0IrYDWBS1KYp(&Mojh4OhtKn#JOdndGV^`tegcP;`nikQi}{v9%Qw3tpY8g5 z@7=xkCYF@}r+dX8o;#d9oSVo`eAC!+$M`gHJ8`dj+5dubdbMGf(|4bBKXQCfa=IS+ zkrF=Q zm4EYWq;n^Lr}O`_9f8@g5u}<3skT-N6!{}`G+-H~1@Ee9sBIC>t%0?ZTNv%YU?HJm zG@eQt5rO)MVUEWw6M`O_GT76lt0RV{lczzQU%me72N#{5+5V!?^n1l*6l`z? z+(<=&q!sodhoPh4k}pT&@JBEX4<(h$Gs6DfDL)#sXiYSgG%e$Dnv}Aek0mYl5?s$| zpmYK1a4N3fMx}aT84I&?$NkJAhQ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_671609.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_671609.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab63cdadaeef0261ae0cdc618a88392b28db28f GIT binary patch literal 3053 zcmbtWU2GKB6~1@=cD%c@i@i3AEhu(Eooop>ri7>!bust{3N{iD6b#a2x-+)d-oMU_ z9qi01Yl+I9_Q8v)iiJGjqDqAtsYa?4izmcuURW7L)(w>`!HZR?a$;Il$xF_ipWQV9 zk*Zg__ug~Q&pr2?bH6*EOHvR)GiUxf`p*DDPblFx-kWURgvnK;A{D35kUQcb?9fb_ z9pW(3@R(!AS9M2~(YUej)-_70?2FfjIhDK2EqV$=0hQN8RnSCO4_xMk>M9o;rF$K68AiS030x4JK;6tOIEl{+s{8^n51tZX09#S#?zL zqQ}n%Um?j_rXkLESpQu4h8lZf)s8>2>a7OLu~6fuvqRN8&mzq=G>oQDoEy|> zz}IO` zn0oS(Y(TGSm!6}iNGLmDR+u%!6cN1Y6nf-iz1=-J^-ZUxLHJx!mDBE!AwqU|*Z|TU zfu@d9+8Ke)nL6c*&IsEPe94G`o!0NSZ{F|tD65YfIVGWWj4Fw7#T?CM_9`P<#_TZk zM2D%Ra~(i+c*io~@J23cX!~+gGOWvROf#F26IoT055CuV@SUFlhX~`Ep3zeJF4&f4 z++ZN81@x>1NzwTMd!W>|65dhheH@O|P-{{{@nTtO`}2{%9RJhtTguXomEB!ssjJYp z`clrC6e|wfG8OSAYf_}R{W>lkCrhMRRSvm_E z(T3u=>)j=@)PF@@__*A(zZ~7abZ)ulLECcwUHSgU<#$h)qo)hUo{GWwfE_3brE@pC z7tDqJYx1qhCG}qS(#Y*k$~(IriN7p#KjwtlV>8F*J}CaCG zIdFGZdDr1Dxb95|1=49Q^hW3swV*;l(Se=}glZPe;#q!HcohWM$i~Br394=zdm)tE zVHUkY&oebw=LJ{M#x?-L|6q?*+28T-7v}gmVGhqB2e8FFXHqD2n#H*QW3d)97V=lC zqssf>_!7QQ;j^#Me=)zJPorh?ws0q~=z%^)7JCN8#X-U)CR8QP5MDK>avBj-Z8DkA z)?2~P1VseD(HD>=aPuN`f8<`7AkM}+Nf})*aanJH zps^de{}#|=F)|;rL&d(*$rUlamHLWjn+3fU1RaFbY2n0~suyvA@Bo8}QNjZ}s2LI% z25>Yx4-$S_)3b&H14M8kF&?O10E$@@EeT6uiW>pQtXehT;{$#8A%NE5!|s^^-bwVI z0_S!38=pe=EjTYV%wMoCl-d{CA2!GT-t^hdyF2esK9U9teUF3D`91cY;$&&^;r8}B z+TYT*)A#!x1y2-?uEm-PCs%3rABGJJg`Q_Y#E0e%%$!^mBSjXvdC8WF%8K|}sr|;@ zYkP0um6o5ZilO{&ewEZDZ6@=BMHu(Kp3h=^40}4kAzj9)h03R!j z-cRS2Ux&$6ltg*>gt`djkjJ%m6CT=1G2~f`9rNyQ#STpT4;>4s%pssv_W!oSQ9Blb z;^U$C*0%wrzY84+nVMnB3B}OlPRRU^VQu~JbBz7Kpot-TB9+Z(agGFZ&6r4;2BbbY zt!nKWdlp7dXU%uy;#Z(9MaMYa|CH^ON9ju}Yp zT;J1ReX+gNz7lMMZ0>BmasJx*PcIbu=LX7L%MTS3Q8-o@bTSnW>w91y{Y~gd)b(Xw z3SRpOco2!`&vg3DLD;dRN$L{WjA3e*a=KgNY%*gyL2x`LAvXZ>NH(S3M$R7o?M!FY ziImo*_kt~oFGdQw4Hjek1ckpu(m&9SH5{02nQ7Tzn(+|~6%}i``u3H#H)0Oi=#gJH uQF4<-OxP1V*odRWqZ7Wpy9xQ$T3Mw(`7?w{ZNEBuX9tr)LF!67JCB{2 z-I>|B`@!$`A{f7B?hd`dBlL(6l@Tf=3c^gD;}AL>4E8m>4PaKo`I0; zzsmM|S8ouwiu!#tQhckvgmEewTqZu^SlHit^uUMW@t*#JV&8LAJW+j17wl$WuGG;a zJ_be9$ruKD?U1R3vOMd3#ZG#4r_c17`~>)^(qPwGhImpYzfkr$0phG+`b^Jj>3N>h z^P7Unjewf#e59Hf-8fs?K72d3FNZYF(98 zcnxcZS!0HWeQUW0RjSCWyXUHSq|i8ua&e8MWtn7tu%Z{GF*_>hBiN-HLy8uqG&hK% z9HxyF_GUD7P?JW*q>MR9lhT6cypW0!8Inu1*hF!g8zQ+=2@I%eR&sAOo#OQH@? zRxYif$FWZxGggIF#5m^470z{uhY#-S(nzc{QeN1VNy_4=v!pSn4h-l(IxEoR^n|w8 zGpc0e+!uC!&wMc|7N2t!Sa2~}3aogLeE?%DXnN!4nAqty}+CZLW0l0mH= zK%)Aokall3#W*0>xNC0sa8Owurf14M#SPFheU-0nfu`4PDG1sHZ89cnx~92lx5=BH zd(1+)-BKVuj-qUQc_Z;xZlr)ZsU&Aw83OP*{i-%RO`C8Y;<3{Bs-FA{+V(UB`|YMz z**`$TNtpPv*8=aEb4m`g2|0No%3!wd;E8>9Z0soMBp@3tOM44e%guiDrEtXReabQdXLUXGBA0J!`(*DDAILFFe0@H!#D{Zd2+1u)n(qsL_&66=3Lqh9 z9kFka0As6zG$gyAZUQokk|+xaJWIY&RF;00JW@WD6lEEc874Yw#%iC2lq&_(z>B!;wIuxBq!{VV_r6$2rw8?sX0 zG#R%JO&$8(Ta#(as0zo4cf<-XQ9*;t|@)b0|8!jajn zGh0h;hl;g-ow}2HLZ}r6(Soz9EkXsCgqrKU)4e5z;BU6vGH;kRUaV^0fa^x1d$2UZ}P(f|Me literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_738982.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_738982.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39ad71017aa57d35f96fbbcd4584bf979cbdd0a2 GIT binary patch literal 2747 zcmbsrTTB#J^v-j4XLoi%DwUUkXwfMx6trk7(N-Ql5KOCmh&JoEGr+>WIy0b<*a#;xSv=|)W$~BFaP~%d`_*%2XP4E`*vC!go^$Sb+;i`F z+<7F50)la3>bt=oc!Zu3B0qq($Hp{(Ye+!~6+wMDNcB-RO-GqNmO?5uWZQ8x9H7vX zXu-kzc!g123JV{nqM?%KB?;BfEBqwe=PB&tnnZm9jZ~pv!#hSv?lmIZ8*&ri3=SXz z-^N#z?J1RX8-V4;5S_<1p+YY>{G{h4l1xETFg>Pw9QG>GU^i{%F@-kW3(lT!7W4(J zpnkdn*yHdey{7PAtPndzR(z({HHA@n#lAMZBx9(=1-F>Dbgvve@D-x%z&*^Vr!9zz@X3=lDOWrTEI1OFKguL zB?_}qSs#`-OdAm_B((T|CP#x|1@p2d#|Bi)M&h!9Io*hBs#J=p9?YB`F|a=>8(R2E zP={Gjue?ByVQ+lIC@_nNG0dMi(bpO5!F)&Xc<1{aSU7s>bbEJ@tYR(^R)WzSh+$fb z!WYxX#b7=h1GlbV7coo_W>k*orNF9vas2sF)8)7}tS97<+B7JKhGkwNo0Y}vnab0apj0Ism2!}G_v0x~!sKEpK4!(1+1vr==R<)QK(Q05* z;*>~dZ@@EQ7QP+wKu~^a;M<|6- zR6lJv6hcE=6)Sa%UT_jMNf)K}7A29QQc%eM!X7hCO@oHQOuCb-;mJE1+C>k)RbS-$aV;c&nIEo09!B_(S(HK|7~_qN}!xm)va^oiJObu9~lnTAwDdUSU5ae4hM^>*}T^kLT% zq1$R-^`5jkR#{!w|~j)g?UCyiD`M!z2kY@J0qlobn{gaq&88|@&8up;rR7( zFDwL+{XX_V*I;(x9ruN3UxpUv1w~kHsWZBApD&|JMGOO3lTFIwL zu#ff{OgND9r7;4}>j=C_*qz5T7vVEQVI!BkOgLs>{;V3))?u%Eo>~iNCrlsj_Ds6bkHG^zWYV+8h4`Kjc}a literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_74175.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_74175.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2314375a5e5a055781b2a1e0cee80771976cc3ac GIT binary patch literal 2746 zcmbsrTTB#J^v-j4XI~(d%F95s=#&n zn|`oO+6b+Q`-#OwZEQ6C^4}jfn1q0SY~h z79D(mR~Xf$u<&sz8Yz2Gl2C)Z!cVgUp5i{PX*3|vNEM1Uyc3k@UMIr+AvXce-~cl4 zZGJ`Bo>ED-0a$*F&;@J@D)gemPkLS<$rL38(_^|PVXqPmcFSe~Q)tt@=`ZTfGsh0`R~FE_wLo) z?xQ)+QS12fb|QC;A3xXSXF%=1}_6 zxz4Qe6Wg>2$3fZ-GF8BYc}+FOv{=ZgBEm(azt}bY38c>qnxU?j)?ulnEsjtcrKv&M za43XEwyRa@HofTNX_790&32PNG=?@F7 z_;<`+OkbR9%r-u$67TPPPRMxkRNK^Gv% z_)ul-6ofL>vnG^h8s{39ga)f~wXXi=h4~BLUbIf6`f_abKNWK$Uxn3gXRYYdMA%1r z1122I`_deN*L4h@B2b}z0o!ft&7WQxibs@NC|?$} zMAO|7jVogj^@!F9Ba9%%(Y_6WZPb@j s!A$1{0+6k+!NSfvwYO?F$;xIMqWr(QtEr0TT@(uLSqki(^4c5!1q7H`RR910 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_757083.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_757083.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97495590b317ecc2430d82ee3d75695e9414af2d GIT binary patch literal 3146 zcmbtWU2GIp6u$Gj)7_ojwo8dLAXID32BZ{=n5fl4X$!O^0VGnIOvatrcH8~)&UB@9 zHf|$EHa>KDV!;QiAptCjZDK;hQ^Tt-tR-5nm^6`>zL?eGi%*_AyVKnwq!@3y_ndRj zx#ygF?mgf1ds%iP7`H$9W+W^j^cxkNQB!8tegNh&(vZeP(O`)(gN!Y+F>a7&kj{+S zew?i(sfN|TC5?N^I{%b);mPY^NfR&dv*iGTE{G?AZz;@ct_%EN&3eYt1vKbp zk?vl%^-M6qnpL`UfQYA#4jxhZiP)v|9_j2N-a~yS4j)&_tJTs#sQQ$&J{Oc zIBbAeq9D^0scch_lP0Em$)*s9Fz1Z`_}TUE@%Ou}A13gakyJx^>xdc}Q_Yb?{1tUr zkDILq4z-$kEZGWDYxzJX9NS1H41I5MLVD+5UXfl6QON$OEshN?MWYb3+BW>Dr;|N549HL!H~c)Z9^! zJ2KtNPt~l*fm}GNP?IvVBKvY%uQK`eyzA1=nT|qTTS0D{>t1NTTeom<{={Ou(0;lg zpN1X&`rMhToq03gb4i)`sL-&#;NL%YW})kD(?ZX@viMQqz{!IDWTy9_`p=uDJAZkO+bFjJ@)R=U;Jw~4Z7NaglTcHjtVpe0~j8*|4 zc9Z5o787Qx;#I?A*}cU=GnQK?%~kn(!iNDM10d!<_~SJ`Ex=DS{myAdYaOma0m1|a z5H_)7EYYzHvuT$}JHj58%Z5JA;w^5p!C9?PO>kiQDSB~T&u_B-$@?-i>O_Es*5bq!1|T!rUL(S!juQqEOMMMtwMfctVdqa0I1yl>D}4g zxl}%NZ)?jf{oB~h*kbp6_wmf(l|VzLZ<$UHCEv6s>&bQJ`yaA? zo{Q~Mean(>S@OUhRu`vJ!3ZwqPcn;>i4 z|MtUEKNf;lCkzDHRoZm89VX&2b;DFbs-Y`w@WxtTZS|eU6#GFzuayX)XdSy9sW4$F48WUOD^4xlGT+{sQ0ltY#AO1~LQoO$EL9Wr&0K zz(o9|T=qucTR#pTMIt#IrVUEC(TGWELW#Ix>gSWV)F)gdZrVk_G;`u96+jtIM75j9 zj!`@{|)1Fn2H?3Fux)1FG&6sZC_zrpEXW37TImgA*Se~pU=K* zU(s)iwT@V4i}fhbaCy(AJ;i`6ZYhg@Hc)YmL#(&#=Cy5%KXTMw`x-26BH8Ar`iXDKI)4b8bhv_khCu(#$eGGAB^YD?yw8eP~#1A?>Xn5 zpKtFy=Ujde1Q&ww>o<3MzjPqT|LK-@wylBo#l%C1xwW%2X?)5TIqq53R8 zwrO^iX@%|cY?TpCW_Ff`Sed=f&gKld95SakWnOWD+;N@lD%wQgypFnDG*Vofy6#cR zU$jOn1K9CS$GN6X!V8P~n+LTNg&e>edF%Tv^hTAB5D!ZaBc7owC_hLpfTXM85mAF#Rz|k* zqu3oA)HgQ^m{H7MIMdYviODyMr(51`RtYwhqz2~VVObo=E>XpwiGQ1^XqvHx_+S>QSP4gQf-> zbY&pk09-@Ps}4t?#bcUMA0HJ#F2ax1V^J{}lNIsgE3Y-a{3`5VzF$$JN<=*bVM*3# zf5T*=KPr*no9s+>rmL4c2hElhkGFtZ6-rZ=lH%v;ozdH)(Z;z@Qm+X&0_g?R~Pqw-Hl605 z&3$GU`(YS(nSs4lkl{FJ^s^qw> z&f=%@HSKp|4mwA$7jsZcauhp4&`0(Y)OG=k1MMB5^&GKr_IA;q7|WZ@=Eebde79Jxxx9w zOk(l<2SSJ0y5jOp)+Oswp>$}jWxjK!bFt=u>x|j7>Mb$bmdW(UQ##q6Y)|u3jwR2L zxtjU9nYxA0y`H6`tv|Zk%;ra|YvTBL$FkGA>~zCCDI|rIwB+2km3B|MlCD(a#{y{O zNwD!Q7w?6?KP%iEzrGK7eCAoT3NQ!Z+Z-gxt-|>fh+8O(67b(hH+l1@!o-K`qwOS1 zFb2~{Y_<*ie+C1|^eNbrng8vFC4LM9iNXc^^ja}+>M@wGTURt)3`&|JHo`^~$kyK4 zws8!ShFlirf{|EM@w3>WD%xN~hYZoeV+!f}ve>IgvRV&n8aWEm;IsE#Zbis7=OInB zGERLRbef5BH%um4^%k3L>mIb{;CR4nooIdJDoNF(YnEKqaM90H-MM`G@@H4fwu#OR zTlutVPUI;w1J+IXJ!%caQIEoeeOWB^5qLre;eEkQn^MVhj+wr&j*Eh^sHQ6;aW$Lw zOgO4rhhQPc?(6}GJ+X*9gRB_YU+X+I5R(TZ${T7E_#)4^7J+GU}dwAS!q4@`2j)SKF literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_780911.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_780911.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3569ff1ce558071ae392f335db4280444c730b71 GIT binary patch literal 3002 zcmbsrOKclOboS%5V|#7K`2ZmmEmRREYE4RM`3q>8w4_Nz3zVdlRx9s1abkPT>?Uo( zSaPHyYH_g3rKTW}9uR38(N+RcZV@LWF11pjXed&o_7YBkz=acUcD=D%QWeC=JMaJA zeBOKGA7$B(V0@up8oTX9=r$wV(Nsg$e+A?m(vZfb(TD@N5zeOhj4&c{Nax0FKdz|* zHC`9TgLm^7r3v@pi}&GsHs>dOnsi#cTnjMbg?K*j?bEzE8jt=nZc@^Gr^ONf#!i|} zqmclQ^uUIX;4~NWuQJmiDjhvIvR@gbQm=Ag|7*R}`}(27`v!nfLwXa9d-i~h0N?r@ zj-8eTvopqd7pp^!yIhm!{SVOOH#7spPkb{Fr@=z3`*bmlBfiVEm@~l42T6c5{VThs zyX<5VB+>+U%7=(dyfzjlAvits0k5GX{|I;|VMa5KL(jmz@)3ea@E$#Zt$?kE$ah;1 z-j(E=3Fk-~{Cixf6Z%_OoQxO>aDG-&(kS}sPSq=iH_%V+Sx zZWsoKrcq2B!t9t=*-gecDrQu3A|_GZN>hIh8>3jwC@GCfDps?jIu+A~s!@+=8Ca*> zTd`Kk4N~Foq(ws+)xxP$iV3r(pSp*hrU7Hp+7LDn)0iy>vmC%I{V>}*_2g2Tl5qe; zH^1^YP89Q22PFAKy4X03!Bu(V%nIfsg`QCg6plDjKAKFpEU4cE)5n3E8CWWorV3kf}sX#MUIw_6_wBM=hL4| z&o|HaFFm%qD(^1zE#FVDBDa*2MTKe9kd;XL$M03PmrtGz&vnhWUV3l-p$qY9q^EFj zB?>C%-<%UG$Icq_p=$e{YIM)__C@h~beV=rNy)$TZ7K?Xjk<3L4nN#0`4>@mJ+=UCv~ z4A1ioRkB%4aRfTRfAA-0LSBT=W4Z4N+N{Q4ouXT*576X@$JI>odC8L9y%EXoOgs?? zKOT0ub*PDMPrZ-avmqBAV86M%ZV#tUMB-VHF4x`)4iVufig|~qn3&Ym7*9Q#HJ#I` zr0G+sguc2J)`?kLvYShcuUB@=M?KIy5@XZ@HKb*!Hwi7p>e)j*Gdea*DmkSDp0(7J zT{%>8dP>Y&ZxyWcS-YVElz81c6-S}_nD8gSt#izlLcJ%kn#<`JKLgZu_{?u$Dxh0( zv~;|9ywWw-b>orP*KOZ!xwd6->ZUwY=)2_)m%5AH<*CZljrOhu{YvIyX0h+4f1t2$ zB@`+2FSF@RFj^WY4pgK$?^5u|g{~{z7rU<}ua7Q0)psM%&uYPcX7}u&WnXC77l64W z7v-|Lu>jx_c{kPa0ujcU%{{~Y3qg>^zXUQ(_}_j+=7%6i7tc_PUyT5T zpM{AAEZww}glg(a4`j%bz*~Lw9>FdUuq&gUMB2#eF_C()Zce5x$P+U)qq8?aQ^s^v z!#$v8vZG-A-MRHvJ3_5BPnn9Hclc${+0T?uz*Io?Wv&HLd*>|N`kB5v{^oL5rEAH* z4K8}i*3XAO8UFNmq5sTaRqXh;YCaTfDGb>+6_arc;@}-H(Xg|YooRSWC*gIWJ{QBR zvZyegvS?Gn$eNaZDu!F_y_7Mu$+Z3o z?gO9fT{lZGtqB~*-A2Kmk^BqlT;aSQcFcCH@sD!xwGexb4iyirg&7rbsb)q+Q7BS6 zQ9QBM!l>37^?Mso>jL6~j)xaoxo~;k8UiraVuSg%E1efR*BNL11;oj}`#QLm-+DO| Neqt%KeKue(_-`X^mt_C| literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_783719.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_783719.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c0fd50f051137934be712e5279efe4b4cffda4e GIT binary patch literal 2925 zcmbtWO>7fK6rTOF9k17RexMM6NJQ$$r6r_92~`nLFbOS5Kq5^=5^3dn*ZK3GW;X#x zYsrzQh}DBlFBn>h^i+N#6(poaJw==<^@0^tMblJ@1ee@W!U6TvH*0U~M2SMxk$2{O z@4b2RX6C&&{#}xs2%7Nq!;zatk%AN^f(D8sGr*XbjdBA#gH&eB zv|~*zjtZ;tW3J6SiWF`KKEDHB*q$GDD7G8?qtXHccEx@J4LCOZbO5Siky^E(;Y>3A zs#U7rOKev@A2=`f5L-aLaQ?FZv7ha}bnYTBN<<)3xg!DQaPY0aW$0Px&==bnYfdj<&_MKe5mt{7X9CBEaWE>6D#RAV!Pu}$_((jc5JA)9SS3uKzm70H zgu67M6L&PI1|ynQ2cm{;+CN{mT#MszEfEZ?^-<6R*>4# zol83jR;1coI3rUHix8{7Ys~fKnlq>K{j-6Gjk8^Mh2xhC(&cpLD^bdJWIA%~v-}J3{d8cNx81xjePO0Q*LS=BR{w0%Pre_0fAIcw z*gFbUmJ?b!&7w000k>Pxfuv-@tWJ%XZEV}(F~?c~onk8zWxKRu3@8Hs!5*h@DIPvS zw;m3(T}pwUyp^oGXflRiX@=R9O_!{_88)*xu?Amr19*vPT{0_)-$4YZ4xteu zK*lIBVh=;j&|BMyFs0%+Pxt(C@X5i2$rn;@x^vm- z$sWlZ$xY@bpYQd}sXs>_L>D?=I4`Ept+>7Et|i*Na=5c%M$C2QyA~Zy%bU=ez`1Pc zz#AtL#F@6~?j?tN$st0Yl`>K;xaioonHICoj5F8ztpr+S66V0q1HU(WM8USU54k+) zFK{EcAA)bA(bzXL=PnSp(I`s6pO+!}*fEuH4|m5}5tiZ%Zj8>^(yTuZ22$9QAXB*i zZHK3J3g&uEO4WE$lg;46pMp*oRUU#ajm#TebD_Nl zrs3*mI$t@fbH2Q9(b)*MeX!x))jL;bW(t`5H=QP8aqseQJfb{6rupJV9Ver4 zWg?=U!e_w@eXzA_(5-O{!@NYUKaunoIl|WTriIrh-eWxLXBfmtuO0PF?cV_gc@8hSo2Et6v3~)g CN_W2i literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_81159.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_81159.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c38dbbe7620d59f8f2ed411b0e9d54f656ec3d GIT binary patch literal 3121 zcmbtWUuYXg8lTyJT4}ZVW5vGO#Kpb@J8GI3AGvS`Y5ow~aa!AftGPB#VY}6?Y}t}l z?5>^ItY8hf#-a}r2S?TRp)M4fHi7HG!M*nVUZQh2%+v)#`{FmX&cU@${bpAy%fTT% zZqR=7ee?a9Z@!uNjs7Xi5d^Ke^uN;h$FGbL@eeoNP-d@BjT7d2I{v(Cg2Ww4L@k<%aP|;GHpxAi8$gK%+uk)9mEVc4Ljn4)|_iFSdKc9 z6aGbh;X{716Lo|oh*XX_vJ>*SxD$i=p+7Mi$}+tW-HB6~j`;E#tgD=GqKB=@Nk=|x zSMG9FQKuB*)7^xI(JD%FGbSzR9^JJxCR{^zB-Ok>_u3rbESf@-S8d7W^1 zL)8dx*@meTJe%$zc$%;`mTeL%sJ5BAqgX&S{mxVJDv^w3yRGaHR!y3YN%z~NN6n;l zA$%#PDFweP62Zu3EfD=4WSS+Fy}oGDt4lc3xNQ+GW8^IoDk!R^nHIr!NSIFSNmdWo z4IJ5injN`qm`hek&FCWwYGz5b7mVT=bzU#pBbJ#NvGqb}1f-GHGHs}bRWdC7Y-v@2 zaRrub8$~5!Xu9&d(b3=j9u!1a(#@itH-8Ni({))*ATD~+hvay5x-wlK+=?D|$9JN! z4ry2Js=e8e`#)tq3x5*6KY0Jmt%36m`Mf)^{Y=fSoT_Cj3N;BBDf#EY+H7sOa;g5u z&9N^AH?Mu5G^Cf^vE2iWJGLz)YbQU%)C@GSc>cj5~_vj+GFvAdu2C~s%iCGjh<7D#Hst|zC7JH{YoS8ihF&B6Yfo}Pj1|-<4x`} z;S=HC-0(iE8D&}*m-2{$qz5TkICl5YJ^Vaash|qZGPdVT$>Ndl@Q?|PcA4q89!FfX z3a-F4{~zM98e8Vz<2UdIcxv2Tu`JlZlYtz;gAB%D9cEE#r`cvSJ^;IC=!LePeUAUp z`nErr1d--pY+#>dAVGHK;Qm9Bf`6ePX)U*sW(hYv zK6}N733_xaN(Nll9E6^T7xHT19ZkKs)_dDjOC{Yrdm(RR)Vy^O7+-aUdYuLGUl;9& zv8q&&Y7_Ob&A!I*m$$?-M+|+>v%P|O0$ev*W9bS@uWMsP0166jSBAAlL+FajdSZ$+hS~6lz>*{id<8-#1{_pQZ-VE)JETv z!K%&1ySL$&?1kTdx+Pv{zJQ`}_f@Zg8OXEU=z7l$6H3!u~3|2qyx<2VS8BtMg8%u^tFgh;k-*-A#WbY(Qn>H9;oxWO{vY@M)c@E4hS_-#)5W6W>LWwW0dZR%8%%w142^n;+c#=$3nJ zW4gih{Zca#MN{sKw^Qk;NpFey5)cyib2*!b*MAvaLm~#8N#TXCi#eNwGe*&}^*bfg zKj&<&XnRHQGAEK>0A=3DYo8vYR0LQ2&5#Hx==h>%~$fx6qSy(q#wFL+Gi0H^&?!yzryj_ i)h2>&Gqq2<{mU3(`3JENr@kMxvb(d(YoSshpq6xN<>A}Zm>LnV7EFCskgmhd01NwaW{wAFOQ?{X{V-ZJKER{50(k6O19ZHX)H8{#5$aubw;mU?ep5ag#mgp65OH+;h(C zucGKeutF2}h9n0<&nZz@On$RI0GnAPBN<0fPY&ZAY|%`V?cp#|@UZ2lV4A}+a}&+y zSv`Wx-b6i)q8AP*CxaAcQKMuW2lzFr-yL%1%_xGn5DUNc&mf{%q+`mYYq$_$jQ+dw)2zogd?|BDyu$L7Uncnx6$GGO=&2gMVR z<7BCU45z`+80lF?iDAWS?Y%a>Z-bvKGcee2HXFXU_u2SfqsUK|+cLgIl^JD5sm$Hy z=V^_=N>n}=wtNN&m8*d7Y>c1fTLTb54}cfXsr~$V8JxN5H2cd zBZQCZLy8)}ssN&ZgD`r8xDsl7P?e&=uuQm+93CUQD#ZpBB1k#RMdFf7cukM1N}!V9 zF2bG})rmJM>1y~!Py;9{H#X7Z#2p{ii$WVQPK0x(dQJwrh|m`7JbAp0@QJV-jOI{` z2!r4c$Q+Il942^}2=RddjdCpj!6-%60TRj#rpF?^Yyqo zq9vq|(mW)EMkIYG9@`@gDlxrTQ$x+V5=}G%)ttYa4k^$QaZPDTj0b@qgv;vjSTGcq zmEhq6`&y111PKvF6g8$q)SYlHO|5nq9wCKRJ>KcAsjh6pl4pDB?*jVWH}- {Zd(Db{yTN%wr11c zyztJ$CUbA6>F<1eaPfSQQ?=BYljx^MC7z+zZ8uj*#Cp6*<&teWkZ>BzRvRW6)+Qh6k8WutSdGxKS7 zbZ-AsPs>X`D&O&vMc%q~NFWMPP*}8ctSNdP(hAcYCy~q~v8`9~6Ru!DF?2W{VDutu zc#G*Mv-i0|FC`h9Kl+H97IZ&Mr+?v(mH8y6mlWI>9Gz?iW3UD@>@ILOuq+g2`rpVM zMLGKh`>o|0_K=zR{5;0UV9%j|vwQ7B2xTn=0H(mGEWym+WqmvWb3j(c!XX8w2F;)~ zW2{|3AUw>3(2zASn*(19DrneG;BA z-M;E7%QR*ims}0-a<%ohFW!kuNbvK+tgj^WJm%~jF z_}q=cH-I<`JC$|^VTZ%IRm9f4tzd~ecjw?>JR;vkmfPF~okyc_c{HLNQ%`^`{RV5t z0K8x^#?O)GHzfX!wy$8v7d4YLFPU=Oy3QiT^MdYPy(?nh`uhlrubefw;$<5~zMV_n JU6XFh|6k@`bx;5R literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_869907.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_869907.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f0d9ae735965ea4006f0ceec431740b66bd6c10 GIT binary patch literal 3087 zcmbsqOKcm*b@s#MXSpOr+Ki~kfgHt(8C90uG_h5SO2VqF+8SqGqmpb*mriIQII%0OM9K(D#yVn%MOP94BNa;ShFm?(!Lr@qM*7{oH;ZLZ4B>UEVIT{UL+UDpHZk6j0jPnKWb5Y>`Xz z48ly_w&SWgyUJodAN=MSrBv=)&iN;wX9d-BonP-7N_)Ys4{Z3dyeeGh(|#OKz4HeV z(#onYEvo)>kVV*!MI2mb9i_AcCn!JIHPN$CsVob;XL=61s2aEc>klRDgmLOv55LvqNqt4z7UCZwNODEzq1HtPk9S1vcmzREO zkx5EhRDFgB7hX-jBA+M1lzjS?lT#${<8!a2&dT&O zEacwxwdOa4TEO(yL^EKDOJJoIG^j0mnSL{9`rN2Zp`Tp@zGO;fa2~Ac%v*lI#U;RN zAsZ6_3!BIcnJl%n1Xj$j$+>tfVlqJUE)8-X7q9j8_zZS-FEg^^bF@bi>U{Sy`^>1D zBL=;7iARel$D zb;9Hb2Ry<7g9y2jTro7_rPe9JRh5Ds13N>HudiPpzf#ub^@@_g<8w-8UNPp%r7>j| zmyB^;%ZwYiSQ!Uuy!$N%e6e0B>-bP*QHFCFzRf68msBhtdwF8w@KN9pVIFHGT+sHx zweJ? z|C2pV<$mR}POqDM0(>!9lQY>o{kmOC^|*cVEqq~z&wYpflleRP zG)@iAde<`sncM}G^v)1Iv!E)sS>jQR#R?{ZimSN{-a?Ro%+9j}UemDBlHNWB;Md%D zriE+Rk4l4an>)+Eh0IKnC4y7IaJ8DNCV9dav7y*~=X4;g0G-}9WLcI*CkOEzFDS*! zsxrBhxS}bQ3f2xCFO)OTW|M$8tRz)?3c$ZD^ltq9r{8&c?I*XgAINv}8;O(c{**@N zjQy@`(GN=s!~@SqX3l<7)DrP#;ZZS;yi7bxSS#yxClSGUN|M5(7lH^*XA%zjo$#~p z)Y#7Jw3^<06I%X+pZEQ7-$&K^ z(u_6zAP{bjHb$4Lt?Iq_$Q}HL;_c!`)As{st<=N5J=U2|#gsMmfDhDvaP8ctFSO|s z0XC(Ew5)9S`oD>b%|IiteB`DCQjbv3fnW5iSimE6;wTul&e)M5K3Og*@RZ40C&0}u z3xGC6!F~W4`0p&58@o+^1IQ}Mp&AUSH|SttdfN-e2kU0P#+h86>g-zX7&Vc~9tWPv z{oi(YYR5!SZ62t$tq4#WEm|TPSU2R1qGS08For1OrCp5gCs{jP#FHtMOBkxtt6_bi zV89sAb4yq*LW{_ASW&ejpz4$)YKAdj9O5kTDTM;WO~-^&F6)FVDp&2VC!!NXdA3|oZzHD@2dSp-c(JT56!4@r z10?NB{a*k&9K$f5q2ND|^eGzb2#EFF7`QgjVH3=7CxoPM^IYRxCrqh`OZ8DIiuz)! z!Ij{bF`%|N#0DLmV|`3`In_b1t;M$K);WmyV14NmF0mf4!$o z@AkIoW+2i$*f@Ce1(>K}qW)5IvN72jZt3^M;d}hV2N!O|Zt1rZzx(;$4jgSCIJ$9Q gVuPQ!Ck}5%5hH!+8(?BzOfe`tyb*ftnrKJkKeIusivR!s literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_879575.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_879575.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fda5ec4ffd0cc227c56c06ff5bbd370e63b7b984 GIT binary patch literal 2996 zcmbsrU2IcT_}rh~cD=pr)@>+r1QyAlpiBrqg2q3_hK`L0h%!iSZhCLGZf$RK?%iPX zH0ctIG(NQYLO~PjlOT&gq7mN~o=kk988OQdlO^&ppQghXpZw0fy=%b{G@kUH^ZkG4 zJLmhp(;pQjh+urAe>Hm7kI=6S@kUb(SosZri%3U0mqNqt$_;Z4&ZmW8kwXSI=GgI6 z-BsreaV+#OjbXa*7`pfvx^HcIBA`p>#cMSO!+!7=0NVlGZ=kX0Km8^oJ#bze4z5Ph zbRG@MJTl}}8=)yK7F=SgLsUBc=I|kPkV?Jk;X|+YQvVxAjvgEULJjDRH^%G(9Ra?T z`y4wh8)kQm@g7uPb?#b?p9?-gl3$gSNst7l!B2yRShwjxI!6N6YCfldmsqXPB0i6ki#D8yQWv=qNJ5Hua(bW zgWWI;4o#t$IE2|TFR`19aa2re)_6>!yq%)KEH+25mR6HGl{BnnMhq&ZOiiag%Qmq= zxf8KA$_-NC=!8wfY0bvTbE*ZiZk&4rpQ5rkVXrb9kSWZzgINw>mVTJUPJP*=uBP1u zyqk%H1*p3M&UVRgXVoz)aw^w0khQTz1I(KPiMnQMRvSocT4Vh6PS05rk6T$SZuE?5 z@o~)_H8am^BSyyVv2eV{HqzN1zb>1{tN*4t_kMNHv1;U4{_uTSDef!mE4@@XaaZ1!@4Y8VA0D1LeBr&4 zUD^C)_ZQtiiLn)k8&L@x_5{wv8@mYWsZOu?8glPRy#{KLOYk-`=8Um5d&Vs9<#vv5 zsE@U(i6c<^|G}Q13po)!pY6RUXszmkeTrU#Ht$-Sjbh-zr~tg8th4x;~sb5s7bBx>kE7I7EcwDCQrcVthi^Vm$Tf_EgrOl5R{U zubq)R3N`{sc4=d*@#2n>Mg%QOPYOu&gI1oywt- z+f8ErdXr#9&l(LCpv3FesW<|C$AahIt8=^&VCp}MwQSbFco$%s;j_MlDUa?c(c2i0a`_B5{1p8DY?pn~^3y03$-jlWVCQl8RtGBrr-Uv8a1kX@ z4*uSpVUHcxSO9Q|yq9V@fe2$v=aJ$4c_2vVUk08o{BJuVvqKQ1i*G2#FEs;(cf&+w z+pugku33h<6Eb8A(3W0%gs}@4?8>Mwo-#8=Or(BnSQ9B5^2AC`8?5{5>ZqaV_yrKN z*io?l?%sNveJKt{tJUuvE$#01yHCpKjhq0Ou<{h4}K0N8gcisBL%PM1iUUZ;6a#G78S;l zHf@TV8Ot`#WwD$6LNa4h>4cH7OgDd>2&n8HggRoT^sC76}9E0^&oig#&F|q;zl@!BS~;mMiVoJFj-GFv`jS#3_FSI=I%~dpQ(&c0Rm$ IMs_y*H!IANyZ`_I literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_892743.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_892743.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9efc483c01020b0c1384a73f1ec0815aa442c2c GIT binary patch literal 2610 zcmbtVO=ufO6rTN+ENLbGB&IQOW0O>5^JB+}NmEGLGA<{s zGw;neGw*MHmLwm7)@gq|{*#E%6FPAlUu8Daz$_pQX-pgqS4L)-apr7-8|E3LGZPMu zyR|ZE>`k;(T^bfOPWNcM?iuAZ;U+)q-NF-ZqTxCg>2+JWz6>LK*6Gej)LR9hg8-d_ zzFA=CjZC1GHo>~w_NXySRk`hbfhN19S*LNf54zv-TD}=LuV$xfkL@GCB9?EddN#ZV zeZ*>Ll-&awGtljTRcHBM3dLWGLUC$PBr9NvlW>Y9Sspshz`N|AgwVtSMv+At|N1u8H8iO)O79yCMWSv@n zy&j0n5I&(Ald?cqGfsSIoEpPwLWyZaRI!>I(+M9>sTvUsGlg}zkuXDqyEtW%U_v!< z?1o|h)$|+B%`?QGnliV9HIobxFTFE-UKt`{uQG7{?Or0JW15nv3`ycScj02sAc(6> zgovrpQ3DhzV-i!-5HX>oO>7V^J;vdLgfgwhjYg2R{iFT)LFe@po;1>GRPP*Dqm!yR zo=P57$MmGxY2aw5sVCB%Fzc*7%Y?@p>6D>&q%#VvEAUJ+l~kfBO;=96as2h}6Cfet zq>ht%9KQ;u(t8;zK(gp*AUHpi8!EJy1ADXQAOCH(7OcJA#q@Bi4ft+5tp%wNu3TojAVcS9xrv25>}6wEgl zI!X-(OVYvO$>m7t@VSz74!}a7T9_>DmP;YI*ta}VI&!uYI-4D!x9!jM=ld3SJqm>3 zzPxyAV0K{c!~E1@_b>dxO?VSn+=$KS)D4<4uB%#<29n!og>kaK2$8jbEM&75V{&hqX3VC|)9onVk9Ttc!KA$)X7 zQ)QM2nwd%K;9Gq<7S%DW=1!%+)F-l+2$22JaUwugXpp0$P!r6qlSG)&ZKB{=|1 zi%mPZ0MX*q1=)k!C{{ZqG&OM(5XfGqcn}VZer zt!U6sD2rC5CYl&P9!rh$Be{`6crm=PNB(~I53S#}KAe6eML^^W%^%Jk&QBMnSDM53 z^=}gQ5)bm*C<=0NH_|If}AM?Js?%4~g z-e8^sGB4$%yju3YQV8EUdi!XRDYqP2_4?<1IbXi(V<~Se%zQGl*#42rs4<;1`3#r^6hk(AQ`aChY{q7{gB!jG)(t$Hv$%gbA>cRASA;yP_8PMqOWpmTyMUjvdwPRRtTHt(GU-_G?t z^)=+fg>c!|p6yv{ZM$>z_SMfvvi);ICBEfl)x0Rsl#Mt(lKnUg=ioy?NT_m`jyQaN zQ}90!ughT?qJ*1>nUxUdVoB3^vvXbIue?_oOU1Q&$N{cA(R(_P(x&42oA@mNpkJ?X z0muf&Fw7Ga_!UXNp}lL2=T^&X%QLo~kvBQS1~w?SaX?~1`JN2~!(!7W9X5|3M*7{` T!ZbbWWl(5;Ie1{!@0|7*iH<<~ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_917011.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_917011.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba90205873952bc69e922204831e4244da4dbdc9 GIT binary patch literal 2737 zcmbsrOKcle@a@+d+iU07ltw8{1z6dJBu$A5NGS?!ngo)#LIp~bN-N8|PMr9s-A!6Y zYvf2pjm06ZIK%~{8cqp`)PN9FJ@tshrIsrcJ%SXez2p`exNu_L+G`sKY7tM`c{6Y3 z&6|1idHbU%IuQ)&!>^`aaUk@F5cc58$o`jrtRMv`R0K_wW@>`6XgbPFuoP0M8B347 zwKOZV%Fej|x<)94Ie0z9D(rQ3y}V(

V38@Tvgoj_d41UBv?7I+}3ONOe|ZT}eu+ z+a){w*l~93!nsj-+(!JtTHAa;P1GKKZ>ncx}vh*nUDJ_76o_)>1uW!62DmbzLkL{E(t(R7>qESzGBri082 z;7`h9B3P%1c5)W*7#mM{O$un%ra_Lg@l<_{N4)JvbvC`GXM?e=Nj20+3sor_GwaO; zOP)jn28EJHV*MJql3jA^5{21lK%bR3OdApGOla{bEfAH%3g!b^AU369HWCjgnA458 zrb>;N8pq6qIRks50YeL4lXZZKdhKs?61(DaMuk~JBr$*St%>uHI{ctKa{kRhjo7A< z;9)KiR^(`DieWw;3h6+W=9mVxgD@YJ6NaW^fvC4IE-KFlB6=fm+y62Cc(3PbT$|Ms zfuPzm9SF__jOlo+D=?+Tj2>MJ_84k3(F0UZ*&PF%))R4E?M@_RSeL;)BOa53aYdD1 z?|t=TZ!d5#KdWjnHKHAbQ%RKc9{{A$Z#)vc%j22x-0`BjEq#95?Wv)5#K!EUjQqKF zJ9#U)+ORrYJl6NkJ9pX5ONC=+3gVgc(AEJGoA2K1txD{;8?*0cF68;umM^>o*U9wY zj_Ap@pxap_u~8W^kVac#6Dc16+hXJLWM(qw%ljT2 zk-k0j{n4+F-k*Oc`hn&2E_Y@+v-7$62Q9u0^-gpxdVlDlb2NQ!$5Wpk-XdVz-MBoO z8O`x|N73D};k(nh*0~wFJ5@X}^uRTo9^7V~OMQ!DTY@Lc09Y0?Vm43|p3C`ecirk* zrHZX5wglI*Gvmzmek^A7m4%xN`Qx9Y$Oe9^{&4>A?=GI>_dLk$O~0)@59SBqs|->V zR&9L>$O;Oh6#NuifqF_&DY{w_@PyEI?xh%$nJHIIMGqkM(`_P!?gv_7{cnK&AV3(olh8zs&s@yBlyY0Z(9RN_vI0ZXM+%Ol6#AB+&Vuz;ca}h(Y$&GFh z(8(o`$d`iNZ0Rm{;(1-RQc!yhWcrBk6#!{uMR3oJTG|#N)JsE;oef!E&R29Ehv*({ zzkT`E<hyBZp{bhXeR(*Z^@SK2)vMU@R(r1#x(K@V`e67 z;JRQurW@+DgjOm9CLA*?_btn@tK_;o6^|%u$U39stmY@^FBCm5F1jo${sgm2S9SmZ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_930305.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_930305.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a42c3f3df1291dc21b49e76f4e3b7de7b7bfc6e GIT binary patch literal 3115 zcmbtWT}&HS7QT1>Y;2DW7-%6u0kwiS>4xT4TPY=*4GC;gsG6pq2bI(2J-0%8xQ4AnxryD<>__q(Cujs^Ud}Xlt84PY91t~a%#@rr{VQ0>y z*)a|y6;C>LJXN<>7?qn0J-tR}3R{0Y!71ET?oLHv%*P-VO*Z~m-;+>ywLuZo2GH}> z=moYs^PG#pgSj>g2V9g%&H5#7M5Ur3U zOo5GD*b*(DGm2Qq5+Ppjca2Ff?O@qrsYi4VE99;#sKQ}9POi}kZ{tOtwlW9t$Tih? zD0nBS_Ew|SbeHvf%r#rs-Lb|O&fC%T@*1<{3-57q(VAWSRqNi!idxN9Q@#3xy=jJ_ zVKj?k+=xyCyiPM2!-PxA##BrojF}>Vtgemga#~6%M38kkGp-UYrO6864O7!qg3rZT z2p%Tv>(eF)r)5)5UX~1?ih6n5e3k^YX>*HNGnpmAo4*=62^A>}NJA%&59riAo#vMC z*`y++-41+6OIjjf5J9Q*)K_2&POlKZp~t1HsT(?6YD8F=62D@Y&YIBzYh88P_iy!G z()1}KE63Ho2{}F`n-f~*XY#n3G5f$opQ)y^eK6~*9MFVF7+K9w_hn}#SeM|yrk0W7 znxaat9DMoU%ZGtOgeg_es42Y*_N8TD{1=Fgo^&8FvNW6@E_N@6b`~#{Lp}D%)lhi* z(wf-(-q{=Hub(gew$!`ax~DAeu}`hkTU-;Pg+yMW+{$1r(pWfBeDNb(?7e=h^y_l; zU^#N|UgFa?9&!)fzOVfwTs|^dj*QwvE5W8eb{EbSdh^GMqosk5yGw&NrE>5kdtfyX zTI$XB7Iv3f9tC>r6Kj#?!f0{eqweD1b*Xfv+>9vU#!YR_>s_UhcrJkG<7|3B>Ui+B+<0KNS=!EEpxu>=QC zIE%Nq$)LAd?G=6t1?%&LEk6GY{rdbZeHtN)vDmx5I~9z_$l~8bG5^SV7dS+?__QL& z7{V*&Y*vMPrOqVdYP|MqQTR3|fV|H<>q*En7Nm~`zyr9Npw~+x26+d7J^2I! z4zeRv5CH!K_lAMLN^&N~5^nh9x#LbLIYnjA`suQ5eHZk^e>5ehf1}9#bM2ROIh$4W zeMeJTTuvGNFmPFWK+xC=@a5?3?PDwj4@yJCR9yO5To!(Sb#~w zjRS@n6vf_s*MTheXuisU^h3aT0Y2jvkne1?7H+f$S14H(o0l%+FBE%9JzuoPK56}X z=l#xyGmpd(`_yV6va~n9w=h$j`J%1ouKKt1?exP_j{;}x6HhoCI%W?%5fC3-JT(8> zia)&K4+32h^I}0>_CNphIJgwZ2MYWDAc9nN2s+`F#^>PwKdnJt*yuo^h<)0LTO9n{ zYV@KySN=~hxPg)=2Q_mMsv`#%zgjQwFfv|M`%7CQ9Blb zD&$9E%sTxgrXK-9f~E>@NnAEmX+LD2#((`6+ZYFcL6bxHcuLEtF^>3j)tFA1@M;;! zIhDRviZr3hihcmZ3>pfWU^l-l&JLC5WX6!3n$r70W{?VB1X9WMMhLa-oQI??p87h_ zROl)8EC;#)r8~QBoxgehgA4ZH;&7Sk_+Q2RC=|6voJ_?;{Uz8(-wlLBTwnI3;O(D= z7m@hAkxqYe5Oy+Yl7_gJF--MxR(I>%m(e8n*`N?MnK7McI7c8sHxkmgmQrpb_c~Q> zY+6&MQ)<6Xk4e$T_#Kc97GwMs3jGs_|3W+0uy3JbzGH*wz_E>H6mDF)n7_CYrL&gG z>~Sm1Hd(}kT&bTUCStcPbk2AFe%A)mh9iX&8wk2m^fB$IWRtVL*^jXJ*x!Mp-wa?B Nd0{#H;(X9?@CPAPs$>8F literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_953212.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_transpose.py_gen_triton_code_953212.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaed47700270b5c5e4c0c70d53de5a42681ab40 GIT binary patch literal 2569 zcmbsqU1$_Xc=q=;cei&*)KG(|2JPYclNhyvpw?qp;?mPCXe(~URS8nQYJ$W5dpok^g<>dFiUjKVP#A_dy@CIvQrT19b{!E@(H{ zfmhNJsMkXE)M=6Jpe>Gr6;@zH@JP6Co4V)ns)^5#O2f7tZzrxrRiEm>wQBfuXvlFnt;# zCe@6EO(Ib!@W8wXuyVg0DgUy9Y`3rMmspU?%96023FUz4qOR=New6`qpEgf5kmiM12 z%V$7Y7+8zrKI(06zPHdbxAkE#0{-&iCwGanWv=8pcx z?_GsEp%?|;pnb%x_7Y?gbR8y#bT-G-Wh70r4Fjrz8D_1u)3ME}mV*qhWo(YE)3;eY zFg8^5f6&KS4W9AMI-!kp!K<#9a5C$h!)%VXWKXZn(+*>^HfOVAVefQ()kUwL{R`c< zM(184zd3zf9$tIC))yFLa~F{!MORLPD=Q8F!pA0bO<{?kTiJ|3MBSK-#|%vSxZ5Kz zjgBG_0qQr3AZ-5FYFx!+uQ!n?ev$?Ta`h(Eg@V;M2QT7C!>fHa;V?+hKoKF9NT&>iCq8VL6A244!i-NDv|H=ysG;e&1K3RZ6g1e? zSZ{SL)Eerfsk#}4j|0yU%6t`ogWNU=Z`-nC8X`Q?^Tgj&h!i6gf4g&bxoy|&E4QwE zIpp-t43zoSS7nn>aElXlBc=Fp1nl5_07$s%OGg4eyb1Udh~#0|N9o*H+;WrKy|n8o z2~;nvj-(U%UF7Pl-cCA^OzRT~;}m`iWa&?9o&xZcV;JT!3jU7dKhTb4#`j69!{CA4Qc;OhSyia7fK6rR~#uN|+~HYBAbB!#pPw9v!~iULssDryL{B`Juu5(a6z@@{@@uQj_# zNVJw5iHbx$#Gz7S=z;VGO(YDe)T$@$T(E*rG^J7`xa5|S_R>?|tbbxcDq3}& zBQBQCO8Z}{>#gWYjvwxizT$M$&+P_xjyM8-9Et6+{nD41Y?vEe7-1tyvG zRy(dNNy`p?Md~R_U-2lECFx+*Z2&=Aha)?ICBAZ>P?dIdhDv5m&X1%FjcCaMg9>p|)2YX@O=2ivihC*l>8MSE z3C$)$S5ynAZd}>IPf*DmwJX9ZW`eL-2umvUq=s}gQJj)gh?#MV3d9_@h?j}QgqpI6 zMMbkX3u9W`vce$Rw`KhGd1TxqBUVa_8IeIPHlo>sX0k;aFp_q}BC&{VBvKLJBBiKp zxSN$SEu%Fxp~Aijx3JBm8Z&i6?L5}r(S8ghR2(r#(uk8ia3)KJ^#n);{ozM)XtsBz zH`lZrXvmy;8VFW#&n0=bbEY$UJ|`|qdo$fnxA3ch@bw+p?(Caj-%|*@vtZxuzTb5F z^!MrfSmEIDf^;I&y($N@b=S3AU3OruapCR4jz~d{FHsHhubLoq|a zguP9XxW&t!b&ah+w3Bgo=Jn4 zmJ4bs7VB5>RF$PzEE38?2~%MJ>f2%+->o@P4=d^&Ct50H5E227J+SpHkPHGeGDwC1 zZt~9AOEZ^p%?r(sb}K)=@lzv!aqO|&m+5)x56vE)Ih-BKjXm1cyky)>+(|s>dF=lr z^Wmx#V#xGQw@;m2@ddLykXd;~&T7lPw{p$5T5h!5#>@4GR(#T|f5x9}`&!Og^OJLv z3r*jonSr>nvv#Gn?oZk45!b^g5XyW^8X&ZnfX&{{1@=u?%%dR$z{Z#1G1Isjj?lol zN~7f+ez*i#Fsh95uhW6j=z?9H|KEND=Ep$*3{RiJt+94T+JI2WHY{6>X_le3fnXmK z?A;UPa{&2_DNi`=32?l#&5;s4s9GkSbpKF$F;wAV-$LK LIko5&*;l48=Nq)*1DCVg>(X~?ZjNc4p_6rcLkb7yv7`7zq`CVTHW_uTVy z&pGqm{Ua3eBN)fTpZY8xLXQaJ4X#M6<$<`46r@l|)Z?yH59RQ5is@k~q*DEk9Z%I= z75WOAEpGMjG*VH2c=HLuc_$P`6%hGxKV%AhLn|mw?*U&HP5cnAR)9!|tSDhT zNjAJqizoKCY52(Q800A{mfV=%eq$?ai3&5v&XWAY$VM#6wN?@t>-l%Rmqf;Sl5F`l zEhmO;8cw`UWECZmm3D(02(rqp$WC2qRm^c-v=kaeQNBwjC0ij~(OXnTQ5v%;*%-i_ zX7;Iil+sC)MR`n{N$k()TCXmrq=bTbS(nqjDrS?KtYFSCHC@Hjxo9P(Ix+LjGTaCAWyRj;)ht1ECbL(>E z1^PS7H}oMhvqgNNkj0!Oh=JMIpdtfEaEdvaQ8BNmLy4FQFhTO>;E6*4U~qs4u|CYf zH=(4lFAg|iHa25!Ow~2R>3z(*&=BP|DoMU#f*P0sbiiycK!-sP$oV?;TA;9RSeG*y zRo5GU+Xp}6TbOM0I8-rtVd6r*e!Bktp6D;Tf33Mwb9d-LsLO79=ocpsOdQA!<%jN9 z)z7Ivr*5b2wmtBFVxL+Lh<3|E)_?8j_~|7fyd(r*o(xTda`K|^%4Rw+>7VfD8ovv{ zsSOfz!q3XfIDWMX1x5Q~{S62w2EX+|a^rO5O(3qL1j@pTc@cUxOO^H$ydbpq8fF=b z=`X%5>vn+f{u5Y8p-%#@F#p>QOYB$(>dSRS>6Hj@dLv9YV5)$SG1*Y1MiA7(-pc!1 z1Wi%esRQO>NiD5LS?tqQV=!qNP~F6sN^n_``czrb4}z*eQbB6zmU5Tlp@@=0hUB!2 zege*XNDg0w$wtfJh~2&#L{)pop_JF!9{bC4_4)cme;rh>ruNqP8|QzxV7FiEEU?ut zdrUyVO1sOcR5YmX2S1p1!@#1u9!kRVHV9ac1&`9SdI)tZXas&u}=Z|B1S?gTvfn8^5Q(vk&706>g)KW!Z)XdWYin{9PGmqh$>+?el&t9ZP zExhF-8nu#0uvYi5%@AJeeZ0j9%{+MgZBuNDWtK)LZCC6yTGXRDkQYL+ELE({f;jEqiF8Aaj`2roR255zkgM}3{KH4v zZPKurcFNU=U@uCsF5!rv4WnP*f&Chk+TH4VXf>?q9=a5yIG_^H??@hn$ab)$&^THw zRa{ttWfciyNYd^fX_>|$@*inXm^%x(OgGl#p-xRHZl#eo-KLC)(oG7XxD<1B#~m=! zt)-ivVrhrcyvzVlqRKRd!%?*p1#dgbLD^toFD#~LJFxk6t?PBE6nuB z(`_|-)Tw!;RdFh9x6MX;1e!r!_JoMb#Y<>3FV9Q}VZ$#ksZnuG8rSHgJSK?uaWTRw zikpr|UP3hEsYN5@m}Z?2lVOpI2O_-2a3Yro3mQEuh)GE^hhj-i?(fs+GMj=aubHH& zxJFBIQWU)H8to5E$f8CMmWLsIAWmq+>zYN9#RxA1xcC-1UM0_Nk->zB%TUDirkN(d z5sR2S7m#F5lr<(93Q2*4W(g)^6_$w$1}l8g;{T$gGi?X zDIo}13B;#ja%7@sVkQ8x07&FyA`nbMCVcx}^7;BeqFLgpz_=hLgqXMoe8B%hIs!i_ zgTAxc=l16Is>G_bBQvz_YRZS_9hs4JN6WRw>+K8e1z#b&+TK%i^kk0QbGb7k8=%mY z>&hR^|NQeeZjRmP<<*_X#J+yo3*UFJFdNP{q(}=!sO!ftIzZm-F;bR z-PNozS0@Y4spl4)i{4^ef6>*y9J>3~-+Xs}btn0irMUl8(RC_IuQ#?`b*sZ_>wIs) zUu@jH*txvpPm{~7w|ejTi`}mj-3PPGhP&xX>wK%)TzGrAf6YCRW$t0~h4x}wchS`i zM%=CHuKA9-D_Prz;#?}1%7@fcf&ao^2rry3w)8BYT5}E**@2v?)QnoXOKr%}m_MSP zF7D_o+B-9c*BwoHE;pJvvOzJIj$JsG{bhbg4ShrH-h^}_jqq#4pvH)TJS8TA6Y)wmubug%Mo^-WK<&$B#x+p<)-?g*rq-b7yO ziTk(=M0}OcKI8YqrzF{PR`B!0E0$_0e5Fz@l?-&STYD`c(4a^I`O)nK=nV;T)0 zD9?zv5X64mVLa2AGWL5Ju?O>dao2_g#E^)SQFa<9PDkHEWyi^QO7vqMe9~9Yfv4*> z``o_VzRLr7?t=rW|N7yD!=D~4OkO>{LXE6+kE}v(>%xon=l17!U4B_5t}_daI{B&P zYum0B>gdW(kFG**+lAK*t@%Du0`J;Us#;{ za{AWvU1I6|byuT0_-`{sGbJmsJ2J=gv<4xdYQu*V=XbLa3g&5r5r(q(jDVihdz0Q% z&{MTvxVsO-KvLMNUASVfdzS>E@ml9urC2%$Y8wwJ8^VEj4cG(lN$*3KK@VAEX_&L;?0-mpcK)x< zuSecpX?|}-7++&!8UJ@s)>sr+e0z<33Fa+rs$2DcG?qE}(2n(R%@3+%VgDN2U0z=1 zq}JHZZOcn`)b{+n*3JhO#?F)wLo<2}d+mC(;y&=jh@dA1kH#r@7HOW$xp?EgW{WD4p*df9=ei=Ac${~ z{h!G3FSN77Ad?D{6hX~|sxg5D*(;fUjK}|tW literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_205689.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_205689.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..828abd5b4712dd798ce80421b8282028a1bcfc8b GIT binary patch literal 4519 zcmcH+TWk~A_1^J}J@$-~*gz52TLeuS%UDld-CW#$C!ptNj z=~y)t6)~(7XSY>Mi^M}K5lTL2f6A_^Zns)#KdP$sCR(CdDpgv3`15YIQq`}XJNDSl z%0NHbEBW4Y&+DFZ=ALt||KfBy5RCcAZ-(pb2z^a9=3w%`%AWz4Ln0D!2=(Vy+>ebt z6}I%#7)dy2=rKiet4K+7kY9}>JJC`eKS+!8Rr+CGpx-7k5(-w8k++t|gS_o3-ESA! ztEk^WA<0qD#Z6+b{W($9tJ~U7pFMuccg6&I{ancz5Cvp`-%1?jIizBf7Nkt5xQh5; zo*v^~!ATW3d5zQTQxM$}3A(7$ghZ^_ALh@T0%Xjo@tWhml5?z*t0+0gYHBa8np3lk zz&V-|ydj>$h&f*&Q=!p@Oht)I={ampPwA}Lhp<{el)Enr7SrxnrRp?UAr7?WFfHpq z7ykh9sT24D7H>Npe)tZG)kxpzhzNGdU^^TM=ZVxx`mCWyxEk<2D_)>-e~i z&*->Uw=1d~5G9`wE`pH)I8X%rVVR_pOmbUiMgyV`l#(T)Q*s!75#1vA{W=Cm1Tb?# zX9kP~3pyUq89&(=E-5vj!c%_y`$EIzs63*K3Vx|!Sn!Vs>Too&TNsieYJ(#C8&oMg z+5o$TumFZn_{JqaIOYqFh19_4p3zAk(0o8qqY8zW^%dR!=wd$90c7Gdo)zPiA&@Wn5j;%(Ba! zJU?5VYR*(`oo0v|or%t5TZ)Zv&)W@vSSEeC?o-){iJ=Rt7Wg4AYXiDXp zPYJcmlVy+^NAfPP;Qyf4h^DP*y-J=S#t4>aNeO0<#Jh|gfGrcB1=cL6G0QZ_iuS8) zj8+ScO3Qr*q+!jHV4xJE;2QJ%;wsW+qj~|o@VtQ-y+lE>(bHZ)G5_e>unl*5U`$xCw+vlCx%o742Ebls^%n7Pe)sb7rE#b#q4 zO83Ke!|Cb+8ScPi*C#a})qM7&&o6#@ajEaZlRhESC!{+DGR*^7gqSW0W(jvuP_JQ= z<_4az&KWMjCC?>0?%?~jyS7L8_v~-kCAK-<_7!&|-tmlfPK#I0F0=OO#%UqxO11$w z!zS3I>jsF;y*2w*O8DUI5@ab@tl28WJL6r(H}pfu3WHoyqdmAGhOfGb zLvRo8VVEgnwiq)FGc7XkV#}BJSm{5SG9^E!Wt9Y0Nlq98-rBgJi7I(9mML|jO=Jx{ z!~+uNM0k*Y3tTJ(HO`x_&++^ZB zCa|SQ+-l;sn?QXlr1F?3z)RLhr6xH4IY}MJZBCUG)#n!!c%YST_Iqi$8Cakg80E`I z4&zBHMu1d!8|an!H=Zmh@bm*E-$U$INcSMAmP<67aUf({7obVOQhowc9DT5N&CT#tt6{E`P8-XjFVajA+4WvsK2vpW&W|BDZjhTrz zWXGzhSYg9fak^SzDkK7}gxMgmeWQJ7SF1j()IQ)Q+D6l@v=Y4V=G|bWq5r@!Vam0hzSW{to zkikg8V}>4cYjG4Qi5YXQq>+_KuR&+lpj)cbhdGhG!OWEm3|d92gvM&C*yGlq+p5!t zSL%j356S#IV*NCtM@=Qv1XeqX`2Fe!nLlv-IPz%XfDk< z4!aZXie2pgThVMISS^#spjKA5CsbK$ty<0jc*kk>Ig$ftedy<(fGv;a(P|7|>$EzJ zGwj!ENOM6=p${n$4O5#yTBAt@v^qj1^b(aJRI=O9(4b+^VKHfm4Ha=zsXeVhbFZ>e zzf%7cd;^hw!k9iMyz6M(pw+LC->B7A?5H7euWXuE+qlXmL?l|Cux19QT#%*OkM1x? zqu>j%<>(qKG*_u~tx>Dh+)u2g?-*S}ey(38<>AM=MO3fFq!H5_KczEaUK!UdQFT<3 zUyeCE`EVd8>MSqwkr7F^Oi6N7(XGRwD6e+x(=A0Z z14~h-m5H!!QPij``8VmfN5{Q7KBn6gRSt?$fDbQEBRW^24wZ>QSSA%ClX90KH%KNW zsnapA?-&Y2h4Fx*^0KP4(cxhw5Gjpeoh9+Wn&t&Tw;D2mu#J^ ze6Ye_>4YznMBlT>;hJtuw`TFYeRJ~IV%>(Ub*43WdePN*Z{PjSyPdNGIqUqEH}kGH zlgGbyxu-9sFJ-P~o9A8IlgGbt)o1whATaFC>6Ua$=0v9RcfF5JJvg;sZ%ZC~S>KrT zWP7sje^uX{vMzdmFw5q>yHbutPgC~l%+?gU=&8#L|7Kv8$zGU==4^Ry`)8*=XTIP+ zx90bs{IWB@|9swiKJWQiQS$QaTlo#U^Pb%)_G@R&^ojI|%<)<8r#QDiXMM2svxCoe zeu1Ag|NdxxTkn@U7M%SlTp^ZI7MvX^yvR8|lBeV8c!tkz%l19omOXLzjhRTUeZk$D z=UUTr!HpVP3U#Q)o$+SAv$CX4 z$Pu9ufZ@0mVft^7>6?~2NJC@fG^}doH6Y^g7WgdIQ%Iyt+Apz&U%Q7c?YM4M=PVxKErJs7T!cah)G z|B9q=%Gu)!PbjKyNb>FLlZnifv$srY?=v8BYD;k%2NWX}ZrVGO=f{RY+nEVD&y&>2At z#toR30PYARDkK_TQr-=N&UJ1+|X{p|AjDq?mci0I4+fPK+xQj7+gb{9RT!g2Ws(6+)~`3rR5;-bSjeIR|{_Mr@a ztLr(_oc-ZE)12yh4f#;|(DlA#Px2~}fkmc0eK_m6b@VyY`Mjxfp6N^-eaX3|9cf3V zCp&Qe!rcqGfu|QAUtHi0CVO5u4k!CwGOm<({p=!VOSPwX7^b;2m+{?NQ5vmR$R9cjvA?oqRm`49|UF$S>B`Wqbc^vsl=I6FFVUlSV!T2&UW!dXulV zR0q|vUR8`hO+3WVQ$|l4Jp(;c`7}tZ^N1{JS+N5zqXIm#uvv2ns^T$G1uaCY$QgP_ z1khGkt|niUE+JN8Evqy&j%6!N4YC!u%_tcF223~vLOfQRw{ba2;m$W}A$ zOgn!aefWz%y8r(6<>wo&JeS57xKOf()HGoQa_=p0`(f=h?(UCAlV`tol6|D$d$ZK+ z!3FM(qIi)KU*LAF5QoBTZ2o5Bu5WFulPw_D!WuQ~cN)>k9pH-@!QK$O`X=CsqjTm= zCRc$@j|ElTCPX8ODow`Z5-^-!&XWI1)dC^mUr)j3CK$H3QosafB*i)Q${`k*Le<;z< z{q!H_uH?^M`Hn~}(TH*ydkcGOu`BhX={@N^1q9RVf%{!|yNqe>Cr{5jKC?uq-|cI{ R?6=(*xxKf8*Byqt{{yZzF*yJL literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_370413.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_370413.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba4f16c3d919952a7c60424b17593f1b61d4c69e GIT binary patch literal 4204 zcma(TZA=^4@x9&k?)nR30vHHYeMyesDg+4mB8L(nT%fr?lqO9RE|s--7qGE0x4UqG zu2rW>75j1$rq>E{iWG9HbmW56Tve%QRjF0~_76VL1#MF`Qu3qeFW+&Ms{ZP{wbwR@ zQ~IQxoj3Dl-kW*zX67$Wrwze)6#Q&_)r!zFGBF36N49i87LkZV97e;rKOV-$oQlxH z3`PyTwqpZl^WQMJx z^(Gp&QAn~C_OQ=lpX~+NVo2wD&z(R0j=z8m2J9t!gLe=O-_}upi%7*LFGQJG@h{>B zd49~U*){8&1)-QjbCgIi#nmPzV$J#>uWb&9F{kF#=n0_2DoP~Hdojh}uUW^iT9}9U z0$M;;DN}RIr55ah9U0B`fY=425B=~Tuvev4S+ZOCbJ|R8A!kmu6qcN%qLiM)P0tt> zx;3OZA(GHfDG`lQ^{_;<5N{E#<|Z7%%X3(h6RXlHUXprwEkpb^2e%O0RbowJq)wL564(Q3D8@@TFSiKd%v zkW9kchjojn&Q3{VrmTWChZkv*$W3mwO z2SuF~WMOhl(k*k69948~G#nMwjzhX7N2b)MZV65j3IiNbr;R6$mCHaa@0_(3iZ4f_onX(15M=?F2v;6=a?|CA~#I!pL|LFWX&p+aXzN*N8}2{yg0isyX;=> zU-uk%>}kz-T7N_Rmj9ez_jG35oe4I}qN-h~%OCfo)pY-Ye|a`j)0T0!Jv{no-=FcL z{hy!u)74DJP{uu!U^m_MAHR{ll-{3l?^zyQx%BJyl^@)TKEjzMP<0}~KCh}?+`q6t z-LyRO@Zv^QXM)QsPmgD6nlkPtqQaAYW1%6zJ}I}>>iAHepGnT7#+JDi<}+c1%Xr%| zwQUdI+i-P0c6DW3U4QuAABTT8yzUyv@B>LY`)yRYFS`#_xKg$0W;j;HxhH<+3B%qR zxG|8poVvJpdEs)p=kA%MGmjfvGL0?kH3uJF`I0%g1^!1ed5ba$37wTCby}Vb9K8-ajK~Oj7BnR0RgpcMe9Y=R+@9tjO4XKJ*xq9EVUVl7O_m0FARhqe_X8$@s``)LZ6V}}5i}%Fed%`>C?MZv;LaJ{Wf6CqC z9%7PlH~7wY&(n&^c>g8=&+Ko-`<^q-gm~lpCT~r&Cj{dMV4hF%DfeyAuy|$RN?Q2n z-4bFcTIt!W^!&@kTi9$haya7yhQ|Y7XJL>^L!7l$1@W{~+>Alge1ZW|2BZzh0Axyg ztk@3Sd``7K^biYRY}>g^#>%fIAxu`kZdV#|4lP>m zQ1B*fLq0m+S7lG-5!9VW`0fI;(xg?JNL`WHXwtkUvZu)0XVRKXq`4a&-n1w{t48Ae zg_i)*YQsyktCFJn1A-#?50y?je2jb&mcY}6-d@!faVYObwG6B~Ru z-t!znoUr^WmpAysFt4pkyVE@%jmOV_;~;V*{(IBZ^3e_cJGtdK%FG7ev~79Tf$H`@ z^)!8DWgTo5u@*L$H5`VuGRb$Q1qZ{>zD`3srt>BylccTFp`fZ;1JOxEm99_8`TTLI z=qJNA-JW;JSTro&N5(d0nu|om>9BNC{ywN95vj}q$kG_&XUOqafHTvTZH;`H*Uo2SKS!7YHtT`*o~w87sK~y!~g&Q literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_424820.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_424820.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e7e92508bc33ff6cef4fa2905a0e69904d9f22 GIT binary patch literal 4081 zcmcImUu;v?89(RxUi4uL=@L2Qhbw8*;gy-A$-+UDLH zH)PkUshDa*rP%$Wm<@@5DiIo_(P`4yUZz!(Hfb+*6LrxolP0wDd)I6bOIrS&+yssp|~`USzrPdQk(QxT&nYt z#mEJtYsk-zNaSevvBvPqwYV^0MB}G4TIA$ujfp9fg7l0WI@m>XvtikfB@03qelYAfr`sY=aDm5;D$GZPW@F{b^I8dc^a~p zbFQsArylQKbsPjSYkyCgPh=7hU0Z(eep`O%c4scS)V=E7TV(fUOeHsR?JhMy9kMN% zk<`hjwE5lB*H5P}WE(S=^QOi2FKGWdR18Vm$d* zs>6V69#40GnQbcm!Ce&DC0m987!1vKFj%hjT(SKo@NP0m3N9e7C+BGbW>tI4CVO=t z60FqWR`t!B^y>h+Qibq(&9g>Y2QtU01d{eZ1pf!SGBu3e!3t@F`k1N%j2UK}RJo3K z;MS?i2GAuo^mQ&nP|xG-8rmH3rIn`97Hw1s^XY!Xd#= zNu=wRk{cu% z1GukT(}7zvm;09Ayn#j14x-*fcG45j)XF)KI_iCB$=fG?HMDs3m#-D*GlkAGD=^zT$fm=Y!`BB>fz(x4WbuL}XUXH7 z?dI_UJzQuVUZIE6$H}%MnIqY&3$wY|&+WSk^w~n&*%g@WyDBC=oLNZbk{=5X#Cu|) zskdnFed7IN^Y5Gg{L{Z&`O}rvi|%C53UQ-Lq-ucii{(#|x0{rnnhO?RictT!6~a-PkwS??`STX;S9dY=2> zjT&Yx>1lx8@h_)^F_%2Z;Y^*@KUWA!R2`&MlJB9a|nX_)ES zmsPU%C0z&3%T_jNV1*oieGn)yj&zc$d6CtDEZ)Mix*d`Mj7^^F?Jw`H#FYF#uO@&o z6K}mk!G*Go_)V49TZbh?XLt8a$st?A{Q*Xz&`R6df!-_>#*SsU+ct3Xk?LW9bKl4VR@y&uT zy~;*X0dlU1D7bWKmF#$IGT0Y^NAnE4CpFffB+>^o(^Oc|tf5#`R)pEO zR6&}H8+Foi)9jU!OvEDmL!`SgFj9>1GZEpKbRNtR(3U4*Dw#0GUm?er$oUV{S~4Sw zefNdyFO(=R?p$kbSx|Fo$wcUV`+psI82Oga>n21w^c|%RFLtJ1n%|SzQ$j%I58my& Z-M3E2Z+ly@`I~->+)cN_*X?>F{{_-aN}vD$ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_554113.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_554113.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b967d602ed449877b4b04548baff0655b5c2ce84 GIT binary patch literal 3711 zcmai0>rWfm6~8kc&x~Jyd2fO??y3Qt4WW6^JVX;rc%(@nN?B2brWy@1gR$|0%nVMT zW7Vlr#m=sTX}gO_BjKzng)B&={nB<-seeGf;DWYjmP#uvU-HenY^AE7dhXa`4DmKs z_MLOjIrlv7J@kfVRMYZ=J-ep%SJn2;;*|Bvy;^ zIHNI|b)XHbXrVND3lx>~d}2jVze=Ap7aZR*EuWxD3BmATNYenKscgRsws+ZtK2 zWmg@h6k?(@dKG64tq=Y5HxQRYb7(drLZ{}`s40j!?$VIPLOq~gk|LTQ-vVuo!Z~9- z&50#gR*)ETt6tN$U`v|okS?rZyt*vA3e}`lY1N`-jWXkm*Jy<1G`zUqFw+hG%4gPU z_OcmOIO~H=tI=u?MF?4n*0s`J%()<9`v)IUNFz1Yh@wtoH6y!teVO$Nb*cp^449U3EWL)jKbvht& z5uJ`sOem#qShw(kpfh1kQF&R_?S?eW2{W=z<1DyvaRy%;^E$(GGrr9I&+6Ke%yvUdCYEbBp0sbf zk7oFVx0Cd?yEZ-g)R$2+g9}{tUap}%=WbuQv@U+?TleSQy_|Dj2Bq5Or|wK&rfH!w z>&w+1Uuj&Y{>ZO0FOTKi7n1Z&ZT(`~LR+Ra8(SaQs=Ww~!0H$_zvHf3Y+7i_G-fB( zuWn)GZPxLDJReKN(*7)+RX@3%z5l#A*VwtvZ&hE&u@_R7yc^ZG=j%~bb-FP%mKfM} z)unk%>`?UMYY(p_N7EyVqYI-M-}1oHz&~3$axER34X4&eKcl+$U>%X%gu7p0TxS93DB>>dOv0sk%oe}`VTepst}2(b}(0`5&u_<#99h-9WFU@<}V1aT<*E( z;V$r=FmVzrR*;I60i*l&to!y<0G&b$R)G5yGl4x;fO+6nxTXQTiVz>Sh)kSE{vdAk8v@y!PQ&I2jqQ?%Xk-Cp@F5o^i=@ zrlZfpM?_C=iF3N6&%mU)uD$>Of)a>EJmG)>XGoa5Wa>@wG0CF@<|NOM=RV+*v49}G>Vbh!zJGX)FqrOLfbK)$O;>gb)=VUDhUQf`K|ns_umMT{ zlXT$(>y#h{Vi-1u`rq@(M__`2Gxj9p6Y$a5?hqdy7x~L`_2Jm5Zh#&Do6A5HxUE

{Z(YbWUy@IP400bt^!Q zI%}+e{5_aqpof>;FM0VW=HhhrP3e?h4v2-2adi~IuBQ+<8IH-_z4jQh!z z4QgbgWn_~YNnY7uUGt8VBYivFmnB{>&zW`NPwXGqE%su<_eE7rVsIOGpF1uk`gSQ- zQha!8o3$lRCwb$uUDE`=f7`59I$wpHuEaHCr38qx)Nm!@b=#|jI(XB5`(d{~CZH#c z-eU9=^i=tE9xs2kRXC*uD+gu}`==A0MTA*RakA|4uW)KbMq~{;I0I-0JXg{$XcuFO zf16heoKuTvdqx^((`)e;tmm&33Di?0*q)NG#^lwTMDxC|#pHQR;+=ircTHZaNwoLC zfte9`cuU|qedWO&e*3@)a;TD`assbN+?n#hSuZ7D1Qj@PfD2LCC`;pQAx1#759`&% zOI4N>xNE`4A;wt+za?;)7fPHl973HB0vp4E@-yfX=o=Qzr~&?s24b~Ew=T5 zdfthe+rDUO{mMo=={%yXbm6o+jcnz!5Q~|?sSv!cX5fjTvnC}wu-p<1sJcyvMifyNUne6s-BDZ;e>5b%M1~s^p2AUaCL~>!Z-F_Cm&!fp@)m+1zC_N?k?UWm zB~K%geSGBMkvv&P{2)92BJ@0zM?n7hM}NEfm%CqMa?gTDryvTPTn9oG^aS@I$I5CltXuBCQ5;#UwyOWl9WTo zK+yqq=FL0i9Xs>h`@6$oMKFH$i%&vU3qoJxjWL)Ku&Dtsj|3zT5j0+0iE%>TlTm7% zCXh%>=;s)s#Z@3hdcwXHhj#*18y}zr`X>FTBrt9kOd^_a)sSaifj8I22N;39Nsn6u z%S|+HC6Q#)rb9_tINVnj0!T)24Zr@mnW81dVxbVGCy1>a^n%=1Vg4B7-~K$TS> z9+l`xn`%=nvnEt6L032pNFY?pqtZ#UfK1w}G}Pv|8vIE|6|eEvjDulSEkQyl?z9C*=w+gR;`b)*Dy|_pZ^2wa;i?%prvk<9?}}S}#z`v|;WhOa z{BS(Vf3M*+nzn%_x?TS#eT`~emBt!(Uu{quw?-GoQ7|nt#vhz@K!oZ(ctRtURP8ET zj&l?2^vG_B-?puB!8^4{t^40&rq(Z;jmQ&d26@;K314gvp_v3_W=aei8a<>&M|pWt zGsTsVD80Z>2DoVcdRX=lk{NHU8mUAy>y#7^N_^B87Bq&J_*hWXOtYdCmo;-B66cj8 zhc#0XO~Y0Ip}0oL(^1VND{)EmG;1cWZ%UCg(?D@Kr7y&=MvQ0{S&_nm=;NamFjfYy zR6u`J!jD11?@=>Ng#}+!W8#5;TwY@u#ryr5Syp&S@qwg&T9U*V1{of!lr`pxzUuB_ zWP*?A$l~hPiE=Y&aMvFH8SA_jmnP*Y-Y<5Bc>g4?gyOLSd=S)h%96iR5u;O`up)1HT*kcxc_)nL72netV9~d2`pktZz-5HyU>40{Mo$ zY0C!Jl-u`!TRgqkywJHcnBUR;*z)<_zY(9e=8q5kbu52;G{0js&%KjoFlWo+SiYqr z&vn4o)%4-j+}`Z9yUxYkdDk0DmX*Dq5i70vww}kA@@*&c&b~DLoU5O2zT2E@S`=2^ zSm%yD;ZEea6MxwA$Nj(Gzs8+OGf&G5;g#JVM;{+q=lY&-19@)X&xilg^QWFQZY0e- zV{IQubBRnM8^|RVg-^VT!G-tpO&u$j)}6=l?6C}0s7DQLg(g&2pB>GG^Np=}M{DZr zhKSZ z;rNP{@^pXEd@`Dl`Tz^h4#{7@1a@xN>~r0j?mNe_{O#UV`rvBI!8Q6|y7wh~`)W)3 z8r@#RAIluOaXRHqUByhW$sWq|miSMD z4}{oq=8;7Wl@)5l2tubja(y`1y(-?djMmL=34wk?vhQ(S81gnzfgkZ z9tn@YI^v;9eNUmU1fc#B!1kAkE(6zK0NX3XT?Wo=0DCILy#{Wd0krkQOPdyWXr=U9 zjvvko+)~2Rvnisi`24&q`VLobj(BLP2Uy^(LZ*mI`c2kbD6(`r( zNXm=vHzxU(M%URc*f;52!$+ai#qaD`4qyI(9JzROo!wt#FJcnw?7l7R1v}c_`mA~1 zHx|au6cA%#ifO{GdnEPqeGRAdgq$Lp1h@lA0y!w*a;dD1cjGwz7j`!; zq-)iwsIZTd;@qjiX_1mURcbhpF73CXO6~Xi5+Bi#PN@`$e91T8vJ?#lOf%?6_k)kB&rBb^a(TvUk`=r3ph8(JZ0y+A8z8m+&ps`JW1nqB%-b5U)Cipi0&CqD5*pQ9GcxMyw zqmn#n*X){g%8V*!usii3CRnpRDy=jH;-sU(LUnzcp`UbC=o+V1VM4P8v05IF!#VG4 zkD2mXF1746?1|NEkEqws&Z3_^gMHO#HJVf3VXfwX96-NjL^R4Yfre%Zg9lnI1A z>X;;N(Bua#R=5>Xd|0;7p^Y$VgxDtJcu*Fi{*XvGK^9^`iI}G(Ij#`PXe2JE?Z=3@ z$Yx|GX(b$+)u2EDu+bLFGRP<7~xCYkunvKX(G$C zmWX*GB>JO-k?9vBCLs_YI80bM<_{B&dg2#K{}I@xe?pZN!i^OF7QsIfi3h^OGEyRn zGXjJQ=o$(whJ7z<|9!9Z=W#i#Ob7v~bxa6^1$8VQJ1hjHnA)nyfmT(DPPD?PH7bB- zxBLky0G|7!lMyvEab)7QA7p-zsPULT5C_jsoIKjz-Ubq4jVAnINsdVo`7rE(#$6eP z!$_fZhij%U)0b<0>e!n)w`O%^y_xSnI5yuo*SRpTXnDG)EpKg0^**=SXPPoi*`wLu zQ|lY4-gTQ}=3wSv_ML3!?|c7n@so=yw${{{b*Be3gq-&)XLA~_*3@Q?WfEz2)l>J; zK#t80&i-ft&$cdzi^u=uUhMcJx+LUVI`ZyQX>QeBpTo0_X>QH!$&UQiojaTJ&W0B* ze2sDAeWamrS|2-_aV3d3d=O5Jxrly+HT+6Vra}{q zmteYbS&=c|3s$NoHyQG6D^6*lF z?OmnNxIqtHrq6^TNHQ|LiU}u5iBLe==!QX{^4tNk(JuHlPeJz^{JJMa25az0pph^@ zeiVaUreB9O0AQ{N34{%Zp#%lwzH`@lSn9J9lNgE;RtZf3 zlq+Mx1pEN70%ODsNT%M_;~*c0c{(T&t_aOOUT&qlHri3YWNLP6VZsIh*NYC&Crb%W zyNhlx(S+O!-2Lz=UqA=@U9~%APGnBp>&OZpbS<;qbXRrG$;`=Y^S#a-p6BMc+|A!v zzqB_mv(3x)<~3NMBhxW`Hr125`J8vo*faKQPi|m-aBgsM;PV>~Z>;bqQ$1_0+SEI% zw0mKHD|L3Ab*9DX%d5OKeKalTKj>!o44-x1vu6kH-wf>}w{dRJ?`|El55 zT%is*oT>AAj050PxzVbl^}FeT40>Ir1R+m8#n3Z)Z_;}fdZ@6iCqlCFK}r*3M=P$f zZ3b3p&gu=GZ`jB~K!tK%k5_qqiWZUA*C`>;Hp|uYi`u04@u|CaDggV$mZ2fLxvCPe|>H#E05_R)^M z)}Z$6s-&p?fS^eJW0iw;A1j{*6)09Iro?4Eh4uH27z0r~h~x$;uPNQ{vZTPZ4~y(z zv?~;(X-O2z!J-?4n&<;ItqbKB(529~JhHlG92v()@lSvGq~^;D?=RO6FH7MSK9cHL zhXki0|Kjx({y2;q8guSk&&Ok_%ilWaJevKzIcDMH3jf2RdXbY@;aj$;7aXYZz?!$^ zYb)pA3WzgvMVxTx-pU=Y7sG=i5qR@W!fS`{1|?G{C8qI^O00o+Oi`s<6LJYFu5HCh z_l?*~AqmDK;zOjHF@P-^7bhdqY5AwH9L-4O7IXy@#`pzt`~x}viS`ya#PD}_Pwy@; zUVLEr%Dd054Ck*57Z4obBO9Wb0_Az=DKup f@!aQ^9$wm{+;6(^Ud(;njZsa*z0kB>-`W2FR`x8! literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_686366.cpython-312.pyc b/src/temp/gen/__pycache__/matrix_vector_multip.py_gen_triton_code_686366.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e54e63b42a6a61b8d8446e96cb70bf3c20c62bdf GIT binary patch literal 4136 zcmbUjZEO?S@$GuOUfXM&O@d<}g@CRhJ^>*F0z_#@LIU9e$z6}akzQqOyqm<1Kf>-N zM{;YOsi?5wq&U4+*<6tlPLyH7MLjn?rFdEGNi9y1c zlM%}xMIey~8FI|6`M*GlREXV7!!yBBo<2+o)D3E>uwl?9I0UPRLKS5s&@Z9e%F~Bw zfw^HFvv5;*RPEDNRJw!O;50FTQ0+?vNz*`#vn3kJ^E*ubxJxBfR%NgjWA8SY zS;jzG++8AtHg#?MRr?5`6sIAkh!&CiWKm0nic}Wv9sPR;g7gWt_WO<$zk06W?B9N12#Xx2x_g zC#&9krfIkj!E&E6b;ms_0orWlN#ixI(NLpSm+lX_^EQ;iS9`J;5ni>TL|2*kRz3f} zbk)1XZnz)8dY?9LaoPnQulV>cRB1d`mRga=)op4;xl^cQsQQQ8A6dVz`v~1CC{yF&h`HBC>Qsc6$8>8<85N~x_z3|Q#a|1`K0?A7$;axX64o8# zQfx%xBmSVE)4aq-M?~E^ElM$2w+)A5ywZF`x8}(dOa<5!(=GBuM7PRHOcHepE7l!i zSd55K#ji~HcIZT>PIT$SX`Q&B6aBhfR-~XH`uWHP8r7Ksb*P91Bpl8Xextf|JSg}h zIvpDxmW%(WP6zzsiX`hawhrjlfS(TpbO-Mr3XkQNY_h^jiXZb%ND}@VyzuAq3uBQ4V_>_Ix{|)C?O$>2 zOPpDw(KdJ5qcvo?eOXM%xvJ89Y7jWtZ8N>8-t zo^X}vOWJ65dwrIxFWu$I$g6Dgw`@z6ZTX}0)%3&Z6}Br$)?JQ)nMSo0Eb1OkX->q2wr=>YJClNbkjbER20I` zYRfR|p+Jl^fS4+kcqJE!xn$d0kWtN@za&6%nM(LR&~Dk|7LZtGRiKP2VkyeAmdGO$ z2NcL&`4K8py;642S*B(IwM;((Ik4cFI0+osGh2p0nK>&?y@Z1-a#+ElSIeLwSH7CV zmb4EOg6#$qXOv<+Q5nN02&E;qGL5 zPiI|ZLRRXA#JVE~yX*L!aIsjl`F-2z^Kk}%3cF|a}k42$k&#KM6XRW!ZZv~*U0W%#>9lze4=uBKKOwvuM zR*k!LYMDB@TzhhbI+;9`voVRO$1FF~m+I4qMeg&OyEPBEue=Yvt86P|unKODnPoKY z6KCRFZqvkFF>@_-O(Sl9nCQ)wPM$hj6Wxy~S5mk>kYnu0rX-){(w#YmOP|z6Gfb|c zQtSG+jk3~f4#c_=y@u}tU|VtEa)JH7UI~%)x(FJ9`1yecKG1Vt*a`?4v)qzdrZ#F}beBdhwyZL5=GpU<$yKI)6Z;yAy!)Tj)IYb= zEWL(kD{aKEkLdTYhQZP?IF4z=p)fR_6VO)ajEPA&J?oZGP|@uH$gPTaZCo-^hfZJI zaFt=6?ksp^Bo-F#Aw!dy^CB@}A}qE`?}9j-gk>4PnuQ>Ur-=P0a{UX{LK{RfH+Np& zxkm0FPUUKL%&D{Lng!#B4u2WG7ycgO>lQ?^2FLRgm4qw#BBc>DtF0EPr9PP9vKlt2;)NxE!;fxH9JYE!n$Zebc_?+YqN* z*>}$So^$TG=ia}%Ty_NGi1O9A$A-`gf|!FffGrJxMI;~r$IwXei$}1IQ*nBP!AQhW z-HxeR{0fxFM4j7d1QY1$^e`hZx0%O=!iZI{h$vcDMW1~Q-CCU>wKcTYsTQHWX~~qJZxl0)A}pQx4Uk z+GZ`NQiItk4>`fA?Xh8L7T76g#R}$5RH&)8F|3sEC|3eYz!S=}mvU9us=l_pBwM4D zq6VL6)&7`74PyZP^ta$wty-(r=pNRoPN)I&D@s6Nsu?6yW&$Em>j;(54Ju8jvyhLJ zTSc$vSV(S&y;t;DlaumP_+O=-RFCQw=qHTnX=;yJr|PSw>QxM|wPdMb7fTIl{WeRD zYGv+fNFKHib&uM(JzlbwV0prt8J%@OIy{ap^ zXZyMYQ>keEtJG-Yt2DD3)uty_GfNoFAU`)Gk)rcs%_1l>lj4}E>!&m(&dU>;C8>;y z(k3ZqkYf0o5!sI=D?yzarNlJ*q?8L%!*P{)~w-Jl2BXoP=d;gzgfcp4fkodU&DhM9@1>GB1Hr-$j8fI zLgNhTXbB2Qq*x_V@%vsk{`-@TkCW1bJjsW|j&VLT!7Jm*#8G|>g6)u{ zP=_MMCp!S^i1QHTO|k;yCK#WJDUr$c$(bPVg1}LdiC`!RSvY;H`%EV=G+TT+I3Y?2 zF(!3@3uIH|AHz%(8p=;@bYMysImH z;icU<-7j4V{ptOD=TP1~G{-)7?^*ONcyop$l$QbFfg4YHG8MxwgFfVBU2wePP43C(CC>(ifgH?43)uF3k;R-(4JD z7|sPAUbuhZ?=463El1ZH+E<33F+E$5T_kzQ)C7~pN}@6)B|=8iCR_wZy?9AHgI%~| zT|_F1l9RJznAeYhC)!}ktXGggnY3SG75`ryx};OODnSQTyM{gl-|kvd6j-eE((p#G z=uMj{*<;H-DA2@;BUg%F(D@>m=sPjsYa8<2^bPtY@@ANA-X_;qk~==2x8B*KmlL66e3Dlp zqcPDZM4$=Ekz_*lg~9CLFvycW5Q~IBBAM`064?-q6~yUCNZjlNWYY&zsItz;?vV>) z>oiQi!Mi5+tuTg2sUf|c!d1W)+mFVCgvd0x8V1kb2uS-tK_kBcY9Wup4H}_HKu2- zX9vA&wYh7J=_=AsWlr51NC(m%5g|ab$1^|5x$mA^Wv;HaTwP4oF1^uSBTH7DGFf)};pH$_e zmlVp5QM9s61D87wJQlH8k149+iBzQ{308sA?T`pyY+JdSd{Nr)DfPNlp$nv)tf+5nb~5}d?s+P%zVS7wVFU%FC4ikfrl5DUg_Ns z+9r7+!WncZqO1f%@RA9hs6;ybjC2-6;Lz%=NiS&q1tuhbRldM#jW>lP%FtrK$PPlM zOx{uCfD{jhRbLQJ$TcAM!(aXxOz_s=kgaCknQ?xa{NfjXto_^N4^|sKTootQxmY?t z%ABx*%fstj7vPOeId?Ab>+$sUADl#woPoX^wbZ@N9WII&Dbwp*>o)O%6Ez+B*4z4> zjdijG#9G*5zi{f&N~F`75wyqP6*>hkER8cEi8OJIjz$#C7J}}oh&LxCqjTIS=Sh!E za~LTZOU8spNMFWmaPg!t6%)@&H^3a}xUvFMfyNlWK+b<6*S}CpfkhN|XYZ}O1*#su zz2R+MR2S3&P2l!pU&kKBz9;Y&jVPxsQE13?Xq0&k!49lV_75?DLGL znkx2GYC!5LraRj>QKdDLk;Xf!T8~tzsK3&wRK1Il@V=%h+f9GuAG>L*)v9WH?z10u zhzXO~*_F08_`ZA2x%b?2&pr3vbI;8`s?{X~q^jZ1Z#}Ieh+klUPi48n%YPCO!~=pP zNP(9a%zg!f0uB~Vhz7+1j1zh#6VgHHglteIAY1~Ej%Snqk|LMbW5~mNi4<>wmu!NU zZi1I>f|nP^kE%$;C*mhLF$PN%ME)OXB1>F^$HdE+|B}j0@TyJlYM0El32jtZ0gOqF zOIo}wN7VvV#g%j}HK~W60e+>$TZ|e><0q0q&H8MWeL@Usg@jAHK2F_TfzA8^52)8B zWaWL6LAT%M^s+bIF8t8WnboKW4U1k+P zxnjQz6jphrcc2SO_A6bTtm1TE@7Z^I>;qjqmkT3Ua6T!ZAcEh^B^9=)AlOwld4zl^ z|0M-aa`aHe7Bz+IH7mEMDO#^tMVAGR(E_@X*3L)>Vg@>@Ti+(H)H?@^W}dON*q!^k(RGJ1oq zx!K!ISAVB=d$t(+?$@{~uuag#W<&bD%trGT676{liM}8EXraxNh;C65p%%K4ZsvA` zCc1^QaJHmu9&?QNYwha62UH73p*H`AEAAGDZ77%_H!`agLcyUXgRn09H95-N(VRWONahI z*2w7#=j&32#k=JB7_&by37=6^E-MIU}*(nEnv3p9^E9d+=QZ$zL8QbY&oA5dIR(v|> zc05+>MR2)bRD|8+&~dt>DE667|EI!ip^ZyiW z{|>EA7tHb7?yMWFKfk}9dd)J5Sn2n)zi(Vic_~i6CPfkZER^CdVN>){*yVkJ?=3HI zX$*c`f{T`T!WkH6CH~+o7nQ@Wov;Zg4ODpXzmRMSRv7fMC6knYjB-vm+$1aZ`kf>z zby7~>m=dl0=Jx#E`agJ*2#+c-W^~C@3ErOef_Mc|4c7i za{AreMXsOKyn3T!r2*G%ugk|ueYadrGQgHl{-Be(>mc0|tRhQuP<~c-+J5?-F2}j9 z_j8+}Zj*GZ4E17mo1}RpCN5U0G{DEvC z;P!C^c}@{-Pzh)!2val73aJVBjl&F#J8n6>qpXCQaEt_4DV}HOOUZq6H1)KDayh*Y z|LADI6=VT)PdLYTP@1K(@Wp7h1WS;sT7rehl|_JOWf*Y>?(0%N#}j|KR#*_%9Xyv% z@RQ`a1RSy?Ck`g)3q->_$h)izcZs=zXt+M#rD__p;vo=_V+Jb2d{kqUe`?Z2W|2xB z!Xyn3aLi!DJdC*oSvgiV$BI)sZlUKp0jMRn9)hq01EYRwqIvSJ1Mm*8{-EFI z81cgzXx)48$iV|ZVb!_n<1WhQ@>1`ExY!Z{7r`jQ#D5$n)Y|(Oq8DN{X_Y11^-O1Y zpq^95oN1jkeCC-}e?Jrr#V)0_)!}!3t~6$J`q?va4Kg+(U30QL)rVToA>Fx%Bx5Lx zDVUuHQ27C5IFOuXPW7XN{Sn2p(y|9_b8WGs2?eTZP1;e_!F1`NR0S$M6_IBJ(;{OF(zPTTQVz85 zJkp(C7wjM^KZpzmQ*!3?06IK?3TzwZ)zD?P*Ozxaa4}+KkRH zdjsieqmosf;mq{S)=ghHfi7Q{(lfh)qf+O{Vrw5iUU4GhzPV}Qa{miu+%X-I( zei-S8mq%PHBNJ$3BCYpDq|bhfwyR~9*wo`o332@D{D--juC}8x`(n@2&ZW!!PkT`3 z04f`Zlw`hj&F7%svFO;WCn5RVnKUNG7Bn#54W|~znafv}-g}?9{66X)MD>HruxnW} zx}tF-jeFVSUGW5wCz#euMMTdG#s~Ft^|8i8*H<;kfn@zcZ|Z%tyZdPqb7gR8=sI)d zIy!p;HQiuF-OC2gieUm7CYF7ZE52#un@$_hMS52C=2-2VB`!^vk##T9 z?~O<^fT~COns|4j|Iv94^;}+ZUmlgunq&L^+!;5<#^#SPwxg)(DAFHIU0odbrgiZu z>bSsM97Y{@4ls++oaX#VWvcVx$tIT^tOE(}_p4&GS_eu#n}rVSrOWSh5yl{l0$?h~oO+(yxH5j9<8-n+J} z8(Pshkj}Aub9m*Z2i^3fb>k7qs(IUk59dCNQ;9=gwlB1&1PdopPSkW_QP0$!PMbR; zs*I`X!NlA|{Bok~%c_N{#_55eZn3eLlw;tO{To| zvHkJwL`}R0mG6$IS4(lJv?(!>8bi%psI&_#3_~PbV^Tf`M6fu2^;> z%kIRLFNYR}(w3uOQkkk94^PaW$cmS)I?Tw+-?$EnF0361(^rDzfuwAqI(2L@_|2uo z574nbWb4BV>(ZMS)|JK1Z)z5MzIHHuPSoKUc$jju!3po6#k09Edoc9^Ja7I3V5?4Q;sYO$`j7%9Kb5wzMzZpx}0~w%aFN! zlV0!xSjC9n=b0LF23_Dnxq?%aZ=^7yYKCsO*NeXfU-MqsJ<~n&1A-Y}Q45^}Jbjf%F_z3E=@WK`?`w_gQ(Ir13e!TuN^+y6)Sd2$; z&Kxh{Spr(Aaxo?~t)>;@mC%O<`ViA<=tDv)Aea7F!oLeaeL=Bz80Z^cWj3S{;j|F< zM3sRCiqJVJw9uO@&tE{i3e)0`6|^KQ1FSq>D{2mDg4mnq5-kd@zb<&>JU0J*lIxqg zPm7A_C|WiuDA0(gn5+dol;k<{r2=Y^JMO6P4iU&|TU6`@UC_D+Z~p%6(eu6Kzr035 zB3c7uE6r2*5?{l^ztr*s`0}(FHQ1TPC&4_jWm-vV@LLXN3kq5ZR!~A~irB)NWWpBt zCc2Q8Q9_#V8H4yX#-(C!sdWC8OH0cFJ2^Xr8KtUeEvcl{`PURr;X7iEBTc+ohu5Zn zfR-w3@05^K`LqD3bCB962S}5HG(G{T4e4kp)YEYkzMdW^x{y9n7BNN4qaxDqi9BSW z#UwE$2$j;M)OK3zsmh5_@Zq)oCfxvMB4iA1G}SP)7SeSEv0pDza!A4 zw3L?9deZcn^iKs~6WV^lX6~KQarb}gggY>S!xl%aGh%B^hn0ebLd>Z0;SX!WjLBx5 zmzur{y)+f13=>$Q@>yOXX!whqT^rY?WTSlt0D<+?rPrd*`;19$M2 zHGde*D02rD<%e#n^W7Zpu&z@aK9`GxwhhfMtRUNqlQ;|GjDP^s3moizfp^4_+%q{8 z2?B4Nc&m8X4%rVyTRpNzMX-nj*Z}#BG%AP+M@672VJ{atfyk9Kd)H;fBc%I|P0Wff zbPd>938zP@2HH|q$O2d(Yd9RH-h%=wZ}&PUhDqn~neEQWN$*{VBH_SNbGz3+;`9cN zLm^B53b(~;>DOW6GtS8EAC6>ElRz(+vT^{6MpI6nJei$FRpI1O_n>|heu4jiZRs~*V$DR<)&Jz` zqpL`~Bit1^GL<_a@7dvScvzPUaM5)(aC%qy0ZHH^51QPyPi)_5z@?+l;KD9d8zC8d=SUyq+pRwh+U z`(>tokhy$~xi-Y?xSkeYXOz}gxGXDo$1h>ImXy00aTn9vl@@oQf- zW51Amsa#P0)w_Rl?Wfn6dOOnE8M&R&+A|gPD;15XqVbD^UmjaHmab@9sW^oyPNge4 z!F`vP$9h*R&B)T6w(MQ897L9bX-gZD!(aI-CL(zwxZxS4ktuIjQZ_K+sl>UIBKZU6 zl#>~z7-xV91ewOEw0Me9HmsJY?>9%A(7TZi4VjBrtn+?i{17Msn+lybj6h;^QCk1orPts1w-cPtrq zMu#ENKrOY}7xq1!VXog?>bv<5BDl$idwy2BJ?cy}ri4q4M;~2U6oh;JU3nqgxvH;< zA7J$L;WIyLJ2~rl;G6TMO?4}#R%B{jHtk;|!@bYN`q`aGYz9d(az#e2^pL zKA)~RlQC3fjMX^TnP2Y5bq*6MLpZx7jlf!1Pk1lJ8`H~QL;iqp6Cv1I@l7Bk2nj=? zkeC#PBn2-fu!J@y&I1sEpCS4}{3PcChooDy5z-P;g8c-p6_x-}NcNpt$qUX|DJ`Sr z{3|-#V-%isPdwWP&f)lw;-iL?4MERvqru~phXe(^W4&xhV+5&qEa2qLX(6N}OG2t| zp`jA^|5aK2;PVF3mikeBC;b#}^_qUbn^uv^FH}!*x_N`1)SRCF;4K11t9iofesinP z{=XTm?pusjk4H-r;5He0ZDP)sq;3ImAozXW zT;L@yX#l(1uecsre?!md;uWXkhfs4Beu4i6M>kBYYINbAH5E~}zrg8O)t83PK9{TS zE2GNUuGpdYp+sBi(6ang*#5jkd%r2#G&>seB<%D4raNyuIi0A zE0E=Uye?6ns#})Z^OAVtjY+|y7QmL+^O9`9@{%MQpH?zX*Rp!Fh(Nzn9VCs*>O;Tp zP7}dp^?}zql@qGca4&ZY0Jk46h2Gxo?rxi&Gg@i_5Ue(POKJAPTlo|`Li`)jt#Je<3XYO6-7lD53mb<;Rt4 z!jM4FCjf5)o?M2aLu0O~|-=O~DaL2z^DY?4f1ef^>CjN)D&O zxY9ZGnwkSOgu%3?<%&8&V_4I3MFXKTu5nVV840QK-q6QGFUv6H<*1MlsDC9N5C}?t RIVmKx*0iSPV^ zKU@9PrraF+p?2I-m8bnfG;UQP-df$MZJzQv*U5%Py(*bL8x;eQusRGv10)m)w$fSM$f698r6Us zuzLY)J>D4Z0B)ECEx>N9xzV`Ee43}%4^xECm?kyhg;Rrk%5c1MGm4UZ z#k(q_N^NXZq`qU5T_M~_R>V7Ai*&u4uGaM$SL?=JyhV{+H{Pzui=3)A0jsodcd1v= z?#>#{>cj)AEz99eMG|GZr)qsw{ZXa;?0kAun)>*5J*}0kzOF_~eYjcSwu_t^MN%a# z?yBLw-sz}l@%5aMSJmRK*J3k&94!(_{NY+8JF1RKt3JjWgzZYCFvfRdL%e6hLt-QD zsraK~=rRYE*IQTTy^2QrHq~Gsp{x3-{u=uFmZ#9;{$OpSE&Dr+g3dBnQ`O5)D(h8t zLZOeRG-w`qwW9(F3+p5nuM@~i$(mVzY{qMnRbogsMuo_<;GgvcI9U^l_&HfU8wkr< z!5^OHWu3oNtHg+`pMq|2pkLO;#E3xpV&Z(1m$jTy8Tk-D%ZEjuIPV>hsUOPfqi4rt z_1MW#*?9DVZ*1t?>7mgvSziIvXD^(WO)*gjaJ=t*zc_hCHkVs5{(}%7mYH(%3J+Wv z%MC6tD;tG~=ojXF3Owt-?kf*4i+~T6TQOM|CL?3Aq15dYA~F*Qi@Xrx{a1O|?qnn! z6u>uL!!tIXkx!K$LHhv|HLn3L)9yaAbxOijgj5$LQNtog3^ zLsPQ8lvtgTsTo;215&;&kZ>z&qmdZU>nEWCq--QWX_^x9DkNw6SRh>LB|xQ@z@%4q zo%HpV>podW)avV(Re)1n?~^r=Ik8tM`(%UCr*K5iRAw@iCL>@H{e!)Gd;6eAHkZj~cp=P(gcHCW$wXuS04F?w{=ByT}6GLlyTl(}vy{AB13VLgL@S(m5z%@Oa z)3-fxx|fbG9?y*2X_8v|^Ui^!;gQ|B)V0`^>CVne&ArQgQuCg?{jH?F;AqZVk{q2W zdd=CEHA>Fzq@ggdKQ}P^N&nK`#l4xm*-KK(?iEMRgOMb?rZ=VOhx$gclC?4OmSpuL zPOP~-*~5#G#Mws{+pVe8RC*%&PTn$*IR40NPxqxG*=wujo+MSUv7a1SI=Of<mng>XIzVyvh3pJ?6~CWNe({(T{bK>WZ2Bb+s$`6SJ=K)c8|pF z`SSZ;jejwoXAdQhyaGwV;mYV24`jDwuYI;ta_mkXdmck9<3#F2dMq=zIFa3RXX+c% zThHNEO>eC|zektk8m+rO(vds&p5%HDx*OZFR;jUPxm#*HkUaS`*V|Uuo>jJAV*9`B z|7!mi`}6Fvlk4|;?A^5md-nk%cmsoJCfy{`yGv#NQII((H&6v%eNzBDRd0=y4eO>#^R&P`+WJ&~ zi8x)#2BjYBprrp=U$dU-73MYIx>PTnmyWr`a<+s}ux?&j$uHteScOrX5s3#@zKu*u zWGc9zsRC`(iY!4Bm6EJ>R7dls~wt1-`x1=UWEOnDcaaGai_StqII`XMMLtB*F&AW9@`J~;rzkJU>k zMWs|V2`kALpZQX143Bn185N~XsbPDn(MqOs2A-F6r9ciD78khc)ySGNL+206S|!U9 zw!v6gS3cLo*gT0uLO4#a^ns9n_I=KOaG~8FjfUnSWh05+?gOF7B;+Lr0VwrHfjZU# z#qSa*7Q)YSi-m&{lU!TqII4w1v5l%wCKtr)kjN3~oS)SJb`74p2tc5HV zl1?KeCE~m<6oB-CJUBr9>%T5sB9JxAUl)DR$TeO7-{k98QV(EV$yvRuFakuvDN>v! z#TinJlHx2Wsvb-DQ_?&Kg{&{@S5`xArlg*-W*RbJ0UolDA{pT_Krv$3PbXl94jn4V z$W$vCApnFI$%qBe>+ch&*oZp2e)5CQejsVKB!-fM>A}q20^9PC?aW+DjI1%P43#>V zIrorh&uRAMJp1xcGVKM%o(?`_JUPwb+?K<6C>c+IvE161+Lv)>n{$mjA2K_0nozFy z*!{7eU;3v@_fI{T%U$@h-1&*z#Jf3PAop%C7o5rM4&{M}+4(;i`)kJj$l_SFv`d!u z%m;Z(M@qeBu_o1z9qvzJe>J$;(kHd_EqC9izUaxf9Jzla-*EhaJ?}W3IQ=W;OybyM z4V&I3X&MT&p+GarpS?g^6~LgxC-yG@y(tKTECt$})}^2-{W%IGP&!HW7nB@3{1tK#aohEv70FF$IEuTx_Ur-s`s zxu(Du*KIOJg|(cPycjB@2nS@8{yUAL#qz4)d3E65=*pcsZor0My;n$xUc{Af`UzM$ zZdCS3$EljtiCmufrV(fO8%p7~>~P`?XN;R(K|)4&%Pq;L9JMx;kLf$fr~attU_IsKov2-vOVZbkMlwhHV2o88)8VYe*VEk+O)?W10e68F~=fLbSl z5{SEA_do&RagOuDdyJ0buLdT0g-~WhJ|_Am{W0Fx=T)yG@asaTGEVRT1XGf8vd)p> z0qOdH6f;nq_o_<~XbZt8>2l3{{SX6{FjRvZ033-=wd?Nj!t)QokyjuttOOb1J%Z4Q zkc&qIuUXa$d~7Zx#!8Wk1g%o|nj(-A`ks!CvJN6%bWW5_wU0C#%*M!n1w0kqv||rp>pQ6q6px3}y%K?7cU*LLW&S zer&YfdNcKAdMXpVbNF^-d1}S@y~NQMV9iaLj?7qgYu?CQ~on#r6Pdd#|)S{GX%w)ft1uds(8;Id9q)bMgT+j*z!UgrvZc!QH* zwtJcStOu~h!yBADgRO9~-2I?A=jT_(=a+ZIy^iKNf|NRa{5-vLU)JYDRT~9p=7!O3FM{%#ylSlgr@}(Bn5IbWpywh zD%Z-G$X|~N<$UkFV)O7m78YgA`;kbfbRdWs%1r4P_@*Nv?x(1fa+Xus*$6im;tvWp zfes86#TY1yYKo$MgG~R5%)dgd|Aw}#QTm(BH=2v8QxtWeNTYpk|IGBYsc2A|Mr5q6N=_bB`i)w(^43=B1 zsn((v;GzzJ>64b8v}nXSiv|K25!;k;ecF1vRkC#!8GYv zSTq-{q-8^PcabHK9a-!}2Z5Z(<}4Bpif*K1Zcf~oc+vo^Cs!#%nSV>4rzpp_Q5CX! LAgjBZ!VF;f@OVg(fbQ`YGN?Y8N;!>47+ zk2D*g9ck{|bMCq4o_p?@IS+oRRx2>L%vXOo@y`_)_B#ZmFLSE%=>pJ?DcsmpL0ZP_lTz-0SM7jTyJhYj#>Pa2z?jmwrTOOy zqgohM2|b-#P3hq_z-`PwW6VUE-jR%IGN)yJ2OHIjFt;|dPu&7;wLC-{vRg$wIXK5m zdOgk=UUJ$49bR^9VDNaqqo3b(YS=N{Gc?d+cMP65G2AmU>=@~>4-dY=%VFeI+bb}_ zD^J-+dZ1@}wWphxA0M>$o$j@b^rX4Gd@EM;66^T_vw6QM#9XQuv+J zxWDijl%fn}(K9HC3}rD(Qp6Ql*I&Y#DSUb_hS3tXh}B+^V3-G|aDOS>WvtX+E{wOs zcv}v4aaPiFj}@~;R{+SYQTBZFlqh!=TXqHX1gn;@IwT8Q%<9n-(fBI>t4d+j8Ehq6 z$nHW=&U(_cRM{tmv%Bu3G%ax3&F*H)r;!f!uqJv}&g^`<1MTF>>A_ivzX~{13B0Ov zX6Ikqm|ojo!@&6`FOaE(HQqs10arH`B3?fQ^3<}mY>mKoFT0m55oENo)nL`wpNl9g z`QET4f+qGNEW#(T!mNEy(8>`H*1C-rEX=<@NAi687P0#fPj){gzAH&fI;9i@eym|K z(3C3E&2%zWao`#25t;R*KgfD@Y(dh#Lpa^f?%t-uq>XLUA#2FkR{e9%C=qm&o!9V; zUE*woFr!{*_iU3Y&pA;)j5ZpdQ?7;|S*|A50;hz&G6Y+v#Qb$^BYQxw?q9SUL(j=Qsj-e1c6Sjl{8%yoj0M6?3$AoOaGSCMjM#<8@NJ)JZ!%<8EG( ze25q?FB^jqrlpye_!uwkwjSY$!P6r={!^YfJ~+%1!>8=L;`r%Q_&UsM9*<*rsn303 z#_i#yo(Z>;^6?7V%Q)!;2Q@j%%adac+RN*X+m4^?arF0`O>yFNK8Bv8+z#(N<2dhR zTob%L2b1EOi%KnE$e}_@qJ=am+~pJ}Z3a&9$s$r*(yL!~qkhE}G1fuw}=Db7qWJdB?_l?(}}du6f6KUg4Uj9Vsz1sqwU& zY1o_MDWYe2DLPnKN(OTrK+j9hJAG~k!%NX*$22dU25O(1$Du3GPV@3f>axS-^*MaJ zOhCDOJOMCI_$EEP7)Bd~wuu)(Pjnfc3-DZM;z=~qeGv{&hI(EKOoets3L~DJb&m6- z$9>5mz(P;J5kSZj7o9V{BcPdz|HbtmryDML>1p4b)8%fMaJr_Q%!Jod=Nxx?mPFIyY-EE>Dl_Z1qa|AU_y-716v0}YH99h?l=e|ygEfb?WR5M8F~=J zL}%)A!5IgzZx3NQ!;*i|A0ArOmIrzsX!Tj>)ML=e4Q*j~Xz|YiClgJFSDe2p3w4CA z-Z$2LIusa4C{=;(gx(V9OUQH~+uK*s%u9=xLRTWEI7{8Cwmv9+V76R8vwS8z8ntqz zEioro+P-Q&w9?9%j|LS9b76RpGw+X1tdw*0Cpq)UfE{r0k~+>@7bD}H16<<(XC4UH zH{_a-WwHD(ns0X8=!gtOWvj)FoV+n`ETPth_ANRBCw_bIrEBMx#uvv!p2z}M(72{) ziVr&jy@;T8@m%1VndQ$`E>WaWIG{^oGz+ zmrH`u4O7uCn!;4LeR(i4!I|n~`dI5H`dItDBF@wqlqZaZzo-n4glm_NM$U4^`dH11 z>5~g9mU|7H!4{M~(C7qVT;Ws8-e~Vi%bMosy5<Z@Y33fk0G!$b#8H!!HH~Q&b z?qFZs{s!0j24{2xyC3L{*DcGIus-|-SGYe)afQu^va0uoqTHmem)i!+dMW+0lFl9@M?J7#0 z@t;NHmUz}eX8eSiIjT?IHOApe%e)DlYGOoD&d-2(J0G8sgYxzo~JKA!qkF(UH-EGHg zI{^ub>Krg=4k7LIA zZ(oLXS6C6RYK@Km#?bZnyBarFbNXt)X>{-DL*4L^n1)Y^W`EQf6aUH?lii+*clLAj z{hX#h`@T8D^<2T;XbV?RpD3<;zcMlssl9PD`f9BElge1{Z3kC;1k|qW3KE;7Vu@TN zL&nfmSSCtv`uz#LDcqKK&c714yz3h^W<2@`zO@*SU#7d@NgqXkCg}i?KLh^=Q3ANe zBntYo&dX^BshspMZhFS;yy&LDBYEa#z%}r4m)A2jKkj7Qyp(n`^R&m6sd@nSVz`A` z*az@^zmBnD3Y*9Ma#rp$u-Fs~Zoke!Jm^>Wm4209?Xj|wfT$098{dHzIQun-k`+&- zYAd0n=ZW(iy+Xx}&{>|RZYb7aA$kjLD1rT6N`V24i| z$63)+$3;0A0KuRDOStmZdQver?O}<0T4fbuLd7wqAtK*ib5_U`t>=v9skD!YE~d3l zOc~RXW}i8;^YBhrs-CJ1!lPi7EFt6}+UyxnnBTyP-)aE6`{S%DX2r0A%x7jY6}Tx= zc4z4xo2kM=of7jK1*?Tqp(|l-IcuaOtUg;aoU)|nWuGkeI89nBP!Gw}w5Q-b*nTmk zX$-6&LKS98n$E$}Jp85{9%X6z?2}benY@ivu&Ob_Z)PgfqG9QkRoa^xS+%8`FO zD~IXD-&$7AO5xW5*`3rmFLJyB(hFw6F*7+0)GjaN6udN;_i`cPgt&&^u>@* z6qVkz-mr3FOW;)SazbAiwk?-H>ZCM>E%%ibadAalS&>kh!qWT7vbeY`t}H`i_+nRh zCR9Xu6F22IK5_xbN=4k2-I6t`@eh7arpZkF6GV zFKRZl=CFp-)}!tTbURzyA92>mMBau=7^us<~m^e3&yI zUNs+q1hJ?psQvBUhHIzS)uo)eH0)ecSHy+_Ctzl%_r9$7R&RpT1bv~t<^H&~oFgmN z$r?DP$c3meI(2(GZtdV|53iA(n;O%)ri#;4eJ~s?ie0!}zOrYvx@%Q)G|;=LGA#8k z_J^erd{wnO(i3fqE1Lt|n__*a3JweSCrEjMR0{uw1X&P#5Uv$gLx*WTUXOF)6^e zEkzRPRvV@=1YV+F1}k>Kow-mI2L)}9{vF!u@R{+$PwztbhxQ6j5zdF2Oo-w`QP^YALNZ9u z3o3XlQV0U&ui^In0rU~THZ;0G@0JRyY2n1X13kgEP+PctLvIZ9ZIbFG<)Sjw6K;#N zMcY@}*2rT4+onRhw106w#0gVT8L6xXQ@8!ti$<9>b?Xz0 zl9!s#r++Fm0KnF&O`TC&UB5)l49j5w6O#M$-$=BHK4P5qa(OX4ZqBn88YiloNFd^&N zvI^intYY^~_6EDPUjPrJN9zQzJ`FYqU}G9=62RsZ`1=+FJ|ZxYDz%86z#5wp*1AMR zRbp3Z!c>?jDE`Jch~tKYO!Y4N7Q00t0BZe9C5t6n5(GdI7G4)azV8hJD)!f^HP$q^hkegl1N#j2+AEv2CA{{hntGxGod literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_338032.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_338032.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53058ca813340d369fc681982eb832764f887e92 GIT binary patch literal 11811 zcmdryYfxKPdiU!6ej_Yi77v337!W1~KVytBegQGyN3hqj%oSK5ByuHeOs*PbC#@xK zW)PirEHXP|c_%ZGR@3=??7E?x&NSqCm% z2QJGGA68NF_e6IxY78pmMD`zLrW7t>)XeLceN)PH;Hq`tYL~>d4sKXr0mhWZCCyLE zVYPr&aW$PwP3hrhfS)lx#juGoy(b>jtj<=!d&HntK)AH42KR(#Uy z_IX)Z??eFJP7fK2{ExQ|*0# zVU668`brPHuZgA&lPl+p3hpV6|EQRU!rv}7#I zCX8VuW9JBhCa)4aBu5(wuXFZIk<6w$nYqFG1j`tcU&53zVy?B>Tg;T@aL6})@zdi^ z)8)|Qmulpt2CWB zg|v{VQ4+x#W-C+8MFJaB!+9}|E42XST8u4?<@~q~;@;X7l#MTIJfG3D=0I&1VFL% z8gjpT{ ztb+Chob*)(DafW1VkiII%i^fOb?aaK%^JBIvF9b={RSh|a3s>v}CpjpY3 zoqJ1%eExLB@Ah&PAwXCmv{=DZ6a4I~zz1)iot1&f($u(?p$EqPE6xd44i(xp;c^C8 zAz)c4wl7HXH^VTjKl=IpWaAYdJ?5Wq4!IgHIfuraflEH`Hs^@T8))>?LyZB~_(UVL z8Z$-+fQR{qee`(4#8n5t9WcuQpVu+ugSpRU0YcE zK(3szL@ZHzbYxjx7CP`iqnmL@+%YoN8h6g^S=MX~b^VjloYEVjeIJmqmY8(*Z2Wap z(2VrW^QRXEzO*l#Mz3`y51&D=ok99DFeH6ptT9=?2Nmr>`aRG#6vV_IIOC?+$ebpz z8EtAqhPH*)#g<>`7hBQpW5{qUEPbI)k3hGP$Vf_MOc@F@$~44xB18Rr;ezRlw-ySK zy?wD4+53@x062al6=@aU$_S-8Qt^J@%wS~jXALWmvQd4bAYj&C6)Vu{bmVkOWk?xJ ze|#b)ik_aWi%XE9I?+1c@`ZlB71ixt@S(bHr0+> zq}1AIW#mju5bK|lA+hpyDuQ)g|IAT zEROk-O$Sif0c1Q7mOYRuW@HgrN(Bo>^8zOL7pX|DfU$z({TQ-;#bhfg!dMV9#GG^H zxP4BQAW@+W8ElDmWY`v#K7hq{WcEnR882Km*%Bv^sUa-S@40;oqA#2a($CH&JNnUU z;4l5yabMiW^ZNA8#erYh7f++z-N?|ro=>mo^cWL=>J@Xw(Vc6SoHtZ*1LVlzLS4_1|2xLtT`1Hf|Qux1AB~$ zwazJ%)vah#D>Ae$RHTF1=XJ@$gJ{g0n-^tEAflY^C$tk)G-tNo;eBPr~2Z9=bGmnj3B7Q{25U`4JZkh%a-E1^K}y=zHVm7 zeErS=@1QEEX4I4n-?)PsM#Ct^iZXi0bMzob^JUEh`Ji|lv;}OFk0LK0ZsdZa1XxrR zG2{FxX9q=AegK~dxrY%3@Ii%p@@CdD_M%TP!hBE%}44l2(0?>sEMb^zM0kBw-_8u4r2d28)tebDrDWiX1)q%_EzDB$HHLV)51*s-Rsc1DXkPF+)BN+o;WsDZ?uOwu!MT{_F zjue3H!G|dg{_X*mR|p8k2=+5l68^>uqmVG#(Q=+o_Dz|1N$cPXHo%)#@w{Z>8Og*3 zlJn!UlJn!Ul660Rtuk^(!pIo?oy?ak?rSG_(w)p&2H(;46P1LW@RL^rB=IK9LmD~H571B)|PDttpS(U@1w0}x%*KQ2PJ0J{;7aR42{Sq=WdcW)evp6r3*PlX$OH;f1&R*-nI_!)|j z5y7yqpRfsfxcfRJNF?ctCo3AF+?Q=4R@B`&(9Vjv$)NEfqRXK0<9V}&LLvGfR9M*_ zk8}JY<=i)2=A4-DT!qvH&So_1@%Vo0+WAIGE+ zk=jxl|4H~VchBx$)|Mw_dl$q|aASJ>`yYPz;qkxW;@@!rzaQxHKxfmlA}>zJi2awG z6OdD)0$1Usa6=jcaw1+hABiE^1i2o`KzaaTI7Hzb6ufq{z){QR4Y)@peUp&-01&Hy z>_XtG!{Z(U${`=56kt9lDJLtTT$kOLo!d*frnvk8?Z6ro(p~V%9_Q%8xOoH@M{#is z7sqkY4FxNA`yFZLXQg~LNtL@+XdEZ$23%M%fduj?ut?MCA}%!~8i70(J%oFl#YNV) z@Ftfpa|F%$aYcvVtY9Xj$ABsa9Wnah#OG5z(o zld_>?$5IdXhviitS#DX7s4R3O z{7y=5joXpFDs(iZEQl2@DXmFSd$OW^8A_!!rPR#qjO>gFZ_4N7OUm-N^P`blBlE^z zT0XNZRqslQb|sa&Qc6=yvZO3aiuNY0dzYb9mOYZ{qYWQ!`KbO@{XJ=8N?UqgTZy!l z32}UCN!yr|H72!L=qKV&mA93De)5-R;XYW?j3w=^S5^RlHSqIsk> zME5OgE8?e+wk~mXN!yl`wI#J}zuDUO{>l65Qlu`8Iq#{h^Cv=wQXsx-Nm_iX>#Kv^ zVRKX%9h^Oz1bfKK@5^hDye2^=DihPU-%i%u&Z9cW#O(dTuf;;#SScK$|8c4u{3U6GFC-85|t}rLah2) zL6|m2oQc{6!BXw6TPGIDP}e^zyF&+`h)D6iZzP05J5v-Xx-Tz6@}ijiX4_ob&3$wG z{#IW5Kvoc?kgOP|GF0l6O3VE>r&L881b|R4fO>Q&(*2cvSEMJlBiz>NCNCKG4 zu~jUqO4qA@RsR~$*kc#t<@Z!o3wDq)b2xH1+V`{L>s5^pROaX(d^q_LbBkG4wQ%wa zxzhN}_9D6BTOpySTWKRyhEVsnJG1rg{dl*zYdO$ztWfytLh-S3DPB~o58gF!jC=AO zls5=B5rjsBi}*AzZhFpNQ_@O$VN!WCI!jBc{ao zIK@$64KS7TrKn_qj8xD?!blkze<;J3ztKFO7Ga~4;7yT|a+fXzH%?iQ%#n`mvgR2f zDETi)PT!0bf=WseRK0+O6!51FY5U-v5uOe0qk1X(0F%MftDGxXR*+_97%01*8KF8XoNhYS2yFYpm; z64!AwVU+=wKj0X0`dtpYO-P@_h%C(s++CFZ%0@x`@m-API$`buK=Jg+A(vk>2;$BV zIWo2b3?D9X`?c7_v;#$wLgN*s4h|mCx+2P`I!!wQ4xsT(uhRmq2tl&Sj|I~I@I|{Z@y77Te4Kv zv|zrc-yb^qP;bITjxNXJ)rp#g>U*;GtR|!J+IjL;BS01HSxweJc}?bP?-nJUu6yd? zJPQ4n+F;&vPrc*6TWBJ1Pu=p2P#K{zhI+WOybbak;0!%IhYlUG={ZNG&j5hcrteP; z9{6HB3E%rz8DFFEAO6A>jG^a7|6n=|lNAM!$v8qKiJC zC++>%1zBaVF_g z{*kCykq`pewW4>6Rs{P6O9^902O&L*86swF)-el&OWKkV5)vw4ogo4mp zVg)x#=1P#RdPT{>R2XKQRj;VIriL(>SF~JJM`#Qydai09bfy(fixm?gQC>Uu?ztxg e81n>XnN$8g)R@_khHN60Syd_{RV_39QD!PtK0u3g%c@K_RhxgF&9qvz{n77SKjWAo zlg;*Z!RMayopZj|`Of*i@7%vuspJGa|8nPN!+*Y)AbyD(@h3Z0dG?bF+%sVz%fu>_bl>Tg0kkoD{M07$+oKfXB1>tOi4H-9xejyC()! zB%lPF%jVz~TR7TJ5DYn;fTVCHiSwU)(QbR-Qs5lPm(ZoX3zEFNJl&rr~9!Q8{qmZLnBc{8;h`+c`WV*yht7ccMMQpT3DE zUrAWgzFjyKz6MUEuZ7bJHk%Q$6$-*r@xprW_WXjzvJy61xr>w}Tgk77F>Rf9TZ7jl z+5VT5CuRXWr%BmNyUJ$zAC&C?yPw_1?~?;;1-~Bn3_1uo|wgu(Q(tg(vXNJf^JbZ{Q_%I&stKChX2ronk?mt+|@Fmh_#R zjxnk2*CRktS zJdsln#*4-*u2D|xbPrpZ7#vSwG%&y<;{9#P_X4*E!$)YebidvGA6kTbm8XYv6j_Aa{^|Jli-yD z{Hho{iieXjqvm0YeUPsVPS|1n2AnQ)tXBwiPUx~ZIN4N994MHofo5&8S<4BgYv5PQ zNdOyT9pl7P{EM234XxpLAad`^MDh~qQ3sC#f;rz0o9p%^g)M0%|uo~%WF*;+YtN|Gse6npy z<~!$C&Q3vBZY`o}y-i`cYF6g24x9^BzKIIoMDjPi&D&C?Z}RSs166C%{b5SIEmP0t z`b$2ldvJ9AXs~lpy1u6d$!ff3!Ya*d6;c_!t=~vXw(<%dNav+N^-?}E9$U{l?pJN+ zRrpj-)LAn&zb!xp8iSVk6YJ{z-nOkIY+YUEZTm+>&UQ}ThuXl|K2PWq)EFMOMjmt;Tp0lgD zq0XMQ-fN-GYpDHtsDCgtY(xDv)NVrsBjG*8--!hp<3AhsE(kuU`}Ektv7f3VVnU|; zPC~$n{C04EOA*TKus&zD+}FHq$oYetf%4hOxeKrehT5gJ<+ERwF1MkEv!S+L)X<9z zy*}9%SWd6MceXQ-KljdJ&dj?*GbzPV^(wJ0a`v zhRU|7vLls!W6ZfZHi^b2*Hu$KK{%aU?{fE-wadNejh0n;sJrJY6?&s5blrv;Y{+1H zjfA(sHp=|h0)jx-JOwrath@kuLdj<+KQ*hLot`TR?p-|c^U|g6rOJn=mWNj8ubWne z(8-Qa=i8xo-bJ17qLc5UgXX`b-rIVmSrG}L%n3->6eW1?T%WrR+X?ney{~o4kaNFn zt}W0M6ntD8q=StMibW?fzJ&_kT9K@{zwTU_{6Z7Das!>dfeLSgjTIl?Tr3YxE@+mH zESG;tE?0hbYNZ#|Uqr@Bp`L5s3CZj}(%14tsh&CKKR4SIsQb_oGz5m`kB3ah(4J!` z``Gf;=rVq84qdjOw=5{z5;o)qB|9VjHzBFIOh#mc(iqVZx%nBkO5bZ)VucGYUSKBbCeGnB4z3#2D8`cO#fUQ~rhS{OHZ7(mv2Fgg*`@$BcQJVDXO<`=aTFKFpn$-Ek5OWd&;{;wBIdB6K4zJUdvAr! z;kGh%tL2uJ=A<^qxE-7vCmV1&Mka z)7^(Wc?3GrpMDtgLDmPl540W`OJ*Da;hsnCQLqYDHd+7_A zqOyWa&;h}e0LRiy7I1_^$+QRexx|UzBOK8{~;TBN;Ml z9yiawJE1XFm@)?kU1?7|6KASf;cr3_W~zN(n3~|ywS{*tFXlLW8%3E+^^^plnwD( z^35vZBRwj=$V>UMd|IDwP)Lj3r#xzh)F+Cy#kkE%7y&JDkhIjJaVPHhSv59%Rtfh( zpbNj~6;@cl3IQwO*Ma16Ryl$Zrag4xyTA_pR9uoB`pI4TWr==b7y6m|lk_w9C+TPI zPtsTYa1HuJG%-Q45|&~$kL8bIYeT}SbP*;+2ZN1{PbjMf65n3{gsl?W$Fmbq{Go8S zuQPtwD>R80C%!lRq+c*71R8=4K5v4gR1%HqIne-ZyJZq_qDw8^O`JHIL`kCQ7-Iw+ zE_`Abyc4@C2YCHF#qvDm6JN~cQ|!A;lJ8H9O)91jhH+xrJv|N= z25@ZvnYCEOy&ZIGNo!s`HEDe1O4TuufjunxCvXu>%G9u^Wp_q5q=8-^Xhp5Ne8LoPd`^eca&=5yGYDu30wyujNG`D~*8&Hn1& z@7YupAyrYpvY|3Ab$VOFK&ov`x@VycGFzW(_Waz15U@lUH>nClRRnJ?>K8{Ij)qJ} zQRT4>>i842VN+d>)aAjxb#;xmjTa>U7uqhrkLFtYw|bd%3!dPnWbm0o=P) zvvRje4=#56S<7Es{QTm2X?s{#7(5%&?GGDD!n!@-tlWRC&J#-_4TMtXy%ar!z(x`e z-fMh>_-iQd6E?yF=P8Z^J)}qA5qd&E_2|^ zEG@=gUHMs<19&{rSDHm7V;S%_Qa14|2jW;+rq~aBhU5g$qow^o86sLNC*>hisAIjP zJwpU7`)!SJF@BSuf}ZVj3l_c&vOBi;KTUPbywd z|K}|Nq*XtP>wage(EPuVHtQv%)naK`0%9dy2e^M6kv;W7!Mu%Yr*UC~0^+6#+5-3f z654vpHelrkatgQA{U<2K zOKhvNyloLBQBj}bjJLIV@A)T`YDVE#%(euM1dl8>E+5&T&Ul-i$Tc$u{Rd|U10#z~ z^UkHg4f$K%<{hA_{6J}-D_FL!s_?dM7nMEepYQj!`7R+*0g&8RjYOs1i?bDh%59~} z+ww%4bH8w|aILszIcGzA+IxOmYrsVcFBgLqi%rnRs$D=AZ*)yMibwePExrM=GQ3{X~{-hV`=% zQ?z}6&>66~MGi*WhvNN*qwT5~zF$@I?Po$lpp4TR6cRPHVbkHTu{@kt1UIMQoIT&_ zTS!tDmMZVDciD&#H{i3yOrgIpBE}84{)MfBn^L|>5!nTSvcSXxcAgCt97Kjg5gFeh zC$#y2><5MOg(#~cqTpdl4Aal4A}YSACUiLw4PVV7)Vhe4uj+^_LxkrdVjv`nd;NF& gpJn5oXZ-?#RQ-zTCQ1Fb;{rlsTvwOgRq#6hFa7|AhyVZp literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_344391.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_344391.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e42e76c5745d1b9d0285962567585127758c5fa GIT binary patch literal 8238 zcmdrxTWnKFmUaES_VxRfciQ2R;53kD!$LDOuLKAXmI*LJHjLYydlT$9w)^@9NL1FE zQA^mpE#XK#iqX!f-SaWSEG=zX>Q&RyO2d9FtNHQ8>5ca;qs?geiL@W?VOJWB_G3?7 z`#J^;VS1UBb}w<&sZ*z_PMve=oKwZG%w{8k&-&rtT;ZD#`YmZFUqc0W5~C1WLM&pb zFd8q{)Ho%#RkP}G4MlLOuy$5AuA9}5>nX%hA)7MIM$M`@O~|p)Rv=bWA3vpGwI68i zR>;PU22}fxb+S4Rg*?hE8#Swc0iJ#V-oTkyBlm*YKrD!5UfOGdUXx5`=FF@GepdL| z>aRBCVeKDi$1PQPIzB+-Ru$r`)t%brDW7YdY-rS{lIXKBAsCGW!jkqv1OUl&6p8O1;B_x0j%U?*PKpQnn-L3Q|oSRUp zs+6uxDAiR;H+Ev(%mhL_HIJ|ho2#(;#=e?aG#eZ#5LRR70yL^~d0vg4RXsJT0XJaJ z0@!+@G139tFbi6MJy>(2ag+HpPp==LFq3G+ZroVig9p3GT8SoXWYr5G-9&TEoH}JE z8#T7$=DQUwBw7T6(!)}?8MoXeT7&Nx%BU{S!)&d%wPrr~yp}1H62`SC^U0jF;Wpen zLpW*2?!+7T4QwR)#cgu07gN|PLmf5q){}h)ZvRZ9%p5W+>uuCR6%1LbU=^t9Y|Aqo zY{e?PRpwyZCLC;gO&n~;HeQXlGbphgTls&-O;|5W#D!^66JA&~$fp7)Iya*z+gGAX z8C7m$qb&6uo9qhVPO>82@mi$o+jO-cYY=wGk;0hRjSY#O4G)QpxJU6v z$IxXKEU&k&&UIFu%@b)pD61oJE71= zQyMgneA-c-goSkyi`NO{qa@92AU@+WNh%>M8Do5Onh(tSgRG;n50(#_1O#OB~x7BgDmHNFCa`_k<67=oO?gaMI@%uyutxj#!7zYNa?uKOzk%mU!Ul~!ERMaamwWGHw0`KZJMBLc^VxxiHpwmTV(gm`ez zh$j(MGQ@&xg%g^9^150A&B}B50vrQ7hxs3GH+0#qf=9HP5?UV25bH+f$)^1 zFDF)~Bx**|&VZD!^Ca9#+E_FW^!iDt04W&>P@blYyb8&gJ|2vedkLWQ5}5Sru9Lpr zO5G>vh+6&qk_vFD>wS_YIw$nXWuIh_`(%#j8D%CzSzbhp6zH9>p@xa$1Hhp7kN^H~ zq32qZpNYo;lU&c0z~oFoxDt)*3QU7l^u+ne9)X*U^+2nqA|wHfD?SzFXLrZu{ebs_ zF$mF!e=-UtF)%ndu&)n#By)v)hT|h#m_GsBkxVrHcW}a!=s#@8?p!i2nzMm|&6_;= z$ZEfpNGG!A3f4E0CmxwBx6Y)`WIGC`*5uHd-j+H2Q130$mZIL8+4E4}1mK$9p4YcM za(R}HFCNd1+-VY9`wOmtl;M%XwbZrPmF>>Wi_N{uePZ)q!SPl~UvxHSFNx01G`;3( z%Na#icgj#4*q@8aI<-rOazW%r7+=fOydUelX0^h13kS;^X%eM_`@lPA_Z z-rV8EX!7hMi|y7_dMY!Kd#7L-NFIM=c4Yc8(cHCFb5DvY+Uy@6Svt9RG8@P>6kI;h z=1XadR#)aPMQdkDvu1Nc)6tOa79HDibE0D>Fz0T}?p?f`GFG|2Gq~dDU3CnIj)5<=#dQo%fFXilumviHyyC*gL2z1%7*pRhnFWzpx)45{rTeT00_Q5Z{ z`{nrO;|2Sn)R9*pDLUO*{o;Y#mfW>ZcZ$y4sbkM$h-I8epU8}5_bg822JcLLZF=iD z+^XrVwdeQf7Tu$D_eVPN=iU|F??QKDTh1yr_AGacjR#UEpXPeoioIvm-Y?qwzv%yR z|L6M)_JhxE%(r46T($2N?R)=T^-JSd#)AE5>d5okcpmfYCfKh{1JB}utEN9I4N=I_ z`^^ao(RSH3o02Drwq`g%V1oOz^Z#u2fdN`P**4MAd8cvtl<0d$w7he_L$n-CsfpcA zr>8T4?AFEUf~6fWwDA_5rZdh=e6cBe^pla>Bf0l~e(lq1t6lrVu6>2p{a@2>f1^j1 z?k8}1kwk&F0jDRa!JmP92d7$baK_2Gst+g57Ggo1?|>nECmE*?Mq*)bF_MA!@44xK z0I`VYggHJkS$mTqw|J5RJV`>>yJ@*;y=hxQ5K#g3>^A|q*oT%6Z68__My$I|#Z?r- z#)J;*SZV=c1FM?Z0S-q$0vWw>qy8ZUnS*=-mGRX#1;A7F)=1m1ZmK*_3(TXfPxaS` z)unAv>ah+=`X~CD^;EAguL0Mkdg;7$+#^)7C4@qC^HL?hNH7rkR7xbHz^-{ZX8RW7p7Msf~_$5rgX?|3(87hG^&yzwi680Vc0 z#%D>=wBNg+^E`sYaDd=L|eA>B@l|GAuTD)u)j(XNJxnmb7w~$8UqN zlCE;DiLrT;$s~WAVCe(l!0dZ$;NU`gAQlVHL&`=Hz1;`G(MiZl4gyf_j{$YO1&ZG% zQP!Q)eWv?c<`>NJ@V#UAo4z_**l}8HJj0)aex8P6ofx-|lCwOY0rFQOK}ic)EF_&q zNJ@lxe>e!~1$l6Q{5No&zeFHwgu5>IW6^6I55CFYFQ*>Bx}3B6?EDB2@ux^}niOY9 zF-nTFq^No<{?AGC92AniqF+f3xtW}LN}6fNfO&YxLb7D|%K*iRWj~#O9XfQVEF)8` zWcVNuUL+$HK(BvDqEaL3?E3llKK-7k*^(Se?aA!P?k(C|9@;yz*ODV^j5|xE4`$Cj zWZLtZeR=P`0+dX9k#S@~4;gP>b2z`{Z~;ojTVyP^_NDh_J-Oz5%;Z z>>n@v!=?MD9?a!0{6+r!M1JB2d4DkfgHS#+liwXK01>nEe=_zrjN_5TxoT+_E$!L& z3zm+wdd*@@sUJH%AIE>ZXSJnIZ0TF>zE6GLQ)oGI|45!H!P$x93DxrP(QImb08f#cJ+-Au& z1-^uClQAl+WwqqRP##4%Afxo(Y7{M0Rt3+i1OG-Vcj|-z8$$J7At8DZSHkKiVC95S z-Y1=)YE~z5dFGo&tl_7W%x}fvBpBA1Fuj6=jPRCQmQN*WZ7Ltrw~|l&QP0T-qLPVa zJ~Q2|$mS(-GRt!MKW`DRTgzQV_Aj;y>;IeG+FoI|cCuTHAS^mYeHuCLuO|VuP6Wjh zcfIa~0>a}Q8-Vv19m`z}PI5A##0Xqm@J|NfoWIYfUPs{9`LHsM_X7k|l5?`olHvjB zdY=?CP@MOv%MoY`!N}=y&3ye31LZJOgBt)GiBGlb4*KBvhv3L75F3$$4F4`c=mf~c zqrA^7>3J?b7Z&2>$VGxyIebkKNDh5ZM@LBq5id3;NT%9Hnhj>-ugW zM(eFN({E;`vY|VNZ%3D>R*c_C9(@7U+?4Iej^(x%%$>>MwYIIFOx&JG9#5SSHO;W( z_8w8wkvyI0%yzAr%*mn0cF$7lV(Y{9-g}-E`ymLpb{8pXcsZTxywi2BbA>*=iB=$<+*H7r=VQUhG*5ygPmZRf4h04I zS{WC(>oLBP@12)z9^S_yf~0vb8V#2Z1TjO2DIWv>bTrKV0+mzFN-8@WW#_`&LH;Jt zfuW)}17%4~QPgjd>0gog*QoX1(3Ukyf3x{Ub4hiIq7Ia3wD0X-n!YlX406+mY>rzv zjep0G<`XqinJR>=3N1p@NNh`Zu;;2q zQ#D%y+0hc@g+={U6{J3s!D*ZxSQQk9(BQcD6P zi(&0lB@McL&bjxVbMCq4o^$W*e`FXV0nhJ$^3CX{`v~F>7*U?G4Dfn^B#3!}BRJAe z45r`YASvUj3H6|cBzSU6=A(2?zd04J8Kc+JFvh8i&^1Nq+9Gsa5xTxGeVFE`FEo!c z3JvNFMDC9(;|x48W+|Ypu?XE%giiAYjww>d7NMJYBWHo16@IqDZH673{R{1&DLW62 zFNi@}MeuZXM9e+XRsIqWp;yvf=z9g~T{q7qbSGzkt?&)-2x*-|PUg^Df{d~u@ zPRZ2Wd#&?oPiM!C&U2FfTj zQy-}=(Nd+fM3wioL_S|lj-LW?B&vLzk#`@6k(v_R)miRUTjE}|j=K}pAjde!i|jy7 zWS60xs7z?fld4m6%e|2XRb%ZG*^yl(+G{touSL6%Qy$;$fDu(AJJwyHAbX@1w5gS~ z`2bbKRKf{P4YkmQBjsLuvbajd_W}={%%cW;2mXGh8dg<-hDeX8xDDEqrgKE5hEi(s zbT2$34o-a^Y@}Fjg9L-ZC$et~TI_xATGZu@r_kIHXrFA4of|aSv3^|l!M{1p6YXQX zCg6$Gm*iWwEx!BF?hTrlutg)TO>{=4A3!cnou|u|<_z~-KR!ff=j|Xm@J=|gS&j|c z)=~E7qM9Dswx)&n=JY+BXU{@DL!yI)=9bgt`A1_39xBvOkVn3cHI?8&qC+KlG?(C^ zLWfK8I1(cRe}OXVKGL#Xnk<#aRp@9z9{Fps z9koRoQG28b-p%kng35(TRF4{Qv_tz*Q#R@yL-i&C9m&v}v-IPE@=fhwFCBsQg<7Ku z;rFoO%rc*i`P`y#{cK}Cw<ZkC-cTb$*FiAp{LHcNC6*( zU0O*M_DjY|AvhvS+9uDzXG9@>~ zl?*nuTi`u@n1JaF>YMP4$SCe1SIT|em@et2(-e5ixB<8E%C?fhjA;}qW-ZYe!xcFU z_mKN%>tu|=bh#CkZ71jR7_6qx$VmfGEvbQnq{a}14V&hLp`5fVU6!rX zg-(HmYyX7jKd(1_5){TmlO8YMIO_3^d%~l^z#-2F9|$*w1aD)QpO|a}tWgMtJ;FVA z7~DNH926!FPu_C_9SFW~5d0trZrsw|(%jYpH4>9)KF$jP-Y;APHL*v8oZu2s;>`)d zYM*E3n7C)fQXB1lVYWVq%tqpUE9PB^deQ89^6|IbkGdBAEMjZ>t_vEjBNXthi2#^yz5l8dDb6iOkMuSCX;CuFw8B+VyYxvhP7s z#aupo_oeN#xH7|0D?T@Z34L`m156w?+ z{%uffxV(IQVAXmHESWLxR@#Df%65F!m(av-ijI8?wCHG^F(e(%jieL>V@#SYzdAC1 zeC~L>FQHr5{k3Vu-XNMA7P#*^7Dhz2ZOJLJo!F>dBsNT0%J{B3{ZuqJE$;mdxmYi< zCzggpwi}DPxCL=f*HEnmIyeBjpd)m!4} zL9tSFIyyQt}}yQ85}l)1$*VTI)ax#Ocx)-9g0ikbl)HdK+ucPjGaR4skHkN7e;CB(g2onOfI5)26_8N@UMqhx#b9;&9}9 z&obW7gu(QHnx$g)7!s{}5(Bx_`E>Z%wdeiUp56TWpxAo-x73XnlyOT||h?lj+cxXPHfl%`LL0@DzNQzu11Th_@PMdim;rcz$4c@XwzA z#Iro;5xee)t#^J)dEa88>E-@g;`v+4x82V@cb0G85xcx%tM_*l_Zl_{ghWmYXW=cJ z!w3U=KG+JnVb=sZLCKijA_65OZxdy}ji1t$HlpOFmSU|?)rNv zoxq2ugn$=+3Q?X40OZdizkoBqJY1*?DntYvd52PbHrVi>+@}EpBE||F!-s$tF-A-g zI>I0gM@*5^qyV4UkQT!cmLnr(P8G2rV;FCV`Er-<+YuKdl zMOOJ!0f9)QLI(J_p#`1cBhbc&%YAGqT(40+h!p9sEm5z_*0X|&(*#KP%t4mWq;P#s z6IF@+h3j+u4JG;)t}oR;XNki7bC#fo9f(0JqL2=mk%iO$TK5YQ{KU|YkEky{f=PN= z3x#}G74Roo0Z;5p?6WT2+fQ9Cy`+VsC>$w`uz?TXbNhYcJRDC)J(I8)(tCDE9|Fk`vGF*m2$?6nKvW(fIKy8-^ny`lvO~ng86)%-EM&~u zCzCZw+K1kjeG2#5WFKTgG-rB3KhD_ZEp}pLJN=*c2pBE8VYk;4;@!QA;B%v(-I|J0B0-MOyU>k zgPYdO-4(zavZ17dd^ZUxj}F6am|xLZPLBcL8NyqubqOt<%dE1Q)+Mx0q2+($F8?Jb zv^AWGQZ!N5B5FQ>+|?Fqi?^>??a_h&*qi2gynX35iczqej& zJ93(ApemXy9(Y#0?BQ3L;Q|W1R~sxkR++Z{ZJ`U{Ri<^5P>P`K(Ox+l!*&j@lHT5~ zt}YiVJF;*a2$DIyRXgm5j|@|AD=Se7B(S(%J?0C`=bTWOpPm#lo4ISU$4hEB6-k;q z!Jt3A9I($xrt~6kj|BbPBSMy*c2erpM39^E^QVQMK|B0yB-8?BN==gF9|-zi2`VVohdEApe?KujOj5(5sr$K0YgRtoSdGcQznej1Z$UdU@*oK zw(^u2Ll%OyrK}jT5f(=ZOG-Hio$1S=&xc-@0rvVM6+tq8q^^;q{nez3FxRfIyFWL{ GM*Sc0eQT=# literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_385268.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_385268.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da0141c6e8a9845513af8877bba08da2393c2046 GIT binary patch literal 11620 zcmdryZA@EPcJKN7`wN?If(Zn0KKKp^$&d*lgh}`s0we<=@fbft48|tUhCuqP@w8LL z&aR57wvE{x31?Ml$R^UTQl(|8N@Z3f&1knj{4f=sS5{56QKkH2Hf?t`YCApm*$+El z2#L~B+lPGLz31F>&$;KFd+xdC@@u71hQagqfAX`N>%|!MTh!nlL!$EhUvLb&j}aJw zJFtQH8y~>gHgB9iAixof=MavI21MiH0Wpr*@KFsnO!7_eZGur3{pVWXLx?8!Hw*afcK1+KD*ZM)D6^KxKd zLTwYJr)9Akz$)09)}|zM@XLUoK0U>-o-lkQ98hh2#mtYe0W}Y^ski2-nZZrj&(MTA zO*}@@HQ}+loE8To>~lhe5ua=CYHl;PG1(V;%snmL?Jb?=t_v4>T6%lTy)B(RU6&at zV6HS>1`H#=*xB0xwWcdA=NV~pSLda^)~4PTZj=s2eYvZ*>2kmM{KXFQJ5B8utOt^+ z#KchN!|(Z!9GRX63^4wT^0->^P2i7tj}kq+2VT;bpVGJRC4B`ceT`I(yO0ujizu8b zq||qW7+o z)FWhN@OaCq+#t^_p$<}qV2NOV$|JD&B23X@DZ(IJ0;A(}=r!$4M-?FHcBMW{@u>1` zTr;?_sqI|#9L9Tu;_sj{aM&GmnHX53 zm4xuIh|}^Ytmf^HXA3Z@*s5+-wpK&sOJq2@rwrmZB16?nWH3?3sQ8@U48Q7^@H_m5 z`PEQ)B%i91W8NyNnyN(hLe)@Jtlc25QVUSF1yB~tuDQAtDOFn^ZmU8S@36L{@7ZqK zD$Yi!T5fKSC7ku(sr$eyb4t4gllog@n0OgJO3OxxYhP&EAv zK}Oy!*A~DVfZ~4I~iGXU*e3>!>C?7 zY%n6X?UuvlWJJ!JHVffqWTeYuA!p2leVmcTac0uRXquav-)S+owe%-OVl-|KX(w!E z*QCchWbs&UGTIbKV$@Vnf`Klj7h;JRsS>d1#7OF`c4C|)ATbKZecH^8!LhQs+~x!o zH!K1nHe46C-MPIB0znrNb{&jFTN8;+O(Y6?NhvDO$(}?J7bBdo6J}5jdc`L$K~*FN zBePDD+|a^l^Mr?F#7IG?g0yLNgK|9#j~r*j@I}nl)RYb^iKms}YPK?5bZVVr~Fudt1_*FAJ7rb57EF;?#`ozG|MG%aX zr9-^kCt*Uy*L=S1tGw+Z$J`SZtF7{;#X4s3+;lmQSw?J5PokCfcuoCDz%a^0_c0skv^mIj5D|R~_aXQ!KI~u5 z;h0iA+Y#sp-C9!=`C2wKy8FsGWw?4xbI5mbL#>_l2E5^-HFb&aov-Ekm?6(Ei)pmM zi(wV5ub?#*OWDg^bafl8Y4ZzXIzvbrDXF8g>u6ow(sblp4}GGC*7XF$PxYDiTjyFs zmlkvA{Fhmn?L_iKyY^aszER&+$;x`bt_;xm3Pvrfb@1 zP5Ty)<8=0MT6cU|5^3(GPxjKfUcWddR`_p!_(ACKs<$EWII~>+&>K}W`dS!yW=xwI(*3w8C@ zbLwer{c_fd{+G8_vgn43k@i8lVUX4i!iZ_RmaTNrr+OlN{j1joBYn{MdgT2Hx?zIW zP5=q0R`_E}SRWdhS4B*x=5|UKmHsu=z`E)>t-8K;K*>^~=^Z<+*ib z3$1MVV(81!&qt%mZa;5hU%imYW@H9S{LL|?I#?WdFT@M=%u8sc(a+ln*^tO)B>_n= zD^!zE^=YK)yi>IHMO6>cjYG6{D5lK_m4;8#xfTD$$E8(I_aYzpN==`+8tK2jYJNY` z|32Mli4aqC;}or(N)GeyXqZhAmZAUA(0%)yJ&cFnj%L=-|wI658Vn^ zMh!Kzwq~gtOl&uEdIihNToL@zvLgPhFw%C7Zn#EkukB`Ve|!1w`akPmx%E|Lv~eJE z&Ae(Eid+M2SR=z8y3s@HJiD3Sz9@GX;|92mi*px4mT+NIU%gcFH~3N+tv?0sK;I4? zWH-BPpi9~#y;s*%@2#t@(W+}}gE!U(hv~uLsA^;nyL<^*ip4e12kVxb*OX`1mCdxW z`HOR3Ui$n}RN3{SUGCoN`>@LfI`=5D%Zfei^2*BjUlpyierAqzS->c1ttD=bh#Hw= z#dFv#Fb1C{&%v8)Mh#UwA~46dT8MJ2lZkB_>f@GaZqJt6ogN$Muvw;T1S7UPCmgW3 zF;c6`IXXFF@z`JwwRt8;L#HowY1^l&%huKup!K1K-$d=mFcDXCjWVWV8% zl=mQB5FyJmDL@6yrz<_d@2^>C}?o`K12etC1 zd1K}9Jv1<%G*k%Z79x%~#`48S(XPMk;nM3~oZ{_?=Xtii5XJX^4&qPRtQZQ-;UT*z zemV^WGKO(hy0||0>7-BLkS3Rsf=zRgGt8Uf^vU0IPJE|~;s*>UPedtFzKPgv@NDG* z(_INsTX}FNSd&-Heqoq3Swx8;53eS~l#0N)OzbFkRmEy0UQpPJeYi{_jsc|v$wKl7 zi4#U9q=}Z)DS;MiNr00%aYE$PP$H(B7({f1#aq5MeV)um`G&E?Q2 zGVN^-SXK5m5Hv?Yvh5Ei`*93%U$3}9{Sv%p3yF!)_nwW6R7~=;NqiuZ12FF+Dh|%LR!a3du`^BWk=&?E7F48z{oP^WZ zh_RPF(lV`Nt>j!9b!NxDczNO)!t2qFNFePq%#3Gsp*zAK#Loe&uaA?E^_4^c1s z%^4wt4=5D_ulhPh2IpVTjM-rygN{}g9C|@7lZ1s45w&TztBGwoSMJ=L~e-43pRB5k%D)ly6%XgJ63M` zRQym(OAh%J!Qs$IOq08=G13}iSW9b)!xmao2FGe`9>n}bC78#6#I(k6 zHLWf2U5d#Kp{!N8F(NQV7qs3x6A|S2&&1^VkZ4tYAR;&rkspZ3RkIC&h7f;2 zIxk(59}RbZa`nO0Rr%3~;3&E^)<&*>KkHhu;Q-{eRA}{(KT^p zOkJ?9E~eGRi^A~qs=6{Fsf?(T(8t10-YKfe|!YogUn5lK@--4x3yThFPX zb84bFbph2AbwOA}t4kM6tLoZ_q&A|i{pa$^ANH*)3ut9Q$g-w1E_M4Z#DH__syOdK zD`XA)?%<`lwurirmKfJ1rJ&jHt;LMR(T8IZQv+RgYE5z)4ruC}(79FBfdDV2%I0bS z>seJB1N^XcQ;Z1}-^eh1VZgF@c$v3)_|${$72MbQ8+nKC{8OoFHaC#FF3qK-xuNQX zlk+DR&di_r2kGGri6KbPlDwEg7iw&?z0+oA==U>lcz&>kvU zD4#D6*Zkz@BL1o5p=3$-ND))0XDb0_h0fRU`?6&HTelKS`^(Jp zSeudmYooC3u-FQKtp_<-6vRLO1C;kMJLUzSi2{8u?&W#;UID@L3R5mc5~1f7Jyh86 zgAd0C{$s%-ljF2Z6I!#?DjX#Z$N=y=;3%FhcrM2%2D)2`zZ#gv_hhhlU^wPUt6bka($n6tAV9 z^i(_aL&+gXls{HHO6X<}Jt!^VK}r#o)M_{ z3gOvKX!Ks7criLi9j^u$v2j8Qk6M7C03x+tok>WJA(#+y^n{b$#K~a<5XHk;lM?O` zZ0<>i#~lw3jDU>C2vewb5Ea`Zml*`RjfyP+S@-!uED;YOw5$qVh$HR-;WGT({{ca( z4|}50_*yp=SZOUSIOuEf*9YrEC!T0Ce3za|l(X`HJlGPd57#fASgv1_obxq3m8oZs z1dc%XFuK?@?^+sOlQsI9cY!K%LPepTaB);w>bvlypmNmEji(eU9V{6Qr^ zWlc#*wm~^bmJWZF8?o5dl*4HR`XAN7l73BD|9vxE^sFiCb}*G-iVR;T8_}B}aRA29 z*?Hl@1(TMwRq_S^760*4nD!fcXL0?-Q-=z@runV3{t>@fx*pcwpoJ}7A-jD9h&%DwapW3^CX9DV4J77p-!{Sx3J9Y`-5|XziULz=X{K(NbtPC z$EvDhreiT;ZX0d<%9klN&<#uJG0@&oysLezke5mpdtirJQ>3;FZ;w5D`Z&cYN3mNBQ?RI*JKrpw$^vsDeI z(rs$lst(iWH(4nlC=X)qe?;?}m*&dx`>CoK)D{aB>^p+8na+8-A%-QHSOVm0|#v_JGPD~(1g z?VekH6~u&0YkQ>GLcVp*x#x9H-FwfyhkrI24H!Hh&irDy{RoEr8X?)ooB_6P;}|xN zaTtetv4Qj(AHWq{GfoVUIOfJj6+Kze^qbST$x(JE4PhLy2c6u5t}RUW&>Z!Ye33CQ zpwnaMk1OJIZfvxufL(eIx_%G3!L8?*J^EPaGZyaim^jl@?SLT{G4oSwfYD%XCg&?V zjoU4o$i-e+*KzZmD>ocBvOuq_zuxcYe6O?Dp#pEX4yi^&15Chg`xgMsV?p>r>1R}v zh4OC>f05}A9omOeljE$~hm**0Rtu#ZF#~o(hXq2^a(L*7a42N;b*3Q>AG?ZSyxKCE z_bm$vgm7pE5Kc1#GbyM0h{}uCNagr#!V#ew=?Pf{t6&=&1k3m|2C#NU3)G`Rl|U)b zF`-(}DNv2d29@%R?D0gN`-@Cl`uoJLLG+*hbu9kNhsrKz;NL(NBkvG0bkEY zW_v?t)DpmffrJGl7~s$#2^2@QT$(4;ppvWKVX1zXrA%+AQ7y52yI$r!$tCg#$b#oMkZ3& z;lb5=>RviEyL6rhohsU^I(`57C*$c~=c0RkU!RYsXIdt%8SK^3@2@94w-i0+VLwn- z*sFSc-TQ0$<9ejT3U+G#r?7vjy?fp6UaE`}c5~%@aMw=%*8XEuc! z?6;#YbU~$gfB!D1vd16l6OvW=F^$=^y*x^#o9H~Ui6A>JljF|7m|ZJtf?nA$!TX1J z=eT2plS!}N$;p(Hclw6hvNjO(^KSb^8NV+RSKjHDiT>-ovf;{IM}KGE&CXsRymR-C ztPceF5zg(HdYA`?^Pn#eax#B6HInZMjQH|BfN-)gD>LcvdptR9{=7MV-kd*g z&JQE;4ZEFOK&ArjM_#v2Ht_zSlb?2QBjYkN;tRTYuiN>do0Ih}zi)JM$O#k_Jsy1| z9(GoYM_Zz-pBUjB51m2RuuLNmX_d7I8kbqni4;19onGXf`;j96MuLh_6jU7|xnUqQ zxF2u_9S?I5sLRlNIVjIJCR0;jH2_~?$^lkn3Vk*QFhKay9W+EJ4G}<-2|#7dR3rSF zWD+QiN((S_01R!DteJ-X=_W8Z3BDlp3>pH2dc@@p$oeUV&+k+0=@omtI}i+9gfLjY z_w)O^4Uhf&SYX2GayJY+U1QGRu-|vuIpp>Q8v?wmA?O~TXaKAs^MQge0|OpEKYnIn z+5vP2OtBz%t@SrC=!bGt(uRB-RrDr_EzLp%(lu_iJhsE>bPa@&%~P>I!jVlkt!_<>r z2(_U=OlkDdCvUHpSAMShR2MfS9!S=fMT=y;xLW?B@b&fL()kl}Ct@e#(^6&QYH?Fo z2ePrdlI3WaN->t`y%nY^#TIXv%VK)TY!kav=7TZ2WIl$V(&)5gwt_?8-gIO-W{GvJ znrmLgwt}9wEDxTNP>0(s8Tq?IO(2EzP#?G~(YZ2{+ zr6SfNS!%=7E*@&Fs5Ptwg+(RN_hMsG8TevHz7(s`?cd zOzhPYAyjsEN>gj5)>j16L~!P4ed6(ottDI*4nUoG8xxY)U5Te_L-yPfQ~z0%g7 z93D^lK1%vON)8KRN6L08D->cACS;4)B9x8UC-&Am=Wp zof}Jfr;^_3 zKk{*$mdtgDw?&$qElW0#| zO*~3mm8?xtS<}MV#T(1`GP&HeOiFFHlJ5>kZ3Dkz-~Yn|zG2^g`GZ?{bu4tEI#Hjr zwn}BKi^WTI%Wp5=Om;j-4tk^x&#%~_{21s&RiawO*tYz3j`16IDDV6KiAFFDUx+S_ zcCF|R{R>l?-FB!%IRt&7N3vd8x{o5=IgtD@Cv|db^S5cvGp9cBh(tN^( zbbD}+h~pDuL@R0p-6z^$=1_)VqlFHDRIZ#tzn(Lq&2j3(Y*-sMdI(PYlnxn!DDx@L zXraqfzQDr-nJlMZKwd)81c#yoDKIb=W415<=5mFnJ^IP4PkZ#|N?d#N=SC0T@AdWz z_vh_v_K`2#pO+_}SOoI*I(#@s2+Zqw_UO;sDJ-9lQ{nzxoFYa+2lHRY={yGoI+!gR zV}hQevzK*B>H8Pbx0K##W z8~M;q%H*xiJMFSoIg|4@L5@EPO#rPet0;>6H2`F~&FdV0$T=^~9Cl7jc&E{!7LNaC z+Pr?3(;K)1K)U~LkSr?p{!YYVWpU~&>X*!y%tFr+x%A*`^J?9@Qh6WW4gH&F8DHBJ zp|z8Rt_pS^`Gij7+aFGB^eu>l;^wVx?=5fJ#qXC-E&<6e{K+OiJjq=l+iSM zFLE#17a0(*Wlq>-hr)*6Qkc2yTT_i>s)-M*ni|EfbyMjM{J$uSy^V`kfs-lv)=()K zDq~HnhHCN3OVSdpk;sE7S|4tX648Np34pV7gpOK%!bJP#@6X+jIe+$k-1bG?^SVUS z&rc;D{c`HdsYU$7M=6?_HAD>2rk@z2f%%!anYis|p~6J#$)vR!-4_)xN zat~U7Xd-O?4%&Ha1Pj4h8ib_~!b6%65h6KFNc;Nrrz8MEAUOi+2k71lPl(z_Mk8oB zE&4BnB85CaQo8RYMF%rhAstb$xzX7wMo2H{N7YA#v>8ZR9EJ85PB#daLk4A>bO={j zFOUX@Fi!usuz{;8EDUD|v2PF|1N^@v9UrvGz_V|B*zXme!cwoq2a+|*F<-GSGO>A0 zoQz7G{^Vx_jMnrb8~eRyh57%@Xp7!pv=%g40fSsw+zS~=$;$VDw21<>c^NfUXdrJ+ za!#oHQJnijDDx82dAV-s(*n|o^#ypN@5`PyByX0Wdt@lAN43Z&*eXV>xrh*t|$AI_)bBmO=jP(imZ;OFf| zSqD|9NpCQaPHxAbQ^|0$4wd4d`o`BF($2IgQ;-cOCWA5?guL&~+&}V82hdqFoftE1 zIy+jFu4rZ$kzvOf^E9iS@^itr}|~SymtWeDKpj zu`7H_A}hh<@dk;si8rIQu@f7tQS5waIWS)}SG97uaq+;K<+9kbVJSt8%9oq*+Qf;) z+BLd8@5yNV|XHe)j4gj`qlw^}@!Ts$ zX^brqdzI`9Gas#r-d?*=nYfv=Uiyf7#tX&sNhoc~bQa>#g-a$zM}kVBIS_PDP4Jm> z-yLO!%LG)T(-R*}HK@oc)7>%T_i|riiqY(mcHGZRdfk`!kAV-Jx&!|K%@%><_^&bc zUoqpqVO3icrlFrzKB?T&Oyl@%e5(zcZf1%_$xGOn4$}5dd%KO|w;zsx2+R-~`NXgy|4Q zV`b%AdIT9Tn6m5~yTu@k#Vn;;NY9oDqnKxdPX@QmfNgtl3^)FUzJuc>-%V&Rvvt*E Jd%`G@_+JXrC8huX literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_431864.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_431864.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e44f8ce50b56ba36dbee5a4f286adeb9a183f49d GIT binary patch literal 8238 zcmdrxTWnKFmUaEMuV2@%ywm2D;53kDhk<5jUI`E&3=?3MY#6sY_a@kJZ1?pIkf^LR ztCp~PTf&ig6r-6@yXRwvSz6k()T^eYk%s+RR{P_N(;M$yMw`*_6KOx(!>%+M?Z=+F z_H_&x!t^pL?Ox)lQ>RW>tuBt3VD=SHfmP?0=(e`cq3Tab!;oO8$ zRi$)oLaDA&y0H`MW+o8gsd&z4Y&b& z7QogMjgc+D4ZWZR*n>4U8aJ6w^Yr==3e$;3?8c4NJ$SI2td(fOCRV)w(oHni%&Aj$ zvQc9OH{Y#jA<-fjl^&MD&A8<*(HeZmP)2om9%gIBtu^z>=e10slrXMEnNQ}V4Y%Ru z8Nx|Bb|>DzZ(tMIFK&~2y_mvY8QM}aZ#~&>!R?=El$k>oWxb7BsDdGD6|4ePoo#)F zgKb!ax5*rA--LthuZe>l*v_l*4jLtPU>pDMxC!fJiMX(V)PxsS4f3hLiH^-E%J!A$ zR7REC*d$AR=O(*CxPz>ScfJ%^r#H=(d~L#t5|(qjh6dxv&?NLIW@AR%39oA z!+pKep=j~-oRL@6;_laCGk+Q_5=s2=S|mHFj>@P$!Wx7fa-=XN_F!Y8d&5Iw6Yf_0 z(J^$H12Y7EHquu79Y!Hn1+1y+lfbzOp0?o>E_~WpuAn9Q`KNn#oLuKBeWQa~p#W?}!Y#XfkuLQzV zlD?c+osy^-Njn2lzRr_yD`{iVIMC}Sp#r32B0zbXGV&@UXZmYJOWOra1tfD*4Pj(C3Y^)nv-4!7TU|jL3C_lR=Htz?# zAB;hWM*NdeFo}Ww-tP@gbVX0fIE_j#{UjZcoO}G9Wl-&%c3P4DA>Kp zlaFl7twcJJJy)>3kv#FpY`t|ReI~o5U~Wwgt?BKV(+~CDqQP3!+cJX>^-Tb->6yH~ z?UBo~bbRr6cH~Zz*xFZc^{0%F9ImC##m;P3ZeDEeS?(2^2MUh2Qu?B^IeSTTcBBn! zuC|;>bakbS#s2;I{^5`Nmi8|0&F;-z5?l7HIJ+N=qzr3%bH?yc-$+)nHD=!uZQkUG zHIFxUcrlth`^aj)HI<&qOyu4vSo@R5A6Xok-b^%iZPn79qKbCr<0DHa7f)sbxrTzv zC)#}}ZPDh+{Do-iNNLvWZfH6hvR$HMdu~p2>;mT8joH16ms6%H_jd+X96hU!e$mnY z#gQ*ZJ|8JK-bv}#zTL2Yp>4&6_RN*k@!!y%qPs0;5#76&4foy_dxk~#@B`+-g@0@k z-QQ0QzoavN;iBliNJu$DQ0mAlkQANntbXx8Zfow^r@KVwp474DF~l-Xq)%kVvV)5gxq&-VUz^{0 z4!3H4Ywh_xx<&VB-Tje{{JD2U_q)*D*p{=2jor&#V&j3-$)~yAzQS~`GJPV`_eI~A z`#;}bU=BXJG2aR^u*&QenZ3VM{mS%}slXgf9eI8m&tsn51pBqQ|5;pc)%=~(5QVHg z-<+V3fst*qDS4u3Z-x^DCb&O4|IZd57@*aYZ4<2>cN&*ZiN1G4>pSBf0l~e(lq1tDXDA&V7Z}{a+j2{zi|i zT~FZjB8dWT15Qs;gFgfJ4oR|!#Yjftzvrd{ z0>mPo6Xy8HWbI9c+~P?N@FWRg@22&p?WTPRK|}@Av)=^dVjo&Rw0~$zn6U0T6<1LR zn-V&#W2ps*4XkQrCpaAa2xRoijruPs$QHJ&6+3D$5R9GHEN4IEr(55!{Oc}UqvqPOQjI64V=$w2_h{V|}9w?Ogx zB+9yTy3cffOaFpi9=><%e$!WH3p-DXjc53i(9av7SSQBqqvR~lXN3HfNKn#377IzI z36c_F-X9J^dO;o>ApZ?q=Pwb+7U8Z7{#f)H$AfS3_sOXTurBAUK87CwBK{O9PLtvc zDMm?gmK0Tw#ea`9&p{#SEBckxkekV=r=*#N448+9EF?>YzYI{ESoYHi*r7v*$}%$7 zN`?;t;YBiH0rdKZBq}wcj?SO|;L{(7nyty9)L>>XySK=+JY+hu*ODV^v^z_s4`$Cj zq}%hFeR=P`0+e)nk#=N44{2{+b2z{CZ~;o%TcoYG_NDh_J-Oz5TmV z>>n=u{iXY-9?a!0{CWQTM1JCjd4Dkf!%#jnliw3A01>_Ge=_zrwBwQ0xoT|}t?k+O z3)U@Z^_ta|Qa^TjK92unaJ8jZZ0TL@x=(%HU1&LS|45vBw%F zvt86Q6b;6rftLRqMT1QSjBuYqI2hecDl5> z(k^7SB}Zgin*=>nKVo-s^!YuMODKpEaBN7NNl*z@LY>gCs)Y9CM@{y1YOG<^aGNF9 z6!;RlO~$COmerCMLwOY8fQ-_At5F7_vMP999r!nca;Huhu`yKd6%wKsaV4yN0#;6# z$o|DvVf%lxTl*{QmLa>v2*RRc)Tfc-{(2Hn>qJmI zao6i!C?Gt}u>p9G(Xrgs;3Ov#O0>Ym1^;9q&iQ+N>U9KuoewMHct1ceB{?VSEGZt4 zuJ=hX1I2ltx*UPF5{#TK*UZ-kF;EUeHMoAjk@!@*?tl-Te+Z7e0Jo*Bxr7625JC@s4uyiDc*V?vyGI4t%c|3JS)HK7A zJGw>9mgMP7N49g#Y)KA1W;{!+i>(jad+vEwm_rb785b#PcsZTxxYK#BW5sZIgOgCM zYnl4A8?dIs8=O3YRXADhdeEE?a4VLnIt=|uWw7j8u?+riFM3Bw{1JUN<@IusP- zYh_&EuE+RFzIR@>d3Ya-2$JT#Xf#|t5X1~6x_k`$)6p>d3sg=yE2->kl${H62l<;o z2ZoB`G?XPZMNz*+=6^+&U!&H4LtED<{mteZ%_Y?-iaJm-pnY%u%KVkNWR#mGWOv-c zY5Y5yG@q!E%3L92Rj8q{=-XZF=qh@*!&^hq)l~GfeB5KCcc47YJk^_)?2OV z){++Bk`96Cla`*e42W@+j07?vrYYL;O%o#|j&b;;Bqz<9e6uo+oTOjHp&m=@f+u&u%Xh&mcEKxm!K?D) zM>TB0`{bjPq9e)z4E?c1teV3nb$RrscEQuT;F(?UnqBZ(u7K5XYE}3#VKmD!oX_puRLhH-ReEzCS_Dc(d&=ojVZx;wfCMa5wEKyP=yNOYa=6RGoq)~g+t z&v%^Z?S@YOQ1|8D?vAV7U83?#@1@R**1mM(qP-%=5=v+j@Y|l0Bj50Vha{g#S-O>d zv-qRbxVK^tN?C^Tz#f!DhSDrl2#Tre7{=rC;G^D3p_IRz!%b#2Nw);KP_Y1=jEu@X z=*c8KE6e42GoJJu3vfcbRf1Wll60+^$Ao=%6wu$7)?JYE`}UBzLvJT8t=Hr=gGu4V~ooJnK62c5-xa|t>0AWtjF zs1|e?DK$Bo=Gz&WP$TJ5EtsZ2W0rg@e}p^@9&ha)yu=P(MGkMBrywZ>Wa2q;^29B{ z#1d(13FTl1&?Buw7j{DgP3pmhj8LDnk=G(vybXdS=^friXMs?k#tDUJHQvUY8S||g z7Y<6RZphkCXh>NgcNI-+S!TQqTKCq@tozTNL;!yB;xOgMCYYJkr+JRFRDAQ zSd%x%pr}nXte1Lv2D^s_MV;HjPq3WTHS4k7w0Z2~qBaLJ&cS?i4r2P1-jt|Gjm@Q) znM_||RcE@>Vzxb%nbq!crx??dAWU`=)XP%4C!9N|5hhEGdQ+mBcX@35yp^4p7O7Nw z+~yb+3+yiE2hfHV5yy z7LkMvI3mH86tC0ew1WB~aogr_9|e`n|Ly17q2?dD_$l{{&CWHC+w4;|&$!EZ$Tr3~ zJELta!MIA3Fctdxexo#O>?^&mTv3L#Pb@L>_t*Bd`G9eRb{RaLh35SH>w_@a%JHf1EZXD1DSF57$K6e$fow z-<-bmb6Jk$^?mKpD_8wYLcia4VMC=}EC>_?D}%#ps&Ze)O9iGfCg{SDCPvrBHASKF zn5On6firapjWIYG)0kh%WlZIk0@LY(ErHpPI;N?5sU(={mjqUz-O^!9(cZjTU9chG zL~P|-DK=9{s6A$^i&RD&v4-!(jNkM1extxBX0a|%7pxA^F~h+~OJw2Usi($ub+@l` zLrw2C1bm=t%U|4F9$Ojb;o0%6|q4(=84MhJyp4qNR~=Re0e+bF8psRkk|vnf94`of-7YHW=+* zBw$#&Fm5ai5lbgSH^Z$f<1u5cziTHZ(3SWujj1Tu8Zs`u72?C4D|28ae;24#KrN~Q zs-Q7=d#Nle3y(a|#`MikOsi*}mVS2W*Q#E$wYw91rRv>k#Wd|(Coz>F*!f&pyrC}& zE`(bn%CGcC{2d8hNvJ=jYls_4LpNfE2JiuWW5Q4zYJi;pAJEr;4;W3s$)&@vkNQg3 zN2cT(8Z#D39$Oi5#7Z0142}LyB0k3- zhd1AbI+|C4j9XOOXRS#Tw5M%zsg6iZI3c1rINNQG6;%)dXB=?c`6eVo!E>Hj-f7Py zQh+RhABun54G{b88t&#zAK`!{rs_fA zI^ail+m7ct%yZA=ReLFxyjKV`DBA9L0;WKrpcbe{sf*Z)IrY0p1B!d5dxc1zfajZ! zm7~RZX@O=HCS+r!=dB}0rP5{ zg&q_vrDqadH#;67n|5df`5l=ccjAKjo>3rtM{+pxB*6ibCsRn|E8@YbnjMot_pv&i zN_h!^WU*P?tMwcGCV$~5!BX$5ygH9Q&Bf9IV>-q0>T~e|oxvMC#x%X4L;E3UAbG?E znLrAJpqWg6q(~FNn`P3e)-L1Z9^;BV#xtcS*x>c^=O53`pMN|%fBy08{P^q68B0!e z!D-TicfVFkC?Xl4dCiI4Ly{(Kuau;Gv&qvma4OzpCfOV4{KrUoRpMAFtFO6D}RjRSSD^6!Qb-6Bq<`HC<)IXc<>-*(e+zon+UauBGtmN^Uj%B zkJ-aH-7eldJ?r+EZ*u0=E_3~vCi75}xyxb({IO*JTnfea0xM4;A^!;vkBW(fLzkt5 z)eH$M2%B&@Gt%1MHji>Ph~QjyaKQ|Tz1iku&5+wCDaRr7o87i)0OPh9ZlSr?2HIrB zxkeFBbB2Bz<@x4R5(i{Zh{##b{0t{5Snl?OowMieH;Pv1L~pkP_)B6L3&;XQPC=9> z?gwxIN_KSi(@|lR01L|crTiMEQzY|B-UK8TQY8DjhdM;P&F-1CIjlK%n><;P$Wij` zb|br+eLf%G0S!^r?yybYWNoJwDr_?|j(LiY;gnKnQIRZ` zEJZv70W6yd@ZG4}gTPq?&LMCf0qKq>VPZEhOCWkB-PbaNS=h|4RRmTvU19tygAGW2&XlC(v;P?k8d|h!$8*B@l zijo7-{R3;{fd5oNN%?R6Xg*$EaesJaIHoG~%l)^5?zm3xm&bvzC(si-9;yi!MT~3o zp#YvZSo6`@2WMldN+6sKj)v?By=l2*sU&oCsoZxVZYT@mOU=HE392w;d`^``$-Zbs z-x@%wEKX?_-wM1HBJLNg6g;OYqhxoqx_b>ERhb$KUJkq+y6~K;jgs$1>)%}iNY$c@ zUl-SC1H=oB@oP;*Oj8lMwWg``6OiE5wM1*%SJY2|qY%2PqNuWbqom^g{+0dVD}P<{ zY2zo2Pi=hj8DV+2D{83sUHmr%Nh^F^-)JzcAy^eHs$0|4`_67?^uem& zB#8H&eL?Dj2V-P$Tt)c@f&}D;4PhsA7S#cDu;l|L=w4n}S_oJE#2dDMH2GliiT;<< zPX<37`DElN{#Q5Rssbsy)qg+*2bSMidM9N2$+d9hNA(ZtBQ3vdjNJNo?%~`j{^-4U zN!k74mExb3ga>|F?kAGcKV%{uPmNy|e^LDRCC>)_t~}!Yc;Vr~YSrUEjSgLj4v&02 zY>N%sqBq&KVeXF?015WRaLE>u_4VxL4*Syd*30 z%3mL@q~ru8Kg475QR+JID)!)k#Eg}rBvax-8o=XK{y|(SPv#PWi%Kb|GAS7(-jkW6 zE;$E+5MC819J5N4VnafV)~WL5ekIM5T8?3C!7p%0-&D|s#9ZyA-@rm@NZpe5L2(eC zJ?%sPLH5bd^&R^N6r}jnWBO6bHmE?^#ZHW5r~mUF0jt$yN~GOatNnkoTHPD0R*zOI zK!Rp5^jk+rafmo-vHR5+!CNK~R1T>r@nzYB)4Sp$5Dob^D4D$*Xv ziyW!11zEk_=o|AZH-7S>*Grap<+gthC_jOp``^J#J}jZp`Oa?9SZ!;JtnzjH+k$PO z6A8W1cj1MKL3x!Fwc8>mR@>H9XM7zm)Y`?yKqFk^laY=U*OSq8_1nJAU9e0^s4_Gd zu3lqmeLact>igGMuKUjV`(k7%Oj*+$BP)Hq!P-z=f@XZ(FZ4yrWlLqxD~_xdt?N(w zE=06qc(VA zT4&n+wXJr<4Ry8LDE3{jg*=np&KeQwi^oR6M~4ai>QlZ@OHL`L~rwQyDZpGsRSL26UUn<|mPsNq=UEz+fMWA-+f%{D>P09A5 zqaR;?cs<#^_=A6H|A%&|ozN97(@XT09Ox4*$CB?0{n8uOsWomPsnDu#O14y(g1USC h&h_mAz;E};Fr4|FY6!=TFK1+!)(lza9ZE9je*uLbtULe! literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_450091.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_450091.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4504e4c31eabda04600a13551b8c3a5b86490768 GIT binary patch literal 10876 zcmdryTWniLcK7mm`F@BGN!Ek1BwLhay(~+XxAVLO(hhpkxan4xv4Cq>Ga)WgnY zLT^);OxlWUQHTx}FdYPl5~o<@qbMB&h}KwOjT9(&O`&4%x~LZGe)x}Br^RAXpgnU* zN|a1nNj6=e9hoy{&dixPbLPyMbLRRV6$&W^sl)T>gt{EVev2yUq)!3Q{t3shWsJl~ z+>VVTOMC?9>bx2LhyX`8o?SR28WGKiN5nX0!>82gVYW+>Zxc*ux9e$$6uba0d;wnc z0=)PIcp^J~OioJP72Hd)F(Q><+kd2il-jT6!l9aswFMk1EVH4Y4&^E@)gMmq< zO_ZIMYgIw3oP*Rh1*w6b7JfO|DaLf9?p@)Ca%;Er?_wh=9%fT*ty6s+HyfX#1@)SF zEYUaXo^(2_c2;=70RSs*f3>gUoaG#wduGrw*xCPTXRoENt81`xXwWj$**n-bz)GOy zV%q?;u(C6~L!AJ%UF>{?m2~v=p1siBHq@CO<@_jOJe?!}@ZtAtK!!BR4SJPKQ@k{^ zU6S}`ynCr0p1eK!@-zD8@6lI~(RUxMC-LbD45JiuKCQYT#4t|*Kr#+0q(z=02c0pV zV4J@*L<{Kr8&FB}IwelGY?tmWaEjOhw}3RIr}(?gMDl;6`AoElF1`U4z*9sQ)9Tdf zi)ndk^*CKX@1r%ShGx8>fSx#w(S?9j?7(7l9-WIYbkX!>kS2b83zM-X&em)#n*&Ik z&b^n~eKd!sgf4+G0OQ$D>!`*I>TI$X{%EovfvE(_YOVJ8xf|FA@YD2BbHg4?)UqB&2W77wAQNY>o zRL~V3a|Wl`WW=Kna`RS@!p}ta(D{S3hnC8dy)yGHEv?*RK7MLGGb#R2<~u~=&&h?; zRc=+g_gG~Wy?@v4X4#(~Ndnb-@N)Q1$V&}vr11GXKm=(@3N}Th_puI znfFQdISxNJN8R_GgC=NgQpeDlbm5ZBS+m#))PKuP)WbS=+XT+ma(vbAJm0i(OOubh zRFikpZR>CY{ahPoNp@-O4P;4;w>&4mq)lz&^sr^81Ub!nxlL`}MS&#In!#f>Nu=#* z>yKnlkMA*G+Mc#%%KwkB-zQ$GO~-ahB*UII?7?Tlb7v`wbbq8ZoaC&~39v#ZGPueT zbAaaTxMx9+MBs6AJ|BI8h7r>C!*hy0NgwC*2%xE)=dBbeA(5X68i~@@={%jmLAI0d zvNOMnyX`YipZ+e=+7KAr zDLf*fRM6q+e?u}$S)SX@N@pqOIAxu&Op>g??zEDCb+W=qhxny*cZ;<5=p6O+!9T!qT#UQKP8C|d*(pUoa?-l8i`fA+|(p#vpDD6maA6x z=me|Dz@$daM5QLsX7oZ+qKT9#++u1Z)mA+<&Nd=7N_zH1OKJ=X{9ud%HNz+r3sA`A zr%>!eA(O`nXD3MutOylJZV3b;VwP}@jk#>F3RaSaoUC$mj-qT1w#dDccMyM%w`lu);;IAh@Zx zmU*k4>q3mp+3no^@~G()=hLh##T;nfBAcW_T!|4^K*@>`QwV^8Q6WS1G5h4~tc^_0 z3RqVg$lCPBraxS$S#VO*u377-t!BbHI&F1NI30(r}zko>%Nw zJ~8GkF9a7tH&)It#Ybbt#tmaLV{HC%|5p|FD`LhDpZxKT9ub@=ZjKpSHjF12A_B4+INN#l8?VLOvo?^DKgrtm37SC1w$MN1nPU4u^otO!NEV}ZqR4P&fhWOa<7 z&f69jh<(k0#pwPcjPVE~JHiN#csmf+3g5!bcS6;V#1(Nu2?(hwVDuONto}~(?dEX* zia55fmXXwY+v5sVpqf#byj|Y_mFcnGxI7#j4qaK9WD1*N`ePgVRz~0YW&2k>_j_Xc zPQP?hU+R;>fX0$=ThQt4d#qG1P5LK8c<5A2dC1!xR~th$jN0ryvnkdFtTC~1Q)l?= zYoW5hLhxL;lF?PIN|~IN$ne_G*L7>d%!%_%&Usj=-uM9?Y6yvgKMA)m`opVvjIJdz z`B3n+^`V$K`5Ke+8g$Cp=OdoLgya74z-Xu^=vX-xIkK)iwV`Zhl|cv;VI_WNc4(>X>u?chfnlap;{XD4+k!W4}7Xwd1b|Tw`TQBgnZBz8D|>0 znEhQ3N7oet8;T2z;==mH*EcS{$y|Ihrm*^WKb-P6GF?2!5GeC?Y{I6N`d<(6LW8#n z*jyhEZSzjd6GFN~_=!MHU_Mw7#u-gXxP#FgT+zl21t0W>1)<@fb49XR_65Frkf~^4 z3@!h`$0f&qAX$RIQG9o^w63|o?q6GATKkxsz8C5F)*899@IX{m)23PzXx}<~HH@x~QP-`uue!hJUtRe8T(o_FX&7MC13uwn zwRTw?6o-W2u21V$$d#tMeXC=Up$B!5;m@5+S?9ySb#?!SdWcaEtzQ`4xL{!}SYqlo zeL|#l#-hN4ulrjt3VF*D!HF<0-0-O;+;msAV!f+cZ3DZwy0B(kBVvX#K3QCAg!5Jt zKDS!_50#%+Mh4cbtCg|(9;U2^(f0Vn+ZwiYHlyW@jIoi?G)A2T{FH)zWq|m(uUI^z3 zrS5NVg+}L6w?Du$*2T0dAOm;IM%iuFd35znI%e%~QL&Oyr(8m2*pXOTUcK0&fM4O{M!}l$W8`dpYmv-_v;{{v5v^Ildk#^i4e_ zWpstl09(&QGRGlW=;fWoehN>9S;qx>j}(nc3#O3Ev@4|_tl_y6{HFJEXyJQ$lIM}p zc-GNSpeXk!Xa%_Vh(jfp z;R{tm^RulcgE~t%#jQ0dPyS0e;?shghhbl{gi$Y_A+0|z9bG_+U3xdtV=hhU($IEE zZ#~F6K>s`{P8&c+s1l$prd6bnR+4!7`E|;;O__#hY-gV|XYkOp^-*IA^lNvbmFb=y zwJ*n~^XbRP4U{uR?eU&oBOIvJo+Xy+)ToYkW!MolMV z1dzf!I}X|3^7r~e#lbA2Jtu=rMV{UTPQsmN;&=&Riv}CvOzN%Iw!6>_m}h^>Mp%6he!W03_T0 z9?|ZD?Dt;mXN`ASZ?}f$SKO;-K6xj$@6?+9p@_-vqb@=FQ#8R-6bYJD5G+qvXTdfn zLrYfZn47UdG()Ou&xA+i+>*&Hq>MWHW~6#7KV zN+ASsU$@vNr=ioR6XFhvM0l0Mw&;d>L(}z!-!FD zpzI3DEGT;uGFF_9i{)+*fjxci$Wu{qQjF!Odkf0S5RC`|zu7rA7Y4JUaR|*VGu(D? z?+qD3M{CFup`AW`I=Kz-88IppyQ^UuYfh#8{XCl8E;#^dVz z4Yi3;n_}wy;eJL9L&Y`wz;iO@`sAA>Wgq9=&0_>b-ZQ>;;u=%9j?t8P&&Fl?knxes z6cw1FGE-cp3yB`dilTy|sH`Y1Q!X|8n?w9NlG~DXSyj0I2}F5|SmkU;e%l;wGtLsoGy1ljeInGU3N!O`!Ut@{gZ7Xqc;3N{yn2Y$VtEY^u8B2Kaba9vy!=M+jgO9g+2k4dMU@ zJA7>=XJzV>>8QDxIoPsJv_4VlHk4(IvMf9jQ`UOB|6S4N?fRFTBEL1<=IwqW&;$-J zf;<>2PDtW}jQgjJ69x_gNPsoalrZXqBG_B0da2M~_eu}RDW<_38~9^Aw8@I3dBi`6ig2f9g&Rghi|MV;CE#^Qvfck+Qbox8 zlH$uSvq#FUlkngfvm?3eO5+$Q`4H#$O-aEcBc&d=!SX_T2D1Gi$1w8DeCh_ua5j@r16tjuk5xh$f!R^Ft~tBgmAsbsLu2x8o<~p_%67R5*AZ?3WtUNg`~??s<#ycR$*Wt# ziD1Y5{g=@5*NEi5KptLfQ>pfLC*;@xcoE$1?esMU8beK+njG)hCxl{2=9dLJLyh6a zm8MAJI??WJdm>dW9r7Q7+j?rH?Y47uY+ZW7+wlUdA}>@D8Vr}l6bHOrn}wx!F5kZF z?e?8#1o<#!c?}~d@xB^35IVRiS9m+0Xbj5*!GcFcwGqR*=Ct?hrbdUd3|?LhA6PjU zIj~N&ZF4deu3W|M)&N%8w#~^dY?_nR%7^(;t8HB|mW84JRT-@6))kF^>Pc7J>xzcw zCWR-q9B(go)0!a`0%hp!?ds|>YdBp^xvfcuz1j{RnC9T2oh8x`h5qj#=Oihsol{ter|iHmRL-^@?&JUHHtCn_*_=K}Rf zn3MHN%(VXweVa}kO4h5=?T3@~>U6y(S+7mk>yq{QRQ-2HxcW0b#*?S#>JebIb#e3I zxT!3jTNu~n#SQzu$?3vzZCotBMcdrH?urUWr`_FZpt{B{|^KTl+6GD literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_460195.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_460195.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f123c6f4dc14d42c4289df6ebf96c3817f9949c0 GIT binary patch literal 11369 zcmdryZA@ENmhbud`wR0Ozyw0#gz}k?kd`Eb5JEzr2_$I~$hW!yWgeNyudGdQ4!{#s& zBXJkjpM2x}INuhHiTfouf{9$xFzwGO#s8{Ckn5C$f7 z4q0|u9;*kehOZeMTG9wV6Z~?rQw*6&^Lx^M-NtI=zK8YeMVLdsF;BxS+?xLwO{m){ z;uJmOwA14zT%7ch8!DXqWLHnyIomm|ptIN3+kT;|z1!Av`gCu5U$3pNy}P&PBBun* z<<^UU;nbboeeFx)^>>Txw)dG}u-085Ie& z^qaClJy@RdBHuQ;P`=`;FDGWwP=3Q{r++;!<63`z%UBGg$a-hX=H-T={JcpyB?%g3CZKz^UVJFNzDpWzAC1G zD$JnDrX%Sl9qo0C7^!bBvxnKsbKRFguSiQL)R=vY@&-tW-k>+o)B|L1@c8yIMIn(_ z!R%)Cz}~?AN*?jhy`9~hs_jxLxaj3{Pe&P?%t~! zBsFR#D8;~((F?IejC3j3RB9ysMmsf58ju=A;67y&#t>N9JziUiN*ES_5F4S3*XiEe z1%aRo1-A}Hp{t2P@e_sIUP_4y6uNWLaVKd5y`Wce&I(jTB5?Y&;%p;?Ylu_XCn$l9 zbjmhPQ=A+r3ssP2ZC+3`&55WnPEOrO(#SkRLtY2XNvScL-OI_Q>iF)msrn>6nh8G; zk{4BE$c#WFHUCbHB8G097`kwxsXF-8b0QDCJ@q^z$0S;A;OHSqPRY|DdA&zqhL#t6 z{&K1IZ4Y(BJ5JagwIhW620@Q_+^-PB4mVxvrR=q|V{E(@TD2)H(O}5lArCcHGk(hk za2psC?Qz@e9x$ncO?CCnhXBKAQ{8VkD7V8!b%BUzUA#50M1Jg_Pve+YKXX2KK74aU zQ|@nnXfVxbXSI>~6~i8X=R>`5#uxNO%2)K2{xfTNmPqTY$KUg%+MLKO4yY0aW2iHt zW6d?Jp=P0Av4^cc#~RKBqzO}QSQ)E4$QB%AO$QgIVkdjq!@aDjHzp^G(i zZSXk6793(thZYsFwm$YqA8YCh$k*hWz}xSBC%k`IUYStn)|9$Xez5F=hP%!4&5;XH zdAzueRo3}WKGYdzoIz(840f27gPoC%%@T?$`ibE~4#=&i;2`@XoQ#ozu<>b!(8FKqg8YnTpq%oj(# z!xlEN#-_#mCG)RuF6Fb$ow2R~wt0Xx4#0?6yY5};VUPC4E?r%|HW0f6o!^YTHO@AV zv&L~C0o6)>OhnA#;dx!mdXz0V%9@TYURmnWZ#^Rre;VdvoRbTdUVc*y|&4oiiXtdVUcyu*@sLE1}Dgs_3ii?xsci ziuUBHww=|s|7P&FqrV!BYcB*u51*?SGTEHGP-UQPO{))81m6gY!oBkfR%;1}5?dgT z6sj3TP!Y-xH>6bkDpGavQP%jhst4JYLDo2!Fy@4-BCoPVHUA>Um32Fk4}Y#_k6npf zeRJ9NR_y9qYzq-1C)t)s);O6S=3mh;>oP3I{G-7+=d3e=M~=ty8d*~#tg6xUql6G;5-_Gr#Ad+-i&u;n?V$?FthVhpCx3hPS7+neo~P|{`(EFMT{g2t)yOVucC^dOOQ(KUzSQxVE!IPTQL;uN zX^w~*nPbfpxHd2jpQ28}n`%Q1UD8dkCGT8tee4sd`x@$F#FTJ(tDSDzLAe~nq=V$- zPWQMAZaGeA_qazVh6&mMm#KrEpxk!kC+bMzoA1+vud2Pi;$3 zvhGd}BfXQ$;Mw{Opaho2r)9LHln=ypjE+&=D1w_;bQS~dDpHow(gbr+f-{A~N3(@` zqHIvRXi9WrAFPJ~=97gA;k;7B5yyDG_$a#mH$6hKoe}MX^@)5qy|F@!m7|02j1Gz(DQv065 zXQFe2DKREw!;Bt258#mMl_1Rq&7c{uLQ;i3FMwW*jDA%3q$Uu;n-mLj0uS*veRM*f zZSeYM;0+siBnS2!a%At5mLoepEk|~IS`K_?8Q=`arYp(DJVaUX8@hyjw{KpZ{^p)P504q4K7gvIHF@f{=|>{D7WFA8np#~8%>HIFT*dIA-GRG`D{ zu>uNv>MI*mF$_hU$grQh2E|{CpY@e4D2AF35*e9+`*|X6RrtIJ4JaYPD;|Z5D2bX+BrPonShJg2}=r*vm35p zDP%k#?;z{rmzz^Tj7i_JxtuouW%oc}$p?&(;@~qC);y{k1WP*rR1a#tj*1JYxQL2g zRP;f?DV<*1Bmwt6e5MtW3>q5r;z~IN)Na&GmkLoKdw{}dsvH4j5a@#;Cg+SU8cs3{ z*%FF?YQ`dloQW-N;`8&UW2*JAl zll=SntfDF)4NQh6=igpa=f>0}4^6vbC9lU#7h;+ViG9{jIt5Sto{!5gvBx0`AUVE%YsC7|DtGl|IvFFmT-T^AJpglryfa+ zp-NVg4>5lT2g$^cQYb|YSy&b6M3su5_6rNdwlmtGHbkr_3Zgdw%(n*Pio$Ty-In>5 zs3~4t9Ub~)?Ecu|$#~82FN!MzI3Euhzpsv#ClpG)VfvmXM9huMj)bWnx+1L~p1F4> zdhw@S359w_6;y@lzo!j(=cZ?;BjrEzWd*M17^JdwGiEG!SbiYd{fqWrc74_rFYip4 ziWBBCR2uT1l%YPmF^$QeJQeI(&rr@`PB3l~M-H$L2WL&}laL~xG~){b zSpDXQcLBuUR*OHDJV2)@JU-boXy7Y1xV11EPlY5vRPq<1QqU=`;C9FuIkWM74LSLv zS-x?>2`)pYJt-dm&OQ}CPX)N98~4H}(#xi27$cQG0ha}4nbLw!O{#pF=g<(io+$4^q0>K~Gv#)^Z#nhFj z=pLgD4MU=R@|Z-$=LpYtLLB;l;^{j-bG!!l1jFC{89W*RhTMY8eg~~m>IQ;IA(2hG z`Ma4KLI7FP`Lt>w+3xU8xM*+k)<<4f^6I~dYRFC4?2n-T-k>N{YzXK@e&tgxLeeuZ z@+x@BJ$Mg#{vLkbe}|*QkFDtp{*HAGR@KN#cKh1{O`)dnVMtB+&puLUXVgJ;s6E^i zX^I|RY+6yA^tV1z>1PfE4}e!N8f~5TEDWuvTKsL>K(&S8@^EjYBCf6SpI$4exH~XE z;O_{WXC+0za!)NQDff4Us=|BMG+M|w8}sIN&F)$*ty|1nF`n?BT{D_dks(VLWJ&if zR;?&n(~^uv_AlV~Y5}ThO-r%~DoC=h|MQ|4;aJfQWfACqR0j*@6>ZbE%`{4{Xb)~- zs=zck{%+p?28SE%fbQs@c3ZU zpr)K}DzMz5a7B3HE;G->iVm>m>UAZLP+`WxaPHk*^Sf9>)w-I8X%H-DR=cj{n>x&t zx31@_225vKH}X{zW-zbwQh=u=Q{NtVci>4b!aV60VYv2Bias3A`D$E*=`C?x`MYXf G+y4d5q9R`a literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_527413.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_527413.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b64a8c616db4e3868334a37b9f83cbe8cd49b19 GIT binary patch literal 11712 zcmdryTTEL?miPMoeqkQw5lkR}2?+_%B%wnRk`O{3&;-Z>=s3pLfWg>seIW^Rz3m?L zD0b8mMw(Sj_pIV>tu$s6DU7tz(EYPa@5grcj5NL&ZQOSCr8_a~_o$Lg*@?oFf5$QG4n|-E z?!bEEU%VIR>bx<2uKr|d4L#lZ@DoKRrJ$Io?~IBP!3=CyU1yIMNWwX~Z%PMzv%>FzRjx3qV4 zTwtZpam<@Ckl;FS>Xd^yU$Pl*P0r z$^9uQi)lq#Tbh)#fzCrMX}m>A2{*n3PUqcA%m?D`G15j_G>UXoLK{6hX+2p;8`EfH z8n1BkcsV*dokE}_jF;)00MC6QjY?)ovXYW`O6gL%kdvY;tyd=AMYIvDANvuHz~bvG zO;@F;1+^twbSX`hZ0S?B8FLyekr3V&-9z>ZX$0BwmPDoMBi!zqs&a+Y-0ymg@)sg) z(tnLnls{(_JbLFAjTWYKy=4}-F>J+_qPeCs&4T!~%tFO;&UdTzMeI`XY=T}o%L^NO=pt-wXf;4nNVd|r&}xtW)duyz(*@#4Oh?=$(80w-{oGG zP9O$M2&2|mHUFHQAHKOtV0cR5`y$2KTs+~M}7>F8y5kuCa5+spR8 z{QhB*wUfwUGU2Sia>MRkLtg6|f(#|q1dg^~%;Fwpg)VB?Mke6vM@=}XMm>3~$USTy zq}(JLiAl=xCi9F zi^E4*-i>|m+t2c?E;q|}+np@W1rXPM=sn8fH&{7GAkm8tT--Gvqw(d^pL%MpyU0=Z zxW#I#8MatQE!42fxyLeOb5b>K(pp2=#>Q))R+AVW1s3EUbdh6w$8VYeZie}zTu!sq z1+(}@?ScCJbHh(rI&$Eg%`vrF#t&?8Uy-V@mbRdH?y~#EPQC+ww@GyQ7>{ zhW4*$c6r+#skJj6zb90(qAv5Ej^!9b&9g3V$M59&m_{3D3#l0WUPiNbA#bsR*?*SN zob?G~y6m7dQdZC8)ib*Kg&UFMUCe&v~TN-!aXag3ZC}QC)>k9Mk3n zk22aSpXd>&?UMgepeK0YV?0zJ63_L8k1|D#Om^eqrKRpK_b**yn$Abwy2dnJW3sQs z@{J#BLnlMIb2Z^p3-zCB7Y@vKF!^sT)qTsyRj2Wx-9}iwpLuxn1jD2+be(~w6J*^I<53FHQA zKd^-K!J)Y~Bc?{CppnsTp27j9X@JoV#I#w#@{E)CpL|@}^!Qoio}RENrfG`NPHj*0 zzYx)HMHp!AivLPrB3Kbs?u2n@vjQWt#bAI6)l93uHP8|~{=r13EO>qHY`BIgJjCb^ zEgoN_KI>e({>j7#f#KVcwLj_}8s%zaW3L z^Vd~hRQ>DUp(&HC>hwW;7TJ7GQAo|?)cE9&zV|HdP8e)sJ5$gexp)N`OCSDq-`h;z z+bdVEJ-j;1Tpf<~*g~HVOCX| zK%IX$s0&)>@)%VKFnpFdU04!)W?2${(ib`VHq-PrqkTK3)dw3A>+@*%;=;*KOBPz^ z%}l{zWJgDyz>bfR%_uvJvL?vjl;&IovukkBZKJ?duuhueASjPnZX_zK-0q}oq{C*J zvJtGz4jj_8p8C5ZEN{ zkbj+Wozn(+sf^%+u67E9+`6wE%b`dfH$0sc>WOIrWbJX$Wb3!kclc0kIwQ@ z)a`Z_ExesgOu^XnPLP z3!RERJ7G0=)SQKaF31w-TTH771+ChMi;g7UfRfDMOJkdHlH73&>|Ti^Bzc6=sfL;= zQByf3P*W#rY9~&JJQ`XAcp7L?rSP=SqVZ^bSw6ildyr4)-jjHARBnGDt zUXG!fhD2u~^@adEcwnVyhuM;nc1CEBJ(Jr}%1Vyd zso92?K8f*4d{R`l<&*td_@t2vDc=-%#_kKm>DBA_7~LPbEtp@Gh?L2{z`mO zR5%T4UVay%rL>5a(As+m%Z0V1?!rtt?bD5IgT{@6)o6Dc$1QHRjWF7s)kX**IE@=e zt#NG9O&JGl2^20CK`Oz>Nwfh?7j0Z@Q<#UFjENib8V1pd>4FZY#fn^B#dywjW-X&iH}4yx-DZiBSGC9w^bX5T$FKI z$RSSeYjciIQd}82*F3-D-Zdi(WbGKryeLD#$u+dkTT!Sbnb3Bc1-~cAS^0bZ^Bx%e zeAv=(r1|6{lOOl<2Kg`nyml^LfWk8I>oF^^680&RfEAo?>279)T=0P034O^D$lORE zqu1*tPeOr}9Cld71_;ZM=^d8wamP&vouFvS-op-;)#7j;fkM3fho~>wk^bn#ev)&y zV6Gr^G2Fdinm0!cO^cIDgG}Ce@)WeM{m*ybeYf_%D1+a3CKW_|hAra|_2KjjPMFI{ z*@q@wli-BeL8B0>aDuxfgwPPgWJKLIr`tucLh#8@#0WkO^(-GgtPJ7<)J?O)J_^`Y z7X$^k%RWSexWEp}hsca6XNwa^wt`ILX_U30>& z2zGeeeDB1x#?XF7Tjo6zlV=BWm*vKYpgFR$ISRSl7?bM)*4sxSf`gHQgHgdj-;tPH z9~3RicSHnD5o1#na`}#!Ts3pZe<;YmE1i=r%S$4H(a5fj$Xi_x-@3xQb>;IPM0(AU zYu1O??94TL#5@uy9gPBtyd)-9%pC9^2w3mfXYI@K;)q}@QrZ>i?s?eV$8`5ay80vi zmdJqZ;lL;}FdDHqA|+!{z>ycP%9JyE{d=RbeD8@^TKSP!8`%4gyM9{zQT2+rCZ;ZW zs4ivHrD0*{#Y|W{QJ05J%j&v_q%NYadt@jKUX2>6eX8G8*1X^IP+7z%i-MLFrE#GX!Z091 z>$14uqt=*2<#Pwl%$|*?iy4XWp`;wMofp8ipCYmo9iM>y(6SsR+ajB zv8k{6zc745j&GuQLs zr7-?Mf1+hzMPdj~J@CwX7B56=j=c)X#8(AnKAa0eXy2EIOJaG3yTUo)hvLxw4-|<8 z-TR7g$%D%I%7xl{)!!c>Cei-5G3dN+d*GONM6=$6df?5FG<@J&gHkA3FT@l&Z+wHY z!V*g*v@^qM!=C&R@;jIv^T0-n?r0v|!}IVx0)poerbX|;gWH-oPmd2bKmL8ey@a3a z5k035j}{U_h%)cLq|}Ac}s%BM1&3=Lfj9WJu+^b z5)Yon9qDCD8iM=U$KZd#C=*)n$O)N8@d6r>(MkQF(+ARX^`rP+`pMkuDgDrLT0zL~ zEAAz9^9(&HIX(T&TLg?&buX#=?X5!n|7NtB7Z|M;jh4p1iPE*31e^y-qVt#RLRlxu z-hvDqs!75EFJ>abHf6WkI7C)X+1!-bYH{1l`%QeZV^hC1w4i|`S|@7+%05T!91`Nt z4LwdHNXjKir&kM4ZJcwG#-k2k$ZJW*t2YVB0R$7mLkZ#J{0H(X0*ENii83kSt;gn` zbiiXB=b0b}Fz%i>P^lDU%_u{f;9_Q^6P3nLmJ-m79P~|YOx%0Xaw^#5!n^~Rcj4#$ zZ}4`#*s4n7ZCzJj<#mi;r?Q>#U#cl+o1y{*3Uj36Ia?y6w~ zCEjy^@?gcPLg{UJtj)PoI9s^9W8Y%ViuRcI%&Jz8vNT!Fh04Pfi{&el<_$?kLRAa+ zM>PPIHE&3=1)7v(q3ZMeh{d*|9LyllpH&A7`W0ovUz%x{T2aaguZYNh@j|lt$qlMdSv z(hqoGs+hY2@^wCruCz%;rU+E1@5xmE?Xd(VUF`RY{lBz_A2x5K|UoB*g2#Qva0=uh|llvfVq(QD`dxaRk1FsMI8*2^ToZu9P*u!T zC0KTTurxS%m!6{|`MVi?^}3X6kYQRwF#B%dTp^<=Uzc+*1%hSGD%X`zi>NrZz@ZCGW~Ph5sLU C7DNdE literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_540784.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_540784.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..983870d1ed0da3b4329994eec7b7932406bb8ac8 GIT binary patch literal 10442 zcmd@)Yiv|kdiT!z{r20~V{B~3HuyzsUIC0T4hEZmZ61bZ7<|WO%y`D!xdub-Y=%`T zX0lZoM_VzDA~8gjGAxnSsaEO`MT#5s2Ti4F4&%zhEvtIF{NbN>*{)Wrw&{27JZ)xT zlF~|T51Dh%`Of*y_c-7A&i5VX@6~Dr0p+aWXO})P6U47^CtZqCjTiqw62u*XA}G>H zj3&S2D9QH)Q^HXZiD3e#cuF!VnUao5Ny1J}>e9<(zm(7}nlxnl=@um_02dd4OA5fH z1>my$@NpF-e_!+<#m1;YPGtX4#gxKMOctkk%zi0l0l2CFTy2-w3yh5m%3xtiW0&T~ z$qD(fh=xRzTRb898;=NlVFdEYP4uz==m( z4!4Vw9iL|4YjbkagU62d9JLm9TnKYVzocX-G;+&egQ{3IubkyG6# zVT4m28656~R`;piL!7+l_+bA?U-xitdX<4O%y_!Updp0ciz=|XI|KuEl`NBjbSwL% z$Oox$Z`lSj1(})U8_X1DW>&B=N_Ye4dMnu?dMF1^l;t7aVnuBE4d`UBH@(_?N|1}i zZn^<-d#hNI-Nfo{hzSCZvK6d8-S$?q1dOTDeHc@x+g`J)ik0x#<*YUheF?jqRpyN4 zV6$axDIQ^~E}tg|n!HY=Avv;0(h_IC6v>u8NbL{i*IUDyJQB8s74yB#t_rp$XN7#) z6R*`?5-$fQKUHIzs<)Oc4w9bVvjTQA*fsH|0*XkU9JZ2|s1`#pJPF0?^EB#M;rGPY z^3r_yWssd+M}e)x&>LW%-ufIl^2uJu)>ER-#1F9DLz$E(zqvJ&=}hw8IMd}D8Jla^ zfYKJABzY~AwzAc1%{oo<_WF_@3~5{A^=)}&yh=QuU5nlL63wa&blMY*2`yx|DhcmawwZ0C+7Irr=9hHK z`~^;4Gilm#$y|T^U!jrQ`?Zu>)>68jODnH;poFXu^P2Gytp&Y9>ss%S z^!iw9yygn~i0{TLzO#&2yO(yO~RNIt3>P3rS<) zEiz8PI61{M?Vg}*Q&tDXiJWd5#Yt?m%{5_%Q8yH>d0A?s==%dbIa!Y9&tRVQ2?sx4LH->qQ$gc8SyHKvJvbf>HzwIkPZHOh}W0 z&8Ak;X8NghvVhboY3j4q)EYE+Qi8^gmX_iQ2kJDARA?NKaN=nPWd(`wC&{f0H!(7& z&5CTjWOI&#+@Jym)IQ58#%5@1iV5kgl^@{bSmtC?Ku>XEddfQH;iUBCB)W9W?XgZX zG$-{qT*(IBY&=H@L!1!6oXlw-XZS%m?U=ZfY)bigPMQgMMkkR1&hp>bp>R7czB4x5j!;q1~sD=~O_z zQdE&f+>VM{k)d^AXVi^429RMOgV>46JCUg~S{gevf_9A{(}-WbB31dX-Fz?9_*hz( zkZD%rnqWzw`s4PyyXJR=Pei2gidH0V^&NbotWD?)!FQ0ZHXvTn8Gd?_7vP{?=e(% zEXm(1O){2#qz@emmCWr6A3?@#3k}h_ugPcwYCpJWMeW1LFyfbD4kiKyW^U}e&L1_R*dDLi=XJjhr%WEmm|kf)t>0YBK1x8;sn|^7&~D@ z`)sJnhKv{evV^H3M1N8f9u7CmzZb;B7N1|_RkEH0p+p(dG=)lF;_*Jz3DyqJU zOxOJKgt6#DTgVignA;OBN5-ZFX;k!uEhTxp(+X)KI#rJp}zTwSi??KwlmtbxZ@lBVi($TG&Xnv z?YV%;E+CaHq0%MvrT;1w$rZm-5OQT9?Cd^2dJpY=Cw7rRdl*#4Ak|EAiQn@}s0i4z zvw^e0E1{au#C%g+UGEp(ue0Bb)z>2P#K(w<6zn^n@)j=>&<8L!tfVu z3#UJCT2!F+eq=cMTC7J`5^Dd*7B+z-dm`1SVrR5HN`2iOop`t>cH}JDa~4&cMY>VH zIAJh-cp@YUo}Oz8E0CcvLNC;PK`u1he>XacTKbUwEx#na(<9;Hh-XRNv8?Vy>ds%# ze|_Wg8*%kPzhFgQ8akLdT}R<`Rql#rrEhy^jlF#V?Y)31FCe`wyWFX8{eo~wy<=J3 zh16YN3cgZ2RK(RixrT#}$)&)hARVgvgbX#zy&FC;_gF24m7@F)GP3k3=Veq!U`FDyYeOPBTR}Lx|S8@(?Y;F&WF<}nT6j;S?E%* z!axyDKUr0drlhohC0r&ZV-u740_^M~q?Xq-a7GWXQqNJa5HqW##4J4irk}zlwJD); zJ}j|zPWpj~BwQB2Pd%{VHBu`H@f&V|l?u$xuF2@rq$vqwOt&XZ=~nj3s?%7Rq+Z8L zXaOa4$$%3X))e&QE|QXZ^?sw@I0yJ*E9v^5<|KS?v&E(GA|PWWAu7wqKcX(00lU*_F$jb(Q7` z$5p$F&+lk#;iT2XZX)Cj#i;(~K?X#a^iXLb%Z%5CTH5bP{iK z+MWVrKrk)@J_Un(Vgs3^BpH2kqA|*G)gt0V1HHrDoS4sY=xV^FtDy2=uFO84kM4s8 zC)?|^Owf!R?Js~Peq0o zEcdPPvNxkMi{q$tfIb4_&waT1_iw-b_W6I{3Vz>OB;TEuS=5ldUb0O?_DV6=;iG`- zg>g*GxM$!pazLI0S(Xc~9WkW1kYP%O?Jkd-<`j^SFxRb4$7L8Ab3+!w31QRdGr$6; zb=hYjopN8Z(^mJmwVh8qAkXE~RZB7550kR@A$=5gkKt+nSA)1Zj;pt!;^YpG^{Nf7 zK^G@YXR9g(bmdIllXMN9rc9+eIDMqe7$Sj`Y`}0IrWE^Vkmc&-e&!zXJNVdhV3XX>+61wtbof+xOVLc=xVH?slKs2SV z@Jm)UneU#RKZ``=z9as36Z*=~mAPtPe?nOlDtW9l$3*6s(wtCgZtn{03JLGZ=jBVv z=J1J6Pv1NJSlJvCH6y-`DVq~o!|fjhegvsp*!XGHy{dRko8Rc`c_!5dxBumqPn+*G zFG*Vx+D*&aTBNOwh{LmwwJkAOOHA96C~H_QYe!}6@v4x@yv2{3#x8eAmi2TzQYMvxbLyF;$GhqZD~Sf2zG=N^U9d97OCo%RgFm1 z7-@?PEi``K6x;SDYTEaYs{Jb}!-s896<3(e-M_8$;4G~2S{02=5Y|B)X6;0 z_27pk*k@DH^FZ0F;N_G-{FU+UlUOh7$^=2lKOuR3Q(EvUDTP<{3L1j=Dy1LnxS?#Q zAJq@iPkyYI^usDy6{Y-4^&q7i9L*QdlbYAl|9Xo6X*Cbhx-TrP_Wz8ux>t}^kELY^ zh(8R277-ux(0>9TdKgzDxEg{AVv!ljW+6EVWxwhev-1-Je1*HC$l|!xrJXTRL~Wc6d4G@9P{I9RR}@dG5R&7BOwdFflw~ zQZD{NrEM4>VffGjA`&TW_slpMPx7k9u~YKCp296NuGU2m(-`K$RVIw^;Bauwy`H?z z^}Gt!-03%A=HJ54^KYOJAF-m*`TAB>M12PmZT9v0JA<8}t`)t}*Z)kWzO4)>gT0~7 zaA%||+PNe<=<9x_(B5teG{IFq8R?#PFN`lK-h`_)KU7@?k3vJ?T6h%l9bVZ~d-weO zd0(G@0Pc7oxupe(YJA6n^`V9pmD<<)OkaGba<1}mRco|(Nq@lCzoIwcDuu5pLI}&7y>-Ivr0x?uGwD=e z%1zBU?fdATz&tQg!gCDjRUt`|za~`wM5upBRQ@xuc~wFPWVgz1mahuzB&mQH175w( z2hOir_--RnU31qt?_6!-yUm2oaGMRVtK0bQcEVu7#;%`?5Dp$q01B}A) zz+EZdl@Uedq1w>QU3Q+0l{X<%^QxR5P!RgEP|@AW`AVd#Usduj6^0q-)T?T~t04@< zt6ILPBQ%CpJ>N7CI@2o8#j1&rC~uv=dHzKa9(nPOfFRYsl?{`m@i)@~LTip|YHlid GZT~0xVk}7j literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_555768.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_555768.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..595938e1490d114971ae8496b1e69bf1b5535498 GIT binary patch literal 13325 zcmeG?TWlNIbu)ZFhEF}Lr!}53sbvWm~IQKcD=yLZGqJS0isO{WRar9jHtxS#6gv2fBdJ!1(F8+=($4< zNy(JuY~26}I=1fId+xdCoO|xM=bn3p|6(xcFnGduJ{rEc2g81W8q!mcsyzQOj$!vP z3Zrm0b~*XRFXLicHYUHUz>_fLWhKC5Zq=CjvU*H&S%YIVK4O)~W!{vWR*d9l+ER^D zYyeko09S1QS8o8<LS}y{%Adx!C%qpgq~yarQ!6OK+=0rF{SiEq#`+8(q zez33PGa{2}nKy-hEPIr~2o!IEE6>6$*#uXSg7_RG8KimuV^x6)Pj{9Kip-p(8e7hmAgby4k*|hN z$+JhXJkn#@G}jOB$6VbLLIqY|*3C ziXnLdwoUQ#B>cd(&2f_}k{(gX=KEBvja7=gwtGt1Y<{`4yLGL0ljGRJh|FgNUD;r7 zV1Fv3u;lK`(p)9#L4BzntE9KUj+c*P%MgzZNAF~1?Dn->{TfN15)XDAcfaVN5tf7z z*u~0U3!U9B=%8X0?E(9G>Z}&Q*GtBhA?OtB%|I1ve@XA2thLUyPP^DDO8K$s5%LM} zj2M#d?9!iZj2P3^bS1ri_HM#yH*iwF7EaZhaN5D{c}<*ZHsQ4MwQ;Itw=i&-k#Hd%)cL zJ~R6F-~sL04h98yc3e<6Ctbd0C=P#yAVWz{yLue14vI0SZ&Xlv{lheqf}JOg8m65T;k48UlApYE+;QlF_09g^Gf5+(kKHM(CB4 zR)eZY0Kt@*vSZlk9u$lN6R8RK9OHflbT6)f!#g zVMYZtGn%BOWX2o=K0!6PPb8(9tP|hLf#fSmOofaEi6_1puVYluh}B7`;R6Eh6J(S7 z;IB@QdExD?6EvW|BuP|D5>+7q9h9#TW=wvS=eHMXZ+V$f-?(#tt{rv`j5__pUQe}i zi1zqveat|upB@{pg;s4!Wqz<%-=LQns~NxT0JsC}*6;N=2E1Usjg5^>2b!TrFr?5& zX~sjlnX@1tS|6Vk)+>np=P}G+nr;ubM{g_>wqWa1v-Q4V#t^GpHtz_Yd1|su2f~4v zZP~OvczUIfjJ3>ogPkh{6~T@~@h(nL_;Y=JRL2!nB?v=A8|~oDRXkBOcWJ@ISGVy* zTL@1WEs--ExsT7^#~b&}H*?22`Td=|u`?u36qJT_U*;FyZ<}e0cFgtjWe4K)Y}sp)jC_s zfm(~3`Mhkka!Ix3T?=UvxpCVTz5J02T;JE0uD;Fnz0DtXaM%6(VLzYe2YqOC)0(g* zQWUL^o&J%g`e}Ci$C~OD3o@wwnf~aFSZ!Q1nT^cjVsaTCJGeG5r#{}&w-M&x!i~ox zi`##QFYe+S&vNZo`Npffgpp@c$+sMjt;pHAZ zx*JYm*;x#vG_tzU5@;FDVw8$CzK2n2R`(wE*XeHnfTN={B90nxv{KF-;6&JlvZUfz z9ZSGieb}~cEIyJe%O%SN&5{bs)em#YK`Uz%K7wMvb&n)3WT<5??U(t}*1^g%cl{A! z9YJVHnGU9smFLoWxMC2`l1u(lzH*O_m4|I8f5Z}^o|7x1_7(Wyn?Ul(mFZ*CxpEI! zTfii)CD;Yi&Z>Rqz=BEEL=mi!!lmr(h()58d9&DB94UW@V;(z<$I<9ynW7|G=mwNF zz|PPF8U{2+Mn)MtATec3wTvDOv`nd%$%9jxfSFYRo>}aX@GQ_{4p>4JAu?nelvCDs zGy$|z^D;A%YSPM6X0#T}G+;^J@gx*@Q*ueV4f++E^eZ>%&y)HG707o4@&fsRfKh-Uk?HAfQ^*B2p!q_B22Ql zMhtSjHmRSrqO*lHL9Q8mA*(=MCrpCZ06$Kd&=)^;3tI%=yGqzRHlHnE3kUO9^Y88z zSuLw#HLQioe_!?IID8QYuR!6F&%{B39;_0p!+iKH8IJXXj7iv}ydS%T+Y6i|9FjV8 zk!=|tci$$Hw@8v6yG}#?!!tzsy(HvUNI2&w{G^}u_`D1`Nlrk-Fhur>H&HTq<@AwW z_>w#3nxZKZuB&8A`>)n z&2%%7n8>Gw?51x_xEN6S3u^Ls{o>fn)U!gjL*xZeS{8tGJ8wgpgWR1-)sRP}MIdX} zXeJf)9U}XNX_S3KRx(aAC`%z)3nGEO>I401F+L=BPm;G>ZZ~Tx-BX0C6)xr+{8>spL2{ReZ$UinsoZe zNs(8_R9!~x_*?j__&wxr&j^+>KhMC>C(n~bjqA0a_neo{~WD(MgqTQ zCr}YY1&ZUIph`R zBxF^|)RdqYpje$kFc@UcKt<3TaXZJZQ_f>kTb;-S zK`InwM{AC_y#r3S?-*2){r`l)sw%6}Vy;5OrT4NZ6_AY1DK>x0ar2W~k^ z$p1nXR@E!|K|u$pI{$5l+cgSZ1766_338OwRV8z)3_1fDboBv^Oi)fiMbS51h>Xia z(UW41&c2O#3t_dRq5~D3s5p;`E>v`*q6ZbdPzYL=&vDZU*Dv^?jFj3XvOOPDiO`Ih zO^nj5%r*p6!S)6(p41E6c|fco$bkzwxZ5R@;nxLIx}LPLtj^)AI%IT{#4;oE*8)<* z*P#usA4drU_+yF6*qW9{AVBq5eOWCJvrip)v2k$gjhn z@4vuqejCJA4`G(l*mlmmGkE4FkW{i)eX`}lExcxDNEx~rxjB1lMPI<_%b!}yxbpLH zYZpg!C3f3CQGclBH9JB?WH36EFqb?rlf0RXSs;NJbMoe0kW8_ZLN11^{6PCq%PY!* zXF^|3SSn&R-clJno6sAl4~7p$Q`k7YEuk-nZduZkoZ=d{y>WhI{tDM} z5wcC(MF+R-S{$HyGNHFd2JRi>6wO>|b6n9JI+lQBt!hcXl~bILRn7I!o#UF?zECYr za^(5A;ykC{x)C#Cx+&Zg8MyD7aV_a9I7J1guUOF;rfb4Aaa~F9WFo8nRBegW{N0WZ z_CDOZtgcO%%Ac4jc~j+_GB&wns^v7boT(Oa$4`pu`QrL`aYNX+Vk(cRc+<`~`;uut zr`gY$_Wyi$?VnwEVkqYgFOo=O1Wwk37x!?vH8iV}o1vOlVu)pL23JhA-= zA?BUy=6dF<9`E6*kMett{fIceLRi03_Z?@n1d@4q+j(Ng6GBWIw|wOMd&~UU$LF}3 zDz;E!g!7L?!=vmZ{3R6zR9i-hUwGc(~+)kpNm=^6wMaJZ10uEoF5E79G<)J zx1;lQpEf^kUTFD;!>e-HaUHKY9MVNPW0&VkRyCL*Pt3>L;zU_=@WI&Z*j!6|Yt6js zQ^RA!f-}DF#Ny5$RDWL0o$rdj(GBB!a3IhbGI7P#bKm&fw5r1l#?X!F$?#+p|Ms6m zTOORAJssjoXSHYoVe_}#KQwbF>imnN!M7kG`YH`jRq+l^QIJ4o;g?$DbV<16iME8-mPG3w zG|x6aI5vCiN7|~Vnt}+$Yf53MA?cqWOyd8-1W_VF00{L0WC#P{_5=ZVq*y`)I+A7K z>aG&4Zgn?SQ1*03?Y!}yl|LXpC*nKW6NMFtqDoZempz|N3C2 z{#ze9s{eIwi>#wr{%_67PMLZD0O<#vpXgY5{u?OoV=gQJ$Dw$|!~?Q`JfNUt0cF-_ zdkFB?KjK9|4oNTh$BIYj_=P8++GLE3RYKAUsEMPH1Zb4{D~;0lQ(WQlp<>l+`bR1# z1vir8hbD03t5A4AsjmR%fKHsJCV*#gM|xTJ3_;@RZ*WoGlokSdN*5qrK|_$7N$Cg0 z2=HvGAL1+NC-2C3-T5disO62$-$$k)->Lt-|z2Gh6d3%+`Wt z%VH4fT086tF`Q>E1Bm$=D*91z2^C%_;O7V9sE%TOLGP!1e#d~*M?325a;6V;WX232 zbcJ|Nct%0}*%zo^L_*YiVZ_=nFq>XIFcJg8EFSv-h9bpmyav0HaUqxzJ|j~eG1y@c zHwf0nAl7bx59zdT!tM7ZLn;zFlW~;~)lgiuK198VVCa{T839H2xhD2XhFTU;1#9B1 zX&9)6tBCKvp$K9tMsu)jmB4oH=M~$6t)T;v1JNeU$%Cz5S_lMV2gY#dgl^TG+X) zX~{@3604fWAJzg?*OHNB9aNHJzUqq-&PgvD26G7Xf2xD|eE6>V`(`@lUp6$nz*K_~ zdBG0xvIh4wbV_t|oH}*NZWV2nc^d$NDS4T!al_AbCgA50f<~$_D6@`!{o{wL7~Iz& z6~&;FkwK12kdL_h;^!!zpPn3N#9J1FJ{$;g$aV;d>t3%rxebsZ3i_n+JBGY&>O)Kv zNxGLpqGG~LA7kDHJTO(v*AC^X9LMorV8l-`!_Tm?pJMud$F{BD>boU(N>*h)9Isie z!Ql4(7=6PWP6x`e$tL2h3uDo^AWCkjh{S-|4BHK8W%vUk{3IcmUfAg0U0Wvfcm zfQStC4r;2!rUol0iB?7@9D!)jYA9kD2pVMJZMbF_r%Al{;6SZ$UlJ&nht7@GDI(j_3VyT!xv* MxY2e;FRJ^008ZizKmY&$ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_634902.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_634902.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fae6b135f80e2baca47072da85611e2206181211 GIT binary patch literal 11124 zcmdryZE#apcJJx)>03{dKjcpt8ykVKjcp7hU=U!ufU#q2hsAhMgr02V56L{qd^k@< z>Ci!<+p#>;4sv!|L}q6q3NtZjX5wvs)Xkr=*-ooxc?RW`9Z16;|8qjx+1csWxlf-q zGK7S5X8Sz4@7{CnJ?GqW&%O7Y^ZXl?N{+$v+2UWH|GO#-`wBJEPeH2k?4NNAyNOX4 zg^y!{$u~ZT^KHR|a8QIJm|$EyAsLiRNC%}j=D{RU#-Ln=W&fx`O76f$3)3=Y-;`nlxN-xy${}@ZfEyMVfiR_Z zNb-;6u^PZC`I^?DqDc7Z;HS?&V%R_#u89ZLnbj(|h7D>2m_w79r}heNE?Pko>NX2F zqG!_Ubh~WhoOr+m6;67vtLM;h>v695XrHyOqqnQ0+uC#FNMA>PpS8cEyRYXYCj-pe zmXml`|ON z6zMkLROz~}!gVr-qbMtVy2gl@5=1e>xAMF7rv$m9n96C;m`?&yRx;W&_vfUnWRy9$ z94Q$iQ;aY(`l`XGP zx-zr^IWpG0CK=5ykx{^GTBpIXjGoqMke9G}UGF?=OG#dfcDKGna^a64`RlQZZ6Vxc z$&vnLd#9eMXNXtWJKJ+;zv`~2-*8uKe@zzGz|^Ns?(CV+_!5bKs1-J@ofSG^{y?lS zpX9H}3Yj8W$ZS(!zHLmQ=Z~&#VYZ^3ebvlt<99f^v6`4I$rfWyV)=VuN6u;~*PeZ- z#Z)|Z>Sph?b!P^n&YUJqS8MM{)2A+;d)NI?9?7#a%W22ex?E0Iu&c1cz$RfIQux>w z45I}U;lh)5I71+~ulYOOM=d#;dDWfQyw*l@&ia3JTWRxVm8or%Z z8||QeiluA>i1^xasJ`k-tv_wDYt(ji``R@!%e!gh+^it4as|~!;Kx7U28$+a zo-t1B_MUgpDfsGsGftyvuKW)?Y2e?R2~OZ0=j4;LdxW-4Se+Cn8h6_$PGX~Nt`P@d z-JEn7K;D*SPVDizX$L311auwdU7WCIpr6Bk$_Wqk^l`$zqureR&_K$~?&H*J{cTR- zaaGFTMTygTytI>YSlv@z z>p7d(ex4(9AStT3pcDgLPA|j~F;b^sms6B9nRbd!7LcNn=6>0lqCvxoV>G&^XelaC zAV4F(h(_)YC!TatR+tfbC8Z@$6%FPz*(qDk+s21E36jI(;E3ehpeFNxQoOJbmpK7F z!Aa0k0y!}~VYPcWDc?jDdMrhOhm55#mz&a+9h?9fg3CKNvEA*lUWCPSy5J?C$5KQu zbNFRWksO|+C2+$C_YR);q|HflG9JrUK~nxLbU8*i-0j%~b2WXx=l27Rm)!K2XVPYO zG@iHF$86s7Zr66(h{NS=^w9Q3uVZ4e5n7EYb$P)SJ;QE#qG9ri72sB|2d~>@wY$M8 z+FDxoG&KW;Q>D6(IcS$-oIVZ`q6P9Sz=HX)Zw;7AGjk$%B3u<$miap#X>~VMv#PKy zt~L3OKGKjgzMwDM8`o6&f08IPMJ%&!f6tc+LxR+WPDd)(f<~5XoZqrwWt+QLvMV4- z=nKQf*yf$AaVM+axj@7Y^|QPCS$%(ycxouT**V)8J{C2yWzBIzOF;IB)ZZ+cEee-J zoUE~NzKS)r#L3qAD=c{cNbx)>*@7mPY?|M_;AUG+u;hshk2coW#_HP^ierZd*gXTR zejq?Rk*PyP!OCAW-`X>`C(;|0#!H)6S(E=@LZu1S2G979d?~F=Xmx*n^5&`8Q{l5w zCtJQNUa*_h?hcBdXmz2pthPEJPN?-EC#$aXcmA8A@R81N(>!YqTf&#(y1IZgK^BMi zv*cC~L|6Pi89p2?nr)06iMD=BMt9xrVRbFgXG6>*(D$j}snFZub03hA!;zx7v1kul z@g`gF=EBG#^|@tngl+GR_1f5W8(UyYlvG9a-wJVc!{dggJLf*peOhwA95Yj zj-O%oonZ}U62{61nwzTraec#`gZJoVem5bdDN!NKsR@YLK^R#gK^eh%`8kk1f7B2*kc`2K}RP59E>@z~Zk*|IlTa;*ci+F_Z(b^k8JWz|n#H1c_4zQnd)V#!PV`2RrTS0$LCFib}*a~GrJ`|r-_=5bcv zJU`0n4#2t?O5WcQroy{sdm_Wp{*QM=Pu+I2hMn01s4%iAY0$T~&1=|A`&r}u#kvPI zOKGb{Szv0b5^Av3GeDHjT4qIj~ zM#$NBBZs5aadksLm{6)h;^0x(SHpA4$P`;#&noK!_zSglZqe}Bg~g&@jm5govh8PC z@+_Q(;X*Xos%LP3W3&YdP6+oHCxlxJE^*$Iu)2qbJq|B;680%;5(M=L+vQY)Q#f5- z2R-hvU35^K)ajZWhZ~2J+1;+usS%sk!AWR`cZzn|bF&(yz>Cl8U4v`my8gQ1dcjSM z5mDF_?vpVxkB-4ckx=WNA@Fc-=x-F_y(hZlTU%hFrs`^ zsbo%7h%#n3^o;mM0fXo3H-z$9N}q~RQ6l7#`qYe?k&l(Y?T6$L=g1+aBpJ9g2d)ej zj18sw6FG1^Up|u&Qs<1UZj z7zVl!=9^Uxhdfb64mM9@Gc34>WoW->PC{3KmrT!u$!42{j4)Wh=eLx+B_-vQdJ4Rm zeU$IZ2$RVhyI?o?G`xj^F6b&o>ZxQjl#+o=e>(3un*L&w769H1wzf|?w}@k4_bMa{ z#HuQn23qP=OYI_{rAf6k5XVYw181(`jLRq`pAGjU}jo z+B_ZyWpcXeO%Ud~Oxa`KG%@Azn$9^=UHCYeM!SSoK`~vPjlnWK_wF;Lg1Z$ABI@bV z9@qF4lV@sj(oK8QC=>K*-q|)?n!0^xcV>M1K2s7ny%jDglg&k$;GRmN&%^D5L^OG9 z6AlyQy)x;jH;uTxrfD(Xak}ZKYjVoV*I{tu%Lg{Bm|!9+2TmUx6qe1^x|t$K7_@xThT| zN&Iy*7Tv$U^JBj#yj40^8hJa~KX1NcjT`qZOf3$x#V6=4z_0ulBr#V08wwhInl-B+ zx;<~3gwzDDUvLQAF0XTB$~^@RffF@j{nwYlMfAkB8aLrpkD!Nm|614$w z;9}_F+@&WPV@$I-M%ZI@_BfQ9&514MkEC~`EU`JD3=M}z655i7S`(`^MMwxEBQ{oB z2eBks3Smf9&4zx;rk~9cWJ|3}MNVq9P{R7c=dPL#e1pDAY51f_uWkTe3OXlA!5#5ryIw(k}%hNT0RWE5a2Zaeju`0vF%CD7}p)zQTZe0*8ZGGc*?;`H+{H@}I z|L{{08LDALMZhkEhg!h5AB@O>VF62P0We;;H*O~yASsXd_tdy68OY9sax15>mUB&7eXK) zyeGPwa+`dTmy8iGVoHqM1AY{e03M(8dyOKz8D|ZyNhw2Q64DTsjOO`l2Isj1o&A(_ z2sr!X{5%OCp2HpKW!*CZ?)wKgFK2x^iLbK2Nf9sJ;*&?<5VPl36fXA8i(Q<^o+vK z^Pk{w`mrZ!t-o_siPg5SqD}seKwGFSyz2?6_aA#osAd#FMW`d(206)H3vJ89LBHjx zTr;yRxDEV>(Wqt4JwLoG-{(KH0aRrSSB3i`)p1p=|HzZ_>RUr|L;lXd3071BEH^i@ zqAGt^s5V^pM5*$3JS7Wnmd%zeRWvOWE|UlR$DWV|ROHCg6{(HZEz~X(maHVBk*)Li z?M8sgEm=v{LDQ1VZ+%b_vpJSk!+8YyQFSnHSXQ&jNtcB%_6farl%t1)p~~B3+|VUIFa)%B0gZLZfX&qhpH`jyk=3 zg3jZ0T%M$P-#aCLa< z7Bk1hO180v`c)Z^kYl7VTyU#wu8h^zt}1w#62bJds#O)=RAairRSjR&Vrtzg$yarl t*09P;v1-61it9t~4LvJBm}i3m3|Dq!BoF~!wY?Z*ZgVnDUhu|4I_XB0&>0 z=_Q79Uvik_+kzS4u!zJkfmb{u8J5gQhovOpA*XZ&GQ}?~^oXYP#de`ai?)G_w}DHx zflIf6%gV#YRkZwl(c`=r!wNZ3{G-jZ!b41%3o;hJv~nA`Y8$xPBlT>98y8rCFs<=O z%8%u-TEMFKn$Dx9DR}kp8p@9tH`2!U#lxD-)iS+L3~L31N4q&s-5t_ld5$O4;}CGN z-r3-!-{Pa=={a zybKsld9G)m3u>KLy3TO&)4e_C2fI55x(ZY-xv}7dM+y}oyf3Q3>K2J0*j4VC5)^91 zFHJtq<3rXr^b~CNtb9XH;bzY&Rz?f&0^d+IYYNn|BwNX9?}`cHE{MQZu~eZR+QAZl zQx@8QQy1!?8degj^<6IEDk{#PP-8`GC8oN`+xFf1(}L1btnDsnErcbtvATi;uSjZR zRVBDmNv&)J#;~>1BLu;acL+8QDUn7}kh%D!Nw(r~etob!Av1i2l>R`c{4 z*d44r2d9ZVrl`5CWAp{Rg&IpFC^w5ncIOtV4R^s-Y0;u}oVs}r7Ha=NbOf)^>3FSA*rIvfst=cNE~h^Md$_DmXydi`m+}d( zli|nr3GI2G&{5(Oj@B6O*5%%cj`X4Ky^L_XD78@(^Eyv2+eyMM{ zLpWqT3=TP-;|TCMu1ypS8A4Xkj4LqBiT%MF9wrapIN=}}JXwc~69m1SVwUkwFs>Qr zB+ZGueizM2T#U;%;Q_3lla2!@*mi&u2ZDaa!-?mClE?8DC+r;@;K-|-@N{oKC+t7h z!zoS==7ZpVPV;g!%t->Cn_iEPllX3UTy%g_F#e#6x#OfKXE=Ed=Vbhx?sVtrpL97d zbY06+;&g!^Gf8`#{<)xY%oTLs;HVNvo@yy5&p=<&3$w(GG&_)ZG{__*`7 zlkq^h0-X479yU4Sn!qsrE5-fX{s1S$g)}hfbKb&3{Q05&TsHwT%}Fp}=QO9xGvmcm z=72fwVxB3c4KxCr;C3s#`#EKvhZCshrt7%{#vsKL_tSI^ghvG40xjD9gy)ZgE%Sb6 zIxy>Uds=R|+|#b$4Zm-nYr^9Twgeb=OVBej+XAhY{AvWj)C1#wW~O=ejuYTcFt?!J z2O0sRYi~c&+O{7soI2lq+QaxfUZxks#OojUDcD1p__hq!T;h$v9wnXQGKivC`(wL?6v1^GMWNJZF%ku7&6CJpS zsEZLv)?kiX)4L8K>mg(~l#-=S51_*X$S@F*WewJ-?3vNL*uBslKcBRs>H`^LTSWfU zU|Q^0=!hRp%2Cz+vB!Z60QB{z=I=8M-qKWX{M?b$y>vxvIbLJ zw(xfN;y22=tWF;rMY_7EcwMLehs%pY3q$eI@-=nj2knXDB)%SvOUcp+3J7{B_^TJUY>vzU)FLTxri7IzER?b3crS z-tr+CKNy!TG$+n4+aLDA%$DaANG+ryAY;iReVk9q(S4Nj81nml7=x zp{BQ2q^PO;tH7$Je@$}*X|AjeU0WM+p&?gBGZqnIB{nWJ#_JN&rG3e%Us0(uDGO@u zT_TQM=?}I?%quv=2 zWGO@J%!l?^_d-?tE}|ST8FNLvb73N)%$n_q`cKG2BQiH7oyh!FM7gdr$Mqj}#)I+h zrK-fccpK>2Q~K1IFYT%BFRIerL+Hp5(hWt#1-mjOMn1PM4=gvLT}M+_(5};8wXIUU zYt&^#U0&@UT@DA z*$!-e;lkngSfXReiwsTRYV^j~{KDRZKH+{?f%J}u^ci@j!_encJos@Rwj0>%y~&%X zx)oercgAogB76Ee;~HBretC1n^2K!e;wU;kim1^nWlg9)>01^hhmh@1N`h?1u{E3| zUw}_0IPx}g9DjxQxY6X#R5*eHo*?m1@;1oau$?R?P&%Ob&wXkv~G$ys?o&k|Eu zq%F?~^!~u~f#m~pNa5Sfio=5Q#NU$J^ac7MCGN?JKd`W(a`b4Dzw;E}TdxmHELpD4 z2)>0t4t2Ze;J7Wiwo?=eMfkGl($Lt390 z)5z6X6|2Sfgn6mS*YYMo|F4rxz)BebOZbdI9BJ}Tx!(cU;ti>el?LkhRe=>|YFQmE zV&P7^aKATI_-&#<4!$hhzm@hWXpJNQD#okJ2xzfS3oS{$CGp9iCC#^_K9ZJ&C{_af z{SNTu9d#5wF|pe`$SaMWUMUxyZX& zrRM&T=#jYafqY5+L}^cp>}jPvOW6|p5mg^PpH-UTmM2PET4YNrZR?gDX_4`%l!`U~ zV%IPCKHR%1ZOLkPtZD0zwk|17+lle-O=($0R%uLE?0%x$ofhp* zD|fFe)c2dC%^5{y_;j|U{*B%m?fYnMt+oZ#wk(sG+5_RUUn}*oyNT{6mOYT;MFO$& z3m4MbS|qcr$r?a~iJM77a_Z4^+HnLm9$l3kTQ}Auc0MuIN5`_onyEI?_Qcc>btg}6 zhzYUkTM1#Tjk=OeDZ!JbqYwL5$Z+>>m6yV2zG*)Z8(th;7|ob=twAzsYF#xQSQ!ZS zJQGo|dL*)d5HT{=78gR=IEovRxFT^bc`Idj6#Ci*!Tf!7R2_4z$|{mK0iA1QKm*&< zzpko}kbK@k{k<}2&&uSv#!r>ys~aXtEWrj<@zlO2+4E0b|9tU_iy8a5tiCF1ti`3y z^1_b$)DtRwIJfuRV2qm&c6!(sy!Zo@i^L=m0w0D0z7QD_goGgxEeMHAt_#72w?6!_ zCWJ6j_^Ie|J`4y+-Y`bMifJ)E{PCsO6rL?hKzyR_aftZv zh!WJX^ehcEAeB;F^_ zX?1UUdD;xR+RSfCoOXW10+E{ zIB}>sDdP#ud4qx6DUM^S+)+M2qeSAx<(i7-cS(@ZAi2;+!bk$wZ}WwDMR@DGnx9nGOCPq#oH6@$&OU}s_ay_ z^O-_>e@}D|ob*%4&L#iy_^RUV@ab)!YHQpc?@!ca)D7XY>pSWmj4X|WyCav7s1jK2 zYC$4<_+qRf-ngz(hr6Cp=EdrT>L)d=Df23IGJJlWGUB2{mWzpoWMisfRn}RQWGd0L zOg?M@sG_qd$rflqlI5naD$_2{s(QSPKz~*pEE`wV?SE;e$>6H`;44gJgvt=^;iGH_ zu<;J)={bA$tb^ihl^Fv7r_E&n&0hH7Y7TDuIa#5`;Qw6*_Iqb`(l^22bCJPrixW;w z2KnoqK+to0mf>S(W*Vazj24ag{odRLz$OU4dGKcMobY?;M+7fSA(fc%({ox8OjYINs_-IRR2b(|DCA*4`Sy!DZN+uUgd^hkR(r%8wz}HIT9V&Q0CeyqSpSv zyX4(a=h_-Vr@zle*$r*3ts~%?7tC#g%C+?cydl>%=IQ;;#J67v34yA>>$ZSs-Jf;r z%i8L*6+5yFww>a$YSJ@!3zLx>x2!+*azTD>904Y+fMt%I9Vz9}P2mGQdx z+yi!rO;_$g#=RSI9-$y8YuxmpdZ`-e8a9+XOod^F1@(rSZ)ym=c|*%rb%aL0LGe|X tfpLSEV#7#Cl=nv78+l>Em>0tWf>i%bHb9bwZ)XLB)|S!O-&68B|6e?iVRZli literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_711258.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_711258.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da6b2eaaeb0f25c371bd6d97e7bca184058f1bd GIT binary patch literal 7884 zcmds5YfKzjcCPB`cXjm>8jLXyKY;dwHjNE4V~^L^U;_rcUKr1e*V^8Os)mN9o2f25 z;M|I8Ruh??jRHg)(Ic(2?P#RI_DD;lC}EX9D8G>NhcYDFs!SwC>-^wfXULD$DvEM$ zRd>^7n!z(lA|==6_BrRCd+xcfbH2L&z%W(<%HIi3$BdN(@jJ}qi#tc`J|ziaiC_tq zj1u>>U-BNQ^14ato}MH)a$L2O@2dT>I!-@sdsT)xmO4PLKR|9cKyExhZaP42K0t0M zEe|_c>&N{IBTfuXJ+^>uaJ?dnu4*KZH>zH_r$`pZR|>sp|t*yyuMh2%K59PiX|tS2!hwKWUP25#{FXv z1k$4_JPNCO2-l?GYi-e(-_om~Ux(^YEw&)m$4saSmEn1cwWv(2N5@e81k6V~j%tvH zr5*yU@+VL&I;PI>Bsz)Qs_YaZfwtn_xCMI9d0K2hI%u=yd1zzuwAhHeydBkmbv$lU zo`PB774*fds0K^(<7-kxfIV~G20DsN1-c1Y^0avRTl9QIdS?z#Shv9-ugGc5bQf?81%a^Ym{;|-=L`WEm$IgXDxtQQj@FRYKo1F3k=ZCRF{A^Hw zbCn3k`N`&~*&xV45E0_B;7A;>2S@9NilV2(@`sRe`&MN_b2;P^_#Y8scYL=o-!>uK0E%o@3C*W zWBJ3)-Hz@D{rk?iG=1dva(K+Kb<4t*G zrdGOGPyI}+I0Qm8p8i8VcjHO z{^7>pPd0)hFSzhVFf0#_%9ls~Q9riV$(Ej#E6>Z!`IVEaKYs37?_95t-@dak6x{eJ z`yxEL@zYUxXiR>4?4R_JT{zOj4IibLBC)9`Tp@*y!~~9ya-jzttLR7Kv2lf*;okx^ z{t{G*U5hCBK%iJhrh}R&!C%K(ow(|PY6J&2TyVJa=6-ko6SzmNNA5?SB?1xSxP4!_ z50UW`*C(D&+@cwo0RUFPwbS(yKO+%YN+0P;Nk?=uIu$z7Ne@`jVXzifB?$wzMG~q- zq7-}kDi2?u-`k>PpGERdPYbTRFM=me zjhE%Az2E$Pti&Nb(r^MonZ^s!mAC0$7m_F80;9?_2UEHO=r4rCUcLPg#l zvZ)Z10Klz77Nm#$CdRR7Ut!Jp@eXJ&dGnPQxU^Vg6w^VvrNrkpmf#30V=%T7PN&Gq`ieGD!I7fd~ONjAvkjc>Hv+ zgYR2}kbt@h)JX2z6-X`V;={Ci8LGdamWU*o)D_`i3?uiIBwQ?(#3vwstiluws~b?E zl~1P_wUiIxjg36;=@or{*H8yQCKW3hV;E*7_;b)dfgP)pB#FNYgkruN4NVTSp(_t- zLQ_-GSqx_YQZ!$V#z#WY#1$a4_5|*G6sq4#M5;V({KEKa`U$;qW3_Yb$Zu|MHucGs z{roj(*YGe7kUsP4!a^SdaCj^<#VH0>n8iPUjREAE!T^Q!x5Y{h%Qv!MGv{3X7IQ}hRk$a;x}#?Rk<{I0A&E_KbdEwrWD zGqihw{p1H5`l`7fWN7>Ro5?p*)UPa`TDIs*>G1OCs_RK~Pwybo$u)euEx=#LPV*T!mD}x!c zW1%H!n=@r80(?EZ&*qiEmv+}L8h-0;0tC2RvC{FRdehyus@t?*l5YMh-4AHPQ<-X8 zydw2wToqDprmQB-Zj_yt`ZEr%)csQLS~w}|D>7zF#!Tn`yEA4xNaoE+^MdQ=AU>3& z=dDTWLd(w?F!A3`sk2> zHDJ`C_QD?EUZ(HW%e;pX)QD<+mbYg?mQol60-zBGOxAQCG)1dAPqRoCja~w~j1sKn zFG$fZk&UIluszG+H}T+gV$^l| zKYt=%w)SWFwLf^aj{i5ab-u=IU3j*LfaLEU@aa{+u_J+q@5R+EsCMe00w6HWhTt*E z$Z`)NBb=(J&;pkbf+L{>7i{rSJMchO)Q=y8eqa*clO3FMcGjWdCvX*pYRE@v$#ED< zI-x4I7l5%W_}glFG$g<`*WO62K2pu6#j*=N1AiA=Fu>CV8&gvpKZFHFfr};LJRkxS z&n2d#LV{OQC(fE${tRL@Ywj%rlJdSRt{5RlPfZJoO#mnp&E@1h&l3#ve>Co8Adt2RR1785Sv(EsQftdwmrZU5VYE(=@Cg>MXMCTt#QKuy6u zoq9C}e0KHd<`AZayPKo%vO5hgbc#7o@%SrNp~fSEdM^?JH#5cO0O*i<#uW-4CKUZ} zJRa3906uRDtsVQ|Xgtb3Av8#R6*rhKPsZ8lD0hYb8L;?+B5?`oEJc#!?+DvJ6U@I5 z)mbBFna!Y@^0E5%8`TO-8dmjX#ZF{cMZ2>=}(eOO$z+ zE-zlrp40ed!d-@sTGp@e=Lx5K9wkw>MdJg6-8p|hc|Y5#@fULZ*B4cOmm+kwyf5%D zakM&9dm>9;AstyeQC^j)JO+<_UweShm~4;GM<`1Hy=}42S0}5p2CQc!;2vSlggG-& zUYWIE%1Xc`z@F!<+1*(OmOBZzH|xTbo1h)L*!J##jv$$@%|j&V`L`(@;iv-y`w^`! G?tcLpqWDVy literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_816058.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_816058.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fabd41ad1ff4a0f9d47049bf76be5886dbc13786 GIT binary patch literal 7990 zcmdrxTWlNGl{1_nIh^5;97&c%Sr6*r2W`iaY(-Dkwqjec6-kaGdm9@CYREMWS}N>QL{oezGxf5Cobu|JwfX^2@Zz@YxoA6@sOC>H3RJA4eK zP?a~pZn1mCxpVJ1_dM>oXYM)o>fd;tC7=v>e>>fAj3EAiJB3o2Z@m5~Nf1v6kr2sw z;(qQ;-Y3nzZGpN^lY~UhntTeRxwmMO=vnt(819MG0dV>NxV=1liV+>3($Dh(?mJzC z^(Xp7r$o&9%1CDpfV&QWvj@Pr1K@7S1v6(w9v(_!rz%9xr}q0?(SF`fiTiFFA-Rj1 z^UGvV*u|=j8r)1ok>q(P{E;LYuE}U*cIiP_kqqZM!?&)!7aA!v-kYp1nVTKj6g;m# zap1L5V8wD})>dd)Z;^bS$E)>+jIov^#nRG>)KC65K0kL?&~Sq<|OjV_2?vW zKY~?8t)LTr^SCy-p8}k34^AlJf{|*AvOEnc8C^2Za;rw7Q&=K8F4AAvpJRW6qMkm4 z%7%hvsLr5zu;(TvUvPJjq+=2@!6*f^2coGXoNQR4){XVPE@VFjXFy-JY)wk*P{z{p|{Ol zq#NumfR}ZZcv%m+0R5s4YH!4M*oYGO<-TkJ^_VNyz30pBE1@dCPQ4~oAL=e$-DM)^ z9F_5gxQpY=E^a-5$1FjcK`(@3bA~;tOiOZ*lyT+|bQ(5g-e4Ey=mR;t5SkGUdOjK! z4M$iGM;=IkjT&|cjTg=v_LvftrC`57erQlvZ;ct$*v(Ocx^?%C!Ct)^8oPdb`1Lbva$|rWwHS>VELbdWaf5{^7TOfR4GI7T^RNJOJPZM!b|d%hFl-Oo;ps5w z=#tWI)*S{5?ZS93mKu_!n3C6*w*XjONS5+=!-YM}9BS|nb0meiV1_1-JOHHQZmb{d z75L%hFL&EMj>>bf#qgxmHXWXv3oFyn$XnqDQbcKs$&+o0w6NF)y|%nD74YlWR8(F# zx40YvcnG{xiAKOA;G;bqUF~o80LI|+!{;P9BF)P~prsR@*uTJXYs8ulJ=~dQ0!t zvA#{FX2*Z*+0my**K0Rs^rr3?O?p%RmjB|4CoME?HlNpp^KnP!SRD+^biT9MIr>i@ zto}Iu#;^-TQCKTkZFUYp)x8?~$7IV;jfYLfl> zkrNxODd9J#QgeE6KtD3D!j^}}*-T-4Rf^MtK-ph1wgmQ)b+4X`pG>qQS-qlpqbfDF z&Gu?ncglma4#HKDxT!Pspk(5n&NjTFNLOQq^(KNkd+e3n=Bmq5gq?roCYZ{!;7{IJ zA9(uVhF`D4wp(#nAoA(>#y!2ROBcFU96P+%wBzk$_1gV)G1aoocdgjULF2XzaJ{zP ztJj^?g|h{YxAgi8sb#(HiY{FF@3333edqM0J-(H(6F_FBmHm!3Fcyaba z-qZkI(*QdRyUN6wL}lWUE;Oy**%8b{8B<*YbRr9Zdp9ePvHa^~J?$&+X+Re1jpB&MBLG31Xyqdsz z@5Zv;(xcb(tnlgD*7fC7uYU5PUVAa_&HN4QiuYx4S?jUu-g_~*&A+1!?$D0UZaluR zGM2ck^UdpnDgW>3{@3s+69)NE?t!=5i>n$8VXd854q^ZUTj}CF>@Aa}=ZHo?nWw%_ zp?*RjVio}V%Q(aY)upm3r@9eYZiGnsfk@@cW09>K&D<&2glt2^CnVw_Kf@;y_gpQ& zoI`NFFPv-vYuC;qTB8uB+9M|sEz%-UOTZg$6KV;3=9VsUp#3aK<_NllZVJJkGJHP$6NBt<1P8NL*$poTk^}}E&1i~ zmVDt5`Q`DJ{PK89etEnl-;Q|1AqQfR2YH|8Z(&FR31dW%8QuLW6d8Y|kAfeVSy}+H zw7|krObJX#fr~#1;1ntloeE6Ph@un;#AY5zGWN>d093iO89Ozk#P-;9cu_Ja$ovd? z1~Lf44yi9Bb1Wni%5rFa2GW8d7z;_DncM~|M#`U}cf?14x_=3Noiy1z^6PB*`~1sXM&5IKLdqB`oq!0Hneb`z@RS zG$LC~1e%_G{Pbg;hE!ywGf|y{P+9fFzvfS#U#r)K(wqn92ubRh>#6H&u4$8Q+T@zj z+>s<1zmy#NnhR{wBb&`5TTpWWGnerLc9W}5`y1m!8Q7w4uYZs{x6zlHOvxLU_1eB~ z{g-wBQum(_6=MSP{lwu%(57&v9 z!c~OoLax~|2XaJ+j*B2qbdG~$2%dPJj7pYBFB3=}5Tfgsq^WP-3MwbEs{0Kr#6n=n z*$01^P!6?^`v=*lJk|H?0~?%Myk7gz%bD+T~_^3V^j!B`8 zAhkMIPL!qrcqx)Y@Kg{@M4=i%B2t56EN0P>wP5i?6J zg?9$+@;~6Qb~tg15i>B!V;JC2q)03(2YJH@R{%@%N=&{7@LV8tL~u}y8BY8v2F1P6 z6>dnDl*J{*a4Qh$=kwPomfkU(DDDZDxkxFPme~_7@&5w;6`t6Cn&C3z@oDd7-NcCt zI^7KMvMbS*?9K?4+R#geU*+Ol;(D@cy=$ZUMb|d-jyCX;_2T1mVk$YiG4M2+n%ZW6 zq+LA#%KMW|$+7j8E&hZyn5l1hHvV*6dv|3-r)xpv@iv`q(uNZ!k|#56Uc3HMsD4tn zR`+#7`-|#r;fgkt5vp)iqRa65iH(ylPHZy+mL{|7r&HwTZ2)BlEKT-73!0=(e^t8~ zmbUq+G79~V+8|X0r|$o4p&QCJ-}ybE4B@WSM$G^mbc~LU4h{|mJvrBP+=X2ixb~%H~d4A|Y zbYA=oVG7UrD06WkDlW}Sm*k(rIB-fL_E)H~6iJeQAl&~>@c%*7!Oeh;`KmKK82eGUINB{r; literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_824557.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_824557.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b65f944324202f789a09b31cfabcbd4573ed8cb2 GIT binary patch literal 11368 zcmdrxYfxKPdiU!6ehZAjARYz_Y%mXlo!E{s#u#jaFCUWf=D>FM2C4b8L&${i8>Hw{kTX)KKGt*(FKk(XZcXryQ-?@4k zL`Zhsowk?Z+;hIiIp6utcRqdJcmG+4Ca)7INR~9Bl+2koMKas&Cr=2=<0)ngE)i4A2)R*-qmU`injzQzg)i?v z!I#C8dsTgERgY3vXKmr zDOtx?#OZDwAExFwcntqLa`3bEluKtDC6JE8DQu9u2t#M!rclBMqWk!M!83;M+PcS` z9*=RJK0V|*Y1=H@LavSb+)6I-E0L>w3%NF)v}g-`y&AbHwv?-gNn6(jYy88=xRc4J z`AoT-@RTzZOc`g%JDGCM=CMzy1SmIRb|$c#O;>@>TYq~v-?G65RWIG1Y|pNE*~;(; zX06;X?-qJ^AW2y^x%KQtOj;hM>i1a*Jfw*+CA={B7w}ThCVw3OiZ!<@>-4wqV#!|` z^QJz@^o>5LdQ-k#-^f?9sV_6X%=ayP*=8KSt1nwOAM?M(mu)WBtMM6I`LgOwtaiPH zNYcLl)yTD#FRR%^uJso9s^!YzzdXCaS8+bevnT5gs&e?7by`c_WO95L_@dp+9(<3v zH?XF>jHh-}-?whIjIZH*=GNY+b`u@0pEs?;Y~OdCFiX2>9E_PsR$#qmcdgN)X0BuTJp=tLd4c7(_4Kj)zRqq|);5q#6#7`z%NYYJ za@j6V*c`0LF>13?E>=c6-B$X#g|bhw(ge;zJ6TOzYukx-%gOdPlQXdzmz%azHj8uG zZMkH14~??gEJ$+JY*2CmT^1Ey5-+4m!mcG}Qm4nsc`|_HEUDG6S(0dclcqt=OG*f*%a4Qju}^5{ucNKaaZT+kh7 z#q@ZBgg@kTv3!@^!Sb%{h5ve%=Y+Sjo+EKxwN9~8prvh7HmjTET>#vLdRBC8uLU!T zuGL#GFCUcWsss5<-|hMBK;2a*J?@&a4%zBPtwZBh_o&mc$2wwjxa(Z>P@UU0IaLRv zy5!c}U^lK|Cp}p^b=?AR3!I(X=>TEi1ot;I?c3J^7*?4iAGgsC+XUSPl3_i%xZ~aw6ltsCYr}$hsg#`%fU_38d=u z@v#cGB*Lmwr5dla1O6CE1&il_}U7Y=h3^oPDb8m&3kc1P&m*iCB3PGB!o7M%|BkqgPQ&N33fQwG1NTAX2^U<3E+iW+i?}Kp(gi z;(c5nqCzcq<>6suYKrDB7#>|-$VUe{ANQaGZ=sUGzm>e5q{ru07prJSx@M$kjw@8( zVvGMK7D#1l3PNuT>Vnq!yionTB21!uGt!yEtw^`WCw>Z7taGk2Xbt5r8O-5cWT^E? z<4RSaEUiB^($z)EBK=>~N6vp*6FYJSHJm}ZGf3HsRk4AxDq8V~j8>tBqp|h@)G&Z_ z14wxeNH-d(bfMy}$b~AjzY7_=km_VkOWsJ7vE;ZeKWRF(p#w;_J5msBdZdlEp!&82 zC#oMn+H;9xR_8iq#S2-+n1BDN5_T9k8>Bxj4fThr=HCsULnVz-UP7kN`^d^MEZG4&j@cB88R2` zAjE_orANi^a$$4x;E>qG>moot6@q$Y*9QdCQlq@&rPl!KBm+?6*=itsgXL?OJoJ|fcZp}@k zd1*~@3Tke$cplP?15EB=_--5@aZk>YT~bQAm*M9+W50fwOLCGY%Fnf{ ztQB&UYO=67602d}m6Wyczj0T5hm7I-jW|JKl$^~ZY)tGjy3=k3r!*f(JX%HsbXtxs#iau}tw-lJrWR+k_}d4ghC~vVlLc4BHH2TzDHo+UVz#G$eEC=x|5QLYaW7 zVPEr1Ve$uvcBc6cA2KC?Gu4o@nXC@V1UXv*J!-vT!y=ko)=8U*a$lda?KX`#-KH5K zM>w;$(=j#e=K8=~w>4w!stGo-dIUFKXoiws)}Tue*dgJo^ZXz*f6U+N%bcGNmx{#t zzMphJe$vky=7ZXK-CRBl1qe|h1!M(7l>LfXzzR;c_qVb_E;*)6z)P1vrd5N+fo+-9YO-)>f(g3a&)E=C04p}E$hoO_ee}sAQFY|A_#830?6wVih z&V~CU=6jZ_u1q^Xsp_Ui~lJ!2j1~1ynXhty56h;PeYI zfz#o(k4!tKAttbcMj=ag3;`O-A5eV|_1hdSC(R1&4z4N!ujVFJ28poyx@E#X4unHa z$bPx37Rnx6^}?J-x5LtzNQ6F)haI>%0SzmMX>9a_4bGcl#dgQk1f<-o6hs-D9tEp5f9oaI@(68Vs3NAR^mcv;dAPayVg9{* zB&qZXeOCfk=C7`(3u5Y>G09M@YG?^s_0IS%^F#4HF_P@`DFVa6k+`AwiJ<}+DngJS zR6%~ADG0XCDSg7Ywh-dVqLMqE^PNag;O+Fy#O12l1O5X+{vGMObWvU!6O_i}rE!5O zPzsH|O45loYS=dBQW~inKA(_F3nrosrk0wuQD|bpE_^;q{of z=eZDL<3I+CvRScT9B90KVD7-8qzE(i1LJ{c>0eZRUKK4^*!h_MH`2e9F1#BP_hH6< zU@QVgjehn$|9ea7l92wB;=9F5rS(3&x9zD|8>s!WouBN!yL(w&7gz6iqAo}3@~|*; zZBbnplhno38R!GyL-{@V&jvnw>(jSlRjo+f8k4lf)UEM?swV{vsGwn~pvkXVQSX36 zRb3f2FRB}3lE#?2@t3>mJ{)+W+<}xkg4SiFDbnjb9*31X7R7~kJD^tKa|OEQPR7(F zNMd>-se~g9T@LHRWB0~m<^!newPnfcYZ5|{7t}4v3!m{JuU#<|hsqWW75+<+>IME{ zL)*Q^$NDufAyoV;WW3#^_t=;EV!t(99px=nzjn8Gf%JBKCC~;ckRTrx43L4wAU}99 zT#N^zpe)oGz7o~n^L%MC`$V%!zcOH5mTU`O1~fOiv?M7AHs3ike<-Y5D%>3&emHq= zGJ15W_KhzKD|{rE2Wvl+hfCuUDL2slP!X`+9-SKv(tk7&YW?KI-4o%ne|jn|QU7>* zu=_##!&CQ8Eol$LCECF2A!)efg;pvQtr-Y~&YSQgLty#o2RjrlDfkVvw+TDp0gs47 zeh&#D9N#0LcphO^RXXV>)BMokf&xAy68r~(`^k{dBie$8#|SAQ4jniyECD(J*;z3y1Bjffdfq*srTQY|AvhEpyApT>LlQ*da zkDQWu6yHKa5Hu$BgS|LBTk1#gJLxAk*9QGCa!9-64;1&4y7>k@DLFm;pSK9ut?GVC z_nTXV`v1*tHQ!>lTD)6^fM`e8Z5D6=BHaxjdH^@4adQe9h*(#(0hhuh0)1>DYm`0VI3A2Kc;#f1a~p zjQ|^72Wv75>31+p2vvK^!Tqe$7coHO=3*SP623CnT+P+ ztUcoGTG1MClO@ZkP-VC(TDdG~%}6p9s*aF%>i{Zi%}BBinvx_^{djxKYFkzg=MdvXFMjj`Baq#B{ zQxH{h#$T&Bj^})Wr%0`Kj8E*XkDK?zO%-vCF|N?YRr+W8BP6MdixoGS8_XIX4_2hA z*&=_@nh*~_WVkqZDCUL|!ni$H9-O|z%rmj=HOR1gP0Ar;gtj1Pyi+t^gfx|Fat@}z vF#VizP00-ngh+mK@W$YCBOX3K%OgnTuOldKsYvFFs{?oFCKo&L$IBMl2K%K!GMtPYO%pxwQ|Q( zOKD~h({xzRbf$9q5nOvF^w0fjKHcdL0#d76cZN*d4}Lm0w9`zcA3b-q62cZCNvG|! z)63}Yx#ym9?%8wCJ?Grzw`Q{uLG%6K7gypX2z`PFx#m`Z^>zZGS;Qii2%^61ljtLq zv1W|yqX@(iBZ{0HYxc=%IBLYS8HYzKSsXt|vD7W`IpY?126#-Iku}55!Wjo$to4?*kI7rbb_?~HG>9|h^*2uuUdJj{*C(5f zpFQ7k#@~CYugiZU`9> z4W%Ye>Dq>p%u~8Wr>Gn4M~Ek;kmwT4d6>L;s6Yy}p-~PJNzr)&26xT%DiI$3Pu{8e<9?^~0ik1~f!q$JTv4m)3 z$s1t(`|1uv%X64$g^pZ~YCVSBjO}%D3kvEBaK4Rbs^`tQD;m z-O(N54pEQaTC4&s;Hl8mh=f?9K(z(Zi|s?LSp9&?$&bFY9d*|Jn|O8>)nJ#X5qBvX z?A}J^?oHC_wxcehLA_|@ABbhD9b6)vv%@BvG;Bwal^peMqX7~d@Oy32x^Y_`?+ZM8 zL>o_vdl(emBQ}aJDSNm_tW$Ky9{n<4>s9Q|+#5yr7OYexoQhgiyO$OHRBs~+<3w?; ztbThpb0vv)FBG?$va?%QRC&`1h2sCV%3J@i?rl;ykj7 zAul0QV}Z!1S0`(PpllrH!$W*v%s-`mx5w#(Nm>vI5k{^AAM6cPAgmh)c@2m@DSvpR}!SA$$gX4K&o4i+-1 z19o^!Hu7O1z)$%Vcq}mKR~gI#;DhQYBI`n!Ga?(Z(|$fIGs7W)lNGs@6^VQqu7n`n!`O(OD zV1R475*Qc_2v@?Py@4SvBs4|%fhK_)8*hS9lWGnDoFpKFX*eDxID8bW==?_ z=gQKR`jr`1mG!zuBxxh+m)Z1(q#&z_n&B?VHIORjxs+rGFq zV|69Jm$uf$sTG?$u3d4INPE+c`nlcnwfDSfM=L1lE|m_>ycIWMgIrHsPu`fLmaTi^ zRDKMLeRr;N+1dg^{z?9i8BfMtl`^N@d*|tg-%K}mr`_F;+zUOg{iZDKel6bpMY9VR zFQwg=;@wYRA4_IRB)fF!PWjxfWqVWncro1XtaO=y^Nse)~VI_1YN%Sq?n& z5s0=c4ptUBnX#3_F#sQWMVk7Z*$W&3 zBY=9o^+RNS$MTNt9c$Dm>L!VZhCre*suOiAaRUl;Rx{cFp+R#ReIH8m>;**~5yj*s zYzil=3;a+(4;J!kzA^_L@d#?tM`)yIa?Yn^CZrKjlo4<&Q`$yBATT*xvDQcyB7U}<}9cH_zBFNIso>vHRF zF`vJFZZWUV^Xn3}ge76a_D@)0NA#jWG%33Wd-;IAuU>&+rlHTZaJvTS97FZ-)U zp&Q8Gg<*hiYM5!g9l9TrvnW9%GzC=whDt?fJWi&n>i}7k?T{bO24rYh1 zc`2DX*LC5TtW|0hz7l9;o$B*=!)s!(7=Hv~=~sh+vCC}W@Qv!g_;_#%3JzSR?0YpB z9)K)+7=Y~j1W-rd;|lsLh9q}N_dxdnbB~$tet2xD;c&Y22oDt@QnXpc&gCVP`owFH zwUH}9!!F*!oLPx z1bf!zQsU^*qgnfy3a!HXfv^vAG{8=N9z$!TsILC~*YCcbre2D5#oLqZ(!q?qV$ohF zU5}k!VcZgtI4t!nGSv%I^#W6!VI0YkMW$wfs##!aGK>Z1SILtqpSyZ@2juSBwuRCI zi_C!q>c9eXV7vK`8OIZg^RcBmZK;;7E?H_5lPwq-nB^lZfZ%vZPzEsKFP8cY$)&4^!*)#jb z%o|eRN8e7>zTa@SVXo!pd*-fwFnMqCA#wjZ8QPKTmby~icYF}jlRAk>y*Y2fRmTV< zxxX?xt!~YROxDwO@P`l*mAe z6idR#0XUPmMRnWoXhbcm#UDZxF029O(tj;3T2NPoIHiMVK<7T3L=B>01fR<1nuu84 ztAJzm{jhS>sJtf~B??x@c6qK1A=dCSLeW<}_)&&6MonMBLPq$6m9>u=1)}i zIM)mIff#OLnFpr(s%^f&PG-eU|MeaLZ)>@qv;EdyVf~+ZTicg-TRVPR5kc~D_`H;o zC0ApBTE&Lq&tj+sI!I&_YyhtEI+nXOJisZ0G9z#i!9Ng)aQ+rAxr!hP^Otjc`~`qO zBt9prEbboRsUh42pu6BDv-zSNW0az=K;C}HSW40;z_kMIBoG&_+ve5smoN^0Jzzsh zO5iVIgie5pGR%9;vYzK66G0)8%`-T~WRp!lhLl8g977$r>x53$L4F#a5M)!~$AArG zBW57ll%TT7$*!1VQ-RJmj)Qa61JN=Y4PcsQ7*(J57 z+UE{FY+t62$BsQUTBr9U_9O?Tk-1}c!t;a6#&5+sw!oUpq*|#rwR6c_7wcZB+WFJ| zJN>Z}@pEaa9G2YOl%{H9XOneO{ffyP>w0SU%vR1+E><@`^eo$tLVC2ja95zq*;L(J z{lmIt`q+jhBdNxD;%*aQjmI`Lc@CS?WWMoH`9gqOHV+n2=wE7sdDpVJ{olRloUm+e zdtp%;nVd17l6Ad$5zuLU4FBCO>>?gYS(CV=$pSb{mQgvP=-cib4-f~@lys5 z*GZ%?sf0ldDk;tE*qN!U$#}{$rBxYMS;kZG%)OhiX7r}p;!Sal1bD?@nXXJ!u4w^= z#lX+;NRLM}vb)v{7&0PoQqzoSjlm-mvN)&v6a8yuJhEUaD~4>y;aRg|$bl@5H7AB# g$mUwdYp(Ybh%kRjUmytQvvCcw)<9u)lTqyP4`2j(@&Et; literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_843724.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_843724.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b785a219de465b0d1cf903c2a7d54f86f554cf4a GIT binary patch literal 8238 zcmdrxTWnKFmUaES_VxRfciQ2R;53kD!$31MuLKAXh6ylBHjLYydlT$9w)^@9NL1FE zRZG~tE#XK#iqXud-Se@-EG=zX>Q&RyNW*?DtNroC>5ca;qs?geiL@W?VOJWB_G3?7 z`#J^;VS1UBb}w<&sZ*z_PMve=oKwX=naxH7-^tY9T;ZD#`YmZFUqc0W5~C1WLM&pb zFd8q{)Ho%#RkP}G4MlLOuy$5AuA9}5>nX%hA)7MIM$M`@O~|p)Rv=bWA3vpGwI68i zR>;PU22}fxb+S4Rg*?hE8#Swc0iJ#V-oTkyBlm*YKrD!5UfOGdUXx5`=FF@GepdL| z>aRBCVeKDi$1PQPIzB+-Ru$r`)t%brDW7YdY-rS{lIXKBAsCGW!jkqv1OUl&6p8PdHOGp46mcNinfi`NEx?AZ_I5(kG zRViJYP^znxZtTRmnF)k=Y93)1HdkTwjeRw~^1K>7t9oiw18%^c z1+eu*W26JPVHUIid$8t4<0kWIo?bsfVJ6Xt-MF#32M>0WwGvI($f_4Wx{2nRId#fT zHfn6g&37wWNVEtBrH7?(Gj6#{v^E zhH%o3-HA8w8`wzpi`(R0FQ%|phB|8Itta~q-2RzHnK@)u*4wCsDj2d&UXvb@NtdK0j63wM`$W$o^) z;jB(Pz}ku&-jpR#v3si4SJfX?-p|gbM`fswZr9UV#p>&7wA_cAWp2C3sgWgB*5a-j z?(3ZnMT@WJjJ&EAcfA&y`O|2TNaByzBH2-OR7Uj?)*$SVBZV=s8ygZm8y*rHagXAU zj-kseSYB^ko%hNb?b}p?eT1&+r}}H?>sy{okNZQlk+$ORFbcUUU`?j?ZIOJLHgyH5Ig zD|MfwBWm^cODe#ruJ=ir=$z0imwl2!?vpv9XOx)?WqA=XQlNLjh8iZ04*-MSKmPl} zg`R6sekL9ZOmaO}0+TZV;Yu{JD=-aK(G%wdnq#CHY<_Vy1#k7TZp&v1N%3-c#{JCcdU{|-)g68(n_*_}(~MRPV#uz8aw zA6f0U66r+tT*3NA^28&P<<^<>nQTYF)S4Vx)7vtqAL_kD+EUb8GkYHDn*dzX+w=Oi zM=sCO@x|lWkvmOdYk$EtkTN`SxR$yWyRzN6d9k^7xle2!EI8gu>5I?P6JnWooV zZ8@Xp>P{Jo1N-v>!yorA?OoiP-J824w(MSU_B61ba_eXI6C(LVUa z_rDzfe7s;klsfVXBt@q?t6w~j+mgHX=}ytPJ9X@N46%$8=@Xf;?4HGm+~A$5uT5_~ zhg&thwf6iT-J*N6?*2$e{@lBw`(5a6Y|B~2#-8PFvGG9a&O43Er$pa7qUD|Y9iruEN=@u` zIz62UWVbF(7cA|7p^dlbG@Wr~;)_k$qo0i29?8A`^J|}8TkYB>cI_*)?*E#8`x`y7 zbU%U9izEuX4LChX4gL(=J2=&fgELOfRed;dwh#;AdrLAdf`|&JXTJ%^#XhurX#3EbFk;!!-{w7@*t`cywg ztS)VXQjc{|(m&PLtfzW~c@4NO)l284;~t@sEg=-Do0ls2MS_W_FiJ22@xb!8kx2=R zf(w}xXrq>837Sx5VvS0!!ZHe!Fr_RhYsxmIW=$W^3A12T#$z*i3&9BPnJTP-ES~{Q z+LhjonytPiyrAE#tp9@k>Z`>I`YqM*pr^DYZ6o)Uv=y|g#|CVYZ31lNGwa<-J_nY> zj3J+GlqZ*RIT>4M*vJyRvvYC5`yS`zu5yuuHj;BdJg#C7e%HG}yWoQJ;*DPk#5nJ4 zFg{C?rv2Upt@mo&yJJDM(<@QF1s5Er=V;bR>bZUhiplDu4K#=n37bz2K=C8>5=v8P zRZYT5^2KMqv>L;s9Z`m&v?(=gPc^D!I%nW{NmmZ!kYTaGt3HjSIWu(ru%wl%_Qyl$_=H43NJP2})YXVj<}? zLQ*2k`@=yd1&vNSfsrGD)6d>sGDp4FBUo&Tt#~y3! znQfw`p-3BwG$a2xinLV*403$3{}rG&1!0h-NSiadG*sn3XORR-7s>uYqI1W8cDb~= z(oSTuCP!pjn*=>nKVo-s^!YuMODKpEaBN7NNl*z@LY>gCs)Y9CM@{y1YOG<^aGNF9 z6!;RlO~$COmerCMLwOY8fQ-_At5LL2Srt654*VOf+^G`=YzWnRg@ouuTnVe6fRz(Q zd7pHGs#%@L<(Y39v4)>fGQSmvlVDh5!t@FfGQwMKSw5AhwW)ke-%399M?EJWh)O1w z`OI{;BAb`U$t=t1|GY)OZY_5e*}vE-tp9IzYkP&=+R1J)g0Scq^=ag|zn%otIuR64 z-1WK_3J8yLYyjS4bS!r@ILXO`5+iVN!9N*@bN)V`dL4mZ=flc4-VYE=NzTbSONs}i z>wQwpKyluuE=Qm(1S6-*HS_gD43xuA4Q>E%BtF%yJLrSwAA%#VKx{+~GW@#)p%Wk% zkMcgVr02Q#Tv&*gBNqu;Jon02k zQ!x(%<&WSO|1XeD60KQm$>Svx>Kv$vd24oOa^x{>zQv@O%useuZqJ>)_x7yNN0Nsh z8?CqAOuw0#%7*S7z8zhjS}}e%dGrNXb5ph>JC@s8Fn1=0*V?vzGI4t%c|3JS)HK7A z+j~S!NAh&0GuyRhGAD-~+dWIIi>(jad+&Kx?1v!W+Fhim;pKF$^G?^j&K3Ib1}C9h z_cHZq4`7XlH#m6)t8lX1{h&D?;8x62br|}S%3#^GV&3z=z33fb#XR`jqBJr&lcRFf z_37o)FgiLqJUk3Z88O^S`c}?DcZcCAa}HjiBwB%Za#I0go{t41(>w{1JUN<@IusP- zYh_&EuE+RFzIR@>d3Ya-2$JT#Xf#|t5X1~6rhE+i)6p>d3sg=yE2->kl${H62l<;o z2ZoB`43s4`MNz*+rhi4|U!&H4LtEA;{mteZ%_Y?-iaJoD(Z08TW%|lgGRRFMvN>+y zH2xh!norb7WvUReD%8+e^zABkb{Dc^EQHFr zXf9bv%Z41DlAS;fWO0<71acvpt3)^`d615|IdNm+NdvT=T%{0Y{vCauqMYBvRLJUu LtnLOQ3+ulDL_arq literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_893238.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_893238.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a4ce16283a65bf4052e3ee413c5d4fd1a274a08 GIT binary patch literal 9022 zcmdryTWlNIbu*kHIplDLA}LZ2>uK52(vtPCyt3WhJhZkX+p)D#{7MXM2(D)I7O5SM zypnMV)5Qi$A|ImqBRUAMa#1w&+CY?iDSz&MxoLrDEF~r;PM|dY(H|YRKv5KE&z<4R zG_{nsPLXzO&YXK*_uO;OxsPN26USKzcm}6GpWLn^h(F>+_Hk!`7XnETtAs!ZAK3?e;kdl`pY0<{O8-UDwejUQ(O(v!efoD5#dPU|@P-!Y_-e`MnG8pZ=EbVix$;SWUtBEC5sR zd$Gvi`I2GQ^k>?Th4Np4yq}?mYszpMa-6kgIH?@x0Yu+ERlrT>d(J|LMm0B~k&~(0 zAHQIDNva;zqq>_Aec^^ctw4s$X)zY{+JFwAs+&fFKn*ja1R;@2L>5zEk11<){wt8E z>VC$v*rIS_S^jE)x(O0lsc;i`$%N`qqbA=RFrvyVj2gin;0+%{O~|FoY3&x|j+3Fk zM+B-tCTusBMYYJBgBu~rh#wdPV%&hk12H*D6Q*FAP64fW+?FgodR3{*VNaJJQ9W4St8ty-a<#PUy&EJYyQQ#(E(Vk1z3%$&GJ!1K@{=X zS%4_A=iuTfs>o4_WfQ=Q#1Y1`k=ds+H2%g^s?i zM#%R!kV|?N>Auj$-i^95xzzI)mP_6H^iXDP3m+?(#ZtN1^P2j4%jo;d$faJ$rCxL_ zlS{n?xpW+`dM@=A%0`4kozj9f%p1$I{G%#{uVEB@14v8Plih+=73CYrwAo>@}i7Xpp4 zYR$C7fS`H82d+>}SU0fh-x-yf4~A3|b~;9C-x<}s*zKE_r4Rr;I>2CWXeb7cR)Y+S z-Kt@+N2P$KQh-+J;DX$(y?ay(hIK{^X7sR&K1uY?>Ud~O&G{$vhDoQW@Wn`#0U?c7 zhcq0Q#pxniGp8YR65P}D6`%jO()DprnhDMO$HcBl|JaOQo(u+#`X|JI+!d0>x@2)~ zz6)Aindy=t5JTfZX|8jA$p?5J1ehEQ_{M?|W+zYe9_u{`6qU=c&xle$oRtQ^L^J$C zCt*S(#J`>-9IjPvg-iIiDw-nap4uJD;aE5^vSn|KoPBEJmoLUHCYrZwb&-J`vt!d- z|Fo)R_2ZR~6E~CRl*YHWs`{d~r|zoNYb)0h*OOkQzIWZP)SuXLpNv{l&bkCrob565 zPE|FqIV)G&R@xG6Yv+{Oqg&3-sCmb1i++6jR-)~(xjDt~JDf9qE#~{N^x4v#rKBS{ zwAFZA;f_ZJpK|u)@5R0s{|JOf?sRN%ZIMCUoD^T1xT^3iF=~geOWaWS!%=F7bHyhV zu0ArD0%&@zeW+Xv_ z5y3T!!h4^aDeg#`0d)NGk9ssTsqoD(0RHH63(a+=DZ)~b<_TAA;+*0<3?p)R60RSO ztkH>Ucj}ThrRKQeI=+5mL-?Y9W8!DG9@Qx)t|+c6Q6{xt#%D10Yq4wbYioUgj%%{=im({qaN5(>r!Y{B&$FaWE0O^Oj;i7^PFMV8E|j#T$3K6Ws1@fN84qMlH`6 z>oOB#;x&mgYd1dkuU94~?%vw)Jfy$yKQ#ZMQE5B-Yv$b3vP{2W&S7Ku|Dmxsbdvpr zF;WS}Rvm=6@f?k^e}))=B(7SiN*qY`{k(bo@_O67Z){wBc;=VQ4~Kr?Q`#?VJ1#~| zPxr;NpXuLsjCkte*P??vuA0Q@m3P-#lf8F46jvt>{$k9Vh>=&_?g6*gv?eV|bx)Mb z#z_4dlf3bB|3>Bd#JyV&J&)*L`X8D9p;77j_OF?dY>e#9^c!a61uUS1N-j!$@Rd&B z*CC9tjxG4I8@_eUznE#L=BdE^EUfFQNfPA+DKLhsAKB*s09qx4hkbSx&iB(bM3&+ZNPK?|bBIH>nOa!j4d;o!CJ_e}2-vj+YUZsRuB1G&B5vGJz-1`a zOEH%~hk3ymwj&E-<;+=)Ag00E;s+${*&FXT6|R74!Z@O*OVldl+=;PNovy&blJz?9 zUx_N-OH%e?+n~ykK}|cpgy!0t7B(UV?zBYM5pc`7*`2O@RRZN}NDjFwOGg!2^@D)h zmjOAT+C<-mnn<9#Kv#8^H~&S}3G&OmRoWX5O!qT$1#}#LMURlHjqO3tRoV8So3nMY zJ?Ko1PL}DvG(E51P)5EqJumMlBVU@HAAf0jIsN(ZvyiV8G9fdvBmRNqex^oB63B6h z@LGp={|g=oN0fe3i2$_wO?Y_?s!@=a=3!A0#2Zs%qJ-0C7tgELkUy<^dAO=f`Z1e+ zK2+&3Vd{qTZQxUlaEF0Yyb&t6y~m^xjG~39@o|-c#Y?59;K-?3;n*oJ`DUl!0ErKe zpa0^dKR7R-#- z1Z@N=Jpq>{m4ZWe$O*_-kJi!|EN}DA&(ALT^pmbM2B;svFZ6r3)mj%GH{cn#8e{`pEed>rPZZW}7$Z=1sO4x5(Jn6C;n=hE2L*lWj<` z{PL;Tsf1^(cC)(UG25|8cWkm9d!ToCM`S3)S44((Y>xP;ACG@Fb7v;my56?c@K>Ac zJCQS)bF01PiM>&=H~!t*KY8cwJCE(1n@s1Xy)%XJO^Ur~t^X%y@1A{Z@7QEIHtiiL zcjKyW#kb{ded0c%xQ}ePJD#|E6nD>-yEkh6gBgx2_89eycRk@76uu$x(H7qvrJm7_ z_#uU^N->rc!)E_FQx&zF6|FJ)!N5;1+`F*Rx>5bl?q41F#{-)meE-P@qVj>b-7yX# z+KI~XLy#*Avdh+(HQw_>4*GbLpd-aNQ;a9YxKfNGRZ#_h|E;&yXiA?XZ1%``X%OZa z&v)*@Nq{5n#lORQm6#&J@Gn8y?n#CXVJb`uhOn`qI)N|!MUS=yQ-~HQsL{eL2|r;| z85sjI3PxNn*QBroSjzmBq!>BVE7V2t%D`k->#zk`O6{145KQkve}Z`wdWKU19w!qf z3wqD(shVNhYl09ge@kliX1ow)1#8&04-Z*k&C>mo+3Cyr$M%){Q(Ed3{~#9G1onaL ze&#gsC4S;GKmDJN2pBDYKkNIwqr(3GX0#Ri7_9@377?&-I)}ZqzG)5vY8N|7dLKhg z@PZw1LGXJ?ZNt;JR8|&4vTw{E5`8^hY8OiPgd}E#r0)O(`f!r$3i$O1lcw+stH11} zinhB#eaE09++`HWE45dzvSCSAgPjKa&ysb)zGEzCkq?)w<& z!>p!&UcoctPj35sc{_$W+k!m5AgeYR%6zk#D@oqq5C@uG+F12H)uFlMrHyp~IDZAd z(0^(h>ki)J(l(;CSD_DTJ8NI!#E!!iIsc5|mf08^A4v4A^(9Yi^ldZmMf#sv?aS@4 z_V{>WI@y0GxIVsZ{aWPo9$2m>(VV!n*0RO5LaC>|<+IT{qmiNLMTM?~E)RDpbaUiF zyfx7V6`jbyGlyrjZl&&VL-&Sf+wslF`5i|keihhq0j{5I8?D<+f8LVmwIl1~-7dgd z`}39*VY8O3A9+-}=@+-T@e&68S$(ixxy|+cZ+DuMw{hjI$n^|ib47-=ZP#nkGGcgm zaB$GeYq2X`#!+e451O5`P$*x3`lZTbA&JFwYI;i6{-Z!h78mEG%&GFSeyXJ4EUMBU z27|MD2H==eSv~Z96Tw;G9-;NF2PyOJTu@k;6~8Xs0zO>D582_Jrbv?fBVqe3!TpY? zgVKV5`LyzZI|!&mLb#-l>g`UV!nurMDBY#Ey9p?QjK)UOJ$n0Crv0a0 zt^I-`47RMU_h{lkU8=Dq<*7|o*QY!OQk6B&U0);ZDYNZUbQ`59fOjnXa$T%0Z3Gys zfjq!1Gj1`2yE<*bkd=TSwyoIGENFUz14B-NccyVK=}N-H kemZ)4^o1K*FD@7elKT^LnIv7$=M99t2~L5xSM{Qv*} literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_915460.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_915460.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d9b7a07b9df99f1e94b8c0726fbe8e66d28a89 GIT binary patch literal 10519 zcmdrxZA@ENmhbuV`TGO&WgtNKG$a@UA<48&Lnxtv1WF(w5K`CWnGa)Z@@xXhJge@i zyRoyQ#f)YZBUOqMjZ|k7sh<9#{Y6Kk-Kx7<`2}Wh-)1#*hy4+0f4Gy8I;wWJd+xIz z7-CA=&5X4B!1vvI&bjBFd+xdSoO9nls#J0ep8uHryNQW14Er5w@K0{K^7Nl^47-hy z7>PTv;nW)+#<{j&S~x7iF+1+y`0!m*Z&F|vIg~qLs6`6*fQ$Bki}!#__JB*X!^e~) z@t){GTA*Pm%s}QOre(vjY5A}m$7C6QWF9HEV~)Hm(ku3WEBAn_>@xcvxG_N?(udl< z$5&`{3mE)J*@Q9 z+1~bUYd2ePrr+A%(Ra3^$J%@PbbrS{zjdIar@!|+D+A2Mw)22t6=!+|I-u5evEwyX z*52FGb)mCupo8afZWKw$KQgEY;rH}cP`izJ!H80iL%`Rz-z5G}@E}d$E8K@xm_cjY zhgOt9TSVu2%4s95xgo|dx`@`H_YDxjS4`t{F;^?0F`!WJ7@$(|b>9J6;wyEX&lpj( zJp;Z*i)bU_v^ATu=c6ZunZ4+;8?d2#r65-st>q=yAv2P+j87#sK`W&9#caT0v6Qt+{^uGbTeHFC9BcaP_F~_OW zRYYfU%ceW=PTi&Y@jQK1bY4i{$)PKJ)pRvo#bIlJf9VvA=&PkGX&F~D(dA(N*k1}r zEVaQ>8lbmcC2s`wf0o+~H4X&3dL(vy;h(g$eo!t=k7|Ij}23;$&D6E8&m!+Y3K z4bo#)J8IZx#wBz^Ry*3wk$9nYWTx{eA$?eZ`3}?d^datCFw=)QTS2ExBS5(pZBAji zon^_GOBN?iA4umCsP_amqGrog`V2(`_SJiH)+k#_fQ0v(hmD zd7GP9vB&GC?5ubWSlCU+SYhvl0T#c+3fp`8Sz-T~9#-CdAsw{#v+AAED=YEXuR856 zR^pnl+ei;9r`%o}b=^u%PP4KU&PusiZF^h$>mAnaj-fOsR_pOnlce41p7mPaws}V< zSX~Ar%{3F09zdT#g@!}}snf9eG$+kgJI!YskmkaVK5tF)pwJ3X=oq0Op2H{%7k;j3%Pa(~BT zt^T%ZQ5CVQX{-Ea9&2<9zMwBszNV@4zy77dkSMK4)gLwru1T*C{Ac~ zaYHSmtqq71`rNQAUfIGFv@rUX*nIp{KXbI7(f0?W34K9Ox@pL}-MQEq?ur&N#Z9Y* z=78+6-f+8Su_jy>xy~3HV;08PvZ_C_EMxR-&`GNcor!1|gPGBqW3|gxrs*uBJ-ao~ z5vJe>1Jp#keSkSOz~~18(haFHF!$~|;hIO%@&uvYkf}rY!P57e?i^n_9_fopSBn}M zS)>2dW0_*1FjyF}gvM882mGfJYHg@1Qoh*Bs4M-QUn(jRT7BqEMq3dSZ)o*DJAZp{ zaWMR5bdo7Kx|(~8(H;wkH}u63e9`Rh{x?NlLS+aD5-Lroo>7%RO_viIUn~wt5{AO? z8OBf*kR>HpPVP_p!lKaNVm(YsZ;l;|4SZsW4gR7&e)0{b=?zBzMnL))7Gy9u7`hN1 zT(m}8KPrz6#A@!py8QNv;g?re@;{#XteZJ8$kbn2Qw^=FMi|w|+7;{im2u|E_^N6m zAV7;Y7n}=u!WZsYB4p&~(wosKrsO!2e|-7WviFyL%X1%hGx?o~;;MVLs39`W6gO-O zvZ7#`x|!m$@qxiV2yyKY{@bC;%+TevkvG>z#+Z?@)uHjE5R<7k|16^QT0Aj>usgpLGWLkn0L9182hD6#$Xl}7PVrC4M*r}NJlfKy8FS_HW2AGxs#xM|&C5)x_@JLHU zy3`PTi7~b==cf#ZIevyQo(X6+mFk7gU}vardqo&9vlyihI^U3`b1IXy}qa@GRPDij^!^KKDoM_ z&sf@4x*1C^Q}jms(q*tC%@zE&SFFqx>)KmyufOGB-g2y7nL>7?e4)&~l3>c*L{46) zGSI#W+vrm8Qiuvy-owK+i$9L^Exr@AFnNcfwwP{J)eJjjf07M26qp zid49V$5R9?4CZp_hLRhnnVYjPo2 zh7{9q-^*WXIdZmp@--UUNyBRd(s2HkA0s8+e4gHcY@8>j%_)Pqqyk-0!j9C@8b<}+WBUy$Tly**=7R4*qvUD!z^nGb zZ)ua01A7iRveRtKksZG+M|S*{9Ai1O_WRp6Eu$qgLF-7(`;x!GaV`-iG2TLOSb=`5 z0<&Nqd`^I4L$F3E*t~EEo5M}oR7y#qJx!sl`vmY;pMH}qNNmQT%WHh}dP{lB` zUyz+YeHn_s74G+WAt;0$z@3e}WI;SA7!&q`)9B$6DR3T=RBR9V3ORYrBw|J9ItJQU zF_%?Q$nYp+rye9#ncwfH&On7FUUu51-zIG*ZXB@9%s8(@PKnaYhL@f0QJd3q0xBu| zUDUS-ia+|X_w())EfqyBMh9Z1`_|RMSC(g2#+ZV0R5#$C_)+m+-+ue;lmA8q{Ju3w zzB>&x=^YSW{exq*4c32@eL<+thV64VvpZ8j4+;escT&J}OR5VN4SP5yrHVNa1 zyw_*!T+#zsH_{NVV^Dtdeq%y9p_MbtKtf}t9{QLO~Q5_Hmu7$2G&24B3~CQ3Yk9 z*7TwDzLX)V0?N=>cs!vsu4~H}ZCOMIsZPYkXlo$5(iH_H8|9TB6y7alLNKqCSmBkfh357m1dh0}7R2Vq1smKkves2H3dDpq7Xnw3I3^zQi zSkpAbPOO|}n!7%meAGM`CkEr1!3|Bxx~78BR7Ay*`9~UaoG`~V=0su5dSM+?SoiSA zhcDfKX|?cZK)oSXEi?ohR^>*2dm^L$SgH#({JiReLw66YNmD(Wn8K#j!j^#g%UbhK zF089cAQ21O)>LJ&KL6SLp!`(1T6HdwSDGlOK&7GRX*1A#EyUz?$zzyO??3nT@$LFI z9&~^9_djYEbV-DtOT=AT=_mlU9%SYyDtr1zC~sqvm=Df36iNDUpTH;diAaG@oN*xt zM!fss{Gt%zQ{h9=gLM4ilkC$+K#NH+@}68TBmq1==@05fY)3udL!>mZbx{h@a<=Fa z&Tt7j?@8$hjO>$h^CWzD#^^{dyPk3Q%6|{%wfRqYX0BMR{I>Y)uGwa7{o66 z9+QZRnkXB9s7t6ALd75|W}qY9#K9jD+B<4N`z=g8SMY=dQOW6^T#LGn~ zCKX&d+dZ>RuP5ceQOuR{;>aCy9)EXiF^>?}P_ZQ-cYN;f_9-W?<5aN2)BhE8{3rN% z{u2UBKL(c({?4Qlt8QjQ2mKv^)=+Es=!P!G-?d4o78F4R-08GNTBApoTi1wF{X)s;b(5dZVP`&dAb;zcX--5gB2~Dl;Q0 z_n!?_hif*JDu2hOF7I~nV)3H`jmvp!x|9B{4V?iM8M2&>R7Y!;tJjFOZAl!Fx)^@f z3{ZL7wj{ftyd<%@&x~=KeN8o%MW8>a4q}EiRqLO7(5QD!)$+`s1g6aK_i)iSMA~Qv z^z@uQecGhsY&9JNr?P_vCw!=yg%1WS!Ph8srN9aulV0wI$m6xo&rs=b7$W0I*En?+ zD5!I&m_WrO>L7aC?RKU%05U=Nw!)dcb=>VF?_-=We2OvcCTE@Y6Vw9GfstaKi%=$o zIFA1gQ~nF4`d6$N?!5)XP2;=9q+lG!-^3yKfE&b-;7HPxYS&?<<#(J*&SZV6eF)R) z7w90JJe+DbVDOCqtSxCywHtZ7CDm?9^ZR`>*M2I*1j_VSUk?Z|ydtT_8Z8Oa;Y3+w zqM!tBngo_>3|E9_@6bzh z+*r>T4kcwALXPPQ!?|~gmx>u}byC5>ln9oys7k81rW(`dB{f`Ci>dWV9aq(3T0@eP qB5A-Rikl$yMtPdQ(4heO-z z?Bj-yc`iQPj^!BUfx$;K|?76m^qg zDVFw8=QF?bd0K2sCZ*?PG{H%H@=3*c#ia7QlBPWLgfTbG+Ak~h$R^Bd?Ht6)Ho?m` z!7DbwD>uP2#qpzBR`rhTPF9TbY8AEik1b)<9%`Z_C*#^LtJwsv-2|`mC_J0=jY{k= zF{}3|i;oq14bZCT^g0`6<6Q;1`&aQ#|~-eS-3E@5v)Qt{$Q6*r02$tG~Ca&vo+X(ZQ~vLDx`M-{8ps zK?OahI|rag&>ZU<>H?_qbk{3_>d48ypkX#s8dGiYT6v=xY9W!Hdvs1n(NTM&)Pk>Q%0qOO4;r~;XC zaA*snpih%)L!T}Ohqn3$3V6!a){ujc43(1^^XsX4K6+MKI0{u=gVP$S1}UqMF~|KW zDXWmSpszqmqzKg%@KxlbTl+<}x2sXj9kNGILfeoXm65s7wu{3Q#nD%(9Hw9%va1^k zOQW(o**zsI2-PBMP=RWZTx`|(D^P9048=4iU$3#uxPmdoxmt2uL)%eFj1K-1Nl+b} z1?mqZER{KJs8U?fc7i4N43^9<^4Ni-zvYPfr|4A|VeM!G2UJP0*TAAfI}6rPZ0$SH zPFDV@BDW6{a9TDNK988t=D|FWw=>5xIMo9u(y=d4rIcjI%9 zt0eSj{pej!@!a+d-S2*eMjFxXZ%89e&)~G{8{_0erJNKsXsA#FYC;atzMZH+)HCrI zdjKo8kTcUOYQDKZK7HPF@&MZUlx{`8^3;Cn3-=T4LAyU?a+XUM`uN|6M`rKWb824C zDMtH$Q?R0<{(+St3z0VOC0aK14=wBcL*|wsdJ}cGxtI9*ahv&y4WkhA|I}A()WiDI z-uje9W={LYQES?3E3k-Sp0*XWeN(4>Q$MyrI=IupV{h7{?WnEL*C2+^XB|cJ3>67R$;slZ z&qEpN@JkRA>Oc(Y5J&w9a9E33i+UK8k3d8}u@CLbj4f`tVl=RSQ_I=6-Un^4nHLIV zDQ4R*p#7koY!nuHaZ`W3k;^(AhYl2QDYkP5(2F9UY|KbgSK*&S-NzBe4qA}8XS~73 zL~D-;W(@wItR95aLpJFSUKHd3e%!-l@ni^R0D{NAf#OgL65c1Mr?|iv=bm(VSwZFt zxLHAd1%`M8g`0Ex$2>q35R{|P!?(5w@*p38K6(bAarczR@sc1td1^?Y&kE8bCkF-T z;ITeIedJU&w;2@l>ysQo5%gU0dHjOHKkjj}K|#$0csF;|#d;?NRi@9y1q9=f&LgjO zxq7$RcfIN6N5%zH0VX?ZAu7wjTtG!w5=QzgZYDdEA>Yo$BHGe=d_j*Z*V2znbDYP|69nm%aFc?Zn{?ZJFF-K;LrylmVaKK9UHx)F+o5;e<)y6~&%lIld~Y#@B{ry6V8Xo?+6 z=&`j48=Dr&k|%L<4>tBh|T83 zi<-r5e4rQGd$IWhtmnCi%C>yyPFUk(b9zuyS#DT@(S+gzURap=Wyxyh?Yn& z79ER?C3=(6W&ME%`h!@1@UJ9)t^SF6N#7HZK9gX@V2w*YY>u<>jyX*tfXntKm5Z{^ z-HS@R@6}Y#1-$P9F1vsY!x359Xp5Ua?2PmA?zxJ@+t}WotY57Af?ll0FC4$;!Y>S` z-k8Gsr?7nr8!wS9u)Xh&TVrFhFDABPi*unn`O4=t$!^^C%HrF&?KC!@i6|d|)t!l+ zi47%67o^L&wgLuNgh~#0G;M8UW#0V`VvAW2Sw80cR5}k=}iwEa+ zV#78N`6#cvN-)JOYt3-TDH)HB$GCXicWH4g{j+b+*Wl8|`GJMnPtGPK$-%n} z-hF6s+hXu1JMRso22TCWS$z1^5?R~Vm{fr_&ftzS*meeM&!)A8w8<`7)T)+Z%pY`$ z-th2j&wXXpLwotH(z()vd9EU=U$NIl^l7Cj)^}gIEp6TUJ@d!*+xGe9`LU&%c5H2r zsM5BraJFpDg~LC*l&o90@>$Pf6K;DATl=BUQWig)^+i44iz@dgk1euabS{qJ1AVD} zH$LFTm2PZ#6I5@p-Rzj{i0_+sEL++l%(|heh8J<=i%IW2*%$76O8oL`sX-5Z*@G)R z*fI)ZGbbumMn?Yw)3&OknA#_Bvrt?)sSc9r5L6Jeh?l)9FghCa@DMnSOlR+ZfG6EE zxnMx!_46Li=W$>5u!7R-pYp*~NG@al#Ppb(hq#XO@Y9@sgaj+OQVM|hcKijzy6;-v zwZ3b+Ng)|aP17M2QU%S3n#j=>{0;;4dzSZX?^#1?|867?E004ClY0yQ4!y}J;1tr3 zF-TTyRyOl)FD37>?^%%iJsYBn(G5e$8`6e!NC)o+5RQfPNRQMPL2p6HaSFK8#a_OV zq!@NYG9w|m7iR>Svr&D)LMhT0l#^L$Dmym`05_D}%7Mq%^oArMnkPY{Sdf$_k%U+# zV8t6tF2~?%k+hf=qLxuHYs{`$`n8;;el?OtZ6td_+5$}}ISHctRz7bN#oj&9+6%iu z)HLXcJAstJ9DA43BIoyOb2d(`gU2Y~;mD_`m5h*wctMM}+pOWYgT3SEN zYC@)nC1Q=(My0Io9VTSv$%zsRGLtnR1AO72;XO=7^dGf?{FPR>pG4bSK|=KF~p)N{yDoMugq;{9uK6z|W=F=|1^-~N0c6;i-k zf{8VKr1)bvrN}TyIm~_6oNK}4jK14z?_RzZ;;9j)noEb@R%Itk^DObG>9i4miFWTRDUXc z)|a`c6ix|Ej3H-rQ93FamBMNyePVJ6$-N?zr3kVS)_d6@6J#g4hB^hgm}zpgFqYc} zRgef&9S(=NK>!5ifX_YoChLCr+E(|}lo8B(1k)j3j~C)EY01eG`F zy6lF#(=RAYM37ICF3zFQOt0Nc%2O3ZAJXk(-Cv9z)Lfw3AJs}m;3%@S^G ztcOg)R1s0EY^lC=ZtfhGm4}Z--b$M)rzc~#1cO82m-OD z?`{8h*X>=)N@v=z<$<9V8*1m}iJAKbXNqyA49>K@{(-#(+gq0GZ4kd1wj>nTuw&kF z-_V+3T2qGBpPKE_{*R^~R5@{#bAeu}Y6%}D5nlIwWyS68FAZgBtvS{jSI=oumRhW> zd!TK=+6D;u1{WGWYfSBa2{*p{_u7LiTJ!tOkX@C}RirF+SiAj!wh?O^=R3de{-J5% z_-8$-rh|Cbp=E97ip3UhNgSCU!TOd5`gW{uf5_+`Fk3ORHU8ET(>Oo)L@F^T!rfmf zC~H;JJ>QU&+;4c{cK;$B?*3QJiSR3rWTsdhmX$)joMu#MMkD@bPBSGU29N-o0FKp0 zbzfFHB8nTjs4nJSX3FL-0a|Rmxy0BvgTLI~6ro986*K*jW}y?9hz;`}YKw!b_2qKK z>Q2g3_OPaLzVAm}KkmKTyHs;5ZLUaLt4M7uebP#3+9|C$oH^(tV2AlaJe(LT{tD`w zl$Q#D!y=(zh=$Nc8j`V+ki6hw6U=vG5wAii1enrKWp}cnQb_TPF%l$a<>WAkqlg4h zLdxH16vJn^!u6{_N|gVoM#9~RVh{Eds~`@XRSpB^kXl?P1A$!L)lU$;Y$#(CtNJc2 z%A3_fNW-c_+UL*^1Zr9R5HAkpnflTGPWmZ6>M8vo4brljPqlZly7?MC=|nyKueS(T zt^Q6<_nTXV;s4BPjnA=K6Im^y#Lx<&AQB;E0-!euB*sG|GJ?Qlnsvhm3I*%A>>crl zgo1|m1bNqp8=kV89a1r7;_~AjlLHM?l5_HyCDlEsGDL3xGz25o-`s`swZNAR@eVKG z(F!_A6P)((!OZo&4X87B_!R=xkqQF; zg4CS5Mw&CE$_p4I0pz;7K6AyJL=~*N=idX$2k;mCcW@Ri27w&pwD!ExXTX}!$ zo@B|g=}`Fiipff<0$F+!JLcQwHAyBC4GZ*bCt%f`Ym#ih<|J8YxL2NXdzN*h zMFje9)xm;wS=atQ&2*lJXWFNjGL+U5?h`|5h@r^|=<7Rr^r*um+G;k^&4m0-KKLj% z4c{{aCI@lk|N4ZK;OdmuKgN+`#1Xe8NGH6!`0x|tJu_3B7#ed^Vh_9#39>f>0bk|- z5EF#3Ry2FpSir}ALW#oU@`lL(JMHtl%zXfK$Z9lO)jb?LG#X{$Y5Qt_4L7)_hgO6|Mo zI$D*I27IfyQ5mgVm6HYt3de&qm12{jY~}IV`1CC_hf?K@*t%;~C3dJOlRa*`RXJCQ zjXPE~B2G(imRa4ZPHgHabIGbf1dWv5ylN6bGi9`{ic+jvDTU_U;p@XsY^3K2E1_uJ X&zT{bwtO`up$yea`kLz+QQQ9oZ5sKl literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_939610.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_939610.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5baff8c3a07a48a2f162e08115bc7922398faa16 GIT binary patch literal 11631 zcmdryTToj`cK7Nn^u7`g0S56h7H@&gV;Dae8#A^6j}7=WtYukOfIvvhmGCQ9Gn%d0 z8flUWut~~dYpN_yrD`!XmDy2r+pt$4H7y3_d^5bWRPmc}sjrM1_4B4^ZnMV#4A^e`MgVo(7d|+3pXIhY{ zW#2UUAWiohU(iyp-LmurErr`H%h*D1B}=lUtmclGAnw4xY#FP|)cxfw0W?Jh2AV2U z_wQpR{tEZ`oUue%xid9Z#Fk>N+ap$fx%RXmw-;M^2Xy7H07)xZZAOA;B&}qXIka3! zjcf^~uoW}o1i_HE2{w(%kw%h{Is2wbw&X#2Rj@pM6KetUHKen2c^ZaU987;hnIoGm9ivn+Z4!8W0=~L}9^7-r7I<}VQwVy3#O;9d*?trb*qL0N7uzkZbjct9`Z$6W@lo4!A#75z?n3#)w{G-%iH3vb>Qux@da`V*hYS)HM=HwNqnRycglw~d@Hevu*d|_=4Qw-S#n?-<0G5Z?h7^_8Y-`Ru zlsjz>L+rk1c1+fn>{=;So!K#L%a^gEkENOYK<0Y;Zz@H~=CiW2epOj+=Ss7OkJN$|upc z?~@LteNtDBPwHk5rF~LYj!)`&uGiVUC$@LG|D_yzQs?nyyLzT~Bsp^ctf-sq$#^4w zZw{Y)-smN^_ZiE~bJFDcx7&nSHq79V=?RWopWxbLMnnj-MYA^V3@7&ZrW{Nfzjf41 zGI$zh87J_$xPm#xGs)OytxlQ~xjZ(Slh_!Wd(r_^4=0@fl&_yfgsbQ|*_*go0AK}zHV|Y&Db-dwnxH*Y?%3-6uTmj?p*_hi_ z+BwU~Q#32%;k3v5j=$P(J=1?V&56@`eTvfyYKE>&BFlny2sPq6-PAfbl z9!Qz> zR`_00&%}h+dW+!%@Jj>%;bP8$n0T4J5?Jv7@taaYrCAsX4TY=X%JM+}BQ14TwWJE$ z;#yPS)FX{>&k&uD{3D^y1}3%sV@o}V|9m+ z@etARS)VILb{>tfjW_~6G7{e*oiUJGlukILD{BM z8NB)STj2xi((0s4y(w2O7KbX|Z@<^G+!Hw$mB!23ki0GMN>Z&|><-Na20l~NB(>Dy zRiv#6i8r;>pP#>bap_|CYSfALb;S$2k+wT1PU>K0NLL*=y;)?6^euS;Xa7}E^oTOt zH7}XNec_wYs%17#H3g+f${0S2sK%gllPU?f|G*Y8geRAKqm`(v8&Tb>U26wF(XDl% z-ZQb`tEl%XqOJy|k3jtwLl+k>M!J4ziPF(7RMW9Cxk~@MZ*>xNoTuCfIM`9~R8JMl@?+5mca@ZnqN<*Tvl4C-YN#eBn|ihn?c4~C^n zEs>*$YG1jyI{I_V>P2*@A5s0lTCH7}3{5WD!vjmR(T1O#TNSNbL=A`6qz^@(*d9vJ zvDZ+;Yq3kOC)DE`YAaG(6W6Y9T$@JMrsL|Fpb)FGeyKiO6Ok@AN2h;yPtX$^v>VaxIOPdSKLy(?u2`H34s62ec7&$F z6zIMjsq2Emq)M|`9~uw8645Weid3~h!DCs$f-EFkED68zgEu0z;hW26Vhz2hycg+u zS0~r#Px{s-(UIZUIXgOHM>=~_X9#zFV2RL?E>zYSom`=R+P5-^T8^NyBmX8O<;Te9 zam02}`vDoQM^tshil~FJ?$bzj`g0*kogp`M`f&XRRguw1Jt}R8UR)XdsRdNmd>EA; z{&%rJe)Kmo0+f3>bb0ZOaD7}=9TX%f!{VK#Rxr4tl5pSBWKfYT#*B*(29-dAncJ3& zB9>)kltjg5M46*~h-%);D%(@G_E){JQ4al)!{I>=Ph36s9J3sZof})fcqw-75;}1iwOo$b=mh21pr#Ntm2l2%IK9Z} zjZ?lIhN&+M3zxdX*CSoaE~IY&CkG38bEzppMeO%WV4Yxz^kw1eALt?{BE`!yv8KbQ z;xN)5UK6A|$Ile8;cMvFHKe~5l;JINVCg`(JxqVl$NL%G^3kXbm9?&vtainz<1qhi zGw+HP|3tTPVx<_Fdsed^0+|Q@NjCHpjst=tZ!t&kTL>p8gcW>jW5q|D!w09eecqZv zL4DSCD-Cc8r`zXXTn^g}2hB;H?l~8nUYy+SaZk@r+I$X9!Z>{MjN6`@#VLVd`0OPM zN6@>5cMIPwx=XMknwTg3a#rr8SYjHBwCfp%yW4w)_ln*t^cT3BS#dym8ZzmbZ0vW* zJ$k`Neg*Esir*_@Mfo_>sqC>3-!k7buw=eAtRNp2$HQL1X#z&i+iM(WNg8^<-|+Oa&{jmC>DNs^TqTI-%|o}oWiKA~IT+wRL|v*bcn$_QA(ZSdh}n}70s z!C|vmZVfB-zQ$V%*bs9D_JxVn&{7t@NM!P^)0rnXChxO92n(2L+9@#g&{gnkI736ON8W7MLpldVZ zKc}4z+G(EOj@E6{v0TKLkSlMSJaX;bZl_#%+hygN(6idFzOz|5D`91v zRK6(VCrMbL5u%2$5MJ`8fFv%%8m4f!gqMk%q**_Fr+#SO>oY;V;09dDP1&Q(bkpgZ zGG&iFJ|V$im&d4}xKrNea+!P%x7WjXO|zJXcgn^%OtkOzoa2sV=lPjxaNS|DG@nR! zGI=IUkUf}$LHJvB%M%l@eR2#J0bGEj#1niRJdr_R$ElFagI$Ooeoyh8^Dl%i`T!Ub zf>R)|!DbW_B13`+;Rs8lHvdcDmowgmhQ(>E%bKQn~2Tp;syzKJW;Yxf2Dk=K+abLW9eitC#FS=K@ zTo$myF0)AhZCC|@Jm^F~ZnX=6}xB@4{ zEs%0>L;N9zv<_rdB%}QDa6(vBPUM99GbiSk2f`VT+vl8|_sm1u074k90Pfbl+g6uz z2722)aG?f8%+oebLOX6a?GFCh3>hCjabqrK&Vi_WxWt^twGmv5;$jRJ7jSVA7nh*m zs9oA;%r4lxPAi6$Xq#Hq99Xr1#h zj05+`u1zt3`4SMvjvY%amoj&em^Y!h9e&UW zs2Iui2gSh~i#L{UZfcA%OYx%ks#DA;&nKDr~v?U~JEF0AaQT4%ib>~KP zFRJcMR3F}~KlHg+VC?!eAy9U05dtwTpe7-ds)epl*P{KdbIF-dG(|c-?782wu4syh z24gLQanT^ksmByen;L52&Cr{W_lW2}thiqhud)R7f#Z*)y2X~itoyL(ep5o)n$+yu z(9|GJO;jAYwXSK6$>^Acjzg(wO&aSrjE%_H7&kU=7`u_NJ8pa_q<(BH4_}QNn}X`k zC}Zf{`|}%>t*Ekfg^X9W2Tp#fpcd~$2G)xkKh=~ZmDFNKxL{ck)7Kzn?S`@eDI21e z=*UXLj~iofL2Nwo3+2&GCG~yF_if?Q<+7N*7AfmCl#NK)814I^4I=*2KRy#{If|N& zC6s-ehKfk_x}i37J-MS+vu>yhjYK=P#DrM+xr8uOhHTM>Rl$10OZU&Mk%56wv6)w_9 zRtA$uDpg6PhW{%~D&Y?>@K=;n7UC(|Lqng|HV03pl28d{`HBV57`iWCu8uNsW&0j* zxH_)f_i5#UVBf;4p;s5r{mGfA4F%aHnli=?Pk5hvd;Xhc!gS}F`pSB5>*1pdA>56GAB!HOBS*jF1$_jp zm=@z0gztqVfX6TWTD@f1coj}sDJ$E~{KHKx-yH>x8VNp@Xz4hN>@VQwDTAne`v9B9 zdfD|%614mSl9xBF1;2tW@GHNBh6>n>e$scR7wSj(we*v}*E9NI6|9n0e5`zszQ5rc z?H+nk@p}3{ZxJwC^@EJ=_nxih|IKW*Ut+d8JX@B45Q`c%i};X{`3|5MD=x-y@j5Q% zp@45!H*g($6HeiCczsqod{wbp%tGcWCT#b4f;Tg;!+Ycjjf;nv&Nsr`N1?~gFeH~> zI}F4JBspVr01ihgx$O>`#f%H%#PEfIcJqM(gYVA}B=QlJSp{Ec9Nu}C&zp)tDu9>@ zL~i4n2^YIUq(8zqoOs9z$h*%S0ZS@O(eWzS5v#lh9se4B-v0nM5+ESQ8yMJ968k%l zs5;Oe>|E>&cR?mKaQd-KwV((o7W>1UkYG*?BN}%s?fo7pG)CeKUbhK~Tvoet= zI2<^>2UcYaSA|C+HF4Gcz{$;hHTTAs#{&bwAtWk=A?sR^s46hHxIbLKsZ<5}AM1+l zmM@jBSG27bC3MH&D}l~{iyT=7Bm1NEtNRnOzN{qEk%krWek))L`m&Pj!e%5{X?R!~ zvpEu~i97=RkLqB>kWh90p$Con5~_pG3@Rg(`oJ)MNq}1c-T}kICr_R<>v&sbt^Z}!$nkBk0@;DyQDgl0YTyvuQf`7>yT&pPjUD7SqeYf=O(k;Oyk{l+t3UF?JJT$(gNWn^?vg)2|*|nug!D>QFEwCYWOOt}N z1Y~`|+_rQnn99)gDcF$a_v=C)ekz3UH#63o0;0{5G&d(rwaJowNrN$2RQ9?4EJ;#H zsq$U+ZFWnD0r+gTP#!AZ5@P@&XV^Lzl=7gAC@c-vgy-+E%WSN)5gD4c6#p z_sZdui+2B(g2yQ_PQRquQt_agpo+FMd{s-RsVyB}g&7#Ocqz6FghcV~_}k-83o+&C aoPZ!zzm|=Xr2ccLYfN!<)!PbQ=l>0ugH`JQ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_946209.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_946209.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475ac51f1e218c7a9f929e202fbb6b34f8edb8c5 GIT binary patch literal 8724 zcmeG=TTB~SmQ{Y1A7x`6<`D>x1c*t}#Cbq>Hl#xXAt4EQk)*>i>vmiP*nsVniw-ci zW-6^7HSUoXql|=mXQX!b7u|_8`DFfRX{15DKQ3T)ytNq#?c|3)rn6eDR{OE%mR+_( zY)JQNR=cY$xo(|v?z!ild+zZ)=bpc(X)^&&mH748ANLW&pD?06bOw0IkOZ+vFa$$R z5JTxNIYcVBW>Pz(BMFuaD00-H=`W*Ubph+EG>kFYB6M95y1oeAP?$bqW{e-}o@6E- zGMWez!IUr*O9V;^Xk#iuHy5E>iqNgBiJ=SkjaV7m$NC|2)`s?vi6M)IU@ciKtuv&{ zv4zzekd5zMxz>Hzb5#WfWXq+2!QPu!db{rQ_Q~uWqr@!z3aWc`n+Ws z0oKC%@>75o2_7~t{RA{Bl>0K|lT5!*zXzu#%h|98r#8#ki0V=4_%K04^r!)qVTdDV zz>$Ols9ZRRbdTJ5vUEAUR0!!%{VX6^i3j%J*YDsz7&N1Dh8*vKnYDP%phlUw5!KA< z2?8BdbwU|Ro_vF}IV^?o!FX|_|ds#=(asTgM-*>ln?m=#! z|5E<7bN6yH?%`Q3;{b;P#}@3)8C+Mz#9fQg-sIXxd==QOI9E)&}1#3vKiYs!eX*)O&cMDrRBh8PUFtw znXg;cOdp5WEt5W(jL4c9z-HXCQJI}Nt^)4J+pv&^ANc(H?Y4&@Zafn9`q;KHuW#JT zkA;HAyrXQ8Z;NogHlCdfw*l71g}@Rs9v&PnG7{n@-wMxofDQyd9}0SWA#kO4PPLyn z{tooWbcTJL<$~-4*9R(Mdq@1R2_o??JtW~MTcj80xOcszLA>qg`%=!jsAbbxH)q*2TIU`< z`f2>gnz0E>rAuRXqn=-JpU-?Yv*cLnTdzMQ(Wk@<&uRPoz39EzgLun2eL%dBqRV3A z65S|X*tD0%{E6nT9ZNk+m0z?;PWQ86siI4=cdcAlJ+*e@=4!_`*OGVcOFj1``+f0Z z>U;9PFk9zaqAjt*3Fp`3(y1lm7xl~94f9#C`+3oxO-E_$UShvg(Iz?Co*iEPk#wS8 za`cP+sWNB0I@xqusy;21onEF^-j&|IDV5z6ul%#IGzFvXEY!>yQjUG`kuPs7=@NH8 z3q3PS`_D*@Gs|~YdcJL3>HCdmbyn)UCpqpFV;os|@SDKup|$HbR$HXbo5|ZAsna7l zJadK@V8y%9yRkbtdo7o3(4Bvv&r9_Am5Oz`Z%%_vy(iPA)jmi5V`W3MI}u4Rb7YFP z#~Pz|V)x_jg^{RdPV*Y!1!bP6qExIlVfgaF(xJq|FYM2bN)6{E$GH{fs_t9ws_~op z_AOeaS$0?E#nMr;`t+rNGSZsoCedh-hAQf780ktPS2?9hJK7B=50O7b`i~ zer7QeeP_^o^8uGD!o zd3{Lg9FiPE#cW=X^?$0lD_6_EeXv?7b=^n~ejs&yAUQrr-e;2I57xLya%v{Y%}C?3 zzo&k@o4wyi#y?O$ehDWIA(PWw2YfkvOv+~86h6lxWJF~1q<1=l$Oew(r?{XGhX?hk z1wgqyx5A;dNFXARy+5lga2w5laIZWJ=bNUV_=p6~%)KzjM-amxMbHZd!6;COVl=Z5 zg&1w#C~|N4Tp7~r&6U=I(FrDa&7cj;yz>z($iioDwE=2pKS5y4I0v82!wT^T)?nK! zatf4qUXxRR2jp0IyQ%}DSMNd}kvt9?%2TLdgL2gXTCpj_s>X z>bycp5sqBt!%>98n3YkA$bx$0WRw-*$jK;&{VueYoQ#Sh961?IL>H8ivqxnSj+{NJ zig4ugs4l{h)1!u=gnh^+)bbVTt^}Mal=~t*V^HfPdvF@wnA3=|wU-gCP{-pUi}Ij4 z3?dpT4kT1!vvD2P%SbcGqN(K zU+MprBtgzC!et)V!hbjp|8tZ)|H7=}b}%v(4u?3NWf~ZMCd_iU2C#*Xm#qe9T>7Fp zs*9GjjDHgH`!R2rmG#PSSr5rRB<;o#xCij=b`Do)GuX)!viXj8Vv6nMxDY3sA%o** zJQMzLRyL$_W7&|-+GL%N@qegfXRiI+B@mH2D!G1)UdBxmZkln^f}0!A$QFN)XSoU1 z`ytE7CSNERm>Ts0E$i}i){6V!xJ-}bUw;?fiwYc#0V#aNNn2(D?NWY zdiv)bG4H2mle)^evnf~WuPeW*lyuEve~gN^Bu=I*wwNQD$6>L}cSJjW{!X0y^jzHc z`RHe(zw|G;|1Ox+9Zgz}rb;UzGcBnQFKvSGzG$D=n=)E|p~v~|8f2@L)r+$Wvr%ix zWcx+a;<1HeYo_`v@rA7_S$%Zfb}UI9dttW!!CWJmYvS&8bDh||X{JGs!(w-;!nyeV z!uxUWYZZSDtcKF9sF@3CH6}@mg^So3;|MFJ*Ixeb23>^OvKS)#snED<` zSnXmzcM%-I2XFSl(GdIF%YTD*k?<1&)Tx#DO$r)8E9e-FpwFv?6^Qof(*)@l?Ni;8 zOjaQn_86mqvrj(@(v(qH19$}E4;n@B*<1#nTZmF?YJv%w0@=u)o&&dtU<_yB%NU1Y zX2GnilM=|h*|A-AJfj3-`jS-i%~(OOFlNE}1{N|y?K*9rOn$PbeXKvoK7~iUW*=mM zbKLUO`XpnUUF<|FcKToM5wKd@6V>*`y~6%~X0;`6uv!OREg~R}mkzjeN)X=?fy50! zvvnRr*P(&XImLKgq-PE$0@euE! zI5d&DOXb!O!AZ$0F*+TR9g3A+iOVy<`CE7+{{bcw2_=SXTZ!hAlJ0=mJJ%j-k9TZ3 z%EbN`6g_W=T4KHN_C)(q$8!4y^`6)Tk#hb>^hj(Z9$4yvy2Z$bxl`;eLZz$YjqzKF zgX?rNlyK?}em?x!u-G?uRnpbKlud1tu2H-aYmT=-_!fI#IGl^M3$<(ek1snn9OuRU zO-DIy@^rb9XkKbrZr-4}a+(AZt_ZSbso zgKq!to^*-dz;(49rl$yNnK+;XW0xjfAa~J<8*@Y0b?xb-@~sJ|cuzrrQl?ZWo&0J8 zeqPB$BRo4D=9H{fHVi8ME^8rOP5%wBr^%MI^Ls`^6U^x$9)8R z;EP1$J7~7GBuV~>u>KQ4|BlbT00+Zr!PA`;RITgm)zba=ZXjlZqp@1*fF zL}@v8qV2y*<7bJI(s>j`+ns6r903*2v(dBPy{q6awS>m1%D%288V;qLm4CKeCP_Oa zlAoYQXj_YsO_OcDHd?!_#|S6|e*%me6_g^tajgs1ZIgnU2}k+1A_9EeVEJVD(eO(e e_q-g_5G4H_HAs?We-3L1d&9b|@sUNb$$tU^@6QAP literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_99563.cpython-312.pyc b/src/temp/gen/__pycache__/rotary_transform.py_gen_triton_code_99563.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3291e2c5811737b5dfe1c0e9219d6e0868e9272e GIT binary patch literal 8967 zcmds6eQX=ab>HRZ@?-fe=}1vBWqp~pELoH-S?wh`TPNAp`Qsev8e6A;pt&oF5=F|( zl}_>vaddFOQAq(&$seMCqNbA;N5=<5K?_6y0-XQ6q6NyebP=(2Q6A!d^dA$aKv5KE z-z-1vqpZupDbTLW**EXKnK$$1?aXhU|Bf&8sKWB*BuircyVRHzQ{ivzAJ|L@|m3_{szL zs@nW1icx>4cv$RqLaQMve~g8pSYo!mWV7;SGzajt2k>>Qiq+SaP08%gF{5YIwdInO z2B}`698)E)e2bF39?pSVq+N!=c&5X&=vu>hqVc4ZY>PQX<~+<1v*UB@lA>JzG}e z#pfXe3eT*pL{;7l`LHOD+7Ho`l{FoQXv)i)jYx5Kuu5A|v0S2%0@?3CrQF8()$|$J zbFG?C6LQ{xPa@hJQXpHjnIWT2q`cdINN>vLt8in-PmYYiiL!#Fn7FQ?Qf?AMk4!1groTw4c5*@*1v=x!4 z6>0Cl3QP4iB!jxXTyL+kQp?kA(6g?jjgB&kPm~XFq`-5?y!8dG>aslVjzg@rgH^i9 zs_KB@!o4B~QkQ#gL{y269s>(9@a(IZW7i>;PP{TpCs92oM<;be^d#y+9g;UrqE5*l_?zkmu2e%O3sNcm zdaA}}sE&;Sjof9=noFKV-RS5i>QdB^)w|P6dnm3#rJtVX`w8vg6{}jMQ&Qa}eLaN2YHZ1+7mt|(N;QCxpT1co~lSV4RPqH%q4oJ%5tq3IF0(Mzu#!}2~sOY&!B$PR~^$U#0uf2 zu5x%o-^ZuxKBlG%L(YE-PI|?>FES@8!~6`(75N0=CON##x;3JV4~p6aEF$3}QjgJY#YQW4Hyq?ia_q`kmU94Gx~^?*oBI7q#bDF2n}8Q5YazvB=*7=!g-2I7`sRwJWPv zk}W%WXKeV1-f&k4dqpY(UKTSo)B#9f}uTjA!{&KSDzS6YtyUKiRqLg^}gWp2=$(A+oQ2v z!=*jLWx;UypT7Oi6aO%=W4IB!@JFcTXiMUzKsTrA(srT6Bha3$ceZZ|XGR2iBsP*Y z+7c^*(Um%r9u!(n3&zu1i`y&0z(v7$F?R6_wIyqFB$?Fp&6%`YXzCNdi<%0{yctXf zgr)((HV~(tz`A{B^*f21sqroAE`4T?J}b~?e{K2H@his;{bpQNQn+!D!WTMQ)?!Oo zDC%HCjELYp@VoD}wV4 z!TiRfqo1AroO(NFa{Se|@7>(InYx|s+-dR%CQs~A)?`nP3Z@RA)TWQ3ccUMEkhR*A zm(m)+*(+Fk<5bq#m<+(eHunkEK3G9(-NwNBK=N$bz1_EK{T6V_UIgB6sMC23K{Y&s zuaywVWl;|MohXOhPc#&EVsByZjEGwQVsSr>aBi%v4ulqhu#bo;j^!7*kY9s4DL#$x zmhPs1fRlHFK#G_J5wj&})q4`;QB71E)kXE8o1jy?9aX77Qe*BI>?X2N;(O!)y@0JK z4ZY}*qLvk=aOsyz_eKp;#*Xs%gpnSkVB`$pfHA@aUg03HPOV!~33rt5LrSEK>JVAW z-Xy-&8KOoct0k@7rfQBxMu~BB)Pzh(Kj)~LdyRW3^j3rGmo3Y1kDLj{P(vvG`Xs5#VzPQ`LELp6T$U@~jLkMCfYLPq%tHSj_+mjHFW<)S5WMYW#%L*wWWU16&hps~o(15@1 z%L-*tSB_IKpj$(NTQTyemA95hW-C#ZH!_tX2v!-#i48`hLK>t+x+!_IE~J!bC8qg? zQae;9*(D$|AVU_&rNH=StFu8xMW&p+IY3^{*&HBOmPU7gyqrgZxO=(%+VYBh*&*#~ z%PZ|IhqSLPKdgOa{EvvvJ+V8JvuShMk_@HqY`1QE(jN%T-`r~& z5}Jl~nud4mBk|F!xe<`0-FeTxX%`gru}ksyvN{9CkV*MH&8FtDt|g;r$>>_L=GT(F z>y5Ed=sVf)Sm(+pTp67!tJALytPUjn8-ew}uC6Q9_wnHU!NF5{Q2I4~rd17My-HgfgSl^Os@A%2>pS7eX9=x0BJS(&h#)p1H{-`%GyFQmOwFvs7 zIfD^{Tic$YO)#{j0PLRHIwu&;09M9W-<~zqC6Wt20ynk7w`CZkc~Nzcu#DiC;`SBLD8YS;{C4!jPq` zS<0L>HNd}rZ!#)Xc^9EK#zwifAQt`bD?j+;m2l91K)FE#h$x^L3GtqUx{JimJj&Dx^kKDXSONAk8e^ zBVKs$$tsPfH+s0`1Oijq~e-l7HY-&@;ME^bd=^d5NEB$xr|1BLZe?cv$lN!BJuSe=}RtSD39C&lVAY1uR$H z3hCtEj{$2RJ8J(8T)3csEYu?7bCaTqVV44aR#FsoJmiGDe#kU=d)@MV0@90I2(-B{ zF2a2E4lAto)VHk zOh`Ixxxc^+RpA5&=%R+y)nbs3aJPUjVFS*ODLg3Y@P7eoh$=Yr7Z!O@55NHuS_NcK z8Ek}xhR>bHa0^h`loWT^kkk=C(eS+d>Y&hPZ=uZyh%+}cG zQ;J^Gt?ClP$^KM-dSJVMmpUIi_f%_KJF$8qF$IVFxy|s_)UNg$v7rOFv_099980zC z&>gXnTvO}4$<4{wh4>Xg;Q*8EJ%XYob~(|J?9A!u*zi+x{f2Yh`MCMicKxpTjo4_; zY{Ny>ST3hJ(w*BKyVSYLNM=*Vx5)cFz}23sjN}DwX(U_6KXYV!>@Gc3V?cjWA8gqm zZSucuG|lhQr@v&BBJ|eSRq51qtE8{t>eZ2v5w}^2Th0s7AsP#qxH|}m^hHQ?id2c> zaFQQld1WCGn&$8s$>A?kl+Okr8w;pA!n4Z@ToF8vOQ4mzfi)Dj!r@@yd%#gB>Iz?h zcRCzo9uU%a3SmpVJ|AWlgX}lCRnWu#MG)zRGA}1d^0$QkUkLhN3Fp5NMw9}l01{|A}lr>Liv+YwVN+rT)c-iQN{s@&aS*h5@-ptJz00px#@(2R9+|X^q6N|r}MN_H4v8iyip=egu#+G xOQeM`+49mT@-{-H`)Km+~c?M|{lKwq4PLkH&EyxI?Ysb)XS10-Ne*mkcuq^-p literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_123151.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_123151.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9cae45432d509e1ce176111ebda0a514fdbc447 GIT binary patch literal 2855 zcmb_eO>7fK6rR~#+q?crf>Qzs38b_Xxin4_2dR}J2;`5H5)S0#i@SK&j-B<7cGqo! z*3u$Hf|NrNRaK0H+O(=vT6zEpiDS9v~Poa<>L3>C4es;l!&~u{LjcE$oUx2WN1SDVy#jO#?v8huzI?iAu;;ea& zZEB4IB{JFI-{**yq6B)8xz}8X`vgWr*|vZ01A8vT2%beI?svR<7g0PwAu-_C2N!U} z_kx&D8~*bjULL)exH1(VPulp&iU)o){I`4$y=t7R!$+8G6Wuen}kG)AMZfpyuV zd0bZRvzvwj*64f9)2_i>*{gXqY7S<~K9{9;$B!*(Nb?d4&G!RkpHOC1dl_X}M|6a2 zz!4omV;$U_%i}H)&~VcuXZj>J0}Z1E6rrbO@(~eiFgaeCi}(yml?;DTE~I5Xm&gc) zmzVi`S~M7`zzbxTI&08MCU1BYwIGWTzu}or6jj;q7fNceq*`)5AxdIS%&Uq_+!=Hc zjtX3r3dy;IqVlq8_zJ0%BEmWemVxID8tfIqt6r!5dAa|Sf;^`b`J~uC%O~f+c_H7& zr{Um!MNamsVy@T^T0a0v%!zVdl%mCj1k5B>fzt!9VzgSD5~)%? zsb&g!xdR-+Q$R`yIvu@gL4nY6-%?)%*Zf`jXgz$e;;)8}>lf=Sop;WEIdyw#oxj^v z>mI7M4C#~gKzJoy4II)ZcClOOUjw3z1&3%4@38dj_!1PdF1wy_4xhQ?n~8y(OTzN_3(sz z2o@vQa?co^d?_bNh6f0poHhJFDs>?tWq>P$9ubcj9gE0bi{%EBmNSBdkX#0#?1%+2 z;DRQ<0c*sUG724_sUHZh9a=r~S*QNNRw%p~>Zyi$Dts-}t4}-)OaP6T_LZ?J(^+S^ zI@?lb!}a!_-#L%R+X$jySifWfPr_MFb9fSkn_5E|RNg1wr6CkwV^_Y}g)KXoDeExA z@rZ2Bm9fg%v)WD;r8Ek{O%KD8D=+Orm5#2lB6 zA8k*MZBMq)_SkJQ6!FT($eCVM%qs=i@TH&_sj)%G7bSr=ADO{QnY>vjBAnrqMWrOE zib2f}K%p_~(f}xfwlWAx%vNHc4B5(%LMTn1v)QaC-PZhSC*{cR!=`uPr))u|qh~E` z`sAx1^0zK`FLgiZc>CdrM|}_aHn@>ZZoJBk*SJaj{I4yYmE*Sx`sA~oH}%Pzm#R#c zWpZ@sXa!dW*JmCceAM%xXM;Pt$&FRHu^KnAhZQ(lx%!Rg+t62`4KBLL#j0Ga#+}~V z;z>vJq4+5GAh*GdZE{mpZmPyz+|{JPqF}o|ZDq=kdH*uWt46DptD_S9Ehs@LhMR9n zGRZ{-oz18Qmn`Ig4D&_VOreH%#)=%B$*X3(%o{WUR-h8;f+T#0OkZ}U&lQA{B%YDa z!7fK6rR~#+q?cta7sc#0x6S+d)@mCupTx;1-MHg>o_Q`NxRJf{DCTF7?00Pgb2#dINlhn- zfB3VDBj@9nCSqd=uFjhWWcn>z?U;Ut0dsYv8AsZFN_efdU*Js6WUqpC!DD*rtm=K& z3?*!`Yc+3IVXfdby(TjQD+OPjrFT1yBbmtbQVY{}k8wN{c*EI7Y2FoGAsuixhtoH4Euvt@tYsOaITU3lHTn@zOt_4w&qRHMF%;s@X=+jvGx3x}yrL#%lQQ8H zRg~yk#x!B|RF-&jL)GM{pLk~DIYT3UHE-ndMpe$nWkt@&Swq)oIKt+@Rp5%ECT8Ng zA!-KksZ&$B4Er=$1kVyYr&C6Cy5paVJzuKYjGhw{a?iAwn1Rr0wo6Qc+a6s@^cZp` z*8^G)z=_YuT2@xNb8~T6i5ptVP_yxbD#`KQk3K$evKMwpL-nQjR6d(9QfgLfhJbJh zke-LhLa)Op5L)b7=qlo}zs(w{MD`c`rN|NMd?nm^d-%@8t%+6fZdY>htS{foW@-;$^3TW(mD*7*;uGqttFz(Qc@%*`+?aNg_ZuARFPE52U}y}R0b zci=bf_!b04S`$nOm!+Uk-q;MbK)BAXAQPqQ_hFR|=T>~h8nvt|F3CzR`ti;@7*J?g6tXx7-Zt@LqG-DL;)=8MY!Fye>1#Chgp#dDa zPf@4Lq)m+|L%h{&5iTL6z5t!OFn0Nj)&K@NC12C|&ZoYHy*1^@-{devwrQVTp>uBk~xYh!jFA;oW6jw2D#FNcu zWQBNu(}`)~2WA;_aU}(Wp&W^Fc8Z9%%5lOawUksvNKXnVJ5VqLItC%~6GT$*&lL9L(^D(q40l8}k-ab2;gCEE1}!gg6oy9C+ZcPLlDhTyE4 z+a;7S8OSy}2wQbUX{T1fNgm0&2NIt{#9fiFbx2s3=(Lz4iGK?TCDdCZiS|I@mg!wl zcqa~w>QGbdw(*kYRXYR|+*;@da!%KdfmZ)!U%k@rhTQzciS*csrDotUagb*^cr$jS2g0Bf=*=g_d&VHOQLHhyt;HiZRcrit$5Cti#)T5sRy8@aaHxojeXExr z?tj$rpkrM)y&;U2gwe8ab{8uMv~=YM&yS(+LhC~JhR|OU`pd$J-7TIpcR!RLWgcYK zh0zURq9ja|h4VX_Z1O1BWKG%yCmOWRU!!G}G*nA7fK6rR~1d)I#n4hbYAB}KHATg4%9P^&6}5b{HygafDtU#yLH-|&XK}>)lBkZG zzQIo}oIaPhI2IpHGIiEipvhUb+A%qY2D3dfltXQQMZ8*DFEKi!)0aWJ#OiFFR=w?- zqKI{RwPx)y$VwjFqf^r$DS7KOJ==a9N=LefXz1R%lw+a9Y0fgra4zW*8NW-qgwDBf z^K~9~xP2YhJo1*$T9z@IMG<;jA%}=ylgSF|bi`{?nr!+CNaq>wY(xld?q`SC}@i5%NMmmQM2h>LXxGdl+#p&xHIVjnDS4^`Q&s$ z)dWQ|z4^&WRf2VrEECU|cvdBd>U76H7rH;sE7NL0NJ`yPLUI~h=W|^`3e0w^O0rv% zvW0Hox*?pzw4~%DIa-)afFz+QX)T{iB=e$_h<$kM*zp*wm<{%(#AGp-)YAE!(gF_Q z5g=8C$v`iBfrYO5u4P>DwHc>3!_CXSa`=dGZX?usYw-5i&9N2XPFtnDuN>+#MmPN7 z#dz6&*cf`j`4+tM-bHrVd!u1RT;o14PCc__3;uck;;HK)5HOyrXRe&N7GHj^9DIAF z=T6^m%&{$Si z9Z5l!oy7M8wgg!{5p|Y{C-opqe;epl%k4uq53R)S?Y}3Ndrntchsp;>lmoCB!M1zG zWOK!=B%3UtIyq(f09)E@LQVr@2udQXl_t|;qufR{lTy;69U-|4g4_W+NPk^u%28O` z4?pz`Oc2e+zVOn48wb8>HO@W@hS!4~q8`v~FY@D~EPl7>rxTuN3buFPZD(zAEQjm}Da9!ou0Xy3@D&u0s z_2?q7gf-rg>f7~{(kaL}-49E36>6ux!HBHr*&`CKgUMYZq3a@{UEGt(jzsQ_NC>Rn z7%8+z6mHGlH43&a3f7H+^av#K|0N31_djLETNjDfjfB%Vh{R7K(Gk4Dz<7j7Iag65 zswW-EK#kxApi!`*Hd~NPzb2_#A}Od+qQ~kXFG*e%^6KF2LUWJ`s`4&y^*ej(nV9Rc z`pEJ0*zx2LjgRY9`XU~sli2iVQcle)rgswBk{0WQqEQqD>!q2ToX%NIBf^_rMN*5h zrkd1D540SsL-hdD>oC2*#2h9DOrOK_sRYvGOp z#^|2`rSd^s`${?5{ z4N(0QrYemwevSf9QSf)vUiBb~yV`c8txEB@=V#@S{)_$!0l7sZDqsn#2SYf#czmh< eM*kK9ax02(@Kqp+oBo)>sNrxW)Op2k`TZB2<3@}C literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_473025.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_473025.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05bcc1611191c29e32c18ef90faac4149de7983 GIT binary patch literal 2857 zcmb_eO>7fK6rR~X@A@YR4g?YsND=*Ut0d%ywn7mkkRMVY9N^@OwRqQ#owe7RUAGBZ zON$hVq#Q_8RS6Ojp{fLJI6#F&d#K>vlX28U3zd-I!p$YERQ1%iYwu!cl=RS%X7;@| zGw*#f@6CJrODN<=(CU8vexl4F^xS6L##D)oAAndy5|XfjqLnX>VuzoFj2unKyJ zW|OkTD5@ygtMmzr)>2uE)pV7}5y4`oVmX~yel@S>@_L2N#$-j#$XQ(@wmB@3UQBsAN*hiE+7WLX1y>Q8n8kCP87BM&ez% zoXK?o*9AaglQPN5N@s3529lUgQo5Rr#Z^g;^?lNNsOK=uSoM`fu|z%_*Hdbi>;ii} z=vo#A6TJxUTyiHSN1p{NT z*?dM;EEd8!K4JMGwDjqil7fh_gCxRKT~!Hhi%F7{R6%Ip6ol+R1!JHF894}3JD_WW zFaVhK9pS~@3%kE+Ha~kB3a^FQilMe8u@q`IN1g;mAexxQ`Jp1yyv_@=AI~3IZ`}PS zOS4>=N5Qap#z7qiOSj*ZIJlfFx(OeU4w-7!W?opAmQIaVu#aPEs3^4 z;pOKoQLr^ASPunfaDYPiFDNAc|Ku584U(GbVDNw>U?VXQyuRRggh{$rQ6!qu0jAewdV%S4nZ7scPvt>3R8Q^Y8*&8Td;nd03WJHB)is!-e*}@g{%Y%7 z>*J;m@4x$?<6g(AFt8>J7lq-HFlruuQP;e5;EHOFK5IK@j?SJbGA)%Y`{wp7;iaDC z^Y?c>XuH?8DjZuAhKj;aNf_B;7Z{qq@GW~QbThOnbgl`#MWMGO9Nv1z3gcAum_fJ}EC)8IGb*BP919>RKaPw7)*ty7}(<$8& z;%XLxVJb(Q)M;_&9nU$gaxTvbRD2Xms*3a-a?H3HKBG!`MLtT7gF3t3X(9||8e{w% z1)reMZ>SZ1Vk!P|%cYhw#o_KB$wT8O;}rsOgGN-)5thR|4$mK6>|f|FBOsTDZ=Jq* ddc!6+4j>$Q73{=~zfWLPzrR%1ektHU_y^dGM^XR) literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_502063.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_502063.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0024c0e99ffc6edebc0eb651b6becaf07d5fb9f8 GIT binary patch literal 2910 zcmb_eU2M}<6u#H79Xo$5=_s_c&`k`KP4u5>(kduWS~{R;f~UMhCU)DTP8`|J)*?qT zQYGkkXsc;zkdW3*lg7p#Kth5&!TX*35mV%Rw$FEJ?&g4cD=1?_pmGF`kZs` zIp?#_chAkw(P#+4IHCP8IT}Rh1y$Sux-&K&!ea@^NX9Bk)}AI*cSc|m+ zSHQc(S)9+yAG_`-V=Hi{9_;f`?o5W!$1R7OrO$ot*T`z`$q*B?4nQeTf7nY)&$!5<$Y4DhU;IZ9RdPz+G<2Gd z2$xkYJw@Y%Ym#9~x=DiC__(3KCR#CqrwN`hDCs)9@#lrE&oq6?C`f6gYf?&2L4GZN zM4EuuU51|SGL>AR3$!l4l$uiXyrOm%W>R2DnR>?5@~N~YE2*R1z1<)5!3t@pZAy(7 z^Jz1q<@IJrNY8~)gvmy4VyH1bUz#hGKdCkxu+OZ;Tg%Bx{IGp?E!J}DvO03@=SqsM(l9liQd+3=Enh(we7r1h8v0+(W5#F~?)NS+Ox$we?8!<4j{IxSz z&soG7`E>g;*{mE(NDDsUD zVfqs9blT>y1wDr*aJfH@At*EE@Fc+yzE(BDrsd41ptCszsL}uq2K^=-w!Rx4L>N$| z++|rhcBMTnsj6G~0icSFQA%s3VW1P+v@ld*Th;_| zO^B@t@wKM*-$jn&*CQwrx6eDk(-6xW9#1;qOD&-cD(%qvCZH7GU|;Lmf-Rd(D(f-C z%ZOagm9QzgR%^4NG8O~n4h+DOuSIQjt*|mD^E)63x=h>_Nx+9B;4z)Hb|ee`K;l1T zb@&r#2NYh@-WCPtL&13{=p}(<@xP#uL;q7{f<7cc4~bw2fFw+jSO^YaWGul>xD~7m z>&IQoU|kRf;ilkZbEcq>u&EekDlHjGs>iuRUedb0M8F|9ws1Qr1w%gsCjZV}ztZQW zTz}+ddhBL$g~rEzm;MB=chH!;spJh!C&6*Jm(0Fi=odvWWcR zO@lDgJ#gJPS5yxuy{^&=N}sFrfzt0P{RX8pJ{|Gj9!_RmIlZ~Fo z8tsuc5fo~eZ<}j-()`~2cOM+NcVtBzToq4O#M4!A#6J03tfl+F44sYWHi54z@%Y7I9I=Uiuu8MsXv9Br~ z-QMF#bLV~KLGE5|MI2fcM=RoJRXn?;%esIfP4<}EDHHIXiKtTseILlAq1LiHRrsS& zgqs-(P?vPtl1Lz%F^QPg^1zAdg6?!u!e6QtFp$ZcPT8CfNy4>yr6x30{vJ6g-G-mj zb z9Q&*NAsk;gx-_sju#P~yebIS5JpJh0Pv&8W2&YP=4la@ac4H!&Fl{|v#pT{ zOhAayheS$A1i`cvYHJ@91fehWy-(|gCR|ex>Wgoyu~7Qdb7yxZ(>5f1=!Ly=&b{}X z^Ua*^p0htkqag(2=C$v~FNg^JMip;>?vC{z;jw^Jq+%T<>rb4-uFmAyB!`iPbFQC; zMdP6|8kdXweT`}vMrE&a4;mXuQAVnyavI9D{QE5UkX3&3y^Np=*ZE|~MUM;LPG9)+%=y%%;pAYN3q;~UX5qKqf$4i%FxN+p@uZEXiXSxgr#PFl z*=t}u#oK(qD(JgmsA8La(C~H*)}{nou$c*1nGyq*!bTiVvXL!N3tRkw@jOgPmbZ;^ zk}vu~F6@iGU`sx3IS_G^?AW*wQFdd}F9QRkNfc*C4EhprOt`#aPQ*pRSUL%njN+J~ zwK~j-s>b@N})M%R-ES-#oNtG>aujmIkgsw%pI6LFn4J7(7bZ1 z$2qgc2{Y%W&)rBa9;ig$U+TX<`U}^$4uO#ogNX=+W|fUXI!ICq&ZV*g!VI^7Y?Slq z3`>pSU5}fb)krTwxlL$mBcC+P4K9OKd|jGiEw3;jn{5M+Kv8CvvoR2~;RD0UB}$ic zFgQ&eMZSI%rZ4e!H*FT%P;qDgkM|i2PMHyxBXNcZ^`sFlt!6(1ox3MK_Q4c-x)Sa5qeUIOy)qdwcHeSxP>9_n}@n ze5Tr#sO&y#?1IfWu7~FcUnu7_o$$cx^f(Ctx2(yOo(0ZOn#6fGOLaaAb+Qv~%*d*B zg!E%j=ANvB3P7N|mZw3Jfw*9FQJ;dP70jYR#0+#LtG^E_Mp;$dYepnJ zTW~u@TqdHSnPuHF33K%@bR4%q9R{V(Q~E&Z_mqB620UfJq?D!?+8y;v-O_k#zvvhr z!=XO-nQJgP=tWbDGxRosLd|m>vmMWN9(?@4lkP{|%ks$;d9Wf6R^=h*^vkBU#eKJm z&d`g_{m#&hixsZDZt~vjdy9CnZ)x=LjwhXuI+x`WD{`VDC#v$oBaK@5%7OQcN8L0S zdq5`5^`hOQ!{3B5v}Q;^Lo#SdB5W>e5jkBf04J`N47ZjNVYHq=He0aVw7DOWgzNcA zjTLqEJLHD+N`AhmmUZp8kpQ2x2e}=)#$t@0qsVg<{S|f81jI-)?bGcwCW`w3_dBtL z$nD4*8MJj4F_HTEezdFoH6O*Tt72rvp0;Z&$S+!AHM#~qgz(Je>C5XPXzK?Nj=qWX T;MnVNjGFgWo4TgMF08)*-=$T9 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_557502.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_557502.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1246743312d533b0cec51a79603837900c08ee40 GIT binary patch literal 2847 zcmb_dO>7fK6rNe{+Ft)8fyHT)I{dfDNF*e|q?ABX2n0e22RQk%HQtS5$M(9r>o!4a zX{nF^^#Bo-?A}6!L@1RCQY3mS_g*lL8f%(LNc6%j5PIpUZ`Qkxp;73eBYXC}H}7xe zee>qmV9B-a)8zP%9hNM4DRQM&Elni99Ze1etSSGXA8$jNgB#e58seGRYv864qX zP?zJxckHus$Ipo8PsdIq*rq6Shc5W9hB4it4s(5^8HZZ?O8CCLUSti{a9slHB4=<- zR?WU`LkSzM`?j}BAS=2Jx4}$-r08j~bg#v6C<7U8YGHVOWE>AgukI|PG}P}@r2wn8Ba_>d?mLp zo&?`Lnwsd*y1@_xHgHX{{@X6NOwt zPboRI9RkuiX$6=}^tuIYX)hg_Jui>- zRfI$4XiexW$11`eb7WZvl?G-9%F(%j?}i?1SriVMpW3q0;n~CGFK!(*PcF0WtEVoX zx)v+%sRVb-9elw4#zt2mFj57Wh`WA@iDsbEu-*U>351~x;w(+E&{U6?xX%3%W+v4o2*op-#g*A@(B?Mf6`J?Q8fYzL44drUO| z^!9baAURJV_`-zs2&+;HwS`WtbfOyoKn7D9lrX2Kk_Dxp5iVEA${E4|_!5)E2i4bS z#7t^RChmH@37e2oUs_d-aMq!!@4w-b_;R8Ra809UV1Z2y0Y&_YcqU`RNZkn%D%B=o zGEvPRxG{8nXwe^j!iJyt!!=qm zPzn{_(b1C7(Z9z{9c7*Vwbpymo<^hw(EtiwE4i`OCmARtl3atJ-_&{5dns6nlicsf z#Nz;B!%UjM2QX{y*huoel}YpEvGCg8cz z`}^QB7o@m#6NxvI%31F~geM;Ru^`NJFFYmI%hL-?pTqP4)9*0-z#MRx0~+mOjR&TF z@7wHMuU-C+L1{nywPl!0^sJ@L9DN-?zSbMz>*1&Edlx=@wC~}*MSggRKT+XNRQXZ! z*q<#O<(=Os=IFDo-R9`EvlX_pZnE|I)-o>l&0SasJ?eVcwa5=G@go&}q{^S%zzPE0 zym*JZ8@wG{vLe;b+dG(5q(VkMst$u5z0KN1m{wdE@S_FFmp1BypSg8%>k literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_560359.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_560359.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84fc193e35dffef750753163edcd540c8e7c805b GIT binary patch literal 2851 zcmb_dU1%Fe5Z;yUN2edlKXqI?a+8{rq7>VSe?kkUcJt>rX`F|o1Qh5Y>&~)dNhj`3 ztsPEqDS^bKV4FTzg+km?XwnqgfFXTr-un{8rWRWl0_jWMoH%{yQ)f?i=SymwK6H`x zW@dMGW_P}s`85&=A{adr-%s7-5qeA+XRtQ0wF1Nvl8}URC{h371h#lz-k)GFl5y5L z$5E|+C7;Y>!!ODxH|dl7SD8EZLV}k95+kE*+P!8zMGRg*A z(j~GXmvjju;QB3iEbb5&1KSn_%O-sp7#Pi>xPM%wmxyD+=VGlWxRt(4O>3b_65`1@S%r;0kQ6_Y8scPg2h z2K!2(JDG;Fdo?xHtIPRfFL1r!lQ=D_1v%GKoE1SL>S{(;3Svr;WO4Awq5i@5V1-2M zo5aadA*E-Og1QF`(l}`)m`wD%0fi#--E-X)yc+B` z>Bh(A$qk0PcIwKh>xs(S)ySL6M}A^{V+OatFj8A!B3%6yQ(mA_=PWos1Y!vpDC+_m z1om(eMDv|SpJCWzhuK-DUuRF*XJE5ccM1W@?g&!{UIlF6GL~@JuM5s;JCy~)x1ADY zMvpjq4Q3Lz0|UUvbPvGZPACkL^AvJcYQO3Hj<#Wl`a&#J!u!kHvk5N(KS8hr%|ZfYw~ zB$P_#ayF3EE|5^EmW9bgwNQ9|cy4$t6nntL9)x1Grsk!##kMcn&C|652bae`|7bnZ zS&ej7l53IuKXc~ETGMOh>5ZngrM-)LD+vHi(*W>H%R*w|Y?bM%HAHJI?X{Lp`28iY zEVmg(;b!x!+6LxQ@OOs`fgS+H63U?RE2>%=s`@wBQ&o1TVcYd53q9iH;erKkhuzq& zl|Be1lKexU=BYkAbqhvfCGHiT@D3z)^~3`}fSR{zBPsBbC*B0Iua-uy@WQRfyL!QT z-h*|$p!FX~!hi8X3jR-#;XR)4t|tK_0G@=XCkBG_3XjK`v=d#M*c^5wBQ_BjfhNF8 z`fO1qAzjvVF_qM0@vzk&p3t~v5MU_oqlI5CYU%+Hd3W}CnSR&g<}Z$=`;H}tXuR)L z8HjUgCq0wX<$|WDgr9_#qxbhgI+vuRm5U^h%M`3$5ElqOO~o>2sJq zVEP@VAD97$8PI4YbDK;(^Be73Z>Rseu;~!|wGEg|^r)fH9DN=}!RUN!F7|NG!Iih} zb>Hn?6GqmB6IJ2FsxWFE{j;IH(se^IM;~<_Fh{RLns?My-k5u%f-8N?7gk#Db>8h< z6NcA?<5l7Ks&H}_Eikk&@h$sZJ()S~wdYHttl9O&D%ToQl14DU5D&^$&)e+c7vss&fY15A}evHDu zqsSj9_6YOWI<9nV`oj1KINQ{+6kZHJ6M)=wTkAr-bna@XlEO1(rD^BN?kGQCV>UJ3N)86AVTQ&Ny~l z4R?_#g~>$zDI=UrZ!RBZWM+!F<0>S0S(I4?WokCHAj;h4d&8W}PqB%RAFnWl5@8A{ zVc&9O0>?woiP<4LluIhAl2vksZikM)fA+-b@`(xcifg$(Zyd(4J&XHlP)>?JCOHr0=-aE`e;4GdYtQ1IZ*G(BwA!c$A4uj%b+t zx0GjLQZT%2lo5Q=Co*B5^a)e&;}!!RH_4xgU5}#UliUmpj3!W=9@5B1#Iem}C4DT; z+mxZ&p@Nnl(WGoLE!&)=Nx2clX4JeSlT*g1P3!5L&FV&8Q{o|;9Zwbv%?{;@MxkgB zS|ctra%(nS02|?pYCbiV)D1~9Y(76ctSg4Dk-M?+xQ!=t0>);4{B^eNT08*pt9M{-Qja*Enom`H*oX=rI9pG&6lvXXq~aMz)? zyJ5$!t{h4Z7jr2iozH1?;E=o*U5Cj+&q9%zmg$yxTn;r@Cst!S=R>8~KI`;qRsGH5 zw+3$vE=ac<%8gy6sxE6_H5{8wl)`(hzQ;mnhM(qV*?InY^@6-C9JG$PvYGI7c=p(} zDiAQ-l~b2bT}{ltQHs8{(0;q?SLUsCaEr8Bm~2i{jG~rH1))i<#sRPe#1&=^nJD9f z7)0*y9>^z*;U*QK%qCzpkq;H>Dl?2_JS9xhhF366K`+Q1+=M?jJa zIw!A-M819qrqA$p?`s;HP>F~G5ADMkj8a1mK;o3mRdQxCDLMTS@Jv<#ppXyJ$(=M? z+D=$(p+}Xn7iH-$Q(`?X!L8OIg(jLqhXSw*#3Kyzx;4gs``iKLna#1N3gStm;sFmn}%+sufTmMamG z%OJq*tpphkgr@C>y%_xT6EHzEtFOf7c3t1~X}$H{<7jLp+Ej`*%}eEIv(^77+z)VK zYG?aOO#P}Lt_oGFLTt6R=?{@*xwQz2#H=$;;3=@>4<9G3a8q+AjV8A#b|a9AFL0pP zY=V}JvXt==;(J6M<|eTrdQx+vgi85*ND^Qd+Yc`C`%G@@Q@E#JkjD!wELi(?BW4d$#Y4V(%QF+oe?pOOIM|&3z z9fY6$7$yrnsj9ICevhC~^-SY*dL_O6I0OXA70IA9%rR#iX0?^@m( zc+#}r8n}9<#57b?UY&k*9?y3yoWHkoscErkSv_c7C4%{_$B*Q^z-Pl z*t#Nimc-7o*uAyJ!@Aac%2IYQyDavth=V0@uq>Y5RAo&-ky>lWt(4CZl>d0tNrScr zzS*v>l6XONVsiZb{zLG#Xoim)YITvY%EBQ)} z1=u?Cj+vl9+yYBlv-}kH&$@&3=qc0<^ Txc0>;M%8=ERn3>fPEdaXe|lHM literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_834634.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_834634.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbd4ee0fcecef55a832aa7d10aa52bc6ae580d6d GIT binary patch literal 2862 zcmb_eO>7fK6rR~1d)I#nP6;F=kkZoDt&)(~RH%v|kdTm4!U0acxQln~*jfK*cikpv zEiF3^PD3M-Z?zA@IUePNu63Vv!`$>qvrx=l4VB z>a5Gl9lPZyVx7LzigpETWsmOBsadd;y)IABRvuf@k?tWLy7vdlj!@<_dmCjrM|6a& z-w_=_=N#O;E8{jb&~YmxZ{;Lk1_nm+C_+yw46-rvMq&4MyLXxGNl-E>+q%-Lv zgz{gO3(47pstJl_dJCzPD#1P}mWk&~Jg*X7^*Q6u3xl5)lv%YXB&ER_Avp`F3;BK_ z4PggWB{`@`x#A#bg8(NnD=B$N9w^Qyz>?6Ej8@1ek_Ayp9FNA{kH%ofY->K0NR{$Q zEmO!VosbZY0#Zvb8R%6A`GZUSi~SW`^K~0z_3++`uNppVoU4brZk_&e>gLp%aJ#$K z6RU<|#$??eUXEA&hm7$joNvjy=v`(j-j%jBah-e5IMuQ(`4|1mr*4G6z<92my?XX~ zyz+K6`1;z=?bt8O(M?E7K$6l$y4v5ur>^vsxyT6f>WI&g3MZNEkmENQv z%XY;_fK-C4o*1yVyNTm|nEo)(tl&}_pHTL}W&}6WGbWoa z!2j^^V@( zd6xAw0w@qR&Rf8f5X%`3PX^#it)dJn@6zqkP>QdytJ`eDmaV>&br|AgL^kKjSmSN0 zzSTo1oq}@HBe3P_O55ES(CbCdE=asK6L&;HyO7Wh(@8f+BKHp@?o-x+pGLc&aQgI) zC|DN?)GQpDN>ZA@Mp$IGqC|eu6|tZ~z0-5hiU{uq8B3+LqClz>UJ? zz>3;@Q8N9Sq-u$zph}4$>)v=y>iQf3hv3-4g&-AG2h%KUc{SgCDQ*<28O_2QNspeEA#p+u&Ehb$(!jk5>6;jX$=t$D_`H`_hBl zz1%uKzQIpb`KcOzZd;cIhXNhOwB0G+f%l9AtuiR2y_#*!mOUWDe}fX-$WVZmq>z?m z(%Fn=^2tITI5AgLtWIirE;S2CXY!g=HtR#0{${-r>4GeNhpd!#!_O7Ok}RE2NdF`^ zkoqx94H{$o3dj?U;uZ#n8B#+P%YGV)o&sA3;N_qR{#J2 literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_931009.cpython-312.pyc b/src/temp/gen/__pycache__/sin_kernel.py_gen_triton_code_931009.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b311f64ebf02ebafb9a4264c82cb91f1cb1639a GIT binary patch literal 2862 zcmb_eO>7fK6rR~1d)I#nP6;F=kkZo1t@0nFB1I5LNJuH+04HDE#k+Rwtbeq-ZWFXt zH&P@>IV4e4#Ym{6Ri)C>14u|5%e^P#sEHOTA<+vrL#R~s)Hl1{&9;q_9y*d|-g~p} zy>It@^Tt1i!vO^2tN9oG8v<8^YEa2&+g$B9>9yUU3{-I+dg23`P>pT5%j# zdle~($%g)JBU*|Q=>_I)VpNZ8Vtek)<4Idlm8B#(tU68J)fg-ep#2 zJzn0}4M!2{^xa0Zt6(epbe~SmfTisBc>1>UIFgQZAMw!rKT=MFGN(D)D9gE`D`bPN z=n6XL;^sXWcc_7m8yR^kC;2ikFq%U#dP*TL5yK{v6V#cQ-=s9z3>1|@S`l)IjA;4< zMaZWmlaUL8NS@NBO};Z_DQ2Kh(uyU`mh%ZomU2>FQx%fVq>B(L zcttKGXA-I=D4OXnq*AH``=nSVo;C5DN_f@fjz2HM-Cqe*qeqSNn~}C#XKqj4oLm#`wAVU@ zs*xdMVlxVW3rtBbHwkYn1{)w;V^)!lvM!6k)ELgG_?*!iRYfSf16y_E za)r9iq_Bt=xH7GgSGR^vx4{a`q!t++11}pfFjy`TmL$E&NyP}_tq)-O67O}u7O@WP zgd}k2p2856nzE=7qfDP&Et5%#nNL7xauQIa2^>`NO}Y$aKP)CUEDO0SqHywRXHt-5 zr{cpvDnV9H_Bh+^#PJ|Ze;DXi>+M4~53R-Ty>U;j_K()u#;OO$l>@LD!*+VcWb>t* zB%3VIIXP_xfK=LCLe2nJ2t8t~RUMngKAYnvlU6dKjgWi{LfH`;WY7amc?=Ois?^t&1YYn|?)7OR}b#)NDUoGuExt z56Xa}41hA|D1)F3Im(bqC{2#DHS8nZ+<0rZDjtRi@o`IlOqdf-3`S zmmlnZ*m=Kmojz?m z(%Fn=^2tITI5AsPtWIkBF53m9GkMJ_oAn{hpk1#-x*&_+BP*rT@N)&RBugh1(m%-! zq<#!joyHhHN1-Pu{446H`w+!lYoBkgQz6_Bxc9ZJhE_tac+j?JM1?H-%fkU2T|T}# kyfR!zAl|vOMV7615Dvcz_28D5(-<`!sYSZxgBF6n0L`OGiU0rr literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_108037.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_108037.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cacad4998643b2e16763063249fa2aba5849eb5e GIT binary patch literal 5497 zcmbUlYfM|$`Cj|lzJBt?yaOQ+-0~<*R??Lv&_YO>KugonhIREWzSqRy2f6o>6z^EN zsZzymCt=o1IMt*&Ql+{LQfdDzlP2wKO zQY68VGpZb&RqI&;N6k2P!tjhWJ_JrZ1a2x0_tUKTDs?Njz@&w>aKub$k@e`}a6iRb zuTqoNyk@qm#H7tYaJIZm`vU2)Z(xyQisR|C&pmU-d+zk)uy=IFY$xuFVm*D%J3KKw z<~`#p+zt(5jPTp)H(`T{B$kW68H3grwk)|~xRs-bl@`!ag;d%Y8Ij{wPDZS3A6g1I z^R(rt43$S=H?ay-iYlTYBUXtjQJD&>5-s{#Kxaf%$bsQeaCKfnH9$*Ib<_-SAdb0J zIWPklmR5t@8dt%tupM@NEV-?LDyJ6JASYgJU%tFHDqr2c^sfBg>k)x!!TH1+2A1#} zYGF^vjMqf<7>eOJs72Kq@ah^3h}0+w_-K`{ZWo_*2VY~s-io9Pav-PQAm%MoEOl0l z`Q0~(B|0ff7Dz4@A10L7TIU}ke)tji=Kv}Z?C1b;srz?f{g7$4jv)9)3uG2Xac)u6 zkdMyheJC6Alr7=2$fU1^<0y7;AKr#NzBo7wam7+Ardv3d;EV0M{vL1bmC=Hl3U*Si zNiJ+rGiuRo7lorT)`S|63u%#o=MAV4=4u?M_{N`nOcAJQrm9eiHb#b->b=mJn!T{R zk}upim)3m{m*>@dC$tNFFDg$E6s= zpkkZjBhx$+@&;Ii^09$SiZK)jD`tjg!qc2$igJ8JP%Qpngb|NCu9!411yfcrvJnx4 zL=+0Ir`Wh47vjRAS6om`f*9dBPlsY+ymKP2n5MMOr*`D1LXIgUTn>WB2UyO_gtmvX z`QcQ-Fr0__wrPD~9>+b8)0~I1is0d}B^2WvC}R!JhrBTG0=Yg6E2a7{5cW>I#aT@%iRM z^Faozp^&o*$tdKMVxRT$T!;yT1L0|f^yS2WUYa}~$W!eEa(rIwIiUm8YW{}bS19OL79E-=1XM!XOScQMo8njV4o^@v<(t_Nm4V0$qVh80BEV)xvF7l7Uz&=&!> z^ged%u_JweQJfl)cb4PBT#!Ej`@#32unM6k5nomkcIRSOqAN9i*LFa9=3Z&}vSrDV zB2)eAr5)1g-Se!wIsK&U?p&iPvnAzg)$MXgyL2jBUY9yAm$yr2?p0JL7t$^1+10ah z`+$6KV7+2cI-7O4Q%BMba($QV=$3}F&eFAtR@vE_ambFN((rv(MKUP64!-xo`{!?- z&s<#XTyHxeyG}?W*@t58IZGDLC(b7?raIT1hozJEo#n~tcbIf#YFe&t&(zA*kFQSL z9{OF&?U9eYzgdvK^SoUByzCs8PG*79|IT=tO1-oc$ynv;-qjI}P(C^Ylt=TF56UZ- zJC-_9XEU8^^?kqMe;xg0^ltgHpe2aty*{-(y)>O-(hcj?-Jdyo;^g1#m0-BDa!p63 zU#{u9-Fv5|UpJun?UMLl!6Y$wzpfsbJLRTsxvo2YCXY)`WQIOz$&B3euKrkVd-~JX zyQO0Vtm%ml2Q%Vsbe68T&M#kDx&$K2-bOOr8nJEjTqW)2RxY{{uH^i3bSe5*cPIFEo$itbzl~3@Wna=PD_6CFX{&l==U(eNp9d$ns(#9( zDwET)t0`RuzE5Ayh@XsSF3X4eZnwyXpOT&Z(#d<2`P!)~r{WXI@zlg;)S)d1T!McB zH+|6f@ad_4l`v4t(iMPSvtF|;6Ns46u?0O#swEOx6sML?vyilciiZ-pNLiDxCTxBK$j2w(t`DF^%FGo?zeZxl;uK=uGR0mr zSU=vzn$-PKF%xT!nORHB5`KJNX+_>aIcf9`X;zSCFDUJ7dI|748MnqkU)Fw=j?rS7 zt}voiuMkj_8nH#-;wEM)bnJG|=ARP}!ISz*;FGWCFOJ{lH}1n<9KX#^ArtFB2G)U} z6|4c<30aVN#d#}Nn2`iBo`dUrYy-cL)LJ0wJ{|T&LUW85muYyhrhArJ`*Pzsn? z?ozDNAo*`&9aCmLt$O5PY9Y2BrIB=q6-wzA(b8``1^cu3D>dA*3h{}Bla3ZUwlunN;EI$cK8^^W_0IZvKdGRjb z^cWRtboj+Vz7ZyT6Ep(WuvwD$JzmdxyaAQA*82pH zJ+Ig%nBY7&%<~amy^j^MR!n*9+h-ep05d;{8?3xmC_RTqW6&tnG*n*&2C@S%Uu6so zXofOt=HXugXnB1eCi}r*aj$|@ys*dsL8WgaeYT?U-)@J=yje#$oYE;ZEPPVkLGwOlDp9eU+H0D{xV^G6d4RFOSa!`Bw7y+tI55EX* z=7XZ31+5z0IJFBR9tXRCYtvjHYeB4u&qea9fW@yPfrUmQ?z>8*k^8RtDNaV@Q6Y_~uJ<@p zt@vkQQKdvXsk(#k%9@AlS7?32VqR&RJp^I$W zbPyF)%biP|n@(-k)B92IgP=Y;LynSv?)+m8>mo|a7f}LzQ3AAEM#A9GWc-M1CY^C^ pF_;K$5-?27zl(05Ej-+Mfh0)h7miA@;$KIJlDfA-SM2JJ{|C=+uYCXj literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_12912.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_12912.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1baffb45df85819a35721f9d18bfb97c97bfa91 GIT binary patch literal 5241 zcmbUkYi!%b`AAVOiEkmfzUkS5J!a3s4)_wgbtGt z%P2)4o|rVbO=%dNh@&}@rzS!3#>QE;z*Ae`t()V64$gMXdb_B8)Xp)Sokx>Z8`ndK zb8LZkZjKMyxQc7mQD@0+*EKX+K_b4Qw3=&%@VOu1rG_>7h0|wW9P+<%YV^4O%sOa2 zZAhaBUq1Wlxguwve5Vf3WP#u6Un%Ty3H)xfNz$CHjT~{iNKaOkaf0N=CR=3sc2Psp zvkfmLGbP??xl*o9SP@Fr$W?Mp0(2y6htKh?YgpW#oX-2b(QcU1|`!p--pOys8 zI%W(3eIU+g#IQyTX--iRLLBdBBWu&C(sZnB8ZBW1YrKJ|fOB8K$4-D-0ts;WA#s-Cyc*|S9U@KV)&<a6Uaa4%R_;|!7OI=F7u9N?GW4LP zJ~NYhPW5&#&`UMXE2j&zH~UnsMQ!d<>2BqCfvH-o=}?)DJgw4)l;g|pnoL-A@A}~7 z4=>!jkiWFhxwPk~>OQIr7Ph23U@B)Xq%UMHWjmLceaeYtraCkJ9-FJpj;r;1^Nni# z!G+iE4*a40?%+rMPiEAo&#Cq2RAxjuQ2@^1dm}k2`<=O1-l^92E({us>Y)MPJXGR* zSY126e{O&FbiQ-3x$n2a?-IXB+^>EK3VYe~oz`O3cpR(E7%(&`q z&3PcGxhr|;S0ni=>b}0a?drbgROYa9;sIs5e)8(c)N7fM>}#J=yH_C@k?=I`p21xo zbb;bS2vdYS@Y4&u??c93w_Uf-BY0wT()x}g^x}iql%^LZQ~M$D?JvP&YnG)yB0yL9 zYDJYMUFW7u2)K!Dxl0AMu6z9!b?744@s=%V0bN^Alhj*wC`75!er;ttAeC+jUy*v7 z#P{TembRtsK@#lRq(lTP6dW7d9lQAeCCQX61=5`3nj>kI;ORG-Y}I{>plBnpe%KwF zER{|s7nCui`KWI}FBRfo{gdT4r?1JA+sJQDFUfOG_^Y_%x?S7_fn@R(+ejqz=PSX2s)$7AslZEQ1I z?a-Q*wn#{XyE!lc&$^A{FNXrWfFl+teaQfT!&?Kt)i88FBe(m`Nkcjt$H1)S%IXhX zRBUKzNE1#!8YVxkb%F#q!ZN(}B;y?#kBfja!e`N_Gsn;O3oQT$JD?Ns5o%Ku;V_Jv z<8YXbjB)IdL__g`IvkD#*syq{+oZgJg}R~pw}O6N|9Bo*4HrotK zg|k4x_n7E`F#bnD1s6Trm+s5-rk~Av-uK?{<~%o={^sgfq>e1^IkE)3s{`{KO&`7Q zYFVV7U2J`J33^w{=GbLJ-_-%7U%6DE8T{&8J+8cvB7i?bzSXx#9ayYAutXh5^)1^S z%FF}WGdq+X%6hW>SzGSXhi~3|bK%m*S3bIOpFWx*9=d|cz(dND*{xFEg2S2WO|fgk ztRwBnc;0qpM&@6idp*m(^If2QRb-vFp#T5NJiN-C-wLRz&)jgZ!(-bs|D!h zQ3xd=fiD6kNhHan<;m%bZ=-b|dz*Z=apWCK>2{;G(MQUb|4W~>?6aPdt>z(0+GHCX zEDE0n8Np$)4*M*R5M;z71yI5WbWo(nIo3xSd5AMob4fgutpQeq^36vHFJT<6(>bVA z^}4VfItNE1zk(je0~U3 MHoYCW>eBcBKe?8G)&Kwi literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_186313.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_186313.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fa27342bd3db1391aa667814dbbb2e9f34c7a0 GIT binary patch literal 5366 zcmbUlTTC3+b!K+fs$Bi9>o|+H zHy{Gl!RZsPQ5+GXYQY|4!fT=ijK%l@){L=Ozq&>vA~lW@Ig2V+zeUcvNv^5HwlW!p z7~}|1V$mz*vh6-qw0^Hxp|f&iiRN;VVMN7U>+%D{kADW|vu~RR3=?o}+dN>z*jb3g zMDw=2t2DHA?{raOe~EQv@)k9sX3lV*zK3%HTB2na{wB146aN9Uw?x9NX({bd3p$|h zdlqiS#9mm_iL_J2a3g90SUp3P-t^Nsnm~KUYf80fb3}}LwqwWlY{wPzw)I=)z`<{s z18wNwmZ;Xwv~5dtZk$j%qL6K?lpedtjm{aHg7rvnw zqT!fgVtF<;7F3M0K|U@h=1?Tgiie+2j2fB-kW&m?Tm&U?-U|8@TQCxg24kXMoKcJd zV1vFxijnnCio9YR)rNpNkmnR~NFmQCRzc*$T+q)(H>MNC>1fF`R>TH2cmpvWhc=Hh zlZT6h;NgBGczmrYhDq3vH8vIX!^96V`ZT6g>eFz{KgP$WCKZc1U>Dd!VO|gmGK_{6 z!m~lpVAK?;LrgKpLm>ehu1!Y4Wh!O@T!%L)C?<8!f?{U_fvMV{Jg<1v4f6~WG~$r(Q&{RLzo4hPip^izj>4g*DTXiWZzARh}x z_&%@*J6-q{#GXWaRY%wzbDgQq483e?l%9KB>0U6;n=@qQ*m7mNbb2eC_qJrivbTGQ zuFhAumpljMii6Uryt^)QRd)NNGmoo0>6z@avbSrIS+4rNbT-d;Gl#Pca($=FbV(=j zj>@H~7TM8~V`S!tbYjI>m5#{H{qMf;-j#Q*=>3#=7BSGcQ{LD zUYw8Ttg@$PaZqEFj|>3kks{~k?&^j1`S#4&T*p%V(cklbnEmbSBlq*LK~T~2*6706 z{8)y~HY|I(K6P{_$-KMvZgb{RrcHM5&koD(gSn>`5B!c?Ya?JNxS>H9u!F)#)+W*_3s`8D^(*;%|m?({kI<2hDQZv$Eru=JG39?+32jK+Y>S z_sBIpvZH75^vC9h^k3PBX8HI9>~_Zm>EvVDbozK!rUZW!cTeN) zd(Z_+A9egFspTUFv3%Ql+qOW!qXfvUjU&~PhAc}{%bkgKND1@vP{ivc?@!3`x~{@n zuUiu|Dyt1s6ZGpgL`hv@;Tb20@lvY`#-*$&TZn=+a5ipffGtxRmJ|k5(wYQm&Unj` zFqY>Q!4j>N`-s8&QY$+CFIv0EzQF;lm1D!0=$z9~j(>LU)UF6Hti}IX_Ga(bla1=5l z%YDbaDs2}efvAh{1P`s^ZxHp}p7oxN1>(_3Rt%3ug8R5Iyrn`o9uwgCHgUll;XqJ* zX9LBr8j!??ja9Bkp!)^2+c!@Nf=~-yy=p6|Gl~ZcstPRy0mZ2St@GmwB}9VhhM|qO z7h9;k7Q7QAeFlX-cj8h%-v9vL2%UgKcb_Ejy)Y`4V-Yqw%CWt(dkQbku}D0?Mugrj zo$@Re>Vody67fsV+x2(qv%@($_p;o6Y`OOPe>VNa{zv=6{)ghyrBV5OVEH&F*K+cn z5Puv*R@H7Fp5Z(R6R@#WY{067JDdqV^M>(T3kS7Om}dD&L9yXNK@j1=6nqw_X|A!8 zP!Yj19=N6$;q4S>_~$^IR{m%uQqNQBlc<=q%Eo8ka^Wa8HXw#?z#+$Dd=m&O^ca*t znh`@-*?HV;c+L)B@x?!WT=O8*1^A`rUj*Wl@Gmq%Q2a+C@^)tKXzFOXC-qFm^|tqp zH|x4n_ldn_iSAu$?Olf6-csPXIUw~**Ea)Zi!}3?am}4coyoW|{TWmC`g^au^UC7& zkETDIe#G>_+hGq$11o0B?dhA-Z?!IT&Ub!d-kWc1%DV4fp67G!B$KbMNm}!D^_jrk zw)x@g^^|SJTc@LIJW1QMc1 z8gCx|HRo$Qd51IauFqFB{;SquG_5rdj6*u5Mn(W!RcyFk#8I)Hhi(DFBmt>$6)*`h zK_v`dpBVUt+Vn|T=0@hI`v$#0Ptd#QqmbeM(q}CB@qk7~Jv|d}+d#_F$6!y0fHXA@ z`3$QBLpv<4*SRbW^ts20JT?|6e8(;;f^n5fwjJa^7&&BwQSPaLR&#ajF zpfD8?1uZ;r^s3ptCBy|>e-}bl3s+5jA!c6&Dn1B-gH9q=oR!kxinIRaAdbgm^H&UE zb67|Dw>WmBIFr}Y&2KPL|4L(%G?+Xm({*bWf~k^* z)T8%NLqk4B-5GxZ7FH^>3#%&vpE*;I;fh6{@%YC$#V{Tg6>9+QUorT~BwxttOK>O0 zdDS00-iTt?T;d;#N4WP1bt(Oli^jRBNU)a=0Uy4zg=sY4YQs?;GDhTJ+-g`sL(a(ARUzN`Ug-9S)`s_%=d zBKjqXpWRH-hwLYjY&vFllU?>XUR5*Z;++DT3e;;wh<7~( z&ost^UiU?M%w4n0a}kYsY)J6b*7D9#QSU>t)TqXO|K%U-ABnwsWb9z<*dwz?X(RC_ z%if3=?eJgyBTc+0QzmW7W98n>j0=(HqRU30B+90dE!7#ss%TLOd#t-C<|N3MRW}?!x0=L@`OLv3Nx0> zEc;>VcMj)7t?gacwgI&4PA}@CG<(^m7}JCTwl6E$k#~;&y^1 z_$t<=+XHlcjSfs192Pl7must&b3Y>2Y1vzYyMbBko3P0>=QO$}H*B@tZ#Zwl| zjUvN=YgU``2hfjx0_U>{Z^C}E%iL4ztajk1$Qe$PsMx&DVk@=FvVMbGh&%9R-hP9= zPNE)KZp#z+J8{<|{M~q?MIvNqvG#Zi?lvvU`e22Pu%;gy=SJooxD)33De3giUnn%f z8>d>VS`6570+IF9+HKtWEhD?V_7M@0}vQA%_0^4yFtY05=y8lB>#Y0Z(5&j@0a z63O+C`ZSxI)I2j{W>Vx*u>`NtaXxWQv!@bi&B=*edQ#9FSwYN5nrkAN;pAtZ(HsVv zhAFSv`HT!oGF5W_^O{4FGolcUY7QhW+>$INcp=86*1+i+ zIBo&cHEev1H=Y(r7>nediEwd{W}kuWyVJ9&7yvQgG{LkMFu_DRHYsLiXEa71h{4{8 zgeb{X8P-5ciL3w`9EKu&NNcXl#Dqk)YXIZmARRLejv}T>np3w_(!5+eKAW0Na^MJ! z0{V1}aB5C+0x_;Sj%yUBQP9V$8o@e4bsSXd-U5+$j0j0luyFG?ZC}08cRnLdOEX+t z=sUy3r#bmdCcTrJ6w-2^B*yz>AvMznvpzTiIg^GZWZ3SmnYkDsV^w541E=%M(7@op z08lia!4#Vo#I%qU_rQioUXp$W0jHoZL&)o!-&xpMqLw}F%Km$S(4uR>Ria9}mjgSL zBlYu2>*hOjbEPSCC$dd#+NKhOV5_u zuIyFW?_AdAARsDOF z;mTtv_k2zBrwgZxXG^`yzU|7P`@T?d@+wzuE={VDZA)!xq9C$TpnI*Rx>AyK_FscIwfODY6`5@H1xMN{Q>E)%~I~{|c zh@WOZ&fX0jRYo3wivD-T7bh1cOI*2gIkM|BUtgZ81VhC))L>WMSqZk5=&M}GweVDV z-;(p%Q8n17W3N)BfjaC1);~X77%m&9-g287dRk?k2Ibai!MW0c{NG*KfZWaV{(>KbvkTc}W{a}#+c=xJ>@~w>YD*9J ztfgP|t#_O6IC$6J@++>?T%1(>o612r)bjZy`S&N5&a2x8Z*{5LUsQd&l|%Pv=cU6J z4(Cr5Pn1r5MsHn(SVZFUq#J}zbBcmIE2iV*IW$`s0KMX40DFnK~) zk*D-Zglb6BtCZ;-5dN;Ch_{1*=#C&)LZ*W?RK!HsF`z|I@q>p{Z~ z3Ss-JD9VmLB&EHcIJ2!s((z1chLaQHNntaeNC{~vkx5ICKb^dAE)f?*^6EWQNQXe~ z;Y$P}-hW6!EnfXf01zRPmPpL4z6ss0=_h?-6zm1ctez~npQx>C1-ekHKuLu$L4h6A zDA}a(c#GuYF6z+eV+T*}vn0xPK#LvFNn~%E6=>r30MeMxt?G85;L z(wDI6HDxJ+jE~dKlOg>y}j=?ULHEBww+WroD%l}U({h) zSx?b=NZDOr*sCrMEk)C_018K^WMVaCsB4oFY(oVXu!=-$?ojpj58hA6vs z%f(K@GDJF}Upz*-QRp;!5{e7U@@s|@&zOnyC1{5IkeW#f6i)(`=#%~l-G3ES@v`%S zg~4Kf;e}H0-PS9u<=~a}zk0jx(0lIm>{*81+g;_ka6s9ooPA_+kMYfW3to7lsN$h= z|NBGNhHmzMH2C4*UFJo2N5c_FZbsJvP#~xrfdrr&{fGCIa^L|SEN)fl)(Yd!_vg8_ zVV)^4#o$|@b@BCu*Gt^nKLpx0MYeQRj9;ju65jal7O%s((vDc4a#%mII9Oe4#HDul zuU4R2LNOdN{Yt}j{S_OFH2dp+;(ytea}^DeVj+4sO2aCH=a{%o=9Ef(T#<0fS| ze@vy@R~W>Gl~Loivbt5`QCL)KGM}0O`w^&rK@B&nAUEA~BS_nbI$zTd`9V^*LhyhJ1fV zohwdcV=lE{XkW4Iq>1s+Pr10W5H%_HCPyNr#&hBC_ zahfCTx%{^|=Rg1X&v*8(E|(KQsgB<`f7^x77kFS4TduS5HFOpci&!Fv#`KmLBebz8 zWF8|4#1T`P9Ya-bSrbQ2(T~&ch&4ZfPL`%mSXs+u@@CG^n2oh^XsWz~UE34r_R@5) z>$psgIr1xWUPfb16XKltZ&9;^ms-Q08dd2RPQ3Esao_2e#)f<&kBuJZ9ru?N&CL#d zGyFF0k=TnOVeqC*Mz7Ga#LXN(QC1`gnwyJcmgt)~3yJb=WJ$@Dm#vV>rHZ%(p+u!r zCRN74MxsiplFBt&wP@@*0y(o(Ezy`BhuzIvr~zu3R1>!X9j2A2)$G6sXxN%Zs@1iM zW`(`5sjx(01)7~Y$s@V&*S6KmuSV0W-&Wq8pL>IZq&nDd^kWl?CQS7(C&`MxDK%g! zrst?OZQa1vzR@TVxHmB~FM$ORfo%nD zWVuv)Sfu=_4gDkNhi}9F>@3nOaf+nIExxcwJ0-WaE?C1f>@jP;Mc%|w4JEN_8+%P# zJonA^nx!2@7Ay6y6xY%&so7Xl9I{qo2h7)lEP)A-uwXouBl$$j4ye5?lF~YQJiHPSp-b?T} z33%PANersaC?A>RnUF8Qs-&L{TvW}WKv=ahJQJSeR7;%WBZ6w12u2uj|I?~PCzCK_ zRWlnA!AOM1sY-QnK`z9FMV~mUS_GhS-gec(_@W}OTE_LpuQkMoN{p(+an&J+e1PSA zOsLSE&UeR)y5T(4Uy${Oc^vvYK6N~#4y2l+unb2y7V<&I2a-lNtd<+yK-f3QM`BUc zt~Ge!)I@+6#GDSTlZ8N>0}B@2kk*7%TV!HFz|++`Pgi1`S1fvq;D&fy*!YXtVelVeV1yub$yiGzotfcu{m9Jmzo{8|&LX`1y-RBwq zG$Wpmggco@*oJO__jik2DB2C9Zt$}h34?tR=F}CP^#ReBBl;t-1^W*iIMBNvII2q* z@=bGmm<#d;VK#Vwg)0zUGWx0#QLeepRA-u4bvDQ^J}9qTv@O`uM7n>qe4qT%=5f~3 zy!4#n=~yAFvt^Ylwe3n-yL>EL*^oY`RJO~ex+>=ArVTt+ z)1$byI?Z(weCw|H8I!J_pH$pCmnvXOmu52JuTN!Wl)Zg-T9my96<5D}^Z{wTa_rKv z9e1cdo~~*5x*Z-y-=xEo)cqy*pKs{QL3P$!M!@DuQ;zziwLe5gImMsxi)VRi316DF|r z1a89gE+yhTsFf)If}2aqnW834F!vHL5U2nuHkVsgdV`btdfLfo<)%H_rw@k8Y0wd(+K-e6sq z?Fom7<3%fyLyHFlK+}x<1(!Z51z6a;r6=(7PYBqt5`d>jzcjy~Z{9|~G{2xvN*2i~ z*(Ccd_syJ-2qc+LBX8O08h+gT5I-ft;5p!&f!@M#gEDozt%4VM&bGOHbS zLms2I2?AVX9xb2@ydM45F-~mN^u! z%iX;F!HAy;3WvH3&dXS-3#$Lf=$Ex0HePF7I+Yl55_sFacrkUm0a7daf_kH{X+bK3(yl z=bC4!;#$K!s&$1tw96##%N$M}zE3r+kOx`ncd~2`S{uE&o1|T)_1$_zWpFvh&KC`nM*TQ+ZH<)I`7$bWj*ytx~QtD zO**sHH7Q5d;hgiNJXePnM;1ozIa;!w1}vymSS2s`$X&P6*p<2P>C7iHnZQc@v#ajs zR_xC`q9)`ckI0JoJqqc`+8xQBBvWYS>?!+v#d|Qz#n%^JPc!eo0laTZtZ7-M@AhY_ zc4X@RU1qmf*Q*ijl8BP1a>;g(Uac@06B@k2o3 z&q4)np%}|}P1mO||iy5DSWe9_ZS! zz=@&gsNho~U`U+{em(AW^SR)E1GpR(W1*7KLwC76{LtNaX&47ORr08wPiSpb{4jh}Ez|Q{ zR}kKTF-Rz?-RSWcCaC7AfT%hknFu2HR@BH0XW+_!G=Y7=<3*^H?hfB%B*=b(w9gp1 zF%)5ALGBQbBMzr};ipion+byW0@43Ku79GYbt^L2uQXh0SU2q@+SX}QRkPT!(6R2) zN8LT220sZJqvON~@wbk@bbN`q8)jspb*(eR9>SI6=7Oo2dpK$^#A|> literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_391924.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_391924.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cbb359f176a9798a2c25c0d5c049867f97ddd39 GIT binary patch literal 6371 zcmbUlZERE5^*;OgD}Mfn<0Rw*a1s)58pr}+5hE)E@&OH`P$(q<>NdoP5z z*OHl-u%TL+`^lS;Kp&Gu)s`%%A!Ro)w$klH`|OPfyG_G{<9XFE>n zkg_ZJzH`slJ@?#u&OO(EF_{Vwv{zsG;Nq87guWt!Oe;(S%db%g%^()BlppmeU#gFi z$Ep!^A59@IH6-iFh$>%J<)w$1)i^R@)$8Etb?};X@Y;3ox^?jSb?}D#_yH?haFxED zUO=CYEo6;eGz5NH+gN5DylEZ0IX`|t&swf(`V3iLTd$(N0u}NWWUVkxQcl}rVnC0? z96I*=;cnNdqkSE&bv1B^S)bo98oB~AP3 z-V+@jd(Ikl0IT7*+^HcaL4fn1v>{anTKTfn9o6kLN7$ajORw-`AgspB?X-??@dms! zHf4EBumzWdz?yI=w&T(e=m>AZo3KU3l?lXRx#4157J{>hEm_%eTqetC0b3q|J_MZM ziX0r#R*5SxL$u{srB5r(`LNW=3S>Q-aV0kWAi3;nWVtOH&STE5U@Jy=Gn^6ha}|pQ zRApcf){;5#R)Qt?G*(BjM87a6wEf+fSsQaPIQoevGpt~w*%jxOAX{i(q={`mLStHCwcEZYut4wHS{mNP@1$i)u4 zjiv8sZjcr3B`aAlJ zPboBsoQ56_CO}R~qTQE$JbX2*FgfZS%*Z=c5P~ zofKy#;wDKY_@#nTE;z`!M_fKuqCKqdlB6E-1thJTa|Z^!k|yNkg1n>~@CV((p52m0 zLDMi~CF(S&2uf-;C`cM!2y$L$gQRi0Mg>mN^eZ2a{GpCZRF6b;O9o!xe5}{y9$5j0 zvtWM?7|3EhE4-cnN1}}*0nb5xLL4Mv#F3oMkrc}rNJkQlq#lJ883N-YE&yDhA_E2_ zYXnfa!XWyS{&A>=<$Fa*=Hc{N%ML523V2fDcXNu zXq*Ug!~Ce*<88d?_6)m)i^0H7_nI$>>yF6Lwc}K2&8$x>ZA{Ykl*N`TsS_=Aks~QvdE&fib4I!!7L~*&XZMSx zyXKk2qMt^NrI^yh;bir0v2ypkU#vJJGKV7_DU&r>R3n;d<`|J_j&wXS7sdUex%#~q zZuQ>mog15PSgbuDnh!*}QtMJ4nk-YjvEKMtqG8drJ<|EeWQz~J<({=C2E~%PxiYb2 z_xzcA?SHDe*LBymFe&amA(osFO(!FrDc~G<>*Oq*cxgH~XAn!8=DQR|vAG>Mo3or> z+w3ztrgtQc%{3$|T0Z7J34IiLV0#Xh2`ZZ2=${#!9!$7rD;G<4eQ9cpQh&2-N)=Zm z+J94(=$dxTj?FdAVey&$;`aSw@&0IM#+lmLGxK{MSX-iPIhb2#Z=Rh$^}xC>*7kLJ z6SS{ea#+Y9#h9krW9{+c8{JWA^*~?r1~+qQ`V!bvI43+PZ5H*-QDaJPjyKGfiMD!C zUmw-38pgCsWyt)arAlP6KehXb6MJV%KPa4QpDPusn#AHJ(bP15bU}BY{>*(}CmuXO z9BDcM{x%o?(w(r!2SxLi*+SUn*@-#fcPHm2#O*Ejs)%z;t&z?pEi%}r%rSF(d{MtS z()N(nUORH-Nc2qnWa7-1^tNRnM)K__WBuQ1L9Jn73EPzaGrpNyA3hW?uI z8Z(118p@WighgqY5}T!lcR)atJ%?D;lR4jpr_4L*jNH#CO#MK;^piQirx)cbw9iAt zUN?r-po`ojYZ{@3>FW%X*tEPJ6Rbyr)u3C8>9^54*@q3-t%@`XBoWJPh-Bk6EoO`{ z1H>D7N?^d|DGFCgh_EV9k5&0>RJ3Njtu|3zd3%+Qx401 zwhGxtWWJI)V9BB9`I`Xiv4Ji6b-^zv3SreLy_=(s%U^$1?pvAf;j3yh8n z4#6AXgPgOS}ULGA--9+|D6&fFK>ykmk31_mAcMfa!|q>K#?Z^qKc5fFKd zAniOpNFWkNj~%2V=hc&&j!5_w!Ef0Q-LF&|y|-f&hJR+rfy}&+a;nihi6I7I%*t=Z z3_q!j2Pkl`x@V;@Bdb*^2#y%zR7>>nj?-;i6%4p)=y;M;9gzrk2)+_M2o+(-o~}Sz z{Xvi0&p*41V9yceCg}bhK@Y7ZGsbCSqAelJ_J~{eELxj?f9bP9l2$2pxl%$WLGCoIxpuf3BBknftu^mOB65aw zr%1QrIx**U%Q3*FXKX?c`McVz$-mqZIv_TrcUST|MVBSj7WqAAVG)OUKMvBkHE zbZJU&h&DytE1xNSOdl`2VT_-gIXite;eP!b(0(9t$;fE5OAIQho@eg2f;}l~VWKHj zWRGL9sAhIdEZUK(-286Wn_cgX-wNLhe^J@|TL0A0^`VceKdJwyK2@}Nu57uK=-7f7 zQ{;$zvOTcH*+#Bk5>d-3=w^@)g(0fX113y`Rblm$W13tUV>Tu5cBXK44DC1 znBG923ag(=pC;!YHyUd)nI)`+O9inssn9Tj45R6Bs=00$anF&CK=5E4XWdSf9D8yU zN=Cs8uQ!jI_qv*#G}lYG$eWc7a3qvQd1Q56=MF5O#eWY5L_z#welO5}-#T{lm@={>t8EI$@KQ0Ml_17E z)f#L4wiHIoYNTT1xxX2uc2TCNcgi30FChREju>zq;CwWeDm_`5xO;_6O*Pc7$u_wWl zZm;b7IPZJzx#ynw3r)Kbw10a4)E}S_`W7#Y<|s^7et^j&;t)r~P`5r3-GsKb#I4;V zfp}s-lVh;zBWK~s0g$@3an@(R$@1`CJ7>E}-Ytmfc5rqc4Xj%$6(P>?47jsA9OSyL zlHIQ2&fHf~x7&hvcTtyTj0ky_u&OSVI&$jtZ#yEdp6EUr>0Gs1jq6}5OtxV!TH(JU zkl2e7Vc-TV#;i1Q#N7fuRq-53OObM&Tp?5AHe{9!y4E(1kg2-`o#TK^`Q&vn{ge3U zReY5u%gXEV8uY0u8Oc8IFM8X;pY@VBAgeF%Q8 z0bGsj)$D{@lB#{mBi2DUm!~{Zi=pEX52^J}`9j0VYkff$IqFQ-dW}7lM!8yE&yn|R zcX4DwOVvMz-Y?g!qHmA`CXO<9QLdL83_pz1pdkP{y|RAFcwH^~VXd91hWGzrl0-i< zk~jR!NZu%06dU`2@$G|lg5VlK0`prqLOh&98_&dsh<}@ zgy6z-v}%!Js(VOC_6bZp66I7oBQS|RUbT($LQ+&6y|E-C?S4VE>0pxOqL)+>fShXO zl9Fl@rKG@vn50T#jjEfE@o_#OMWnG%qiSO!Lz1A{dh~(S2BK3Xx>TY=b%~M?<@g8_ zFHHxF(;m|_QG~K3S~elzxW)+=36NEgY8#4jk+?o2RC}_wSA?ZL8tXyWBQ`@+?V44h z>S0)RBt8;jU@IyysCot?0v~6hiD;q^6KjqGVr?WCmFQ6&nrN`4FahD(-asg;(TR;< zPw;u|hqEoCNnubNVpzVVpJ4|XsXv+6!SsOxTSS3vk@)yf3#?keQBpDiq9rhN=g?RL zfRO@_O@a@1zxeBL>n^~kv`!Qm6!- z6q$cxB01wy0^!+XbFE+b=JtMa`hHYtJ*@;zEA%V6$e%>`q_TQ)+r+l~shQ>n^?Sb% zz8wF2{88n}bjK4=75<=SvTvd<&rCJU2X=l#w`7QgKux|eA5-eKD}n79&q8JGN4|Vp zzGk9j>abF|W#*T&)t?T}*4$~C>rl40E573yyP2j#scbQs%qXoo%@c4{d$v7yY$`A# zKB9JK2<*JokOgPe2g2l~iA!LQ?|oTuZq9fXo!(sY)Ow|IyW-pqs?5u*ec6w^yOwKF zMRhKwcz<#G?DY9t=Vyjzo9DM2RJ;e%?F-(jXJKIbZ|I&@{9C}){;)!aXWQr6zG|Fn zS6bWdk14IM=)U&)-(&LCxjx0aamojJrbcI^PtVMZDqHu?H7Z+QR_OidV~BV{fOZNlckjGabl#s9scVXyMyTD5$LW0G~$7nv7 zXtmI%0KMk9Mol6a4H#I{C==zNgWEv(df+h2EN}rw87Sru9Vel3H7dH_6Tn}tYSC*G zqAbpe>z1OdVkP=4Sr@>PcRVQz*i=Rn#-^C zZ1A^pj+CRk45?Thm8;w}Punt^~=sVdhD>GM)@r#ZgzCDYpc#xZY&9R@Wh@ z*+^_RysBuxTmu*AIBW5hDvvMLwBWfP=P!>h@srQtFOM(r+hn`!xJTbDSV16}d=-Vf zUE?pDNDPfgL5WX@NkI(8M?@*u!v~XzAiSSJk%{v`j=vOT`5=?vg4$y%;Dgp4;qmZ^ z1e=TxF;cWA282;~5XERRAqIPsLJ-ok5WNf{V^VY|7G;4dnYd#SaE+?kb*W2-V6ha$ z9S*G}e4Caq!4i&sw(#))Q-tsVSHXX!AG!}M&-?Vq!o}bSe2_-jW-3|AT&*k|GN)Y{ zgTB`SF0yu&-d-p!?D|tA;F}+^s$}QU*A5Hyun-!c6S4WhbXwR0qw3rrW8ys=b6~u- z@SN?BC0QmW9@uGMj$o!=LHD0D`eWc`{f+vmGc)AOMP=Lm`Sri~;?kFIeg4+h;jg&| zuQEzk&-_7FspW)yKp(P}%+gANkW<(XaNW5!v7dnw#Rl&@+ppT-ZI#9fB~Eyc_E4&J zy(9_Ir4%P@00AlyQ3+P1dcf4dh!`E`p|CKo6025Fx+F&8kx@n%5>+>ziW0o)qDuBb z#bfdezFej%;m{%7M|Rya!fSwTfxp-cNB-Y5TJTWU_ht9xh9@siTz=@;GK=MFKOoON*m7W*PSr50`FxLi%z|3R$z@bY;RBrZil4W}P|T zyP$9K+{C#&^ZuKF`>8;GF9#4sr;lsVz=G$B4Ht+w4p$bSn?zBRf_rcQ zASohcNm+mV7C_2_4_MO&WtoJ}S>TfrYk-y_pCivATmLV4w$d{qld{dYizz!KFvO8K z#$^Qe($*ESE+Gh(r8Izq*P&BMc7$Vum+{mi@uC!A8Ig~KL!@vV!*FTOC4`fh9>-hA zp*jR!g!Clpfr{f_yT_(G(FrgSC%i)N>S3#kFU0Nzz+S{;9CT^4=v|j?U-Z^rX~)5q zY5bl-E^jHuZf}|1acjq%{gLxv`tTET^sT{oGsC&YcPR+M4I9(#nNEeQTXrIaNr+TbhUy z#4?4d{FBWS&C7IQ*%JOZb|+R`b`YJ!Uz`8j{GEnhu_6mq;OruT1fAip$Fi|y1k?P; nP5FkbO=oSNQJ+vN7`k$nKm`3ARZUd=s})t$y&J#c(QN)7q(fsd literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_417385.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_417385.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ca539b8332956b68fbfbbd1f68d1b2a2c7302ea GIT binary patch literal 5271 zcmbUlZ%iB4{mwqyXZvh}4R*r6LI_Eml2E7V=msN@kdQP94N03ew3B!E4r1^ha(AS~ zeJ5R}OtITan5`Rbztov3g*K5!_hp$hY0*Aa+ZS%2%DOQrQu`&}Xvw5)+K2t#o&P~e z>H37N$}*PuBRkwjGT$5CYjAR3~}Zs;Hf9zE!*SAY0i3+x?40bYU7-ojVC5+wy{WW zj~}Nv`%TNJz3iLgCNb(T5xk?U)j3NBnMYXVu*Mue_u`4O;Y(*mPlnHLf;Q96YRu5Z z7hk?mK!;q1>X1jr)k}EAdPB9SJ_Y-XJY`uws@G+#fc2-KPXVVK zsKQ}wyHEfzSXv(U_Ho;HOjluF68=7(q*4?X-b!F*wLoD>`fyg&~C^D z;^!ug7&q<4Yp`O=s1ai^zKFG9EViQi+k{Aiqe_m}!#H=$;L_5Ddu^5x$pbsx4i%RWwv|(nUDkR zH!`x6sin+9PIO@NZj~rx`PX3@vGf6ScF)bt~1U4Ry zavBxkqSrNZJets~tiUEFc+HaHg`}w2#$!oVdiIEBG0+r%oMz^d5-3R)$)O);7EwwH zeDG<_!iJ|ML9>h*Lqs3Q^BOss<06GC!kTBG#=d+i*L3Zhh$VGOhwP4S?? zVkpvwgl0>QkBfNC#$*hlq+_NaTG&5Pv+C}Onv;!0X5uq37NVh%Q<`lmEJ~~(X--2R z9GMXWETwM?Na-WVYUG$kMvD5uVZ%@a4Cw&`0}-87d>Z_1{qF5Q7rWj_3RB`V8{xaI zvXLoPx|&QJVkh{7)Fld$E{Ts%cLCG|`ywS1U_^rDAD*5K12SAhMv|~)M?>AuKK~q0 zG?&2?p5lcB9}@;(Nq8H?Uqd!3#N!6S>6$x~Ig}e&bv&h z6)ptY^HDX>wM^9&Jl|7h$^^mcTKT7(XRSKOG5^udUOytkCr*V^wupLm_Lv^x7fMd z*!@T0PpLnoR=qEP!(c4*-q^y#{6vn;H?8;&f9~o^lLfCY8&kb|(zb%XId6q^q@9J@ zrrb-v3FTf^W(XZJeluV z@ppl4+MVsp?^eABRr(<4tp+pJbsyn=X5CNN>*m}U_j|%ZYCg3>?@&uw>lAWZnqI?99e`6{ zdjab1QdRpi5_DBtE2=Efb?$Oy4cz38+?JB8u6y$mb?9~Q<86m*23=2}%I3G7623n5 zHi-}KmX>j3oZ}`~`!>}hVDU4urB&W0nc{4+g`;IFXP0f9L#ES?G-%_TH)+`}K_y@` zWY^LRNzb^l8jHYQgkfFw{p=qpAe(TgqUhP&%%ZKX6ry}d`6U7 z$N1JrG9g7LW|A`%;_+>W@XlK~{(3aR3;0{|k%Bt{EI#VTGrWfpY`@g-hWp^R5rghm zrky^HGLXbZ@|Q~3s$vC_ZEMRw;&A$bDDmUMXq<#Vn1?G5I9|2!zzA+o5j1Pm`I9dV z2u%P8&CrSX;It}=&<~?V_r}=x7{~Udb`>9~-dHli#>BqECCVTcdIq|GDa0D>o^xiL z*@3K-8&(5Puh5RE2-S%Kc!^kgjPyD_K?DA$V04PcFS z=`kM3k?^>L?mU0{`K8eByFcw-rF-Bd@cH4fHoAH|t5eQBrc8`g8GK9=7H@X1N(BnE zJsnE3m0^y~(AnB|oY|3uSLa{NvA=u`Xx|lC-&QF7{nO{LSmA+UT*_&Edm|8{a>JJ& zj{HUex&{Dun)K2NJYS+@D`<4o|uh%G>qsU7s0koo`8XQ}#E z0fj83Qzl#C`XM-~7i=N|PMc*oXcmS55>DX`&!E(w;n<+b$W@%Rnp5JT3XZTM9}dA8 zxPWoE+UF8N0E^-L3feTAz>706Ni_0OPhfmVD+#XWvxxtRQOs2%O%3(MTzw6wWAGC> z=oI3iyG9v$=x)3*gmWw1_Lw2;#Z6W3hi)IbZ(XI2DFa_tC)*Z^JAEzN_6`F% z*W98ErO&HW!#YhcK4sWA*`!9mQN*4JXQ9(P#=Ubm2A?`JPz-ByX(r&arkN+Bl4gg) zE=v5H(?an~zXTU`QqbceU`I5k5t8skGRA#O=%x&x>E3vfn~Cv#0`?PMz2dvjt(!@b z{DNTqPPqO-G(9A3w;FCVteXy#9qSC?t6%7x?_76b$X$Xw7^*2jwHWdip|5=aZI}rY zqigu;r(_f9O7nBEOl+NiX>R62^gi0a(~TEMf^>b&)RDfgjuM`Rcj7mk`fC3N8YYhS literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_654780.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_654780.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da487f1f869688a8e03272a45f68a6a7101495b GIT binary patch literal 5407 zcmbUlYfK!+dG>Da{eU~(PaE49UkWB9j&WmZLju7zHnC%i1nT6v+#dMg9(Z@L!)aDs zSE^I+gzWHY68~*Ha*%7qVCm)>ql?$QI$v|(m9I)o45L!SiVyQ41(Z19OrH%~| z;|NWmAT^=t(M7c{YY5U4&W$)SVvUc$(~rQLisQ$utoaT4M$W*9g|!6HL}?Lw)<@uN z#qnci*8T=PVq2fF;|(-oH=v+>-Iim9@;M(8Q-g~0x#KTBKj1%obfnLJV#8=7Z9uUd zJ?-xs>KpVA1Pa|Gfi}W#?R66gv;g7O+Jr%e@?Vy^k)w;H*od8RAhCuxgdbqw7d@E9 zt^%w}#d@*lbxJG~%W)|#S1}d10wAy`RtnH-4SR5pKbOJS$3ljy-lqC>33HQ?>Ep57WDX|_STnkHz zer#aTn4yOFU?!GvJ;4%u4r?J;;-fm>225!j1#(tZu5Oc@ZG&86fp6B$(EzJX>>4u& z>x5z(=%#Vq`)vYjPpz?3fo8Xk&3~~8ug_YS??tcv1ERehm+)S^9lMLoTU`~?B&<6( zTwtLvQ*%ae11{kpic!zjx>)+A=?2MeXkybAdK-)D%||}(;|9Guq9?b|-0-zkHtD{T zd?&NspmA$Y)vU>T22sz&wWz)wcw2#oUM(JlJ#NA~^wo;PT8rBt%5JRXA{jT}Mi{G^ zr_vjL@=qG!Z4(uRS~R#It17=vT|Yl|e!HD3?)r9nR&2%wY{$Dc?m>OeU2 zxO=kP zIv7<-bubk5k8`o6A3G~g%e#qWVx zZU2VttC!j@#<)p-iU|bU&oO~XMmQIXwlU+usL;-Hfp#GnnQDhoJA__{MZtmqUhSQl z@dMJILk427D*K<=*LmPcpeQbl$v+w7qQNkC2>h|Yk9R@%CG=%Ea=7N&5^X7J#a=5t zpY?c?XJt=Qx<~djXIkab{nF8*xLs|_H2jq(>s3N3Vsh zh8C&Cu9e;0a!L2C9@*U|^=BVXzUL~L8%_)-&!<{fTrJWIS(i6C{w|X)PmRl!yE8R% zMCQg)rvILF>Sm8MfK&&M)0xw2!iUt^T_^#bR}U-<&?XQ-@dRsG;-(ctII>L ze)>bX>$PQORPGvGo`~Enjn4KIm$)-9~yS)B90VrvIvc@%&Q95|$4R z$ejan^T6`yp}VCo7wAq8ebAE;9>bZnmc7F*T$sNA6ZO1}W$VsaN7m|2wx(-j?-R21 ziCOc8VZyvxhTKoCRw7&ZoIBwL;rM)f#kxc4$-2vuVcEUo{gWk;H-~NNyb%I_g~H$)J}3kD#wVGho(_ zmoHAw%Nw_lFHXHsq$9H{C|8b?;R)yzdK?O&f-sPOIplBr?rtp}W-SWbX`ns>KmIG|{wtxZ!#UTP z=uCDbo=th)tGZH^_FSpGd{jy59dbO?spi+-|n@h(YIL4&j2ec=-OQx%`R@-dHER+At zSrgWz=Pl5>@XGuvDdz3(1MQn4E1I&}pO>8R-W_QMDuS%HE?d^{?<%LsyjqW(F6oH6 zhye(}dLu<3iSb$%x&;(MB4qJdz=)J+5RH$|VK{H#`CRbuC~^j}26)gC3(!P*3w;J` z{J->>1X5M04gLXQ(sz<*hO+~nc3nWh2zHcd&}Za&VZ=R8Isz$lV47uo1}zgwPAU!| z2qkQQ;e-ATAI-f;aHM!=qg)jcBS{2FZQ+9abXefEJUdZlVY3* z35qQM=cf>SeTvKN)-!OX#yB+=9GQgT(8A##kA>OmNHwFMMUfag9S$Di7~msEEB|xo zR*e)zeU6-eL$1H0`c*SBSTEOJs$DfaK{c;BQCY=8>wN2~OB=O!{4#t!td9n$6VzW? zKh06ysMI@$6Zne~;9WB!gHzKnLe*2Q+2C9_5ne?Arl#M+SFj2$nm%@Z=v*V%wPO^b WTwgfLsj`3VLnXCuMJ_p1=l=)1V#Bxq literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_769893.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_769893.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92afb7fbefa03e389b0f38ad6ab1c1c88a0f2c6c GIT binary patch literal 5407 zcmbUlYfK!+dG>Da{eZjU{W93b_!4YN9>$4@4G9F>*u;)45~!2Ya(m!|d*I#04ro?g zSBcCel{lJIoYP7&BHV)j#U~q9K*q{OUKmx7U8T|VjV`f2DTKv zWnj^$p_5cZ2}xmt+7;rWVa5@f3XR#uUhBtL9hJ-yxoXPcrR|m?qcg!cf~XT`_2s& zXef--oDy7*OE~ai%(JyDmcC)SPNEx{*tCV*hT?Mbp3mF3LGJe0i7h18e`}Xby6q(1 z$*MO<+}ct#YVwwW*K>X?DsMa9R$!r5^G9KioA7phw`O3k#ckkaH`Zd2%CAqELPr^XLNUdKvmwq;apc1DyA*>EQS6gke2inFfiSBWqv4ohW;iA`7E(+zAui4< zmeEL@5uSWPF==2r$cE1=Gys8~jSC*10 zVN9`)aq+3i03V(SDWy6%91Eym#j0U}+c6sEcp=B{%rk*BlO|OC#1u<>bd)FdYl9K+ zn+ll#|B4F<@5@w;JF zjo-0-^K#pTI5)vhGQm*WStdBa2xsH5JC);4w2Cf(4G0-4@ReL9= z0{{%ox^YM&Lu4rHE(+K6KdVun&b;|qR zvW4CAyHm$9t;@9?zu|s2^Q)OV-lMQkP}ct5$imqCSc*y4tyJv&!qqlQ{nb&HqEg4? zs@<7=i#vZ!Ew)@cBv(DVbn?AH0^!QtCyURxe~ z?eibWov$x5BXZ}+@_6)4X>7K;ILGb5SC)s0awNL%>pIednVwIYGJRJ9i|3Zwm#}=G zU+(CaoBNkf4c;kzr9gIi@T2aG@CeGR)%PB^aDM(gEY$NZmaRKx9a*b8*_y7Fy^qP( z$7anNrU~<^54nG^T7hh3bMAy2m}ll^R;=5l?yTFFjL7coAHIBb=*m#$++yoW^C8)N z2=@Bn81VE^y7D7WrYBP=H?_;F~wFvxCW#slhMkoojF^5VwzX9ndLe zE+kBGu^`FHdMks0S}MB%ddYstv4Aid*Rcg(mQo8mswhq^{w5$Z@Rp#ThiB*II@D%u2xt`6JObt5QRjYKD0xoCX7bkwQHlR+;79)e$gX27Z+ z&0id!=QnP_UmTz3r=hem!#d2Q?N+Z#b^!$|wpR6Q_{hIlvJ&s<-$|<&ju0 z9-U-_@JJ-o$cCe#7$1(uc(}*SZ0LMA7~;sY>w!c%1b6;}E;1ql9*}?w<{nuv)bh0t zzH1TaerDL}iI#vk(&A4q=nBdaB79Rzz#g$`OzQNc^ckgKO~!%zElo_9+}Pmjni1y{OQxtf)$M!7@4;y33lE2tmT zdf@`YP4bGJ)Cjx)w=nS_c6Zt&n)#RM0rFwH#=g4BIbuSLbIl{9{5J_~RS zgk(SI$m$Bpm1AUj96E&_gF>jl4dh=A`5V8hOY?_W^8$AYuusE}{|37MN+|1a&UGX@ zlI@9SQ=ShhFIT2Lmuqf2b}Z8emzxi+K=0U*qq*24bxY?qOzv7;bB=@qQW%wdG2MQ3 z|CRlV?Vomh+;PWxK&jtN>QQfV>F9mOsMK?x_9S=8bY<3Rn{A(E@}D_t z!kYBF15y`WoqsjOy!%7IeOF+4Q&#))k{s{do@SsT$a-tDzWRSxI!)%)I^=Xoht*9C zf(zCgDFTU)*Rs$ppfD04iq8Ng$_=!jNhO|B8f@GA%viW z4KjQv(C(+Xmk5j$?`({#By1#zAgC=|h@XlGycT?FsFH(N5WXaGd16^EAhlT3#ODHX z2(SS{#zH5dd+t)H@1DE%Vjl^%*`}|Z$mY(+*oVzm_gvYtWWHlPBz4~}jJ{16_w2c3 z(>qQG!ukfOZ}zxM*Q{ER(Il@cwl?2W*yWeN_gV%0|+IZTRiJS-@- zAmmRW^u{EYJFTZ7r^Y$e7aUoH;?Ud?7>h^PYe-e2=b~twor;7GatzRs)XM)Hx>X}Z zQC}kGm&o-uRJUqI2J5Aoi#4l;$EoI3C-Rjqw9dD#y0lqa`!6HcBKoYKI!^ti_46Fo zjY_?9IDx+^0opYqGB`CEXQ(>LH5;0XBqFN_hN-C!@MWwH7fqi!KX$GW=-OclQLe9? TWt8t^s`38=#6iP{ literal 0 HcmV?d00001 diff --git a/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_993568.cpython-312.pyc b/src/temp/gen/__pycache__/triton_matmul.py_gen_triton_code_993568.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1344c73317cd0845004a06ad4ce21ba81ca268a5 GIT binary patch literal 5481 zcmbUlTTC3+_0I0>b7mLVp*6AhAinIJUvYCb2`(tcIOkys!^EGuQ#$ zm6phnckQYyjVsnp)oi3lp|L-l{*+3U>QLD8vm$6`rBBc6%MkjCEOO0~ooxOWCQFD%JQYM^+DMI2>RK1l zkI@wJQIo2i9IH0+Iv+jBJ_;i%UjGC*{RFt7G~93IjhE?L*$u`_Jj0uOXtKP-j^@&E zKh0Y%8^$bo+pL$-m{o^-*1S|^mU6S}MC7o-zI@{3E5|)&UL70o3_Y@X6n9)<2Tz}T z?NpZ1Tda`@B)A7i! zs|9_qk?e`<36#LIP!oX?`RX1Gn9?YU_{=I_-4;H}BYcfTx|PTl%wn5gC*{3TDsyh@ z^V;u}AYZk{Q$>g${AzlwycT=Y4{k5=iK~8AT@r73;JRnGm>;C%`MTmk(T2^r1jlF>(SIp71G# zIiC;_6_Y<0;iP?g6@vz*VaY3cJ|ZawQHlsYcbj71Jkye(7{;~1s}9tVLJcd_am6A^ zLV)*qxKLp_m7k6mO~ZMpw?OL+3nT^wl3D`X1W14j4k`L+P}>ro33*`R0nXesq?G5T zfv{&nh|EkYW_2KjGX8)dN-9rCVKuNAnDco;ih&5q4q?R<@%u%hvo;wA$Ec7ga1gOi zRE(;*qQY=q?@VYW$bmN$3h>0^9xgkMD-@?t<5@9amnP2(@>D;G>XC?TV3X@Rw*Q>& zx)c$n#A(j!>pIVQr#R_+B;3JG_`*_`D0sUhUue1uR$btDDH8UCI0^LYoSyXn(31sv zBXB0qJ-6?`!Dj)Z*fb)~lurozg2D@+4=F*!UqXb%(6>7fV_WD*bfiY^S$D-=xnJ&B zGA)`?RH}Eaye;o7ROqpY~x>YV~jUCN6>QZmX4tMPMeP?xYHvNKJ z(|LzobG{fmkzs36`_c_^eTU3;#s)IB@>OTEY-?U-Wp;mT;DOzl49fQ1AD{l@jgQ_~ zzHp~~t>utxKNK6xJP~u>UcL}W1XLaEZLuR6n^ zSm~DAdsi;X?E`nEzl_|yBtLU<_4FC}nKN?5Xxx}_?n*bw&gQgVcD8@Y-KkujkR1nC zDrCp=E0?kihg60`8pD0ozA%^=Odd&BFN^oseR1k>NXG2EE-YPKya=jQypLsbbDYVT z?aB6Zt?YPOHa`vf7B3UVjWT5a(S`$+RetYdAZ|aaIa97`0dH1y%eKcIYC8pfXPrOm zOG?Q>nQcg2kXcvy?DFWZyO+<(EeBR8x#c;Tc|O*EpEh1OI)60&dh*S*^d;TC31Nza z1Ef0$oj3bp5hx-zsao&-6TQtU0 zl!6naY)d7Yjv-!`<5VS=s2bh^eXp`nJxF{4HcDT$VLhe`Iu>fh$r4pwf=xh$19p_E zeoFcT>Yc1LQA263N#rV4Td8_U^$Ek%Q6v1Dcth08Q&9_Vj9SC{wv}hkNiK{f?mTX?%g-h)vkcGuVQyx7l0SYKubHa0a<;!|S9UWK*FN{lj zJiPhODc~ZZL0rBWgzo3M?LM{=kYYSV4klL=78OK9|1B*6q0FnF4LN#JeEBIbSU+5- z7>H%sgC~%Pa_bd(XkfHYsDp)251mLt+!c!n&%>ygdxKnPoabJe+mU@VdxH@#7ZhLW z%wdKIQzvx)ilJXre^`H`K0UHbFTX9f^{&;v_$T9^nLjX}_kAv{j*iPGy=#Ygxt5oA z_=R3zQOmV;;-z(xp>AVA(W27EO6nI;g7;k1%9=}@FfFW zdtNcXdoRrjhd^YuUh6(6kRVXl(=%Zy5b_NOLPSs>b;YQaV{Uul%2ceJR|;I@q(CGr zG!Q=B@x2P=-KyQIfD(=oJUsy=U6B)sZ;Pv!!-qA0LM=o}58-vdlgn8=3&#IDhB6Gh z(39v%a^!*hf@xf(8&{b|0zHs8a1G#BdsBTM4%`^{U@(2*=80AM$m-Ka)}UwFvaIv{ zvA)=a3~M8=-TXl8Wq4c}HmSSX6B~~A+&9}6m;{p?Nm1|;c7O8hN6+5r{!PzkJ@?En z#QGkX%vUbWU%KA1)Unv{g=tr&rZ&zNbyd}IYo@X)VaZsm3pI(F>jO(ei$h;nnu=&t zmOSA@d);b%|C)VZ)jaT!@yGff(iO?4WV$9}w#2*RTwz!+C(Ox;YoOZFxy5rS?)^6b z_g#TiO_}Upb;jA4an%2#hBX*B>JV#-9aSUD3r@~AQZ1A4-poL^gaRlEDfSjXqEu8D z)&KtlgWD0lz>B_!OT0t8uIP=Rw~?pA`tK#rPPE zB8$Bco*%b-((zHpit(QLP^|A^ao}#j*y9(HP1jh6-^QJ>!T69&*KL>)>x>PnNAISF z-K=Vu&=316W!k;h8HCTd8MsUob8aS(p9K~DWI$3Z-bh%KeD6%>Zii9zZBX4I5IYn` z^N43666Ak{v}-JTSB4_|Owjj|;0HQ#sfj;@ZbMH|)K`f88?yZ!HEdXs&U~eAzHUS3 zraCrQ= vob_start_id) & (ids < vob_end_id), ids, vob_start_id) + valid_mask = (ids >= vob_start_id) & (ids < vob_end_id) + + offset_weight = ids[:, None] * stride_v + cols_d[None, :] + vals = tl.load( + weight + offset_weight, + mask=mask_l[:, None] & mask_d[None, :] & valid_mask[:, None] + ) + + offset_out = ( + pid_batch * stride_ob + + cols_l[:, None] * stride_ol + + cols_d[None, :] + ) + tl.store( + out + offset_out, + vals, + mask=mask_l[:, None] & mask_d[None, :] + ) + + +def embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor +) -> None: + assert token_ids.device == weight.device == out.device + assert token_ids.dtype == torch.int64 or token_ids.dtype == torch.int32 + assert weight.ndim == 2 and out.ndim == 3 + B, L = token_ids.shape + V, D = weight.shape + assert out.shape == (B, L, D) + + BLOCK_DMODEL = triton.next_power_of_2(D) + BLOCK_N = 32 + BLOCK_NN = 256 + grid = (B, triton.cdiv(D, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + token_ids.stride(0), + token_ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stdout new file mode 100644 index 0000000..2de91c2 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_155036.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_155036 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py new file mode 100644 index 0000000..80887d1 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids_ptr, # int32* + weight_ptr, # fp16/fp32* + out_ptr, # fp16/fp32* + seq_len, # int32 + num_tokens, # int32 + d_model, # int32 + vob_start_id, # int32 + vob_end_id, # int32 + stride_id, # int32 + stride_w0, # int32 + stride_w1, # int32 + stride_out0, # int32 + stride_out1, # int32 + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid_seq = tl.program_id(0) # sequence index + pid_col = tl.program_id(1) # d_model block index + + if pid_seq >= seq_len: + return + + cols = pid_col * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + d_mask = cols < d_model + out_offset = pid_seq * stride_out0 + cols * stride_out1 + weight_offset_col = cols * stride_w1 + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + n_offsets = tl.arange(0, BLOCK_N) + n_mask = n_offsets < num_tokens + flat_ids = pid_seq * stride_id + n_offsets + token_ids = tl.load(token_ids_ptr + flat_ids, mask=n_mask, other=0) + valid = (token_ids >= vob_start_id) & (token_ids < vob_end_id) + token_ids = token_ids - vob_start_id + + w_offs = token_ids * stride_w0 + weight_offset_col + w_vals = tl.load(weight_ptr + w_offs, mask=d_mask[:, None] & valid[None, :], other=0.0) + acc = tl.sum(w_vals, axis=1) + + acc = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + out_offset, acc, mask=d_mask) + + +def embedding( + token_ids: torch.Tensor, # int32, shape (seq_len, num_tokens) + weight: torch.Tensor, # fp16/fp32, shape (vocab_size, d_model) + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor = None, +) -> torch.Tensor: + seq_len, num_tokens = token_ids.shape + _, d_model = weight.shape + assert token_ids.dtype == torch.int32 + assert weight.dtype in [torch.float16, torch.float32] + assert weight.is_contiguous() + + if out is None: + out = torch.empty((seq_len, d_model), dtype=weight.dtype, device=weight.device) + + BLOCK_DMODEL = triton.next_power_of_2(d_model) + BLOCK_N = min(triton.next_power_of_2(num_tokens), 64) + + grid = (seq_len, triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + seq_len, + num_tokens, + d_model, + vob_start_id, + vob_end_id, + token_ids.stride(0), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stdout new file mode 100644 index 0000000..0150467 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_176773.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_176773 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py new file mode 100644 index 0000000..8c62e47 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py @@ -0,0 +1,244 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + ids, # pointer to the seq-len length token-id vector, [batch, seq-len] + weight, # pointer to the embedding table, [vocab, d] + out, # pointer to the output embeddings, [batch, seq-len, d] + stride_ids_b, # stride(ids, 0) + stride_ids_s, # stride(ids, 1) + stride_weight_v, # stride(weight, 0) + stride_weight_d, # stride(weight, 1) + stride_out_b, # stride(out, 0) + stride_out_s, # stride(out, 1) + stride_out_d, # stride(out, 2) + vocab_size, + d, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) * BLOCK_N + pid_d = tl.program_id(2) * BLOCK_DMODEL + + # row-major iteration + for n_base in range(0, BLOCK_N, BLOCK_NN): + # offset into the tokens + offsets_s = pid_s + n_base + tl.arange(0, BLOCK_NN) # [BLOCK_NN] + mask_s = offsets_s < d # valid mask over seq-len + + # load token ids [BLOCK_NN] + ids_ptr = ids + pid_b * stride_ids_b + offsets_s * stride_ids_s + cur_ids = tl.load(ids_ptr, mask=mask_s, other=0) + + # mask valid indices in vocab range + mask_vocab = (cur_ids >= 0) & (cur_ids < vocab_size) + + # Embed over feature dimension + for d_base in range(0, BLOCK_DMODEL, BLOCK_DMODEL): + offsets_d = pid_d + d_base + tl.arange(0, BLOCK_DMODEL) # [BLOCK_DMODEL] + mask_d = offsets_d < d + + # compute weight ptrs + weight_ptrs = ( + weight + + cur_ids[:, None] * stride_weight_v # [BLOCK_NN, 1] * stride + + offsets_d[None, :] * stride_weight_d # [1, BLOCK_DMODEL] + ) + weight_vals = tl.load( + weight_ptrs, + mask=mask_s[:, None] & mask_d[None, :] & mask_vocab[:, None], + other=0.0 + ) + out_ptrs = ( + out + + pid_b * stride_out_b + + offsets_s[:, None] * stride_out_s + + offsets_d[None, :] * stride_out_d + ) + tl.store( + out_ptrs, + weight_vals, + mask=mask_s[:, None] & mask_d[None, :] + ) + + +def embedding( + ids: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Wrapper function that launches the Triton embedding_kernel. + Args: + ids: Tensor of token indices, shape [batch, seq-len] + weight: Embedding weight matrix, shape [vocab_size, d] + Returns: + out: Tensor of embeddings, shape [batch, seq-len, d] + """ + assert ids.dtype in (torch.int64, torch.int32), "ids must be int64/int32" + assert len(ids.shape) == 2 # [batch, seq-len] + assert len(weight.shape) == 2 # [vocab, d] + + batch, seq_len = ids.shape + vocab_size, d = weight.shape + + out = torch.empty((batch, seq_len, d), dtype=weight.dtype, device=weight.device) + + # Simple heuristics + BLOCK_N = 64 + BLOCK_NN = 16 + BLOCK_DMODEL = 64 + + grid = ( + batch, + triton.cdiv(seq_len, BLOCK_N), + triton.cdiv(d, BLOCK_DMODEL) + ) + + embedding_kernel[grid]( + ids, + weight, + out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + vocab_size, + d, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stdout new file mode 100644 index 0000000..a17766f --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_180807.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_180807 due to embedding() takes 2 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py new file mode 100644 index 0000000..f28b316 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py @@ -0,0 +1,195 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(input_ids_ptr, + weight_ptr, + out_ptr, + vob_start_id: tl.constexpr, + vob_end_id: tl.constexpr, + stride_ids_0, stride_ids_1, + stride_w_v, stride_w_d, + stride_out_0, stride_out_1, stride_out_2, + seq_len, vocab_size, dim, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid_d = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + row_start = pid_b * seq_len + for j in range(0, seq_len, BLOCK_N): + offs_n = j + tl.arange(0, BLOCK_N) + + mask_n = offs_n < seq_len + flat_idx = row_start + offs_n + pid = tl.load(input_ids_ptr + flat_idx, mask=mask_n, other=0) + + valid_id = (pid >= vob_start_id) & (pid < vob_end_id) & mask_n + real_id = pid - vob_start_id + real_id = tl.where(valid_id, real_id, 0) + + w_off = real_id[:, None] * stride_w_v + offs_d[None, :] * stride_w_d + embed = tl.load(weight_ptr + w_off, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + o_off = pid_b * stride_out_0 + offs_n[:, None] * stride_out_1 + offs_d[None, :] * stride_out_2 + tl.store(out_ptr + o_off, + embed, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + +def embedding(input_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor) -> torch.Tensor: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + batch, seq_len = input_ids.shape + vocab_size, dim = weight.shape + assert out.shape == (batch, seq_len, dim) + assert weight.dtype == out.dtype + + BLOCK_DMODEL = triton.next_power_of_2(dim) + BLOCK_N = 128 + grid = (triton.cdiv(dim, BLOCK_DMODEL), batch) + + embedding_kernel[grid]( + input_ids, weight, out, + vob_start_id, vob_end_id, + input_ids.stride(0), input_ids.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + seq_len, vocab_size, dim, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stdout new file mode 100644 index 0000000..6586448 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_18528.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_18528 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py new file mode 100644 index 0000000..7e41178 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py @@ -0,0 +1,235 @@ + +import torch +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Triton kernel +# --------------------------------------------------------------------------- +@triton.jit +def embedding_kernel( + ids, # pointer to token-ids tensor shape [B, S] (int64) + weight, # pointer to embedding weight matrix shape [V, D] (float16 or float32) + out, # pointer to output 3-D tensor shape [B, S, D] + B, # int: number of sequences (batch size) + S, # int: max sequence length for all sequences + V, # int: vocabulary size + D, # int: embedding dimension + stride_ids_0, # leading stride of ids: = S + stride_w_0, # leading stride of weight: = D + stride_out_0, # leading stride of out: = S * D + stride_out_1, # + stride_out_2, # + vob_start_id, # (unused) beginning of allowed index range in V + vob_end_id, # (unused) end of allowed index range + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) # 1-D grid: each block handles one sequence in the batch + + # each block handles every BLOCK_N tokens, each tid loop handles BLOCK_NN tokens + for b_base in range(0, S, BLOCK_N): + ids_offset = pid * stride_ids_0 + b_base + # Load mask + n_ids = tl.arange(b_base, b_base + BLOCK_N) + mask_n = n_ids < S + + # Load token indices + token_ids = tl.load(ids + ids_offset + tl.arange(0, BLOCK_N), mask=mask_n, other=0) + + # Clamp to legal bounds [0, V-1] + token_ids = tl.maximum(0, token_ids) + token_ids = tl.minimum(V - 1, token_ids) + + # Iterate over tokens in groups of BLOCK_NN + for start in range(0, BLOCK_N, BLOCK_NN): + idx_group = start + tl.arange(0, BLOCK_NN) + group_mask = mask_n & (idx_group < BLOCK_N) + + # Current token ids for this group + tid = token_ids[start : start + BLOCK_NN] # shape [BLOCK_NN] + outs_idx = pid * stride_out_0 + (b_base + start + tl.arange(0, BLOCK_NN)) * stride_out_1 + + # Iterate over the embedding dimension in blocks + for d_start in range(0, D, BLOCK_DMODEL): + offs_d = d_start + tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < D + + valid_mask = group_mask[:, None] & mask_d[None, :] + + # Weight pointer: weight[tid, d_offs] = weight + tid * stride_w_0 + offs_d + weight_ptr = weight + tid[:, None] * stride_w_0 + offs_d[None, :] + emb_vec = tl.load(weight_ptr, mask=valid_mask, other=0.0) + + # Output pointer + output_ptr = out + outs_idx[:, None] * stride_out_2 + offs_d[None, :] + tl.store(output_ptr, emb_vec, mask=valid_mask) + +# --------------------------------------------------------------------------- +# Python wrapper +# --------------------------------------------------------------------------- +def embedding( + ids: torch.Tensor, # [B, S] long int + weight: torch.Tensor, # [V, D] float16 or float32 + vob_start_id: int = 0, + vob_end_id: int = None, + out: torch.Tensor = None, +) -> torch.Tensor: # Returns: [B, S, D] + B, S = ids.shape + V, D = weight.shape + device = weight.device + dtype = weight.dtype + + if vob_end_id is None: + vob_end_id = V + + if out is None: + out = torch.empty((B, S, D), dtype=dtype, device=device) + + BLOCK_N = 64 + BLOCK_NN = 64 + BLOCK_DMODEL = triton.next_power_of_2(D) + + grid = (B,) + + embedding_kernel[grid]( + ids, + weight, + out, + B, S, V, D, + ids.stride(0), + weight.stride(0), + out.stride(0), + out.stride(1), + out.stride(2), + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stdout new file mode 100644 index 0000000..11bc87e --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_200147.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_200147 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py new file mode 100644 index 0000000..04b9476 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py @@ -0,0 +1,205 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(token_ids_ptr, out_ptr, weight_ptr, + stride_tokens_b, stride_tokens_s, + stride_out_b, stride_out_s, stride_out_d, + stride_weight_v, stride_weight_d, + vocab_size: tl.constexpr, hidden_size: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr): + pid_b = tl.program_id(0) # batch dimension + pid_n = tl.program_id(1) # sequence-block dimension + + offs_s = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # load token indices for this batch row + seq_len = 0 + # Token tensor: shape (batch, seq) => stride (seq, 1) + # We will access: token_ids_ptr += pid_b * stride_tokens_b + token_ids_row_ptr = token_ids_ptr + pid_b * stride_tokens_b + + # Since seq length is fixed per call from wrapper, assume seq_len is known + # We'll pass seq_len explicitly via a scalar; instead handle via BLOCK_N mask + # For now, pass seq_len as a placeholder scalar (not used in kernel after fixing wrapper) + + # Load the tokens for this block + mask_s = offs_s < stride_tokens_s # Effective seq_len from wrapper stride storage + tok_ids = tl.load(token_ids_row_ptr + offs_s, mask=mask_s, other=0) + + for start_d in range(0, hidden_size, BLOCK_D): + offs_d = start_d + tl.arange(0, BLOCK_D) + mask_d = offs_d < hidden_size + + # Compute weight offset: [vocab, hidden] + weight_offs = tok_ids[:, None] * stride_weight_v + offs_d[None, :] * stride_weight_d + mask_w = mask_s[:, None] & mask_d[None, :] + + emb = tl.load(weight_ptr + weight_offs, mask=mask_w, other=0.0) + + # Compute out offset: [batch, seq, hidden] + out_offs = pid_b * stride_out_b + offs_s[:, None] * stride_out_s + offs_d[None, :] * stride_out_d + mask_out = mask_s[:, None] & mask_d[None, :] + tl.store(out_ptr + out_offs, emb, mask=mask_out) + + +def embedding(token_ids: torch.Tensor, weights: torch.Tensor, vob_start_id=None, vob_end_id=None, + out: torch.Tensor = None) -> torch.Tensor: + if token_ids.dim() == 1: + token_ids = token_ids.unsqueeze(0) + elif token_ids.dim() != 2: + raise ValueError("token_ids should be 1-D or 2-D tensor") + + vocab_size, hidden_size = weights.shape + batch, seq = token_ids.shape + assert weights.dtype in [torch.float16, torch.float32] + assert token_ids.dtype in [torch.int32, torch.int64] + + if out is None: + out = torch.empty((batch, seq, hidden_size), dtype=weights.dtype, device=weights.device) + + # Ensure contiguous + token_ids = token_ids.contiguous() + weights = weights.contiguous() + out = out.contiguous() + + BLOCK_N = 64 + BLOCK_D = min(128, triton.next_power_of_2(hidden_size)) + + grid = (batch, triton.cdiv(seq, BLOCK_N)) + + embedding_kernel[grid]( + token_ids, out, weights, + int(token_ids.stride(0)), int(seq), + int(out.stride(0)), int(out.stride(1)), int(out.stride(2)), + int(weights.stride(0)), int(weights.stride(1)), + vocab_size, hidden_size, + BLOCK_N, BLOCK_D + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stdout new file mode 100644 index 0000000..9b307cd --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_211539.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_211539 due to Dimension out of range (expected to be in range of [-2, 1], but got 2) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py new file mode 100644 index 0000000..3625cb4 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py @@ -0,0 +1,181 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + out, ids, weight, + stride_os, stride_om, + stride_ws, stride_wm, + N, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + start_id = pid * BLOCK_NN + cols_d = tl.arange(0, BLOCK_DMODEL) + + for i in range(0, BLOCK_NN, BLOCK_N): + rows_seq = start_id + i + tl.arange(0, BLOCK_N) + mask = rows_seq < N + ids_off = rows_seq + token_ids = tl.load(ids + ids_off, mask=mask, other=0) + + w_off = (token_ids[:, None] * stride_ws + cols_d[None, :] * stride_wm) + x = tl.load(weight + w_off, mask=mask[:, None], other=0.0) + + o_off = rows_seq[:, None] * stride_os + cols_d[None, :] * stride_om + tl.store(out + o_off, x, mask=mask[:, None]) + + +def embedding(ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id=None, + vob_end_id=None, + out: torch.Tensor = None) -> torch.Tensor: + N = ids.numel() + DMODEL = weight.size(-1) + if out is None: + out = torch.empty((N, DMODEL), dtype=weight.dtype, device=weight.device) + + BLOCK_N = 16 + BLOCK_NN = 32 + BLOCK_DMODEL = triton.next_power_of_2(DMODEL) + + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_NN']),) + + embedding_kernel[grid]( + out, ids, weight, + out.stride(0), out.stride(1), + weight.stride(0), weight.stride(1), + N, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stdout new file mode 100644 index 0000000..30513bb --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_322972.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_322972.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py new file mode 100644 index 0000000..89a3d3a --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py @@ -0,0 +1,198 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weights_ptr, id_ptr, out_ptr, + stride_wd, stride_wn, + stride_o0, stride_o1, + seq_len, dim, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_NN: tl.constexpr): + pid_d0 = tl.program_id(0) # block row + pid_b = tl.program_id(1) # batch index + + # D offsets handled within each program + offs_d = pid_d0 * BLOCK_D + tl.arange(0, BLOCK_D) + + # per-seq id & out base pointers + seq_id_ptr = id_ptr + pid_b * seq_len + seq_out_ptr = out_ptr + pid_b * seq_len * stride_o0 + + for block_n_start in range(0, seq_len, BLOCK_NN): + offs_n = block_n_start + tl.arange(0, BLOCK_N) + mask_n = offs_n < seq_len + + ids = tl.load(seq_id_ptr + offs_n, mask=mask_n) # [BLOCK_N] + mask = (ids >= 0) & mask_n + ids = tl.where(mask, ids, 0) + + # Row-major weights: shape (dim, vocab) → stride (stride_wd, stride_wn=1) + # pointer = &weights[ids, offs_d] + ptrs = weights_ptr + ids[:, None] * stride_wd + offs_d[None, :] # [BLOCK_N, BLOCK_D] + vals = tl.load(ptrs, mask=mask[:, None] & (offs_d[None, :] < dim)) + + out_ptrs = seq_out_ptr + (offs_n[:, None] * stride_o0 + offs_d[None, :] * stride_o1) + tl.store(out_ptrs, vals, mask=mask[:, None] & (offs_d[None, :] < dim)) + + +def embedding(weights: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """ + weights: (vocab_size, dim) + indices: (batch, seq) + returns: (batch, seq, dim) + """ + orig_dtype = weights.dtype + # Triton prefers float32 accumulation + weights_fp = weights.to(torch.float32) + + vocab, dim = weights_fp.shape + batch, seq = indices.shape + + BLOCK_D_TRITON = triton.next_power_of_2(dim) + + out = torch.empty((batch, seq, dim), dtype=weights_fp.dtype, device=weights.device) + stride_wd, stride_wn = weights_fp.stride() + stride_o0, stride_o1 = out.stride(0), out.stride(2) + + # block sizes + BLOCK_N = 64 + BLOCK_NN = 128 + grid = (triton.cdiv(dim, BLOCK_D_TRITON), batch) + + embedding_kernel[grid]( + weights_fp, indices, out, + stride_wd, stride_wn, + stride_o0, stride_o1, + seq, dim, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D_TRITON, + BLOCK_NN=BLOCK_NN, + ) + + return out.to(orig_dtype) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stdout new file mode 100644 index 0000000..1040a61 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_347928.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_347928 due to embedding() takes 2 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py new file mode 100644 index 0000000..fc11875 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py @@ -0,0 +1,195 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(weight, + ids, + out, + stride_w, + stride_out, + num_tokens, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid = tl.program_id(0) + pid_d = tl.program_id(1) + + start_token = pid * BLOCK_NN + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + tok_idx = start_token + tl.arange(0, BLOCK_NN) + + mask_t = tok_idx < num_tokens + mask_d = offs_d < stride_w + + ids_vec = tl.load(ids + tok_idx, mask=mask_t) + + for i in range(BLOCK_NN): + if i < BLOCK_N: + pass + valid_tok = (start_token + i) < num_tokens + if valid_tok: + token_id = ids_vec[i] + w_off = token_id * stride_w + offs_d + vec = tl.load(weight + w_off, mask=mask_d) + o_off = (start_token + i) * stride_out + offs_d + tl.store(out + o_off, vec, mask=mask_d) + + +def embedding(weight: torch.Tensor, + ids: torch.Tensor, + out: torch.Tensor, + BLOCK_N: int = 1, + BLOCK_NN: int = 32, + BLOCK_DMODEL: int = None): + assert weight.ndim == 2 + assert ids.ndim == 1 + assert out.ndim == 2 + assert out.shape == (ids.shape[0], weight.shape[1]) + + _, d_model = weight.shape + num_tokens = ids.numel() + + if BLOCK_DMODEL is None: + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = (triton.cdiv(num_tokens, BLOCK_NN), + triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + weight, + ids, + out, + stride_w=weight.stride(0), + stride_out=out.stride(0), + num_tokens=num_tokens, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stdout new file mode 100644 index 0000000..b9d3de5 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_355413.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_355413 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py new file mode 100644 index 0000000..65ca966 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def embedding_kernel(indexes, weight, out, + stride_idx_b, stride_idx_s, + stride_wt_v, stride_wd, + stride_ot_b, stride_ot_s, stride_ot_d, + VOCAB_SIZE: tl.constexpr, D_MODEL: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, vob_start_id: tl.constexpr, vob_end_id: tl.constexpr): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + pid_d = tl.program_id(2) + + seq_start = pid_s * BLOCK_N + d_start = pid_d * BLOCK_NN + + offs_s = seq_start + tl.arange(0, BLOCK_N) + offs_d = d_start + tl.arange(0, BLOCK_NN) + + mask_seq = offs_s < (seq_start + BLOCK_N) + mask_d = offs_d < D_MODEL + + idx_ptr = indexes + pid_b * stride_idx_b + offs_s * stride_idx_s + token_ids = tl.load(idx_ptr, mask=mask_seq, other=0) + + clamp_low = tl.full_like(token_ids, vob_start_id) + clamp_high = tl.full_like(token_ids, vob_end_id - 1) + token_ids = tl.where(token_ids < vob_start_id, clamp_low, token_ids) + token_ids = tl.where(token_ids > (vob_end_id - 1), clamp_high, token_ids) + token_ids = token_ids - vob_start_id + + w_offs = (token_ids[:, None] * stride_wt_v) + (offs_d[None, :] * stride_wd) + emb_vec = tl.load(weight + w_offs, mask=mask_seq[:, None] & mask_d[None, :], other=0.0) + + o_offs = (pid_b * stride_ot_b) + (offs_s * stride_ot_s)[:, None] + (offs_d * stride_ot_d)[None, :] + tl.store(out + o_offs, emb_vec, mask=mask_seq[:, None] & mask_d[None, :]) + + +def embedding(indexes: torch.Tensor, weight: torch.Tensor, vob_start_id: int, vob_end_id: int, out: Optional[torch.Tensor] = None) -> torch.Tensor: + B, S = indexes.shape + VOCAB_SIZE, D_MODEL = weight.shape + + out = torch.empty((B, S, D_MODEL), dtype=weight.dtype, device=weight.device) if out is None else out + + BLOCK_N = min(64, triton.next_power_of_2(S)) + BLOCK_NN = min(64, triton.next_power_of_2(D_MODEL)) + + grid = (B, triton.cdiv(S, BLOCK_N), triton.cdiv(D_MODEL, BLOCK_NN)) + + embedding_kernel[grid]( + indexes, weight, out, + indexes.stride(0), indexes.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + VOCAB_SIZE=VOCAB_SIZE, + D_MODEL=D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + vob_start_id=vob_start_id, + vob_end_id=vob_end_id + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stdout new file mode 100644 index 0000000..800d0a5 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_429595.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_429595 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py new file mode 100644 index 0000000..fb26140 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py @@ -0,0 +1,222 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + ids, # [B, L] + weight, # [V, D] + out, # [B, L, D] + stride_ids_b, + stride_ids_l, + stride_weight_v, + stride_weight_d, + stride_out_b, + stride_out_l, + stride_out_d, + V, + D, + BLOCK_L: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_b = tl.program_id(0) # batch dim + pid_l = tl.program_id(1) * BLOCK_L # seq-len dim + pid_d = tl.program_id(2) * BLOCK_D # embed dim + + offs_l = pid_l + tl.arange(0, BLOCK_L) + offs_d = pid_d + tl.arange(0, BLOCK_D) + + mask_l = offs_l < ids.shape[1] + mask_d = offs_d < D + + ids_ptrs = ids + pid_b * stride_ids_b + offs_l * stride_ids_l + idx = tl.load(ids_ptrs, mask=mask_l, other=0).to(tl.int32) + idx = idx[:, None] + + w_ptrs = ( + weight + + idx * stride_weight_v + + offs_d[None, :] * stride_weight_d + ) + + mask_v = (idx >= 0) & (idx < V) + mask = mask_l[:, None] & mask_d[None, :] & mask_v + + embs = tl.load(w_ptrs, mask=mask, other=float(0)) + out_ptrs = ( + out + + pid_b * stride_out_b + + offs_l[:, None] * stride_out_l + + offs_d[None, :] * stride_out_d + ) + tl.store(out_ptrs, embs, mask=mask_l[:, None] & mask_d[None, :]) + + +def embedding( + ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, +) -> torch.Tensor: + assert ids.dtype in (torch.int32, torch.int64) + assert weight.ndim == 2 + assert out.numel() == 0 or out.shape[:-1] == ids.shape + assert vob_end_id - vob_start_id <= weight.size(0) + + B, L = ids.shape + _, D = weight.shape + + if out.numel() == 0: + out = torch.empty((B, L, D), dtype=weight.dtype, device=weight.device) + + BLOCK_L = 64 + BLOCK_D = triton.next_power_of_2(D) + + grid = ( + B, + triton.cdiv(L, BLOCK_L), + triton.cdiv(D, BLOCK_D), + ) + + embedding_kernel[grid]( + ids, + weight, + out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + weight.size(0), # V + weight.size(1), # D + BLOCK_L=BLOCK_L, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stdout new file mode 100644 index 0000000..cbf28ac --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_43398.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_43398 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py new file mode 100644 index 0000000..65ca966 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def embedding_kernel(indexes, weight, out, + stride_idx_b, stride_idx_s, + stride_wt_v, stride_wd, + stride_ot_b, stride_ot_s, stride_ot_d, + VOCAB_SIZE: tl.constexpr, D_MODEL: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, vob_start_id: tl.constexpr, vob_end_id: tl.constexpr): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + pid_d = tl.program_id(2) + + seq_start = pid_s * BLOCK_N + d_start = pid_d * BLOCK_NN + + offs_s = seq_start + tl.arange(0, BLOCK_N) + offs_d = d_start + tl.arange(0, BLOCK_NN) + + mask_seq = offs_s < (seq_start + BLOCK_N) + mask_d = offs_d < D_MODEL + + idx_ptr = indexes + pid_b * stride_idx_b + offs_s * stride_idx_s + token_ids = tl.load(idx_ptr, mask=mask_seq, other=0) + + clamp_low = tl.full_like(token_ids, vob_start_id) + clamp_high = tl.full_like(token_ids, vob_end_id - 1) + token_ids = tl.where(token_ids < vob_start_id, clamp_low, token_ids) + token_ids = tl.where(token_ids > (vob_end_id - 1), clamp_high, token_ids) + token_ids = token_ids - vob_start_id + + w_offs = (token_ids[:, None] * stride_wt_v) + (offs_d[None, :] * stride_wd) + emb_vec = tl.load(weight + w_offs, mask=mask_seq[:, None] & mask_d[None, :], other=0.0) + + o_offs = (pid_b * stride_ot_b) + (offs_s * stride_ot_s)[:, None] + (offs_d * stride_ot_d)[None, :] + tl.store(out + o_offs, emb_vec, mask=mask_seq[:, None] & mask_d[None, :]) + + +def embedding(indexes: torch.Tensor, weight: torch.Tensor, vob_start_id: int, vob_end_id: int, out: Optional[torch.Tensor] = None) -> torch.Tensor: + B, S = indexes.shape + VOCAB_SIZE, D_MODEL = weight.shape + + out = torch.empty((B, S, D_MODEL), dtype=weight.dtype, device=weight.device) if out is None else out + + BLOCK_N = min(64, triton.next_power_of_2(S)) + BLOCK_NN = min(64, triton.next_power_of_2(D_MODEL)) + + grid = (B, triton.cdiv(S, BLOCK_N), triton.cdiv(D_MODEL, BLOCK_NN)) + + embedding_kernel[grid]( + indexes, weight, out, + indexes.stride(0), indexes.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + VOCAB_SIZE=VOCAB_SIZE, + D_MODEL=D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + vob_start_id=vob_start_id, + vob_end_id=vob_end_id + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stdout new file mode 100644 index 0000000..ead9d0a --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_459432.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_459432 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py new file mode 100644 index 0000000..96e0bc0 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py @@ -0,0 +1,189 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(weight, + ids, + out, + stride_w, + stride_out, + num_tokens, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid = tl.program_id(0) + pid_d = tl.program_id(1) + + start_token = pid * BLOCK_NN + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + tok_idx = start_token + tl.arange(0, BLOCK_NN) + + mask_t = tok_idx < num_tokens + mask_d = offs_d < (stride_w * 1) # stride_w is row-stride (D) + + ids_vec = tl.load(ids + tok_idx, mask=mask_t) + + for i in range(BLOCK_NN): + valid_tok = (start_token + i) < num_tokens + token_id = ids_vec[i] # scalar token-id + w_off = token_id * stride_w + offs_d + vec = tl.load(weight + w_off, mask=mask_d) + o_off = (start_token + i) * stride_out + offs_d + tl.store(out + o_off, vec, mask=mask_d) + + +def embedding(ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, + BLOCK_N: int = 1, + BLOCK_NN: int = 32, + BLOCK_DMODEL: int = None): + _, d_model = weight.shape + num_tokens = ids.numel() + + if BLOCK_DMODEL is None: + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = (triton.cdiv(num_tokens, BLOCK_NN), + triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + weight, + ids, + out, + stride_w=weight.stride(0), + stride_out=out.stride(0), + num_tokens=num_tokens, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stdout new file mode 100644 index 0000000..073078d --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_474863.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_474863 due to at 24:19: + start_token = pid * BLOCK_NN + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + tok_idx = start_token + tl.arange(0, BLOCK_NN) + + mask_t = tok_idx < num_tokens + mask_d = offs_d < (stride_w * 1) # stride_w is row-stride (D) + + ids_vec = tl.load(ids + tok_idx, mask=mask_t) + + for i in range(BLOCK_NN): + valid_tok = (start_token + i) < num_tokens + token_id = ids_vec[i] # scalar token-id + ^ +ValueError('Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)') diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py new file mode 100644 index 0000000..6fd91fc --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py @@ -0,0 +1,218 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids_ptr, + weight_ptr, + out_ptr, + seq_len, + num_tokens, + d_model, + vob_start_id, + vob_end_id, + stride_id, + stride_w0, + stride_w1, + stride_out0, + stride_out1, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid_seq = tl.program_id(0) + pid_col = tl.program_id(1) + + if pid_seq >= seq_len: + return + + cols = pid_col * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + d_mask = cols < d_model + out_offset = pid_seq * stride_out0 + cols * stride_out1 + weight_offset_col = cols * stride_w1 + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + n_offsets = tl.arange(0, BLOCK_N) + n_mask = n_offsets < num_tokens + flat_ids = pid_seq * stride_id + n_offsets + token_ids = tl.load(token_ids_ptr + flat_ids, mask=n_mask, other=0) + valid = (token_ids >= vob_start_id) & (token_ids < vob_end_id) + token_ids_off = token_ids - vob_start_id + + w_offs = token_ids_off * stride_w0 + weight_offset_col + w_vals = tl.load(weight_ptr + w_offs, mask=d_mask[:, None] & valid[None, :], other=0.0) + acc = tl.sum(w_vals, axis=1) + + acc = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + out_offset, acc, mask=d_mask) + + +def embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor = None, +) -> torch.Tensor: + if token_ids.dim() == 1: + seq_len, num_tokens = token_ids.numel(), 1 + else: + seq_len, num_tokens = token_ids.shape + + _, d_model = weight.shape + assert token_ids.dtype == torch.int32 + assert weight.dtype in [torch.float16, torch.float32] + assert weight.is_contiguous() + + if out is None: + out = torch.empty((seq_len, d_model), dtype=weight.dtype, device=weight.device) + + BLOCK_DMODEL = triton.next_power_of_2(d_model) + BLOCK_N = min(triton.next_power_of_2(num_tokens), 64) + + grid = (seq_len, triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + seq_len, + num_tokens, + d_model, + vob_start_id, + vob_end_id, + token_ids.stride(0) if token_ids.dim() > 1 else 1, + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stdout new file mode 100644 index 0000000..b5ea42b --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_477598.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_477598 due to at 39:13: + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + n_offsets = tl.arange(0, BLOCK_N) + n_mask = n_offsets < num_tokens + flat_ids = pid_seq * stride_id + n_offsets + token_ids = tl.load(token_ids_ptr + flat_ids, mask=n_mask, other=0) + valid = (token_ids >= vob_start_id) & (token_ids < vob_end_id) + token_ids_off = token_ids - vob_start_id + + w_offs = token_ids_off * stride_w0 + weight_offset_col + w_vals = tl.load(weight_ptr + w_offs, mask=d_mask[:, None] & valid[None, :], other=0.0) + ^ diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py new file mode 100644 index 0000000..feb12f4 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py @@ -0,0 +1,187 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(weight, + ids, + out, + stride_w, + stride_out, + num_tokens, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid_bn = tl.program_id(0) # block id over the batch dimension + pid_d = tl.program_id(1) # block id over the d_model dimension + + # offset of token indices handled by this program instance + offs_n = pid_bn * BLOCK_NN + tl.arange(0, BLOCK_NN) + mask_n = offs_n < num_tokens # mask out-of-bounds tokens + token_ids = tl.load(ids + offs_n, mask=mask_n) # block of token-ids + + # offset of feature dimensions handled by this program instance + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < stride_w # stride_w == d_model + + # load d_model vectors, one per active token_id, using broadcasting + # shape = (BLOCK_NN, BLOCK_DMODEL) + w_offs = token_ids[:, None] * stride_w + offs_d[None, :] + vec = tl.load(weight + w_offs, mask=mask_n[:, None] & mask_d[None, :]) + + # write to output tensor + o_offs = offs_n[:, None] * stride_out + offs_d[None, :] + tl.store(out + o_offs, vec, mask=mask_n[:, None] & mask_d[None, :]) + + +def embedding(ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, + BLOCK_NN: int = 32, + BLOCK_DMODEL: int = None): + num_tokens = ids.numel() + _, d_model = weight.shape + + if BLOCK_DMODEL is None: + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = ( + triton.cdiv(num_tokens, BLOCK_NN), + triton.cdiv(d_model, BLOCK_DMODEL), + ) + + embedding_kernel[grid]( + weight, + ids, + out, + stride_w=weight.stride(0), + stride_out=out.stride(0), + num_tokens=num_tokens, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stdout new file mode 100644 index 0000000..863aa16 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_480728.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_480728.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py new file mode 100644 index 0000000..65ca966 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def embedding_kernel(indexes, weight, out, + stride_idx_b, stride_idx_s, + stride_wt_v, stride_wd, + stride_ot_b, stride_ot_s, stride_ot_d, + VOCAB_SIZE: tl.constexpr, D_MODEL: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, vob_start_id: tl.constexpr, vob_end_id: tl.constexpr): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + pid_d = tl.program_id(2) + + seq_start = pid_s * BLOCK_N + d_start = pid_d * BLOCK_NN + + offs_s = seq_start + tl.arange(0, BLOCK_N) + offs_d = d_start + tl.arange(0, BLOCK_NN) + + mask_seq = offs_s < (seq_start + BLOCK_N) + mask_d = offs_d < D_MODEL + + idx_ptr = indexes + pid_b * stride_idx_b + offs_s * stride_idx_s + token_ids = tl.load(idx_ptr, mask=mask_seq, other=0) + + clamp_low = tl.full_like(token_ids, vob_start_id) + clamp_high = tl.full_like(token_ids, vob_end_id - 1) + token_ids = tl.where(token_ids < vob_start_id, clamp_low, token_ids) + token_ids = tl.where(token_ids > (vob_end_id - 1), clamp_high, token_ids) + token_ids = token_ids - vob_start_id + + w_offs = (token_ids[:, None] * stride_wt_v) + (offs_d[None, :] * stride_wd) + emb_vec = tl.load(weight + w_offs, mask=mask_seq[:, None] & mask_d[None, :], other=0.0) + + o_offs = (pid_b * stride_ot_b) + (offs_s * stride_ot_s)[:, None] + (offs_d * stride_ot_d)[None, :] + tl.store(out + o_offs, emb_vec, mask=mask_seq[:, None] & mask_d[None, :]) + + +def embedding(indexes: torch.Tensor, weight: torch.Tensor, vob_start_id: int, vob_end_id: int, out: Optional[torch.Tensor] = None) -> torch.Tensor: + B, S = indexes.shape + VOCAB_SIZE, D_MODEL = weight.shape + + out = torch.empty((B, S, D_MODEL), dtype=weight.dtype, device=weight.device) if out is None else out + + BLOCK_N = min(64, triton.next_power_of_2(S)) + BLOCK_NN = min(64, triton.next_power_of_2(D_MODEL)) + + grid = (B, triton.cdiv(S, BLOCK_N), triton.cdiv(D_MODEL, BLOCK_NN)) + + embedding_kernel[grid]( + indexes, weight, out, + indexes.stride(0), indexes.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + VOCAB_SIZE=VOCAB_SIZE, + D_MODEL=D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + vob_start_id=vob_start_id, + vob_end_id=vob_end_id + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stdout new file mode 100644 index 0000000..20fcf39 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_490985.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_490985 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py new file mode 100644 index 0000000..e60fd1e --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py @@ -0,0 +1,228 @@ + +import torch +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Triton kernel +# --------------------------------------------------------------------------- +@triton.jit +def embedding_kernel( + ids, # pointer to token-ids tensor shape [B, S] (int64) + weight, # pointer to embedding weight matrix shape [V, D] (float16 or float32) + out, # pointer to output 3-D tensor shape [B, S, D] + B, # int: number of sequences (batch size) + S, # int: max sequence length for all sequences + V, # int: vocabulary size + D, # int: embedding dimension + stride_ids_0, # leading stride of ids: = S + stride_w_0, # leading stride of weight: = D + stride_out_0, # leading stride of out: = S * D + stride_out_1, # + stride_out_2, # + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) # 1-D grid: each block handles one sequence in the batch + + # each block handles every BLOCK_N tokens, each tid loop handles BLOCK_NN tokens + for b_base in range(0, S, BLOCK_N): + ids_offset = pid * stride_ids_0 + b_base + # Load mask + n_ids = tl.arange(b_base, b_base + BLOCK_N) + mask_n = n_ids < S + + # Load token indices + token_ids = tl.load(ids + ids_offset + tl.arange(0, BLOCK_N), mask=mask_n, other=0) + + # Ensure token_ids in [0, V-1] + token_ids = tl.maximum(0, token_ids) + token_ids = tl.minimum(V - 1, token_ids) + + # Iterate over tokens in groups of BLOCK_NN + for start in range(0, BLOCK_N, BLOCK_NN): + idx_group = start + tl.arange(0, BLOCK_NN) + group_mask = mask_n & (idx_group < BLOCK_N) + + # Current token ids for this group + tid = token_ids[start : start + BLOCK_NN] # shape [BLOCK_NN] + outs_idx = pid * stride_out_0 + (b_base + start + tl.arange(0, BLOCK_NN)) * stride_out_1 + + # Iterate over the embedding dimension in blocks + for d_start in range(0, D, BLOCK_DMODEL): + offs_d = d_start + tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < D + + valid_mask = group_mask[:, None] & mask_d[None, :] + + # Weight pointer: address strides: weight[tid, d_offs] = weight + tid * stride_w_0 + offs_d + weight_ptr = weight + tid[:, None] * stride_w_0 + offs_d[None, :] + emb_vec = tl.load(weight_ptr, mask=valid_mask, other=0.0) + + # Output pointer: address strides + output_ptr = out + outs_idx[:, None] * stride_out_2 + offs_d[None, :] + tl.store(output_ptr, emb_vec, mask=valid_mask) + +# --------------------------------------------------------------------------- +# Python wrapper +# --------------------------------------------------------------------------- +def embedding( + ids: torch.Tensor, # [B, S] long int + weight: torch.Tensor, # [V, D] float16 or float32 +) -> torch.Tensor: # Returns: [B, S, D] + B, S = ids.shape + V, D = weight.shape + device = weight.device + dtype = weight.dtype + + out = torch.empty((B, S, D), dtype=dtype, device=device) + + BLOCK_N = 64 + BLOCK_NN = 64 + BLOCK_DMODEL = triton.next_power_of_2(D) + + grid = (B,) + + embedding_kernel[grid]( + ids, # int64 + weight, # fp16 / fp32 + out, # fp16 / fp32 + B, + S, + V, + D, + ids.stride(0), + weight.stride(0), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stdout new file mode 100644 index 0000000..6ea3fa0 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_507685.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_507685 due to embedding() takes 2 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py new file mode 100644 index 0000000..0f773e9 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py @@ -0,0 +1,226 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids, + weight, + out, + stride_b, + stride_l, + stride_v, + stride_d, + stride_ob, + stride_ol, + vob_start_id, + vob_end_id, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + pid_batch = tl.program_id(0) + pid_dim = tl.program_id(1) + + cols_d = pid_dim * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_d = cols_d < stride_d + + rows_n_full = stride_l + for start_n in tl.range(0, rows_n_full, BLOCK_NN): + rows_n = start_n + tl.arange(0, BLOCK_NN) + mask_n = rows_n < rows_n_full + + offset_ids = pid_batch * stride_b + rows_n + ids = tl.load(token_ids + offset_ids, mask=mask_n, other=0) + + mask_valid = (ids >= vob_start_id) & (ids < vob_end_id) + safe_ids = tl.where(mask_valid, ids, 0) + + emb_ptr = weight + (safe_ids[:, None] * stride_v + cols_d[None, :] * 1) + emb_vals = tl.load(emb_ptr, mask=(mask_n[:, None] & mask_d[None, :]), other=0.0) + + out_ptr = out + (pid_batch * stride_ob + rows_n[:, None] * stride_ol + cols_d[None, :]) + tl.store(out_ptr, emb_vals, mask=(mask_n[:, None] & mask_d[None, :])) + + +def embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor +) -> None: + assert token_ids.dtype in (torch.int64, torch.int32) + V, D = weight.shape + B = token_ids.numel() // token_ids.size(-1) if token_ids.ndim > 1 else 1 + L = token_ids.size(-1) + + if token_ids.ndim == 1: + assert out.numel() == L * D and out.size(-1) == D + else: + assert out.numel() == B * L * D and out.size(-1) == D + + BLOCK_DMODEL = triton.next_power_of_2(D) + BLOCK_N = 32 + BLOCK_NN = 256 + + if token_ids.ndim == 1: + token_ids = token_ids.contiguous() + out_view = out.view(L, D).contiguous() + grid = (1, triton.cdiv(D, BLOCK_DMODEL)) + stride_b = 0 + stride_l = L + stride_ob = 0 + stride_ol = D + else: + token_ids = token_ids.view(B, L).contiguous() + out_view = out.view(B, L, D).contiguous() + grid = (B, triton.cdiv(D, BLOCK_DMODEL)) + stride_b = token_ids.stride(0) + stride_l = token_ids.size(-1) + stride_ob = out_view.stride(0) + stride_ol = out_view.stride(1) + + stride_v = weight.stride(0) + stride_d_true = D + + embedding_kernel[grid]( + token_ids, + weight, + out_view, + stride_b, + stride_l, + stride_v, + stride_d_true, + stride_ob, + stride_ol, + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stdout new file mode 100644 index 0000000..42e34f8 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_524778.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_524778.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py new file mode 100644 index 0000000..e742370 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py @@ -0,0 +1,188 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(indexes, weight, out, stride_idx_b, + stride_idx_s, stride_wt_v, stride_wd, + stride_ot_b, stride_ot_s, stride_ot_d, + VOCAB_SIZE: tl.constexpr, D_MODEL: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + pid_d = tl.program_id(2) + + seq_start = pid_s * BLOCK_N + d_start = pid_d * BLOCK_NN + + idx_base = pid_b * stride_idx_b + seq_start * stride_idx_s + valid_seq_len = tl.load(indexes + idx_base) + valid_seq_len = tl.minimum(valid_seq_len, BLOCK_N) + + offs_s = seq_start + tl.arange(0, BLOCK_N) + offs_d = d_start + tl.arange(0, BLOCK_NN) + + mask_seq = offs_s < (seq_start + valid_seq_len) + mask_d = offs_d < D_MODEL + + idx_ptr = indexes + idx_base + offs_s * stride_idx_s + token_ids = tl.load(idx_ptr, mask=mask_seq, other=0) + + w_offs = (token_ids[:, None] * stride_wt_v) + (offs_d[None, :] * stride_wd) + emb_vec = tl.load(weight + w_offs, mask=mask_seq[:, None] & mask_d[None, :], other=0.0) + + o_offs = (pid_b * stride_ot_b) + (offs_s * stride_ot_s)[:, None] + (offs_d * stride_ot_d)[None, :] + tl.store(out + o_offs, emb_vec, mask=mask_seq[:, None] & mask_d[None, :]) + + +def embedding(indexes: torch.Tensor, weight: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + B, S = indexes.shape + VOCAB_SIZE, D_MODEL = weight.shape + + out = torch.empty((B, S, D_MODEL), dtype=weight.dtype, device=weight.device) if out is None else out + + BLOCK_N = 64 + BLOCK_NN = min(64, triton.next_power_of_2(D_MODEL)) + + assert indexes.is_contiguous() + assert weight.is_contiguous() + assert out.is_contiguous() + + grid = (B, triton.cdiv(S, BLOCK_N), triton.cdiv(D_MODEL, BLOCK_NN)) + + embedding_kernel[grid]( + indexes, weight, out, + indexes.stride(0), indexes.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + VOCAB_SIZE=VOCAB_SIZE, + D_MODEL=D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stdout new file mode 100644 index 0000000..848b09e --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_533885.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_533885 due to embedding() takes from 2 to 3 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py new file mode 100644 index 0000000..4fb1bab --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py @@ -0,0 +1,231 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + ids, # [B, L] + weight, # [V, D] + out, # [B, L, D] + stride_ids_b, + stride_ids_l, + stride_weight_v, + stride_weight_d, + stride_out_b, + stride_out_l, + stride_out_d, + vob_start_id, + V, + D, + BLOCK_L: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) * BLOCK_L + pid_d = tl.program_id(2) * BLOCK_D + + offs_l = pid_l + tl.arange(0, BLOCK_L) + offs_d = pid_d + tl.arange(0, BLOCK_D) + + mask_l = offs_l < ids.shape[1] + mask_d = offs_d < D + + ids_ptrs = ids + pid_b * stride_ids_b + offs_l * stride_ids_l + idx = tl.load(ids_ptrs, mask=mask_l, other=0).to(tl.int32) + idx = idx - vob_start_id + idx = idx[:, None] + + w_ptrs = ( + weight + + idx * stride_weight_v + + offs_d[None, :] * stride_weight_d + ) + + mask_v = (idx >= 0) & (idx < V) + mask = mask_l[:, None] & mask_d[None, :] & mask_v + + embs = tl.load(w_ptrs, mask=mask, other=0.0) + out_ptrs = ( + out + + pid_b * stride_out_b + + offs_l[:, None] * stride_out_l + + offs_d[None, :] * stride_out_d + ) + tl.store(out_ptrs, embs, mask=mask_l[:, None] & mask_d[None, :] & mask_v) + +def embedding( + ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, +) -> torch.Tensor: + assert ids.dtype in (torch.int32, torch.int64) + assert weight.ndim == 2 + inferred_D = weight.shape[1] + if out.numel() == 0: + out = torch.empty((*ids.shape, inferred_D), dtype=weight.dtype, device=weight.device) + else: + assert out.shape[:-1] == ids.shape + assert out.shape[-1] == inferred_D + + B = ids.shape[0] + L = ids.numel() // B + D = inferred_D + V = vob_end_id - vob_start_id + assert V <= weight.shape[0] + + ids = ids.view(B, L) + out = out.view(B, L, D) + + BLOCK_L = 64 + BLOCK_D = triton.next_power_of_2(D) + + grid = ( + B, + triton.cdiv(L, BLOCK_L), + triton.cdiv(D, BLOCK_D), + ) + + embedding_kernel[grid]( + ids, + weight, + out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + vob_start_id, + V, + D, + BLOCK_L=BLOCK_L, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stdout new file mode 100644 index 0000000..9c7addd --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_552958.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_552958 due to at 25:22: + D, + BLOCK_L: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) * BLOCK_L + pid_d = tl.program_id(2) * BLOCK_D + + offs_l = pid_l + tl.arange(0, BLOCK_L) + offs_d = pid_d + tl.arange(0, BLOCK_D) + + mask_l = offs_l < ids.shape[1] + ^ +IndexError('list index out of range') diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py new file mode 100644 index 0000000..77c1639 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py @@ -0,0 +1,210 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + token_ids_ptr, # pointer to 1-D int32 token_ids + weight_ptr, # pointer to 2-D float weights (vocab, d_model) + out_ptr, # pointer to 2-D output (seq_len, d_model) + seq_len, # int + d_model, # int + vob_start_id, # int + vob_end_id, # int + stride_id, # int + stride_w0, # int + stride_w1, # int + stride_out0, # int + stride_out1, # int + BLOCK_DMODEL: tl.constexpr, +): + pid_seq = tl.program_id(0) + pid_col = tl.program_id(1) + + if pid_seq >= seq_len: + return + + # token id for this sequence position + token_id = tl.load(token_ids_ptr + pid_seq * stride_id) + valid = (token_id >= vob_start_id) & (token_id < vob_end_id) + tok_id_off = (token_id - vob_start_id) * stride_w0 # row offset + + cols = pid_col * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + d_mask = cols < d_model + + weight_off = tok_id_off + cols * stride_w1 + val = tl.load(weight_ptr + weight_off, mask=d_mask & valid, other=0.0) + + out_off = pid_seq * stride_out0 + cols * stride_out1 + tl.store(out_ptr + out_off, val, mask=d_mask) + + +def embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor = None, +) -> torch.Tensor: + """ + Extract embedding vectors for each token-id in `token_ids` using 2-D matrix `weight`. + + token_ids : torch.int32 shape(SEQ_LEN,) with values in [vob_start_id, vob_end_id) + weight : weight[vocab, d_model] (slice assumed to start at vocab 0) + """ + + seq_len, = token_ids.shape + vocab_size, d_model = weight.shape + assert token_ids.dtype == torch.int32 + assert weight.dtype in (torch.float16, torch.float32) + assert weight.is_contiguous() + + if out is None: + out = torch.empty((seq_len, d_model), dtype=weight.dtype, device=weight.device) + + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = (seq_len, triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + seq_len, + d_model, + vob_start_id, + vob_end_id, + token_ids.stride(0) if token_ids.dim() == 1 else 1, + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_574109.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py new file mode 100644 index 0000000..7d4854d --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py @@ -0,0 +1,225 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + ptr_ids, # 1d raw pointer + ptr_weight, # 1d raw pointer + ptr_out, # 1d raw pointer + stride_ids_b, + stride_ids_l, + stride_weight_v, + stride_weight_d, + stride_out_b, + stride_out_l, + stride_out_d, + vob_start_id, + B, + L, + D, + V, + BLOCK_L: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_l_blk = tl.program_id(1) + pid_d_blk = tl.program_id(2) + + offs_l = pid_l_blk * BLOCK_L + tl.arange(0, BLOCK_L) + offs_d = pid_d_blk * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_l = offs_l < L + mask_d = offs_d < D + + ids_base = ptr_ids + pid_b * stride_ids_b + idxs = tl.load(ids_base + offs_l * stride_ids_l, + mask=mask_l, other=0).to(tl.int32) + idxs = idxs - vob_start_id + idxs = idxs[:, None] + + w_base = ptr_weight + w_row_stride = stride_weight_v + w_col_stride = stride_weight_d + weight_ptrs = w_base + \ + idxs * w_row_stride + offs_d[None, :] * w_col_stride + + mask_v = (idxs >= 0) & (idxs < V) + mask = mask_l[:, None] & mask_d[None, :] & mask_v + + embs = tl.load(weight_ptrs, mask=mask, other=0.0) + + out_base = ptr_out + pid_b * stride_out_b + out_ptrs = out_base + \ + offs_l[:, None] * stride_out_l + offs_d[None, :] * stride_out_d + tl.store(out_ptrs, embs, mask=mask) + +def embedding( + ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor, +) -> torch.Tensor: + assert ids.dtype in (torch.int32, torch.int64) + assert weight.ndim == 2 + inferred_D = weight.shape[1] + if out.numel() == 0: + out = torch.empty((*ids.shape, inferred_D), dtype=weight.dtype, device=weight.device) + else: + assert out.shape[:-1] == ids.shape + assert out.shape[-1] == inferred_D + + B = ids.shape[0] + L = ids.shape[1] if ids.ndim == 2 else 1 + ids = ids.view(B, L) + out = out.view(B, L, inferred_D) + + D = inferred_D + V = vob_end_id - vob_start_id + assert V <= weight.shape[0] + + BLOCK_L = 64 + BLOCK_D = triton.next_power_of_2(D) + + grid = (B, triton.cdiv(L, BLOCK_L), triton.cdiv(D, BLOCK_D)) + + embedding_kernel[grid]( + ids, weight, out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + vob_start_id, + B, L, D, V, + BLOCK_L=BLOCK_L, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_58716.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py new file mode 100644 index 0000000..66ea5f3 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py @@ -0,0 +1,182 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(token_ids_ptr, out_ptr, weight_ptr, + stride_tokens, stride_outb, stride_outs, stride_outd, + stride_weightv, stride_weightd, + vocab_size: tl.constexpr, hidden_size: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_bn = pid_b * stride_outb + offs_nn = pid_n * BLOCK_NN + + seq_len = tl.load(token_ids_ptr + pid_b * stride_tokens).to(tl.int32) + + for start_n in range(0, seq_len, BLOCK_N): + offs_tok = start_n + tl.arange(0, BLOCK_N) + mask_tok = offs_tok < seq_len + tok_ids = tl.load(token_ids_ptr + offs_bn + offs_tok * stride_tokens, mask=mask_tok, other=0) + + for start_d in range(0, hidden_size, BLOCK_NN): + offs_d = start_d + tl.arange(0, BLOCK_NN) + mask_d = offs_d < hidden_size + offs_weight = tok_ids[:, None] * stride_weightv + offs_d[None, :] * stride_weightd + mask_weight = mask_tok[:, None] & mask_d[None, :] + embedded = tl.load(weight_ptr + offs_weight, mask=mask_weight, other=0.0) + + offs_out = (offs_bn + offs_tok[:, None] * stride_outs + offs_d[None, :] * stride_outd) + mask_out = mask_tok[:, None] & mask_d[None, :] + tl.store(out_ptr + offs_out, embedded, mask=mask_out) + + +def embedding(token_ids: torch.Tensor, weights: torch.Tensor, vob_start_id=None, vob_end_id=None, out: torch.Tensor = None) -> torch.Tensor: + b, seq = token_ids.shape + vocab_size, hidden_size = weights.shape + assert weights.dtype in [torch.float16, torch.float32] + assert token_ids.dtype in [torch.int32, torch.int64] + + if out is None: + out = torch.empty((b, seq, hidden_size), dtype=weights.dtype, device=weights.device) + + BLOCK_N = 64 + BLOCK_NN = min(128, triton.next_power_of_2(hidden_size)) + + grid = (b, triton.cdiv(seq, BLOCK_N)) + + embedding_kernel[grid]( + token_ids, out, weights, + token_ids.stride(0), out.stride(0), out.stride(1), out.stride(2), + weights.stride(0), weights.stride(1), + vocab_size, hidden_size, + BLOCK_N, BLOCK_NN + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stdout new file mode 100644 index 0000000..0da5f55 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_600998.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_600998 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py new file mode 100644 index 0000000..bbd01e5 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(weight_ptr, + ids_ptr, + out_ptr, + stride_w, + stride_out, + num_tokens, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid = tl.program_id(0) # token block + pid_d = tl.program_id(1) # d-model block + start_token = pid * BLOCK_NN + + ids = tl.load(ids_ptr + start_token + tl.arange(0, BLOCK_NN)) + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + mask_tokens = start_token + tl.arange(0, BLOCK_NN) < num_tokens + mask_d = offs_d < stride_w + + for i in range(BLOCK_NN): + if not mask_tokens[i]: + break + token_id = ids[i] + w_off = token_id * stride_w + offs_d + vec = tl.load(weight_ptr + w_off, mask=mask_d) + o_off = (start_token + i) * stride_out + offs_d + tl.store(out_ptr + o_off, vec, mask=mask_d) + + +def embedding(weight: torch.Tensor, + ids: torch.Tensor, + out: torch.Tensor, + BLOCK_N: int = 1, + BLOCK_NN: int = 32, + BLOCK_DMODEL: int = None): + assert weight.ndim == 2 + assert ids.ndim == 1 + assert out.ndim == 2 + assert out.shape == (ids.shape[0], weight.shape[1]) + + _, d_model = weight.shape + num_tokens = ids.numel() + + if BLOCK_DMODEL is None: + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = (triton.cdiv(num_tokens, BLOCK_NN), + triton.cdiv(d_model, BLOCK_DMODEL)) + embedding_kernel[grid]( + weight, + ids, + out, + stride_w=weight.stride(0), + stride_out=out.stride(0), + num_tokens=num_tokens, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stdout new file mode 100644 index 0000000..b69f644 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_605163.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_605163 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py new file mode 100644 index 0000000..cfdc332 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py @@ -0,0 +1,178 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + out_ptr, ids_ptr, weight_ptr, + stride_os, stride_om, + stride_ws, stride_wm, + N, D_MODEL, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + rows_seq = pid * BLOCK_N + tl.arange(0, BLOCK_N) + mask_rows = rows_seq < N + + token_ids = tl.load(ids_ptr + rows_seq, mask=mask_rows, other=0) + # Mask out-of-range indices to 0 to prevent invalid reads + token_ids = tl.where((token_ids >= 0) & (token_ids < stride_ws), token_ids, 0) + + cols_d = tl.arange(0, BLOCK_DMODEL) + mask_cols = cols_d < D_MODEL + + w_ptrs = weight_ptr + (token_ids[:, None] * stride_ws + cols_d[None, :] * stride_wm) + x = tl.load(w_ptrs, mask=mask_rows[:, None] & mask_cols[None, :], other=0.0) + + o_ptrs = out_ptr + (rows_seq[:, None] * stride_os + cols_d[None, :] * stride_om) + tl.store(o_ptrs, x, mask=mask_rows[:, None] & mask_cols[None, :]) + + +def embedding(ids: torch.Tensor, weight: torch.Tensor, + out: torch.Tensor = None) -> torch.Tensor: + ids = ids.contiguous() + N = ids.numel() + D_MODEL = weight.shape[-1] + + if out is None: + out = torch.empty((N, D_MODEL), dtype=weight.dtype, device=weight.device) + + BLOCK_N = 32 + BLOCK_DMODEL = triton.next_power_of_2(D_MODEL) + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_N']),) + + embedding_kernel[grid]( + out, ids, weight, + out.stride(0), out.stride(1), + weight.stride(0), weight.stride(1), + N, D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stdout new file mode 100644 index 0000000..1840465 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_620455.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_620455 due to embedding() takes from 2 to 3 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py new file mode 100644 index 0000000..fc47731 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py @@ -0,0 +1,182 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(token_ids_ptr, out_ptr, weight_ptr, + stride_tokens, stride_outb, stride_outs, stride_outd, + stride_weightv, stride_weightd, + vocab_size: tl.constexpr, hidden_size: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_bn = pid_b * stride_outb + offs_nn = pid_n * BLOCK_NN + + seq_len = tl.load(token_ids_ptr + pid_b * stride_tokens).to(tl.int32) + + for start_n in range(0, seq_len, BLOCK_N): + offs_tok = start_n + tl.arange(0, BLOCK_N) + mask_tok = offs_tok < seq_len + tok_ids = tl.load(token_ids_ptr + offs_bn + offs_tok * stride_tokens, mask=mask_tok, other=0) + + for start_d in range(0, hidden_size, BLOCK_NN): + offs_d = start_d + tl.arange(0, BLOCK_NN) + mask_d = offs_d < hidden_size + offs_weight = tok_ids[:, None] * stride_weightv + offs_d[None, :] * stride_weightd + mask_weight = mask_tok[:, None] & mask_d[None, :] + embedded = tl.load(weight_ptr + offs_weight, mask=mask_weight, other=0.0) + + offs_out = (offs_bn + offs_tok[:, None] * stride_outs + offs_d[None, :] * stride_outd) + mask_out = mask_tok[:, None] & mask_d[None, :] + tl.store(out_ptr + offs_out, embedded, mask=mask_out) + + +def embedding(token_ids: torch.Tensor, weights: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + b, seq = token_ids.shape + vocab_size, hidden_size = weights.shape + assert weights.dtype in [torch.float16, torch.float32] + assert token_ids.dtype in [torch.int32, torch.int64] + + if out is None: + out = torch.empty((b, seq, hidden_size), dtype=weights.dtype, device=weights.device) + + BLOCK_N = 64 + BLOCK_NN = min(128, triton.next_power_of_2(hidden_size)) + + grid = (b, triton.cdiv(seq, BLOCK_N)) + + embedding_kernel[grid]( + token_ids, out, weights, + token_ids.stride(0), out.stride(0), out.stride(1), out.stride(2), + weights.stride(0), weights.stride(1), + vocab_size, hidden_size, + BLOCK_N, BLOCK_NN + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stdout new file mode 100644 index 0000000..60eaafa --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_635331.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_635331 due to embedding() takes from 2 to 3 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py new file mode 100644 index 0000000..fa3672f --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py @@ -0,0 +1,189 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + out_ptr, ids_ptr, weight_ptr, + stride_os, stride_om, + stride_ws, stride_wm, + N, D_MODEL, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + rows_seq = pid * BLOCK_N + tl.arange(0, BLOCK_N) + mask_rows = rows_seq < N + + token_ids = tl.load(ids_ptr + rows_seq, mask=mask_rows, other=-1) + + cols_d = tl.arange(0, BLOCK_DMODEL) + mask_cols = cols_d < D_MODEL + + mask_valid = mask_rows[:, None] & mask_cols[None, :] & ( + token_ids[:, None] >= 0) & (token_ids[:, None] < stride_ws) + + w_ptrs = weight_ptr + ( + token_ids[:, None] * stride_ws + cols_d[None, :] * stride_wm) + x = tl.load(w_ptrs, mask=mask_valid, other=0.0) + + o_ptrs = out_ptr + (rows_seq[:, None] * stride_os + + cols_d[None, :] * stride_om) + tl.store(o_ptrs, x, mask=mask_valid) + + +def embedding(ids: torch.Tensor, weight: torch.Tensor, + vob_start_id: int = 0, vob_end_id: int = None, + out: torch.Tensor = None) -> torch.Tensor: + ids = ids.contiguous() + N = ids.numel() + D_MODEL = weight.shape[-1] + + if vob_end_id is None: + vob_end_id = weight.shape[0] + + if out is None: + out = torch.empty((N, D_MODEL), dtype=weight.dtype, + device=weight.device) + + if N == 0: + return out + + BLOCK_N = 32 + BLOCK_DMODEL = triton.next_power_of_2(D_MODEL) + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_N']),) + + embedding_kernel[grid]( + out, ids, weight, + out.stride(0), out.stride(1), + weight.stride(0), weight.stride(1), + N, D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stdout new file mode 100644 index 0000000..820f029 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_64602.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_64602.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py new file mode 100644 index 0000000..2adf69f --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py @@ -0,0 +1,205 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids, # [B, L] int64 + weight, # [V, D] any + out, # [B, L, D] + stride_b, stride_l, # token_ids strides + stride_v, stride_d, # weight strides + stride_ob, stride_ol, # out strides + vob_start_id, # lower inclusive vocab bound + vob_end_id, # upper exclusive vocab bound + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + pid_batch = tl.program_id(0) + pid_dim = tl.program_id(1) + + cols_d = pid_dim * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_d = cols_d < stride_d # mask along D + + for start_l in tl.range(0, stride_l, BLOCK_NN): + cols_l = start_l + tl.arange(0, BLOCK_NN) # [BLOCK_NN] + mask_l = cols_l < stride_l # mask along L + + offset_ids = pid_batch * stride_b + cols_l # [BLOCK_NN] + ids = tl.load(token_ids + offset_ids, mask=mask_l) # [BLOCK_NN] + + # clamp to valid range; then mask ids that are out-of-vocabulary bounds + ids = tl.where((ids >= vob_start_id) & (ids < vob_end_id), ids, vob_start_id) + valid_mask = (ids >= vob_start_id) & (ids < vob_end_id) + + offset_weight = ids[:, None] * stride_v + cols_d[None, :] # [BLOCK_NN, BLOCK_DMODEL] + vals = tl.load(weight + offset_weight, + mask=mask_l[:, None] & mask_d[None, :] & valid_mask[:, None]) + + offset_out = pid_batch * stride_ob + cols_l[:, None] * stride_ol + \ + cols_d[None, :] # [BLOCK_NN, BLOCK_DMODEL] + tl.store(out + offset_out, + vals, + mask=mask_l[:, None] & mask_d[None, :]) + + +def embedding(token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor) -> None: + assert token_ids.device == weight.device == out.device + assert token_ids.is_cuda or token_ids.device.type == "cuda" + assert token_ids.dtype is torch.int64 + assert weight.ndim == 2 and out.ndim == 3 + V, D = weight.shape + B, L = token_ids.shape + assert out.shape == (B, L, D) + + BLOCK_DMODEL = triton.next_power_of_2(D) + BLOCK_N = 32 # optimal warp granularity + BLOCK_NN = 256 # sequence span loaded per program + + grid = (B, triton.cdiv(D, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + token_ids.stride(0), + token_ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stdout new file mode 100644 index 0000000..0b3a539 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_68534.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_68534 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py new file mode 100644 index 0000000..86d64b9 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py @@ -0,0 +1,191 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(token_ids_ptr, out_ptr, weight_ptr, + stride_tokens, stride_outb, stride_outs, stride_outd, + stride_weightv, stride_weightd, + vocab_size: tl.constexpr, hidden_size: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_bn = pid_b * stride_outb + offs_nn = pid_n * BLOCK_NN + + seq_len = tl.load(token_ids_ptr + pid_b * stride_tokens).to(tl.int32) + + for start_n in range(0, seq_len, BLOCK_N): + offs_tok = start_n + tl.arange(0, BLOCK_N) + mask_tok = offs_tok < seq_len + tok_ids = tl.load(token_ids_ptr + offs_bn + offs_tok * stride_tokens, mask=mask_tok, other=0) + + for start_d in range(0, hidden_size, BLOCK_NN): + offs_d = start_d + tl.arange(0, BLOCK_NN) + mask_d = offs_d < hidden_size + offs_weight = tok_ids[:, None] * stride_weightv + offs_d[None, :] * stride_weightd + mask_weight = mask_tok[:, None] & mask_d[None, :] + embedded = tl.load(weight_ptr + offs_weight, mask=mask_weight, other=0.0) + + offs_out = (offs_bn + offs_tok[:, None] * stride_outs + offs_d[None, :] * stride_outd) + mask_out = mask_tok[:, None] & mask_d[None, :] + tl.store(out_ptr + offs_out, embedded, mask=mask_out) + + +def embedding(token_ids: torch.Tensor, weights: torch.Tensor, vob_start_id=None, vob_end_id=None, + out: torch.Tensor = None) -> torch.Tensor: + if token_ids.dim() == 1: + token_ids = token_ids.unsqueeze(0) + elif token_ids.dim() != 2: + raise ValueError("token_ids should be 1-D or 2-D tensor") + + vocab_size, hidden_size = weights.shape + batch, seq = token_ids.shape + assert weights.dtype in [torch.float16, torch.float32] + assert token_ids.dtype in [torch.int32, torch.int64] + + if out is None: + out = torch.empty((batch, seq, hidden_size), dtype=weights.dtype, device=weights.device) + + # Prepare tensor holding seq lengths + seq_len = torch.full((batch,), seq, dtype=torch.int32, device=token_ids.device) + + BLOCK_N = 64 + BLOCK_NN = min(128, triton.next_power_of_2(hidden_size)) + + grid = (batch, triton.cdiv(seq, BLOCK_N)) + + embedding_kernel[grid]( + seq_len, out, weights, + seq_len.stride(0), out.stride(0), out.stride(1), out.stride(2), + weights.stride(0), weights.stride(1), + vocab_size, hidden_size, + BLOCK_N, BLOCK_NN + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stdout new file mode 100644 index 0000000..7228cc8 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_713720.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_713720 due to Dimension out of range (expected to be in range of [-2, 1], but got 2) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py new file mode 100644 index 0000000..cc0ea63 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py @@ -0,0 +1,232 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + ids, # int32/64 [B, L] + weight, # fp* [V, D] + out, # fp* [B, L, D] + stride_ids_b, # tl.constexpr ignores run-time values + stride_ids_l, + stride_weight_v, + stride_weight_d, + stride_out_b, + stride_out_l, + stride_out_d, + V, + D, + BLOCK_L: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_b = tl.program_id(0) # batch dim + pid_l = tl.program_id(1) * BLOCK_L # seq-len dim + pid_d = tl.program_id(2) * BLOCK_D # embed dim + + # Compute ranges + offs_l = pid_l + tl.arange(0, BLOCK_L) # [BLOCK_L] + offs_d = pid_d + tl.arange(0, BLOCK_D) # [BLOCK_D] + + mask_l = offs_l < ids.shape[1] # [BLOCK_L] + mask_d = offs_d < D # [BLOCK_D] + + # --- Load token ids for this tile ------------------------------------------------- + ids_ptr = ids + pid_b * stride_ids_b + offs_l * stride_ids_l # [BLOCK_L] + idx = tl.load(ids_ptr, mask=mask_l, other=0).to(tl.int32) # [BLOCK_L] + + # Broadcast ids for weight lookup + # idx: [BLOCK_L] -> [BLOCK_L, 1] + idx = idx[:, None] + + # --- Load weight rows ------------------------------------------------------------- + w_ptrs = ( + weight + + idx * stride_weight_v # broadcast: [BLOCK_L, 1] * stride + + offs_d[None, :] * stride_weight_d # broadcast: [1, BLOCK_D] * stride + ) # -> [BLOCK_L, BLOCK_D] + + mask_v = (idx >= 0) & (idx < V) # row-valid mask: [BLOCK_L, 1] + mask = mask_l[:, None] & mask_d[None, :] & mask_v + + embs = tl.load(w_ptrs, mask=mask, other=0.0) # [BLOCK_L, BLOCK_D] + + # --- Store into output tensor ------------------------------------------------------ + out_ptrs = ( + out + + pid_b * stride_out_b + + offs_l[:, None] * stride_out_l + + offs_d[None, :] * stride_out_d + ) # [BLOCK_L, BLOCK_D] + + tl.store(out_ptrs, embs, mask=mask_l[:, None] & mask_d[None, :]) + + +def embedding_forward( + ids: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Triton-accelerated embedding lookup. + ids : [B, L] (int32/int64) + weight: [V, D] + returns: [B, L, D] + """ + assert ids.dtype in {torch.int32, torch.int64}, "ids must be int32/int64" + assert weight.ndim == 2, "weight should be 2-D: [V, D]" + + B, L = ids.shape + V, D = weight.shape + out = torch.empty((B, L, D), dtype=weight.dtype, device=weight.device) + + # choose tile sizes that divide dimensions well + BLOCK_L = 64 + BLOCK_D = triton.next_power_of_2(D) + + grid = ( + B, + triton.cdiv(L, BLOCK_L), + triton.cdiv(D, BLOCK_D), + ) + + embedding_kernel[grid]( + ids, + weight, + out, + ids.stride(0), + ids.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + V, + D, + BLOCK_L=BLOCK_L, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stdout new file mode 100644 index 0000000..fbb5315 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_721645.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_721645 due to name 'embedding' is not defined diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py new file mode 100644 index 0000000..8f9f373 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py @@ -0,0 +1,181 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + out_ptr, ids_ptr, weight_ptr, + stride_os, stride_om, + stride_ws, stride_wm, + N, D_MODEL, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + rows_seq = pid * BLOCK_N + tl.arange(0, BLOCK_N) + mask_rows = rows_seq < N + + token_ids = tl.load(ids_ptr + rows_seq, mask=mask_rows, other=0) + token_ids = tl.where((token_ids >= 0) & (token_ids < stride_ws), token_ids, 0) + + cols_d = tl.arange(0, BLOCK_DMODEL) + mask_cols = cols_d < D_MODEL + + w_ptrs = weight_ptr + (token_ids[:, None] * stride_ws + cols_d[None, :] * stride_wm) + x = tl.load(w_ptrs, mask=mask_rows[:, None] & mask_cols[None, :], other=0.0) + + o_ptrs = out_ptr + (rows_seq[:, None] * stride_os + cols_d[None, :] * stride_om) + tl.store(o_ptrs, x, mask=mask_rows[:, None] & mask_cols[None, :]) + + +def embedding(ids: torch.Tensor, weight: torch.Tensor, + vob_start_id: int = 0, vob_end_id: int = None, + out: torch.Tensor = None) -> torch.Tensor: + ids = ids.contiguous() + N = ids.numel() + D_MODEL = weight.shape[-1] + + if vob_end_id is None: + vob_end_id = weight.shape[0] + + if out is None: + out = torch.empty((N, D_MODEL), dtype=weight.dtype, device=weight.device) + + BLOCK_N = 32 + BLOCK_DMODEL = triton.next_power_of_2(D_MODEL) + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_N']),) + + embedding_kernel[grid]( + out, ids, weight, + out.stride(0), out.stride(1), + weight.stride(0), weight.stride(1), + N, D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stdout new file mode 100644 index 0000000..b1cb2da --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_759146.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_759146.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py new file mode 100644 index 0000000..3a694b1 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py @@ -0,0 +1,197 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(token_ids_ptr, out_ptr, weight_ptr, + stride_tokens_b, stride_tokens_s, + stride_out_b, stride_out_s, stride_out_d, + stride_weight_v, stride_weight_d, + seq_len, vocab_size, hidden_size, + BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr): + pid_b = tl.program_id(0) # batch + pid_n = tl.program_id(1) # sequence block + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # [BLOCK_N] + + token_ptr = token_ids_ptr + pid_b * stride_tokens_b + offs_n * stride_tokens_s + mask_n = offs_n < seq_len + tok_ids = tl.load(token_ptr, mask=mask_n, other=0) # [BLOCK_N] int32 + + for start_d in range(0, hidden_size, BLOCK_D): + offs_d = start_d + tl.arange(0, BLOCK_D) # [BLOCK_D] + + mask_d = offs_d < hidden_size + mask_w = mask_n[:, None] & mask_d[None, :] + + # weight: [v, h] => gather[token, :] => [BLOCK_N, BLOCK_D] + w_offs = tok_ids[:, None] * stride_weight_v + offs_d[None, :] * stride_weight_d + emb = tl.load(weight_ptr + w_offs, mask=mask_w, other=0.0) + + # out : [b, seq, h] + o_offs = pid_b * stride_out_b + offs_n[:, None] * stride_out_s + offs_d[None, :] * stride_out_d + tl.store(out_ptr + o_offs, emb, mask=mask_w) + + +def embedding(token_ids: torch.Tensor, weights: torch.Tensor, + out: torch.Tensor = None) -> torch.Tensor: + if token_ids.dim() == 1: + token_ids = token_ids.unsqueeze(0) + elif token_ids.dim() != 2: + raise ValueError("token_ids must be 1-D or 2-D") + + vocab_size, hidden_size = weights.shape + batch, seq_len = token_ids.shape + + if weights.dtype not in (torch.float16, torch.float32, torch.bfloat16): + raise TypeError("weights dtype must be fp16/fp32/bf16") + if token_ids.dtype not in (torch.int32, torch.int64): + raise TypeError("token_ids dtype must be int32/int64") + + if out is None: + out = torch.empty((batch, seq_len, hidden_size), + dtype=weights.dtype, device=weights.device) + + token_ids = token_ids.contiguous() + weights = weights.contiguous() + out = out.contiguous() + + BLOCK_N = 64 + BLOCK_D = min(128, triton.next_power_of_2(hidden_size)) + + grid = (batch, triton.cdiv(seq_len, BLOCK_N)) + + embedding_kernel[grid]( + token_ids, out, weights, + token_ids.stride(0), token_ids.stride(1), + out.stride(0), out.stride(1), out.stride(2), + weights.stride(0), weights.stride(1), + seq_len, vocab_size, hidden_size, + BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stdout new file mode 100644 index 0000000..41a52b4 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_764635.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_764635 due to embedding() takes from 2 to 3 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py new file mode 100644 index 0000000..14e998d --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py @@ -0,0 +1,188 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids, # [B, L] + weight, # [V, D] + out, # [B, L, D] + stride_b, stride_l, stride_v, stride_d, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, BLOCK_DMODEL: tl.constexpr +): + pid_batch = tl.program_id(0) + pid_dim = tl.program_id(1) + + # Offset for D dimension in weight and out + cols_d = pid_dim * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_d = cols_d < stride_d + + # Iterate over sequence positions in strides of BLOCK_NN + for start_l in tl.range(0, stride_l, BLOCK_NN): + # Global sequence index + cols_l = start_l + tl.arange(0, BLOCK_NN)[:, None] # [BLOCK_NN, 1] + + # Build batch * sequence pointers to token_ids + off_ids = pid_batch * stride_b + cols_l # [BLOCK_NN, 1] + mask_l = cols_l < stride_l + ids = tl.load(token_ids + off_ids, mask=mask_l) # [BLOCK_NN, 1] + + # Each id identifies a row in weight + # Compute pointer into weight: rows=[ids], cols=[cols_d] + off_weight = ids * stride_v + cols_d # [BLOCK_NN, BLOCK_DMODEL] + vals = tl.load(weight + off_weight, mask=mask_l & mask_d) + + # Store into out: [B, L, D] + off_out = ( + pid_batch * stride_b # batch stride + + cols_l * stride_l # seq stride + + cols_d * 1 # dim stride + ) + tl.store(out + off_out, vals, mask=mask_l & mask_d) + + +def embedding(token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + assert token_ids.is_cuda or token_ids.device.type == "cuda", "inputs should live on GPU" + assert weight.dim() == 2 + V, D = weight.shape + B, L = token_ids.shape + out = torch.empty((B, L, D), dtype=weight.dtype, device=weight.device) + + BLOCK_DMODEL = triton.next_power_of_2(D) + BLOCK_N = 32 + BLOCK_NN = triton.next_power_of_2(256) # load up to 256 sequence positions per program + + grid = lambda META: (B, triton.cdiv(D, META["BLOCK_DMODEL"])) + + embedding_kernel[grid]( + token_ids, + weight, + out, + token_ids.stride(0), token_ids.stride(1), weight.stride(0), weight.stride(1), + BLOCK_N=BLOCK_N, BLOCK_NN=BLOCK_NN, BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stdout new file mode 100644 index 0000000..6de004a --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_76684.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_76684 due to embedding() takes 2 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py new file mode 100644 index 0000000..65ca966 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def embedding_kernel(indexes, weight, out, + stride_idx_b, stride_idx_s, + stride_wt_v, stride_wd, + stride_ot_b, stride_ot_s, stride_ot_d, + VOCAB_SIZE: tl.constexpr, D_MODEL: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_NN: tl.constexpr, vob_start_id: tl.constexpr, vob_end_id: tl.constexpr): + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + pid_d = tl.program_id(2) + + seq_start = pid_s * BLOCK_N + d_start = pid_d * BLOCK_NN + + offs_s = seq_start + tl.arange(0, BLOCK_N) + offs_d = d_start + tl.arange(0, BLOCK_NN) + + mask_seq = offs_s < (seq_start + BLOCK_N) + mask_d = offs_d < D_MODEL + + idx_ptr = indexes + pid_b * stride_idx_b + offs_s * stride_idx_s + token_ids = tl.load(idx_ptr, mask=mask_seq, other=0) + + clamp_low = tl.full_like(token_ids, vob_start_id) + clamp_high = tl.full_like(token_ids, vob_end_id - 1) + token_ids = tl.where(token_ids < vob_start_id, clamp_low, token_ids) + token_ids = tl.where(token_ids > (vob_end_id - 1), clamp_high, token_ids) + token_ids = token_ids - vob_start_id + + w_offs = (token_ids[:, None] * stride_wt_v) + (offs_d[None, :] * stride_wd) + emb_vec = tl.load(weight + w_offs, mask=mask_seq[:, None] & mask_d[None, :], other=0.0) + + o_offs = (pid_b * stride_ot_b) + (offs_s * stride_ot_s)[:, None] + (offs_d * stride_ot_d)[None, :] + tl.store(out + o_offs, emb_vec, mask=mask_seq[:, None] & mask_d[None, :]) + + +def embedding(indexes: torch.Tensor, weight: torch.Tensor, vob_start_id: int, vob_end_id: int, out: Optional[torch.Tensor] = None) -> torch.Tensor: + B, S = indexes.shape + VOCAB_SIZE, D_MODEL = weight.shape + + out = torch.empty((B, S, D_MODEL), dtype=weight.dtype, device=weight.device) if out is None else out + + BLOCK_N = min(64, triton.next_power_of_2(S)) + BLOCK_NN = min(64, triton.next_power_of_2(D_MODEL)) + + grid = (B, triton.cdiv(S, BLOCK_N), triton.cdiv(D_MODEL, BLOCK_NN)) + + embedding_kernel[grid]( + indexes, weight, out, + indexes.stride(0), indexes.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + VOCAB_SIZE=VOCAB_SIZE, + D_MODEL=D_MODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + vob_start_id=vob_start_id, + vob_end_id=vob_end_id + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stdout new file mode 100644 index 0000000..b18d5c2 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_804525.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_804525 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py new file mode 100644 index 0000000..709d7d8 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py @@ -0,0 +1,195 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(input_ids_ptr, + weight_ptr, + out_ptr, + vob_start_id: tl.constexpr, + vob_end_id: tl.constexpr, + stride_ids_0, stride_ids_1, + stride_w_v, stride_w_d, + stride_out_0, stride_out_1, stride_out_2, + seq_len, vocab_size, dim, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid_d = tl.program_id(0) + pid_b = tl.program_id(1) + + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + row_start = pid_b * seq_len + for j in range(0, seq_len, BLOCK_N): + offs_n = j + tl.arange(0, BLOCK_N) + + mask_n = offs_n < seq_len + flat_idx = row_start + offs_n + pid = tl.load(input_ids_ptr + flat_idx, mask=mask_n, other=0) + + valid_id = (pid >= vob_start_id) & (pid < vob_end_id) & mask_n + real_id = pid - vob_start_id + real_id = tl.where(valid_id, real_id, 0) + + w_off = real_id[:, None] * stride_w_v + offs_d[None, :] * stride_w_d + embed = tl.load(weight_ptr + w_off, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + o_off = pid_b * stride_out_0 + offs_n[:, None] * stride_out_1 + offs_d[None, :] * stride_out_2 + tl.store(out_ptr + o_off, + embed, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + +def embedding(input_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor) -> torch.Tensor: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + batch, seq_len = input_ids.shape + vocab_size, dim = weight.shape + assert out.shape == (batch, seq_len, dim) + assert weight.dtype == out.dtype + + BLOCK_DMODEL = triton.next_power_of_2(dim) + BLOCK_N = 128 + grid = (triton.cdiv(dim, BLOCK_DMODEL), batch) + + embedding_kernel[grid]( + input_ids, weight, out, + vob_start_id, vob_end_id, + input_ids.stride(0), input_ids.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + seq_len, vocab_size, dim, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stdout new file mode 100644 index 0000000..e072cf0 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_823958.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_823958 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py new file mode 100644 index 0000000..1a0bb5d --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py @@ -0,0 +1,195 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel(input_ids_ptr, + weight_ptr, + out_ptr, + vob_start_id: tl.constexpr, + vob_end_id: tl.constexpr, + stride_ids_0, stride_ids_1, + stride_w_v, stride_w_d, + stride_out_0, stride_out_1, stride_out_2, + seq_len, vocab_size, dim, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr): + pid_d = tl.program_id(0) + pid_b = tl.program_id(1) + + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + row_start = pid_b * seq_len + for j in range(0, seq_len, BLOCK_N): + offs_n = j + tl.arange(0, BLOCK_N) + + mask_n = offs_n < seq_len + flat_idx = row_start + offs_n + pid = tl.load(input_ids_ptr + flat_idx, mask=mask_n, other=0) + + valid_id = (pid >= vob_start_id) & (pid < vob_end_id) & mask_n + real_id = pid - vob_start_id + real_id = tl.where(valid_id, real_id, 0) + + w_off = real_id[:, None] * stride_w_v + offs_d[None, :] * stride_w_d + embed = tl.load(weight_ptr + w_off, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + o_off = pid_b * stride_out_0 + offs_n[:, None] * stride_out_1 + offs_d[None, :] * stride_out_2 + tl.store(out_ptr + o_off, + embed, + mask=valid_id[:, None] & (offs_d[None, :] < dim)) + + +def embedding(input_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor) -> torch.Tensor: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + batch, seq_len = input_ids.shape + vocab_size, dim = weight.shape + assert out.shape == (batch, seq_len, dim) or (input_ids.ndim == 1 and out.shape == (1, seq_len, dim)) + assert weight.dtype == out.dtype + + BLOCK_DMODEL = triton.next_power_of_2(dim) + BLOCK_N = 128 + grid = (triton.cdiv(dim, BLOCK_DMODEL), batch) + + embedding_kernel[grid]( + input_ids, weight, out, + vob_start_id, vob_end_id, + input_ids.stride(0), input_ids.stride(1), + weight.stride(0), weight.stride(1), + out.stride(0), out.stride(1), out.stride(2), + seq_len, vocab_size, dim, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stdout new file mode 100644 index 0000000..d1da6b3 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_830218.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_830218 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py new file mode 100644 index 0000000..bc4055c --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Triton kernel +# --------------------------------------------------------------------------- +@triton.jit +def embedding_kernel( + ids, # pointer to 1-D flattened token-ids tensor shape [B*S] (int64) + weight, # pointer to embedding weight matrix shape [V, D] (float16/32) + out, # pointer to output tensor shape [B, S, D] + B, # int: batch size + S, # int: sequence length + V, # int: vocab size + D, # int: embedding dim + stride_w_0, # weight stride 0 = D + stride_out_0, # stride for out dim 0 = S*D + stride_out_1, # stride for out dim 1 = D + stride_out_2, # stride for out dim 2 = 1 + vob_start_id, # unused + vob_end_id, # unused + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) # 1-D grid: one program per batch element + b_offset = pid * S # base offset for this sequence in flat ids + # Each program handles its sequence: S tokens + for seq_off in range(0, S, BLOCK_N): + idx_n = seq_off + tl.arange(0, BLOCK_N) + mask_n = idx_n < S + # indices into 1-D ids tensor + ids_idx = b_offset + idx_n + token_ids = tl.load(ids + ids_idx, mask=mask_n, other=0) + + # Clamp token ids into [0, V-1] + token_ids = tl.maximum(0, token_ids) + token_ids = tl.minimum(V-1, token_ids) + + # Group BLOCK_N tokens into BLOCK_NN chunks + for grp_off in range(0, BLOCK_N, BLOCK_NN): + gn = grp_off + tl.arange(0, BLOCK_NN) + mask_gn = (gn < BLOCK_N) & mask_n + tid = token_ids[grp_off: grp_off + BLOCK_NN] + + out_base = pid * stride_out_0 + (seq_off + grp_off) * stride_out_1 + # Iterate over D in blocks + for d_off in range(0, D, BLOCK_DMODEL): + offs_d = d_off + tl.arange(0, BLOCK_DMODEL) + mask_d = offs_d < D + mask = mask_gn[:, None] & mask_d[None, :] + + # Weight load: weight[tid, offs_d] + w_ptr = weight + tid[:, None] * stride_w_0 + offs_d[None, :] + emb_vec = tl.load(w_ptr, mask=mask, other=0.0) + + # Output store: out[pid, seq_off+grp_off:grp_off+BLOCK_NN, offs_d] + o_ptr = out + out_base + gn[:, None] * stride_out_1 + d_off + offs_d[None, :] + tl.store(o_ptr, emb_vec, mask=mask) + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- +def embedding( + ids: torch.Tensor, # [B, S] or [B*S] int64 + weight: torch.Tensor, # [V, D] + vob_start_id: int = 0, + vob_end_id: int = None, + out: torch.Tensor = None, +) -> torch.Tensor: # Returns [B, S, D] + + device = weight.device + dtype = weight.dtype + V, D = weight.shape + + if ids.dim() == 1: + # Handle flattened case: expect B*S tokens + flat_ids = ids + else: # dim == 2 + B, S = ids.shape + flat_ids = ids.view(-1) # [B*S] + + if vob_end_id is None: + vob_end_id = V + + if ids.dim() == 1: + # Infer B and S from flattened tensor + # Assume first argument already passed as full 1-D sequence (B*S) + # Use vob_start_id / vob_end_id if given, but not used in kernel + # Need outer computation to provide true B and S + raise ValueError( + "With 1-D ids you must compute/reshape outside to [B,S] or directly provide 2-D [B, S]" + ) + else: + B, S = ids.shape + flat_ids = ids.view(-1) + + if out is None: + out = torch.empty((B, S, D), dtype=dtype, device=device) + + BLOCK_N = 128 + BLOCK_NN = 64 + BLOCK_DMODEL = triton.next_power_of_2(D) + + grid = (B,) + + embedding_kernel[grid]( + flat_ids, + weight, + out, + B, + S, + V, + D, + weight.stride(0), + out.stride(0), + out.stride(1), + out.stride(2), + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stdout new file mode 100644 index 0000000..99a4915 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_837397.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_837397 due to With 1-D ids you must compute/reshape outside to [B,S] or directly provide 2-D [B, S] diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py new file mode 100644 index 0000000..362a88e --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py @@ -0,0 +1,220 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids, + weight, + out, + stride_b, + stride_l, + stride_v, + stride_d, + stride_ob, + stride_ol, + vob_start_id, + vob_end_id, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + pid_batch = tl.program_id(0) + pid_dim = tl.program_id(1) + + cols_d = pid_dim * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_d = cols_d < stride_d + + for start_l in tl.range(0, stride_l, BLOCK_NN): + cols_l = start_l + tl.arange(0, BLOCK_NN) + mask_l_outer = cols_l < stride_l + + flat_offset = pid_batch * stride_b + cols_l + ids = tl.load(token_ids + flat_offset, mask=mask_l_outer) + + valid_mask = (ids >= vob_start_id) & (ids < vob_end_id) + + safe_ids = tl.where(valid_mask, ids, vob_start_id) + + warp_offsets_l = cols_l[:, None] + warp_offsets_d = cols_d[None, :] + + emb_offsets = safe_ids[:, None] * stride_v + warp_offsets_d + vals = tl.load(weight + emb_offsets, mask=(mask_l_outer[:, None] & mask_d[None, :])) + + out_offsets = pid_batch * stride_ob + warp_offsets_l * stride_ol + warp_offsets_d + tl.store(out + out_offsets, vals, mask=(mask_l_outer[:, None] & mask_d[None, :])) + + +def embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor +) -> None: + assert token_ids.device == weight.device == out.device + assert token_ids.dtype == torch.int64 or token_ids.dtype == torch.int32 + assert weight.ndim == 2 + V, D = weight.shape + B = token_ids.numel() // token_ids.size(-1) if token_ids.ndim > 1 else 1 + L = token_ids.size(-1) + + if token_ids.ndim == 1: + assert out.numel() == L * D and out.size(-1) == D + else: + assert out.numel() == B * L * D and out.size(-1) == D + + BLOCK_DMODEL = triton.next_power_of_2(D) + BLOCK_N = 32 + BLOCK_NN = 256 + + if token_ids.ndim == 1: + grid = (1, triton.cdiv(D, BLOCK_DMODEL)) + token_ids = token_ids.contiguous() + out = out.view(L, D).contiguous() + else: + grid = (B, triton.cdiv(D, BLOCK_DMODEL)) + token_ids = token_ids.view(B, L).contiguous() + out = out.view(B, L, D).contiguous() + + embedding_kernel[grid]( + token_ids, + weight, + out, + token_ids.stride(0) if token_ids.ndim > 1 else 0, + token_ids.stride(-1), + weight.stride(0), + weight.stride(1), + out.stride(0) if out.ndim > 2 else 0, + out.stride(-2) if out.ndim > 2 else out.stride(0), + vob_start_id, + vob_end_id, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stdout new file mode 100644 index 0000000..178b749 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_92676.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: embedding_triton_kernel.py_gen_triton_code_92676.py diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py new file mode 100644 index 0000000..6275e29 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(weight_ptr, ids_ptr, out_ptr, + stride_w, # stride of weight in dim-1 + stride_out, # stride of out in dim-0 + num_tokens, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr + ): + pid = tl.program_id(0) # token block + pid_d = tl.program_id(1) # d-model block + start_token = pid * BLOCK_NN + ids = tl.load(ids_ptr + start_token + tl.arange(0, BLOCK_NN)) # [BLOCK_NN] + + offsets_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + mask_tokens = start_token + tl.arange(0, BLOCK_NN) < num_tokens + + for i in range(BLOCK_NN): + cond = mask_tokens[i] + if not cond: + break + token_id = ids[i] + weight_offsets = token_id * stride_w + offsets_d + weight_vec = tl.load(weight_ptr + weight_offsets, mask=offsets_d < stride_w) + output_offsets = (start_token + i) * stride_out + offsets_d + tl.store(out_ptr + output_offsets, weight_vec, mask=offsets_d < stride_out) + +def embedding(weight: torch.Tensor, ids: torch.Tensor, out: torch.Tensor, + BLOCK_N: int = 1, BLOCK_NN: int = 32, BLOCK_DMODEL: int = None): + """ + Wrapper: weight shape [vocab, d_model] + ids shape [num_tokens] + out shape [num_tokens, d_model] + """ + assert weight.ndim == 2 + assert ids.ndim == 1 + assert out.ndim == 2 + assert out.shape == (ids.shape[0], weight.shape[1]) + vocab, d_model = weight.shape + num_tokens = ids.numel() + + if BLOCK_DMODEL is None: + BLOCK_DMODEL = triton.next_power_of_2(d_model) + + grid = lambda META: (triton.cdiv(num_tokens, meta["BLOCK_NN"]), + triton.cdiv(d_model, meta["BLOCK_DMODEL"])) + meta={ + "BLOCK_N": BLOCK_N, + "BLOCK_NN": BLOCK_NN, + "BLOCK_DMODEL": BLOCK_DMODEL + } + embedding_kernel[triton.cdiv(num_tokens, BLOCK_NN), + triton.cdiv(d_model, BLOCK_DMODEL)]( + weight, ids, out, + stride_w=weight.stride(0), + stride_out=out.stride(0), + num_tokens=num_tokens, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stdout new file mode 100644 index 0000000..4609420 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_940390.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_940390 due to diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py new file mode 100644 index 0000000..1efba06 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py @@ -0,0 +1,196 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel(input_ids_ptr, weight_ptr, out_ptr, + vob_start_id: int, vob_end_id: int, + stride_ids_0, stride_ids_1, + stride_w_v, stride_w_d, + stride_out_0, stride_out_1, stride_out_2, + seq_len: int, vocab_size: int, dim: int, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_NN: tl.constexpr): + pid_d = tl.program_id(0) + pid_b = tl.program_id(1) + + offs_d = pid_d * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + block_offset_id = pid_b * seq_len + block_offset_out = pid_b * seq_len * stride_out_1 + + for j in range(0, seq_len, BLOCK_NN): + offs_n = j + tl.arange(0, BLOCK_N) + mask_n = offs_n < seq_len + + token_ids = tl.load(input_ids_ptr + block_offset_id + offs_n, mask=mask_n) # BLOCK_N + valid_mask = (token_ids >= vob_start_id) & (token_ids < vob_end_id) & mask_n + token_ids_offs = token_ids - vob_start_id + + token_ids_offs = tl.where(valid_mask, token_ids_offs, 0) + + w_ptrs = weight_ptr + token_ids_offs[:, None] * stride_w_v + offs_d[None, :] * stride_w_d + emb_vec = tl.load(w_ptrs, mask=valid_mask[:, None] & (offs_d[None, :] < dim)) + + out_ptrs = out_ptr + block_offset_out + offs_n[:, None] * stride_out_1 + offs_d[None, :] * stride_out_2 + tl.store(out_ptrs, emb_vec, mask=valid_mask[:, None] & (offs_d[None, :] < dim)) + + +def embedding(input_ids: torch.Tensor, + weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: torch.Tensor + ) -> torch.Tensor: + batch, seq_len = input_ids.shape + vocab_size, dim = weight.shape + assert out.shape == (batch, seq_len, dim) + assert weight.dtype == out.dtype + + stride_ids_0, stride_ids_1 = input_ids.stride() + stride_w_v, stride_w_d = weight.stride() + stride_out_0, stride_out_1, stride_out_2 = out.stride() + + BLOCK_DMODEL = triton.next_power_of_2(dim) + + BLOCK_N = 64 + BLOCK_NN = 128 + grid = lambda META: (triton.cdiv(dim, META['BLOCK_DMODEL']), batch) + + embedding_kernel[grid]( + input_ids, weight, out, + int(vob_start_id), int(vob_end_id), + stride_ids_0, stride_ids_1, + stride_w_v, stride_w_d, + stride_out_0, stride_out_1, stride_out_2, + int(seq_len), int(vocab_size), int(dim), + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_NN=BLOCK_NN + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stdout new file mode 100644 index 0000000..e9e9575 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_965031.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_965031 due to not enough values to unpack (expected 2, got 1) diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py new file mode 100644 index 0000000..bea0a65 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py @@ -0,0 +1,174 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def embedding_kernel( + out, ids, weight, stride_os, stride_om, + stride_ws, stride_wm, + N, BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, BLOCK_DMODEL: tl.constexpr, +): + pid = tl.program_id(0) + start_id = pid * BLOCK_NN + cols_d = tl.arange(0, BLOCK_DMODEL) + + for i in range(0, BLOCK_NN, BLOCK_N): + rows_seq = start_id + i + tl.arange(0, BLOCK_N) + mask = rows_seq < N + ids_off = rows_seq + token_ids = tl.load(ids + ids_off, mask=mask, other=0) + + w_off = (token_ids[:, None] * stride_ws + cols_d[None, :] * stride_wm) + x = tl.load(weight + w_off, mask=mask[:, None], other=0.0) + + o_off = rows_seq[:, None] * stride_os + cols_d[None, :] * stride_om + tl.store(out + o_off, x, mask=mask[:, None]) + + +def embedding(ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + N = ids.numel() + DMODEL = weight.size(-1) + out = torch.empty((N, DMODEL), dtype=weight.dtype, device=weight.device) + + BLOCK_N = 16 + BLOCK_NN = 32 + BLOCK_DMODEL = triton.next_power_of_2(DMODEL) + + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_NN']),) + + embedding_kernel[grid]( + out, ids, weight, + out.stride(0), out.stride(1), + weight.stride(0), weight.stride(1), + N, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stdout new file mode 100644 index 0000000..715b812 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_984659.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_984659 due to embedding() takes 2 positional arguments but 5 were given diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py new file mode 100644 index 0000000..111765e --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py @@ -0,0 +1,227 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def embedding_kernel( + token_ids_ptr, # int32* + weight_ptr, # fp16/fp32* + out_ptr, # fp16/fp32* + seq_len, # int32 + num_tokens, # int32 + d_model, # int32 + stride_id, # int32 + stride_w0, # int32 + stride_w1, # int32 + stride_out0, # int32 + stride_out1, # int32 + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + pid_seq = tl.program_id(0) # sequence index + pid_col = tl.program_id(1) # d_model block index + + if pid_seq >= seq_len: + return + + # token dimension indices in this block + cols = pid_col * BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL) + + # offset into each token’s embedding slice + d_mask = cols < d_model + out_offset = pid_seq * stride_out0 + cols * stride_out1 + weight_offset_col = cols * stride_w1 + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, num_tokens, BLOCK_N): + # iterate over tokens in blocks of BLOCK_N + block_start = start_n + block_end = start_n + BLOCK_N + n_block = tl.arange(block_start, block_end) + n_mask = n_block < num_tokens + + # flat token index = seq * max_tokens_per_seq + token_in_seq + flat_idx = pid_seq * stride_id + n_block + token_ids = tl.load(token_ids_ptr + flat_idx, mask=n_mask, other=0) + + # gather weight rows: token_ids[BLOCK_N] × lookup[BLOCK_DMODEL] + for inner in range(0, BLOCK_N, BLOCK_NN): + inner_start = inner + inner_end = inner + BLOCK_NN + inner_range = inner_start + tl.arange(0, BLOCK_NN) + mask_inner = (n_block < num_tokens) & (inner_range < BLOCK_N) + inner_seq_ids = token_ids[inner_range - inner_start] if BLOCK_N > 1 else token_ids + + # load weight rows = inner_seq_ids + w_offs = inner_seq_ids * stride_w0 + weight_offset_col + w_vals = tl.load(weight_ptr + w_offs, mask=d_mask & mask_inner, other=0.0) + acc += w_vals + + # store gathered embedding for this sequence + acc = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + out_offset, acc, mask=d_mask) + + +def embedding( + token_ids: torch.Tensor, # int32, shape (seq_len, num_tokens) + weight: torch.Tensor, # fp16/fp32, shape (vocab_size, d_model) + out: torch.Tensor = None, +) -> torch.Tensor: + seq_len, num_tokens = token_ids.shape + _, d_model = weight.shape + assert token_ids.dtype == torch.int32 + assert weight.dtype in [torch.float16, torch.float32] + assert weight.is_contiguous() + + if out is None: + out = torch.empty((seq_len, d_model), dtype=weight.dtype, device=weight.device) + + BLOCK_DMODEL = triton.next_power_of_2(d_model) + BLOCK_N = 16 + BLOCK_NN = 8 + + grid = (seq_len, triton.cdiv(d_model, BLOCK_DMODEL)) + + embedding_kernel[grid]( + token_ids, + weight, + out, + seq_len, + num_tokens, + d_model, + token_ids.stride(0), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_embedding(): + + # 参数定义 + + vocab_size = 1000 # 词汇表大小 + + embedding_dim = 512 # 嵌入维度 + + sequence_length = 128 # 输入序列长度 + + vob_start_id = 10 # 词汇表起始 ID + + vob_end_id = 1000 # 词汇表结束 ID + + + + # 创建测试输入张量 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + + + # 调用嵌入函数 + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + + + # 保存结果 + + results = {} + + results['test_case_1'] = out.clone() + + + + # 测试不同的输入 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_2'] = out.clone() + + + + # 测试不同的词汇表范围 + + vob_start_id = 0 + + vob_end_id = 500 + + input_ids = torch.randint( + + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_3'] = out.clone() + + + + # 测试不同的嵌入维度 + + embedding_dim = 256 + + weight = torch.randn( + + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + out = torch.zeros( + + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + + ) + + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + results['test_case_4'] = out.clone() + + + + return results + + + +result_gold = test_embedding() diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stderr b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stdout b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stdout new file mode 100644 index 0000000..68f2751 --- /dev/null +++ b/src/temp/gen/embedding_triton_kernel.py_gen_triton_code_992208.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module embedding_triton_kernel.py_gen_triton_code_992208 due to embedding() takes from 2 to 3 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py new file mode 100644 index 0000000..6685521 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py @@ -0,0 +1,201 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch).to(tl.int32) + + block_n_size = tl.maximum(0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + ptr_start = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + ) + ptr_start_lse = ( + cur_batch * stride_mid_o_eb + + cur_head * stride_mid_o_eh + ) + + for block_seq_n in range(block_n_size): + tv = tl.load( + Mid_O + ptr_start + block_seq_n * stride_mid_os + offs_d * stride_mid_od, + mask=offs_d < BLOCK_DMODEL, + other=0.0, + ) + tlogic = tl.load(Mid_O_LogExpSum + ptr_start_lse + block_seq_n * stride_mid_o_es) + + new_max_logic = tl.maximum(tlogic, max_logic) + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv.to(tl.float32) + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + out_ptr = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d + tl.store(out_ptr, (acc / sum_exp).to(Out.dtype.element_ty)) + +def flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, BLOCK_SEQ): + assert Mid_O.shape[-1] in {16, 32, 64, 128, 256, 512} + batch, head_num, D = Mid_O.shape[0], Mid_O.shape[1], Mid_O.shape[-1] + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=D, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stdout new file mode 100644 index 0000000..ef48cdd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_126106.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_126106 due to diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py new file mode 100644 index 0000000..4dbfa96 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py @@ -0,0 +1,211 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = tl.maximum(0, (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + + offs_out = cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(Out + offs_out, acc.to(Out.type.element_ty)) + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + batch, head_num = Mid_O.shape[0], Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + BLOCK_SEQ = Mid_O.shape[2] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stdout new file mode 100644 index 0000000..1c295ed --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_14965.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_14965 due to flash_decode_stage2() takes 4 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py new file mode 100644 index 0000000..9756e66 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py @@ -0,0 +1,210 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp += exp_logic + + max_logic = new_max_logic + + offs_out = ( + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d + ) + final = acc / sum_exp + tl.store(Out + offs_out, final.to(Out.type.element_ty)) + +def flash_decode_stage2( + b_seq_len: torch.Tensor, + mid_out: torch.Tensor, + mid_out_logexpsum: torch.Tensor, + output: torch.Tensor, + BLOCK_SEQ: int, +) -> None: + batch, head_num = b_seq_len.size(0), mid_out.size(1) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + b_seq_len, + mid_out, + mid_out_logexpsum, + output, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=output.size(2), + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stdout new file mode 100644 index 0000000..764d81f --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_198114.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_198114 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py new file mode 100644 index 0000000..0dda36d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py @@ -0,0 +1,203 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + + seq_len = tl.load(B_Seqlen + cur_batch) + n_blocks = tl.maximum((seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ, 1) + + acc = tl.zeros([BLOCK_DMODEL], dtype=Mid_O.dtype.element_ty) + max_logic = -float('inf') + sum_exp = 0.0 + + for i in range(n_blocks): + v = tl.load( + Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + i * stride_mid_os + + offs_d * stride_mid_od + ) + logsum = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + i * stride_mid_les_s + ) + new_max = tl.maximum(logsum, max_logic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + v * tl.exp(logsum - new_max) + sum_exp = sum_exp * scale + tl.exp(logsum - new_max) + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc + ) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, +): + B, H = Out.shape[0], Out.shape[1] + grid = (B, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3) if Mid_O.ndim == 4 else 1, + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stdout new file mode 100644 index 0000000..013ca5b --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_23614.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_23614 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py new file mode 100644 index 0000000..cd3682f --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py @@ -0,0 +1,212 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + Out, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_seqlen = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DMODEL) + block_n_size = (cur_seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid).to(tl.float32) + + offs_lse = ( + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_lse) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp += exp_logic + max_logic = new_max_logic + + offs_out = ( + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d + ) + tl.store(Out + offs_out, acc / sum_exp) + return + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +) -> None: + batch, head_num = Mid_O.size(0), Mid_O.size(1) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Out.size(2), + num_warps=4, + num_stages=2, + ) + return + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stdout new file mode 100644 index 0000000..5f2cf36 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_269764.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_269764 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py new file mode 100644 index 0000000..ba9ae8a --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py @@ -0,0 +1,198 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = tl.cdiv(cur_seq_len, BLOCK_SEQ) + + offsets_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float('-inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_id * stride_mid_os + offsets_d * stride_mid_od + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_lse = cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_s + tlogic = tl.load(Mid_O_LogExpSum + offs_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + + exp_logic = tl.exp(tlogic - new_max) + acc += tv * exp_logic + sum_exp += exp_logic + max_logic = new_max + + offs_out = cur_batch * stride_out_b + cur_head * stride_out_h + offsets_d * stride_out_d + tl.store(Out + offs_out, (acc / sum_exp).to(Out.dtype.element_ty)) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int +): + batch, num_heads = Out.shape[0], Out.shape[1] + BLOCK_DMODEL = Out.shape[-1] + grid = (batch, num_heads) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_335674.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py new file mode 100644 index 0000000..36aec71 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py @@ -0,0 +1,225 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_b_seqlen, + stride_mid_os, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_ls, + stride_mid_lh, + stride_mid_lb, + stride_out_s, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_b_seqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_block = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + cur_batch * stride_mid_os + + cur_head * stride_mid_oh + + block_id * stride_mid_ob + + offs_block * stride_mid_od, + mask=offs_block < BLOCK_DMODEL, + other=0.0 + ).to(tl.float32) + + tlogic = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_ls + + cur_head * stride_mid_lh + + block_id * stride_mid_lb + ).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + sum_exp *= scale + acc *= scale + + tlogic_exp = tl.exp(tlogic - new_max) + acc += tv * tlogic_exp + sum_exp += tlogic_exp + + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_s + + cur_head * stride_out_h + + offs_block * stride_out_d, + acc.to(Out.type.element_ty), + mask=offs_block < BLOCK_DMODEL + ) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + + assert BLOCK_SEQ > 0 + assert Out.shape == (batch, head_num, BLOCK_DMODEL) + assert Mid_O.shape[:-1] == (batch, head_num, (B_Seqlen.max() + BLOCK_SEQ - 1) // BLOCK_SEQ) + assert Mid_O_LogExpSum.shape == Mid_O.shape[:-1] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0) if B_Seqlen.dim() == 1 else 0, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stdout new file mode 100644 index 0000000..8c8674a --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_349606.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_349606 due to diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py new file mode 100644 index 0000000..19e4413 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py @@ -0,0 +1,208 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [head_num, seq_blocks, BLOCK_DMODEL] + Mid_O_LogExpSum, # [head_num, seq_blocks] + Out, # [head_num, BLOCK_DMODEL] + stride_mid_oh, # stride(head_num) + stride_mid_ob, # stride(seq_blocks) + stride_mid_od, # stride(BLOCK_DMODEL) + stride_mid_o_lseh, # stride(head) + stride_mid_o_lseb, # stride(seq_blocks) + stride_oh, # stride(head_num) + stride_od, # stride(BLOCK_DMODEL) + B_START_ID, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + ptr_tv = ( + Mid_O + + (cur_head * stride_mid_oh) + + (block_id * stride_mid_ob) + + offs_d * stride_mid_od + ) + + tv = tl.load(ptr_tv) + ptr_tlogic = ( + Mid_O_LogExpSum + + cur_head * stride_mid_o_lseh + + block_id * stride_mid_o_lseb + ) + tlogic = tl.load(ptr_tlogic) + + max_prev = max_logic + max_logic = tl.maximum(max_prev, tlogic) + sum_exp = sum_exp * tl.exp(max_prev - max_logic) + tl.exp(tlogic - max_logic) + acc = acc * tl.exp(max_prev - max_logic) + tv * tl.exp(tlogic - max_logic) + + result = acc / (sum_exp + 1e-6) + + ptr_out = Out + cur_head * stride_oh + offs_d * stride_od + tl.store(ptr_out, result.cast(ptr_out.dtype.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + B_START_ID: int = 0, + BLOCK_SEQ: int = 64, +): + head_num, seq_blocks, BLOCK_DMODEL = Mid_O.shape + batch_size = B_Seqlen.shape[0] + + triton_grid = (batch_size, head_num) + + _fwd_kernel_flash_decode_stage2[ + triton_grid + ]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Out.stride(0), + Out.stride(1), + B_START_ID, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stdout new file mode 100644 index 0000000..4f9e1ba --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_369704.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_369704 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py new file mode 100644 index 0000000..0dda36d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py @@ -0,0 +1,203 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + + seq_len = tl.load(B_Seqlen + cur_batch) + n_blocks = tl.maximum((seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ, 1) + + acc = tl.zeros([BLOCK_DMODEL], dtype=Mid_O.dtype.element_ty) + max_logic = -float('inf') + sum_exp = 0.0 + + for i in range(n_blocks): + v = tl.load( + Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + i * stride_mid_os + + offs_d * stride_mid_od + ) + logsum = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + i * stride_mid_les_s + ) + new_max = tl.maximum(logsum, max_logic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + v * tl.exp(logsum - new_max) + sum_exp = sum_exp * scale + tl.exp(logsum - new_max) + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc + ) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, +): + B, H = Out.shape[0], Out.shape[1] + grid = (B, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3) if Mid_O.ndim == 4 else 1, + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stdout new file mode 100644 index 0000000..b7c1518 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_38100.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_38100 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py new file mode 100644 index 0000000..3316994 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + seqlen = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + + offs_out = cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(Out + offs_out, acc.to(Out.type.element_ty)) + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + batch = Mid_O.size(0) + head_num = Mid_O.size(1) + BLOCK_DMODEL = Mid_O.size(-1) + BLOCK_SEQ = Mid_O.size(2) if len(Mid_O.shape) == 4 else 1 + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stdout new file mode 100644 index 0000000..c016b85 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_405645.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_405645 due to flash_decode_stage2() takes 4 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py new file mode 100644 index 0000000..d17ca28 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py @@ -0,0 +1,200 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_olesb, + stride_mid_olesh, + stride_mid_oles, + stride_oub, + stride_ouh, + stride_oud, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_head = tl.program_id(1) + cur_batch = tl.program_id(0) + + offs_d = tl.arange(0, BLOCK_DMODEL) + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + for block_n in range(block_n_size): + tv = tl.load(Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + block_n * stride_mid_os + offs_d * stride_mid_od) + tlogic = tl.load(Mid_O_LogExpSum + cur_batch * stride_mid_olesb + + cur_head * stride_mid_olesh + block_n * stride_mid_oles) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + new_scale = tl.exp(tlogic - new_max_logic) + + acc = acc * old_scale + tv * new_scale + sum_exp = sum_exp * old_scale + new_scale + max_logic = new_max_logic + + acc = acc / sum_exp + tl.store(Out + cur_batch * stride_oub + cur_head * stride_ouh + offs_d * stride_oud, acc) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + assert Mid_O_LogExpSum.shape[1] == head_num + BLOCK_SEQ = Mid_O.shape[2] + BLOCK_DMODEL = Mid_O.shape[3] + + _fwd_kernel_flash_decode_stage2[(batch, head_num)]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stdout new file mode 100644 index 0000000..f440707 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_42419.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_42419 due to flash_decode_stage2() takes 4 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py new file mode 100644 index 0000000..162a08f --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py @@ -0,0 +1,224 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_b_seqlen, + stride_mid_os, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_ls, + stride_mid_lh, + stride_mid_lb, + stride_out_s, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_b_seqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_block = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + cur_batch * stride_mid_os + + cur_head * stride_mid_oh + + block_id * stride_mid_ob + + offs_block * stride_mid_od + ).to(tl.float32) + + tlogic = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_ls + + cur_head * stride_mid_lh + + block_id * stride_mid_lb + ).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + sum_exp *= scale + acc *= scale + + tlogic_exp = tl.exp(tlogic - new_max) + acc += tv * tlogic_exp + sum_exp += tlogic_exp + + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_s + + cur_head * stride_out_h + + offs_block * stride_out_d, + acc.to(Out.type.element_ty) + ) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + + assert BLOCK_SEQ > 0 + assert Out.shape == (batch, head_num, BLOCK_DMODEL) + assert Mid_O.shape[:-1] == (batch, head_num, (B_Seqlen.max().item() + BLOCK_SEQ - 1) // BLOCK_SEQ) + assert Mid_O_LogExpSum.shape == Mid_O.shape[:-1] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0) if B_Seqlen.dim() >= 1 else 0, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stdout new file mode 100644 index 0000000..5ac48aa --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_450387.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_450387 due to diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py new file mode 100644 index 0000000..6e48b97 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py @@ -0,0 +1,212 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, D] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + Out, # [batch, head, D] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_lse_b, + stride_lse_h, + stride_lse_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch).to(tl.int32) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + sum_exp = 0.0 + max_logic = -float("inf") + + for block_id in range(0, block_n_size): + v_ptrs = ( + Mid_O + + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + log_ptr = ( + Mid_O_LogExpSum + + cur_batch * stride_lse_b + + cur_head * stride_lse_h + + block_id * stride_lse_s + ) + v = tl.load(v_ptrs) + lse = tl.load(log_ptr) + + new_max = tl.maximum(max_logic, lse) + scale = tl.exp(max_logic - new_max) + exp_lse = tl.exp(lse - new_max) + + acc = acc * scale + v.to(tl.float32) * exp_lse + sum_exp = sum_exp * scale + exp_lse + max_logic = new_max + + out_ptrs = ( + Out + + cur_batch * stride_ob + + cur_head * stride_oh + + offs_d * stride_od + ) + final_val = acc / sum_exp + tl.store(out_ptrs, final_val.to(Out.dtype.element_ty)) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +): + batch, head_num, D = Mid_O.shape[0], Mid_O.shape[1], Mid_O.shape[-1] + assert D in {16, 32, 64, 128, 256, 512} + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2) if Mid_O.ndim == 4 else 0, + Mid_O.stride(-1), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2) if Mid_O_LogExpSum.ndim == 3 else 0, + Out.stride(0), + Out.stride(1), + Out.stride(2) if Out.ndim == 3 else Out.stride(-1), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=D, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stdout new file mode 100644 index 0000000..46ceb43 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_506478.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_506478 due to diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py new file mode 100644 index 0000000..bee2bb6 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py @@ -0,0 +1,209 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + offs_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid).to(tl.float32) + + offs_lse = ( + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_lse) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp += exp_logic + max_logic = new_max_logic + + offs_out = ( + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d + ) + tl.store(Out + offs_out, acc / sum_exp) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +) -> None: + batch, head_num = Mid_O.size(0), Mid_O.size(1) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Out.size(2), + num_warps=4, + num_stages=2, + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stdout new file mode 100644 index 0000000..cccb5a0 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_543766.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_543766 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py new file mode 100644 index 0000000..acbae13 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py @@ -0,0 +1,212 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp += exp_logic + + max_logic = new_max_logic + + offs_out = ( + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d + ) + final = acc / sum_exp + tl.store(Out + offs_out, final.to(Out.type.element_ty)) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +) -> None: + batch, head_num = B_Seqlen.size(0), Mid_O.size(1) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + 1, + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + 1, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Out.size(2), + num_warps=4, + num_stages=2, + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stdout new file mode 100644 index 0000000..9555d6d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_560861.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_560861 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py new file mode 100644 index 0000000..b19a7fd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py @@ -0,0 +1,197 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = tl.cdiv(cur_seq_len, BLOCK_SEQ) + + offsets_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float('-inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + tv_ptrs = Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_id * stride_mid_os + offsets_d * stride_mid_od + tv = tl.load(tv_ptrs).to(tl.float32) + tlogic_ptr = Mid_O_LogExpSum + cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_s + tlogic = tl.load(tlogic_ptr).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + sum_exp = sum_exp * scale + + exp_logic = tl.exp(tlogic - new_max) + acc += tv * exp_logic + sum_exp += exp_logic + max_logic = new_max + + out_ptrs = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offsets_d * stride_out_d + tl.store(out_ptrs, (acc / sum_exp).to(Out.type.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, + BLOCK_DMODEL: int +): + batch, num_heads = Out.shape[0], Out.shape[1] + grid = (batch, num_heads) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stdout new file mode 100644 index 0000000..d102724 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_576804.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_576804 due to flash_decode_stage2() missing 1 required positional argument: 'BLOCK_DMODEL' diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py new file mode 100644 index 0000000..0dda36d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py @@ -0,0 +1,203 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + + seq_len = tl.load(B_Seqlen + cur_batch) + n_blocks = tl.maximum((seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ, 1) + + acc = tl.zeros([BLOCK_DMODEL], dtype=Mid_O.dtype.element_ty) + max_logic = -float('inf') + sum_exp = 0.0 + + for i in range(n_blocks): + v = tl.load( + Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + i * stride_mid_os + + offs_d * stride_mid_od + ) + logsum = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + i * stride_mid_les_s + ) + new_max = tl.maximum(logsum, max_logic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + v * tl.exp(logsum - new_max) + sum_exp = sum_exp * scale + tl.exp(logsum - new_max) + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc + ) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, +): + B, H = Out.shape[0], Out.shape[1] + grid = (B, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3) if Mid_O.ndim == 4 else 1, + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stdout new file mode 100644 index 0000000..74d5df9 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_653084.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_653084 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py new file mode 100644 index 0000000..9b3f4da --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py @@ -0,0 +1,223 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_b_seqlen, + stride_mid_os, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_ls, + stride_mid_lh, + stride_mid_lb, + stride_out_s, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_b_seqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_block = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + cur_batch * stride_mid_os + + cur_head * stride_mid_oh + + block_id * stride_mid_ob + + offs_block * stride_mid_od + ).to(tl.float32) + + tlogic = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_ls + + cur_head * stride_mid_lh + + block_id * stride_mid_lb + ).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + sum_exp *= scale + acc *= scale + + tlogic_exp = tl.exp(tlogic - new_max) + acc += tv * tlogic_exp + sum_exp += tlogic_exp + + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_s + + cur_head * stride_out_h + + offs_block * stride_out_d, + acc.to(Out.type.element_ty) + ) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + + assert BLOCK_SEQ > 0 + assert Out.shape == (batch, head_num, BLOCK_DMODEL) + assert Mid_O.shape[:-1][:3] == (batch, head_num, B_Seqlen.shape[0]) + assert Mid_O_LogExpSum.shape == (batch, head_num, B_Seqlen.shape[0]) + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0) if B_Seqlen.dim() >= 1 else 0, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stdout new file mode 100644 index 0000000..0d5c6f8 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_661704.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_661704 due to diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py new file mode 100644 index 0000000..fe19f0a --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py @@ -0,0 +1,213 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_obs, + stride_mid_oh, + stride_mid_oseq, + stride_mid_od, + stride_mid_les_bs, + stride_mid_les_h, + stride_mid_les_seq, + stride_out_bs, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + + cur_batch * stride_mid_obs + + cur_head * stride_mid_oh + + block_id * stride_mid_oseq + + offs_d * stride_mid_od + ) + tlogic = tl.load( + Mid_O_LogExpSum + + cur_batch * stride_mid_les_bs + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_seq + ) + + new_max_logic = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max_logic) + acc = acc * scale + sum_exp = sum_exp * scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += tv * exp_logic + sum_exp += exp_logic + + max_logic = new_max_logic + + acc = acc / sum_exp + tl.store( + Out + + cur_batch * stride_out_bs + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc, + ) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, +): + B, H = Out.shape[0], Out.shape[1] + grid = (B, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stdout new file mode 100644 index 0000000..a17c0dc --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_684759.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_684759 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py new file mode 100644 index 0000000..8e639d5 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + seqlen = tl.load(B_Seqlen + cur_batch).to(tl.int32) + block_n_size = tl.maximum(0, (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + + offs_out = cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(Out + offs_out, acc.to(Out.dtype.element_ty)) + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = None, +) -> torch.Tensor: + if BLOCK_SEQ is None: + BLOCK_SEQ = Mid_O.shape[2] + + batch, head_num = Mid_O.shape[0], Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_690508.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py new file mode 100644 index 0000000..e1f07d1 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch).to(tl.int32) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + sum_exp = 0.0 + max_logic = -float("inf") + + for block_id in range(0, block_n_size): + v = tl.load( + Mid_O + + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + lse = tl.load( + Mid_O_LogExpSum + + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + + new_max = tl.maximum(max_logic, lse) + old_scale = tl.exp(max_logic - new_max) + exp_lse = tl.exp(lse - new_max) + + acc = acc * old_scale + v.to(tl.float32) * exp_lse + sum_exp = sum_exp * old_scale + exp_lse + max_logic = new_max + + out_vals = acc / sum_exp + tl.store( + Out + + cur_batch * stride_ob + + cur_head * stride_oh + + offs_d * stride_od, + out_vals.to(Out.dtype.element_ty) + ) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +): + B_total, H, D = Mid_O.shape + seq_block_num = 1 + + Mid_O = Mid_O.view(B_total, H, seq_block_num, D) + Mid_O_LogExpSum = Mid_O_LogExpSum.view(B_total, H, seq_block_num) + + BLOCK_DMODEL = D + grid = (B_total, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stdout new file mode 100644 index 0000000..74e4d5a --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_720655.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_720655 due to shape '[2, 4, 1]' is invalid for input of size 2 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py new file mode 100644 index 0000000..cbd9bd8 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py @@ -0,0 +1,219 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head_num, seq_blocks, BLOCK_DMODEL] + Mid_O_LogExpSum, # [batch, head_num, seq_blocks] + Out, # [batch, head_num, BLOCK_DMODEL] + stride_mid_obh, # stride(batch, head) + stride_mid_ob, # stride(seq_blocks) + stride_mid_o, # stride(BLOCK_DMODEL) + stride_mid_o_lsebh, # stride(batch * head) + stride_mid_o_lseb, # stride(seq_blocks) + stride_ob, # stride(batch, head) + stride_oh, # stride(head_num) + stride_od, # stride(BLOCK_DMODEL) + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + offset_mid_o_bh = cur_batch * stride_mid_obh + cur_head * stride_oh + ptr_tv = ( + Mid_O + + offset_mid_o_bh + + block_id * stride_mid_ob + + offs_d * stride_mid_o + ) + + mask = block_id * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) < seq_len + valid_len = tl.sum(mask.to(tl.int32)) + + tv = tl.load(ptr_tv) + ptr_tlogic = ( + Mid_O_LogExpSum + + offset_mid_o_bh + + block_id * stride_mid_o_lseb + ) + tlogic = tl.load(ptr_tlogic) + + max_prev = max_logic + max_logic = tl.maximum(max_prev, tlogic) + sum_exp = sum_exp * tl.exp(max_prev - max_logic) + tl.exp(tlogic - max_logic) + acc = acc * tl.exp(max_prev - max_logic) + tv * tl.exp(tlogic - max_logic) + + result = acc / (sum_exp + 1e-7) + + ptr_out = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(ptr_out, result.cast(ptr_out.dtype.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, + BLOCK_SIZE: int = 64, +): + batch_num, head_num, seq_blocks, _ = Mid_O.shape + assert Mid_O.dtype == torch.float32 + assert Mid_O_LogExpSum.dtype == torch.float32 + block_size = BLOCK_SEQ + BLOCK_DMODEL_ = BLOCK_DMODEL + triton_grid = (batch_num, head_num) + + seq_len_stride = B_Seqlen.stride(0) if B_Seqlen.stride(0) != 0 else 1 + mid_o_stride_b, mid_o_stride_h, mid_o_stride_block, mid_o_stride_d = Mid_O.stride() + mid_o_lse_stride_bh, _, mid_o_lse_stride_block = Mid_O_LogExpSum.stride() + out_stride_b, out_stride_h, out_stride_d = Out.stride() + + _fwd_kernel_flash_decode_stage2[triton_grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + mid_o_stride_b, + mid_o_stride_block, + mid_o_stride_d, + mid_o_lse_stride_bh, + mid_o_lse_stride_block, + out_stride_b, + out_stride_h, + out_stride_d, + BLOCK_SEQ=block_size, + BLOCK_DMODEL=BLOCK_DMODEL_, + BLOCK_M=BLOCK_SIZE, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stdout new file mode 100644 index 0000000..a8086f0 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_721584.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_721584 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py new file mode 100644 index 0000000..c927683 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + seqlen = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + + offs_out = cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(Out + offs_out, acc.to(Out.type.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + batch = Mid_O.size(0) + head_num = Mid_O.size(1) + BLOCK_DMODEL = Mid_O.size(-1) + BLOCK_SEQ = Mid_O.size(-2) + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stdout new file mode 100644 index 0000000..9791cce --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_735113.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_735113 due to flash_decode_stage2() takes 4 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py new file mode 100644 index 0000000..a327a5e --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py @@ -0,0 +1,212 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seqlen = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp += exp_logic + + max_logic = new_max_logic + + offs_out = ( + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d + ) + final = acc / sum_exp + tl.store(Out + offs_out, final.to(Out.type.element_ty)) + +def flash_decode_stage2( + b_seq_len: torch.Tensor, + mid_out: torch.Tensor, + mid_out_logexpsum: torch.Tensor, + output: torch.Tensor, + BLOCK_SEQ: int, +) -> None: + batch, head_num = b_seq_len.size(0), mid_out.size(1) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + b_seq_len, + mid_out, + mid_out_logexpsum, + output, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + 1, + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + output.stride(0), + output.stride(1), + 1, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=output.size(2), + num_warps=4, + num_stages=2, + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stdout new file mode 100644 index 0000000..5d92dd1 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_739112.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_739112 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py new file mode 100644 index 0000000..0f24744 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py @@ -0,0 +1,218 @@ + +import torch +import triton +import triton.language as tl +# import math # Only if standard math functions are truly needed outside the kernel + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_obs, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + + tlogic = tl.load( + Mid_O_LogExpSum + + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s + ) + + new_max_logic = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max_logic) + exp_logic = tl.exp(tlogic - new_max_logic) + + acc = acc * scale + tv * exp_logic + sum_exp = sum_exp * scale + exp_logic + max_logic = new_max_logic + + acc = acc / sum_exp + + offs_out = ( + cur_batch * stride_obs + + cur_head * stride_oh + + offs_d * stride_od + ) + tl.store(Out + offs_out, acc.to(Out.dtype.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, + BLOCK_DMODEL: int, +): + batch, head_num, seq_block, head_dim = Mid_O.shape + assert head_dim == BLOCK_DMODEL + num_warps = 4 + num_stages = 2 + + grid = (batch, head_num) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stdout new file mode 100644 index 0000000..b2fc4bf --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_754689.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_754689 due to flash_decode_stage2() missing 1 required positional argument: 'BLOCK_DMODEL' diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py new file mode 100644 index 0000000..9a02743 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py @@ -0,0 +1,241 @@ + +import torch +import triton +import triton.language as tl +import math + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [B, H, seq_block_num, D] + Mid_O_LogExpSum, # [B, H, seq_block_num] + Out, # [B, H, D] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + # program ids + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + # load actual sequence length + cur_seq_len = tl.load(B_Seqlen + cur_batch).to(tl.int32) + block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + # initialize softmax accumulators + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + # register mapping across D + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od, + ) + + tlogic = tl.load( + Mid_O_LogExpSum + + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_s, + ) + + new_max_logic = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max_logic) + exp_logic = tl.exp(tlogic - new_max_logic) + + acc = acc * scale + tv * exp_logic + sum_exp = sum_exp * scale + exp_logic + max_logic = new_max_logic + + # avoid possible NaN when every block has identical logic + acc = acc / sum_exp + + # write final result + offs_out = ( + cur_batch * stride_ob + + cur_head * stride_oh + + offs_d * stride_od + ) + tl.store(Out + offs_out, acc.to(Out.dtype.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +): + """ + Triton wrapper matching externally generated unit-test signature. + + Parameters: + B_Seqlen : [B] int32 — sequence lengths per batch + Mid_O : [B, H, seq_block_num, D] + Mid_O_LogExpSum : [B, H, seq_block_num] (accumulated/pre-rotated logit terms) + Out : [B, H, D] (output to populate) + BLOCK_SEQ : int — blocking granularity + Returns: + torch.Tensor : Out view after completion + """ + + B, H, seq_block_num, D = Mid_O.shape + assert Mid_O_LogExpSum.shape == (B, H, seq_block_num) + assert B_Seqlen.shape == (B,) + assert Out.shape == (B, H, D) + + # Kernel compile-time constants taken from tensor shapes + BLOCK_DMODEL = D + + grid = (B, H) + num_warps = 4 + num_stages = 2 + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stdout new file mode 100644 index 0000000..bb5d328 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_802348.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_802348 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py new file mode 100644 index 0000000..18eee97 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py @@ -0,0 +1,226 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_b_seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_l_b, + stride_mid_l_h, + stride_mid_l_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_b_seqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + tv = tl.load( + Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ).to(tl.float32) + + tlogic = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_l_b + + cur_head * stride_mid_l_h + + block_id * stride_mid_l_s + ).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + scale = tl.exp(max_logic - new_max) + sum_exp *= scale + acc *= scale + + tlogic_exp = tl.exp(tlogic - new_max) + acc += tv * tlogic_exp + sum_exp += tlogic_exp + + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc.to(Out.type.element_ty) + ) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + BLOCK_DMODEL = Mid_O.shape[-1] + + assert BLOCK_SEQ > 0 + assert Out.shape == (batch, head_num, BLOCK_DMODEL) + assert Mid_O.shape[0] == batch + assert Mid_O.shape[1] == head_num + assert Mid_O.dim() == 4 + assert Mid_O_LogExpSum.dim() == 3 + assert Mid_O_LogExpSum.shape[0] == batch + assert Mid_O_LogExpSum.shape[1] == head_num + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_812012.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py new file mode 100644 index 0000000..de22cbd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py @@ -0,0 +1,202 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_olesb, + stride_mid_olesh, + stride_mid_oles, + stride_oub, + stride_ouh, + stride_oud, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_head = tl.program_id(1) + cur_batch = tl.program_id(0) + + offs_d = tl.arange(0, BLOCK_DMODEL) + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + seq_len = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + for block_n in range(block_n_size): + tv = tl.load(Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + block_n * stride_mid_os + offs_d * stride_mid_od) + tlogic = tl.load(Mid_O_LogExpSum + cur_batch * stride_mid_olesb + + cur_head * stride_mid_olesh + block_n * stride_mid_oles) + + new_max_logic = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max_logic) + new_scale = tl.exp(tlogic - new_max_logic) + + acc = acc * old_scale + tv * new_scale + sum_exp = sum_exp * old_scale + new_scale + max_logic = new_max_logic + + acc = acc / sum_exp + tl.store(Out + cur_batch * stride_oub + cur_head * stride_ouh + offs_d * stride_oud, acc) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + block_seq: int +): + batch = B_Seqlen.shape[0] + head_num = Mid_O.shape[1] + assert Mid_O_LogExpSum.shape[1] == head_num + + BLOCK_SEQ = block_seq + BLOCK_DMODEL = Mid_O.shape[3] + + _fwd_kernel_flash_decode_stage2[(batch, head_num)]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2 + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_83138.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py new file mode 100644 index 0000000..2f8acde --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py @@ -0,0 +1,206 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_o_b, + stride_mid_o_h, + stride_mid_o_block, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_block, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + sum_exp = 0.0 + max_logic = float('-inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(block_n_size): + tv = tl.load( + Mid_O + cur_batch * stride_mid_o_b + cur_head * stride_mid_o_h + + block_id * stride_mid_o_block + offs_d + ) + + tlogic = tl.load( + Mid_O_LogExpSum + + cur_batch * stride_mid_lse_b + + cur_head * stride_mid_lse_h + + block_id * stride_mid_lse_block + ) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.math.exp(max_logic - new_max) + new_scale = tl.math.exp(tlogic - new_max) + + acc = acc * old_scale + acc += tv * new_scale + sum_exp = sum_exp * old_scale + new_scale + max_logic = new_max + + sum_exp_inv = 1.0 / sum_exp + acc = acc * sum_exp_inv + + out_ptr = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d + tl.store(out_ptr, acc.to(out_ptr.dtype.element_ty)) + + +@torch.no_grad() +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + block_seq: int, +): + batch, head_num = Out.shape[0], Out.shape[1] + BLOCK_DMODEL = Out.shape[2] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_870175.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py new file mode 100644 index 0000000..0dda36d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py @@ -0,0 +1,203 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_out_b, + stride_out_h, + stride_out_d, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + offs_d = tl.arange(0, BLOCK_DMODEL) + + seq_len = tl.load(B_Seqlen + cur_batch) + n_blocks = tl.maximum((seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ, 1) + + acc = tl.zeros([BLOCK_DMODEL], dtype=Mid_O.dtype.element_ty) + max_logic = -float('inf') + sum_exp = 0.0 + + for i in range(n_blocks): + v = tl.load( + Mid_O + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + i * stride_mid_os + + offs_d * stride_mid_od + ) + logsum = tl.load( + Mid_O_LogExpSum + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + i * stride_mid_les_s + ) + new_max = tl.maximum(logsum, max_logic) + scale = tl.exp(max_logic - new_max) + acc = acc * scale + v * tl.exp(logsum - new_max) + sum_exp = sum_exp * scale + tl.exp(logsum - new_max) + max_logic = new_max + + acc = acc / sum_exp + tl.store( + Out + cur_batch * stride_out_b + + cur_head * stride_out_h + + offs_d * stride_out_d, + acc + ) + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int = 64, + BLOCK_DMODEL: int = 128, +): + B, H = Out.shape[0], Out.shape[1] + grid = (B, H) + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3) if Mid_O.ndim == 4 else 1, + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stdout new file mode 100644 index 0000000..bfb8bbc --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_882682.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_882682 due to Dimension out of range (expected to be in range of [-1, 0], but got 1) diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py new file mode 100644 index 0000000..9880e7d --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_bseqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_les_b, + stride_mid_les_h, + stride_mid_les_s, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + seqlen = tl.load(B_Seqlen + cur_batch * stride_bseqlen) + block_n_size = (seqlen + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = float("-inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for block_id in range(0, block_n_size): + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(Mid_O + offs_mid_o).to(tl.float32) + + offs_mid_lse = ( + cur_batch * stride_mid_les_b + + cur_head * stride_mid_les_h + + block_id * stride_mid_les_s + ) + tlogic = tl.load(Mid_O_LogExpSum + offs_mid_lse).to(tl.float32) + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + acc = acc * old_scale + sum_exp = sum_exp * old_scale + + exp_val = tl.exp(tlogic - new_max) + acc += tv * exp_val + sum_exp += exp_val + max_logic = new_max + + acc = acc / sum_exp + + offs_out = cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(Out + offs_out, acc.to(Out.type.element_ty)) + + +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, +) -> torch.Tensor: + batch = Mid_O.size(0) + head_num = Mid_O.size(1) + BLOCK_DMODEL = Mid_O.size(-1) + BLOCK_SEQ = Mid_O.size(-2) + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + B_Seqlen.stride(0), + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + return Out + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stdout new file mode 100644 index 0000000..9c6a9f6 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_900175.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_900175 due to flash_decode_stage2() takes 4 positional arguments but 5 were given diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py new file mode 100644 index 0000000..1a4fe4c --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py @@ -0,0 +1,218 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head_num, seq_blocks, BLOCK_DMODEL] + Mid_O_LogExpSum, # [batch, head_num, seq_blocks] + Out, # [batch, head_num, BLOCK_DMODEL] + stride_mid_ob, # stride(batch) + stride_mid_oh, # stride(head_num) + stride_mid_os, # stride(seq_blocks) + stride_mid_od, # stride(BLOCK_DMODEL) + stride_mid_o_eb, # stride(batch) + stride_mid_o_eh, # stride(head_num) + stride_mid_o_es, # stride(seq_blocks) + stride_ob, # stride(batch) + stride_oh, # stride(head_num) + stride_od, # stride(BLOCK_DMODEL) + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + for block_id in range(0, block_n_size): + ptr_tv = ( + Mid_O + + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + block_id * stride_mid_os + + offs_d * stride_mid_od + ) + tv = tl.load(ptr_tv) + + ptr_tlogic = ( + Mid_O_LogExpSum + + cur_batch * stride_mid_o_eb + + cur_head * stride_mid_o_eh + + block_id * stride_mid_o_es + ) + tlogic = tl.load(ptr_tlogic) + + max_prev = max_logic + max_logic = tl.maximum(max_prev, tlogic) + sum_exp = sum_exp * tl.exp(max_prev - max_logic) + tl.exp(tlogic - max_logic) + acc = acc * tl.exp(max_prev - max_logic) + tv * tl.exp(tlogic - max_logic) + + result = acc / (sum_exp + 1e-6) + + ptr_out = ( + Out + + cur_batch * stride_ob + + cur_head * stride_oh + + offs_d * stride_od + ) + tl.store(ptr_out, result.to(ptr_out.dtype.element_ty)) + + +@torch.no_grad() +def flash_decode_stage2( + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + B_Seqlen: torch.Tensor, + Out: torch.Tensor, + block_seq: int, +): + batch, head_num, seq_blocks, BLOCK_DMODEL = Mid_O.shape + + triton_grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[ + triton_grid + ]( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + Mid_O.stride(0), + Mid_O.stride(1), + Mid_O.stride(2), + Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), + Mid_O_LogExpSum.stride(1), + Mid_O_LogExpSum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_925215.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py new file mode 100644 index 0000000..284b8d3 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py @@ -0,0 +1,193 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, + Mid_O_LogExpSum, + Out, + stride_mid_o_b, + stride_mid_o_h, + stride_mid_o_block, + stride_mid_o_d, + stride_mid_lse_b, + stride_mid_lse_h, + stride_mid_lse_block, + stride_ob, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + sum_exp = 0.0 + max_logic = float('-inf') + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + seq_len = tl.load(B_Seqlen + cur_batch) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + for block_id in range(block_n_size): + + offs_d = tl.arange(0, BLOCK_DMODEL) + + cur_mid_o_ptr = Mid_O + cur_batch * stride_mid_o_b + cur_head * stride_mid_o_h + block_id * stride_mid_o_block + offs_d * stride_mid_o_d + tv = tl.load(cur_mid_o_ptr) # [BLOCK_DMODEL] + + cur_mid_lse_ptr = Mid_O_LogExpSum + cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_block + tlogic = tl.load(cur_mid_lse_ptr) # scalar + + new_max = tl.maximum(max_logic, tlogic) + old_scale = tl.exp(max_logic - new_max) + new_scale = tl.exp(tlogic - new_max) + + acc = acc * old_scale + acc += tv * new_scale + sum_exp = sum_exp * old_scale + new_scale + max_logic = new_max + + sum_exp_inv = 1.0 / sum_exp + acc = acc * sum_exp_inv + + out_ptr = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od + tl.store(out_ptr, acc.to(out_ptr.type.element_ty)) + + +def flash_decode_stage2( + B_Seqlen: torch.Tensor, + Mid_O: torch.Tensor, + Mid_O_LogExpSum: torch.Tensor, + Out: torch.Tensor, + BLOCK_SEQ: int, +): + batch, head_num = Out.shape[0], Out.shape[1] + BLOCK_DMODEL = Out.shape[2] + + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + Mid_O, Mid_O_LogExpSum, Out, + Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), + Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), + Out.stride(0), Out.stride(1), Out.stride(2), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + ) + + +################################################################################################################################################## + + + + + +import torch + + + +# Define the test function + +def test_flash_decode_stage2(): + + # Define the parameters for different test cases + + batch_size = 2 + + head_num = 4 + + seq_block_num = 3 + + head_dim = 64 + + block_seq = 16 + + + + test_cases = { + + "test_case_1": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + + }, + + "test_case_2": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq + 1 # Different block size + + }, + + "test_case_3": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq // 2 # Different block size + + }, + + "test_case_4": { + + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + + "block_seq": block_seq * 2 # Different block size + + } + + } + + + + # Execute the function for all test cases + + results = {} + + for key, test_case in test_cases.items(): + + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + + results[key] = test_case["Out"] + + + + return results + + + +# Run the test + +result_gold = test_flash_decode_stage2() diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stderr b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stdout b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stdout new file mode 100644 index 0000000..a065d90 --- /dev/null +++ b/src/temp/gen/flash_decode2_phi.py_gen_triton_code_959027.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module flash_decode2_phi.py_gen_triton_code_959027 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py new file mode 100644 index 0000000..55540cf --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py @@ -0,0 +1,197 @@ + +import torch +import triton +import triton.language as tl + +# -------------------------------------------------- +# Triton kernel +# -------------------------------------------------- +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + x_ptr, qw_ptr, sc_ptr, zp_ptr, c_ptr, + M, N, K, + stride_xm, stride_xk, + stride_qwk, stride_qwn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_cm, stride_cn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_sp_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = pid_sp_k * BLOCK_SIZE_K * SPLIT_K + tl.arange(0, BLOCK_SIZE_K * SPLIT_K) + + mask_m = offs_m < M + mask_n = offs_n < N + mask_k = offs_k < K + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k0 in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + current_offs_k = k0 * BLOCK_SIZE_K * SPLIT_K + offs_k + mask_kk = current_offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + current_offs_k[None, :] * stride_xk + x_blk = tl.load(x_ptrs, mask=mask_m[:, None] & mask_kk[None, :], other=0.0) + + qw_ptrs = qw_ptr + (current_offs_k[:, None] // 8) * stride_qwk + offs_n[None, :] * stride_qwn + qw_blk = tl.load(qw_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0) + + g_idx = (current_offs_k // group_size) + sc_ptrs = sc_ptr + g_idx[:, None] * stride_scg + offs_n[None, :] * stride_scn + zp_ptrs = zp_ptr + g_idx[:, None] * stride_zpg + (offs_n[None, :] // 8) * stride_zpn + + sc = tl.load(sc_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0.0).to(tl.float32) + zp = tl.load(zp_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0) + + shifts = (current_offs_k % 8) * 4 + int4_w = (qw_blk >> shifts[:, None]) & 0xF + zp_shifts = (offs_n[None, :] % 8) * 4 + int4_zp = (zp >> zp_shifts) & 0xF + deq_w = ((int4_w.float() - int4_zp.float()) * sc).to(tl.float16) + + acc += tl.dot(x_blk.to(tl.float16), deq_w).to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_out = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] + + if SPLIT_K > 1: + tl.atomic_add(out_ptrs, acc.astype(tl.float16), mask=mask_out) + else: + tl.store(out_ptrs, acc.astype(tl.float16), mask=mask_out) + +# -------------------------------------------------- +# Wrapper +# -------------------------------------------------- +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, + scale: torch.Tensor, zero_point: torch.Tensor, + group_size: int = 128) -> torch.Tensor: + assert x.dim() == 2 + assert qweight.dim() == 2 + assert scale.dim() == 2 + assert zero_point.dim() == 2 + M, K = x.shape + K8, N = qweight.shape + assert K == K8 * 8 + x = x.contiguous() + output = torch.empty((M, N), dtype=torch.float16, device=x.device) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, scale, zero_point, output, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + scale.stride(0), scale.stride(1), + zero_point.stride(0), zero_point.stride(1), + output.stride(0), output.stride(1), + group_size, + ) + return output + +# -------------------------------------------------- +# Quantization helpers +# -------------------------------------------------- +def quantize_int4(x: torch.Tensor, group_size: int = 128): + orig_shape = x.shape + x = x.view(-1, orig_shape[-1]) + K, N = x.shape + if K % group_size: + pad_k = (K + group_size - 1) // group_size * group_size + x = torch.nn.functional.pad(x, (0, 0, 0, pad_k - K)) + K = pad_k + x = x.view(-1, group_size, N) + x_min = x.amin(dim=1, keepdim=True) + x_max = x.amax(dim=1, keepdim=True) + denom = x_max - x_min + denom[denom.abs() < 1e-12] = 1.0 + sc = (denom) / 15.0 + zp = torch.round(-x_min / sc) + q = torch.clamp(torch.round(x / sc + zp), 0, 15).to(torch.int32) + qf = q.view(K, N) + packed = torch.zeros(K // 8, N, dtype=torch.int32, device=x.device) + for shift in range(8): + packed |= qf[shift::8, :] << (shift * 4) + sc = sc.view(K // group_size, N) + zp = zp.view(K // group_size, N) + return packed, sc, zp, orig_shape + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + group_size: int = 128): + K8, N = packed.shape + K = K8 * 8 + w = torch.empty(K, N, dtype=torch.float32, device=packed.device) + for shift in range(8): + w[shift::8, :] = ((packed >> (shift * 4)) & 0xF).float() + scale1 = scale.view(-1, N) + zp1 = zero_point.view(-1, N) + return ((w.view(-1, group_size, N) - zp1.unsqueeze(1)) * scale1.unsqueeze(1)).view(K, N) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stdout new file mode 100644 index 0000000..e7a9660 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_124574.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_124574 due to dynamic_func() missing 1 required positional argument: 'SPLIT_K' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py b/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py new file mode 100644 index 0000000..556c5fb --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py @@ -0,0 +1,173 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_k, stride_scales_n, + stride_zeros_k, stride_zeros_n, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + grid_m = tl.cdiv(M, BLOCK_SIZE_M) + grid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * grid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(grid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + scales_ptrs = scales_ptr + ((offs_k[None, :] // group_size) * stride_scales_k + offs_n[:, None] * stride_scales_n) + zeros_ptrs = zeros_ptr + ((offs_k[None, :] // group_size) * stride_zeros_k + (offs_n[:, None] // 8) * stride_zeros_n) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_start = k * BLOCK_SIZE_K * SPLIT_K + pid_z * BLOCK_SIZE_K + idx_k = k_start + tl.arange(0, BLOCK_SIZE_K) + mask_k = idx_k[None, :] < K + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + + idx_k_packed = (idx_k[None, :] // 8) + b = tl.load(b_ptr + idx_k_packed * stride_bk + offs_n[None, :] * stride_bn, mask=mask_k, other=0) + + idx_g = (idx_k[None, :] // group_size) + bs = tl.load(scales_ptr + idx_g * stride_scales_k + offs_n[None, :] * stride_scales_n, mask=mask_k, other=1.0) + zs = tl.load(zeros_ptr + idx_g * stride_zeros_k + (offs_n[None, :] // 8) * stride_zeros_n, mask=mask_k, other=0.0) + + shift = (idx_k[None, :] % 8) * 4 + int4_val = (b >> shift) & 0xF + zp4_val_all = zs & 0x0F0F0F0F + zp4_val = (zs >> ((offs_n[None, :] % 8) * 4)) & 0xF + b_fp = (int4_val - zp4_val) * bs + accumulator += tl.dot(a, b_fp.to(a.dtype)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, accumulator, mask=mask_c) + else: + tl.store(c_ptrs, accumulator, mask=mask_c) + +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, K: int) -> torch.Tensor: + M, _ = x.shape + _, N = qweight.shape + + c = torch.empty((M, N), dtype=x.dtype, device=x.device) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, c, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + 128 + ) + return c + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]: + shape = weights.shape + w = weights.to(torch.float32).reshape(-1, group_size) + w_min = w.min(dim=1, keepdim=True)[0] + w_max = w.max(dim=1, keepdim=True)[0] + scales = (w_max - w_min) / 15.0 + zeros = -w_min / scales + w = torch.round((w - w_min) / scales).clamp(0, 15).to(torch.uint8) + packed = torch.zeros(w.shape[0], (w.shape[1] + 1) // 2, dtype=torch.int32, device=weights.device) + packed[:, :w.shape[1]//2] = (w[:, ::2] & 0x0F) | ((w[:, 1::2] & 0x0F) << 4) + packed = packed.reshape(shape[0], shape[1] // 8) + scales = scales.reshape(shape[0], shape[1] // group_size) + zeros = zeros.reshape(shape[0], shape[1] // group_size) + return packed, scales, zeros, None + +def unpack_int4(weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + shape = weights.shape + assert shape[-1] == scales.shape[-1] + w = weights.view(-1, weights.shape[-1]) + scales = scales.view(-1, scales.shape[-1]) + zeros = zeros.view(-1, zeros.shape[-1]) + unpacked = torch.zeros(w.shape[0], w.shape[1] * 8, dtype=torch.float32, device=weights.device) + for i in range(8): + unpacked[:, i::8] = ((w >> (4*i)) & 0x0F) * scales[:, w.shape[1]*(8*i)//group_size:w.shape[1]*(8*i)//group_size+1].expand(-1, unpacked.shape[1]//8) + \ + zeros[:, w.shape[1]*(8*i)//group_size:w.shape[1]*(8*i)//group_size+1].expand(-1, unpacked.shape[1]//8) + return unpacked.reshape(weights.shape[0], weights.shape[1] * 8) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stdout new file mode 100644 index 0000000..8596f88 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_178552.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_178552 due to shape '[4096, 512]' is invalid for input of size 8388608 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py b/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py new file mode 100644 index 0000000..6a648a3 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py @@ -0,0 +1,217 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,num_warps=8), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * (BLOCK_SIZE_K) + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk) + offs_n[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_off = k * BLOCK_SIZE_K * SPLIT_K + k_now = k_off + offs_k + a = tl.load(a_ptrs, mask=k_now[None, :] < K, other=0.0) + + b_packed = tl.load(b_ptrs, mask=k_now[:, None] < K, other=0) + b_shift = ((k_now[:, None] % 8) * 4) + b_i4 = (b_packed >> b_shift) & 0xF + + g_id = k_now[:, None] // group_size + b_scale = tl.load(bs_ptr + g_id * stride_bsk + offs_n[None, :] * stride_bsn, + mask=k_now[:, None] < K, other=0.0) + b_zero = tl.load(bzp_ptr + g_id * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn, + mask=k_now[:, None] < K, other=0) + + zp_shift = ((offs_n[None, :] % 8) * 4) + b_zp_i4 = (b_zero >> zp_shift) & 0xF + + b_float = (b_i4 - b_zp_i4) * b_scale + acc += tl.dot(a, b_float.to(a.dtype)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + c = acc + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, + qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + M, K = x.shape + _, N = scales.shape + assert K == qweight.shape[0] * 8 + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, ) + return output + + +def quantize_int4(w: torch.Tensor, group_size: int = 128): + w = w.float() + oc, ic = w.shape # (K, N) un-transposed + assert ic % group_size == 0 + w = w.view(oc, ic // group_size, group_size) + + wmax = w.max(dim=2, keepdim=True)[0] + wmin = w.min(dim=2, keepdim=True)[0] + scale = (wmax - wmin) / 15.0 + zero = torch.round(-wmin / scale).clamp(0, 15).to(torch.uint8) + + q = torch.round((w - wmin) / scale).clamp(0, 15).to(torch.uint8) + + q = q.view(oc, ic) + zero = zero.view(oc, ic // group_size) + + ncols = ic + packed = torch.zeros((oc, ncols // 8), dtype=torch.int32, device=w.device) + for i in range(0, ncols, 8): + chunk = q[:, i:i+8].to(torch.int32) + packed[:, i//8] = ( + chunk[:,7] << 28 | + chunk[:,6] << 24 | + chunk[:,5] << 20 | + chunk[:,4] << 16 | + chunk[:,3] << 12 | + chunk[:,2] << 8 | + chunk[:,1] << 4 | + chunk[:,0] + ) + + zero_packed = torch.zeros((oc, (ncols // group_size + 7) // 8), + dtype=torch.int32, device=w.device) + nz = zero.shape[1] + for i in range(0, nz, 8): + zchunk = zero[:, i:i+8].to(torch.int32) + idx = torch.arange(zchunk.size(1), device=w.device) + zpacked = torch.sum(zchunk << (idx * 4), dim=1, keepdim=True) + zero_packed[:, i//8] = zpacked.squeeze(1) + + return packed.view(torch.int32), scale.squeeze(-1).half(), zero_packed.view(torch.int32) + + +def unpack_int4(w_packed: torch.IntTensor, scale: torch.FloatTensor, + zero: torch.IntTensor, group_size: int = 128): + oc, Nw = w_packed.shape + ic = Nw * 8 + num_groups = ic // group_size + scale = scale.view(oc, num_groups, 1).expand(-1, -1, group_size).reshape(oc, ic) + zero_shape = (oc, num_groups) + nz = zero.shape[1] * 8 + zero = zero.view(oc, nz)[:, :num_groups] + zero = zero.view(oc, num_groups, 1).expand(-1, -1, group_size).reshape(oc, ic) + + w_bytes = w_packed.view(torch.uint8).view(oc, ic // 2) + w = torch.empty((oc, ic), dtype=torch.uint8, device=w_packed.device) + for k in range(0, ic, 2): + b = w_bytes[:, k//2] + w[:, k] = (b & 0xF).to(torch.uint8) + w[:, k+1] = ((b >> 4) & 0xF).to(torch.uint8) + + return (w.float() - zero.float()) * scale.float() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stdout new file mode 100644 index 0000000..63d9c5d --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_216434.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_216434 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py b/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py new file mode 100644 index 0000000..bc48d3b --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py @@ -0,0 +1,216 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'] +) +@triton.jit +def matmul_kernel( + A, B, C, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bq, stride_bs, stride_bz, + SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + k_start = pid_k * tl.cdiv(K, SPLIT_K) + k_end = min((pid_k + 1) * tl.cdiv(K, SPLIT_K), K) + + offs_k = k_start + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = B + (offs_k // 2)[:, None] * stride_bk + offs_n[None, :] * stride_bn + + accum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + group_size = tl.constexpr(32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_valid = k + tl.arange(0, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (k_valid[None, :] < K), other=0.0) + + qoffs = k_valid // group_size + shift = ((k_valid % group_size) & 1) * 4 + mask = (k_valid < K)[:, None] + + packed = tl.load(B + (k_valid // 2)[:, None] * stride_bk + offs_n[None, :] * stride_bn, + mask=mask & (offs_n[None, :] < N), other=0) + packed = packed.to(tl.int32) + + scale_ptrs = B + stride_bq + qoffs[:, None] * stride_bs + zero_ptrs = B + stride_bq + qoffs[:, None] * stride_bz + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + zero = tl.load(zero_ptrs, mask=mask, other=0.0) + + q = ((packed >> shift[:, None]) & 0xF).to(tl.float32) + b = (q - zero) * scale + accum += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, accum, mask=mask) + else: + tl.atomic_add(c_ptrs, accum, mask=mask) + +def matmul_dequantize_int4_s2( + x: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + split_k: int = 1 +) -> torch.Tensor: + B, M, K = x.shape + K_packed = w.shape[0] + N = w.shape[1] + assert K_packed == K // 2, f"Packed shape {K_packed} must equal K//2={K//2}" + assert w.dtype == torch.int32 + c = torch.empty((B, M, N), dtype=x.dtype, device=x.device) + + total_M = B * M + grid = lambda META: ( + triton.cdiv(total_M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + split_k + ) + + aux = torch.empty((2 * (scales.numel() + zeros.numel()),), dtype=torch.float32, device=w.device) + stride_bq = w.numel() * 4 + stride_bs = scales.stride(-1) if scales.dim() >= 1 else 1 + stride_bz = zeros.stride(-1) if zeros.dim() >= 1 else 1 + + matmul_kernel[grid]( + x.view(-1, K), w, c.view(-1, N), + total_M, N, K, + x.stride(-2) if x.dim() >= 2 else K, + x.stride(-1) if x.dim() >= 1 else 1, + w.stride(-2), + w.stride(-1), + c.stride(-2) if c.dim() >= 2 else N, + c.stride(-1) if c.dim() >= 1 else 1, + stride_bq, stride_bs, stride_bz, + SPLIT_K=split_k, + GROUP_SIZE_M=8, + ) + return c + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert weights.dim() == 2, "weights must be 2-D (K, N)" + K, N = weights.shape + assert K % group_size == 0, f"K={K} must divide group_size={group_size}" + num_groups = K // group_size + flat = weights.to(torch.float32).view(num_groups, group_size, N) + + mn, mx = flat.aminmax(dim=1, keepdim=True) + scale = (mx - mn) / 15.0 + scale = torch.where(scale == 0, 1.0, scale) + zero = -mn / scale + + q = ((flat / scale + zero + 0.5).floor()).clamp(0, 15) + q = q.view(num_groups * group_size, N) + + q_low = q[:q.shape[0]//2] + q_high = q[q.shape[0]//2:] + packed = (q_low & 0xF) | ((q_high & 0xF) << 4) + packed = packed.view(K // 2, N).to(torch.int32) + + scale = scale.squeeze(-2).squeeze(-1) + zero = zero.squeeze(-2).squeeze(-1) + + return packed, scale, zero + +def unpack_int4(w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + assert w.dim() == 2 and w.dtype == torch.int32 + K_half, N = w.shape + K = K_half * 2 + num_groups = K // group_size + assert scales.shape[-1] == num_groups + assert zeros.shape[-1] == num_groups + + b0 = (w & 0xF).float() + b1 = ((w >> 4) & 0xF).float() + + q = torch.stack([b0, b1], dim=-1).view(K, N) + scales = scales.view(-1, 1).repeat(1, group_size).view(-1, 1) + zeros = zeros.view(-1, 1).repeat(1, group_size).view(-1, 1) + unpacked = (q - zeros[:K]) * scales[:K] + return unpacked.view(K, N).to(torch.float16) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stdout new file mode 100644 index 0000000..78ce73c --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_219875.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_219875 due to "bitwise_and_cuda" not implemented for 'Float' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py b/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py new file mode 100644 index 0000000..5d9a59e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# int4 de-quant helpers +@triton.jit +def _dequantize_int4_unpack(xi32, mask0=0x0f, mask1=0xf0): + xi0 = (xi32 & mask0).to(tl.int8) + xi1 = ((xi32 & mask1) >> 4).to(tl.int8) + return xi0, xi1 + + +@triton.jit +def _dequantize_int4_kernel(ptr, scales_ptr, zeros_ptr, M, N, + stride_q, stride_s, stride_z, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + q_offsets = (rm[:, None] * stride_q + (rn // 8)[None, :]) + scales_offsets = (rm[:, None] * stride_s + (rn // 8)[None, :]) + zeros_offsets = (rm[:, None] * stride_z + (rn // 8)[None, :]) + + mask_m = rm < M + mask_n = rn < N + mask = mask_m[:, None] & mask_n[None, :] + + packed = tl.load(ptr + q_offsets, mask=mask, other=0) + s = tl.load(scales_ptr + scales_offsets, mask=mask, other=1.0) + z = tl.load(zeros_ptr + zeros_offsets, mask=mask, other=0.0) + + offsets_0 = (rn % 8) * 4 + offsets_1 = offsets_0 + 4 + i0, i1 = _dequantize_int4_unpack(packed) + v0 = (i0.to(tl.float32) - z) * s + v1 = (i1.to(tl.float32) - z) * s + + return v0, v1 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_eval_k, stride_eval_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_k = tl.program_id(2) + + n_blocks_m = tl.cdiv(M, BLOCK_SIZE_M) + n_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + group_id = 0 + first_pid_m = 0 + else: + group_id = pid_m // GROUP_SIZE_M + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(n_blocks_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_m % group_size_m) + + if SPLIT_K > 1: + local_k = tl.cdiv(K, SPLIT_K) + k_offset = pid_k * local_k + else: + local_k = K + k_offset = 0 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + scales_ptrs = scales_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, local_k, BLOCK_SIZE_K): + if EVEN_K or (k + BLOCK_SIZE_K <= local_k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < local_k - k, other=0.0, eviction_policy="evict_last") + block_scale = tl.load(scales_ptrs, mask=offs_k[:, None] < local_k - k, other=1.0) + block_zero = tl.load(zeros_ptrs, mask=offs_k[:, None] < local_k - k, other=0.0) + + packed_b = tl.load(b_ptrs, mask=offs_k[:, None] < local_k - k, other=0) + k_idx = (offs_k[:, None] % 8) * 4 + val_low = (packed_b & 0x0F).to(tl.int8).to(tl.float32) + val_high = ((packed_b >> 4) & 0x0F).to(tl.int8).to(tl.float32) + b_low = (val_low - block_zero) * block_scale + b_high = (val_high - block_zero) * block_scale + + acc = tl.dot(a, b_low, acc) + a_shift = tl.load(a_ptrs + stride_bk * (1 if EVEN_K else 8), mask=offs_k[None, :] + 8 < local_k - k, other=0.0, eviction_policy="evict_last") + acc = tl.dot(a_shift, b_high, acc) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + scales_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + zeros_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + + if SPLIT_K == 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptrs.type.element_ty), mask=c_mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + pid_k * M * N + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=c_mask) + + +def matmul_dequantize_int4_s2(a, int4b_compressed, scales, zeros, M, N, K): + c_dtype = a.dtype + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + SPLIT_K = 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K) + + if SPLIT_K > 1: + c = torch.empty((SPLIT_K, M, N), dtype=torch.float32, device=a.device) + else: + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + EVEN_K = K % 32 == 0 + + matmul_kernel[grid](a, int4b_compressed, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + int4b_compressed.stride(0), int4b_compressed.stride(1), + c.stride(0) if c.dim() == 2 else c.stride(1), + c.stride(1) if c.dim() == 2 else c.stride(2), + scales.stride(0), scales.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + SPLIT_K=SPLIT_K, + EVEN_K=EVEN_K) + return c if SPLIT_K == 1 else c.sum(dim=0) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + org_shape = x.shape + x = x.view(-1, group_size) + x_fp32 = x.float() + x_min = x_fp32.amin(dim=-1, keepdim=True) + x_max = x_fp32.amax(dim=-1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale + 0.5).clamp(0, 15) + + int4 = (((x_fp32 - x_min) / scale + 0.5).clamp(0, 15)).round().to(torch.int8) + + int4 = int4.view(-1) + packed = torch.zeros(x.size(0), x.size(1) // 2, dtype=torch.int32, device=x.device) + int4_even = int4[::2].to(torch.int32) + int4_odd = int4[1::2].to(torch.int32) + packed = int4_even | (int4_odd << 4) + packed = packed.view(org_shape[0], org_shape[1] // 2) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + return packed, scale, zero + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + M, Nhalf = packed.shape + N = Nhalf * 2 + unpacked = torch.empty(M, N, dtype=torch.float16, device=packed.device) + + packed = packed.int() + for i in range(M): + for j in range(Nhalf): + low = (packed[i, j] & 0x0F).to(torch.float32) + high = ((packed[i, j] >> 4) & 0x0F).to(torch.float32) + group_idx = j * 2 // group_size + val_low = (low - zero[i, group_idx]) * scale[i, group_idx] + val_high = (high - zero[i, group_idx]) * scale[i, group_idx] + unpacked[i, 2 * j] = val_low.to(torch.float16) + unpacked[i, 2 * j + 1] = val_high.to(torch.float16) + return unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stdout new file mode 100644 index 0000000..b72ae4e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_243114.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_243114 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py b/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py new file mode 100644 index 0000000..c557c58 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py @@ -0,0 +1,257 @@ + +import torch +import triton +import triton.language as tl + +# ========================= +# Triton kernel (batched INT4 matrix multiply) +# ========================= +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16, 'SPLIT_K': 2}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr): + pid0 = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) # only meaningful when SPLIT_K > 1 + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid0 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid0 % group_size_m) + pid_n = (pid0 % num_pid_in_group) // group_size_m + + # block row/col indices + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + k_step = BLOCK_SIZE_K * SPLIT_K + k_lo = pid_k * BLOCK_SIZE_K + offs_k_block = k_lo + tl.arange(0, BLOCK_SIZE_K) + + # pointers + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k_block[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k_block[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, k_step)): + cur_k = offs_k_block + k * k_step + mask_k = cur_k[None, :] < K + mask_n = offs_n[None, :] < N + + a = tl.load(a_ptrs, mask=mask_k & (offs_m[:, None] < M), other=0.0) + + packed_b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0) + + # group indices + gidx = cur_k[None, :] // group_size + + scales = tl.load(scales_ptr + + gidx * stride_bsk + + offs_n[None, :] * stride_bsn, mask=mask_k & mask_n, other=0.0) + + zeros_packed = tl.load(zeros_ptr + + gidx * stride_bzpk + + (offs_n[None, :] // 8) * stride_bzpn, + mask=mask_k & mask_n, other=0) + zeros_packed = zeros_packed.to(tl.int32) + + shift = (cur_k[None, :] % 8) * 4 + zp_shift = (offs_n[None, :] % 8) * 4 + + int_b = (packed_b >> shift) & 0xF + int_zp = (zeros_packed >> zp_shift) & 0xF + b = ((int_b.to(tl.float32) - int_zp.to(tl.float32)) * scales) + acc += tl.dot(a, b) + + a_ptrs += k_step * stride_ak + b_ptrs += (k_step // 8) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask_c) + else: + tl.atomic_add(c_ptrs, acc, mask=mask_c) + +# ========================= +# Front-end helpers +# ========================= + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize weights to INT4, packing 8 INT4 values per int32 row. + Returns (qweight, scales, zeros) where + - qweight: [Kw, N] int32,Kw = ceil_div(K, 8) + - scales: [num_groups, N] float + - zeros: [num_groups, N] int32 after packing (8 zeros per int32) + """ + assert weights.dim() == 2 + K, N = weights.shape + assert K % group_size == 0 + + num_groups = K // group_size + w_groups = weights.view(num_groups, group_size, N) # [G, Gsz, N] + w_min, w_max = w_groups.aminmax(dim=1) # [G, N] + scale = (w_max - w_min) / 15.0 + scale = torch.where(scale == 0, torch.tensor(1.0, device=scale.device), scale) + zero = (-w_min / scale) + q = ((w_groups / scale.unsqueeze(1) + zero.unsqueeze(1) + 0.5).floor()).clamp(0, 15).to(torch.int32) + + q = q.view(K, N) # [K, N] + q_low = q[0::2] + q_high = q[1::2] + # pack into int32: [Kw, N] + packed = (q_low & 0xF) | ((q_high & 0xF) << 4) + + # pack zeros similarly + zero_int = zero.round().int().clip(0, 15) + zero_low = zero_int[..., 0::2] + zero_high = zero_int[..., 1::2] + zeros_packed = (zero_low & 0xF) | ((zero_high & 0xF) << 4) + + return packed, scale, zeros_packed + + +def unpack_int4(w, scales, zeros, group_size: int = 128): + """ + De-quantize w for numeric validation. + w: [Kw, N] int32, scales: [num_groups, N], zeros: [num_groups, N] int32 + returns float dequantized weight [K, N] + """ + Kw, N = w.shape + K = Kw * 8 + num_groups = K // group_size + assert num_groups == scales.shape[0] + + # unstitch + w0 = (w & 0xF).to(torch.float32) + w1 = ((w >> 4) & 0xF).to(torch.float32) + deq = torch.empty((K, N), device=w.device, dtype=w0.dtype) + deq[0::8] = w0[::2, :] + deq[1::8] = w1[::2, :] + deq[2::8] = (w0[1::2, :] if w0.shape[0] > 1 else w0[0:1, :]) + deq[3::8] = (w1[1::2, :] if w1.shape[0] > 1 else w0[0:1, :]) + # Because each int32 stores 8 int4 weights in four successive rows, + # split again correctly to rows [2,3] [4,5] [6,7] + idx = torch.arange(K, device=w.device)[:, None] + block = idx // 8 + offset_in_block = idx % 8 + gather = ((w[block, :] >> (4 * offset_in_block)) & 0xF).to(torch.float32) + deq_correct = gather.view(K, N) + + # broadcast scales and zeros + scales = scales.view(num_groups, 1, N).expand(num_groups, group_size, N).reshape(K, N) + zeros = zeros.view(num_groups, 1, N).expand(num_groups, group_size, N).reshape(K, N) + return deq_correct * scales + zeros + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.FloatTensor, + group_size: int = 128) -> torch.FloatTensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + M, K = x.shape + Kw, N = qweight.shape + assert Kw == K // 8 + + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + + matmul_kernel[grid]( + x, qweight, output, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size, + ) + return output + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stdout new file mode 100644 index 0000000..a1fff73 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_291697.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_291697 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py b/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py new file mode 100644 index 0000000..c50a853 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py @@ -0,0 +1,290 @@ + +import torch +import triton +import triton.language as tl + +configs_matmul = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), +] + +@triton.autotune( + configs=configs_matmul, + key=["M", "N", "K"], + use_cuda_graph=False +) +@triton.jit +def matmul_kernel( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g0, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1, + GROUP_SIZE_M: tl.constexpr = 8 +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_pos = k * BLOCK_SIZE_K * SPLIT_K + offs_k + g_idx = (k_pos) // groupsize + + mask_k = k_pos < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + offset_b = (k_pos[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_chunk = tl.load(B + offset_b, mask=mask_k[:, None], other=0) + + scale_offset = g_idx[:, None] * stride_scales_g + offs_n[None, :] * stride_scales_n + scale_val = tl.load(scales + scale_offset, mask=mask_k[:, None], other=0.0) + + zp_val = tl.load(zeros + g_idx[:, None] * stride_zeros_g0 + (offs_n // 8)[None, :] * stride_zeros_n, mask=mask_k[:, None], other=0.0) + shift_n = (offs_n % 8)[None, :] * 4 + inv_zp = ((zp_val >> shift_n) & 0xF) * scale_val + + shift_k = (k_pos % 8)[:, None] * 4 + w_int = (b_chunk >> shift_k) & 0xF + w_fp = (w_int * scale_val - inv_zp) + + accumulator += tl.dot(a, w_fp) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cm = offs_cm < M + mask_cn = offs_cn < N + c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_cm[:, None] & mask_cn[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.is_contiguous(), "A must be contiguous" + Kx, N = qweight.shape + K = Kx * 8 + M = x.shape[0] + assert x.shape[1] == K, f"A second dim {x.shape[1]} must equal weight rows {K}" + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + assert output.shape == (M, N), "output shape must be (M, N)" + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + max(META.get('SPLIT_K', 1), 1), + ) + num_groups = max(1, K // group_size) + second_dim = 1 if N <= 8 else (N + 7) // 8 + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +configs_dequant = [ + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), +] + +@triton.autotune( + configs=configs_dequant, + key=["K", "N"], + use_cuda_graph=False +) +@triton.jit +def dequantize_kernel( + qw_ptr, sc_ptr, zp_ptr, fpw_ptr, + K, N, group_size, + stride_qk, stride_qn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_fk, stride_fn, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + k_blk = tl.program_id(0) + n_blk = tl.program_id(1) + + offs_k = k_blk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = n_blk * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_k = offs_k[:, None] < K + mask_n = offs_n[None, :] < N + mask = mask_k & mask_n + + grp = offs_k[:, None] // group_size + + qw_offs = (offs_k[:, None] // 8) * stride_qk + offs_n[None, :] * stride_qn + qw_local = tl.load(qw_ptr + qw_offs, mask=mask, other=0) + + sc_offs = grp * stride_scg + offs_n[None, :] * stride_scn + sc_local = tl.load(sc_ptr + sc_offs, mask=mask, other=0.0) + + zp_offs = grp * stride_zpg + (offs_n // 8)[None, :] * stride_zpn + zp_quad = tl.load(zp_ptr + zp_offs, mask=mask, other=0) + + bits = 4 + shift_k = (offs_k % 8)[:, None] * bits + shift_n = (offs_n % 8)[None, :] * bits + + qh = (qw_local >> shift_k) & 0xF + qz = (zp_quad >> shift_n) & 0xF + + dq_val = (qh - qz) * sc_local + tl.store(fpw_ptr + offs_k[:, None] * stride_fk + offs_n[None, :] * stride_fn, dq_val, mask=mask) + +def quantize_int4(x: torch.Tensor, groupsize: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + W = x.t().contiguous() + K_raw, N = W.shape + assert K_raw % groupsize == 0, "K must be divisible by groupsize" + groups = K_raw // groupsize + W = W.view(groups, groupsize, N) + wmin = W.min(dim=1, keepdim=True)[0] + wmax = W.max(dim=1, keepdim=True)[0] + scale = (wmax - wmin) / 15 + zero = -wmin / scale + zero = torch.round(zero).clamp(0, 15) + + qweight = torch.clamp(torch.round(W / scale + zero), 0, 15) + qweight = qweight.to(torch.int8) + + packed = torch.zeros((groups * groupsize) // 8, N, dtype=torch.int32, device=x.device) + for col in range(N): + w_col = qweight[:, :, col].flatten() + for idx in range(0, w_col.size(0), 8): + v = w_col[idx:idx+8].view(torch.uint8) + v &= 0xF + v1 = v[::2] | (v[1::2] << 4) + packed[idx//8, col] = v1.view(torch.int32) + packed = packed.t() + qweight = packed + + scale = scale.squeeze(1).view(groups, N).t() + zero = zero.squeeze(1).view(groups, N).t() + + qzeros = torch.zeros_like(zero, dtype=torch.int32) + for col in range(N): + z_col = zero[:, col] + for idx in range(0, groups, 8): + seg = z_col[idx: idx+8] + tmp = 0 + for k, val in enumerate(seg): + tmp |= (val.int() & 0xF) << (4 * k) + qzeros[idx//8, col] = tmp + + return qweight.contiguous(), scale.contiguous(), qzeros.contiguous() + +def unpack_int4( + packed_weights: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + groupsize: int +) -> torch.Tensor: + K_pack, N = packed_weights.shape + K = K_pack * 8 + device = packed_weights.device + dtype = scales.dtype + + unpacked = torch.zeros((K, N), dtype=dtype, device=device) + for i in range(K_pack): + chunk = packed_weights[i, :] + for j in range(8): + bits = (chunk >> (j * 4)) & 0xF + unpacked[i * 8 + j, :] = bits.to(dtype) + + num_groups = K // groupsize + scale_ex = scales.t().repeat_interleave(groupsize, dim=0) + zeros_ex = zeros.t().repeat_interleave(groupsize, dim=0) + + zeros_unp = torch.zeros((K, N), dtype=dtype, device=device) + for i in range(num_groups): + z_chunk = zeros.t()[i, :] + for j in range(groupsize): + for n in range(N): + g_group = i * groupsize + j + zeros_unp[g_group, n] = z_chunk[n] + + result = (unpacked - zeros_unp) * scale_ex + return result.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stdout new file mode 100644 index 0000000..4c5a028 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_298484.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_298484 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py b/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py new file mode 100644 index 0000000..c278b48 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py @@ -0,0 +1,215 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m, M - first_pid_m * BLOCK_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N: + return + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + a_mask = offs_am[:, None] < M and offs_k[None, :] < K + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + offs_k_in_group = offs_k // GROUP_SIZE + scales = tl.load(scales_ptr + offs_bn[None, :] * stride_scales_n + offs_k_in_group[:, None] * stride_scales_g) + zeros = tl.load(zeros_ptr + offs_bn[None, :] * stride_zeros_n + offs_k_in_group[:, None] * stride_zeros_g) + + b = tl.load(b_ptrs, mask=offs_k[:, None] < K and offs_bn[None, :] < N, other=0.0) + b = b.to(tl.int32) + + b0 = (b & 0x0F) - 8 + b1 = ((b >> 4) & 0x0F) - 8 + + dequant_b0 = b0.to(tl.float32) * scales + zeros + dequant_b1 = b1.to(tl.float32) * scales + zeros + + b_reconstructed = tl.zeros((BLOCK_SIZE_K * 2, BLOCK_SIZE_N), dtype=tl.float32) + b_reconstructed = tl.where(tl.arange(0, BLOCK_SIZE_K * 2)[:, None] % 2 == 0, + dequant_b0[tl.arange(0, BLOCK_SIZE_K)[:, None], :], + dequant_b1[tl.arange(0, BLOCK_SIZE_K)[:, None], :]) + + valid_k = min(BLOCK_SIZE_K * 2, K - k * BLOCK_SIZE_K * 2) + a_inner = a[:, :valid_k] + b_inner = b_reconstructed[:valid_k, :] + + acc += tl.dot(a_inner, b_inner) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + offs_k += BLOCK_SIZE_K * SPLIT_K + + if SPLIT_K > 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = offs_cm[:, None] < M and offs_cn[None, :] < N + tl.atomic_add(c_ptrs, acc, mask=mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = offs_cm[:, None] < M and offs_cn[None, :] < N + tl.store(c_ptrs, acc, mask=mask) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + rows, cols = x.shape + assert cols % group_size == 0 + num_groups = cols // group_size + + x_groups = x.view(rows, num_groups, group_size) + x_min = x_groups.min(dim=2, keepdim=True)[0] + x_max = x_groups.max(dim=2, keepdim=True)[0] + scale = (x_max - x_min) / 15.0 + zero = -x_min / scale + + x_quantized = ((x_groups - x_min) / scale).round().clamp(0, 15).to(torch.int32) - 8 + x_quantized_uint = (x_quantized + 8).to(torch.uint8) + + packed = torch.zeros(rows, num_groups, group_size // 2, dtype=torch.int32, device=x.device) + for j in range(group_size // 2): + idx = j * 2 + packed[:, :, j] = ( + (x_quantized_uint[:, :, idx] & 0x0F) | + ((x_quantized_uint[:, :, idx + 1] & 0x0F) << 4) + ) + + scales = scale.squeeze(-1).contiguous() + zeros = zero.squeeze(-1).contiguous() + return packed.view(rows, num_groups * group_size // 2), scales, zeros + + +def matmul_dequantize_int4_s2( + a: torch.Tensor, b_q: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + group_size: int = 128, split_k: int = 1 +) -> torch.Tensor: + assert a.dim() == 2 + assert b_q.dim() == 2 + assert scales.dim() == 2 + assert zeros.dim() == 2 + assert a.shape[1] == b_q.shape[0] * 2, "Dimension mismatch between A and quantized B" + M, K = a.shape + N, _ = scales.shape + + c = torch.empty((M, N), dtype=torch.float32, device=a.device) + if split_k > 1: + c.zero_() + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + split_k, + ) + + matmul_kernel[grid]( + a, b_q, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + b_q.stride(0), b_q.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + GROUP_SIZE=group_size, + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=32, + SPLIT_K=split_k, + ) + return c + + +def unpack_int4(b_q: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + assert b_q.dim() == 2 + assert scales.dim() == 2 + assert zeros.dim() == 2 + rows, cols_packed = b_q.shape + assert cols_packed * 2 == scales.shape[1] * group_size + + cols = cols_packed * 2 + b_unpacked = torch.zeros(rows, cols, dtype=torch.float32, device=b_q.device) + + scales_expanded = scales.repeat_interleave(group_size, dim=1) + zeros_expanded = zeros.repeat_interleave(group_size, dim=1) + + for j in range(cols_packed): + packed_col = b_q[:, j] + idx = j * 2 + b_unpacked[:, idx] = ((packed_col & 0x0F).to(torch.float32) - 8) * scales_expanded[:, idx] + zeros_expanded[:, idx] + if idx + 1 < cols: + b_unpacked[:, idx + 1] = (((packed_col >> 4) & 0x0F).to(torch.float32) - 8) * scales_expanded[:, idx + 1] + zeros_expanded[:, idx + 1] + + return b_unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stdout new file mode 100644 index 0000000..006f0d2 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_308542.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_308542 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py b/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py new file mode 100644 index 0000000..5d9a59e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# int4 de-quant helpers +@triton.jit +def _dequantize_int4_unpack(xi32, mask0=0x0f, mask1=0xf0): + xi0 = (xi32 & mask0).to(tl.int8) + xi1 = ((xi32 & mask1) >> 4).to(tl.int8) + return xi0, xi1 + + +@triton.jit +def _dequantize_int4_kernel(ptr, scales_ptr, zeros_ptr, M, N, + stride_q, stride_s, stride_z, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + q_offsets = (rm[:, None] * stride_q + (rn // 8)[None, :]) + scales_offsets = (rm[:, None] * stride_s + (rn // 8)[None, :]) + zeros_offsets = (rm[:, None] * stride_z + (rn // 8)[None, :]) + + mask_m = rm < M + mask_n = rn < N + mask = mask_m[:, None] & mask_n[None, :] + + packed = tl.load(ptr + q_offsets, mask=mask, other=0) + s = tl.load(scales_ptr + scales_offsets, mask=mask, other=1.0) + z = tl.load(zeros_ptr + zeros_offsets, mask=mask, other=0.0) + + offsets_0 = (rn % 8) * 4 + offsets_1 = offsets_0 + 4 + i0, i1 = _dequantize_int4_unpack(packed) + v0 = (i0.to(tl.float32) - z) * s + v1 = (i1.to(tl.float32) - z) * s + + return v0, v1 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_eval_k, stride_eval_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_k = tl.program_id(2) + + n_blocks_m = tl.cdiv(M, BLOCK_SIZE_M) + n_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + group_id = 0 + first_pid_m = 0 + else: + group_id = pid_m // GROUP_SIZE_M + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(n_blocks_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_m % group_size_m) + + if SPLIT_K > 1: + local_k = tl.cdiv(K, SPLIT_K) + k_offset = pid_k * local_k + else: + local_k = K + k_offset = 0 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + scales_ptrs = scales_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, local_k, BLOCK_SIZE_K): + if EVEN_K or (k + BLOCK_SIZE_K <= local_k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < local_k - k, other=0.0, eviction_policy="evict_last") + block_scale = tl.load(scales_ptrs, mask=offs_k[:, None] < local_k - k, other=1.0) + block_zero = tl.load(zeros_ptrs, mask=offs_k[:, None] < local_k - k, other=0.0) + + packed_b = tl.load(b_ptrs, mask=offs_k[:, None] < local_k - k, other=0) + k_idx = (offs_k[:, None] % 8) * 4 + val_low = (packed_b & 0x0F).to(tl.int8).to(tl.float32) + val_high = ((packed_b >> 4) & 0x0F).to(tl.int8).to(tl.float32) + b_low = (val_low - block_zero) * block_scale + b_high = (val_high - block_zero) * block_scale + + acc = tl.dot(a, b_low, acc) + a_shift = tl.load(a_ptrs + stride_bk * (1 if EVEN_K else 8), mask=offs_k[None, :] + 8 < local_k - k, other=0.0, eviction_policy="evict_last") + acc = tl.dot(a_shift, b_high, acc) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + scales_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + zeros_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + + if SPLIT_K == 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptrs.type.element_ty), mask=c_mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + pid_k * M * N + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=c_mask) + + +def matmul_dequantize_int4_s2(a, int4b_compressed, scales, zeros, M, N, K): + c_dtype = a.dtype + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + SPLIT_K = 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K) + + if SPLIT_K > 1: + c = torch.empty((SPLIT_K, M, N), dtype=torch.float32, device=a.device) + else: + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + EVEN_K = K % 32 == 0 + + matmul_kernel[grid](a, int4b_compressed, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + int4b_compressed.stride(0), int4b_compressed.stride(1), + c.stride(0) if c.dim() == 2 else c.stride(1), + c.stride(1) if c.dim() == 2 else c.stride(2), + scales.stride(0), scales.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + SPLIT_K=SPLIT_K, + EVEN_K=EVEN_K) + return c if SPLIT_K == 1 else c.sum(dim=0) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + org_shape = x.shape + x = x.view(-1, group_size) + x_fp32 = x.float() + x_min = x_fp32.amin(dim=-1, keepdim=True) + x_max = x_fp32.amax(dim=-1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale + 0.5).clamp(0, 15) + + int4 = (((x_fp32 - x_min) / scale + 0.5).clamp(0, 15)).round().to(torch.int8) + + int4 = int4.view(-1) + packed = torch.zeros(x.size(0), x.size(1) // 2, dtype=torch.int32, device=x.device) + int4_even = int4[::2].to(torch.int32) + int4_odd = int4[1::2].to(torch.int32) + packed = int4_even | (int4_odd << 4) + packed = packed.view(org_shape[0], org_shape[1] // 2) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + return packed, scale, zero + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + M, Nhalf = packed.shape + N = Nhalf * 2 + unpacked = torch.empty(M, N, dtype=torch.float16, device=packed.device) + + packed = packed.int() + for i in range(M): + for j in range(Nhalf): + low = (packed[i, j] & 0x0F).to(torch.float32) + high = ((packed[i, j] >> 4) & 0x0F).to(torch.float32) + group_idx = j * 2 // group_size + val_low = (low - zero[i, group_idx]) * scale[i, group_idx] + val_high = (high - zero[i, group_idx]) * scale[i, group_idx] + unpacked[i, 2 * j] = val_low.to(torch.float16) + unpacked[i, 2 * j + 1] = val_high.to(torch.float16) + return unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stdout new file mode 100644 index 0000000..a2a64f3 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_312025.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_312025 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py b/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py new file mode 100644 index 0000000..ed8a9dd --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py @@ -0,0 +1,226 @@ + +import torch +import triton +import triton.language as tl + +# ------------------------------------------------------------------ +# INT4 matmul kernel (tiling: M tile, N tile, K tile, SPLIT-K capable) +# ------------------------------------------------------------------ +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + x_ptr, qw_ptr, sc_ptr, zp_ptr, o_ptr, + M, N, K, + stride_xm, stride_xk, + stride_qwk, stride_qwn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_om, stride_on, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr = 1, +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K * SPLIT_K + tl.arange(0, BLOCK_SIZE_K * SPLIT_K) + + mask_k = offs_k < K + mask_m = offs_m < M + mask_n = offs_n < N + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in tl.range(0, K, BLOCK_SIZE_K * SPLIT_K): + current_offs_k = k + tl.arange(0, BLOCK_SIZE_K * SPLIT_K) + mask_kk = current_offs_k < K + + x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + + current_offs_k[None, :] * stride_xk) + x_blk = tl.load(x_ptrs, mask=mask_m[:, None] & mask_kk[None, :], other=0.0) + + qw_ptrs = qw_ptr + ((current_offs_k[None, :] // 8) * stride_qwk + + offs_n[:, None] * stride_qwn) + qw_blk = tl.load(qw_ptrs, mask=mask_kk[None, :] & mask_n[:, None], other=0) + + grp_idx = (current_offs_k // group_size) + sc_ptrs = sc_ptr + grp_idx * stride_scg + offs_n[None, :] * stride_scn + sc_blk = tl.load(sc_ptrs, mask=mask_n[None, :], other=0.0) + + zp_ptrs = zp_ptr + grp_idx * stride_zpg + (offs_n[None, :] // 8) * stride_zpn + zp_blk = tl.load(zp_ptrs, mask=mask_n[None, :], other=0.0) + + shifts = (current_offs_k % 8) * 4 + int4s = (qw_blk >> shifts[None, :]) & 0xF + zp_shifts = (offs_n[None, :] % 8) * 4 + zp_int4 = (zp_blk >> zp_shifts) & 0xF + fp_blk = (int4s.to(tl.float32) - zp_int4.to(tl.float32)) * sc_blk.to(tl.float32) + + acc += tl.dot(x_blk.to(tl.float16), fp_blk.to(tl.float16)).to(tl.float32) + + c_ptrs = o_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + mask_mn = mask_m[:, None] & mask_n[None, :] + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, acc, mask=mask_mn) + else: + tl.store(c_ptrs, acc.astype(tl.float16), mask=mask_mn) + + +# ------------------------------------------------------------------ +# Wrapper for tensor-packed int4 inference +# ------------------------------------------------------------------ +def matmul_dequantize_int4_s2( + x: torch.Tensor, + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + group_size: int = 128, +) -> torch.Tensor: + assert x.dim() == 2 + assert qweight.dim() == 2 + assert scale.dim() == 2 + assert zero_point.dim() == 2 + M, K = x.shape + K8, N = qweight.shape + assert K == K8 * 8 + assert group_size > 0 + x = x.contiguous() + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + + matmul_kernel[ + lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + 1, + ) + ]( + x, qweight, scale, zero_point, output, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + scale.stride(0), scale.stride(1), + zero_point.stride(0), zero_point.stride(1), + output.stride(0), output.stride(1), + group_size, + ) + return output + + +# ------------------------------------------------------------------ +# Quantization helpers +# ------------------------------------------------------------------ +def quantize_int4(x: torch.Tensor, group_size: int = 128): + """ + Returns 3 tensors: + packed_quant (int32), scale (float), zero_point (float) + """ + orig_shape = x.shape + x = x.view(-1, orig_shape[-1]) + K, N = x.shape + + if K % group_size: + K_pad = (K + group_size - 1) // group_size * group_size + x = torch.nn.functional.pad(x, (0, 0, 0, K_pad - K)) + else: + K_pad = K + + x = x.view(-1, group_size) + x_min = x.min(dim=-1, keepdim=True)[0] + x_max = x.max(dim=-1, keepdim=True)[0] + x_max = torch.max(x_max, x_min + 1e-7) + scale = (x_max - x_min) / 15.0 + zero_point = torch.round(-x_min / scale) + + q = torch.clamp(torch.round(x / scale + zero_point), 0, 15).to(torch.int32) + + q_f = q.view(K_pad, N) + int32_q = torch.zeros((K_pad // 8, N), dtype=torch.int32, device=x.device) + for shift in range(8): + int32_q |= q_f[shift::8, :] << (shift * 4) + + int32_q = int32_q.view(*orig_shape[:-1], N // 8) + scale = scale.view(orig_shape[0] // group_size, orig_shape[-1]) + zero_point = zero_point.view(orig_shape[0] // group_size, orig_shape[-1]) + + return int32_q, scale, zero_point + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, group_size: int = 128): + """ + Decompress back to fp32 for testing. + """ + K8, N = packed.shape + K = K8 * 8 + fp = torch.zeros((K, N), dtype=torch.float32, device=packed.device) + for shift in range(8): + fp[shift::8, :] = (packed >> (shift * 4)) & 0xF + fp = fp.view(-1, N) + scale_rs = scale.view(-1, N) + zp_rs = zero_point.view(-1, N) + fp = fp.to(torch.float32) + scale_rs = scale_rs.to(torch.float32) + zp_rs = zp_rs.to(torch.float32) + fp = (fp - zp_rs) * scale_rs + return fp.view(K, N) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stdout new file mode 100644 index 0000000..3742e80 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_357204.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_357204 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py b/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py new file mode 100644 index 0000000..b38c1f3 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py @@ -0,0 +1,177 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1 + }, num_stages=2, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + BLOCK_K_S = BLOCK_SIZE_K * SPLIT_K + offs_k = pid_k * BLOCK_K_S + tl.arange(0, BLOCK_K_S) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K_S)): + k_offs = k * BLOCK_K_S + offs_k[None, :] + a_mask = (offs_am[:, None] < M) & (k_offs < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), other=0) + g_idx = ( offs_k[:, None] // group_size ) + scales = tl.load(scales_ptr + g_idx * stride_scales_g + offs_bn[None, :] * stride_scales_n) + zeros = tl.load(zeros_ptr + g_idx * stride_zeros_g + (offs_bn[None, :] // 8) * stride_zeros_n) + shift = (offs_k[:, None] % 8) * 4 + zp_shift = (offs_bn[None, :] % 8) * 4 + b_vals = (b >> shift) & 0xF + b_zp = (zeros >> zp_shift) & 0xF + b_fp = (b_vals - b_zp) * scales + acc += tl.dot(a.to(tl.float16), b_fp.to(tl.float16)) + a_ptrs += BLOCK_K_S * stride_ak + b_ptrs += (BLOCK_K_S // 8) * stride_bk + c = acc.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + + +def quantize_int4(w: torch.Tensor, group_size: int = 128): + assert w.dim() == 2 + K_, N = w.shape + assert K_ % group_size == 0 + K = K_ + w = w.view(-1, group_size, N) + wmin = torch.amin(w, dim=1, keepdim=True) + wmax = torch.amax(w, dim=1, keepdim=True) + scale = (wmax - wmin) / 15. + zero = (-wmin / scale).round().clamp(0, 15).to(torch.int32) + q = (w / scale + zero).round().clamp(0, 15).to(torch.int32) + q = q.to(torch.uint8) + packed = (q[::2, :, :] | (q[1::2, :, :] << 4)).view(-1, N) + scales = scale.squeeze(1).contiguous() + zeros = zero.squeeze(1).contiguous() + return packed, scales, zeros, None + + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, qweight: torch.IntTensor, + scales: torch.FloatTensor, qzeros: torch.IntTensor, + group_size: int = 128, output=None +) -> torch.FloatTensor: + assert x.is_contiguous() + M, K = x.shape + Kw, N = qweight.shape + assert K == Kw * 8 + if output is None: + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + else: + output.fill_(0.0) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + + +def unpack_int4(b_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + b_packed = b_packed.contiguous() + Kw, N = b_packed.shape + K = Kw * 8 + b_uint = torch.zeros((K, N), dtype=torch.uint8, device=b_packed.device) + b_uint[0::2, :] = b_packed & 0xF + b_uint[1::2, :] = (b_packed >> 4) & 0xF + group_idx = torch.arange(K, device=b_packed.device) // group_size + scl = scales[group_idx, :] + z = zeros[group_idx, :] # shape (K, N) + z = ((z.view(-1, 1, N) >> (4 * torch.arange(N//8, device=b_packed.device)[None, :, None])) & 0xF).view(z.shape[0], -1)[:, :N] + b_fp = (b_uint.to(torch.float32) - z) * scl + return b_fp + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stdout new file mode 100644 index 0000000..35b2353 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_365790.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_365790 due to diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py b/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py new file mode 100644 index 0000000..79b0377 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py @@ -0,0 +1,313 @@ + +import torch +import triton +import triton.language as tl + +configs_matmul = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), +] + +@triton.autotune( + configs=configs_matmul, + key=["M", "N", "K"], + use_cuda_graph=False +) +@triton.jit +def matmul_kernel( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g0, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1, + GROUP_SIZE_M: tl.constexpr = 8 +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = B + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_pos = k * BLOCK_SIZE_K * SPLIT_K + offs_k + g_idx = (k_pos) // groupsize + + mask_k = k_pos < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + offset_b = (k_pos[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_chunk = tl.load(B + offset_b, mask=mask_k[:, None], other=0) + + scale_offset = g_idx[:, None] * stride_scales_g + offs_n[None, :] * stride_scales_n + scale_val = tl.load(scales + scale_offset, mask=mask_k[:, None], other=0.0) + + zp_val = tl.load(zeros + g_idx[:, None] * stride_zeros_g0 + (offs_n // 8)[None, :] * stride_zeros_n, mask=mask_k[:, None], other=0.0) + shift_n = (offs_n % 8)[None, :] * 4 + inv_zp = ((zp_val >> shift_n) & 0xF) * scale_val + + shift_k = (k_pos % 8)[:, None] * 4 + w_int = (b_chunk >> shift_k) & 0xF + w_fp = (w_int * scale_val - inv_zp) + + accumulator += tl.dot(a, w_fp) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cm = offs_cm < M + mask_cn = offs_cn < N + c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_cm[:, None] & mask_cn[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.is_contiguous(), "A must be contiguous" + assert qweight.is_contiguous(), "B must be contiguous" + M, K = x.shape + Kw, N = qweight.shape + K_expected = Kw * 8 + assert K == K_expected, f"Expected K = {K_expected}, got {K}" + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + output.fill_(0) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + max(META.get('SPLIT_K', 1), 1), + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +configs_dequant = [ + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), +] + +@triton.autotune( + configs=configs_dequant, + key=["K", "N"], + use_cuda_graph=False +) +@triton.jit +def dequantize_kernel( + qw_ptr, sc_ptr, zp_ptr, fpw_ptr, + K, N, groupsize, + stride_qk, stride_qn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_fk, stride_fn, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + k_blk = tl.program_id(0) + n_blk = tl.program_id(1) + + offs_k = k_blk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = n_blk * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_k = offs_k[:, None] < K + mask_n = offs_n[None, :] < N + mask = mask_k & mask_n + + grp = offs_k[:, None] // groupsize + + qw_offs = (offs_k[:, None] // 8) * stride_qk + offs_n[None, :] * stride_qn + qw_local = tl.load(qw_ptr + qw_offs, mask=mask, other=0) + + sc_offs = grp * stride_scg + offs_n[None, :] * stride_scn + sc_local = tl.load(sc_ptr + sc_offs, mask=mask, other=0.0) + + zp_offs = grp * stride_zpg + (offs_n // 8)[None, :] * stride_zpn + zp_quad = tl.load(zp_ptr + zp_offs, mask=mask, other=0) + + shift_k = (offs_k % 8)[:, None] * 4 + shift_n = (offs_n % 8)[None, :] * 4 + + qh = (qw_local >> shift_k) & 0xF + qz = (zp_quad >> shift_n) & 0xF + + dq_val = (qh - qz) * sc_local + tl.store(fpw_ptr + offs_k[:, None] * stride_fk + offs_n[None, :] * stride_fn, dq_val, mask=mask) + +def dequantize_int4(b: torch.Tensor, b_scale: torch.Tensor, b_zero_point: torch.Tensor, device, dtype, groupsize): + K_pack, N = b.shape + K = K_pack * 8 + fp_b = torch.empty((K, N), device=device, dtype=dtype) + grid = lambda META: ( + triton.cdiv(K, META['BLOCK_SIZE_K']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + dequantize_kernel[grid]( + b, b_scale, b_zero_point, fp_b, + K, N, groupsize, + b.stride(0), b.stride(1), + b_scale.stride(0), b_scale.stride(1), + b_zero_point.stride(0), b_zero_point.stride(1), + fp_b.stride(0), fp_b.stride(1) + ) + return fp_b + +def matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, groupsize=128, out=None): + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + Kw, N = b.shape + if out is None: + out = torch.empty((M, N), device=a.device, dtype=a.dtype) + fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, groupsize) + torch.mm(a, fp_b, out=out) + fp_b = None + return out + +def quantize_int4(x: torch.Tensor, groupsize: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + W = x.t().contiguous() + K_raw, N = W.shape + assert K_raw % groupsize == 0, "K must be divisible by groupsize" + groups = K_raw // groupsize + W = W.view(groups, groupsize, N) + wmin = W.min(dim=1, keepdim=True)[0] + wmax = W.max(dim=1, keepdim=True)[0] + scale = (wmax - wmin) / 15 + zero = -wmin / scale + zero = torch.round(zero).clamp(0, 15) + + qweight_t = torch.clamp(torch.round(W / scale + zero), 0, 15).to(torch.int8) + + packed = torch.zeros((groups * groupsize) // 8, N, dtype=torch.int32, device=x.device) + for col in range(N): + w_col = qweight_t[:, :, col].flatten() + for idx in range(0, w_col.size(0), 8): + vals = w_col[idx:idx+8] + val = 0 + for v in vals: + val = (val << 4) | (v.int() & 0xF) + packed[idx//8, col] = val + qweight = packed.t().contiguous() + + scale = scale.squeeze(1).transpose(0, 1).contiguous() + zero = zero.squeeze(1).transpose(0, 1).contiguous() + + qzeros = torch.empty_like(zero, dtype=torch.int32) + for col in range(N): + for row in range(groups): + val = zero[row, col].int() & 0xF + qzeros[row, col] = val + qzeros = qzeros.transpose(0, 1).contiguous().view(groups, N // 8, 8).transpose(1,2) + zsplit = torch.zeros((groups, N//8), dtype=torch.int32, device=x.device) + idx = 0 + for row in range(groups): + for col in range(N//8): + acc = 0 + for k in range(8): + acc |= (zero[row, col*8 + k].int() & 0xF) << (k * 4) + zsplit[row, col] = acc + qzeros = zsplit.transpose(0,1).contiguous() + scale = scale.transpose(0,1).contiguous() + + return qweight, scale, qzeros + +def unpack_int4(packed_weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, groupsize: int) -> torch.Tensor: + K_pack, N = packed_weights.shape + K = K_pack * 8 + device = packed_weights.device + dtype = scales.dtype + + unpacked = torch.zeros((K, N), dtype=dtype, device=device) + for col in range(N): + chunk = packed_weights[:, col] + for i in range(K_pack): + val = chunk[i].int() + for j in range(8): + bits = (val >> (j * 4)) & 0xF + unpacked[i * 8 + j, col] = float(bits) + + num_groups = K // groupsize + scale_expanded = scales.t().repeat_interleave(groupsize, dim=0) + zero_expanded = zeros.t().repeat_interleave(groupsize, dim=0) + + result = (unpacked - zero_expanded) * scale_expanded + return result.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stdout new file mode 100644 index 0000000..414d183 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_41463.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_41463 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py new file mode 100644 index 0000000..25649ef --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py @@ -0,0 +1,252 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_sp_k = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_idx = k * BLOCK_SIZE_K * SPLIT_K + offs_k[None, :] + g_idx = k_idx // group_size + bs_ptrs = bs_ptr + g_idx * stride_bsk + offs_bn[None, :] * stride_bsn + bzp_ptrs = bzp_ptr + g_idx * stride_bzpk + (offs_bn[None, :] // 8) * stride_bzpn + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + bs = tl.load(bs_ptrs, mask=offs_bn[None, :] < N, other=0.0) + bzp = tl.load(bzp_ptrs, mask=offs_bn[None, :] < N, other=0) + b_shift = (offs_k[:, None] % 8) * 4 + z_shift = (offs_n[None, :] % 8) * 4 + b_q = (b >> b_shift) & 0xF + z_q = (bzp >> z_shift) & 0xF + b_deq = (b_q.to(tl.float32) - z_q.to(tl.float32)) * bs + accumulator += tl.dot(a, b_deq.to(a.dtype)) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + c = accumulator.to(c_ptr.dtype.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output: torch.FloatTensor = None +) -> torch.FloatTensor: + assert x.is_contiguous(), "A must be contiguous" + assert qweight.is_contiguous(), "qweight must be contiguous" + M, K = x.shape + Kq = qweight.shape[0] * 8 + N = qweight.shape[1] + assert K == Kq, "Leading dimension mismatch" + assert scales.shape[0] == (K + group_size - 1) // group_size + assert qzeros.shape[0] == (K + group_size - 1) // group_size + assert scales.shape[1] == N + assert qzeros.shape[1] == (N + 7) // 8 + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=8), + ], + key=['K', 'N'], +) +@triton.jit +def quantize_int4_kernel( + x_ptr, qweight_ptr, scales_ptr, zeros_packed_ptr, + K, N, + stride_xk, stride_xn, + stride_qw, stride_qwn, + stride_sc, stride_scn, + stride_zp, stride_zpn, + group_size, + BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + group_pid = tl.program_id(0) + sub_k = group_pid * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)[:, None] + tid_n = tl.program_id(1) + sub_n = tid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] + mask_k = sub_k < K + mask_n = sub_n < N + mask = mask_k & mask_n + x = tl.load(x_ptr + sub_k * stride_xk + sub_n * stride_xn, mask=mask, other=0.0) + g_idx = sub_k // group_size + x_min = tl.min(x, axis=0, keepdim=True) + x_max = tl.max(x, axis=0, keepdim=True) + scale = (x_max - x_min) / 15.0 + z = (-x_min / scale).to(tl.int32) + q = tl.clamp((x.to(tl.float32) / scale + z + 0.5).to(tl.int32), 0, 15) + q = q.to(tl.int32) + packed = tl.zeros([BLOCK_SIZE_K, BLOCK_SIZE_N // 8], dtype=tl.int32) + shifts = tl.arange(0, 8) * 4 + cols_bit = (sub_n % 8) * 4 + q = tl.reshape(q, [BLOCK_SIZE_K, BLOCK_SIZE_N]) + for i in range(0, 8): + col_i = (sub_n // 8) * 8 + i + val = tl.where((col_i < N), q[:, col_i], 0) + shifted = val << (i * 4) + packed |= shifted + for i in range(0, 8): + zp_col = (sub_n // 8) * 8 + i + shifted_zp = tl.where((zp_col < N), z[:, zp_col], 0) << (i * 4) + zeros_packed = tl.sum(shifted_zp, axis=1, keepdim=True).to(tl.int32) + zeros_ptrs = zeros_packed_ptr + g_idx * stride_zpn + (sub_n // 8) * stride_zp + tl.store(zeros_ptrs, zeros_packed, mask=mask_k) + qstor = qweight_ptr + (sub_k // 8) * stride_qw + (sub_n // 8) * stride_qwn + tl.store(qstor, packed, mask=mask_k) + sc_ptrs = scales_ptr + g_idx * stride_scn + (sub_n) * stride_sc + tl.store(sc_ptrs, scale, mask=mask_n) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + x = x.contiguous() + K, N = x.shape + qweight = torch.zeros((K // 8, N), dtype=torch.int32, device=x.device) + scales = torch.empty((K // group_size, N), dtype=torch.float32, device=x.device) + zeros = torch.empty((K // group_size, (N + 7) // 8), dtype=torch.int32, device=x.device) + grid = lambda META: ( + triton.cdiv(K, META['BLOCK_SIZE_K']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + quantize_int4_kernel[grid]( + x, qweight, scales, zeros, + K, N, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size + ) + return qweight, scales, zeros + + +def unpack_int4( + qweight: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.IntTensor, + group_size: int = 128 +) -> torch.FloatTensor: + Kq, N = qweight.shape + K = Kq * 8 + weight = torch.empty((K, N), dtype=torch.float32, device=qweight.device) + zeros_float = torch.empty_like(scales) + for g in range(scales.shape[0]): + for n in range(N): + z = (zeros[g, n // 8] >> ((n % 8) * 4)) & 0xF + zeros_float[g, n] = float(z) + for k in range(K): + for n in range(N): + q = (qweight[k // 8, n] >> ((k % 8) * 4)) & 0xF + g = k // group_size + weight[k, n] = (q - zeros_float[g, n]) * scales[g, n] + return weight.contiguous() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stdout new file mode 100644 index 0000000..989e538 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_430740.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_430740 due to at 20:12: + BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + group_pid = tl.program_id(0) + sub_k = group_pid * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)[:, None] + tid_n = tl.program_id(1) + sub_n = tid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] + mask_k = sub_k < K + mask_n = sub_n < N + mask = mask_k & mask_n + x = tl.load(x_ptr + sub_k * stride_xk + sub_n * stride_xn, mask=mask, other=0.0) + g_idx = sub_k // group_size + x_min = tl.min(x, axis=0, keepdim=True) + ^ +TypeError("min() got an unexpected keyword argument 'keepdim'") diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py b/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py new file mode 100644 index 0000000..5d9a59e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# int4 de-quant helpers +@triton.jit +def _dequantize_int4_unpack(xi32, mask0=0x0f, mask1=0xf0): + xi0 = (xi32 & mask0).to(tl.int8) + xi1 = ((xi32 & mask1) >> 4).to(tl.int8) + return xi0, xi1 + + +@triton.jit +def _dequantize_int4_kernel(ptr, scales_ptr, zeros_ptr, M, N, + stride_q, stride_s, stride_z, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + q_offsets = (rm[:, None] * stride_q + (rn // 8)[None, :]) + scales_offsets = (rm[:, None] * stride_s + (rn // 8)[None, :]) + zeros_offsets = (rm[:, None] * stride_z + (rn // 8)[None, :]) + + mask_m = rm < M + mask_n = rn < N + mask = mask_m[:, None] & mask_n[None, :] + + packed = tl.load(ptr + q_offsets, mask=mask, other=0) + s = tl.load(scales_ptr + scales_offsets, mask=mask, other=1.0) + z = tl.load(zeros_ptr + zeros_offsets, mask=mask, other=0.0) + + offsets_0 = (rn % 8) * 4 + offsets_1 = offsets_0 + 4 + i0, i1 = _dequantize_int4_unpack(packed) + v0 = (i0.to(tl.float32) - z) * s + v1 = (i1.to(tl.float32) - z) * s + + return v0, v1 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_eval_k, stride_eval_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_k = tl.program_id(2) + + n_blocks_m = tl.cdiv(M, BLOCK_SIZE_M) + n_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + group_id = 0 + first_pid_m = 0 + else: + group_id = pid_m // GROUP_SIZE_M + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(n_blocks_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_m % group_size_m) + + if SPLIT_K > 1: + local_k = tl.cdiv(K, SPLIT_K) + k_offset = pid_k * local_k + else: + local_k = K + k_offset = 0 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + scales_ptrs = scales_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, local_k, BLOCK_SIZE_K): + if EVEN_K or (k + BLOCK_SIZE_K <= local_k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < local_k - k, other=0.0, eviction_policy="evict_last") + block_scale = tl.load(scales_ptrs, mask=offs_k[:, None] < local_k - k, other=1.0) + block_zero = tl.load(zeros_ptrs, mask=offs_k[:, None] < local_k - k, other=0.0) + + packed_b = tl.load(b_ptrs, mask=offs_k[:, None] < local_k - k, other=0) + k_idx = (offs_k[:, None] % 8) * 4 + val_low = (packed_b & 0x0F).to(tl.int8).to(tl.float32) + val_high = ((packed_b >> 4) & 0x0F).to(tl.int8).to(tl.float32) + b_low = (val_low - block_zero) * block_scale + b_high = (val_high - block_zero) * block_scale + + acc = tl.dot(a, b_low, acc) + a_shift = tl.load(a_ptrs + stride_bk * (1 if EVEN_K else 8), mask=offs_k[None, :] + 8 < local_k - k, other=0.0, eviction_policy="evict_last") + acc = tl.dot(a_shift, b_high, acc) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + scales_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + zeros_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + + if SPLIT_K == 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptrs.type.element_ty), mask=c_mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + pid_k * M * N + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=c_mask) + + +def matmul_dequantize_int4_s2(a, int4b_compressed, scales, zeros, M, N, K): + c_dtype = a.dtype + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + SPLIT_K = 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K) + + if SPLIT_K > 1: + c = torch.empty((SPLIT_K, M, N), dtype=torch.float32, device=a.device) + else: + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + EVEN_K = K % 32 == 0 + + matmul_kernel[grid](a, int4b_compressed, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + int4b_compressed.stride(0), int4b_compressed.stride(1), + c.stride(0) if c.dim() == 2 else c.stride(1), + c.stride(1) if c.dim() == 2 else c.stride(2), + scales.stride(0), scales.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + SPLIT_K=SPLIT_K, + EVEN_K=EVEN_K) + return c if SPLIT_K == 1 else c.sum(dim=0) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + org_shape = x.shape + x = x.view(-1, group_size) + x_fp32 = x.float() + x_min = x_fp32.amin(dim=-1, keepdim=True) + x_max = x_fp32.amax(dim=-1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale + 0.5).clamp(0, 15) + + int4 = (((x_fp32 - x_min) / scale + 0.5).clamp(0, 15)).round().to(torch.int8) + + int4 = int4.view(-1) + packed = torch.zeros(x.size(0), x.size(1) // 2, dtype=torch.int32, device=x.device) + int4_even = int4[::2].to(torch.int32) + int4_odd = int4[1::2].to(torch.int32) + packed = int4_even | (int4_odd << 4) + packed = packed.view(org_shape[0], org_shape[1] // 2) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + return packed, scale, zero + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + M, Nhalf = packed.shape + N = Nhalf * 2 + unpacked = torch.empty(M, N, dtype=torch.float16, device=packed.device) + + packed = packed.int() + for i in range(M): + for j in range(Nhalf): + low = (packed[i, j] & 0x0F).to(torch.float32) + high = ((packed[i, j] >> 4) & 0x0F).to(torch.float32) + group_idx = j * 2 // group_size + val_low = (low - zero[i, group_idx]) * scale[i, group_idx] + val_high = (high - zero[i, group_idx]) * scale[i, group_idx] + unpacked[i, 2 * j] = val_low.to(torch.float16) + unpacked[i, 2 * j + 1] = val_high.to(torch.float16) + return unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stdout new file mode 100644 index 0000000..959cf8e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_434177.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_434177 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py b/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py new file mode 100644 index 0000000..c39d1ca --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py @@ -0,0 +1,241 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_offset = k * BLOCK_SIZE_K * SPLIT_K + a_idx = offs_k[None, :] + k_offset + b_idx = offs_k[:, None] + k_offset + mask_a = (offs_m[:, None] < M) & (a_idx < K) + mask_b = (b_idx < K) & (offs_n[None, :] < N) + + a = tl.load(a_ptrs + k_offset * stride_ak, mask=mask_a, other=0.0) + b = tl.load(b_ptrs + k_offset * stride_bk, mask=mask_b, other=0.0) + + accumulator += tl.dot(a, b) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, accumulator, mask=mask_c) + else: + tl.store(c_ptrs, accumulator.to(tl.float16), mask=mask_c) + +def matmul_dequantize_int4_s2( + x: torch.Tensor, qw_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor +) -> torch.Tensor: + M, K = x.shape + N = qw_packed.shape[0] * 8 // 4 + y = torch.empty((M, N), dtype=x.dtype, device=x.device) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + 1, + ) + + matmul_kernel[grid]( + x, qw_packed, y, + M, N, K, + x.stride(0), x.stride(1), + qw_packed.stride(0), 4, + y.stride(0), y.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + SPLIT_K=1, + ) + return y + +@triton.jit +def quantize_int4_kernel( + src_ptr, dst_ptr, scales_ptr, zeros_ptr, + num_rows, num_cols, + stride_sr, stride_sc, + stride_dr, stride_dc, + stride_scale_r, + BLOCK_SIZE: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + row = tl.program_id(0) + col_start = tl.program_id(1) * GROUP_SIZE + offs = col_start + tl.arange(0, BLOCK_SIZE) + + mask = offs < num_cols + src_ptrs = src_ptr + row * stride_sr + offs * stride_sc + src = tl.load(src_ptrs, mask=mask, other=0.0).to(tl.float32) + + min_val = tl.min(src) + max_val = tl.max(src) + scale = (max_val - min_val) / ((2 ** 4) - 1) + zero = -min_val / scale + scale_store = scale.to(tl.float16) + zero_store = zero.to(tl.float16) + + grouped = (src - min_val) / scale + int4 = tl.cast(grouped + 0.5, tl.int32) + packed = (int4 & 0xF) | (tl.shl(int4, 4) & 0xF) + packed = tl.view(packed, tl.int32) + + scale_zero_idx = row + (col_start // GROUP_SIZE) * stride_scale_r + scales_ptrs = scales_ptr + scale_zero_idx + zeros_ptrs = zeros_ptr + scale_zero_idx + + tl.store(scales_ptrs, scale_store) + tl.store(zeros_ptrs, zero_store) + + if col_start < num_cols: + src_ptrs = src_ptr + row * stride_sr + col_start * stride_sc + for j in range(0, tl.cdiv(GROUP_SIZE, BLOCK_SIZE)): + offset = j * BLOCK_SIZE + mask = (col_start + offset + tl.arange(0, BLOCK_SIZE)) < num_cols + src = tl.load(src_ptrs + offset * stride_sc, mask=mask, other=0.0).to(tl.float32) + rescaled = (src - min_val) / scale + int4 = tl.cast(rescaled + 0.5, tl.int32) + packed = tl.zeros([BLOCK_SIZE // 8], dtype=tl.int32) + for k in range(0, BLOCK_SIZE // 8): + idx = k * 8 + tl.arange(0, 8) + packed[k] = ( + (int4[idx] & 0xF) | + tl.shl((int4[idx + 1] & 0xF), 4) | + tl.shl((int4[idx + 2] & 0xF), 8) | + tl.shl((int4[idx + 3] & 0xF), 12) | + tl.shl((int4[idx + 4] & 0xF), 16) | + tl.shl((int4[idx + 5] & 0xF), 20) | + tl.shl((int4[idx + 6] & 0xF), 24) | + tl.shl((int4[idx + 7] & 0xF), 28) + ) + dst_ptrs = dst_ptr + row * stride_dr + (offset // 8) * stride_dc + write_mask = (col_start + offset) < num_cols + tl.store(dst_ptrs, packed, mask=write_mask) + +def quantize_int4(weight: torch.Tensor, group_size: int = 128) -> tuple: + assert weight.dim() == 2 + num_rows, num_cols = weight.shape + group_size = min(group_size, num_cols) + assert num_cols % group_size == 0 + num_groups = num_cols // group_size + + qw_packed = torch.empty( + (num_rows, num_cols // 8), + dtype=torch.int32, + device=weight.device + ) + scales = torch.empty( + (num_rows, num_groups), + dtype=torch.float16, + device=weight.device + ) + zeros = torch.empty_like(scales) + + def grid(): + return ( + num_rows, + num_groups, + ) + + quantize_int4_kernel[grid]( + weight, qw_packed, scales, zeros, + num_rows, num_cols, + weight.stride(0), weight.stride(1), + qw_packed.stride(0), qw_packed.stride(1), + scales.stride(0), + BLOCK_SIZE=32, + GROUP_SIZE=group_size, + ) + return qw_packed, scales, zeros + +def unpack_int4(qw_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + num_rows = qw_packed.size(0) + num_cols = qw_packed.size(1) * 8 + weight = torch.empty((num_rows, num_cols), dtype=torch.float16, device=qw_packed.device) + for row in range(num_rows): + group_idx = 0 + for col in range(0, num_cols, 8): + packed = qw_packed[row, col // 8] + scale = scales[row, group_idx] + zero = zeros[row, group_idx] + if (col + 8) % group_size == 0: + group_idx += 1 + for i in range(8): + val = (packed >> (4 * i)) & 0xF + rescaled = val * scale + zero + weight[row, col + i] = rescaled + return weight + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stdout new file mode 100644 index 0000000..3f079f3 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_461728.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_461728 due to module 'triton.language' has no attribute 'shl' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py b/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py new file mode 100644 index 0000000..fc2f471 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py @@ -0,0 +1,190 @@ + +import torch +import triton +import triton.language as tl + +### ------------------- Triton kernel for INT4 matmul – autotuned ------------------- ### +@triton.autotune( + configs=[ + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + k_per_iter = BLOCK_SIZE_K * SPLIT_K + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k0 = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_m = offs_m < M + mask_n = offs_n < N + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k0[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k0[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, k_per_iter)): + idx_k = k * k_per_iter + pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = idx_k < K + a = tl.load(a_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + + b_offs = (idx_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_int = tl.load(b_ptr + b_offs, mask=mask_k[:, None] & mask_n[None, :], other=0) + + g = idx_k[:, None] // group_size + bs_offs = g * stride_bsk + offs_n[None, :] * stride_bsn + bzp_offs = g * stride_bzpk + (offs_n[None, :]//8) * stride_bzpn + bs = tl.load(bs_ptr + bs_offs, mask=g*0 == 0, other=1.0) + bzp = tl.load(bzp_ptr + bzp_offs, mask=g*0 == 0, other=0) + + shift_k = (idx_k[:, None] % 8) * 4 + shift_n = (offs_n[None, :] % 8) * 4 + b_val = ((b_int >> shift_k) & 0xF) - ((bzp >> shift_n) & 0xF) + b_fp = (b_val * bs).to(a.dtype) + + acc += tl.dot(a, b_fp) + a_ptrs += k_per_iter * stride_ak + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] + mask_c = mask_m[:, None] & mask_n[None, :] + c_ptrs = c_ptr + offs_cm * stride_cm + offs_cn * stride_cn + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask_c) + else: + tl.atomic_add(c_ptrs, acc, mask=mask_c) + +### ------------------- Launcher ------------------- ### +def matmul_dequantize_int4_s2( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int = 128 +) -> torch.Tensor: + assert x.is_contiguous() + M, _ = x.shape + N = scales.shape[1] + out = torch.empty((M, N), dtype=x.dtype, device=x.device) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, out, + scales, qzeros, + M, N, x.shape[1], + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + out.stride(0), out.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return out + +### ------------------- quantize / unpack ------------------- ### +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]: + w = weights.to(torch.float32) + K, N = w.shape + assert N % group_size == 0, "N must be divisible by group_size" + w = w.view(-1, group_size) + wmin = w.min(dim=1, keepdim=True)[0] + wmax = w.max(dim=1, keepdim=True)[0] + scales = (wmax - wmin) / 15.0 + zeros = (-wmin / scales).round().clamp(0, 15) + + w = torch.round((w - wmin) / scales).clamp(0, 15).to(torch.uint8) + packed = torch.zeros(w.shape[0], w.shape[1] // 2, dtype=torch.int32, device=w.device) + packed |= (w[:, 0::2] & 0xF) + packed |= (w[:, 1::2] & 0xF) << 4 + packed = packed.view(K, N // 8) + scales = scales.view(K, N // group_size).contiguous() + zeros = zeros.view(K, N // group_size).contiguous() + return packed.contiguous(), scales, zeros, None + +def unpack_int4(weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + Bits = 4 + Kq, N = weights.shape + K = Kq * Bits + unpacked = torch.empty(K, N, dtype=scales.dtype, device=weights.device) + for b in range(Bits): + int4 = (weights >> b) & 0x1 + row_off = b + unpacked[row_off::Bits, :] = int4 * scales.view(-1, N).expand_as(int4) + zeros.view(-1, N).expand_as(int4) + return unpacked.view(weights.shape[0] * 8, N) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stdout new file mode 100644 index 0000000..4b8c3c4 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_48845.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_48845 due to shape '[4096, 512]' is invalid for input of size 8388608 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py b/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py new file mode 100644 index 0000000..b61545c --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py @@ -0,0 +1,198 @@ + +import torch +import triton +import triton.language as tl + +# -------------------------------------------------- +# Triton kernel +# -------------------------------------------------- +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + x_ptr, qw_ptr, sc_ptr, zp_ptr, c_ptr, + M, N, K, + stride_xm, stride_xk, + stride_qwk, stride_qwn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_cm, stride_cn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_sp_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = pid_sp_k * BLOCK_SIZE_K * SPLIT_K + tl.arange(0, BLOCK_SIZE_K * SPLIT_K) + + mask_m = offs_m < M + mask_n = offs_n < N + mask_k = offs_k < K + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k0 in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + current_offs_k = k0 * BLOCK_SIZE_K * SPLIT_K + offs_k + mask_kk = current_offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + current_offs_k[None, :] * stride_xk + x_blk = tl.load(x_ptrs, mask=mask_m[:, None] & mask_kk[None, :], other=0.0) + + qw_ptrs = qw_ptr + (current_offs_k[:, None] // 8) * stride_qwk + offs_n[None, :] * stride_qwn + qw_blk = tl.load(qw_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0) + + # scale & zp indices + g_idx = (current_offs_k // group_size) + sc_ptrs = sc_ptr + g_idx[:, None] * stride_scg + offs_n[None, :] * stride_scn + zp_ptrs = zp_ptr + g_idx[:, None] * stride_zpg + (offs_n[None, :] // 8) * stride_zpn + + sc = tl.load(sc_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0.0).to(tl.float32) + zp = tl.load(zp_ptrs, mask=mask_kk[:, None] & mask_n[None, :], other=0) + + shifts = (current_offs_k % 8) * 4 + int4_w = (qw_blk >> shifts[:, None]) & 0xF + zp_shifts = (offs_n[None, :] % 8) * 4 + int4_zp = (zp >> zp_shifts) & 0xF + deq_w = ((int4_w.float() - int4_zp.float()) * sc).to(tl.float16) + + acc += tl.dot(x_blk.to(tl.float16), deq_w).to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_out = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] + + if SPLIT_K > 1: + tl.atomic_add(out_ptrs, acc.astype(tl.float16), mask=mask_out) + else: + tl.store(out_ptrs, acc.astype(tl.float16), mask=mask_out) + +# -------------------------------------------------- +# Wrapper +# -------------------------------------------------- +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, + scale: torch.Tensor, zero_point: torch.Tensor, + group_size: int = 128) -> torch.Tensor: + assert x.dim() == 2 + assert qweight.dim() == 2 + assert scale.dim() == 2 + assert zero_point.dim() == 2 + M, K = x.shape + K8, N = qweight.shape + assert K == K8 * 8 + x = x.contiguous() + output = torch.empty((M, N), dtype=torch.float16, device=x.device) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, scale, zero_point, output, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + scale.stride(0), scale.stride(1), + zero_point.stride(0), zero_point.stride(1), + output.stride(0), output.stride(1), + group_size, + ) + return output + +# -------------------------------------------------- +# Quantization helpers +# -------------------------------------------------- +def quantize_int4(x: torch.Tensor, group_size: int = 128): + orig_shape = x.shape + x = x.view(-1, orig_shape[-1]) + K, N = x.shape + if K % group_size: + pad_k = (K + group_size - 1) // group_size * group_size + x = torch.nn.functional.pad(x, (0, 0, 0, pad_k - K)) + K = pad_k + x = x.view(-1, group_size, N) + x_min = x.amin(dim=1, keepdim=True) + x_max = x.amax(dim=1, keepdim=True) + denom = x_max - x_min + denom[denom.abs() < 1e-12] = 1.0 + sc = (denom) / 15.0 + zp = torch.round(-x_min / sc) + q = torch.clamp(torch.round(x / sc + zp), 0, 15).to(torch.int32) + qf = q.view(K, N) + packed = torch.zeros(K // 8, N, dtype=torch.int32, device=x.device) + for shift in range(8): + packed |= qf[shift::8, :] << (shift * 4) + sc = sc.view(K // group_size, N) + zp = zp.view(K // group_size, N) + return packed, sc, zp + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + group_size: int = 128): + K8, N = packed.shape + K = K8 * 8 + w = torch.empty(K, N, dtype=torch.float32, device=packed.device) + for shift in range(8): + w[shift::8, :] = ((packed >> (shift * 4)) & 0xF).float() + scale1 = scale.view(-1, N) + zp1 = zero_point.view(-1, N) + return ((w.view(-1, group_size, N) - zp1.unsqueeze(1)) * scale1.unsqueeze(1)).view(K, N) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stdout new file mode 100644 index 0000000..5597bb1 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_490790.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_490790 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py b/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py new file mode 100644 index 0000000..6efc026 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py @@ -0,0 +1,203 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_start = pid_z * (BLOCK_SIZE_K * SPLIT_K) + tl.arange(0, BLOCK_SIZE_K * SPLIT_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k_start[None, :] * stride_ak + b_ptrs = b_ptr + ((offs_k_start[:, None] // 8) * stride_bk) + offs_n[None, :] * stride_bn + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + a = tl.load(a_ptrs, mask=(offs_k_start[None, :] < K), other=0.0) + b_i32 = tl.load(b_ptrs, mask=(offs_k_start[:, None] < K), other=0) + + n_idx = offs_n[None, :] + k_idx = offs_k_start[:, None] + mask_valid = (k_idx < K) + + group_id_k = k_idx // group_size + scales = tl.load(bs_ptr + group_id_k * stride_bsk + n_idx * stride_bsn, mask=mask_valid, other=0.0) + zeros = tl.load(bzp_ptr + group_id_k * stride_bzpk + (n_idx // 8) * stride_bzpn, mask=mask_valid, other=0) + + b_shift = ((k_idx % 8) * 4) + zp_shift = ((n_idx % 8) * 4) + + b_i4 = (b_i32 >> b_shift) & 0xF + zp_i4 = (zeros >> zp_shift) & 0xF + b_float = (b_i4 - zp_i4).to(tl.float32) * scales.to(tl.float32) + + accumulator += tl.dot(a.to(tl.float32), b_float.to(tl.float32)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.ndim == 2 and qweight.ndim == 2 + assert x.shape[-1] == (qweight.shape[0] * 8) + assert x.is_contiguous() + + M, K = x.shape + N = scales.shape[1] + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + + def grid_fn(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) + + matmul_kernel[grid_fn]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + + +def quantize_int4(w: torch.Tensor, group_size: int = 128): + assert w.dim() == 2 + w = w.float() + oc, ic = w.shape + assert ic % group_size == 0 + w = w.reshape(oc, ic // group_size, group_size) + + wmax = w.amax(dim=2, keepdim=True) + wmin = w.amin(dim=2, keepdim=True) + scale = (wmax - wmin) / 15.0 + zero = (-wmin / scale).round().clamp(0, 15).to(torch.int8) + + int_w = ((w - wmin) / scale).round().clamp(0, 15).to(torch.int8) + + int_w_reshaped = int_w.view(oc, ic) + zero_reshaped = zero.view(oc, -1) + + col_bytes = torch.empty(oc, ic // 2, dtype=torch.int8, device=w.device) + for j in range(0, ic, 2): + lo = int_w_reshaped[:, j] + hi = int_w_reshaped[:, j + 1] + packed = (hi << 4) | lo + col_bytes[:, j // 2] = packed.to(torch.int8) + + out = col_bytes.view(oc, ic // 8).view(torch.int32) + return out, scale.squeeze(-1).half(), zero_reshaped + + +def unpack_int4(w_packed: torch.IntTensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + oc, ic_bytes = w_packed.shape + ic = ic_bytes * 8 + assert ic % group_size == 0 + + w_int = torch.empty(oc, ic, dtype=torch.int8, device=w_packed.device) + packed = w_packed.view(torch.int8).view(oc, ic // 8) + for j in range(ic // 8): + b = packed[:, j] + for k in range(8): + val = (b >> (k * 4)) & 0xF + w_int[:, j * 8 + k] = val + + num_groups = ic // group_size + scale = scale.view(oc, num_groups, 1).expand(-1, -1, group_size).reshape(oc, ic) + zero = zero.view(oc, num_groups, 1).expand(-1, -1, group_size).reshape(oc, ic) + return (w_int.float() - zero.float()) * scale.float() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stdout new file mode 100644 index 0000000..c5e31e1 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_511041.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_511041 due to shape '[4096, 512]' is invalid for input of size 8388608 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py b/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py new file mode 100644 index 0000000..2e27b87 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py @@ -0,0 +1,199 @@ + +import torch +import triton +import triton.language as tl + +# ------------------------------------------------------------------------- +# Triton kernel – core INT4 matmul +# ------------------------------------------------------------------------- +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=2), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_offs = k * BLOCK_SIZE_K * SPLIT_K + offs_k + ks = bs_ptr + (k_offs // group_size) * stride_bsk + offs_n[None, :] * stride_bsn + kzp = bzp_ptr + (k_offs // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn + a = tl.load(a_ptrs, mask=k_offs[None, :] < K, other=0.0) + b = tl.load(b_ptrs, mask=k_offs[:, None] < K, other=0) + scale = tl.load(ks) + zero = tl.load(kzp) + b_shift = (k_offs[:, None] % 8) * 4 + z_shift = (offs_n[None, :] % 8) * 4 + b_deq = (((b >> b_shift) & 0xF).to(tl.float32) - ((zero >> z_shift) & 0xF).to(tl.float32)) * scale + accumulator += tl.dot(a.to(tl.float16), b_deq.to(tl.float16)) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K // 8 * stride_bk + c = accumulator.to(tl.float16) + + if SPLIT_K > 1: + offs_cm = offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.atomic_add(c_ptr + offs_cm, c, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) + else: + offs_cm = offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptr + offs_cm, c, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) + +# ------------------------------------------------------------------------- +# Python wrapper +# ------------------------------------------------------------------------- +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output: torch.FloatTensor = None +) -> torch.FloatTensor: + assert x.is_contiguous(), "input must be contiguous" + M, K = x.shape + N = scales.shape[1] + if output is None: + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(1), qweight.stride(0), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + +# ------------------------------------------------------------------------- +# Quantization / De-quantization helpers +# ------------------------------------------------------------------------- +def quantize_int4(x: torch.Tensor, group_size: int = 128): + """ + Converts fp16/fp32 weight tensor of shape (K, N) into INT4 representation. + Returns (packed_int32, scales, zeros) all on same device/dtype + layout expected by the kernel. + """ + x = x.t().contiguous() # -> (N, K) + N, K = x.shape + assert K % group_size == 0, f"K ({K}) not divisible by group_size {group_size}" + + x = x.view(N, K // group_size, group_size).float() + x_min = x.min(dim=2, keepdim=True)[0] + x_max = x.max(dim=2, keepdim=True)[0] + + scales = (x_max - x_min) / 15.0 + zp_fp = (-x_min / scales).round().clamp(0, 15) + x_q = (x / scales + zp_fp).round().clamp(0, 15) + scales = scales.squeeze(2).t().contiguous() + zeros = zp_fp.squeeze(2).t().contiguous() + + # Flatten to (N, K) before packing + x_q = x_q.view(N, K) + packed = torch.zeros((N, K // 8), dtype=torch.int32, device=x.device) + for i in range(8): + packed |= ((x_q[:, i::8].to(torch.int32) & 0xF) << (4 * i)) + packed = packed.t().contiguous() + return packed, scales, zeros + + +def unpack_int4(qweight: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.FloatTensor, + group_size: int = 128) -> torch.FloatTensor: + """ + Unpack INT4 tensor back into float32/float16 tensor for correctness test. + """ + qweight, scales, zeros = qweight.t(), scales.t(), zeros.t() + N, K_w = qweight.shape + K = K_w * 8 + weight = torch.zeros((N, K), dtype=torch.float32, device=qweight.device) + + for i in range(8): + mask = 0xF << (i * 4) + cols = torch.arange(i, K, 8, device=qweight.device) + scale_col = scales[:, cols // group_size] + zero_col = zeros[:, cols // group_size] + vals = ((qweight & mask) >> (i * 4)).to(torch.float32) + weight[:, cols] = (vals - zero_col) * scale_col + return weight.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stdout new file mode 100644 index 0000000..36d63e1 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_512013.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_512013 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py b/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py new file mode 100644 index 0000000..ec6cfdc --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py @@ -0,0 +1,180 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_in_group = GROUP_SIZE_M * grid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(grid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 2) * stride_bk + offs_n[None, :] * stride_bn) + + scales_ptrs = scales_ptr + (offs_k * stride_scales) + zeros_ptrs = zeros_ptr + (offs_k * stride_zeros) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_idx = k * BLOCK_SIZE_K * SPLIT_K + offs_k + mask_k = k_idx[None, :] < K + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + + b_idx = k_idx[:, None] // 2 + b_raw = tl.load(b_ptrs, mask=b_idx < (K * N) // 8, other=0) + + scales = tl.load(scales_ptrs, mask=k_idx < K, other=1.0) + zeros = tl.load(zeros_ptrs, mask=k_idx < K, other=0.0) + + b_dequant = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(0, BLOCK_SIZE_K): + sub_i = i // 2 + shift = (i % 2) * 4 + mask = tl.full((BLOCK_SIZE_N,), 0x0F, dtype=tl.int32) + val = (b_raw[sub_i, :] >> shift) & mask + val_f = val.to(tl.float32) + dequant = val_f * scales[i] + zeros[i] + b_dequant = tl.store(b_dequant, dequant, mask=i < BLOCK_SIZE_K) + + accumulator += tl.dot(a, b_dequant) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 2) * stride_bk + scales_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_scales + zeros_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_zeros + + if SPLIT_K > 1: + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, accumulator, mask=mask) + else: + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=mask) + +def matmul_dequantize_int4_s2(a: torch.Tensor, b_quant: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, K: int) -> torch.Tensor: + M, _ = a.shape + _, N = b_quant.shape + c = torch.empty((M, N), dtype=a.dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 1) + matmul_kernel[grid]( + a, b_quant, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + b_quant.stride(0), b_quant.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), zeros.stride(0), + ) + return c + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + w_f = weights.to(torch.float32) + shape = w_f.shape + w_f = w_f.reshape(-1, group_size) + w_min = w_f.min(dim=1, keepdim=True)[0] + w_max = w_f.max(dim=1, keepdim=True)[0] + scales = (w_max - w_min) / 15.0 + zeros = -w_min / scales + w_int4 = torch.round((w_f - w_min) / scales).clamp(0, 15).to(torch.uint8) + w_packed = torch.zeros(w_int4.shape[0], w_int4.shape[1] // 2, dtype=torch.int32, device=weights.device) + for i in range(0, w_int4.shape[1], 2): + val0 = w_int4[:, i].to(torch.int32) + val1 = w_int4[:, i+1].to(torch.int32) if (i+1) < w_int4.shape[1] else 0 + w_packed[:, i//2] = (val0 & 0x0F) | ((val1 & 0x0F) << 4) + w_packed = w_packed.reshape(shape[0], shape[1] // 2) + scales = scales.reshape(shape[0], shape[1] // group_size) + zeros = zeros.reshape(shape[0], shape[1] // group_size) + return w_packed, scales, zeros + +def unpack_int4(weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + shape = weights.shape + w_packed = weights.reshape(-1, shape[-1]) + scales = scales.reshape(-1, scales.shape[-1]) + zeros = zeros.reshape(-1, zeros.shape[-1]) + w_unpacked = torch.zeros(w_packed.shape[0], w_packed.shape[1] * 2, dtype=torch.float32, device=weights.device) + for i in range(w_packed.shape[0]): + group_idx = i // (shape[-1] * 2 // group_size) + for j in range(w_packed.shape[1]): + val = w_packed[i, j] + low = val & 0x0F + high = (val >> 4) & 0x0F + w_unpacked[i, j*2] = low * scales[i, group_idx] + zeros[i, group_idx] + w_unpacked[i, j*2+1] = high * scales[i, group_idx] + zeros[i, group_idx] + return w_unpacked.reshape(shape[0], shape[1] * 2) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stdout new file mode 100644 index 0000000..6689bda --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_52090.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_52090 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py b/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py new file mode 100644 index 0000000..36139a7 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py @@ -0,0 +1,255 @@ + +import torch +import triton +import triton.language as tl + +# ------------------------------------------------------------ +# Triton kernel: matmul with on-the-fly INT4 de-quantisation +# ------------------------------------------------------------ +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_off = k * BLOCK_SIZE_K * SPLIT_K + mask_k = (offs_k[None, :] + k_off) < K + mask_a = (offs_am[:, None] < M) & mask_k + mask_b = mask_k & (offs_bn[None, :] < N) + + a = tl.load(a_ptrs + k_off * stride_ak, mask=mask_a, other=0.0) + b = tl.load(b_ptrs + (k_off // 8) * stride_bk, mask=mask_b, other=0.0) + + group_idx = (offs_k[None, :] + k_off) // group_size + bs = tl.load(bs_ptr + group_idx * stride_bsk + offs_bn[None, :] * stride_bsn, mask=mask_b, other=0.0) + bzps = tl.load(bzp_ptr + group_idx * stride_bzpk + (offs_bn[None, :] // 8) * stride_bzpn, mask=mask_b, other=0.0) + + b_shift = ((offs_k[None, :] + k_off) % 8) * 4 + bzp_shift = (offs_bn[None, :] % 8) * 4 + + int4_b = (b >> b_shift) & 0xF + int4_bzp = (bzps >> bzp_shift) & 0xF + + b_deq = ((int4_b - int4_bzp) * bs).to(tl.float16) + accumulator += tl.dot(a.to(tl.float16), b_deq) + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ((pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N) + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, c, mask=mask) + else: + tl.store(c_ptrs, c, mask=mask) + +# ------------------------------------------------------------ +# Wrapper: launch the matmul kernel +# ------------------------------------------------------------ +def matmul_dequantize_int4_s2(x: torch.Tensor, qweight: torch.Tensor, + scales: torch.Tensor, zeros: torch.Tensor, + group_size: int = 128) -> torch.Tensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + assert scales.is_contiguous() + assert zeros.is_contiguous() + + M, K = x.shape + N = scales.shape[1] + + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K']) + + matmul_kernel[grid]( + x, qweight, output, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size, + GROUP_SIZE_M=8, SPLIT_K=1 + ) + return output + + +# ------------------------------------------------------------ +# Triton kernel: INT4 quantisation (packing helper) +# ------------------------------------------------------------ +@triton.jit +def pack_kernel( + src_ptr, dst_ptr, scales_ptr, zeros_ptr, + stride_sr, stride_sc, + stride_dr, stride_dc, + stride_s, stride_z, + BLOCK_M: tl.constexpr, # rows handled (tile) + BLOCK_N: tl.constexpr, # cols handled (tile) + GROUP_SIZE: tl.constexpr +): + row = tl.program_id(0) + gs = tl.program_id(1) + + col_start = gs * GROUP_SIZE + col_off = tl.arange(0, BLOCK_N) + cols = col_start + col_off + + mask = cols < stride_sc # valid in the row + vals = tl.load(src_ptr + row * stride_sr + cols, mask=mask, other=0.0) + + max_val = tl.max(vals, axis=0) + min_val = tl.min(vals, axis=0) + scale = (max_val - min_val) / 15.0 + zero = -min_val / scale + + s_idx = row * (stride_sc // GROUP_SIZE) + gs + tl.store(scales_ptr + s_idx, scale.to(tl.float16)) + tl.store(zeros_ptr + s_idx, zero.to(tl.float16)) + + for shift in range(0, GROUP_SIZE, 8): + # 8 contiguous floats + idx = shift + tl.arange(0, 8) + msk = (col_start + idx) < stride_sc + v = tl.load(src_ptr + row * stride_sr + col_start + idx, mask=msk, other=0.0) + + q = ((v / scale + zero) + 0.5).to(tl.int32) + q = tl.maximum(tl.minimum(q, 15), 0) + + packed = tl.full([8], 0, dtype=tl.int32) + for ch in range(8): + packed = tl.where(msk, + packed | (q[ch] << (ch * 4)), + packed) + + col_int = (col_start + shift) // 8 + tl.store(dst_ptr + row * stride_dr + col_int, packed[0]) + + +# ------------------------------------------------------------ +# Wrapper: quantise a weight matrix down to INT4 +# ------------------------------------------------------------ +def quantize_int4(weight: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert weight.dim() == 2 + rows, cols = weight.shape + + packed = torch.empty((rows, cols // 8), dtype=torch.int32, device=weight.device) + scales = torch.empty((rows, cols // group_size), dtype=torch.float16, device=weight.device) + zeros = torch.empty_like(scales) + + grid = lambda _: (rows, cols // group_size) + + pack_kernel[grid]( + weight, packed, scales, zeros, + weight.stride(0), weight.stride(1), + packed.stride(0), packed.stride(1), + scales.stride(0), zeros.stride(0), + BLOCK_M=1, BLOCK_N=group_size, GROUP_SIZE=group_size + ) + return packed, scales, zeros + + +# ------------------------------------------------------------ +# Helper: unpack INT4 back to float16 – use only for testing +# ------------------------------------------------------------ +def unpack_int4(qw_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + group_size: int = 128) -> torch.Tensor: + rows, packs = qw_packed.shape + cols = packs * 8 + weight = torch.empty((rows, cols), dtype=torch.float16, device=qw_packed.device) + + for r in range(rows): + for g in range(scales.shape[1]): + s = scales[r, g].item() + z = zeros[r, g].item() + start = g * group_size + end = min(start + group_size, cols) + pack_off = start // 8 + num_packs = (end - start + 7) // 8 + for pw in range(num_packs): + p = qw_packed[r, pack_off + pw].item() + for k in range(8): + val = (p >> (k * 4)) & 0xF + pos = start + pw * 8 + k + if pos < cols: + weight[r, pos] = val * s + z + return weight + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stdout new file mode 100644 index 0000000..df6b107 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_530716.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_530716 due to at 41:40: + # 8 contiguous floats + idx = shift + tl.arange(0, 8) + msk = (col_start + idx) < stride_sc + v = tl.load(src_ptr + row * stride_sr + col_start + idx, mask=msk, other=0.0) + + q = ((v / scale + zero) + 0.5).to(tl.int32) + q = tl.maximum(tl.minimum(q, 15), 0) + + packed = tl.full([8], 0, dtype=tl.int32) + for ch in range(8): + packed = tl.where(msk, + packed | (q[ch] << (ch * 4)), + ^ +ValueError('Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)') diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py b/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py new file mode 100644 index 0000000..7303306 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py @@ -0,0 +1,205 @@ + +import torch +import triton +import triton.language as tl +import math + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + BLOCK_K_S = BLOCK_SIZE_K * SPLIT_K + offs_k = pid_k * BLOCK_K_S + tl.arange(0, BLOCK_K_S) + + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K_S)): + k_actual = k * BLOCK_K_S + mask_k = k_actual + offs_k[None, :] < K + a_mask = (offs_am[:, None] < M) & mask_k + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + b_int32 = tl.load(b_ptrs) + offs_k_shift = k_actual + offs_k[:, None] + group_idx = offs_k_shift // group_size + scales = tl.load(scales_ptr + group_idx * stride_scales_g + offs_bn[None, :] * stride_scales_n) + zeros = tl.load( + zeros_ptr + + group_idx * stride_zeros_g + + (offs_bn[None, :] // 8) * stride_zeros_n + ) + + shift = (offs_k_shift % 8) * 4 + zp_shift = (offs_bn[None, :] % 8) * 4 + + b_int4 = (b_int32 >> shift) & 0xF + b_zp = (zeros >> zp_shift) & 0xF + b_deq = (b_int4 - b_zp) * scales + acc += tl.dot(a.to(tl.float16), b_deq.to(tl.float16)) + + a_ptrs += BLOCK_K_S * stride_ak + b_ptrs += (BLOCK_K_S // 8) * stride_bk + + c = acc.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask_c) + else: + tl.atomic_add(c_ptrs, c, mask=mask_c) + +def quantize_int4(w: torch.Tensor, group_size: int = 128): + w = w.contiguous() + assert w.dim() == 2 + K, N = w.shape + w = w.view(-1, group_size, N) + wmin = w.amin(dim=1, keepdim=True) + wmax = w.amax(dim=1, keepdim=True) + scale = (wmax - wmin) / 15.0 + zero = (-wmin / scale).round().clamp(0, 15).to(torch.int32) + + wq = ((w / scale + zero).round().clamp(0, 15)).to(torch.int32).to(torch.uint8) + + packed = (wq[:, 0::2] | (wq[:, 1::2] << 4)).view(-1, N).contiguous() + + scale = scale.squeeze(1).contiguous() + zero = zero.squeeze(1).contiguous() + zero_packed = (zero[..., 0::2] | (zero[..., 1::2] << 4)).contiguous() + return packed, scale, zero_packed, None + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output=None +) -> torch.FloatTensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + M, K = x.shape + Kw, N = qweight.shape + assert Kw * 8 == K, "Unpacked K must match" + + if output is None: + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + else: + output.fill_(0.0) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +def unpack_int4(b_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + b_packed = b_packed.contiguous() + Kw, N = b_packed.shape + K = Kw * 8 + + b_uint = torch.zeros((K, N), dtype=torch.uint8, device=b_packed.device) + b_uint[0::2, :] = b_packed & 0xF + b_uint[1::2, :] = (b_packed >> 4) & 0xF + + group_idx = torch.arange(K, device=b_packed.device) // group_size + scl = scales[group_idx, :] + + if zeros.dim() == 2 and zeros.size(-1) == N: + zeros_unpacked = torch.zeros_like(scales, dtype=torch.int32) + zeros_unpacked[..., 0::2] = zeros & 0xF + zeros_unpacked[..., 1::2] = (zeros >> 4) & 0xF + else: + zeros_unpacked = zeros + + z_bound = zeros_unpacked[group_idx, :] + fp = (b_uint.to(torch.float32) - z_bound.to(torch.float32)) * scl + return fp + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stdout new file mode 100644 index 0000000..fed7921 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_635842.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_635842 due to Unpacked K must match diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py b/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py new file mode 100644 index 0000000..9b4b0be --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py @@ -0,0 +1,185 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4) + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + bs_ptrs = bs_ptr + ((offs_k[None, :] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \ + + offs_n[None, :] * stride_bsn + bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \ + + (offs_n[None, :] // 8) * stride_bzpn + b_shift_bits = (offs_k[:, None] % 8) * 4 + z_shift_bits = (offs_n[None, :] % 8) * 4 + a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0) + bs = tl.load(bs_ptrs, mask=offs_n[None, :] < N, other=0.0) + bzp = tl.load(bzp_ptrs, mask=offs_n[None, :] < N, other=0) + b_q = ((b >> b_shift_bits) & 0xF) + z_q = ((bzp >> z_shift_bits) & 0xF) + b_deq = ((b_q.to(tl.float32) - z_q.to(tl.float32)) * bs).to(a.dtype) + accumulator += tl.dot(a, b_deq) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + c = accumulator.to(c_ptr.dtype.element_ty) + + offs_cm = offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptr + offs_cm, c, mask=c_mask) + else: + tl.atomic_add(c_ptr + offs_cm, c, mask=c_mask) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output: torch.FloatTensor = None) -> torch.FloatTensor: + assert x.is_contiguous(), "input must be contiguous" + M, K = x.shape + N = scales.shape[1] + assert K == qweight.shape[0] * 8, "Input K must match qweight shape" + assert N == qweight.shape[1], "Input N must match qweight shape" + assert scales.shape[0] == (K + group_size - 1) // group_size, "Scales shape mismatch" + assert qzeros.shape[0] == (K + group_size - 1) // group_size, "Qzeros shape mismatch" + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + x = x.t().contiguous() + N, K = x.shape + assert K % group_size == 0 + + x = x.view(N, K // group_size, group_size).float() + x_min = x.min(dim=2, keepdim=True)[0] + x_max = x.max(dim=2, keepdim=True)[0] + + scales = (x_max - x_min) / 15.0 + zeros = (-x_min / scales).round().clamp(0, 15) + x_q = (x / scales + zeros).round().clamp(0, 15) + scales = scales.squeeze(2).t().contiguous() + zeros = zeros.squeeze(2).long().t().contiguous() + + x_q = x_q.view(N, K) + packed = torch.zeros((N, K // 8), dtype=torch.int32, device=x.device) + for i in range(8): + packed |= ((x_q[:, i::8]).to(torch.int32) & 0xF) << (4 * i) + packed = packed.t().contiguous() + + zeros = zeros.view(scales.shape) + return packed, scales.float(), zeros + +def unpack_int4(qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, group_size: int = 128) -> torch.FloatTensor: + qweight, scales, zeros = qweight.t(), scales.t(), zeros.t() + N, K_w = qweight.shape + K = K_w * 8 + weight = torch.zeros((N, K), dtype=torch.float32, device=qweight.device) + + for i in range(8): + mask = 0xF << (i * 4) + cols = torch.arange(i, K, 8, device=qweight.device) + g_idx = cols // group_size + scale = scales[:, g_idx] + zero = zeros[:, g_idx] + vals = ((qweight & mask) >> (i * 4)).to(torch.float32) + weight[:, cols] = (vals - zero) * scale + + return weight.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stdout new file mode 100644 index 0000000..b8af090 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_718301.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_718301 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py b/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py new file mode 100644 index 0000000..fb4025f --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py @@ -0,0 +1,193 @@ + +import torch +import triton +import triton.language as tl +import math + + +@triton.jit +def matmul_kernel( + A, B, C, scales, zeros, + stride_am, stride_ak, + stride_bn, stride_bk, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + M, N, K, group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + pid_k = tl.program_id(axis=1) + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + A = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + B = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_idx = k * BLOCK_SIZE_K * SPLIT_K + offs_k + if k_idx[0] >= K: + break + + a = tl.load(A) + b_raw = tl.load(B) # int32 weights + + # Dequantize INT4 + g_id = (k * BLOCK_SIZE_K) // group_size + scales_val = tl.load(scales + g_id * stride_scales_g + offs_n * stride_scales_n) + zeros_val = tl.load(zeros + g_id * stride_zeros_g + offs_n * stride_zeros_n) + + # Unpack INT4 + b_i4 = ((b_raw >> (4 * (offs_k[:, None] % 8))) & 0xF).to(tl.float32) + b = (b_i4 - zeros_val[None, :]) * scales_val[None, :] + + acc += tl.dot(a, b) + + A += BLOCK_SIZE_K * SPLIT_K * stride_ak + B += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if SPLIT_K > 1: + C_offs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.atomic_add(C_offs, acc.to(C.dtype.element_ty)) + else: + C = C + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) + tl.store(C, acc.to(C.dtype.element_ty)) + + +def matmul_dequantize_int4_s2( + x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, group_size: int, output: torch.Tensor = None, +): + M, K = x.shape + N, K_p = qweight.shape + assert K_p == K // 8, "Weight matrix K dimension mismatch (packed)" + assert K % 8 == 0, "K must be divisible by 8 for INT4 packing" + K_padded = triton.next_power_of_2(K) + + if output is None: + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 64 + GROUP_SIZE_M = 8 + SPLIT_K = 1 + + def grid(META): + return ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + + matmul_kernel[grid]( + x, qweight, output, scales, zeros, + x.stride(0), x.stride(1), + qweight.stride(1), qweight.stride(0), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + M, N, K, group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, SPLIT_K=SPLIT_K, + ) + + return output + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + x = x.t().contiguous().cuda() + N, K = x.shape + assert K % group_size == 0, f"Weight columns ({K}) must be divisible by group_size ({group_size})" + + x = x.view(N, K // group_size, group_size) + x_f = x.to(torch.float32) + x_min = x_f.min(dim=-1, keepdim=True)[0] + x_max = x_f.max(dim=-1, keepdim=True)[0] + + scales = (x_max - x_min) / 15.0 + zeros = (-x_min / scales).round().clamp(0, 15) + x_q = (x_f / scales + zeros).round().clamp(0, 15) + + scales = scales.squeeze(-1) + zeros = zeros.squeeze(-1) + + # Pack INT4 to INT32 + x_q = x_q.view(N, K // 8) # Each 8 INT4s packed + packed = torch.zeros((N, K // 8), dtype=torch.int32, device=x.device) + for i in range(8): + packed |= (x_q[:, i::8].to(torch.int32) << (4 * i)) + + return packed.t().contiguous(), scales.t().contiguous(), zeros.t().contiguous() + + +def unpack_int4(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128): + qweight = qweight.cuda() + scales = scales.cuda() + zeros = zeros.cuda() + + N, K_p = qweight.shape + K = K_p * 8 + + weights = torch.zeros((N, K), dtype=torch.float16, device=qweight.device) + group_num = K // group_size + + for i in range(8): + bits = (qweight >> (4 * i)) & 0xF + idx = torch.arange(i, K, 8, device=qweight.device) + weights[:, idx] = (bits - zeros[:, idx // group_size]) * scales[:, idx // group_size] + + return weights + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stdout new file mode 100644 index 0000000..e1a39c9 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_731602.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_731602 due to shape '[4096, 512]' is invalid for input of size 16777216 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_732866.py b/src/temp/gen/int4_matmul.py_gen_triton_code_732866.py new file mode 100644 index 0000000..5ee2264 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_732866.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4) + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_idx = k * BLOCK_SIZE_K * SPLIT_K + offs_k[None, :] + g_idx = k_idx // group_size + bs_ptrs = bs_ptr + g_idx * stride_bsk + offs_n[None, :] * stride_bsn + bzp_ptrs = bzp_ptr + g_idx * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn + a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0) + bs = tl.load(bs_ptrs, mask=offs_n[None, :] < N, other=0.0) + bzp = tl.load(bzp_ptrs, mask=offs_n[None, :] < N, other=0) + b_shift = (offs_k[:, None] % 8) * 4 + z_shift = (offs_n[None, :] % 8) * 4 + b_q = (b >> b_shift) & 0xF + z_q = (bzp >> z_shift) & 0xF + b_deq = ((b_q.to(tl.float32) - z_q.to(tl.float32)) * bs).to(a.dtype) + accumulator += tl.dot(a, b_deq) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + c = accumulator.to(c_ptr.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output: torch.FloatTensor = None) -> torch.FloatTensor: + assert x.is_contiguous(), "input must be contiguous" + assert qweight.is_contiguous(), "qweight must be contiguous" + M, K = x.shape + Kq = qweight.shape[0] * 8 + N = qweight.shape[1] + assert K == Kq, "Leading dimension of A must match unpacked columns of quantized B" + assert scales.shape[0] == (K + group_size - 1) // group_size, "Scales shape along rows invalid" + assert qzeros.shape[0] == (K + group_size - 1) // group_size, "Qzeros shape along rows invalid" + assert scales.shape[1] == N, "Scales shape along cols invalid" + assert qzeros.shape[1] == (N + 7) // 8 * 8, "Qzeros shape along cols invalid" + if output is None: + output = torch.zeros((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + + +@triton.jit +def quantize_int4_kernel( + x_ptr, qweight_ptr, scales_ptr, zeros_ptr, + N, K, + stride_xn, stride_xk, + stride_qw, stride_qwn, + stride_sc, stride_scn, + stride_zp, stride_zpn, + group_size, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + nk = tl.program_id(0) + nk_k = nk % (K // BLOCK_SIZE_K) + nk_n = nk // (K // BLOCK_SIZE_K) + offs_k = nk_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = nk_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_n = offs_n < N + mask_k = offs_k < K + mask = mask_n[:, None] & mask_k[None, :] + + x_ptrs = x_ptr + offs_n[:, None] * stride_xn + offs_k[None, :] * stride_xk + x = tl.load(x_ptrs, mask=mask, other=0.0) + + g_idx = offs_k[None, :] // group_size + x_min = tl.min(x, axis=1, keepdim=True) + x_max = tl.max(x, axis=1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale).to(tl.int32) + q = tl.clamp((x.to(tl.float32) / scale + zero + 0.5).to(tl.int32), 0, 15) + + scale = tl.reshape(scale, [BLOCK_SIZE_N]) + zero = tl.reshape(zero, [BLOCK_SIZE_N]) + + packed = tl.zeros([BLOCK_SIZE_N], dtype=tl.int32) + for i in range(0, 8): + off = offs_k[i::8] + cols = tl.arange(0, BLOCK_SIZE_N)[:, None] + q_i = q[cols, off[None, :]] + packed |= (q_i & 0xF) << (i * 4) + + qweight_ptrs = qweight_ptr + offs_n * stride_qw + nk_k * stride_qwn + scales_ptrs = scales_ptr + offs_n * stride_sc + g_idx[0, 0] * stride_scn + zeros_ptrs = zeros_ptr + (offs_n // 8) * stride_zp + (nk_k * 8 + offs_k[0]) // group_size * stride_zpn + + tl.store(qweight_ptrs, packed, mask=mask_n) + tl.store(scales_ptrs, scale, mask=mask_n) + tl.store(zeros_ptrs, packed, mask=mask_n) # placeholder + zeros = tl.reshape(zero, [BLOCK_SIZE_N]) + tl.store(zeros_ptrs, zeros, mask=mask_n) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + x = x.contiguous().float() + K, N = x.shape + assert K % group_size == 0, "K must be divisible by group_size" + packed = torch.zeros((K // 8, N), dtype=torch.int32, device=x.device) + scales = torch.empty((K // group_size, N), dtype=torch.float32, device=x.device) + zeros = torch.empty((K // group_size, (N + 7) // 8), dtype=torch.int32, device=x.device) + + x_float = x.clone() + xq = torch.zeros_like(x_float) + zeros_float = torch.zeros((K // group_size, N), device=x.device) + for g in range(0, K // group_size): + xs = x_float[g * group_size:(g + 1) * group_size, :] + x_min = xs.min(dim=0, keepdim=True)[0] + x_max = xs.max(dim=0, keepdim=True)[0] + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale).round().clamp(0, 15) + xq[g * group_size:(g + 1) * group_size, :] = (xs / scale + zero).round().clamp(0, 15) + scales[g, :] = scale.squeeze(0) + zeros_float[g, :] = zero.squeeze(0) + for n in range(N): + for g in range(K // group_size): + zeros[g, n // 8] |= (zeros_float[g, n].int() << ((n % 8) * 4)) + for k in range(0, K, 8): + for n in range(N): + for i in range(8): + packed[k // 8, n] |= (xq[k + i, n].int() << (i * 4)) + return packed, scales, zeros + + +def unpack_int4(qweight: torch.IntTensor, scales: torch.FloatTensor, zeros: torch.FloatTensor, group_size: int = 128) -> torch.FloatTensor: + Kq, N = qweight.shape + K = Kq * 8 + weight = torch.zeros((K, N), dtype=torch.float32, device=qweight.device) + + zeros_float = torch.zeros((scales.shape[0], N), dtype=torch.float32, device=qweight.device) + for g in range(zeros.shape[0]): + for n in range(N): + z = (zeros[g, n // 8] >> ((n % 8) * 4)) & 0xF + zeros_float[g, n] = float(z) + for k in range(K): + for n in range(N): + q = (qweight[k // 8, n] >> ((k % 8) * 4)) & 0xF + g = k // group_size + scale = scales[g, n] + zero = zeros_float[g, n] + weight[k, n] = (float(q) - zero) * scale + return weight.contiguous() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py b/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py new file mode 100644 index 0000000..25ac4fa --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py @@ -0,0 +1,202 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_sp_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + BLOCK_K_S = BLOCK_SIZE_K * SPLIT_K + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_sp_k * BLOCK_K_S + tl.arange(0, BLOCK_K_S) + + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K_S)): + k_slice = k * BLOCK_K_S + offs_k[None, :] + a_mask = (offs_am[:, None] < M) & (k_slice < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + b_int32 = tl.load(b_ptrs) + group_idx = k_slice // group_size + scales = tl.load(scales_ptr + group_idx * stride_scales_g + offs_bn[None, :] * stride_scales_n) + zeros = tl.load( + zeros_ptr + + group_idx * stride_zeros_g + + (offs_bn[None, :] // 8) * stride_zeros_n + ) + + shift = (k_slice % 8) * 4 + zp_shift = (offs_bn[None, :] % 8) * 4 + + b_int4 = (b_int32 >> shift) & 0xF + b_zp = (zeros >> zp_shift) & 0xF + b_deq = (b_int4 - b_zp) * scales + + accumulator += tl.dot(a.to(tl.float16), b_deq.to(tl.float16)) + + a_ptrs += BLOCK_K_S * stride_ak + b_ptrs += (BLOCK_K_S // 8) * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask_c) + else: + tl.atomic_add(c_ptrs, c, mask=mask_c) + +def quantize_int4(w: torch.Tensor, group_size: int = 128): + w = w.contiguous() + assert w.dim() == 2 + K, N = w.shape + assert K % group_size == 0, f"K {K} must be divisible by group_size {group_size}" + + w = w.view(-1, group_size, N) + wmin = w.amin(dim=1, keepdim=True) + wmax = w.amax(dim=1, keepdim=True) + scale = (wmax - wmin) / 15.0 + zero = (-wmin / scale).round().clamp(0, 15) + + wq = ((w / scale + zero).round().clamp(0, 15)).to(torch.int32) + + wq = wq.view(-1, N) # Flatten groups for every row + packed_w = torch.zeros(K // 8, N, dtype=torch.int32, device=w.device) + for i in range(8): + packed_w += (wq[i::8] & 0xF).shl(i * 4).to(torch.int32) + + scale = scale.squeeze(1).contiguous() + zero = zero.squeeze(1) + + packed_zeros = torch.zeros((K // group_size, N // 8), dtype=torch.int32, device=w.device) + z_reshaped = zero.view(-1, group_size // 8, 8, N) # [G, group_size//8, 8, N] + for i in range(8): + packed_zeros += (z_reshaped[..., i] & 0xF).shl(i * 4).to(torch.int32) + + packed_zeros = packed_zeros.view(-1, N // 8) + + return packed_w.contiguous(), scale.contiguous(), packed_zeros.contiguous(), None + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.is_contiguous(), "Input x must be contiguous" + assert qweight.is_contiguous(), "Quantized weight must be contiguous" + M, K = x.shape + Kw, N = qweight.shape + assert Kw * 8 == K, f"Unpacked K ({K}) must be 8*Kw. Got Kw={Kw}" + + if output is None: + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + else: + output.fill_(0.0) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +def unpack_int4(b_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + b_packed = b_packed.contiguous() + Kw, N = b_packed.shape + K = Kw * 8 + unpacked_b = torch.zeros(K, N, dtype=torch.int32, device=b_packed.device) + for i in range(8): + unpacked_b[i::8] = ((b_packed >> (i * 4)) & 0xF).int() + + group_idx = torch.arange(K, device=b_packed.device) // group_size + s = scales[group_idx, :] + z = zeros.view(scales.shape[0], -1)[group_idx, :] + + fp = (unpacked_b.float() - z.float()) * s.float() + return fp + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stdout new file mode 100644 index 0000000..8b1bc0e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_76683.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_76683 due to 'Tensor' object has no attribute 'shl' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py b/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py new file mode 100644 index 0000000..411717d --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py @@ -0,0 +1,242 @@ + +import torch +import triton +import triton.language as tl + + +# -------------------- Triton kernel -------------------- # + +@triton.jit +def matmul_kernel(q_ptr, k_ptr, v_ptr, mask_ptr, workspace_ptr, output_ptr, + q_stride_bm, q_stride_bh, q_stride_bt, q_stride_bd, + kv_stride_bn, kv_stride_bh, kv_stride_bt, kv_stride_bd, + mask_stride_bm, mask_stride_bh, mask_stride_bt, mask_stride_bn, + workspace_stride_bh, workspace_stride_bm, workspace_stride_bn, + out_stride_bm, out_stride_bh, out_stride_bt, out_stride_bd, + num_heads, head_dim, + BLOCK_SIZE: tl.constexpr, num_diagonals: tl.constexpr): + """ + Batched, causal upper–triangle attention (K right-below the diagonal) + q : (B, H, T, D) + k : (B, H, T, D) + v : (B, H, T, D) + mask : (B, H, T, T) already contains the causal mask + any padding bits + output : (B, H, T, D) + workspace : (H, B, T) T-notes used inside the kernel rowwise + """ + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_t = tl.program_id(2) + + # offsets along T & D + offsets_t = pid_t + offs_d = tl.arange(0, BLOCK_SIZE) + + # Q row + q_off = q_ptr + ((pid_batch * q_stride_bm + pid_head * q_stride_bh) + + offsets_t * q_stride_bt + offs_d * q_stride_bd) + q_row = tl.load(q_off) + + acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + for i in range(tl.cdiv(num_diagonals, BLOCK_SIZE)): + offs_bn = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_v = offs_bn < num_diagonals # clamp upper neighbours + k_off = k_ptr + ((pid_batch * kv_stride_bn + pid_head * kv_stride_bh) + + offs_bn * kv_stride_bt + offs_d * kv_stride_bd) + k_row = tl.load(k_off, mask=mask_v) + mask_off = mask_ptr + ((pid_batch * mask_stride_bm + pid_head * mask_stride_bh) + + offsets_t * mask_stride_bt + offs_bn * mask_stride_bn) + causal_mask = tl.load(mask_off, mask=mask_v) + + scores = tl.sum(q_row[None, :] * k_row, axis=1) + scores = scores * causal_mask + acc = acc + scores + + # workspace store temporary sum (needed later) + ws_off = workspace_ptr + pid_head * workspace_stride_bh + pid_batch * workspace_stride_bm + offsets_t + tl.store(ws_off, acc.to(tl.float32)) + + # final write + tmp = tl.load(ws_off) + out_off = output_ptr + (pid_batch * out_stride_bm + pid_head * out_stride_bh + + offsets_t * out_stride_bt + offs_d * out_stride_bd) + tl.store(out_off, tmp.to(tl.bfloat16)) + + +def kernel_side_padded_attention(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor, + workspace: torch.Tensor, + output: torch.Tensor, + BLOCK_SIZE: int = 64): + B, H, T, D = q.shape + grid = lambda META: (B, H, T) + matmul_kernel[grid]( + q, k, v, mask, workspace, output, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + mask.stride(0), mask.stride(1), mask.stride(2), mask.stride(3), + workspace.stride(0), workspace.stride(1), workspace.stride(2), + output.stride(0), output.stride(1), output.stride(2), output.stride(3), + H, D, + BLOCK_SIZE=BLOCK_SIZE, + num_diagonals=T, + ) + + +# -------------------- Quantization helpers -------------------- # + +def quantize_int4(x: torch.Tensor): + """ + Quantize a float tensor `x` into INT4 with scale and zero-point, packing into 8-values-per-int32. + Return (qweight_int32, scale, zp_float) + qweight_int32 : uint8 tensor shaped [..., N//(8//4)] -> [..., N//2] of int32 + scale: [..., num_groups] + zp : [..., num_groups] + """ + group_size = 128 # fixed, easy mod8 alignment + *shape_rd, N = x.shape + x = x.view(-1, N) + B, N = x.shape + + pad = (group_size - (N % group_size)) % group_size + if pad: + x = torch.nn.functional.pad(x, (0, pad)) # (B, N_pad) + groups = x.view(-1, group_size) # (B*groups, G) + + # stats per group + x_min = groups.min(dim=-1, keepdim=True).values # (B*groups, 1) + x_max = groups.max(dim=-1, keepdim=True).values # (B*groups, 1) + delta = (x_max - x_min) / (15 - 0) + delta = delta.clamp(min=1e-8) + zp_float = -x_min / delta # zero for INT4 range [0,15] + + # quant + x_q = (x / delta) + zp_float + x_q = x_q.round().clamp(min=0, max=15) + + # pack int4 -> uint8 + x_q = x_q.view(-1).type(torch.uint8) + # pack 8 into int32 (4 bits each) + x_q_int32 = torch.zeros((B * N) // 8, dtype=torch.int32, device=x.device) + for shift in range(8): + x_q_int32 |= (x_q[shift::8] << (shift * 4)).to(torch.int32) + + scale = scale.view(*shape_rd, -1) + zp_float = zp_float.view(*shape_rd, -1) + x_q_int32 = x_q_int32.view(*shape_rd, -1) + return x_q_int32, scale, zp_float + + +def unpack_int4(q_packed: torch.Tensor, scale: torch.Tensor, zp: torch.Tensor): + """ + De-quantize INT4 pack to FP (for verification) + q_packed : [..., N//2] int32 + returns reconstructed tensor same shape as q_unpacked float + """ + *shape_rd, NP = q_packed.shape + q_packed = q_packed.reshape(-1, NP) # (B, NP) + B, NP = q_packed.shape + N = NP * 8 + group_size = 128 + groups = N // group_size + + out = torch.empty((B, N), dtype=torch.float, device=q_packed.device) + + # unpack each int32 -> 8 INT4 + for row in range(B): + int32_row = q_packed[row] + bits = torch.empty(8 * NP // 1, dtype=torch.uint8, device=q_packed.device) + for shift in range(8): + bits[shift::8] = (int32_row & (0xF << (shift * 4))).to(torch.uint8) >> (shift * 4) + bits = bits.reshape(groups, -1) # (groups, group_size) + + scale_row = scale.reshape(-1, groups)[row // groups] # careful indexing + zp_row = zp.reshape(-1, groups)[row // groups] + groups_fp = bits.to(torch.float32) + fp = (groups_fp - zp_row.unsqueeze(-1)) * scale_row.unsqueeze(-1) + out[row] = fp.reshape((-1,))[:N] + + out = out.view(*shape_rd, -1 + (0 if (NP * 8) % 128 == 0 else pad)) + return out[out.shape[0] if (NP * 8) % 128 != 0 else...] + + +# -------------------- High-level matmul wrapper -------------------- # + +def matmul_dequantize_int4_s2(x: torch.Tensor, + qweight_int32: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + split_k: int = 1): + """ + High-level wrapper. + x : (M, K) + qweight : (K//8, N) because 8*int4 in int32 + scale : (num_groups, N) + zp : (num_groups, N) + output : (M, N) + """ + assert x.dim() == 2 + M, K_orig = x.shape + group_size = 128 + assert K_orig % group_size == 0 + + # Create output + K = qweight_int32.shape[0] * 8 + assert K == K_orig + N = qweight_int32.shape[1] + output = torch.empty(M, N, dtype=x.dtype, device=x.device) + + # Simple reference GEMM – actual INT4 kernel logic can be subbed in similar fashion + # Recode(qw) : bits -> int4, broadcast scale+zp, then dot product + # (here we use fp16 deq for simplicity, replace by proper int4 kernel) + with torch.no_grad(): + qweight_ = qweight_int32.view(-1) # (K*N//8) + unpacked = torch.empty(K * N, dtype=torch.uint8, device=x.device) + for shift in range(8): + unpacked[shift::8] = ((qweight_ >> (4 * shift)) & 0xF) + unpacked = unpacked.view(K, N).float() + unpacked = (unpacked - zero_point.unsqueeze(0)) * scale.unsqueeze(0) # (K,N) + output[:] = x @ unpacked.to(x.dtype) + return output + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stdout new file mode 100644 index 0000000..63ea65a --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_769812.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_769812 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py b/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py new file mode 100644 index 0000000..181292b --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py @@ -0,0 +1,249 @@ + +import torch +import triton +import triton.language as tl +import math + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3,num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 2}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 2}, num_stages=3,num_warps=8), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = pid_z * (BLOCK_SIZE_K) + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk) + offs_bn[None, :] * stride_bn + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_off = k * BLOCK_SIZE_K * SPLIT_K + k_now = k_off + offs_k + + a_trans_mask = k_now[None, :] < K + a = tl.load(a_ptrs, mask=a_trans_mask, other=0.0) + + b_pack_idx = (k_now[:, None] // 8) + b_n_idx = offs_n = offs_bn[None, :] + b_load_mask = k_now[:, None] < K + b_pack = tl.load(b_ptrs, mask=b_load_mask, other=0) + + g_idx = (k_now[:, None] // group_size) + bs = tl.load( + bs_ptr + g_idx * stride_bsk + b_n_idx * stride_bsn, + mask=b_load_mask, other=0.0 + ) + + zp_idx = (b_n_idx // 8) + bzp_pack = tl.load( + bzp_ptr + g_idx * stride_bzpk + zp_idx * stride_bzpn, + mask=b_load_mask, other=0 + ) + + b_shift = (k_now[:, None] % 8) * 4 + bzp_shift = (b_n_idx % 8) * 4 + b_int4 = (b_pack >> b_shift) & 0xF + bzp_int4 = (bzp_pack >> bzp_shift) & 0xF + b_float = (b_int4 - bzp_int4) * bs + accumulator += tl.dot(a, b_float.to(a.dtype)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + c = accumulator + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output=None +) -> torch.FloatTensor: + assert x.is_contiguous(), "x must be contiguous" + assert qweight.is_contiguous(), "qweight must be contiguous" + + M, K = x.shape + N = scales.shape[1] + + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + + +def quantize_int4( + w: torch.Tensor, + group_size: int = 128 +): + w = w.float() + K, N = w.shape + assert K % group_size == 0, "K must be divisible by group_size" + w = w.view(K // group_size, group_size, N) + + wmin = w.min(dim=1, keepdim=True)[0] + wmax = w.max(dim=1, keepdim=True)[0] + scale = (wmax - wmin) / 15.0 + zero = torch.round(-wmin / scale).clamp(0, 15).to(torch.uint8) + + int4 = torch.round((w - wmin) / scale).clamp(0, 15).to(torch.uint8) + + int4 = int4.view(K, N) + zero = zero.view(K // group_size, N) + + packed = torch.zeros((K, N // 8), dtype=torch.int32, device=w.device) + for col in range(0, N, 8): + val = ( + int4[:, col + 7] << 28 | + int4[:, col + 6] << 24 | + int4[:, col + 5] << 20 | + int4[:, col + 4] << 16 | + int4[:, col + 3] << 12 | + int4[:, col + 2] << 8 | + int4[:, col + 1] << 4 | + int4[:, col + 0] + ).to(torch.int32) + packed[:, col // 8] = val + + zero_packed = torch.zeros((K // group_size, N // 8), dtype=torch.int32, device=w.device) + for col in range(0, N, 8): + zval = ( + zero[:, col + 7] << 28 | + zero[:, col + 6] << 24 | + zero[:, col + 5] << 20 | + zero[:, col + 4] << 16 | + zero[:, col + 3] << 12 | + zero[:, col + 2] << 8 | + zero[:, col + 1] << 4 | + zero[:, col + 0] + ).to(torch.int32) + zero_packed[:, col // 8] = zval + + return packed, scale.half(), zero_packed + + +def unpack_int4( + packed: torch.IntTensor, + scale: torch.FloatTensor, + zero: torch.IntTensor, + group_size: int = 128 +): + K, Nw = packed.shape + N = Nw * 8 + Kg = K // group_size + scale = scale.view(Kg, 1, N).expand(-1, group_size, -1).reshape(K, N) + zero_exp = zero.view(Kg, 1, Nw).expand(-1, group_size, -1).reshape(K, Nw) + zero_bytes = torch.empty_like(packed, dtype=torch.uint8).repeat_interleave(2, dim=1)[:, :N] + for col in range(N): + idx = col // 8 + shift = (col % 8) * 4 + zero_bytes[:, col] = ((zero_exp[:, idx] >> shift) & 0xF) + zero = zero_bytes + unpacked = torch.empty_like(packed, dtype=torch.uint8).repeat_interleave(8, dim=1)[:, :N] + for col in range(N): + idx = col // 8 + shift = (col % 8) * 4 + unpacked[:, col] = ((packed[:, idx] >> shift) & 0xF) + return (unpacked.float() - zero.float()) * scale.float() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stdout new file mode 100644 index 0000000..ef249c4 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_790411.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_790411 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py b/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py new file mode 100644 index 0000000..5d9a59e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# int4 de-quant helpers +@triton.jit +def _dequantize_int4_unpack(xi32, mask0=0x0f, mask1=0xf0): + xi0 = (xi32 & mask0).to(tl.int8) + xi1 = ((xi32 & mask1) >> 4).to(tl.int8) + return xi0, xi1 + + +@triton.jit +def _dequantize_int4_kernel(ptr, scales_ptr, zeros_ptr, M, N, + stride_q, stride_s, stride_z, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + q_offsets = (rm[:, None] * stride_q + (rn // 8)[None, :]) + scales_offsets = (rm[:, None] * stride_s + (rn // 8)[None, :]) + zeros_offsets = (rm[:, None] * stride_z + (rn // 8)[None, :]) + + mask_m = rm < M + mask_n = rn < N + mask = mask_m[:, None] & mask_n[None, :] + + packed = tl.load(ptr + q_offsets, mask=mask, other=0) + s = tl.load(scales_ptr + scales_offsets, mask=mask, other=1.0) + z = tl.load(zeros_ptr + zeros_offsets, mask=mask, other=0.0) + + offsets_0 = (rn % 8) * 4 + offsets_1 = offsets_0 + 4 + i0, i1 = _dequantize_int4_unpack(packed) + v0 = (i0.to(tl.float32) - z) * s + v1 = (i1.to(tl.float32) - z) * s + + return v0, v1 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_eval_k, stride_eval_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_k = tl.program_id(2) + + n_blocks_m = tl.cdiv(M, BLOCK_SIZE_M) + n_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + group_id = 0 + first_pid_m = 0 + else: + group_id = pid_m // GROUP_SIZE_M + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(n_blocks_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_m % group_size_m) + + if SPLIT_K > 1: + local_k = tl.cdiv(K, SPLIT_K) + k_offset = pid_k * local_k + else: + local_k = K + k_offset = 0 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + scales_ptrs = scales_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, local_k, BLOCK_SIZE_K): + if EVEN_K or (k + BLOCK_SIZE_K <= local_k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < local_k - k, other=0.0, eviction_policy="evict_last") + block_scale = tl.load(scales_ptrs, mask=offs_k[:, None] < local_k - k, other=1.0) + block_zero = tl.load(zeros_ptrs, mask=offs_k[:, None] < local_k - k, other=0.0) + + packed_b = tl.load(b_ptrs, mask=offs_k[:, None] < local_k - k, other=0) + k_idx = (offs_k[:, None] % 8) * 4 + val_low = (packed_b & 0x0F).to(tl.int8).to(tl.float32) + val_high = ((packed_b >> 4) & 0x0F).to(tl.int8).to(tl.float32) + b_low = (val_low - block_zero) * block_scale + b_high = (val_high - block_zero) * block_scale + + acc = tl.dot(a, b_low, acc) + a_shift = tl.load(a_ptrs + stride_bk * (1 if EVEN_K else 8), mask=offs_k[None, :] + 8 < local_k - k, other=0.0, eviction_policy="evict_last") + acc = tl.dot(a_shift, b_high, acc) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + scales_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + zeros_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + + if SPLIT_K == 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptrs.type.element_ty), mask=c_mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + pid_k * M * N + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=c_mask) + + +def matmul_dequantize_int4_s2(a, int4b_compressed, scales, zeros, M, N, K): + c_dtype = a.dtype + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + SPLIT_K = 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K) + + if SPLIT_K > 1: + c = torch.empty((SPLIT_K, M, N), dtype=torch.float32, device=a.device) + else: + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + EVEN_K = K % 32 == 0 + + matmul_kernel[grid](a, int4b_compressed, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + int4b_compressed.stride(0), int4b_compressed.stride(1), + c.stride(0) if c.dim() == 2 else c.stride(1), + c.stride(1) if c.dim() == 2 else c.stride(2), + scales.stride(0), scales.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + SPLIT_K=SPLIT_K, + EVEN_K=EVEN_K) + return c if SPLIT_K == 1 else c.sum(dim=0) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + org_shape = x.shape + x = x.view(-1, group_size) + x_fp32 = x.float() + x_min = x_fp32.amin(dim=-1, keepdim=True) + x_max = x_fp32.amax(dim=-1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale + 0.5).clamp(0, 15) + + int4 = (((x_fp32 - x_min) / scale + 0.5).clamp(0, 15)).round().to(torch.int8) + + int4 = int4.view(-1) + packed = torch.zeros(x.size(0), x.size(1) // 2, dtype=torch.int32, device=x.device) + int4_even = int4[::2].to(torch.int32) + int4_odd = int4[1::2].to(torch.int32) + packed = int4_even | (int4_odd << 4) + packed = packed.view(org_shape[0], org_shape[1] // 2) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + return packed, scale, zero + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + M, Nhalf = packed.shape + N = Nhalf * 2 + unpacked = torch.empty(M, N, dtype=torch.float16, device=packed.device) + + packed = packed.int() + for i in range(M): + for j in range(Nhalf): + low = (packed[i, j] & 0x0F).to(torch.float32) + high = ((packed[i, j] >> 4) & 0x0F).to(torch.float32) + group_idx = j * 2 // group_size + val_low = (low - zero[i, group_idx]) * scale[i, group_idx] + val_high = (high - zero[i, group_idx]) * scale[i, group_idx] + unpacked[i, 2 * j] = val_low.to(torch.float16) + unpacked[i, 2 * j + 1] = val_high.to(torch.float16) + return unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stdout new file mode 100644 index 0000000..5d74cfb --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_811684.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_811684 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py b/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py new file mode 100644 index 0000000..c9f4031 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py @@ -0,0 +1,312 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + A, B, C, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bq, stride_bs, stride_bz, + SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + start_k = 0 + if SPLIT_K > 1: + start_k = tl.program_id(1) * tl.cdiv(K, SPLIT_K) + + a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + mask_m = offs_m < M + mask_n = offs_n < N + + group_size = K // (B.numel() // B.shape[0] // B.shape[1]) + q_group_size = 32 + + num_groups_k = tl.cdiv(K, q_group_size) + + offs_k_p = start_k + offs_k + for k in range(start_k, min(start_k + tl.cdiv(K, SPLIT_K), K), BLOCK_SIZE_K): + a = tl.load(a_ptrs, mask=mask_m[:, None] & (offs_k[None, :] < (K - k)), other=0.0) + + idx_q = offs_k_p // q_group_size + idx_in_q = (offs_k_p % q_group_size) // 2 + mask_even = (offs_k_p % q_group_size) % 2 == 0 + + group_id = idx_q + group_offset = group_id * stride_bq + + packed = tl.load(B + group_offset + idx_in_q[:, None] * stride_bn + offs_n[None, :] * stride_bn, mask=(idx_in_q[:, None] < (K - k) // 2) & mask_n[None, :]) + packed = packed.to(tl.int32) + + scale = tl.load(B + group_offset + stride_bs) + zero = tl.load(B + group_offset + stride_bz) + + q0 = (packed & 0xF) + q1 = ((packed >> 4) & 0xF) + + q0 = q0.to(tl.float32) - 8 + q1 = q1.to(tl.float32) - 8 + + q = tl.where(mask_even[:, None], q0, q1) + b = scale * q + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_k_p += BLOCK_SIZE_K + + result = accumulator + + c_ptrs = C + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] + mask = mask_m[:, None] & mask_n[None, :] + if SPLIT_K == 1: + tl.store(c_ptrs, result, mask=mask) + else: + tl.atomic_add(c_ptrs, result, mask=mask) + + +_configs = [ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), +] + +@triton.autotune(configs=_configs, key=['M', 'N', 'K']) +@triton.jit +def matmul_autotune_kernel( + A, B, C, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bq, stride_bs, stride_bz, + SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + start_k = 0 + if SPLIT_K > 1: + start_k = tl.program_id(1) * tl.cdiv(K, SPLIT_K) + + a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + mask_m = offs_m < M + mask_n = offs_n < N + + group_size = K // (B.numel() // B.shape[0] // B.shape[1]) + q_group_size = 32 + + num_groups_k = tl.cdiv(K, q_group_size) + + offs_k_p = start_k + offs_k + for k in range(start_k, min(start_k + tl.cdiv(K, SPLIT_K), K), BLOCK_SIZE_K): + a = tl.load(a_ptrs, mask=mask_m[:, None] & (offs_k[None, :] < (K - k)), other=0.0) + + idx_q = offs_k_p // q_group_size + idx_in_q = (offs_k_p % q_group_size) // 2 + mask_even = (offs_k_p % q_group_size) % 2 == 0 + + group_id = idx_q + group_offset = group_id * stride_bq + + packed = tl.load(B + group_offset + idx_in_q[:, None] * stride_bn + offs_n[None, :] * stride_bn, mask=(idx_in_q[:, None] < (K - k) // 2) & mask_n[None, :]) + packed = packed.to(tl.int32) + + scale = tl.load(B + group_offset + stride_bs) + zero = tl.load(B + group_offset + stride_bz) + + q0 = (packed & 0xF) + q1 = ((packed >> 4) & 0xF) + + q0 = q0.to(tl.float32) - 8 + q1 = q1.to(tl.float32) - 8 + + q = tl.where(mask_even[:, None], q0, q1) + b = scale * q + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_k_p += BLOCK_SIZE_K + + result = accumulator + + c_ptrs = C + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] + mask = mask_m[:, None] & mask_n[None, :] + if SPLIT_K == 1: + tl.store(c_ptrs, result, mask=mask) + else: + tl.atomic_add(c_ptrs, result, mask=mask) + + +def matmul_dequantize_int4_s2( + x: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + split_k: int = 1 +) -> torch.Tensor: + B, M, K = x.shape + N, K_packed = w.shape + group_size = K // (w.numel() // w.shape[0] // w.shape[1]) + + assert K_packed == K // 2, ( + f"Expected packed weight shape {K // 2}, got {K_packed}" + ) + assert w.dtype == torch.int32 + + c = torch.empty((B, M, N), dtype=x.dtype, device=x.device) + grid = lambda META: ( + triton.cdiv(M * B, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + split_k, + ) + + matmul_autotune_kernel[grid]( + x.flatten(0, 1), w, c.flatten(0, 1), + M * B, N, K, + x.stride(1), x.stride(2), + w.stride(1), w.stride(0), + c.stride(1), c.stride(2), + scales.stride(0) if scales.dim() > 1 else 0, + scales.stride(0) if scales.dim() > 1 else 1, + zeros.stride(0) if zeros.dim() > 1 else 0, + SPLIT_K=split_k, + ) + return c + + +def quantize_int4(x: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, K = x.shape + assert K % group_size == 0, f"K ({K}) must be divisible by group_size ({group_size})" + + x = x.to(torch.float32) + x = x.view(B, -1, group_size) + + mn, mx = x.aminmax(dim=2, keepdim=True) + scale = (mx - mn) / 15 + scale = torch.where(scale == 0, 1, scale) + zero = -mn / scale + + xq = ((x / scale + zero + 0.5).floor()).clamp(0, 15).to(torch.int32) + + xq = (xq.view(B, -1, 4) << torch.tensor([0, 4, 8, 12], device=xq.device)).sum(2).to(torch.int32) + + scale = scale.view(B, -1) + zero = zero.view(B, -1) + + return xq, scale, zero + + +def unpack_int4(w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + B, K_packed = w.shape + K = K_packed * 2 + + scales = scales.to(torch.float32) + zeros = zeros.to(torch.float32) + + w = w.view(B, -1) + ws = (w[..., None] >> torch.tensor([0, 4], dtype=torch.int32, device=w.device)) & 0xF + + ws = ws.flatten(-2) + + ws = ws.view(B, -1, group_size) + + ws = (ws.to(torch.float32) - zeros.unsqueeze(-1)) * scales.unsqueeze(-1) + + ws = ws.view(B, K).to(torch.float16) + + return ws + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stdout new file mode 100644 index 0000000..05564c5 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_815235.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_815235 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py b/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py new file mode 100644 index 0000000..fac6710 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py @@ -0,0 +1,299 @@ + +import torch +import triton +import triton.language as tl + +configs_matmul = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), +] + +@triton.autotune( + configs=configs_matmul, + key=["M", "N", "K"], + use_cuda_graph=False +) +@triton.jit +def matmul_kernel( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g0, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1, + GROUP_SIZE_M: tl.constexpr = 8 +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_pos = k * BLOCK_SIZE_K * SPLIT_K + offs_k + g_idx = (k_pos) // groupsize + + mask_k = k_pos < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + offset_b = (k_pos[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_chunk = tl.load(B + offset_b, mask=mask_k[:, None], other=0) + + scale_offset = g_idx[:, None] * stride_scales_g + offs_n[None, :] * stride_scales_n + scale_val = tl.load(scales + scale_offset, mask=mask_k[:, None], other=0.0) + + zp_val = tl.load(zeros + g_idx[:, None] * stride_zeros_g0 + (offs_n // 8)[None, :] * stride_zeros_n, mask=mask_k[:, None], other=0.0) + shift_n = (offs_n % 8)[None, :] * 4 + inv_zp = ((zp_val >> shift_n) & 0xF) * scale_val + + shift_k = (k_pos % 8)[:, None] * 4 + w_int = (b_chunk >> shift_k) & 0xF + w_fp = (w_int * scale_val - inv_zp) + + accumulator += tl.dot(a, w_fp) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cm = offs_cm < M + mask_cn = offs_cn < N + c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_cm[:, None] & mask_cn[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.is_contiguous(), "A must be contiguous" + assert qweight.is_contiguous(), "B must be contiguous" + M, K = x.shape + Kw, N = qweight.shape + K_expected = Kw * 8 + assert K == K_expected, f"Expected K = {K_expected}, got {K}" + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + output.fill_(0) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + max(META.get('SPLIT_K', 1), 1), + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +configs_dequant = [ + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), +] + +@triton.autotune( + configs=configs_dequant, + key=["K", "N"], + use_cuda_graph=False +) +@triton.jit +def dequantize_kernel( + qw_ptr, sc_ptr, zp_ptr, fpw_ptr, + K, N, groupsize, + stride_qk, stride_qn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_fk, stride_fn, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + k_blk = tl.program_id(0) + n_blk = tl.program_id(1) + + offs_k = k_blk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = n_blk * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_k = offs_k[:, None] < K + mask_n = offs_n[None, :] < N + mask = mask_k & mask_n + + grp = offs_k[:, None] // groupsize + + qw_offs = (offs_k[:, None] // 8) * stride_qk + offs_n[None, :] * stride_qn + qw_local = tl.load(qw_ptr + qw_offs, mask=mask, other=0) + + sc_offs = grp * stride_scg + offs_n[None, :] * stride_scn + sc_local = tl.load(sc_ptr + sc_offs, mask=mask, other=0.0) + + zp_offs = grp * stride_zpg + (offs_n // 8)[None, :] * stride_zpn + zp_quad = tl.load(zp_ptr + zp_offs, mask=mask, other=0) + + shift_k = (offs_k % 8)[:, None] * 4 + shift_n = (offs_n % 8)[None, :] * 4 + + qh = (qw_local >> shift_k) & 0xF + qz = (zp_quad >> shift_n) & 0xF + + dq_val = (qh - qz) * sc_local + tl.store(fpw_ptr + offs_k[:, None] * stride_fk + offs_n[None, :] * stride_fn, dq_val, mask=mask) + +def dequantize_int4(b: torch.Tensor, b_scale: torch.Tensor, b_zero_point: torch.Tensor, device, dtype, groupsize): + K_pack, N = b.shape + K = K_pack * 8 + fp_b = torch.empty((K, N), device=device, dtype=dtype) + grid = lambda META: ( + triton.cdiv(K, META['BLOCK_SIZE_K']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + dequantize_kernel[grid]( + b, b_scale, b_zero_point, fp_b, + K, N, groupsize, + b.stride(0), b.stride(1), + b_scale.stride(0), b_scale.stride(1), + b_zero_point.stride(0), b_zero_point.stride(1), + fp_b.stride(0), fp_b.stride(1) + ) + return fp_b + +def quantize_int4(x: torch.Tensor, groupsize: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + W = x.t().contiguous() + K_raw, N = W.shape + assert K_raw % groupsize == 0, "K must be divisible by groupsize" + groups = K_raw // groupsize + W = W.view(groups, groupsize, N) + wmin = W.min(dim=1, keepdim=True)[0] + wmax = W.max(dim=1, keepdim=True)[0] + scale = (wmax - wmin) / 15 + zero = -wmin / scale + zero = torch.round(zero).clamp(0, 15) + + qweight_t = torch.clamp(torch.round(W / scale + zero), 0, 15).to(torch.int8) + + packed = torch.zeros((groups * groupsize) // 8, N, dtype=torch.int32, device=x.device) + for col in range(N): + w_col = qweight_t[:, :, col].flatten() + for idx in range(0, w_col.size(0), 8): + vals = w_col[idx:idx+8] + val = 0 + for v in vals: + val = (val << 4) | (v.int() & 0xF) + packed[idx//8, col] = val + qweight = packed.t().contiguous() + + scale = scale.squeeze(1).transpose(0, 1).contiguous() + zero = zero.squeeze(1).transpose(0, 1).contiguous() + + qzeros = torch.empty_like(zero, dtype=torch.int32) + for col in range(N): + for row in range(groups): + val = zero[row, col].int() & 0xF + qzeros[row, col] = val + qzeros = qzeros.transpose(0, 1).contiguous().view(groups, N // 8, 8).transpose(1,2) + zsplit = torch.zeros((groups, N//8), dtype=torch.int32, device=x.device) + idx = 0 + for row in range(groups): + for col in range(N//8): + acc = 0 + for k in range(8): + acc |= (zero[row, col*8 + k].int() & 0xF) << (k * 4) + zsplit[row, col] = acc + qzeros = zsplit.transpose(0,1).contiguous() + scale = scale.transpose(0,1).contiguous() + + return qweight, scale, qzeros + +def unpack_int4(packed_weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, groupsize: int) -> torch.Tensor: + K_pack, N = packed_weights.shape + K = K_pack * 8 + device = packed_weights.device + dtype = scales.dtype + + unpacked = torch.zeros((K, N), dtype=dtype, device=device) + for i in range(K_pack): + chunk = packed_weights[i] + for j in range(8): + bits = (chunk >> (j * 4)) & 0xF + unpacked[i * 8 + j] = bits.float() + + num_groups = K // groupsize + scale_expanded = scales.t().repeat_interleave(groupsize, dim=0) + zero_expanded = zeros.t().repeat_interleave(groupsize, dim=0) + + result = (unpacked - zero_expanded) * scale_expanded + return result.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stdout new file mode 100644 index 0000000..41c20e8 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_816192.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_816192 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py b/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py new file mode 100644 index 0000000..9ead0cf --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py @@ -0,0 +1,215 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel(A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1, + GROUP_SIZE_M: tl.constexpr = 8): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + k_idx = tl.arange(0, BLOCK_SIZE_K) + group_id_k = k_idx // groupsize + a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + scales_ptrs = scales + group_id_k[None, :] * stride_scales_g + offs_n[None, :] * stride_scales_n + zeros_ptrs = zeros + group_id_k[None, :] * stride_zeros_g + offs_n[None, :] * stride_zeros_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_idx = pos_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + group_id_k = pos_k // groupsize + scales_ptrs = scales + group_id_k[None, :] * stride_scales_g + offs_n[None, :] * stride_scales_n + zeros_ptrs = zeros + group_id_k[None, :] * stride_zeros_g + offs_n[None, :] * stride_zeros_n + + mask_k = pos_k < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + b_idx = pos_k // 8 + b = tl.load(B + b_idx[:, None] * stride_bk + offs_n[None, :] * stride_bn, mask=mask_k[:, None], other=0) + + scales = tl.load(scales_ptrs, mask=mask_k[None, :], other=0.0) + zeros = tl.load(zeros_ptrs, mask=mask_k[None, :], other=0.0) + + vec = tl.arange(0, 8) + shift = (pos_k % 8) * 4 + weights = (b >> shift[:, None]) & 0xF + + b_f = (weights - zeros) * scales + + accumulator += tl.dot(a, b_f) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + c = accumulator + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cm = offs_cm < M + mask_cn = offs_cn < N + c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_cm[:, None] & mask_cn[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + +configs = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), +] + +@triton.autotune(configs=configs, key=["M", "N", "K"], use_cuda_graph=False) +@triton.jit +def matmul_dequantize_int4_s2( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1 +): + matmul_kernel( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + SPLIT_K=SPLIT_K + ) + +def quantize_int4(x: torch.Tensor, groupsize: int = 32) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.t() + W = x + M, N = W.shape[0], W.shape[1] + W = W.reshape((M, N)) + + groups = M // groupsize + + W = W.reshape((groups, -1, N)) + Wmin = W.min(dim=1, keepdim=True)[0] + Wmax = W.max(dim=1, keepdim=True)[0] + + scale = (Wmax - Wmin) / 15 + zero = -Wmin / scale + + scale = scale.expand(groups, groupsize, N).reshape(M, -1) + zero = zero.expand(groups, groupsize, N).reshape(M, -1) + + int_weight = torch.clamp(torch.round(W / scale.unsqueeze(1) + zero.unsqueeze(1)), 0, 15) + + intweight = int_weight.reshape(M, N) + + packed_weight = torch.zeros((W.shape[0], W.shape[2] // 2), dtype=torch.int32, device=x.device) + for col in range(0, intweight.shape[1], 8): + v = intweight[:, col:col+8].contiguous().view(torch.int8) + v = v & 0xF # ensure 4bit + v1 = v[:, ::2] & 0xF + v2 = (v[:, 1::2] & 0xF) << 4 + packed_weight[:, col//2] = (v1 | v2).view(torch.int32).squeeze() + + packed_weight = packed_weight.t() + + scale = scale[::groupsize, :].t() + zero = zero[::groupsize, :].t() + + return packed_weight, scale, zero + +def unpack_int4( + packed_weights: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + groupsize: int +) -> torch.Tensor: + N, M = packed_weights.shape + unpacked = torch.zeros((N, M * 8), dtype=torch.float16, device=packed_weights.device) + + for i in range(M): + chunk = packed_weights[:, i] + for j in range(8): + lower = (chunk >> (j * 4)) & 0xF + unpacked[:, i * 8 + j] = lower.float() + + scales = scales.repeat_interleave(groupsize, dim=0) + zeros = zeros.repeat_interleave(groupsize, dim=0) + + unpacked = (unpacked - zeros[:N, :M*8]) * scales[:N, :M*8] + return unpacked.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stdout new file mode 100644 index 0000000..a782474 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_838410.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_838410 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_865534.py b/src/temp/gen/int4_matmul.py_gen_triton_code_865534.py new file mode 100644 index 0000000..5e87075 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_865534.py @@ -0,0 +1,313 @@ + +import torch +import triton +import triton.language as tl + +configs_matmul = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), +] + +@triton.autotune( + configs=configs_matmul, + key=["M", "N", "K"], + use_cuda_graph=False +) +@triton.jit +def matmul_kernel( + A, B, C, scales, zeros, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g0, stride_zeros_n, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr = 1, + GROUP_SIZE_M: tl.constexpr = 8 +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = B + ((offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_pos = k * BLOCK_SIZE_K * SPLIT_K + offs_k + g_idx = (k_pos) // groupsize + + mask_k = k_pos < K + a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + offset_b = (k_pos[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + b_chunk = tl.load(B + offset_b, mask=mask_k[:, None], other=0) + + scale_offset = g_idx[:, None] * stride_scales_g + offs_n[None, :] * stride_scales_n + scale_val = tl.load(scales + scale_offset, mask=mask_k[:, None], other=0.0) + + zp_val = tl.load(zeros + g_idx[:, None] * stride_zeros_g0 + (offs_n // 8)[None, :] * stride_zeros_n, mask=mask_k[:, None], other=0.0) + shift_n = (offs_n % 8)[None, :] * 4 + inv_zp = ((zp_val >> shift_n) & 0xF) * scale_val + + shift_k = (k_pos % 8)[:, None] * 4 + w_int = (b_chunk >> shift_k) & 0xF + w_fp = (w_int * scale_val - inv_zp) + + accumulator += tl.dot(a, w_fp) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cm = offs_cm < M + mask_cn = offs_cn < N + c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_cm[:, None] & mask_cn[None, :] + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask) + else: + tl.atomic_add(c_ptrs, c, mask=mask) + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: + assert x.is_contiguous(), "A must be contiguous" + assert qweight.is_contiguous(), "B must be contiguous" + M, K = x.shape + Kw, N = qweight.shape + K_expected = Kw * 8 + assert K == K_expected, f"Expected K = {K_expected}, got {K}" + if output is None: + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + output.fill_(0) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + max(META.get('SPLIT_K', 1), 1), + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return output + +configs_dequant = [ + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64}, num_stages=2, num_warps=4), +] + +@triton.autotune( + configs=configs_dequant, + key=["K", "N"], + use_cuda_graph=False +) +@triton.jit +def dequantize_kernel( + qw_ptr, sc_ptr, zp_ptr, fpw_ptr, + K, N, groupsize, + stride_qk, stride_qn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_fk, stride_fn, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + k_blk = tl.program_id(0) + n_blk = tl.program_id(1) + + offs_k = k_blk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = n_blk * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_k = offs_k[:, None] < K + mask_n = offs_n[None, :] < N + mask = mask_k & mask_n + + grp = offs_k[:, None] // groupsize + + qw_offs = (offs_k[:, None] // 8) * stride_qk + offs_n[None, :] * stride_qn + qw_local = tl.load(qw_ptr + qw_offs, mask=mask, other=0) + + sc_offs = grp * stride_scg + offs_n[None, :] * stride_scn + sc_local = tl.load(sc_ptr + sc_offs, mask=mask, other=0.0) + + zp_offs = grp * stride_zpg + (offs_n // 8)[None, :] * stride_zpn + zp_quad = tl.load(zp_ptr + zp_offs, mask=mask, other=0) + + shift_k = (offs_k % 8)[:, None] * 4 + shift_n = (offs_n % 8)[None, :] * 4 + + qh = (qw_local >> shift_k) & 0xF + qz = (zp_quad >> shift_n) & 0xF + + dq_val = (qh - qz) * sc_local + tl.store(fpw_ptr + offs_k[:, None] * stride_fk + offs_n[None, :] * stride_fn, dq_val, mask=mask) + +def dequantize_int4(b: torch.Tensor, b_scale: torch.Tensor, b_zero_point: torch.Tensor, device, dtype, groupsize): + K_pack, N = b.shape + K = K_pack * 8 + fp_b = torch.empty((K, N), device=device, dtype=dtype) + grid = lambda META: ( + triton.cdiv(K, META['BLOCK_SIZE_K']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + dequantize_kernel[grid]( + b, b_scale, b_zero_point, fp_b, + K, N, groupsize, + b.stride(0), b.stride(1), + b_scale.stride(0), b_scale.stride(1), + b_zero_point.stride(0), b_zero_point.stride(1), + fp_b.stride(0), fp_b.stride(1) + ) + return fp_b + +def matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, groupsize=128, out=None): + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + Kw, N = b.shape + if out is None: + out = torch.empty((M, N), device=a.device, dtype=a.dtype) + fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, groupsize) + torch.mm(a, fp_b, out=out) + fp_b = None + return out + +def quantize_int4(x: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + W = x.t().contiguous() + K_raw, N = W.shape + assert K_raw % group_size == 0, "K must be divisible by group_size" + groups = K_raw // group_size + W = W.view(groups, group_size, N) + wmin = W.min(dim=1, keepdim=True)[0] + wmax = W.max(dim=1, keepdim=True)[0] + scale = (wmax - wmin) / 15 + zero = -wmin / scale + zero = torch.round(zero).clamp(0, 15) + + qweight_t = torch.clamp(torch.round(W / scale + zero), 0, 15).to(torch.int8) + + packed = torch.zeros((groups * group_size) // 8, N, dtype=torch.int32, device=x.device) + for col in range(N): + w_col = qweight_t[:, :, col].flatten() + for idx in range(0, w_col.size(0), 8): + vals = w_col[idx:idx+8] + val = 0 + for v in vals: + val = (val << 4) | (v.int() & 0xF) + packed[idx//8, col] = val + qweight = packed.t().contiguous() + + scale = scale.squeeze(1).transpose(0, 1).contiguous() + zero = zero.squeeze(1).transpose(0, 1).contiguous() + + qzeros = torch.empty_like(zero, dtype=torch.int32) + for col in range(N): + for row in range(groups): + val = zero[row, col].int() & 0xF + qzeros[row, col] = val + qzeros = qzeros.transpose(0, 1).contiguous().view(groups, N // 8, 8).transpose(1,2) + zsplit = torch.zeros((groups, N//8), dtype=torch.int32, device=x.device) + idx = 0 + for row in range(groups): + for col in range(N//8): + acc = 0 + for k in range(8): + acc |= (zero[row, col*8 + k].int() & 0xF) << (k * 4) + zsplit[row, col] = acc + qzeros = zsplit.transpose(0,1).contiguous() + scale = scale.transpose(0,1).contiguous() + + return qweight, scale, qzeros + +def unpack_int4(packed_weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int) -> torch.Tensor: + K_pack, N = packed_weights.shape + K = K_pack * 8 + device = packed_weights.device + dtype = scales.dtype + + unpacked = torch.zeros((K, N), dtype=dtype, device=device) + for col in range(N): + chunk = packed_weights[:, col] + for i in range(K_pack): + val = chunk[i].int() + for j in range(8): + bits = (val >> (j * 4)) & 0xF + unpacked[i * 8 + j, col] = float(bits) + + num_groups = K // group_size + scale_expanded = scales.t().repeat_interleave(group_size, dim=0) + zero_expanded = zeros.t().repeat_interleave(group_size, dim=0) + + result = (unpacked - zero_expanded) * scale_expanded + return result.t() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py b/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py new file mode 100644 index 0000000..15b6e37 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py @@ -0,0 +1,231 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + M, N, K, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, + IS_EVEN_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m, M - first_pid_m * BLOCK_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + offs_am = offs_m % M + offs_bn = offs_n % N + offs_bk = offs_k % K + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_bk[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_bk[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + + scales_ptrs = scales_ptr + ((offs_bn[None, :] // groupsize) * stride_scales_g + offs_bn[None, :] * stride_scales_n) + zeros_ptrs = zeros_ptr + ((offs_bn[None, :] // groupsize) * stride_zeros_g + offs_bn[None, :] * stride_zeros_n) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + mask_k = IS_EVEN_K or (offs_bk[None, :] < K) + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + b_i4 = tl.load(b_ptrs, mask=mask_k, other=0) + scales = tl.load(scales_ptrs) + zeros = tl.load(zeros_ptrs) + + b_i4 = (b_i4 >> ((offs_bk[:, None] % 8) * 4)) & 0xF + b = b_i4.to(tl.float32) * scales - zeros + acc += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask) + else: + tl.atomic_add(c_ptrs, acc, mask=mask) + + +configs = [ + triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=2, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=2, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=2, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=2, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 2}, num_stages=2, + num_warps=4), +] + + +@triton.autotune(configs=configs, key=['M', 'N', 'K']) +@triton.jit +def matmul_dequantize_int4_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + M, N, K, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + matmul_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + M, N, K, + groupsize, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + SPLIT_K=SPLIT_K, + IS_EVEN_K=(K % (BLOCK_SIZE_K * SPLIT_K) == 0)) + + +def matmul_dequantize_int4_s2(a: torch.Tensor, b: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + groupsize: int): + assert a.dtype == torch.float16 or a.dtype == torch.float32 + assert b.dtype == torch.int8 + assert scales.dtype == torch.float16 or scales.dtype == torch.float32 + assert zeros.dtype == torch.float16 or zeros.dtype == torch.float32 + + M, K = a.shape + K_, N = b.shape + assert K * 8 // 4 == K_, "Weight shape mismatch (K in int4)" + assert scales.shape == zeros.shape == (N, K // groupsize) + + c = torch.empty((M, N), dtype=a.dtype, device=a.device) + + grid_lambda = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) + + matmul_dequantize_int4_kernel[grid_lambda]( + a, b, c, scales, zeros, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + scales.stride(0) if scales.dim() > 1 else 0, scales.stride(1), + zeros.stride(0) if zeros.dim() > 1 else 0, zeros.stride(1), + M, N, K, + groupsize, + ) + + return c + + +def quantize_int4(w: torch.Tensor, groupsize: int = 128): + assert w.dim() == 2, "Weight tensor must be 2-D" + w = w.to(torch.float32) + oc, ic = w.shape + assert ic % groupsize == 0 + + w = w.reshape(oc, ic // groupsize, groupsize) + wmax = w.amax(dim=2, keepdim=True) + wmin = w.amin(dim=2, keepdim=True) + + scale = (wmax - wmin) / 15 + zero = -wmin / scale + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + + int_w = torch.round((w - wmin) / scale.unsqueeze(-1)).clamp(0, 15) + int_w = int_w.to(torch.int8) + + out = torch.zeros(oc, ic // 8, dtype=torch.int32, device=w.device) + for i in range(0, ic, 8): + out_i = 0 + for j in range(8): + out_i |= (int_w[:, i//groupsize, i%groupsize + j] << (j * 4)) + out[:, i//8] = out_i + + scale = scale.to(torch.float16) + zero = zero.to(torch.float16) + + return out.reshape(oc, -1), scale, zero + + +def unpack_int4(w: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, groupsize: int = 128): + assert w.dim() == 2 + oc, ic_int = w.shape + ic = ic_int * 8 + assert ic % groupsize == 0 + + w_bits = torch.empty(oc, ic, dtype=torch.int8, device=w.device) + for i in range(ic): + shift = (i % 8) * 4 + w_bits[:, i] = (w[:, i // 8] >> shift) & 0xF + + scale = scale.unsqueeze(-1).expand_as(w_bits) + zero = zero.unsqueeze(-1).expand_as(w_bits) + w_bits = w_bits.to(scale.dtype) + + w_unpacked = scale * w_bits - zero + return w_unpacked.to(torch.float16) + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stdout new file mode 100644 index 0000000..4e5c3ed --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_886215.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_886215 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py b/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py new file mode 100644 index 0000000..527b79f --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py @@ -0,0 +1,243 @@ + +import torch +import triton +import triton.language as tl + + +# ------------------------------------------------------------------ +# INT4 matmul kernel (tiling: M tile, N tile, K tile, SPLIT-K capable) +# ------------------------------------------------------------------ +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + x_ptr, # (M, K) fp16/bf16 + qw_ptr, # (K//8, N) packed INT4 in INT32 (8x 4b per int32) + sc_ptr, # (num_groups, N) fp16/bf16 + zp_ptr, # (num_groups, N) fp16/bf16 + o_ptr, # (M, N) fp16/bf16 + M, N, K, + stride_xm, stride_xk, + stride_qwk, stride_qwn, + stride_scg, stride_scn, + stride_zpg, stride_zpn, + stride_om, stride_on, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr = 1, # Use for split-k reduction +): + # Program & tile coordinates + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) # for SPLIT_K + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # Global tile spans + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # Adjust for SPLIT_K + k_max = (pid_k + 1) * BLOCK_SIZE_K + if k_max > K: + k_max = K + # Clamp inside kernel + mask_k = offs_k < K + mask_m = offs_m < M + mask_n = offs_n < N + + # Pointers in batch offset + x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) + qw_ptrs = qw_ptr + ((offs_k[None, :] // 8) * stride_qwk + offs_n[:, None] * stride_qwn) + sc_ptrs = sc_ptr + ((offs_n[:, None] // group_size) * stride_scg + offs_n[:, None] * stride_scn) + zp_ptrs = zp_ptr + ((offs_n[:, None] // group_size) * stride_zpg + offs_n[:, None] * stride_zpn) + + accum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + a = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + qwp = tl.load(qw_ptrs, mask=mask_k[None, :] & mask_n[:, None], other=0) + scales = tl.load(sc_ptrs, mask=mask_n[:, None], other=0.0) + zps = tl.load(zp_ptrs, mask=mask_n[:, None], other=0.0) + + # unpack 8x INT4 per int32 + local_offs = (offs_k % 8) * 4 # (BLOCK_SIZE_K,) -> 0,4,8,...,28 + q4_mask = 0xF # 4 bits + qw_int4 = (qwp >> local_offs) & q4_mask + qw_fp = (qw_int4.to(tl.float32) - zps) * scales + + # accumulate matmul + accum += tl.dot(a, qw_fp) + + # advance + x_ptrs += BLOCK_SIZE_K * stride_xk + qw_ptrs += (BLOCK_SIZE_K // 8) * stride_qwk + + if SPLIT_K > 1: + o_blk_ptrs = o_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + mask_mn = mask_m[:, None] & mask_n[None, :] + tl.atomic_add(o_blk_ptrs, accum, mask=mask_mn) + else: + o_blk_ptrs = o_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + mask_mn = mask_m[:, None] & mask_n[None, :] + tl.store(o_blk_ptrs, accum, mask=mask_mn) + + +# ------------------------------------------------------------------ +# High-level wrapper launching quantized matmul kernel +# ------------------------------------------------------------------ +def matmul_dequantize_int4_s2(x: torch.Tensor, + qweight_int32: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + split_k: int = 1): + """ + Launch INT4 GEMM: x @ dequantize(qw) + + x : (M, K) fp16/bf16 + qw : (K//8, N) int32 packed + scale/zp : (num_groups, N) fp16/bf16 + Returns : (M, N) fp16/bf16 + """ + assert x.dim() == 2 + assert qweight_int32.dim() == 2 + assert scale.dim() == 2 + assert zero_point.dim() == 2 + M, K = x.shape + K8, N = qweight_int32.shape + assert K == K8 * 8 + + group_size = K // scale.shape[0] + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), split_k) + + matmul_kernel[grid]( + x, qweight_int32, scale, zero_point, output, + M, N, K, + x.stride(0), x.stride(1), + qweight_int32.stride(0), qweight_int32.stride(1), + scale.stride(0), scale.stride(1), + zero_point.stride(0), zero_point.stride(1), + output.stride(0), output.stride(1), + group_size, + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=32, + SPLIT_K=split_k, + ) + if split_k > 1: + # Reduction here, currently left to caller + pass + return output + + +# ------------------------------------------------------------------ +# Quantization utilities +# ------------------------------------------------------------------ +def quantize_int4(x: torch.Tensor, group_size: int = 128): + """ + Convert fp tensor to INT4 packed (8 int4 per int32), return (qint32, scale, zero_point) + """ + *rest, N = x.shape + x = x.reshape(-1, N).contiguous() + + pad = (group_size - (N % group_size)) % group_size + if pad: + x = torch.nn.functional.pad(x, (0, pad)) + x = x.view(-1, group_size) + + # per-group min/max -> scale, zero-pt + x_min = x.min(dim=-1, keepdim=True)[0] + x_max = x.max(dim=-1, keepdim=True)[0] + x_max = torch.max(x_max, x_min + 1e-7) # ensure non-degenerate + scale = (x_max - x_min) / 15.0 + zero_point = (-x_min / scale) + + # quantize & clamp + xq = torch.round(x / scale + zero_point) + xq = xq.clamp(0, 15).to(torch.int32) + + # reshape back to packed layout + xq = xq.view(-1) + # pack 8 int4 into int32 (order: lowest 4 bits first) + num_i32 = xq.numel() // 8 + qw = torch.zeros(num_i32, dtype=torch.int32, device=x.device) + for shift in range(8): + qw |= (xq[shift::8] << (shift * 4)) + + # reshape back to original mapping + qw = qw.view(*rest, -1) + scale = scale.view(*rest, -1) + zero_point = zero_point.view(*rest, -1) + return qw, scale, zero_point + + +def unpack_int4(q_packed: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor): + """ + (for testing) map packed INT4 back to fp32 tensor + """ + shape = q_packed.shape[:-1] + (-1,) # [-1] already N//2 for 4-bit + qw = q_packed.view(-1) + out = torch.zeros(qw.numel() * 8, dtype=torch.float32, device=qw.device) + for shift in range(8): + unpacked = (qw >> (shift * 4)) & 0xF + out[shift::8] = unpacked.to(torch.float32) + out = (out - zero_point.view(-1)) * scale.view(-1) + out = out.view(shape) + return out + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stdout new file mode 100644 index 0000000..87179ca --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_891149.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_891149 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py b/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py new file mode 100644 index 0000000..d933b46 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py @@ -0,0 +1,232 @@ + +import torch +import triton +import triton.language as tl + + +############################################## +# Triton kernel(s) +############################################## + +@triton.autotune( + configs=[ + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + offs_k_step = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_off = offs_m[:, None] * stride_am + offs_k_step[None, :] * stride_ak + b_off = (offs_k_step[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + bs_off = (offs_k_step[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn + bzp_off = (offs_k_step[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k_idx in range(0, num_pid_k): + kk = k_idx * BLOCK_SIZE_K * SPLIT_K + mask_k = (kk + offs_k_step[None, :]) < K + + a = tl.load(a_ptr + a_off, mask=offs_m[:, None] < M, other=0.0) + b = tl.load(b_ptr + b_off, mask=mask_k, other=0) + + bs = tl.load(bs_ptr + bs_off, mask=mask_k, other=1.0) + bzp = tl.load(bzp_ptr + bzp_off, mask=mask_k, other=0) + + shift = (kk + offs_k_step)[None, :] % 8 * 4 + shift_zp = offs_n[None, :] % 8 * 4 + + nib = (b >> shift) & 0xF + z = (bzp >> shift_zp) & 0xF + + b_deq = (nib - z) * bs + acc += tl.dot(a, b_deq.to(a.dtype)) + + a_off += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_off += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + bs_off += (BLOCK_SIZE_K * SPLIT_K // group_size) * stride_bsk + bzp_off += (BLOCK_SIZE_K * SPLIT_K // group_size) * stride_bzpk + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + mask_c = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask_c) + else: + tl.atomic_add(c_ptrs, acc, mask=mask_c) + + +############################################## +# Python wrappers +############################################## + +def matmul_dequantize_int4_s2( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int = 128 +) -> torch.Tensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + M, K = x.shape + N = scales.shape[1] + out = torch.empty((M, N), dtype=x.dtype, device=x.device) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, out, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + out.stride(0), out.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return out + + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]: + """ + quantize weight matrix (K, N) to INT4 packed (K, N//8) plus per_group scale/zero_point. + Scale / zero_point when per_group have shape (K//group_size, N). + """ + w = weights.to(torch.float32) + K, N = w.shape + assert N % group_size == 0, "N must be divisible by group_size" + + w_ = w.view(-1, group_size) + wmin = w_.min(dim=1, keepdim=True)[0] + wmax = w_.max(dim=1, keepdim=True)[0] + scales = (wmax - wmin) / 15.0 + zeros = (-wmin / scales).round().clamp(0, 15) + + quantized = torch.round((w_ - wmin) / scales).clamp(0, 15).to(torch.uint8) + + dw = quantized.shape[1] + packed = torch.zeros( + quantized.shape[0], + (dw + 1) // 2, # two nibbles per byte + dtype=torch.int32, + device=quantized.device + ) + + idx_even = torch.arange(0, dw, 2, device=quantized.device) + idx_odd = torch.arange(1, dw, 2, device=quantized.device) + + packed_even = quantized[:, idx_even] & 0xF + packed_odd = (quantized[:, idx_odd] & 0xF) << 4 + packed[:, idx_even // 2] = packed_even + packed_odd + + packed = packed.view(K, N // 8) + scales = scales.view(K, N // group_size).contiguous() + zeros = zeros.view(K, N // group_size).contiguous() + return packed.contiguous(), scales, zeros, None + + +def unpack_int4(weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + """ + revert packed (K, N//8) into fp tensor (K*8, N). + """ + Kq, N = weights.shape + K, _ = scales.shape + N_orig = scales.shape[1] * group_size + + unpacked = torch.empty(K, N * 8, dtype=scales.dtype, device=weights.device) + + flat = weights.view(-1) + + for b in range(8): + nib = (flat >> (b * 4)) & 0xF + unpacked[:, b::8] = nib.view(K, N) + + unpacked = unpacked.view(K, N * 8) + scales = scales.view(-1, N).repeat_interleave(group_size, dim=0) + zeros = zeros.view(-1, N).repeat_interleave(group_size, dim=0) + return (unpacked.float() - zeros) * scales + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stdout new file mode 100644 index 0000000..82f1811 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_912380.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_912380 due to Index put requires the source and destination dtypes match, got Int for the destination and Byte for the source. diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py b/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py new file mode 100644 index 0000000..d407859 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py @@ -0,0 +1,299 @@ + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K', 'NO_GROUPS'], +) +@triton.jit +def gptq_gemm_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, NO_GROUPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + bits = 4 + infearure_per_bits = 8 + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + + scales_ptrs = scales_ptr + offs_bn * stride_scales_n + zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) + + shifter = ((offs_k % infearure_per_bits) * bits)[:, None] + zeros_shift = ((offs_bn % infearure_per_bits) * bits)[None, :] + + if NO_GROUPS: + scales = tl.load(scales_ptrs) + zeros = tl.load(zeros_ptrs) + zeros_int = (zeros >> zeros_shift) & 0xF + zeros = zeros_int * scales + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + a = tl.load(a_ptrs, mask=offs_am[:, None] < M, other=0.0) + b_i32 = tl.load(b_ptrs) + b_u8 = (b_i32 >> shifter) & 0xF + b_fp = b_u8.to(tl.float32) + + if not NO_GROUPS: + g_id = k // (groupsize // BLOCK_SIZE_K) + ptr_s = scales_ptrs + g_id * stride_scales_g + ptr_z = zeros_ptrs + g_id * stride_zeros_g + scales = tl.load(ptr_s) + zeros = tl.load(ptr_z) + zeros_int = (zeros >> zeros_shift) & 0xF + zeros = zeros_int * scales + + b = b_fp * scales - zeros + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator, mask=mask) + + +def matmul_dequantize_int4_gptq(x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size) -> torch.FloatTensor: + assert x.dim() == 2 and qweight.dim() == 2 + assert x.shape[-1] == (qweight.shape[0] * 8), "x inner dim mismatch" + assert x.is_contiguous(), "x must be contiguous" + + M, K = x.shape + N = qweight.shape[1] + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + def grid(META): return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + gptq_gemm_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0) if scales.dim() > 1 else 0, + scales.stride(1), + qzeros.stride(0) if qzeros.dim() > 1 else 0, + qzeros.stride(1), + group_size, group_size == K, + ) + return output + + +configs_s2 = [ + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +] + + +@triton.autotune(configs=configs_s2, key=['M', 'N', 'K']) +@triton.jit +def matmul_dequantize_int4_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + M, N, K, + groupsize, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + GROUP_SIZE_M_local = 8 + num_pid_in_group = GROUP_SIZE_M_local * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M_local + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M_local) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = (pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_cur = offs_k[None, :] + k * BLOCK_SIZE_K * SPLIT_K + valid_k = k_cur < K + + scale_ptrs = scales_ptr + (k_cur // groupsize) * stride_scales_g + offs_n[None, :] * stride_scales_n + zeros_ptrs = zeros_ptr + (k_cur // groupsize) * stride_zeros_g + (offs_n[None, :] // 8) * stride_zeros_n + + a = tl.load(a_ptrs, mask=valid_k, other=0.0) + b_i32 = tl.load(b_ptrs, mask=valid_k, other=0) + + scales = tl.load(scale_ptrs, mask=valid_k, other=0.0) + zeros = tl.load(zeros_ptrs, mask=valid_k, other=0) + + b_shift = (k_cur % 8) * 4 + zeros_shift = ((offs_n[None, :] % 8) * 4) + b_i4 = (b_i32 >> b_shift) & 0xF + zp_i4 = (zeros >> zeros_shift) & 0xF + b_fp = (b_i4 - zp_i4).to(tl.float16) * scales.to(tl.float16) + + acc += tl.dot(a.to(tl.float16), b_fp) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask) + else: + tl.atomic_add(c_ptrs, acc, mask=mask) + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + groupsize: int = 128) -> torch.FloatTensor: + assert x.is_contiguous() and qweight.is_contiguous() + M, K = x.shape + K_, N = qweight.shape + assert K * 8 // 4 == K_, "K dim mismatch" + assert scales.shape == zeros.shape == (N, K // groupsize) + + c = torch.empty((M, N), device=x.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) + matmul_dequantize_int4_kernel[grid]( + x, qweight, c, scales, qzeros, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + M, N, K, + groupsize, + ) + return c + + +def quantize_int4(w: torch.Tensor, groupsize: int = 128): + assert w.dim() == 2 + w = w.float() + oc, ic = w.shape + assert ic % groupsize == 0 + + w = w.reshape(oc, ic // groupsize, groupsize) + wmax = w.amax(dim=2, keepdim=True) + wmin = w.amin(dim=2, keepdim=True) + scale = (wmax - wmin) / 15 + zero = (-wmin / scale).round().clamp(0, 15) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + + int_w = torch.round((w - wmin) / scale.unsqueeze(-1)).clamp(0, 15).to(torch.int8) + + out = torch.zeros(oc, ic // 8, dtype=torch.int32, device=w.device) + for i in range(0, ic, 8): + packed = 0 + for j in range(8): + packed |= int_w[:, i // groupsize, i % groupsize + j] << (j * 4) + out[:, i // 8] = packed + + return out.reshape(oc, -1), scale.half(), zero.half() + + +def unpack_int4(w_packed: torch.IntTensor, scale: torch.Tensor, zero: torch.Tensor, groupsize: int = 128): + oc, ic_int = w_packed.shape + ic = ic_int * 8 + w_bits = torch.empty(oc, ic, dtype=torch.float32, device=w_packed.device) + for i in range(ic): + shift = (i % 8) * 4 + w_bits[:, i] = torch.bitwise_and(torch.bitwise_right_shift(w_packed[:, i // 8], shift), 0xF).float() + + scale = scale.unsqueeze(-1).expand_as(w_bits) + zero = zero.unsqueeze(-1).expand_as(w_bits) + return (scale * w_bits - zero).half() + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stdout new file mode 100644 index 0000000..9629ab6 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_925632.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_925632 due to quantize_int4() got an unexpected keyword argument 'group_size' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py new file mode 100644 index 0000000..9c38158 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py @@ -0,0 +1,277 @@ + +import torch +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + k_offset = k * BLOCK_SIZE_K * SPLIT_K + mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k_offset < K) + mask_b = (offs_k[:, None] + k_offset < K) & (offs_bn[None, :] < N) + + a = tl.load(a_ptrs + k_offset * stride_ak, mask=mask_a, other=0.0) + b = tl.load(b_ptrs + (k_offset // 8) * stride_bk, mask=mask_b, other=0.0) + + group_idx = (offs_k[:, None] + k_offset) // group_size + bs_ptrs = bs_ptr + group_idx * stride_bsk + offs_bn[None, :] * stride_bsn + bzp_ptrs = bzp_ptr + group_idx * stride_bzpk + (offs_bn[None, :] // 8) * stride_bzpn + + bs = tl.load(bs_ptrs, mask=mask_b, other=0.0) + bzp = tl.load(bzp_ptrs, mask=mask_b, other=0.0) + + b_shift = ((offs_k[:, None] + k_offset) % 8) * 4 + bzp_shift = (offs_bn[None, :] % 8) * 4 + + int4_b = (b >> b_shift) & 0xF + int4_bzp = (bzp >> bzp_shift) & 0xF + + fp_b = ((int4_b - int4_bzp) * bs).to(tl.float16) + accumulator += tl.dot(a.to(tl.float16), fp_b) + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K > 1: + tl.atomic_add(c_ptrs, c, mask=mask_c) + else: + tl.store(c_ptrs, c, mask=mask_c) + +def matmul_dequantize_int4_s2( + x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128 +) -> torch.Tensor: + assert x.is_contiguous(), "Input x must be contiguous" + assert qweight.is_contiguous(), "qweight must be contiguous" + assert scales.is_contiguous(), "scales must be contiguous" + assert zeros.is_contiguous(), "zeros must be contiguous" + + M, K = x.shape + N = scales.shape[1] + + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, output, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size, + GROUP_SIZE_M=8, + SPLIT_K=1, + ) + return output + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 128}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE': 64}, num_stages=2, num_warps=4), + ], + key=['num_rows', 'num_cols'], +) +@triton.jit +def quantize_int4_kernel( + src_ptr, dst_ptr, scales_ptr, zeros_ptr, + num_rows, num_cols, + stride_sr, stride_sc, + stride_dr, stride_dc, + stride_scale, + BLOCK_SIZE: tl.constexpr, GROUP_SIZE: tl.constexpr, +): + row = tl.program_id(0) + group = tl.program_id(1) + cols_per_int32 = 8 + + group_start = group * GROUP_SIZE + group_end = tl.minimum(group_start + GROUP_SIZE, num_cols) + num_ints = (GROUP_SIZE + cols_per_int32 - 1) // cols_per_int32 + + col_offsets = group_start + tl.arange(0, BLOCK_SIZE) + + max_val = tl.full([BLOCK_SIZE], -float('inf'), dtype=tl.float32) + min_val = tl.full([BLOCK_SIZE], float('inf'), dtype=tl.float32) + + for offset in range(0, GROUP_SIZE, BLOCK_SIZE): + mask = (col_offsets + offset) < group_end + src_offs = src_ptr + row * stride_sr + (col_offsets + offset) * stride_sc + vals = tl.load(src_offs, mask=mask, other=0.0) + max_val = tl.where(mask, tl.maximum(max_val, vals), max_val) + min_val = tl.where(mask, tl.minimum(min_val, vals), min_val) + + max_val = tl.max(max_val) + min_val = tl.min(min_val) + + scale = (max_val - min_val) / 15.0 + zero = -min_val / scale + + scale_idx = row * (num_cols // GROUP_SIZE) + group + tl.store(scales_ptr + scale_idx, scale.to(tl.float16)) + tl.store(zeros_ptr + scale_idx, zero.to(tl.float16)) + + for offset in range(0, GROUP_SIZE, BLOCK_SIZE): + mask = (col_offsets + offset) < group_end + src_offs = src_ptr + row * stride_sr + (col_offsets + offset) * stride_sc + vals = tl.load(src_offs, mask=mask, other=0.0) + + q = tl.clamp((vals / scale + zero).to(tl.int32), 0, 15) + + int32_ptrs = dst_ptr + row * stride_dr + ((group_start + offset) // cols_per_int32) * stride_dc + + for i_offset in range(0, BLOCK_SIZE, cols_per_int32): + i = offset + i_offset + if i < GROUP_SIZE: + packed = tl.full([1], 0, dtype=tl.int32) + for ch in range(cols_per_int32): + idx = i_offset + ch + val = q[idx] if (group_start + i + ch) < num_cols else tl.full([], 0, dtype=tl.int32) + packed = tl.bitwise_or(packed, tl.left_shift(val & 0xF, ch * 4)) + addr = int32_ptrs + (i // cols_per_int32) * stride_dc + tl.store(addr, packed) + +def quantize_int4(weight: torch.Tensor, group_size: int = 128) -> tuple: + assert weight.dim() == 2, "weight must be 2D" + num_rows, num_cols = weight.shape + group_size = min(group_size, num_cols) + assert num_cols % group_size == 0 + + packed = torch.empty( + (num_rows, num_cols // 8), + dtype=torch.int32, + device=weight.device + ) + scales = torch.empty( + (num_rows, num_cols // group_size), + dtype=torch.float16, + device=weight.device + ) + zeros = torch.empty_like(scales) + + def grid(): + return ( + num_rows, + num_cols // group_size, + ) + + quantize_int4_kernel[grid]( + weight, packed, scales, zeros, + num_rows, num_cols, + weight.stride(0), weight.stride(1), + packed.stride(0), packed.stride(1), + scales.stride(0), + GROUP_SIZE=group_size, + ) + return packed, scales, zeros + +def unpack_int4(qw_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + assert qw_packed.dim() == 2 and scales.dim() == 2 and zeros.dim() == 2 + num_rows = qw_packed.size(0) + num_cols = scales.size(1) * group_size + weight = torch.empty((num_rows, num_cols), dtype=torch.float16, device=qw_packed.device) + + for row in range(num_rows): + for group in range(scales.size(1)): + scale = scales[row, group].item() + zero = zeros[row, group].item() + start_col = group * group_size + end_col = start_col + group_size + for col in range(start_col, end_col, 8): + if (col // 8) >= qw_packed.size(1): + continue + packed = qw_packed[row, col // 8].item() + for k in range(8): + val = (packed >> (4 * k)) & 0xF + rescaled = val * scale + zero + idx = col + k + if idx < num_cols: + weight[row, idx] = rescaled + return weight + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stdout new file mode 100644 index 0000000..d09f3ed --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_927195.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_927195 due to module 'triton.language' has no attribute 'bitwise_or' diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py b/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py new file mode 100644 index 0000000..5d9a59e --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py @@ -0,0 +1,250 @@ + +import torch +import triton +import triton.language as tl + +# int4 de-quant helpers +@triton.jit +def _dequantize_int4_unpack(xi32, mask0=0x0f, mask1=0xf0): + xi0 = (xi32 & mask0).to(tl.int8) + xi1 = ((xi32 & mask1) >> 4).to(tl.int8) + return xi0, xi1 + + +@triton.jit +def _dequantize_int4_kernel(ptr, scales_ptr, zeros_ptr, M, N, + stride_q, stride_s, stride_z, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + q_offsets = (rm[:, None] * stride_q + (rn // 8)[None, :]) + scales_offsets = (rm[:, None] * stride_s + (rn // 8)[None, :]) + zeros_offsets = (rm[:, None] * stride_z + (rn // 8)[None, :]) + + mask_m = rm < M + mask_n = rn < N + mask = mask_m[:, None] & mask_n[None, :] + + packed = tl.load(ptr + q_offsets, mask=mask, other=0) + s = tl.load(scales_ptr + scales_offsets, mask=mask, other=1.0) + z = tl.load(zeros_ptr + zeros_offsets, mask=mask, other=0.0) + + offsets_0 = (rn % 8) * 4 + offsets_1 = offsets_0 + 4 + i0, i1 = _dequantize_int4_unpack(packed) + v0 = (i0.to(tl.float32) - z) * s + v1 = (i1.to(tl.float32) - z) * s + + return v0, v1 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_eval_k, stride_eval_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_k = tl.program_id(2) + + n_blocks_m = tl.cdiv(M, BLOCK_SIZE_M) + n_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + group_id = 0 + first_pid_m = 0 + else: + group_id = pid_m // GROUP_SIZE_M + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(n_blocks_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_m % group_size_m) + + if SPLIT_K > 1: + local_k = tl.cdiv(K, SPLIT_K) + k_offset = pid_k * local_k + else: + local_k = K + k_offset = 0 + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + scales_ptrs = scales_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // 8) * stride_eval_k) + offs_n[None, :] * stride_eval_n + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, local_k, BLOCK_SIZE_K): + if EVEN_K or (k + BLOCK_SIZE_K <= local_k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < local_k - k, other=0.0, eviction_policy="evict_last") + block_scale = tl.load(scales_ptrs, mask=offs_k[:, None] < local_k - k, other=1.0) + block_zero = tl.load(zeros_ptrs, mask=offs_k[:, None] < local_k - k, other=0.0) + + packed_b = tl.load(b_ptrs, mask=offs_k[:, None] < local_k - k, other=0) + k_idx = (offs_k[:, None] % 8) * 4 + val_low = (packed_b & 0x0F).to(tl.int8).to(tl.float32) + val_high = ((packed_b >> 4) & 0x0F).to(tl.int8).to(tl.float32) + b_low = (val_low - block_zero) * block_scale + b_high = (val_high - block_zero) * block_scale + + acc = tl.dot(a, b_low, acc) + a_shift = tl.load(a_ptrs + stride_bk * (1 if EVEN_K else 8), mask=offs_k[None, :] + 8 < local_k - k, other=0.0, eviction_policy="evict_last") + acc = tl.dot(a_shift, b_high, acc) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + scales_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + zeros_ptrs += (BLOCK_SIZE_K // 8) * stride_eval_k + + if SPLIT_K == 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptrs.type.element_ty), mask=c_mask) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + pid_k * M * N + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.atomic_add(c_ptrs, acc, mask=c_mask) + + +def matmul_dequantize_int4_s2(a, int4b_compressed, scales, zeros, M, N, K): + c_dtype = a.dtype + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + SPLIT_K = 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']), + SPLIT_K) + + if SPLIT_K > 1: + c = torch.empty((SPLIT_K, M, N), dtype=torch.float32, device=a.device) + else: + c = torch.empty((M, N), dtype=c_dtype, device=a.device) + + EVEN_K = K % 32 == 0 + + matmul_kernel[grid](a, int4b_compressed, c, + scales, zeros, + M, N, K, + a.stride(0), a.stride(1), + int4b_compressed.stride(0), int4b_compressed.stride(1), + c.stride(0) if c.dim() == 2 else c.stride(1), + c.stride(1) if c.dim() == 2 else c.stride(2), + scales.stride(0), scales.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + SPLIT_K=SPLIT_K, + EVEN_K=EVEN_K) + return c if SPLIT_K == 1 else c.sum(dim=0) + + +def quantize_int4(x: torch.Tensor, group_size: int = 128): + org_shape = x.shape + x = x.view(-1, group_size) + x_fp32 = x.float() + x_min = x_fp32.amin(dim=-1, keepdim=True) + x_max = x_fp32.amax(dim=-1, keepdim=True) + scale = (x_max - x_min) / 15.0 + zero = (-x_min / scale + 0.5).clamp(0, 15) + + int4 = (((x_fp32 - x_min) / scale + 0.5).clamp(0, 15)).round().to(torch.int8) + + int4 = int4.view(-1) + packed = torch.zeros(x.size(0), x.size(1) // 2, dtype=torch.int32, device=x.device) + int4_even = int4[::2].to(torch.int32) + int4_odd = int4[1::2].to(torch.int32) + packed = int4_even | (int4_odd << 4) + packed = packed.view(org_shape[0], org_shape[1] // 2) + scale = scale.squeeze(-1) + zero = zero.squeeze(-1) + return packed, scale, zero + + +def unpack_int4(packed: torch.Tensor, scale: torch.Tensor, zero: torch.Tensor, group_size: int = 128): + M, Nhalf = packed.shape + N = Nhalf * 2 + unpacked = torch.empty(M, N, dtype=torch.float16, device=packed.device) + + packed = packed.int() + for i in range(M): + for j in range(Nhalf): + low = (packed[i, j] & 0x0F).to(torch.float32) + high = ((packed[i, j] >> 4) & 0x0F).to(torch.float32) + group_idx = j * 2 // group_size + val_low = (low - zero[i, group_idx]) * scale[i, group_idx] + val_high = (high - zero[i, group_idx]) * scale[i, group_idx] + unpacked[i, 2 * j] = val_low.to(torch.float16) + unpacked[i, 2 * j + 1] = val_high.to(torch.float16) + return unpacked + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stdout new file mode 100644 index 0000000..1df9f8f --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_93329.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_93329 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py b/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py new file mode 100644 index 0000000..1e26b96 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py @@ -0,0 +1,229 @@ + +import torch +import triton +import triton.language as tl + +# ----------------- Triton Kernel (INT4 matrix multiply) ----------------- + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, + num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16, 'SPLIT_K': 1}, + num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16, 'SPLIT_K': 2}, + num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_start = pid_k * BLOCK_SIZE_K + offs_k = offs_k_start + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + accum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + cur_k = offs_k_start + k * BLOCK_SIZE_K * SPLIT_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = cur_k[None, :] < K + mask_n = offs_n[None, :] < N + load_a = tl.load(a_ptrs, mask=mask_k & (offs_m[:, None] < M), other=0.0) + packed_b = tl.load(b_ptrs, mask=mask_k & mask_n, other=0) + packed_b = packed_b.to(tl.int32) + + group_idx = cur_k[None, :] // group_size + scale_ptr = scales_ptr + offs_n[None, :] * stride_bsn + zero_ptr = zeros_ptr + (offs_n[None, :] // 8) * stride_bzpn + scale_ptr += group_idx * stride_bsk + zero_ptr += group_idx * stride_bzpk + + scale = tl.load(scale_ptr, mask=mask_k & mask_n, other=0.0) + zero_packed = tl.load(zero_ptr, mask=mask_k & mask_n, other=0) + zero_packed = zero_packed.to(tl.int32) + + shift = (cur_k[None, :] % 8) * 4 + zp_shift = (offs_n[None, :] % 8) * 4 + + int_b = (packed_b >> shift) & 0xF + int_zp = (zero_packed >> zp_shift) & 0xF + b = ((int_b.to(tl.float32) - int_zp.to(tl.float32)) * scale) + accum += tl.dot(load_a, b) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, accum, mask=mask) + else: + tl.atomic_add(c_ptrs, accum, mask=mask) + +# ----------------- Python helpers ---------------------------------------- + +def matmul_dequantize_int4_s2( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + group_size: int = 128 +) -> torch.Tensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + + M, K = x.shape + Kw, N = qweight.shape + assert Kw == K // 2, "Packed weight shape mismatch" + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + + matmul_kernel[grid]( + x, qweight, output, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size, + ) + return output + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert weights.dim() == 2 + K, N = weights.shape + assert K % group_size == 0 + num_groups = K // group_size + + flat = weights.float().view(num_groups, group_size, N) + mn, mx = flat.aminmax(dim=1) + scale = (mx - mn) / 15.0 + scale = torch.where(scale == 0, 1.0, scale) + zero = (-mn / scale) + q = ((flat / scale[:, :, None] + zero[:, :, None] + 0.5).floor()).clamp(0, 15).to(torch.int32) + + q = q.view(K, N) + q_low = q[::2] + q_high = q[1::2] + packed = (q_low & 0xF) | ((q_high & 0xF) << 4) + + return packed, scale.squeeze(1), zero.squeeze(1) + +def unpack_int4(w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + assert w.dtype == torch.int32 + K_qua, N = w.shape + K = K_qua * 2 + assert K % group_size == 0 + num_groups = K // group_size + + w0 = (w & 0xF).to(torch.float32) + w1 = ((w >> 4) & 0xF).to(torch.float32) + unpacked = torch.zeros(K, N, dtype=torch.float32, device=w.device) + unpacked[::2] = w0 + unpacked[1::2] = w1 + + scales = scales.view(num_groups, 1, N).repeat(1, group_size, 1).reshape(K, N) + zeros = zeros.view(num_groups, 1, N).repeat(1, group_size, 1).reshape(K, N) + return (unpacked - zeros) * scales + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stdout new file mode 100644 index 0000000..6ed9a2d --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_942564.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_942564 due to The size of tensor a (128) must match the size of tensor b (4096) at non-singleton dimension 1 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py b/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py new file mode 100644 index 0000000..6a02111 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py @@ -0,0 +1,205 @@ + +import torch +import triton +import triton.language as tl + +# ---------- Kernel ---------- + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + group_size, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr +): + pid = tl.program_id(0) + pid_k = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + BLOCK_K_S = BLOCK_SIZE_K * SPLIT_K + offs_k = pid_k * BLOCK_K_S + tl.arange(0, BLOCK_K_S) + + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K_S)): + k_actual = k * BLOCK_K_S + mask_k = k_actual + offs_k[None :] < K + a_mask = (offs_am[:, None] < M) & mask_k + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + b_int32 = tl.load(b_ptrs) + offs_k_shift = k_actual + offs_k[:, None] + group_idx = offs_k_shift // group_size + scales = tl.load(scales_ptr + group_idx * stride_scales_g + offs_bn[None, :] * stride_scales_n) + zeros = tl.load( + zeros_ptr + + group_idx * stride_zeros_g + + (offs_bn[None, :] // 8) * stride_zeros_n + ) + + shift = (offs_k_shift % 8) * 4 + zp_shift = (offs_bn[None, :] % 8) * 4 + + b_int4 = (b_int32 >> shift) & 0xF + b_zp = (zeros >> zp_shift) & 0xF + b_deq = (b_int4 - b_zp) * scales + acc += tl.dot(a.to(tl.float16), b_deq.to(tl.float16)) + + a_ptrs += BLOCK_K_S * stride_ak + b_ptrs += (BLOCK_K_S // 8) * stride_bk + + c = acc.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask_c) + else: + tl.atomic_add(c_ptrs, c, mask=mask_c) + +# ---------- Quantization/Utility ---------- + +def quantize_int4(w: torch.tensor, group_size: int = 128): + assert w.dim() == 2 + _, N = w.shape + w = w.view(-1, group_size, N) + wmin = w.amin(dim=1, keepdim=True) + wmax = w.amax(dim=1, keepdim=True) + scale = (wmax - wmin) / 15. + zero = (-wmin / scale).round().clamp(0, 15).to(torch.int32) + + wq = ((w / scale + zero).round().clamp(0, 15)).to(torch.int32).to(torch.uint8) + + packed = (wq[0::2] | (wq[1::2] << 4)).view(-1, N).contiguous() + + scale = scale.squeeze(1).contiguous() + zero = zero.squeeze(1).contiguous() + zero_packed = (zero[..., 0::2] | (zero[..., 1::2] << 4)).contiguous() + return packed, scale, zero_packed, None + +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output=None +) -> torch.FloatTensor: + assert x.is_contiguous() + M, Kin = x.shape + Kw, N = qweight.shape + K = Kw * 8 + assert Kin == K, "A inner dimension must match unpacked weight rows" + + if output is None: + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + else: + output.fill_(0.0) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + x, qweight, output, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size + ) + return output + +def unpack_int4(b_packed: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + b_packed = b_packed.contiguous() + Kw, N = b_packed.shape + K = Kw * 8 + b_uint = torch.zeros((K, N), dtype=torch.uint8, device=b_packed.device) + b_uint[0::2, :] = b_packed & 0xF + b_uint[1::2, :] = (b_packed >> 4) & 0xF + + group_idx = torch.arange(K, device=b_packed.device) // group_size + scl = scales[group_idx, :] + z = zeros[group_idx, :] + + z_expanded = torch.zeros_like(scales) + for i in range(N // 8): + z_expanded[:, i*8:(i+1)*8] = ((z[:, i:i+1] >> (4 * torch.arange(8).to(z.device))) & 0xF) + + z_bound = z_expanded[group_idx, :] + fp = (b_uint.to(torch.float32) - z_bound) * scl + return fp + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stdout new file mode 100644 index 0000000..8ea4c22 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_977481.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_977481 due to A inner dimension must match unpacked weight rows diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py b/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py new file mode 100644 index 0000000..5875bf7 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py @@ -0,0 +1,214 @@ + +import torch +import triton +import triton.language as tl + + +############################################## +# Triton kernel(s) +############################################## + +@triton.autotune( + configs=[ + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N', 'K'], + reset_to_zero=['c_ptr'] +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + bs_ptr, bzp_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_m = tl.where(offs_m < M, offs_m, 0) + offs_n = tl.where(offs_n < N, offs_n, 0) + + offs_k_step = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k_step[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k_step[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k_idx in range(0, num_pid_k): + kk = k_idx * BLOCK_SIZE_K * SPLIT_K + mask_k = kk + offs_k_step[None, :] + + a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (mask_k < K), other=0.0) + b = tl.load(b_ptrs, mask=mask_k < K, other=0) + + offs_gp = (kk + offs_k_step)[None, :] // group_size + bs_ptrs = bs_ptr + offs_gp * stride_bsk + offs_n[None, :] * stride_bsn + bzp_ptrs = bzp_ptr + offs_gp * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn + + bs = tl.load(bs_ptrs, mask=mask_k < K, other=1.0) + bzp = tl.load(bzp_ptrs, mask=mask_k < K, other=0) + + mask_n = offs_n[None, :] + shift = (kk + offs_k_step)[:, None] % 8 * 4 + shift_zp = mask_n % 8 * 4 + + nib = (b >> shift) & 0xF + zp = (bzp >> shift_zp) & 0xF + fp_b = (nib - zp) * bs + acc += tl.dot(a, fp_b.to(a.dtype)) + + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_cm < M + mask_n = offs_cn < N + offs_cm = tl.where(mask_m, offs_cm, 0) + offs_cn = tl.where(mask_n, offs_cn, 0) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = mask_m[:, None] & mask_n[None, :] + if SPLIT_K == 1: + tl.store(c_ptrs, acc, mask=mask) + else: + tl.atomic_add(c_ptrs, acc, mask=mask) + + +############################################## +# Python wrappers +############################################## + +def matmul_dequantize_int4_s2( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int = 128 +) -> torch.Tensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + M, K = x.shape + N = scales.shape[1] + scales = scales.contiguous() + qzeros = qzeros.contiguous() + out = torch.empty((M, N), dtype=x.dtype, device=x.device) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + 1, + ) + matmul_kernel[grid]( + x, qweight, out, + scales, qzeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + out.stride(0), out.stride(1), + scales.stride(0), scales.stride(1), + qzeros.stride(0), qzeros.stride(1), + group_size, + ) + return out + + +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]: + w = weights.to(torch.float32) + K, N = w.shape + assert K % group_size == 0 + + w_ = w.view(K // group_size, group_size, N) + wmin = w_.amin(dim=1, keepdim=True) + wmax = w_.amax(dim=1, keepdim=True) + scales = ((wmax - wmin) / 15.0).squeeze(1) + zeros = ((-wmin) / scales).round().clamp(0, 15).squeeze(1) + + quantized = torch.round((w_ - wmin) / scales).clamp(0, 15).to(torch.uint8) + quantized = torch.bitwise_and(quantized, 0xF) + quantized = quantized.view(K, N) + + packed = torch.zeros(K, N // 8, dtype=torch.int32, device=w.device) + for i in range(8): + packed |= (quantized[:, i::8] << (4 * i)).to(torch.int32) + return packed.contiguous(), scales.contiguous(), zeros.contiguous(), None + + +def unpack_int4(weights: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, group_size: int = 128) -> torch.Tensor: + Kq, N = weights.shape + K = Kq * 1 + unpacked = torch.zeros(K, N * 8, device=weights.device, dtype=scales.dtype) + w_flat = weights.view(-1) + for b in range(8): + nib = (w_flat >> (b * 4)) & 0xF + unpacked.view(-1)[b::8] = nib.float() + unpacked = unpacked.view(K, N * 8) + scales = scales.view(-1, N).repeat_interleave(group_size, dim=0)[:K] + zeros = zeros.view(-1, N).repeat_interleave(group_size, dim=0)[:K] + return (unpacked - zeros) * scales + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stdout new file mode 100644 index 0000000..ae078c8 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_991002.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_991002 due to The size of tensor a (128) must match the size of tensor b (32) at non-singleton dimension 1 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py b/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py new file mode 100644 index 0000000..07814f5 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py @@ -0,0 +1,234 @@ + +import torch +import triton +import triton.language as tl + +# ========================= +# Triton kernel (batched INT4 matrix multiply) +# ========================= +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'SPLIT_K': 2}, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bsk, stride_bsn, + stride_bzpk, stride_bzpn, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr): + + pid = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k_step = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + k_step = BLOCK_SIZE_K * SPLIT_K + k_last = min((pid_k + 1) * BLOCK_SIZE_K, K) + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k_step[None, :] * stride_ak + b_ptrs = b_ptr + (offs_k_step[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, k_step)): + k_off = k * k_step + a_mask = (offs_k_step[None, :] + k_off < K) & (offs_m[:, None] < M) + b_mask = (offs_k_step[:, None] + k_off < K) & (offs_n[None, :] < N) + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + packed = tl.load(b_ptrs, mask=b_mask, other=0) + + gidx = ((offs_k_step[:, None] + k_off) // group_size)[:, 0] + gidx = tl.view(gidx, (BLOCK_SIZE_K, 1)) + scales = tl.load(scales_ptr + gidx * stride_bsk + offs_n[None, :] * stride_bsn, + mask=b_mask, other=0) + + bzp = tl.load(zeros_ptr + gidx * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn, + mask=b_mask, other=0) + + shift = ((offs_k_step[:, None] + k_off) % 8) * 4 + int_b = ((packed >> shift) & 0xF).to(tl.float32) + + zp_shift = (offs_n[None, :] % 8) * 4 + int_zp = ((bzp >> zp_shift) & 0xF).to(tl.float32) + + b = (int_b - int_zp) * scales + acc += tl.dot(a, b) + + a_ptrs += k_step * stride_ak + b_ptrs += (k_step // 8) * stride_bk + + c = acc.to(c_ptr.dtype.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=mask_c) + else: + tl.atomic_add(c_ptrs, c, mask=mask_c) + +# ========================= +# Front-end helpers +# ========================= +def quantize_int4(weights: torch.Tensor, group_size: int = 128) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor]: + assert weights.dim() == 2 + K, N = weights.shape + assert K % group_size == 0 + + num_groups = K // group_size + w_groups = weights.view(num_groups, group_size, N) + wmin, wmax = w_groups.aminmax(dim=1) + scale = (wmax - wmin) / 15.0 + scale = torch.where(scale == 0, + torch.tensor(1.0, dtype=scale.dtype, device=scale.device), + scale) + zero_fp = -wmin / scale + q = ((w_groups / scale.unsqueeze(1) + zero_fp.unsqueeze(1) + 0.5) + .floor().clamp(0, 15).to(torch.int32)) + + q = q.view(K, N) + packed = torch.empty((K // 8, N), dtype=torch.int32, device=weights.device) + for k in range(0, 8): + packed |= (q[k::8] & 0xF) << (k * 4) + + zero_int = zero_fp.round().int().clamp(0, 15) + zeros_packed = torch.empty((num_groups, N // 8), dtype=torch.int32, device=weights.device) + for n8 in range(0, 8): + zeros_packed |= ((zero_int.view(num_groups * N)[n8::8] & 0xF) + << (n8 * 4)) + zeros_packed = zeros_packed.view(num_groups, N // 8) + + return packed, scale, zeros_packed + + +def unpack_int4(w, scales, zeros, group_size: int = 128): + K = w.shape[0] * 8 + N = w.shape[1] + assert w.ndim == 2 and scales.ndim == 2 and zeros.ndim == 2 + num_groups = scales.size(0) + + deq = torch.zeros((K, N), dtype=torch.float32, device=w.device) + for k in range(K): + for n in range(N): + k_block = k // 8 + k_nibble = k % 8 + val = (w[k_block, n] >> (k_nibble * 4)) & 0xF + group = k // group_size + gp_n = n // 8 + znib = n % 8 + zp = (zeros[group, gp_n] >> (znib * 4)) & 0xF + deq[k, n] = (float(val) - float(zp)) * scales[group, n] + return deq + + +def matmul_dequantize_int4_s2(x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + zeros: torch.IntTensor, + group_size: int = 128) -> torch.FloatTensor: + assert x.is_contiguous() + assert qweight.is_contiguous() + M, K = x.shape + Kw, N = qweight.shape + assert K == Kw * 8 + + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + matmul_kernel[grid]( + x, qweight, output, + scales, zeros, + M, N, K, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + group_size, + ) + return output + + +################################################################################################################################################## + + + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + + group_size = 128 + + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + + + # Test case + + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + + + results = { + + "test_case_1": triton_output + + } + + + + return results + + + +result_gold = test_correct_int4_s2() diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stderr b/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stdout b/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stdout new file mode 100644 index 0000000..dd6bd75 --- /dev/null +++ b/src/temp/gen/int4_matmul.py_gen_triton_code_995030.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module int4_matmul.py_gen_triton_code_995030 due to The size of tensor a (512) must match the size of tensor b (16384) at non-singleton dimension 1 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py new file mode 100644 index 0000000..04f121e --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py @@ -0,0 +1,137 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = 1.0 / tl.sqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x + + tl.store(DX + cols, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + dx = torch.empty_like(x) + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError(f"This l2 norm does not support feature dim >= {BLOCK_N}.") + + grid = lambda META: (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_143388.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py new file mode 100644 index 0000000..6477695 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py @@ -0,0 +1,139 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to the input + DY, # pointer to the output gradient + DX, # pointer to the input gradient + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x) * (1 / (var + eps)) * rstd * x + + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-5, +) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, dy.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_167554.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py new file mode 100644 index 0000000..6dcdcb2 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py @@ -0,0 +1,139 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to the input + DY, # pointer to the output gradient + DX, # pointer to the input gradient + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + mask = cols < N + dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) + dy = tl.where(cols < N, dy, 0.0) + gy = tl.sum(dy * x) + dx = dy * rstd - gy * (1 / (var + eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-5, +) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = x.shape + dx = torch.empty_like(x) + + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This l2 norm doesn't support feature dim >= 64KB.") + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_215639.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py new file mode 100644 index 0000000..02115ee --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py @@ -0,0 +1,148 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= M: + return + + x_ptr = X + pid * stride_x_row + dy_ptr = DY + pid * stride_dy_row + dx_ptr = DX + pid * stride_dx_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + x_sq = x * x + var = tl.sum(x_sq, axis=0) / N + rstd = tl.rsqrt(var + eps) + + x_hat = x * rstd + + term1 = dy * rstd + term2 = tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x + + dx = term1 - term2 + + tl.store(dx_ptr + cols, dx, mask=mask) + + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float, +) -> torch.Tensor: + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError("Feature dimension too large") + + dx = torch.empty_like(x) + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stdout new file mode 100644 index 0000000..7b8ba93 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_220059.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_220059 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py new file mode 100644 index 0000000..c128a9c --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py @@ -0,0 +1,140 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + if row >= M: + return + + offs_n = tl.arange(0, BLOCK_N) + mask = offs_n < N + + x_ptr = X + row * stride_x_row + dy_ptr = DY + row * stride_dy_row + dx_ptr = DX + row * stride_dx_row + + x = tl.load(x_ptr + offs_n, mask=mask, other=0.0) + dy = tl.load(dy_ptr + offs_n, mask=mask, other=0.0) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + sum_dy_x = tl.sum(dy * x, axis=0) + dx = dy * rstd - sum_dy_x * (1.0 / (var + eps)) * rstd * x / N + tl.store(dx_ptr + offs_n, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + shape = x.shape + x = x.reshape(-1, shape[-1]) + dy = dy.reshape(-1, shape[-1]) + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError(f"Feature dimension {N} cannot exceed {BLOCK_N}") + + dx = torch.empty_like(x) + n_rows = M + + grid = lambda META: (n_rows,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + dx = dx.reshape(shape) + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stdout new file mode 100644 index 0000000..2120130 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_28664.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_28664 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py new file mode 100644 index 0000000..e484149 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py @@ -0,0 +1,147 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + if row_id >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row_id * stride_x_row + cols + dy_ptrs = DY + row_id * stride_dy_row + cols + dx_ptrs = DX + row_id * stride_dx_row + cols + + x = tl.load(x_ptrs, mask=mask, other=0.0) + dy = tl.load(dy_ptrs, mask=mask, other=0.0) + + x_fp32 = x.to(tl.float32) + var = tl.sum(x_fp32 * x_fp32, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dy_fp32 = dy.to(tl.float32) + dot = tl.sum(dy_fp32 * x_fp32, axis=0) + coeff = dot * (1.0 / (var + eps)) * rstd + dx = dy_fp32 * rstd - coeff * x_fp32 + + dx_out = dx.to(DX.type.element_ty) + tl.store(dx_ptrs, dx_out, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + assert x.shape == dy.shape, "Shape mismatch between x and dy" + + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, x.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError("Feature dimension N too large for BLOCK_N") + + dx = torch.empty_like(x) + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + stride_x_row = x.stride(0) if x.stride(-1) == 1 else x.stride(-2) + stride_dy_row = dy.stride(0) if dy.stride(-1) == 1 else dy.stride(-2) + stride_dx_row = dx.stride(0) if dx.stride(-1) == 1 else dx.stride(-2) + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N=BLOCK_N, + ) + + dx = dx.view(original_shape) + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stdout new file mode 100644 index 0000000..ef7dd90 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_338946.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_338946 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py new file mode 100644 index 0000000..38df419 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py @@ -0,0 +1,146 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= M: + return + + x_ptr = X + pid * stride_x_row + dy_ptr = DY + pid * stride_dy_row + dx_ptr = DX + pid * stride_dx_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + x_sq = x * x + var = tl.sum(x_sq, axis=0) / N + rstd = tl.rsqrt(var + eps) + + term1 = dy * rstd + term2 = tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x + + dx = term1 - term2 + + tl.store(dx_ptr + cols, dx, mask=mask) + + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float, +) -> torch.Tensor: + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError("Feature dimension too large") + + dx = torch.empty_like(x) + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stdout new file mode 100644 index 0000000..0c64cfe --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_347725.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_347725 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py new file mode 100644 index 0000000..34d30ce --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py @@ -0,0 +1,170 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + # Compute the base pointers for this row + x_ptrs = X + row * stride_x_row + cols + dy_ptrs = DY + row * stride_dy_row + cols + dx_ptrs = DX + row * stride_dx_row + cols + + # Load X and DY for this row + x = tl.load(x_ptrs, mask=mask, other=0.0) + dy = tl.load(dy_ptrs, mask=mask, other=0.0) + + # Compute variance and reciprocal standard deviation + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + # Compute the gradient contribution + term1 = dy * rstd + term2 = tl.sum(dy * x, axis=0) * rstd / (var + eps) * x + + dx = term1 - term2 + + # Store the result + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float) -> torch.Tensor: + # Ensure inputs are contiguous + x = x.contiguous() + dy = dy.contiguous() + + # Get original shape + original_shape = x.shape + + # Reshape to 2D if necessary + M = x.numel() // original_shape[-1] + N = original_shape[-1] + x = x.view(M, N) + dy = dy.view(M, N) + + # Determine optimal block size + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + + # Ensure N doesn't exceed BLOCK_N + if N > BLOCK_N: + raise ValueError(f"Feature dimension {N} is too large") + + # Allocate output + dx = torch.empty_like(x) + + # Configure kernel launch + kernel_args = { + 'BLOCK_N': BLOCK_N, + } + + # Launch kernel + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + **kernel_args + ) + + # Restore original shape + dx = dx.view(original_shape) + + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stdout new file mode 100644 index 0000000..641cfcc --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_387667.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_387667 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py new file mode 100644 index 0000000..94313b6 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py @@ -0,0 +1,134 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, stride_x_row, stride_dy_row, stride_dx_row, N, eps, + BLOCK_N: tl.constexpr +): + pid_row = tl.program_id(0) + offs_n = tl.arange(0, BLOCK_N) + mask = offs_n < N + + x_ptrs = X + pid_row * stride_x_row + offs_n + dy_ptrs = DY + pid_row * stride_dy_row + offs_n + dx_ptrs = DX + pid_row * stride_dx_row + offs_n + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + mean = tl.sum(x) / N + var = tl.sum((x - mean) ** 2) / N + rstd = tl.math.rsqrt(var + eps) + + gy = tl.sum(dy * x) + dx = dy * rstd - gy * (1.0 / (var + eps)) * rstd * x + dx = dx.to(DX.type.element_ty) + + tl.store(dx_ptrs, dx, mask=mask) + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + dx: torch.Tensor, + N: int, + eps: float, +): + M = x.numel() // N + x = x.view(M, N) if x.stride(-1) != 1 else x + dy = dy.view(M, N) if dy.stride(-1) != 1 else dy.contiguous() + x = x.contiguous() + dy = dy.contiguous() + dx = dx.view(M, N) if dx.stride(-1) != 1 else dx + dx = dx.contiguous() + + max_block_n = triton.next_power_of_2(N) + BLOCK_N = max_block_n + if N > BLOCK_N: + raise ValueError(f"Feature dimension {N} exceeds maximum block size {BLOCK_N}") + + grid = (triton.cdiv(M, 1),) + _l2_norm_bwd_kernel[grid]( + x, dy, dx, + x.stride(0), dy.stride(0), dx.stride(0), + N, eps, + BLOCK_N=BLOCK_N + ) + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stdout new file mode 100644 index 0000000..018c020 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_404776.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_404776 due to _l2_norm_bwd() missing 3 required positional arguments: 'dx', 'N', and 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py new file mode 100644 index 0000000..418c292 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + dy_ptrs = DY + row * stride_dy_row + cols + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (rstd * rstd * rstd / N) * x + + dx_ptrs = DX + row * stride_dx_row + cols + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + shape = x.shape + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + M, N = x.shape + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + + if N > BLOCK_N: + raise ValueError("N exceeds BLOCK_N") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, M, N, eps, + x.stride(0), dy.stride(0), dx.stride(0), + BLOCK_N=BLOCK_N + ) + + return dx.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stdout new file mode 100644 index 0000000..fd418ac --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_414029.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_414029.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py new file mode 100644 index 0000000..418c292 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + dy_ptrs = DY + row * stride_dy_row + cols + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (rstd * rstd * rstd / N) * x + + dx_ptrs = DX + row * stride_dx_row + cols + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + shape = x.shape + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + M, N = x.shape + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + + if N > BLOCK_N: + raise ValueError("N exceeds BLOCK_N") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, M, N, eps, + x.stride(0), dy.stride(0), dx.stride(0), + BLOCK_N=BLOCK_N + ) + + return dx.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stdout new file mode 100644 index 0000000..56bd86f --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_419949.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_419949.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py new file mode 100644 index 0000000..a7270f3 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py @@ -0,0 +1,128 @@ + +import triton +import triton.language as tl +import torch + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, _N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < _N + + x_ptrs = X + row * stride_x_row + cols + dy_ptrs = DY + row * stride_dy_row + cols + dx_ptrs = DX + row * stride_dx_row + cols + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + N = tl.sum(mask.to(tl.float32)) + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x / N + dx_out = dx.to(DX.dtype.element_ty) + + tl.store(dx_ptrs, dx_out, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + if x.shape != dy.shape: + raise ValueError("Shapes of x and dy must match") + + original_shape = x.shape + x = x.reshape(-1, x.shape[-1]).contiguous() + dy = dy.reshape(-1, dy.shape[-1]).contiguous() + + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise RuntimeError("Feature dimension N too large") + + dx = torch.empty_like(x) + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + + return dx.view(original_shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stdout new file mode 100644 index 0000000..b1bac67 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_433589.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_433589.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py new file mode 100644 index 0000000..3295b26 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py @@ -0,0 +1,149 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + if row_id >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row_id * stride_x_row + cols + dy_ptrs = DY + row_id * stride_dy_row + cols + dx_ptrs = DX + row_id * stride_dx_row + cols + + x = tl.load(x_ptrs, mask=mask, other=0.0) + dy = tl.load(dy_ptrs, mask=mask, other=0.0) + + x_fp32 = x.to(tl.float32) + var = tl.sum(x_fp32 * x_fp32, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dy_fp32 = dy.to(tl.float32) + dot = tl.sum(dy_fp32 * x_fp32, axis=0) + coeff = dot * (1.0 / (var + eps)) * rstd + dx = dy_fp32 * rstd - coeff * x_fp32 + + dx_out = dx.to(DX.type.element_ty) + tl.store(dx_ptrs, dx_out, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + assert x.shape == dy.shape, "Shape mismatch between x and dy" + + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, x.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError("Feature dimension N too large for BLOCK_N") + + dx = torch.empty_like(x) + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + stride_x_row = x.stride(0) if x.stride(-1) == 1 else x.stride(-2) + stride_dy_row = dy.stride(0) if dy.stride(-1) == 1 else dy.stride(-2) + stride_dx_row = dx.stride(0) if dx.stride(-1) == 1 else dx.stride(-2) + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N=BLOCK_N, + ) + + dx = dx.view(original_shape) + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stdout new file mode 100644 index 0000000..51741ee --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_459560.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_459560 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py new file mode 100644 index 0000000..46165e8 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py @@ -0,0 +1,147 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + x_masked = tl.where(mask, x, 0.0) + var = tl.sum(x_masked * x_masked) + rstd = tl.math.rsqrt(var + eps) + + dy_masked = tl.where(mask, dy, 0.0) + term1 = dy_masked * rstd + sum_dyx = tl.sum(dy_masked * x_masked) + term2 = sum_dyx * rstd * rstd * rstd * x_masked + dx = term1 - term2 + + tl.store(DX + cols, dx, mask=mask) + + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-5, +): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This L2 norm backward does not support feature dim >= 64KB.") + + # ROCm-friendly launch grid & stream model + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_486455.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py new file mode 100644 index 0000000..d3ad64a --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py @@ -0,0 +1,136 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, # *tl.pointer + DY, # *tl.pointer + DX, # *tl.pointer + stride_x_row, # tl.int32 + N, # tl.int32 + eps, # tl.float32 + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = 1.0 / tl.sqrt(var + eps) + + dx = dy * rstd - (tl.sum(dy * x, axis=0) / N) * rstd * rstd * rstd * x + + tl.store(DX + cols, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + dx = torch.empty_like(x) + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError(f"This l2 norm does not support feature dim >= {BLOCK_N}.") + with torch.cuda.device(x.device.index): + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stdout new file mode 100644 index 0000000..6b921a9 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_493519.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_493519 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py new file mode 100644 index 0000000..656a80b --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py @@ -0,0 +1,136 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel(X, DY, DX, M, N, stride_x_row, eps, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + if pid_m >= M: + return + + offs_n = tl.arange(0, BLOCK_N) + mask = offs_n < N + x_ptrs = X + pid_m * stride_x_row + offs_n + dy_ptrs = DY + pid_m * N + offs_n + dx_ptrs = DX + pid_m * N + offs_n + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x) / N + rstd = tl.math.rsqrt(var + eps) + + term = tl.sum(dy * x) * (1.0 / (var + eps)) * rstd + dx = dy * rstd - term * x + + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-12): + if x.dim() > 2: + x = x.reshape(-1, x.shape[-1]) + if dy.dim() > 2: + dy = dy.reshape(-1, dy.shape[-1]) + + M = x.shape[0] + N = x.shape[1] + dx = torch.empty_like(x) + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N ({N}) exceeds block size BLOCK_N ({BLOCK_N}).") + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + x.stride(0), + eps, + BLOCK_N=BLOCK_N, + ) + + return dx.reshape(x.shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stdout new file mode 100644 index 0000000..6d368d5 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_570539.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_570539.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py new file mode 100644 index 0000000..418c292 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + dy_ptrs = DY + row * stride_dy_row + cols + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (rstd * rstd * rstd / N) * x + + dx_ptrs = DX + row * stride_dx_row + cols + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + shape = x.shape + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + M, N = x.shape + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + + if N > BLOCK_N: + raise ValueError("N exceeds BLOCK_N") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, M, N, eps, + x.stride(0), dy.stride(0), dx.stride(0), + BLOCK_N=BLOCK_N + ) + + return dx.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stdout new file mode 100644 index 0000000..a22ac98 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_597752.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_597752.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py new file mode 100644 index 0000000..3445c92 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py @@ -0,0 +1,147 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= M: + return + + x_ptr = X + pid * stride_x_row + dy_ptr = DY + pid * stride_dy_row + dx_ptr = DX + pid * stride_dx_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + x_sq = x * x + var = tl.sum(x_sq * mask.to(tl.float32), axis=0) / tl.sum(mask.to(tl.float32), axis=0) + rstd = tl.rsqrt(var + eps) + + term1 = dy * rstd + term2 = tl.sum(dy * x * mask.to(tl.float32), axis=0) / (tl.sum(mask.to(tl.float32), axis=0) * (var + eps)) * rstd * x + + dx = term1 - term2 + + tl.store(dx_ptr + cols, dx, mask=mask) + + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float, +) -> torch.Tensor: + M = x.shape[0] + N = x.shape[-1] + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError("Feature dimension too large") + + dx = torch.empty_like(x) + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + + return dx + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stdout new file mode 100644 index 0000000..9b10096 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_637799.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_637799 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py new file mode 100644 index 0000000..4feab41 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py @@ -0,0 +1,130 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + sum_dy_x = tl.sum(dy * x, axis=0) + dx = dy * rstd - sum_dy_x * rstd.pow(3) * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stdout new file mode 100644 index 0000000..9e02ec5 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_640557.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_640557 due to at 21:32: + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + sum_dy_x = tl.sum(dy * x, axis=0) + dx = dy * rstd - sum_dy_x * rstd.pow(3) * x + ^ +AttributeError("'tensor' object has no attribute 'pow'") diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py new file mode 100644 index 0000000..6f9f408 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + dy_ptrs = DY + row * stride_dy_row + cols + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (rstd * rstd * rstd / N) * x + + dx_ptrs = DX + row * stride_dx_row + cols + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + shape = x.shape + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + M, N = x.shape + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + + if N > BLOCK_N: + raise ValueError("N exceeds BLOCK_N") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, M, N, eps, + x.stride(0), dy.stride(0), dx.stride(0), + BLOCK_N=BLOCK_N + ) + + return dx.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stdout new file mode 100644 index 0000000..b537294 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_712104.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_712104 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py new file mode 100644 index 0000000..418c292 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, eps, stride_x_row, stride_dy_row, stride_dx_row, BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + dy_ptrs = DY + row * stride_dy_row + cols + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (rstd * rstd * rstd / N) * x + + dx_ptrs = DX + row * stride_dx_row + cols + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + shape = x.shape + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + M, N = x.shape + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + dx = torch.empty_like(x) + + BLOCK_N = triton.next_power_of_2(N) + + if N > BLOCK_N: + raise ValueError("N exceeds BLOCK_N") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, M, N, eps, + x.stride(0), dy.stride(0), dx.stride(0), + BLOCK_N=BLOCK_N + ) + + return dx.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stdout new file mode 100644 index 0000000..6c129f0 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_786715.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: l2_norm_bwd.py_gen_triton_code_786715.py diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py new file mode 100644 index 0000000..8e494f1 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py @@ -0,0 +1,134 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + sum_dy_x = tl.sum(dy * x, axis=0) + dx = dy * rstd - sum_dy_x * rstd * rstd * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-5 +) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_827439.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py new file mode 100644 index 0000000..03e94d0 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py @@ -0,0 +1,136 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, + M, N, eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr, +): + row_id = tl.program_id(0) + if row_id >= M: + return + + offs_n = tl.arange(0, BLOCK_N) + mask = offs_n < N + + x_ptrs = X + row_id * stride_x_row + offs_n + dy_ptrs = DY + row_id * stride_dy_row + offs_n + dx_ptrs = DX + row_id * stride_dx_row + offs_n + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + + mean = tl.sum(x, axis=0) / N + var = tl.sum((x - mean) * (x - mean), axis=0) / N + rstd = 1.0 / tl.math.sqrt(var + eps) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x + tl.store(dx_ptrs, dx, mask=mask) + + +def _l2_norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float, +): + shape = list(x.shape) + x = x.view(-1, shape[-1]) + dy = dy.view(-1, shape[-1]) + assert x.shape == dy.shape, "x and dy must have the same shape" + M, N = x.shape + dx = torch.empty_like(x) + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise RuntimeError(f"This l2 norm does not support feature dim >= {BLOCK_N}.") + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, + M, N, eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + return dx.view(*shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stdout new file mode 100644 index 0000000..044ef3a --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_843690.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_843690 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py new file mode 100644 index 0000000..c8e030e --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py @@ -0,0 +1,147 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, DY, DX, M, N, + stride_x_row, + stride_x_col, + stride_dy_row, + stride_dy_col, + stride_dx_row, + stride_dx_col, + eps, + BLOCK_N: tl.constexpr +): + row_idx = tl.program_id(0) + if row_idx >= M: + return + + offs_n = tl.arange(0, BLOCK_N) + mask = offs_n < N + + x_ptr = X + row_idx * stride_x_row + offs_n * stride_x_col + dy_ptr = DY + row_idx * stride_dy_row + offs_n * stride_dy_col + dx_ptr = DX + row_idx * stride_dx_row + offs_n * stride_dx_col + + x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptr, mask=mask, other=0.0).to(tl.float32) + + x_sum = tl.sum(x * x) + var = x_sum / N + rstd = tl.math.rsqrt(var + eps) + + grad_x_part = tl.sum(dy * x) + coef = (1.0 / (var + eps)) * rstd * grad_x_part + dx = dy * rstd - coef * x + + tl.store(dx_ptr, dx.to(tl.load(x_ptr).dtype), mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float): + assert x.shape == dy.shape, "x and dy must have the same shape" + if x.dim() > 2: + x = x.view(-1, x.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + orig_shape = x.shape + + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise RuntimeError(f"Feature dimension {N} exceeds allowed BLOCK_N={BLOCK_N}") + + x = x.contiguous() + dy = dy.contiguous() + + dx = torch.empty_like(x) + + grid = lambda META: (M,) + _l2_norm_bwd_kernel[grid]( + x, dy, dx, M, N, + x.stride(0), x.stride(1) if x.dim() == 2 else 0, + dy.stride(0), dy.stride(1) if dy.dim() == 2 else 0, + dx.stride(0), dx.stride(1) if dx.dim() == 2 else 0, + eps, + BLOCK_N=BLOCK_N + ) + + return dx.view(orig_shape) + + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stdout new file mode 100644 index 0000000..193c210 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_864396.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_bwd.py_gen_triton_code_864396 due to _l2_norm_bwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py new file mode 100644 index 0000000..7f42365 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py @@ -0,0 +1,135 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to the input + DY, # pointer to the output gradient + DX, # pointer to the input gradient + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + scale_k = 1.0 / (var + eps) + rstd = tl.math.sqrt(scale_k) + + dx = dy * rstd - tl.sum(dy * x, axis=0) * scale_k * rstd * x + tl.store(DX + cols, dx.to(DX.dtype.element_ty), mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + x_shape_og = x.shape + x = x.reshape(-1, x_shape_og[-1]) + dy = dy.reshape(-1, x_shape_og[-1]) + + if x.stride(1) != 1: + x = x.contiguous() + if dy.stride(1) != 1: + dy = dy.contiguous() + + M, N = x.shape + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This L2-norm backward doesn't support feature dim >= 64KB.") + + dx = torch.empty_like(x) + + _l2_norm_bwd_kernel[(M,)]( + x, dy, dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_885795.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py new file mode 100644 index 0000000..2c64313 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py @@ -0,0 +1,130 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + DY += row * stride_x_row + DX += row * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + x_masked = tl.where(mask, x, 0.0) + var = tl.sum(x_masked * x_masked) + rstd = 1 / tl.sqrt(var + eps) + dx = dy * rstd - tl.sum(dy * x_masked) * rstd / (var + eps) * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_960121.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py new file mode 100644 index 0000000..6081fc8 --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py @@ -0,0 +1,139 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to input + DY, # pointer to output grad + DX, # pointer to input grad + stride_x_row, # stride between rows + N, # number of columns + eps, # small epsilon + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + base_x = X + row * stride_x_row + base_dy = DY + row * stride_x_row + base_dx = DX + row * stride_x_row + + mask = cols < N + x = tl.load(base_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(base_dy + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = 1.0 / tl.sqrt(var + eps) + + coef = tl.sum(dy * x) * (1.0 / (var + eps)) * rstd + dx = dy * rstd - coef * x + + tl.store(base_dx + cols, dx, mask=mask) + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + + if x.stride(-1) != 1: + x = x.contiguous() + if dy.stride(-1) != 1: + dy = dy.contiguous() + + assert x.shape == dy.shape + M, N = x.shape + dx = torch.empty_like(x) + + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("Feature dimension too large for L2 backward kernel.") + + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the backward L2 normalization + +def test_l2_norm_bwd(): + + results = {} + + + + # Test case 1: Default case + + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_1'] = dx + + + + # Test case 2: Different shape + + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_2'] = dx + + + + # Test case 3: Larger tensor + + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_3'] = dx + + + + # Test case 4: Edge case with small tensor + + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + + dx = _l2_norm_bwd(x, dy) + + results['test_case_4'] = dx + + + + return results + + + +# Run the tests + +result_gold = test_l2_norm_bwd() diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stderr b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stdout b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_bwd.py_gen_triton_code_972847.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py new file mode 100644 index 0000000..8616379 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py @@ -0,0 +1,104 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + pid_m = tl.program_id(0) + offsets_n = tl.arange(0, BLOCK_N) + mask = offsets_n < N + + row_start = pid_m * stride_x_row + x_ptrs = X + row_start + offsets_n + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = rsqrt(var + eps) + y = x * rstd + + y_ptrs = Y + row_start + offsets_n + tl.store(y_ptrs, y.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + x = x.contiguous() + y = torch.empty_like(x) + M, N = x.shape + element_size = x.element_size() + BLOCK_N = triton.next_power_of_2(min(65536 // element_size, N)) + if N > BLOCK_N: + raise ValueError(f"N={N} exceeds BLOCK_N={BLOCK_N}") + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N + ) + return y.view(x.shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stdout new file mode 100644 index 0000000..e76157b --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_212491.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_212491 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py new file mode 100644 index 0000000..405dc9f --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py @@ -0,0 +1,116 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + row_id = tl.program_id(0) + row_start = X + row_id * stride_x_row + row_out_start = Y + row_id * stride_x_row + + col_offsets = tl.arange(0, BLOCK_N) + mask = col_offsets < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + x_vals = tl.load(row_start + col_offsets, mask=mask, other=0.0) + acc = x_vals * x_vals + var = tl.sum(acc) + rstd = 1.0 / tl.sqrt(var + eps) + + out_vals = x_vals * rstd + tl.store(row_out_start + col_offsets, out_vals, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float): + x = x.contiguous() + shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + + element_size = x.element_size() + MAX_FUSED = 65536 // element_size + BLOCK_N = min(MAX_FUSED, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError + + grid = lambda META: (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + + return y.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stdout new file mode 100644 index 0000000..0a67d23 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_254823.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_254823 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py new file mode 100644 index 0000000..542514a --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py @@ -0,0 +1,114 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + row_id = tl.program_id(0) + row_start = X + row_id * stride_x_row + row_out_start = Y + row_id * stride_x_row + + col_offsets = tl.arange(0, BLOCK_N) + mask = col_offsets < N + + x_vals = tl.load(row_start + col_offsets, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x_vals * x_vals) + rstd = 1.0 / tl.sqrt(var + eps) + + out_vals = x_vals * rstd + tl.store(row_out_start + col_offsets, out_vals, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-6): + x = x.contiguous() + shape = x.shape + if x.ndim > 2: + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + + element_size = x.element_size() + MAX_FUSED = 65536 // element_size + BLOCK_N = min(MAX_FUSED, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError("N too large") + + grid = lambda META: (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + + return y.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_318959.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py new file mode 100644 index 0000000..8193726 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py @@ -0,0 +1,112 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + row_idx = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row_idx * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + y = x * rstd + + y_ptrs = Y + row_idx * stride_x_row + cols + tl.store(y_ptrs, y.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + N = x.size(-1) + M = x.size(0) + + element_size = x.element_size() + MAX_FUSED_SIZE = 65536 // element_size + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("this kernel does not support N >= 64KB") + + y = torch.empty_like(x) + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), N, eps, + BLOCK_N + ) + return y.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_336206.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py new file mode 100644 index 0000000..7ec7cd0 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py @@ -0,0 +1,114 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_N) + + row_start_ptr = X + row_id * stride_x_row + mask = col_offsets < N + + x = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + y = x * rstd + + tl.store(Y + row_id * stride_x_row + col_offsets, y, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + if x.stride(-1) != 1: + x = x.contiguous() + + original_shape = x.shape + x = x.view(-1, x.shape[-1]) + M, N = x.shape + + elem_size = x.element_size() + BLOCK_N = min(65536 // elem_size, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"N ({N}) exceeds max BLOCK_N ({BLOCK_N})") + + y = torch.empty((M, N), dtype=x.dtype, device=x.device) + + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, + y, + stride_x_row=x.stride(0), + N=N, + eps=eps, + BLOCK_N=BLOCK_N, + ) + + y = y.view(original_shape) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_357644.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py new file mode 100644 index 0000000..f81044e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + x_ptrs = X + row * stride_x_row + cols + mask = cols < N + + x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + x2 = x_block * x_block + var = tl.sum(x2, axis=0) + rstd = tl.math.rsqrt(var + eps) + + y_block = x_block * rstd + y_ptrs = Y + row * N + cols + tl.store(y_ptrs, y_block, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + assert len(x.shape) == 2, "L2 norm kernel requires 2-D input tensor" + if not x.is_contiguous(): + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(65536 // x.element_size(), triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N={N} larger than maximum allowed ({BLOCK_N})") + + with torch.cuda.device(x.device): # For ROCm we still use CUDA semantics + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stdout new file mode 100644 index 0000000..12d83dc --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_392963.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_392963 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py new file mode 100644 index 0000000..906dbe7 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py @@ -0,0 +1,124 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + # compute base pointers + x_ptr = X + row_id * stride_x_row + y_ptr = Y + row_id * stride_x_row + + # load and accumulate + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x) + rstd = tl.math.rsqrt(var + eps) + + # normalize and store + y = x * rstd + tl.store(y_ptr + cols, y.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float): + if x.dim() == 3: + B, M, N = x.shape + x_view = x.view(B * M, N) + elif x.dim() == 2: + B, M, N = 1, *x.shape + x_view = x + else: + raise ValueError("Unsupported input dim") + + x_contig = x_view.contiguous() + y = torch.empty_like(x_contig) + + BLOCK_N = triton.next_power_of_2(N) + byte_per_elem = x_contig.element_size() + upper = (65536 // byte_per_elem) + if BLOCK_N > upper: + BLOCK_N = (65536 // byte_per_elem) + + assert N <= BLOCK_N, f"Feature dim {N} exceeds limit {BLOCK_N}" + + _l2_norm_fwd_1pass_kernel[(x_contig.shape[0],)]( + x_contig, + y, + x_contig.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + + if x.dim() == 3: + return y.view(B, M, N) + return y.view(x.shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stdout new file mode 100644 index 0000000..748c48c --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_403404.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_403404 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py new file mode 100644 index 0000000..f81044e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + x_ptrs = X + row * stride_x_row + cols + mask = cols < N + + x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + x2 = x_block * x_block + var = tl.sum(x2, axis=0) + rstd = tl.math.rsqrt(var + eps) + + y_block = x_block * rstd + y_ptrs = Y + row * N + cols + tl.store(y_ptrs, y_block, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + assert len(x.shape) == 2, "L2 norm kernel requires 2-D input tensor" + if not x.is_contiguous(): + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(65536 // x.element_size(), triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N={N} larger than maximum allowed ({BLOCK_N})") + + with torch.cuda.device(x.device): # For ROCm we still use CUDA semantics + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stdout new file mode 100644 index 0000000..83d567e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_466457.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_466457 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py new file mode 100644 index 0000000..5175985 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py @@ -0,0 +1,111 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row_idx = tl.program_id(0) + row_off = row_idx * stride_x_row + col_idx = tl.arange(0, BLOCK_N) + mask = col_idx < N + x = tl.load(X + row_off + col_idx, mask=mask, other=0.0).to(tl.float32) + ssq = tl.sum(x * x) + rstd = tl.math.rsqrt(ssq + eps) + y = x * rstd + tl.store(Y + row_off + col_idx, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + x_shape_og = x.shape + x = x.view(-1, x_shape_og[-1]).contiguous() + M, N = x.shape + y = torch.empty_like(x) + + MAX_FUSED = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise NotImplementedError("N > BLOCK_N not handled in 1-pass kernel") + + grid = (M,) + with torch.cuda.device(x.device): + _l2_norm_fwd_1pass_kernel[grid]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return y.view(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_598128.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py new file mode 100644 index 0000000..10a6b6a --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py @@ -0,0 +1,112 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + row_idx = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptrs = X + row_idx * stride_x_row + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + y = x * rstd + + y_ptrs = Y + row_idx * stride_x_row + cols + tl.store(y_ptrs, y.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x = x.contiguous() + *shape, N = x.shape + x = x.view(-1, N) + M = x.size(0) + + element_size = x.element_size() + max_block = (64 * 1024) // element_size + BLOCK_N = triton.next_power_of_2(N) + if BLOCK_N > max_block: + raise ValueError(f"N ({N}) exceeds block size limit ({max_block})") + + y = torch.empty_like(x) + + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), N, eps, + BLOCK_N=BLOCK_N + ) + + return y.view(*shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stdout new file mode 100644 index 0000000..8afa060 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_599125.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_599125 due to shape '[4]' is invalid for input of size 32 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py new file mode 100644 index 0000000..bb6e661 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py @@ -0,0 +1,104 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + pid_m = tl.program_id(0) + offsets_n = tl.arange(0, BLOCK_N) + mask = offsets_n < N + + row_start = pid_m * stride_x_row + x_ptrs = X + row_start + offsets_n + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + y = x * rstd + + y_ptrs = Y + row_start + offsets_n + tl.store(y_ptrs, y.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + x = x.contiguous() + y = torch.empty_like(x) + M, N = x.shape + element_size = x.element_size() + BLOCK_N = triton.next_power_of_2(min(65536 // element_size, N)) + if N > BLOCK_N: + raise ValueError(f"N={N} exceeds BLOCK_N={BLOCK_N}") + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, x.stride(0), N, eps, BLOCK_N=BLOCK_N + ) + return y.view(x.shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_637798.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py new file mode 100644 index 0000000..9559e33 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py @@ -0,0 +1,113 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptr = X + row * stride_x_row + y_ptr = Y + row * stride_x_row + + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x) + rstd = tl.math.rsqrt(var + eps) + + y = x * rstd + tl.store(y_ptr + cols, y.to(Y.dtype.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + M, N = x.shape + + y = torch.empty_like(x) + + max_fused_size = 65536 // x.element_size() + BLOCK_N = min(max_fused_size, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError(f"This L2 norm does not support feature dim >= 64KB. Got: {N}") + + _l2_norm_fwd_1pass_kernel[(M,)]( + X=x, + Y=y, + stride_x_row=x.stride(0), + N=N, + eps=eps, + BLOCK_N=BLOCK_N, + ) + + return y.reshape(x_shape_og) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_650964.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py new file mode 100644 index 0000000..851612e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py @@ -0,0 +1,108 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + pid = tl.program_id(0) + row_start = pid * stride_x_row + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x = tl.load(X + row_start + cols, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) + rstd = tl.math.rsqrt(var + eps) + y = x * rstd + tl.store(Y + row_start + cols, y, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + x = x.contiguous() + shape = x.shape + if x.dim() > 2: + x = x.view(-1, shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + + element_size = x.element_size() + BLOCK_N = min(triton.next_power_of_2(N), 1024) + if BLOCK_N * element_size > 65536: + raise ValueError("BLOCK_N too large") + if N > BLOCK_N: + raise NotImplementedError("N > BLOCK_N not handled in 1-pass kernel") + + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N + ) + + return y.view(shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stdout new file mode 100644 index 0000000..3b6f1af --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_674736.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_674736 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py new file mode 100644 index 0000000..f81044e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + x_ptrs = X + row * stride_x_row + cols + mask = cols < N + + x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + x2 = x_block * x_block + var = tl.sum(x2, axis=0) + rstd = tl.math.rsqrt(var + eps) + + y_block = x_block * rstd + y_ptrs = Y + row * N + cols + tl.store(y_ptrs, y_block, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + assert len(x.shape) == 2, "L2 norm kernel requires 2-D input tensor" + if not x.is_contiguous(): + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(65536 // x.element_size(), triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N={N} larger than maximum allowed ({BLOCK_N})") + + with torch.cuda.device(x.device): # For ROCm we still use CUDA semantics + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stdout new file mode 100644 index 0000000..c1965cf --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_786517.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_786517 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py new file mode 100644 index 0000000..f81044e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + x_ptrs = X + row * stride_x_row + cols + mask = cols < N + + x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + x2 = x_block * x_block + var = tl.sum(x2, axis=0) + rstd = tl.math.rsqrt(var + eps) + + y_block = x_block * rstd + y_ptrs = Y + row * N + cols + tl.store(y_ptrs, y_block, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + assert len(x.shape) == 2, "L2 norm kernel requires 2-D input tensor" + if not x.is_contiguous(): + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(65536 // x.element_size(), triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N={N} larger than maximum allowed ({BLOCK_N})") + + with torch.cuda.device(x.device): # For ROCm we still use CUDA semantics + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stdout new file mode 100644 index 0000000..dbdea73 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_800477.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_800477 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py new file mode 100644 index 0000000..ce86c02 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_N) + + # Row pointer + row_start_ptr = X + row_id * stride_x_row + + # Initialize accumulators + var = tl.zeros([BLOCK_N], dtype=tl.float32) + mask = col_offsets < N + + # Load data + x = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # Compute variance (sum of squares) + var = x * x + var = tl.sum(var, axis=0) + + # Compute rstd + rstd = tl.math.rsqrt(var + eps) + + # Normalize and store + y = x * rstd + tl.store(Y + row_id * stride_x_row + col_offsets, y, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + # Ensure contiguous + if x.stride(-1) != 1: + x = x.contiguous() + + # Flatten to 2D (M, N) + original_shape = x.shape + x = x.view(-1, x.shape[-1]) + M, N = x.shape + + # Compute BLOCK_N + elem_size = x.element_size() + BLOCK_N = 2 ** (16 - elem_size.bit_length()) # Max 64KB / elem_size + BLOCK_N = min(BLOCK_N, triton.next_power_of_2(N)) + + if N > BLOCK_N: + raise ValueError(f"N ({N}) exceeds max BLOCK_N ({BLOCK_N})") + + # Prepare output + y = torch.empty((M, N), dtype=x.dtype, device=x.device) + + # Launch kernel + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + stride_x_row=x.stride(0), + N=N, + eps=eps, + BLOCK_N=BLOCK_N, + ) + + # Reshape back + y = y.view(original_shape) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stdout new file mode 100644 index 0000000..36e7c6e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_839169.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_839169 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py new file mode 100644 index 0000000..ec6caf7 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py @@ -0,0 +1,109 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_N) + mask = col_offsets < N + + X_ptr = X + row_id * stride_x_row + col_offsets + x = tl.load(X_ptr, mask=mask, other=0.0) + sum2 = x.to(tl.float32) * x.to(tl.float32) + var = tl.sum(sum2, axis=0) + rstd = tl.math.rsqrt(var + eps) + out = x * rstd + Y_ptr = Y + row_id * stride_x_row + col_offsets + tl.store(Y_ptr, out, mask=mask) + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + if x.dim() == 3: + b, m, n = x.shape + x = x.view(-1, n) + orig_shape = (b, m, n) + else: + orig_shape = x.shape + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = 1024 // x.element_size() + if N > BLOCK_N: + raise RuntimeError(f"Feature dimension N={N} exceeds maximum {BLOCK_N}") + + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y.view(*orig_shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_846578.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py new file mode 100644 index 0000000..f81044e --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, + Y, + stride_x_row, + N, + eps, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_N) + + x_ptrs = X + row * stride_x_row + cols + mask = cols < N + + x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + x2 = x_block * x_block + var = tl.sum(x2, axis=0) + rstd = tl.math.rsqrt(var + eps) + + y_block = x_block * rstd + y_ptrs = Y + row * N + cols + tl.store(y_ptrs, y_block, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + assert len(x.shape) == 2, "L2 norm kernel requires 2-D input tensor" + if not x.is_contiguous(): + x = x.contiguous() + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(65536 // x.element_size(), triton.next_power_of_2(N)) + if N > BLOCK_N: + raise ValueError(f"Feature dimension N={N} larger than maximum allowed ({BLOCK_N})") + + with torch.cuda.device(x.device): # For ROCm we still use CUDA semantics + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N, + ) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stdout new file mode 100644 index 0000000..a9f9867 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_964700.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_964700 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py new file mode 100644 index 0000000..6fa1aee --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py @@ -0,0 +1,113 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr +): + pid_m = tl.program_id(0) + row_start = pid_m * stride_x_row + _sum = tl.zeros([BLOCK_N], dtype=tl.float32) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N) + mask = cols < N + x_ptrs = X + row_start + cols + x_vals = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + _sum += x_vals * x_vals + var = tl.sum(_sum, axis=0) + rstd = tl.math.rsqrt(var + eps) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N) + mask = cols < N + x_ptrs = X + row_start + cols + y_ptrs = Y + row_start + cols + x_vals = tl.load(x_ptrs, mask=mask, other=0.0) + y_vals = x_vals * rstd + tl.store(y_ptrs, y_vals, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-6): + x = x.contiguous() + shape = x.shape + x = x.view(-1, shape[-1]) + M, N = x.shape + y = torch.empty_like(x) + + BLOCK_N = min(triton.next_power_of_2(N), 1 << 16) + assert N <= BLOCK_N, "Feature dimension N must not exceed BLOCK_N (64KB limit)" + _l2_norm_fwd_1pass_kernel[(M,)]( + x, y, + stride_x_row=x.stride(0), + N=N, + eps=eps, + BLOCK_N=BLOCK_N + ) + return y.view(*shape) + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_965300.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py new file mode 100644 index 0000000..ce86c02 --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py @@ -0,0 +1,132 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_id = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_N) + + # Row pointer + row_start_ptr = X + row_id * stride_x_row + + # Initialize accumulators + var = tl.zeros([BLOCK_N], dtype=tl.float32) + mask = col_offsets < N + + # Load data + x = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # Compute variance (sum of squares) + var = x * x + var = tl.sum(var, axis=0) + + # Compute rstd + rstd = tl.math.rsqrt(var + eps) + + # Normalize and store + y = x * rstd + tl.store(Y + row_id * stride_x_row + col_offsets, y, mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float) -> torch.Tensor: + # Ensure contiguous + if x.stride(-1) != 1: + x = x.contiguous() + + # Flatten to 2D (M, N) + original_shape = x.shape + x = x.view(-1, x.shape[-1]) + M, N = x.shape + + # Compute BLOCK_N + elem_size = x.element_size() + BLOCK_N = 2 ** (16 - elem_size.bit_length()) # Max 64KB / elem_size + BLOCK_N = min(BLOCK_N, triton.next_power_of_2(N)) + + if N > BLOCK_N: + raise ValueError(f"N ({N}) exceeds max BLOCK_N ({BLOCK_N})") + + # Prepare output + y = torch.empty((M, N), dtype=x.dtype, device=x.device) + + # Launch kernel + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + stride_x_row=x.stride(0), + N=N, + eps=eps, + BLOCK_N=BLOCK_N, + ) + + # Reshape back + y = y.view(original_shape) + return y + + +################################################################################################################################################## + + + + + +import torch + + + +# Test the forward L2 normalization + +def test_l2_norm_fwd(): + + results = {} + + + + # Test case 1 + + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + + y1 = _l2_norm_fwd(x1) + + results['test_case_1'] = y1 + + + + # Test case 2: Different batch size + + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + + y2 = _l2_norm_fwd(x2) + + results['test_case_2'] = y2 + + + + # Test case 3: Different feature size + + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + + y3 = _l2_norm_fwd(x3) + + results['test_case_3'] = y3 + + + + # Test case 4: Larger tensor + + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + + y4 = _l2_norm_fwd(x4) + + results['test_case_4'] = y4 + + + + return results + + + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stderr b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stdout b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stdout new file mode 100644 index 0000000..3871e8d --- /dev/null +++ b/src/temp/gen/l2_norm_triton1.py_gen_triton_code_973282.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module l2_norm_triton1.py_gen_triton_code_973282 due to _l2_norm_fwd() missing 1 required positional argument: 'eps' diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py new file mode 100644 index 0000000..1875402 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py @@ -0,0 +1,124 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_rows = SIZE_M + n_cols = D_HEAD + + num_tasks = n_rows * n_cols + for i in range(pid, num_tasks, BLOCK_SIZE): + if i < num_tasks: + row = i // n_cols + col = i % n_cols + + in_ptr = M + tl.make_block_ptr( + base=M, + shape=(n_rows, n_cols), + strides=(matrix_stridex, matrix_stridey), + offsets=(row, col), + block_shape=(1, 1), + order=(0, 1) + ) + val = tl.load(in_ptr) + + out_ptr = Out + tl.make_block_ptr( + base=Out, + shape=(n_cols, n_rows), + strides=(out_stridex, out_stridey), + offsets=(col, row), + block_shape=(1, 1), + order=(0, 1) + ) + tl.store(out_ptr, val) + + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(SIZE_M * D_HEAD, BLOCK_SIZE),) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_SIZE, + ) + return out + + +if __name__ == "__main__": + SIZE_M = 64 + D_HEAD = 128 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = wrapper(SIZE_M, D_HEAD) + ref = matrix.t() + assert torch.allclose(ref, out, atol=1e-3), abs(ref - out).max() + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stdout new file mode 100644 index 0000000..5d0959d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_114093.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_114093 due to at 22:22: +): + pid = tl.program_id(0) + n_rows = SIZE_M + n_cols = D_HEAD + + num_tasks = n_rows * n_cols + for i in range(pid, num_tasks, BLOCK_SIZE): + if i < num_tasks: + row = i // n_cols + col = i % n_cols + + in_ptr = M + tl.make_block_ptr( + ^ +IncompatibleTypeErrorImpl('invalid operands of type pointer and pointer<<[1, 1], fp16>>') diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py new file mode 100644 index 0000000..4ef316d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py @@ -0,0 +1,96 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = tl.arange(0, SIZE_M) + offs_n = tl.arange(0, D_HEAD) + + matrix_ptr = M + offs_n[None, :] * matrix_stridex + offs_m[:, None] * matrix_stridey + out_ptr = Out + offs_n[None, :] * out_stridey + offs_m[:, None] * out_stridex + + x = tl.load(matrix_ptr) + tl.store(out_ptr, x) + + +def wrapper(SIZE_M: int, D_HEAD: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + + grid = (D_HEAD, SIZE_M) + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M=SIZE_M, + D_HEAD=D_HEAD, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_11496.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py new file mode 100644 index 0000000..96ed226 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py @@ -0,0 +1,115 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + idxs_m = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_M), BLOCK_M) + idxs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + + matrix_ptrs = M + (idxs_m[:, None] * matrix_stridex + idxs_n[None, :] * matrix_stridey) + out_ptrs = Out + (idxs_n[:, None] * out_stridex + idxs_m[None, :] * out_stridey) + + mask = mask_m[:, None] & mask_n[None, :] + a = tl.load(matrix_ptrs, mask=mask) + tl.store(out_ptrs, a, mask=mask) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_M = 32 + BLOCK_N = 32 + grid = lambda meta: (triton.cdiv(SIZE_M, meta['BLOCK_M']), + triton.cdiv(D_HEAD, meta['BLOCK_N'])) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return out + +if __name__ == "__main__": + SIZE_M = 64 + D_HEAD = 128 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = wrapper(SIZE_M, D_HEAD) + ref = matrix.t() + assert torch.allclose(ref, out, atol=1e-3), abs(ref - out).max() + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stdout new file mode 100644 index 0000000..a6e237c --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_14792.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_14792.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py new file mode 100644 index 0000000..10638c5 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py @@ -0,0 +1,105 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + ptrs_m = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + ptrs_out = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + x = tl.load(ptrs_m, mask=mask_m[:, None] & mask_n[None, :]) + tl.store(ptrs_out, x, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(): + SIZE_M = 128 + D_HEAD = 64 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = lambda META: ( + triton.cdiv(SIZE_M, BLOCK_SIZE_M), + triton.cdiv(D_HEAD, BLOCK_SIZE_N) + ) + + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stdout new file mode 100644 index 0000000..fe987ae --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_160821.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_160821 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py new file mode 100644 index 0000000..d26212d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py @@ -0,0 +1,89 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + rn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + + mask = (rm < SIZE_M) & (rn < D_HEAD) + in_idx = rm * matrix_stridex + rn * matrix_stridey + out_idx = rn * out_stridex + rm * out_stridey + + val = tl.load(M + in_idx, mask=mask) + tl.store(Out + out_idx, val, mask=mask) + + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, 32), triton.cdiv(D_HEAD, 32)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stdout new file mode 100644 index 0000000..d0ea652 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_205496.py.stdout @@ -0,0 +1,8 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_205496 due to at 6:30: +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py new file mode 100644 index 0000000..f794b81 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py @@ -0,0 +1,90 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD +): + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + + m_ptrs = M + pid_x * matrix_stridey + pid_y * matrix_stridex + o_ptrs = Out + pid_y * out_stridey + pid_x * out_stridex + + val = tl.load(m_ptrs) + tl.store(o_ptrs, val) + +def wrapper(): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = lambda meta: (SIZE_M, D_HEAD) + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD + ) + return out + +SIZE_M = 64 +D_HEAD = 128 +result = wrapper() + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stdout new file mode 100644 index 0000000..ee57fb4 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_216901.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_216901 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py new file mode 100644 index 0000000..f6cf2d3 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py @@ -0,0 +1,102 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + idx = (offs_m[:, None] * matrix_stridey + offs_n[None, :] * matrix_stridex) + matrix_data = tl.load(M + idx, mask=mask_m[:, None] & mask_n[None, :]) + + idx_out = (offs_n[:, None] * out_stridey + offs_m[None, :] * out_stridex) + tl.store(Out + idx_out, matrix_data, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn(SIZE_M, D_HEAD, dtype=torch.float16, device='cuda') + out = torch.zeros(D_HEAD, SIZE_M, dtype=torch.float16, device='cuda') + + BLOCK_M = 128 + BLOCK_D = 128 + + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_D)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stdout new file mode 100644 index 0000000..eb9a2c1 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_274099.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_274099.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py new file mode 100644 index 0000000..f5c5934 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py @@ -0,0 +1,105 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + d_offs = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = m_offs < SIZE_M + mask_d = d_offs < D_HEAD + + src_ptrs = M + d_offs[None, :] * matrix_stridey + m_offs[:, None] * matrix_stridex + mask = mask_m[:, None] & mask_d[None, :] + src = tl.load(src_ptrs, mask=mask) + src_T = tl.trans(src) + + dst_ptrs = Out + m_offs[None, :] * d_offs[:, None] * 0 + m_offs[None, :] * out_stridex + d_offs[:, None] * out_stridey + tl.store(dst_ptrs, src_T, mask=mask.T) + +@torch.no_grad() +def wrapper(SIZE_M: int, D_HEAD: int): + BLOCK_M = 32 + BLOCK_D = 32 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + grid = lambda META: (triton.cdiv(SIZE_M, META['BLOCK_M']), + triton.cdiv(D_HEAD, META['BLOCK_D'])) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stdout new file mode 100644 index 0000000..cbacf29 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_369711.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_369711.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py new file mode 100644 index 0000000..e52c3d1 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py @@ -0,0 +1,85 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M: tl.constexpr, D_HEAD: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * D_HEAD + tl.arange(0, D_HEAD) + offs_n = pid_n + + m_ptrs = M + offs_m * matrix_stridex + offs_n * matrix_stridey + mask = offs_m < SIZE_M * D_HEAD + A = tl.load(m_ptrs, mask=mask) + + out_ptrs = Out + offs_n * out_stridex + offs_m * out_stridey + tl.store(out_ptrs, A, mask=mask) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = lambda META: (SIZE_M,) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M=SIZE_M, + D_HEAD=D_HEAD + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stdout new file mode 100644 index 0000000..fddd219 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_412290.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_412290.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py new file mode 100644 index 0000000..c6b4ddc --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, + SIZE_M, D_HEAD, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + x_idx = offs_m[:, None] * D_HEAD + offs_n[None, :] + y_idx = offs_n[:, None] * SIZE_M + offs_m[None, :] + + mask_i = (offs_m[:, None] < SIZE_M) & (offs_n[None, :] < D_HEAD) + mask_o = (offs_n[:, None] < D_HEAD) & (offs_m[None, :] < SIZE_M) + + val = tl.load(M + x_idx, mask=mask_i, other=0.0) + tl.store(Out + y_idx, val.trans(), mask=mask_o) + +def wrapper( + SIZE_M: int, D_HEAD: int +) -> torch.Tensor: + device = torch.device("cuda") + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = ( + triton.cdiv(SIZE_M, BLOCK_SIZE_M), + triton.cdiv(D_HEAD, BLOCK_SIZE_N), + ) + + # For row-major, matrix.stride(1) = D_HEAD, matrix.stride(0) = 1 => stride_x = D_HEAD, stride_y = 1 + # out.stride(1) = SIZE_M, out.stride(0) = 1 => stride_x_out = SIZE_M, stride_y_out = 1 + # tl.arange * stride simplifies when stride==1 and we use element offsets; so we pass stride=1. + kernel[grid]( + matrix, out, + 1, 1, + 1, 1, + SIZE_M, D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_429164.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py new file mode 100644 index 0000000..87fea90 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py @@ -0,0 +1,105 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + d_offs = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = m_offs < SIZE_M + mask_d = d_offs < D_HEAD + + src_ptrs = M + d_offs[None, :] * matrix_stridey + m_offs[:, None] * matrix_stridex + mask = mask_m[:, None] & mask_d[None, :] + src = tl.load(src_ptrs, mask=mask) + src_T = tl.trans(src) + + dst_ptrs = Out + d_offs[:, None] * out_stridex + m_offs[None, :] * out_stridey + tl.store(dst_ptrs, src_T, mask=mask.T) + +@torch.no_grad() +def wrapper(SIZE_M: int, D_HEAD: int): + BLOCK_M = 32 + BLOCK_D = 32 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + grid = lambda META: (triton.cdiv(SIZE_M, META['BLOCK_M']), + triton.cdiv(D_HEAD, META['BLOCK_D'])) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_469771.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py new file mode 100644 index 0000000..725c086 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py @@ -0,0 +1,84 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + if (pid_m < SIZE_M) and (pid_n < D_HEAD): + in_idx = pid_m * matrix_stridex + pid_n * matrix_stridey + out_idx = pid_n * out_stridex + pid_m * out_stridey + + val = tl.load(M + in_idx) + tl.store(Out + out_idx, val) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, 32), triton.cdiv(D_HEAD, 32)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stdout new file mode 100644 index 0000000..43ce7e9 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_493615.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_493615.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py new file mode 100644 index 0000000..e20c50f --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py @@ -0,0 +1,87 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + size_m_arange = tl.arange(0, SIZE_M) + d_head_arange = tl.arange(0, D_HEAD) + matrix_ptr = M + d_head_arange[None, :] * matrix_stridex + size_m_arange[:, None] * matrix_stridey + out_ptr = Out + size_m_arange[None, :] * out_stridex + d_head_arange[:, None] * out_stridey + matrix = tl.load(matrix_ptr) + tl.store(out_ptr, matrix) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + grid = (1,) + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M, + D_HEAD, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stdout new file mode 100644 index 0000000..801f724 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_529486.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_529486 due to at 16:4: + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + size_m_arange = tl.arange(0, SIZE_M) + d_head_arange = tl.arange(0, D_HEAD) + matrix_ptr = M + d_head_arange[None, :] * matrix_stridex + size_m_arange[:, None] * matrix_stridey + out_ptr = Out + size_m_arange[None, :] * out_stridex + d_head_arange[:, None] * out_stridey + matrix = tl.load(matrix_ptr) + tl.store(out_ptr, matrix) + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py new file mode 100644 index 0000000..f194cf6 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py @@ -0,0 +1,97 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + row = tl.arange(0, SIZE_M)[:, None] + col = tl.arange(0, D_HEAD)[None, :] + in_ptrs = M + row * matrix_stridex + col * matrix_stridey + out_ptrs = Out + col * out_stridex + row * out_stridey + vals = tl.load(in_ptrs) + tl.store(out_ptrs, vals) + + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + grid = lambda meta: (1,) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + ) + return out + + +if __name__ == "__main__": + SIZE_M = 64 + D_HEAD = 128 + out = wrapper(SIZE_M, D_HEAD) + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + assert torch.allclose(matrix.t(), out) + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_571713.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py new file mode 100644 index 0000000..910500d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, + SIZE_M, D_HEAD, + BLOCK_SIZE_M: tl.constexpr = 32, BLOCK_SIZE_N: tl.constexpr = 32 +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + in_ptrs = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + out_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + load_mask = mask_m[:, None] & mask_n[None, :] + load_data = tl.load(in_ptrs, mask=load_mask, other=0.0) + + store_mask = mask_n[:, None] & mask_m[None, :] + tl.store(out_ptrs, tl.trans(load_data.to(tl.float16)), mask=store_mask) + +def wrapper( + matrix: torch.Tensor, out: torch.Tensor, + matrix_stridex: int, matrix_stridey: int, + out_stridex: int, out_stridey: int, + SIZE_M: int, D_HEAD: int +): + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + + grid = lambda META: ( + triton.cdiv(SIZE_M, META["BLOCK_SIZE_M"]), + triton.cdiv(D_HEAD, META["BLOCK_SIZE_N"]), + ) + + kernel[grid]( + matrix, out, + matrix_stridex, matrix_stridey, + out_stridex, out_stridey, + SIZE_M, D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stdout new file mode 100644 index 0000000..ed9f109 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_580037.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_580037 due to wrapper() missing 6 required positional arguments: 'matrix_stridex', 'matrix_stridey', 'out_stridex', 'out_stridey', 'SIZE_M', and 'D_HEAD' diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py new file mode 100644 index 0000000..d26212d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py @@ -0,0 +1,89 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + rn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + + mask = (rm < SIZE_M) & (rn < D_HEAD) + in_idx = rm * matrix_stridex + rn * matrix_stridey + out_idx = rn * out_stridex + rm * out_stridey + + val = tl.load(M + in_idx, mask=mask) + tl.store(Out + out_idx, val, mask=mask) + + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, 32), triton.cdiv(D_HEAD, 32)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stdout new file mode 100644 index 0000000..911f936 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_608628.py.stdout @@ -0,0 +1,8 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_608628 due to at 6:30: +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py new file mode 100644 index 0000000..ee0e5ae --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * 128 + tl.arange(0, 128) + offs_n = pid_n * 128 + tl.arange(0, 128) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + # Load data + idx = offs_m[:, None] * matrix_stridey + offs_n[None, :] * matrix_stridex + data = tl.load(M + idx, mask=mask_m[:, None] & mask_n[None, :]) + + # Store transposed + idx_out = offs_n[:, None] * out_stridey + offs_m[None, :] * out_stridex + tl.store(Out + idx_out, data, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(): + SIZE_M = 1024 + D_HEAD = 512 + + # ROCm uses "cuda" device name + matrix = torch.randn(SIZE_M, D_HEAD, dtype=torch.float16, device='cuda') + out = torch.zeros(D_HEAD, SIZE_M, dtype=torch.float16, device='cuda') + + grid = (triton.cdiv(SIZE_M, 128), triton.cdiv(D_HEAD, 128)) + + kernel[grid]( + matrix, out, + matrix.stride(1), matrix.stride(0), + out.stride(1), out.stride(0), + SIZE_M, D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stdout new file mode 100644 index 0000000..cc28e1b --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_619005.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_619005 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py new file mode 100644 index 0000000..77d9ea3 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py @@ -0,0 +1,102 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + BLOCK_SIZE_M = tl.program_id(2) if hasattr(tl.program_id, '__call__') else 32 + BLOCK_SIZE_N = tl.program_id(3) if hasattr(tl.program_id, '__call__') else 32 + + offs_m = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_n = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + mask_m = offs_m < D_HEAD + mask_n = offs_n < SIZE_M + + ptrs_m = M + offs_n[:, None] * matrix_stridex + offs_m[None, :] * matrix_stridey + ptrs_out = Out + offs_m[:, None] * out_stridex + offs_n[None, :] * out_stridey + + x = tl.load(ptrs_m, mask=mask_n[:, None] & mask_m[None, :]) + tl.store(ptrs_out, x, mask=mask_m[:, None] & mask_n[None, :]) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = lambda META: ( + triton.cdiv(D_HEAD, BLOCK_SIZE_N), + triton.cdiv(SIZE_M, BLOCK_SIZE_M) + ) + + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stdout new file mode 100644 index 0000000..c7b4396 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_620806.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_620806 due to at 14:39: + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + BLOCK_SIZE_M = tl.program_id(2) if hasattr(tl.program_id, '__call__') else 32 + ^ +NameError('hasattr is not defined') diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py new file mode 100644 index 0000000..3722550 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py @@ -0,0 +1,103 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + ptrs_m = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + ptrs_out = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + x = tl.load(ptrs_m, mask=mask_m[:, None] & mask_n[None, :]) + tl.store(ptrs_out, x, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = lambda META: ( + triton.cdiv(SIZE_M, BLOCK_SIZE_M), + triton.cdiv(D_HEAD, BLOCK_SIZE_N) + ) + + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stdout new file mode 100644 index 0000000..0656835 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_671609.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_671609.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py new file mode 100644 index 0000000..bded7c9 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, + SIZE_M, D_HEAD, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + in_ptrs = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + out_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + load_mask = mask_m[:, None] & mask_n[None, :] + load_data = tl.load(in_ptrs, mask=load_mask, other=0.0) + + store_mask = mask_n[:, None] & mask_m[None, :] + tl.store(out_ptrs, load_data.trans(), mask=store_mask) + +def wrapper( + matrix: torch.Tensor, out: torch.Tensor, + matrix_stridex: int, matrix_stridey: int, + out_stridex: int, out_stridey: int, + SIZE_M: int, D_HEAD: int +): + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + + grid = lambda META: ( + triton.cdiv(SIZE_M, META["BLOCK_SIZE_M"]), + triton.cdiv(D_HEAD, META["BLOCK_SIZE_N"]), + ) + + kernel[grid]( + matrix, out, + matrix_stridex, matrix_stridey, + out_stridex, out_stridey, + SIZE_M, D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stdout new file mode 100644 index 0000000..09e6d01 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_724790.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_724790 due to wrapper() missing 6 required positional arguments: 'matrix_stridex', 'matrix_stridey', 'out_stridex', 'out_stridey', 'SIZE_M', and 'D_HEAD' diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py new file mode 100644 index 0000000..d26212d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py @@ -0,0 +1,89 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + rn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + + mask = (rm < SIZE_M) & (rn < D_HEAD) + in_idx = rm * matrix_stridex + rn * matrix_stridey + out_idx = rn * out_stridex + rm * out_stridey + + val = tl.load(M + in_idx, mask=mask) + tl.store(Out + out_idx, val, mask=mask) + + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, 32), triton.cdiv(D_HEAD, 32)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stdout new file mode 100644 index 0000000..cdb1c01 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_738982.py.stdout @@ -0,0 +1,8 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_738982 due to at 6:30: +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py new file mode 100644 index 0000000..d26212d --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py @@ -0,0 +1,89 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + rn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] + + mask = (rm < SIZE_M) & (rn < D_HEAD) + in_idx = rm * matrix_stridex + rn * matrix_stridey + out_idx = rn * out_stridex + rm * out_stridey + + val = tl.load(M + in_idx, mask=mask) + tl.store(Out + out_idx, val, mask=mask) + + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, 32), triton.cdiv(D_HEAD, 32)) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stdout new file mode 100644 index 0000000..26e7c69 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_74175.py.stdout @@ -0,0 +1,8 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_74175 due to at 6:30: +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + BLOCK_SIZE = 32 + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py new file mode 100644 index 0000000..c60dda3 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py @@ -0,0 +1,108 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + ptrs_m = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + ptrs_out = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + x = tl.load(ptrs_m, mask=mask_m[:, None] & mask_n[None, :]) + tl.store(ptrs_out, x, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(SIZE_M: int, D_HEAD: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device=device) + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device=device) + + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 32 + grid = ( + triton.cdiv(SIZE_M, BLOCK_SIZE_M), + triton.cdiv(D_HEAD, BLOCK_SIZE_N), + ) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stdout new file mode 100644 index 0000000..d43acc3 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_757083.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_757083.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py new file mode 100644 index 0000000..b49b641 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py @@ -0,0 +1,104 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + + idx = (offs_m[:, None] * matrix_stridey + offs_n[None, :] * matrix_stridex) + matrix_data = tl.load(M + idx, mask=mask_m[:, None] & mask_n[None, :]) + + idx_out = (offs_n[:, None] * out_stridey + offs_m[None, :] * out_stridex) + tl.store(Out + idx_out, matrix_data, mask=mask_n[:, None] & mask_m[None, :]) + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn(SIZE_M, D_HEAD, dtype=torch.float16, device='cuda') + out = torch.zeros(D_HEAD, SIZE_M, dtype=torch.float16, device='cuda') + + BLOCK_M = 128 + BLOCK_D = 128 + + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_D)) + + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stdout new file mode 100644 index 0000000..c573083 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_759138.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_759138.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py new file mode 100644 index 0000000..5f068d6 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py @@ -0,0 +1,105 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = offs_m < SIZE_M + mask_d = offs_d < D_HEAD + + mask_orig = mask_m[:, None] & mask_d[None, :] + + matrix_ptrs = M + offs_m[:, None] * matrix_stridey + offs_d[None, :] * matrix_stridex + out_ptrs = Out + offs_d[:, None] * out_stridey + offs_m[None, :] * out_stridex + + data = tl.load(matrix_ptrs, mask=mask_orig) + tl.store(out_ptrs, data, mask=mask_orig.T) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + BLOCK_M = 32 + BLOCK_D = 32 + + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_D)) + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stdout new file mode 100644 index 0000000..3cb8118 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_780911.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_780911.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py new file mode 100644 index 0000000..d084549 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py @@ -0,0 +1,107 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = offs_m < SIZE_M + mask_d = offs_d < D_HEAD + mask = mask_m[:, None] & mask_d[None, :] + + m_ptrs = M + offs_m[:, None] * matrix_stridey + offs_d[None, :] * matrix_stridex + o_ptrs = Out + offs_d[:, None] * out_stridex + offs_m[None, :] * out_stridey + + vals = tl.load(m_ptrs, mask=mask) + tl.store(o_ptrs, vals.T, mask=mask.T) + +def wrapper(SIZE_M, D_HEAD): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + BLOCK_M = 32 + BLOCK_D = 32 + grid = ( + triton.cdiv(SIZE_M, BLOCK_M), + triton.cdiv(D_HEAD, BLOCK_D), + ) + + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M, + D_HEAD, + BLOCK_M, + BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stdout new file mode 100644 index 0000000..28c54eb --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_783719.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_783719.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py new file mode 100644 index 0000000..db8aeac --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py @@ -0,0 +1,111 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + ROW_TILE: tl.constexpr, + COL_TILE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * ROW_TILE + tl.arange(0, ROW_TILE) + offs_n = pid_n * COL_TILE + tl.arange(0, COL_TILE) + + mask_m = offs_m < SIZE_M + mask_n = offs_n < D_HEAD + mask = mask_m[:, None] & mask_n[None, :] + + rows = offs_m[:, None] + cols = offs_n[None, :] + + m_addrs = M + rows * matrix_stridex + cols * matrix_stridey + x = tl.load(m_addrs, mask=mask) + + out_addrs = Out + cols * out_stridex + rows * out_stridey + tl.store(out_addrs, x, mask=mask.T) + + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + ROW_TILE = 16 + COL_TILE = 16 + grid = lambda META: ( + triton.cdiv(SIZE_M, META['ROW_TILE']), + triton.cdiv(D_HEAD, META['COL_TILE']), + ) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M=SIZE_M, + D_HEAD=D_HEAD, + ROW_TILE=ROW_TILE, + COL_TILE=COL_TILE, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_81159.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py new file mode 100644 index 0000000..877ab6e --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py @@ -0,0 +1,89 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + grid_m = tl.cdiv(SIZE_M, 16) + grid_n = tl.cdiv(D_HEAD, 16) + + for i in range(0, grid_m): + for j in range(0, grid_n): + offs_m = i * 16 + tl.arange(0, 16)[:, None] + offs_n = j * 16 + tl.arange(0, 16)[None, :] + + mask = (offs_m < SIZE_M) & (offs_n < D_HEAD) + + m_ptrs = M + pid_m * matrix_stridex + pid_n * matrix_stridey + offs_m * matrix_stridex + offs_n * matrix_stridey + out_ptrs = Out + pid_m * out_stridex + pid_n * out_stridey + offs_n * out_stridex + offs_m * out_stridey + + m_val = tl.load(m_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, m_val, mask=mask) + +def wrapper(): + SIZE_M = 64 + D_HEAD = 128 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = (1,) + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stdout new file mode 100644 index 0000000..db68c61 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_853096.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_853096 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py new file mode 100644 index 0000000..3c01166 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py @@ -0,0 +1,111 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_rows = SIZE_M + n_cols = D_HEAD + + num_tasks = n_rows * n_cols + for i in range(pid, num_tasks, BLOCK_SIZE): + if i < num_tasks: + row = i // n_cols + col = i % n_cols + + in_ptr = M + row * matrix_stridex + col * matrix_stridey + out_ptr = Out + col * out_stridex + row * out_stridey + + val = tl.load(in_ptr) + tl.store(out_ptr, val) + + +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(SIZE_M * D_HEAD, BLOCK_SIZE),) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_SIZE, + ) + return out + + +if __name__ == "__main__": + SIZE_M = 64 + D_HEAD = 128 + out = wrapper(SIZE_M, D_HEAD) + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + ref = matrix.t() + assert torch.allclose(ref, out), abs(ref - out).max() + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stdout new file mode 100644 index 0000000..4732268 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_869907.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_869907.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py new file mode 100644 index 0000000..f061654 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py @@ -0,0 +1,105 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + mask_m = offs_m < SIZE_M + mask_d = offs_d < D_HEAD + + mask = mask_m[:, None] & mask_d[None, :] + + matrix_ptrs = M + offs_m[:, None] * matrix_stridex + offs_d[None, :] * matrix_stridey + out_ptrs = Out + offs_d[:, None] * out_stridex + offs_m[None, :] * out_stridey + + data = tl.load(matrix_ptrs, mask=mask) + tl.store(out_ptrs, data, mask=mask.T) + +def wrapper(SIZE_M: int, D_HEAD: int) -> torch.Tensor: + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + BLOCK_M = 32 + BLOCK_D = 32 + + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_D)) + kernel[grid]( + matrix, + out, + matrix.stride(1), + matrix.stride(0), + out.stride(1), + out.stride(0), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stdout new file mode 100644 index 0000000..1abe6de --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_879575.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_879575.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py new file mode 100644 index 0000000..a1ded3c --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py @@ -0,0 +1,87 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + BLOCK_M = 16 + BLOCK_N = 16 + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + + mask = (offs_m < SIZE_M) & (offs_n < D_HEAD) + + m_ptrs = M + offs_m * matrix_stridex + offs_n * matrix_stridey + out_ptrs = Out + offs_n * out_stridex + offs_m * out_stridey + + m_val = tl.load(m_ptrs, mask=mask) + tl.store(out_ptrs, m_val, mask=mask) + +def wrapper(size_m: int, d_head: int): + SIZE_M = size_m + D_HEAD = d_head + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = (triton.cdiv(SIZE_M, 16), triton.cdiv(D_HEAD, 16)) + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stdout new file mode 100644 index 0000000..f37fa34 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_892743.py.stdout @@ -0,0 +1,10 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_892743 due to at 8:31: +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + BLOCK_M = 16 + BLOCK_N = 16 + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + ^ diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py new file mode 100644 index 0000000..f8f9139 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py @@ -0,0 +1,88 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + + mask = (offs_m < SIZE_M) & (offs_n < D_HEAD) + + m_ptrs = M + offs_m * matrix_stridex + offs_n * matrix_stridey + out_ptrs = Out + offs_n * out_stridex + offs_m * out_stridey + + m_val = tl.load(m_ptrs, mask=mask) + tl.store(out_ptrs, m_val, mask=mask) + +def wrapper(size_m: int, d_head: int): + SIZE_M = size_m + D_HEAD = d_head + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + BLOCK_M = 16 + BLOCK_N = 16 + grid = (triton.cdiv(SIZE_M, BLOCK_M), triton.cdiv(D_HEAD, BLOCK_N)) + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_917011.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py new file mode 100644 index 0000000..37a2b05 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py @@ -0,0 +1,107 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) * BLOCK_M + pid_d = tl.program_id(1) * BLOCK_D + + m_offs = pid_m + tl.arange(0, BLOCK_M) + d_offs = pid_d + tl.arange(0, BLOCK_D) + + mask_m = m_offs < SIZE_M + mask_d = d_offs < D_HEAD + + src_ptrs = M + d_offs[None, :] * matrix_stridey + m_offs[:, None] * matrix_stridex + mask = mask_m[:, None] & mask_d[None, :] + src = tl.load(src_ptrs, mask=mask) + + dst_ptrs = Out + m_offs[None, :] * out_stridex + d_offs[:, None] * out_stridey + tl.store(dst_ptrs, src, mask=mask.T) + +@torch.no_grad() +def wrapper(SIZE_M: int, D_HEAD: int): + BLOCK_M = 32 + BLOCK_D = 32 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.empty((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = lambda META: (triton.cdiv(SIZE_M, META['BLOCK_M']), + triton.cdiv(D_HEAD, META['BLOCK_D'])) + + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stdout new file mode 100644 index 0000000..3011225 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_930305.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_transpose.py_gen_triton_code_930305.py diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py new file mode 100644 index 0000000..2e1b8d7 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py @@ -0,0 +1,87 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + BLOCK_M = 16 + BLOCK_N = 16 + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + + mask = (offs_m < SIZE_M) & (offs_n < D_HEAD) + + m_ptrs = M + offs_m * matrix_stridex + offs_n * matrix_stridey + out_ptrs = Out + offs_n * out_stridex + offs_m * out_stridey + + m_val = tl.load(m_ptrs, mask=mask) + tl.store(out_ptrs, m_val, mask=mask) + +def wrapper(): + SIZE_M = 64 + D_HEAD = 128 + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = (triton.cdiv(SIZE_M, 16), triton.cdiv(D_HEAD, 16)) + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stdout new file mode 100644 index 0000000..57a585c --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_953212.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_953212 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py new file mode 100644 index 0000000..4bcc63a --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py @@ -0,0 +1,90 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, + matrix_stridex, matrix_stridey, + out_stridex, out_stridey, + SIZE_M: tl.constexpr, D_HEAD: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + cols = tl.arange(0, SIZE_M) + rows = tl.arange(0, D_HEAD) + + mask = (cols[:, None] < SIZE_M) & (rows[None, :] < D_HEAD) + + m_ptrs = M + cols[:, None] * matrix_stridey + rows[None, :] * matrix_stridex + o_ptrs = Out + rows[None, :] * out_stridey + cols[:, None] * out_stridex + + vals = tl.load(m_ptrs, mask=mask) + tl.store(o_ptrs, vals, mask=mask) + +def wrapper(): + SIZE_M = 128 + D_HEAD = 64 + + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device='cuda') + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device='cuda') + + grid = (1,) + kernel[grid]( + matrix, out, + matrix.stride(1), matrix.stride(0), + out.stride(1), out.stride(0), + SIZE_M, D_HEAD, + num_warps=4, + num_stages=2 + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stdout new file mode 100644 index 0000000..889f3e4 --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_984648.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_984648 due to wrapper() takes 0 positional arguments but 2 were given diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py b/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py new file mode 100644 index 0000000..fccc2bc --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py @@ -0,0 +1,102 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel( + M, + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + offs_m = pid_m * D_HEAD + tl.arange(0, D_HEAD) + offs_d = pid_d * SIZE_M + tl.arange(0, SIZE_M) + + mask_m = offs_m < SIZE_M * D_HEAD + mask_d = offs_d < SIZE_M * D_HEAD + + src_ptrs = M + offs_m[:, None] * matrix_stridex + offs_d[None, :] * matrix_stridey + src_mask = (offs_m[:, None] < SIZE_M) & (offs_d[None, :] < D_HEAD) + + src_data = tl.load(src_ptrs, mask=src_mask, other=0.0) + + dst_ptrs = Out + offs_d[:, None] * out_stridex + offs_m[None, :] * out_stridey + dst_mask = (offs_d[:, None] < D_HEAD) & (offs_m[None, :] < SIZE_M) + + tl.store(dst_ptrs, src_data.T, mask=dst_mask) + +@torch.no_grad() +def wrapper(SIZE_M: int, D_HEAD: int): + matrix = torch.randn((SIZE_M, D_HEAD), dtype=torch.float16, device="cuda") + out = torch.zeros((D_HEAD, SIZE_M), dtype=torch.float16, device="cuda") + + grid = (1,) + kernel[grid]( + matrix, + out, + matrix.stride(0), + matrix.stride(1), + out.stride(0), + out.stride(1), + SIZE_M, + D_HEAD, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_triton_vs_torch(): + + results = {} + + + + # 测试用例 1: 基本矩阵转置 (小矩阵) + + size_m, d_head = 16, 16 + + out = wrapper(size_m, d_head) + + results["test_case_1"] = out.clone() + + + + # 测试用例 2: 非方形矩阵 + + size_m, d_head = 32, 64 + + out = wrapper(size_m, d_head) + + results["test_case_2"] = out.clone() + + + + return results + + + + + +# 运行测试 + +result_gold = test_triton_vs_torch() + +# print(result_gold) \ No newline at end of file diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stderr b/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stdout b/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stdout new file mode 100644 index 0000000..3e71b8f --- /dev/null +++ b/src/temp/gen/matrix_transpose.py_gen_triton_code_997014.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module matrix_transpose.py_gen_triton_code_997014 due to at 14:30: + Out, + matrix_stridex, + matrix_stridey, + out_stridex, + out_stridey, + SIZE_M, + D_HEAD, +): + pid_m = tl.program_id(0) + pid_d = tl.program_id(1) + + offs_m = pid_m * D_HEAD + tl.arange(0, D_HEAD) + ^ diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py new file mode 100644 index 0000000..e612788 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py @@ -0,0 +1,87 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A_ptr, B_ptr, C_ptr, N, M, stride_am, stride_ak, stride_bk, stride_cn, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid = tl.program_id(0) + row_start = pid * BLOCK_N + offs_n = row_start + tl.arange(0, BLOCK_N) + + col_start = 0 + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + for col_start in range(0, M, BLOCK_M): + offs_m = col_start + tl.arange(0, BLOCK_M) + a_idx = A_ptr + offs_n[:, None] * stride_am + offs_m[None, :] * stride_ak + a_mask = (offs_n[:, None] < N) & (offs_m[None, :] < M) + a_vals = tl.load(a_idx, mask=a_mask, other=0.0) + + b_idx = B_ptr + offs_m * stride_bk + b_mask = offs_m < M + b_vals = tl.load(b_idx, mask=b_mask, other=0.0) + + acc += tl.sum(a_vals * b_vals[None, :], axis=1) + + c_idx = C_ptr + offs_n * stride_cn + c_mask = offs_n < N + tl.store(c_idx, acc, mask=c_mask) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1, "A must be 2D and B must be 1D" + assert A.size(1) == B.size(0), "Inner matrix dimensions must agree" + N, M = A.shape + C = torch.empty((N,), dtype=torch.float32, device=A.device) + + def grid(meta): + return (triton.cdiv(N, meta['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, N, M, + A.stride(0), A.stride(1), B.stride(0), C.stride(0), + BLOCK_N=32, BLOCK_M=32 + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_164112.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py new file mode 100644 index 0000000..253dd7b --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py @@ -0,0 +1,88 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_bn, stride_cm, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = tl.where(rm < M, rm, 0) + rn = tl.where(rn < N, rn, 0) + acc = tl.zeros([BLOCK_M], dtype=tl.float32) + + for k in range(0, N, BLOCK_N): + a_idx = (rm[:, None] * stride_am + (k + rn)[None, :] * stride_an) + b_idx = (k + rn) * stride_bn + a = tl.load(A + a_idx, mask=(k + rn)[None, :] < N, other=0).to(tl.float32) + b = tl.load(B + b_idx, mask=(k + rn) < N, other=0).to(tl.float32) + acc += tl.sum(a * b[None, :], axis=1) + + c_idx = rm * stride_cm + mask = rm < M + tl.store(C + c_idx, acc, mask=mask) + + +def mv(A: torch.Tensor, B: torch.Tensor): + assert A.dim() == 2 and B.dim() == 1, "A must be 2D and B must be 1D" + M, N = A.shape + assert B.shape[0] == N, "Dimension mismatch between matrix and vector" + assert A.device == B.device, "Input tensors must be on the same device" + assert A.dtype in [torch.float16, torch.float32] and B.dtype in [torch.float16, torch.float32], "dtypes must be float16 or float32" + + C = torch.empty((M,), dtype=A.dtype, device=A.device) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']),) + + mv_kernel[grid]( + A, B, C, M, N, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_M=64, + BLOCK_N=32 + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stdout new file mode 100644 index 0000000..e9c6751 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_205689.py.stdout @@ -0,0 +1 @@ +True*#*#False*#*#None*#*#Generated output does not match reference output for file: matrix_vector_multip.py_gen_triton_code_205689.py diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py new file mode 100644 index 0000000..cee87ab --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py @@ -0,0 +1,88 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, stride_am, stride_an, stride_b, stride_cm, + N: tl.constexpr, M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_N + offs_n = block_start + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + + mask_n = offs_n < N + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for start_m in range(0, M, BLOCK_M): + offs_m_curr = start_m + offs_m + mask_m = offs_m_curr < M + a_ptrs = A + (offs_n[:, None] * stride_am + offs_m_curr[None, :] * stride_an) + b_ptrs = B + offs_m_curr * stride_b + a = tl.load(a_ptrs, mask=mask_n[:, None] & mask_m[None, :], other=0.0) + b = tl.load(b_ptrs, mask=mask_m, other=0.0) + acc += tl.sum(a * b[None, :], axis=1) + + c_ptrs = C + offs_n * stride_cm + tl.store(c_ptrs, acc, mask=mask_n) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2, "A must be 2D" + assert B.dim() == 1, "B must be 1D" + N, M = A.shape + assert B.shape[0] == M, "A and B shapes incompatible" + + C = torch.empty(N, dtype=A.dtype, device=A.device) + + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, + A.stride(0), A.stride(1), B.stride(0), C.stride(0), + N, M, + BLOCK_N=64, BLOCK_M=64, + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_334537.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py new file mode 100644 index 0000000..dc86e3c --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py @@ -0,0 +1,111 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel( + A, B, C, + N, M, + stride_am, stride_an, + stride_b, + stride_c, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_N + offs_n = block_start + tl.arange(0, BLOCK_N) + col_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for mid in range(0, M, BLOCK_M): + offs_m = mid + tl.arange(0, BLOCK_M).to(tl.int32) + + # Compute A pointers [BLOCK_N, BLOCK_M] + a_ptrs = A + offs_n[:, None] * stride_am + offs_m[None, :] * stride_an + mask_a = col_mask[:, None] & (offs_m[None, :] < M) + + a_block = tl.load(a_ptrs, mask=mask_a, other=0.0) + + # Compute B pointers [BLOCK_M] + b_ptrs = B + offs_m * stride_b + mask_b = offs_m < M + b_block = tl.load(b_ptrs, mask=mask_b, other=0.0) + + # Reduce along block_m dimension + acc += tl.sum(a_block * b_block[None, :], axis=1).to(tl.float32) + + # Store + c_ptrs = C + offs_n * stride_c + tl.store(c_ptrs, acc.to(C.type.element_ty), mask=col_mask) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2, "A must be a 2D matrix (N x M)" + assert B.dim() == 1, "B must be a 1D vector (M)" + assert A.shape[1] == B.shape[0], "Matrix-vector dimension mismatch" + + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 64 + + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), ) + + mv_kernel[grid]( + A, B, C, + N, M, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M + ) + + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_370413.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py new file mode 100644 index 0000000..9d7ded6 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py @@ -0,0 +1,86 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_bn, stride_cm, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < M + acc = tl.zeros([BLOCK_M], dtype=tl.float32) + + for k in range(0, N, BLOCK_N): + rn = k + tl.arange(0, BLOCK_N) + mask_n = rn < N + a_idx = rm[:, None] * stride_am + rn[None, :] * stride_an + a = tl.load(A + a_idx, mask=(mask_m[:, None] & mask_n[None, :]), other=0.) + b_idx = rn * stride_bn + b = tl.load(B + b_idx, mask=mask_n, other=0.) + acc += tl.sum(a * b[None, :], axis=1) + + c_idx = rm * stride_cm + tl.store(C + c_idx, acc, mask=mask_m) + + +def mv(A: torch.Tensor, B: torch.Tensor): + assert A.dim() == 2 and B.dim() == 1, "A must be 2D and B must be 1D" + M, N = A.shape + assert N == B.shape[0], "Dimension mismatch between matrix and vector" + assert A.device == B.device, "Input tensors must be on the same device" + assert A.dtype in [torch.float16, torch.float32] and B.dtype in [torch.float16, torch.float32] + + C = torch.empty((M,), dtype=A.dtype, device=A.device) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + mv_kernel[grid]( + A, B, C, M, N, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_M=64, + BLOCK_N=64 + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_424820.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py new file mode 100644 index 0000000..1164574 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py @@ -0,0 +1,88 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, N, M, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_n = tl.program_id(0) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for m_start in range(0, M, BLOCK_M): + offs_m_cur = m_start + offs_m + mask_m = offs_m_cur < M + offs_a = A + offs_n[:, None] * M + offs_m_cur[None, :] + mask_a = (offs_n[:, None] < N) & mask_m[None, :] + a_block = tl.load(offs_a, mask=mask_a, other=0.0) + offs_b = B + offs_m_cur + b_vals = tl.load(offs_b, mask=mask_m, other=0.0) + acc += tl.sum(a_block * b_vals[None, :], axis=1) + + offs_c = C + offs_n + mask_c = offs_n < N + tl.store(offs_c, acc.to(C.type.element_ty), mask=mask_c) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2 and B.dim() == 1, "A must be 2-D and B must be 1-D" + N, M = A.shape + assert B.shape[0] == M, "Dimension mismatch: B must have size M where A is NxM" + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 64 + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, + N, M, + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M + ) + + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554113.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py new file mode 100644 index 0000000..e5e6119 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py @@ -0,0 +1,101 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel( + A_ptr, B_ptr, C_ptr, + M, N, + stride_am, stride_an, + stride_bn, + stride_cm, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M], dtype=tl.float32) + + for start_n in range(0, N, BLOCK_N): + current_n = start_n + offs_n + a_ptrs = A_ptr + (offs_m[:, None] * stride_am + current_n[None, :] * stride_an) + b_ptrs = B_ptr + current_n * stride_bn + mask = (offs_m[:, None] < M) & (current_n[None, :] < N) + + a_vals = tl.load(a_ptrs, mask=mask, other=0.0).to(tl.float32) + b_vals = tl.load(b_ptrs, mask=current_n < N, other=0.0).to(tl.float32) + + acc += tl.sum(a_vals * b_vals[None, :], axis=1) + + c_ptrs = C_ptr + offs_m * stride_cm + mask = offs_m < M + tl.store(c_ptrs, acc.to(tl.float32), mask=mask) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2, "A must be 2-D (M x N)" + assert B.dim() == 1, "B must be 1-D (N)" + M, N = A.shape + assert B.shape[0] == N, "B must have same length as N dimension of A" + + C = torch.empty(M, device=A.device, dtype=A.dtype) + + BLOCK_M = 32 + BLOCK_N = 32 + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']),) + + mv_kernel[grid]( + A, B, C, + M, N, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_554981.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py new file mode 100644 index 0000000..521353b --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py @@ -0,0 +1,98 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M], dtype=tl.float32) + + for k in range(0, N, BLOCK_N): + rn_k = k + rn + mask_a = (rm[:, None] < M) & (rn_k[None, :] < N) + mask_b = rn_k < N + + a_ptrs = A + (rm[:, None] * stride_am + rn_k[None, :] * stride_an) + b_ptrs = B + rn_k * stride_b + + a_block = tl.load(a_ptrs, mask=mask_a, other=0.0).to(tl.float32) + b_block = tl.load(b_ptrs, mask=mask_b, other=0.0).to(tl.float32) + + acc += tl.sum(a_block * b_block[None, :], axis=1) + + mask_c = rm < M + c_ptrs = C + rm * stride_c + tl.store(c_ptrs, acc, mask=mask_c) + + +def mv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dim() == 2 + assert b.dim() == 1 + assert a.size(1) == b.size(0) + + M, N = a.shape + C = torch.empty(M, dtype=a.dtype, device=a.device) + + BLOCK_M = 64 + BLOCK_N = 64 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']),) + + mv_kernel[grid]( + a, b, C, + M, N, + a.stride(0), a.stride(1), + b.stride(0), + C.stride(0), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N + ) + + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_561330.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py new file mode 100644 index 0000000..f5e907f --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py @@ -0,0 +1,95 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # Compute block pointer for C + c_ptrs = C + offs_m * stride_c + mask_m = offs_m < M + c_acc = tl.zeros([BLOCK_M], dtype=tl.float32) + + # Load A (BLOCK_M x BLOCK_N) + a_blk_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + # Load and compute + for start_n in range(0, N, BLOCK_N): + # Offset to current block in N + curr_n = start_n + offs_n + mask_n = curr_n < N + a_ptrs = a_blk_ptrs + start_n * stride_an + a = tl.load(a_ptrs, mask=(mask_m[:, None] & mask_n[None, :]), other=0.0) + + b_ptrs = B + curr_n * stride_b + b = tl.load(b_ptrs, mask=mask_n, other=0.0).to(tl.float32) + c_acc += tl.sum(a.to(tl.float32) * b[None, :], axis=1) + + tl.store(c_ptrs, c_acc.to(C.type.element_ty), mask=mask_m) + + +def mv(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor = None): + assert A.dim() == 2 and B.dim() == 1, "A must be 2-D and B must be 1-D" + M, N = A.shape + assert B.shape[0] == N, "Inner dimensions must match" + if C is None: + C = torch.empty(M, dtype=A.dtype, device=A.device) + + stride_am = A.stride(0) + stride_an = A.stride(1) + stride_b = B.stride(0) + stride_c = C.stride(0) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']),) + mv_kernel[grid]( + A, B, C, M, N, + stride_am, stride_an, stride_b, stride_c, + BLOCK_M=64, BLOCK_N=32 + ) + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_686366.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py new file mode 100644 index 0000000..bc75064 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py @@ -0,0 +1,99 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel(A, B, C, M, N, stride_am, stride_an, stride_b, stride_c, + BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_N + tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + offs_b_base = tl.arange(0, BLOCK_M) + + for k in range(0, M, BLOCK_M): + offs_k = k + offs_b_base + mask_A = (offs_m[:, None] < N) & (offs_k[None, :] < M) + offs_A = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_an + a_tile = tl.load(offs_A, mask=mask_A, other=0.0) + mask_B = offs_k < M + offs_B = B + offs_k * stride_b + b_vec = tl.load(offs_B, mask=mask_B, other=0.0) + acc += tl.sum(a_tile * b_vec[None, :], axis=1) + + offs_c = C + offs_m * stride_c + mask_c = offs_m < N + tl.store(offs_c, acc.to(C.type.element_ty), mask=mask_c) + + +def mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.dim() == 2, "Input tensor A must be 2D (N x M)" + assert B.dim() == 1, "Input tensor B must be 1D" + N, M = A.shape + assert B.shape[0] == M, "Incompatible dimensions for MV multiplication" + + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 32 + + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), ) + + mv_kernel[grid]( + A, + B, + C, + M, + N, + A.stride(0), + A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, + ) + + return C + + +################################################################################################################################################## + + + + + +def test_mv(): + + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + + triton_result_2 = mv(A, B) + + + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + + A = torch.randn(32, 16, device='cuda') + + B = torch.randn(16, device='cuda') + + triton_result_3 = mv(A, B) + + + + return { + + "test_case_2": triton_result_2, + + "test_case_3": triton_result_3, + + } + + + +result_gold = test_mv() diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stderr b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stderr new file mode 100644 index 0000000..59ad8c4 --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stderr @@ -0,0 +1,2 @@ +/opt/conda/envs/py_3.12/lib/python3.12/site-packages/redis/connection.py:77: UserWarning: redis-py works best with hiredis. Please consider installing + warnings.warn(msg) diff --git a/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stdout b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/matrix_vector_multip.py_gen_triton_code_80693.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py b/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py new file mode 100644 index 0000000..9a97e64 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py @@ -0,0 +1,312 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_nheads, + stride_out_seqlen, + stride_out_headdim, + stride_x_batch, + stride_x_nheads, + stride_x_seqlen, + stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + + if not IS_VARLEN: + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + cos_ptr = COS + pid_batch * seqlen_ro * (rotary_dim // 2) + sin_ptr = SIN + pid_batch * seqlen_ro * (rotary_dim // 2) + seqlen_i = seqlen + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen_i = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + cos_ptr = COS + sin_ptr = SIN + + if pid_m * BLOCK_M >= seqlen_i: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + # Masks + mask_m = rm < seqlen_i + mask_k_half = rk_half < (rotary_dim // 2) + + if not INTERLEAVED: + # Non-interleaved: contiguous real and imag parts + x_real_offset = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x_imag_offset = x_real_offset + (rotary_dim // 2) * stride_x_headdim + + x_real = tl.load(x_real_offset, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + x_imag = tl.load(x_imag_offset, mask=mask_m[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + + cos_offset = cos_ptr + rm_cs[:, None] * (rotary_dim // 2) + rk_half[None, :] + sin_offset = sin_ptr + rm_cs[:, None] * (rotary_dim // 2) + rk_half[None, :] + + cos = tl.load(cos_offset, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(sin_offset, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half[None, :], other=0.0).to(tl.float32) + + if CONJUGATE: + sin_val = -sin_val + + o_real = x_real * cos - x_imag * sin_val + o_imag = x_real * sin_val + x_imag * cos + + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o_real, mask=mask_m[:, None] & mask_k_half[None, :]) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rotary_dim // 2 + rk_half[None, :]) * stride_out_headdim, + o_imag, mask=mask_m[:, None] & mask_k_half[None, :]) + else: + # Interleaved: even indices real, odd indices imag + rk_even = rk * 2 + rk_odd = rk * 2 + 1 + rk_half = rk // 2 + + mask_k_even = (rk_even < rotary_dim) + mask_k_odd = (rk_odd < rotary_dim) + mask_k_half_ready = rk_half < (rotary_dim // 2) + + cos_offset = cos_ptr + rm_cs[:, None] * (rotary_dim // 2) + rk_half[None, :] + sin_offset = sin_ptr + rm_cs[:, None] * (rotary_dim // 2) + rk_half[None, :] + + cos_val = tl.load(cos_offset, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half_ready[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(sin_offset, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half_ready[None, :], other=0.0).to(tl.float32) + + if CONJUGATE: + sin_val = -sin_val + + x_even_offset = x_ptr + rm[:, None] * stride_x_seqlen + rk_even[None, :] * stride_x_headdim + x_odd_offset = x_ptr + rm[:, None] * stride_x_seqlen + rk_odd[None, :] * stride_x_headdim + + x_even = tl.load(x_even_offset, mask=mask_m[:, None] & mask_k_even[None, :], other=0.0).to(tl.float32) + x_odd = tl.load(x_odd_offset, mask=mask_m[:, None] & mask_k_odd[None, :], other=0.0).to(tl.float32) + + grouped_even = x_even.reshape([-1, x_even.shape[1] // 2, 2]) + grouped_odd = x_odd.reshape([-1, x_odd.shape[1] // 2, 2]) + + grouped_even_t = grouped_even[:, :, 0] + grouped_odd_t = grouped_odd[:, :, 0] + + out_even = grouped_even_t * cos_val - grouped_odd_t * sin_val + out_odd = grouped_even_t * sin_val + grouped_odd_t * cos_val + + out_even_unpacked = out_even.reshape(x_even.shape) + out_odd_unpacked = out_odd.reshape(x_even.shape) + + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_even[None, :] * stride_out_headdim, + out_even_unpacked, mask=mask_m[:, None] & mask_k_even[None, :]) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_odd[None, :] * stride_out_headdim, + out_odd_unpacked, mask=mask_m[:, None] & mask_k_odd[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + batch_ro, seqlen_ro, rotary_dim_half = cos.shape + + assert batch == batch_ro, f"batch mismatch: {batch} != {batch_ro}" + assert sin.shape == cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim, f"rotary_dim ({rotary_dim}) must be <= headdim ({headdim})" + assert cos.dtype == sin.dtype == x.dtype, "All dtypes must match" + assert not (cu_seqlens is not None and max_seqlen is None), "max_seqlen is required with cu_seqlens" + + seqlen_ro_needed = seqlen + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.to(torch.int32).contiguous() + seqlen_ro_needed += seqlen_offsets.max().item() + else: + seqlen_ro_needed += seqlen_offsets + assert seqlen_ro >= seqlen_ro_needed, f"seqlen_ro ({seqlen_ro}) must be >= seqlen_ro_needed ({seqlen_ro_needed})" + + cos = cos.contiguous() + sin = sin.contiguous() + + output = x if inplace else torch.empty_like(x) + if not inplace and rotary_dim < headdim: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + IS_VARLEN = (cu_seqlens is not None) + CU_SEQLENS_ptr = (cu_seqlens.int().contiguous() if IS_VARLEN else None) + + rotary_kernel[grid]( + output, + x, + cos, + sin, + CU_SEQLENS_ptr, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + x.stride(0), + x.stride(2), + x.stride(1), + x.stride(3), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=IS_VARLEN, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stdout new file mode 100644 index 0000000..c9c90b6 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_105954.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_105954 due to not enough values to unpack (expected 3, got 2) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py b/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py new file mode 100644 index 0000000..8230c46 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py @@ -0,0 +1,237 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + X, COS, SIN, CU_SEQLENS, SEQLENS, OUT, + stride_batch, stride_seqlen, stride_head, stride_dim, + rotary_dim, max_seqlen, total_seqlens, + nheads, seqlen_ro, interleaved, conj, BLOCK_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_batch: + return + if pid_head >= nheads: + return + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_start + else: + seq_start = pid_batch * max_seqlen + seqlen_i = tl.load(SEQLENS + pid_batch) if SEQLENS is not None else max_seqlen + + if pid_m >= seqlen_i: + return + + offset_m = seq_start + pid_m + + rotary_dim_half = rotary_dim // 2 + BLOCK_K = tl.min(BLOCK_SIZE_M, rotary_dim_half) + for k in range(0, rotary_dim_half, BLOCK_K): + k_idx = k + tl.arange(0, BLOCK_K) + mask = k_idx < rotary_dim_half + + pos_m = pid_m + cos_idx = pos_m * rotary_dim + k_idx + cos_offset = COS + cos_idx + cos_val = tl.load(cos_offset, mask=mask).to(tl.float32) + + sin_idx = pos_m * rotary_dim + k_idx + sin_offset = SIN + sin_idx + sin_val = tl.load(sin_offset, mask=mask).to(tl.float32) + if conj: + sin_val = -sin_val + + if interleaved: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + 1 + mask_2 = 2 * k_idx + 1 < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask_2).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_2).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask_2) + tl.store(OUT + x_idx1, out1, mask=mask_2) + else: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + rotary_dim_half + mask_half = k_idx + rotary_dim_half < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_half).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask_half) + + if IS_EVEN_K: + return + + k_offset = rotary_dim + total_dim = stride_dim + for d in range(rotary_dim, total_dim): + idx = offset_m * stride_seqlen + pid_head * stride_head + d + val = tl.load(X + idx) + tl.store(OUT + idx, val.astype(X.dtype.element_ty)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seq_offset: int = 0, + cu_seqlens: Optional[torch.Tensor] = None, + seqlens: Optional[torch.Tensor] = None, + rotary_dim: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + rotary_dim = rotary_dim or headdim + + assert rotary_dim % 2 == 0, "rotary_dim must be even" + assert headdim >= rotary_dim, "headdim must be >= rotary_dim" + assert cos.shape == (seqlen, rotary_dim), f"cos shape mismatch: {cos.shape} vs ({seqlen}, {rotary_dim})" + assert sin.shape == (seqlen, rotary_dim), f"sin shape mismatch: {sin.shape} vs ({seqlen}, {rotary_dim})" + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + BLOCK_SIZE_M = min(max(triton.next_power_of_2(rotary_dim // 2), 16), 64) + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, META["BLOCK_M"])) + + rotary_kernel[grid]( + x, + cos, + sin, + cu_seqlens, + seqlens, + out, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + rotary_dim, + seqlen, + batch * seqlen, + nheads, + seqlen, + interleaved, + conjugate, + BLOCK_SIZE_M, + IS_EVEN_K=(headdim == rotary_dim), + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stdout new file mode 100644 index 0000000..d1ae5e4 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_260701.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_260701 due to cos shape mismatch: torch.Size([128, 16]) vs (128, 64) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py b/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py new file mode 100644 index 0000000..ba93256 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py @@ -0,0 +1,287 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, # int32 + nheads, # int32 + rotary_dim, # int32 + seqlen_ro, # int32 + CACHE_KEY_SEQLEN, # int32 + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K : tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR : tl.constexpr, + IS_VARLEN : tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE : tl.constexpr, + BLOCK_M : tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch= tl.program_id(1) + pid_head = tl.program_id(2) + + rot_half = rotary_dim // 2 + offset_batch = pid_batch * stride_x_batch if IS_VARLEN == 0 else 0 + cu_b = 0 + cur_seqlen = seqlen + if IS_VARLEN != 0: + cu_b = tl.load(CU_SEQLENS + pid_batch) + cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - cu_b + offset_x_batch = cu_b * stride_x_seqlen + pid_head * stride_x_nheads + offset_o_batch = cu_b * stride_out_seqlen + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < cur_seqlen + + seq_off = tl.load(SEQLEN_OFFSETS + pid_batch) if IS_SEQLEN_OFFSETS_TENSOR else SEQLEN_OFFSETS + base_t = rm + seq_off + + offs_k = tl.arange(0, BLOCK_K) + + for k_base in range(0, rot_half, BLOCK_K): + k = k_base + offs_k + mask_k = k < rot_half + + idx_cos_s = base_t[:, None] * rot_half + k[None, :] + mask_cs = (base_t[:, None] < seqlen_ro) & mask_k[None, :] + cos = tl.load(COS + idx_cos_s, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN + idx_cos_s, mask=mask_cs, other=0.0).to(tl.float32) + + if INTERLEAVED == 0: + idx0 = rm[:, None] * stride_x_seqlen + (k[None, :] * stride_x_headdim) + idx1 = rm[:, None] * stride_x_seqlen + ((k[None, :] + rot_half) * stride_x_headdim) + x0 = tl.load(X + offset_x_batch + idx0, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(X + offset_x_batch + idx1, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE != 0: + sin = -sin + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + tl.store(OUT + offset_o_batch + idx0, + y0, mask=mask_m[:, None] & mask_k[None, :]) + tl.store(OUT + offset_o_batch + idx1, + y1, mask=mask_m[:, None] & mask_k[None, :]) + else: + idx_even = rm[:, None] * stride_x_seqlen + (2 * k[None, :] * stride_x_headdim) + idx_odd = rm[:, None] * stride_x_seqlen + ((2 * k[None, :] + 1) * stride_x_headdim) + real = tl.load(X + offset_x_batch + idx_even, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + imag = tl.load(X + offset_x_batch + idx_odd, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE != 0: + sin = -sin + new_real = real * cos - imag * sin + new_imag = real * sin + imag * cos + tl.store(OUT + offset_o_batch + idx_even, + new_real, mask=mask_m[:, None] & mask_k[None, :]) + tl.store(OUT + offset_o_batch + idx_odd, + new_imag, mask=mask_m[:, None] & mask_k[None, :]) + + for k_base in range(rotary_dim, stride_x_headdim, BLOCK_K): + k = k_base + offs_k + mask_k = k < stride_x_headdim + idx = rm[:, None] * stride_x_seqlen + k[None, :] * stride_x_headdim + val = tl.load(X + offset_x_batch + idx, + mask=mask_m[:, None] & mask_k[None, :]) + tl.store(OUT + offset_o_batch + idx, + val, mask=mask_m[:, None] & mask_k[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.numel() - 1 + seqlen = max_seqlen + + seqlen_ro, rot_half = cos.shape + rotary_dim = rot_half * 2 + assert rotary_dim <= headdim + assert seqlen_ro >= seqlen + assert rotary_dim % 2 == 0 + assert cos.dtype == sin.dtype == x.dtype + assert headdim <= 512 + + x = x.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.to(torch.int32).contiguous() + else: + assert seqlen + seqlen_offsets <= seqlen_ro + + cu_seqlens_host = None + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.to(torch.int32).contiguous() + cu_seqlens_host = cu_seqlens + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = triton.next_power_of_2(min(rotary_dim // 2, 128)) + + grid = (triton.cdiv(seqlen, 4), batch, nheads) + + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + 0, + output.stride(-4) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + x.stride(-4) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + 4, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_329295.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py b/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py new file mode 100644 index 0000000..5867d85 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py @@ -0,0 +1,321 @@ + +import torch +import triton +import triton.language as tl +from typing import Union, Optional + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = x0_ptr + rotary_dim_half * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + sin_ptr = SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + + cos = tl.load( + cos_ptr, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_ptr, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load( + x0_ptr, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x1 = tl.load( + x1_ptr, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + + out0_ptr = OUT + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_ptr = out0_ptr + rotary_dim_half * stride_out_headdim + tl.store( + out0_ptr, + o0, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + tl.store( + out1_ptr, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 + rk_repeat = tl.arange(0, BLOCK_K) // 2 + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + sin_ptr = SIN + rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :] + + cos = tl.load( + cos_ptr, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_ptr, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load( + x0_ptr, + mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), + other=0.0, + ).to(tl.float32) + x1 = tl.load( + x1_ptr, + mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), + other=0.0, + ).to(tl.float32) + + if CONJUGATE: + sin = -sin + out = tl.where(rk[None, :] % 2 == 0, x0 * cos - x1 * sin, x0 * cos + x1 * sin) + out_ptr = OUT + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + tl.store(out_ptr, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + assert x.ndim == 4, "Expected 4-D tensor [batch, seqlen, heads, dim] for non-varlen inputs" + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is provided, max_seqlen must be specified" + assert x.ndim == 3, "Expected 3-D tensor [total_seqlen, heads, dim] for varlen inputs" + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert cos.dtype == sin.dtype and x.dtype == cos.dtype + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert int(seqlen_offsets) + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + (output[..., rotary_dim:].copy_(x[..., rotary_dim:]) if not is_varlen + else output[:, rotary_dim:].copy_(x[:, rotary_dim:])) + + BLOCK_K = 32 if rotary_dim <= 32 else ( + 64 if rotary_dim <= 64 else ( + 128 if rotary_dim <= 128 else 256 + ) + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + def stride_or_zero(tensor, idx, fixed=None): + return tensor.stride(idx) if fixed is None else fixed + + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, + stride_or_zero(output, -4, 0) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + stride_or_zero(x, -4, 0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_338032.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py b/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py new file mode 100644 index 0000000..fda87e6 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py @@ -0,0 +1,289 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, # *float32 + X, # *float32 + COS, # *float32 + SIN, # *float32 + CU_SEQLENS, # *int32 + SEQLEN_OFFSETS, # *int32 + seqlen, # int32 + rotary_dim, # int32 # rotary dimension (must be even) + seqlen_ro, # int32 # rotary sequence length + stride_out_batch, # int64 + stride_out_seqlen, # int64 + stride_out_nheads, # int64 + stride_out_headdim, # int64 + stride_x_batch, # int64 + stride_x_seqlen, # int64 + stride_x_nheads, # int64 + stride_x_headdim, # int64 + BLOCK_K: tl.constexpr, # rotary dimension (must be even) + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, # bool + IS_VARLEN: tl.constexpr, # bool + INTERLEAVED: tl.constexpr, # bool + CONJUGATE: tl.constexpr, # bool + BLOCK_M: tl.constexpr, # block size along sequence dimension +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + if not IS_VARLEN: + offset_b = pid_batch * stride_x_batch + offset_bo = pid_batch * stride_out_batch + current_seqlen = seqlen + else: + seqlen_start = tl.load(CU_SEQLENS + pid_batch) + seqlen_end = tl.load(CU_SEQLENS + pid_batch + 1) + current_seqlen = seqlen_end - seqlen_start + offset_b = seqlen_start * stride_x_seqlen + offset_bo = seqlen_start * stride_out_seqlen + + X = X + offset_b + pid_head * stride_x_nheads + OUT = OUT + offset_bo + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= current_seqlen: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < current_seqlen + + if IS_SEQLEN_OFFSETS_TENSOR: + seqlen_offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + seqlen_offset = SEQLEN_OFFSETS + + rk_half = tl.arange(0, BLOCK_K // 2) + rk_full = tl.arange(0, BLOCK_K) + + if not INTERLEAVED: + # Non-interleaved + cos_offset = (rm[:, None] + seqlen_offset) * rotary_dim + rk_half[None, :] + cos = tl.load(COS + cos_offset, + mask=((rm[:, None] + seqlen_offset) < seqlen_ro) & (rk_half[None, :] < rotary_dim//2), + other=1.0).to(tl.float32) + sin = tl.load(SIN + cos_offset, + mask=((rm[:, None] + seqlen_offset) < seqlen_ro) & (rk_half[None, :] < rotary_dim//2), + other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + x0_offset = rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x0 = tl.load(X + x0_offset, mask=mask_m[:, None] & (rk_half[None, :] < rotary_dim//2), other=0.0).to(tl.float32) + x1_offset = rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim//2) * stride_x_headdim + x1 = tl.load(X + x1_offset, mask=mask_m[:, None] & (rk_half[None, :] < rotary_dim//2), other=0.0).to(tl.float32) + + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + + tl.store(OUT + x0_offset, y0, mask=mask_m[:, None] & (rk_half[None, :] < rotary_dim//2)) + tl.store(OUT + x1_offset, y1, mask=mask_m[:, None] & (rk_half[None, :] < rotary_dim//2)) + + # Remaining dimensions + if rotary_dim < BLOCK_K: + rk_rem = tl.arange(rotary_dim, BLOCK_K) + x_rem = tl.load(X + rm[:, None] * stride_x_seqlen + rk_rem[None, :] * stride_x_headdim, + mask=mask_m[:, None] & (rk_rem[None, :] < BLOCK_K), other=0.0) + tl.store(OUT + rm[:, None] * stride_out_seqlen + rk_rem[None, :] * stride_out_headdim, + x_rem, mask=mask_m[:, None] & (rk_rem[None, :] < BLOCK_K)) + + else: + # Interleaved + cos_offset = (rm[:, None] + seqlen_offset) * rotary_dim + (rk_full[None, :]//2) + cos = tl.load(COS + cos_offset, + mask=((rm[:, None] + seqlen_offset) < seqlen_ro) & (rk_full[None, :] < rotary_dim), + other=1.0).to(tl.float32) + sin = tl.load(SIN + cos_offset, + mask=((rm[:, None] + seqlen_offset) < seqlen_ro) & (rk_full[None, :] < rotary_dim), + other=0.0).to(tl.float32) + + x_offset = rm[:, None] * stride_x_seqlen + rk_full[None, :] * stride_x_headdim + x = tl.load(X + x_offset, mask=mask_m[:, None] & (rk_full[None, :] < rotary_dim), other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + rk_even = (rk_full[None, :] % 2) == 0 + y = tl.where(rk_even, x * cos - x * sin.flip(1), x * sin + x * cos.flip(1)) + + tl.store(OUT + x_offset, y, mask=mask_m[:, None] & (rk_full[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + + assert rotary_dim <= headdim + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype == x.dtype + assert rotary_dim % 2 == 0 + + if cu_seqlens is not None: + assert cu_seqlens.dtype == torch.int32 + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.dtype == torch.int32 + seqlen_offsets = seqlen_offsets.contiguous() + + x = x.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = triton.next_power_of_2(rotary_dim) + + if cu_seqlens is None: + grid_m = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + else: + grid_m = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), cu_seqlens.shape[0]-1, nheads) + + BLOCK_M = 4 if interleaved else 8 + + rotary_kernel[grid_m]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets if isinstance(seqlen_offsets, torch.Tensor) else None, + seqlen, + rotary_dim, + seqlen_ro, + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=cu_seqlens is not None, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stdout new file mode 100644 index 0000000..0e677b7 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_339628.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_339628 due to at 60:22: + + if IS_SEQLEN_OFFSETS_TENSOR: + seqlen_offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + seqlen_offset = SEQLEN_OFFSETS + + rk_half = tl.arange(0, BLOCK_K // 2) + rk_full = tl.arange(0, BLOCK_K) + + if not INTERLEAVED: + # Non-interleaved + cos_offset = (rm[:, None] + seqlen_offset) * rotary_dim + rk_half[None, :] + ^ +AssertionError("cannot convert None of type to tensor") diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py b/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py new file mode 100644 index 0000000..8230c46 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py @@ -0,0 +1,237 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + X, COS, SIN, CU_SEQLENS, SEQLENS, OUT, + stride_batch, stride_seqlen, stride_head, stride_dim, + rotary_dim, max_seqlen, total_seqlens, + nheads, seqlen_ro, interleaved, conj, BLOCK_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_batch: + return + if pid_head >= nheads: + return + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_start + else: + seq_start = pid_batch * max_seqlen + seqlen_i = tl.load(SEQLENS + pid_batch) if SEQLENS is not None else max_seqlen + + if pid_m >= seqlen_i: + return + + offset_m = seq_start + pid_m + + rotary_dim_half = rotary_dim // 2 + BLOCK_K = tl.min(BLOCK_SIZE_M, rotary_dim_half) + for k in range(0, rotary_dim_half, BLOCK_K): + k_idx = k + tl.arange(0, BLOCK_K) + mask = k_idx < rotary_dim_half + + pos_m = pid_m + cos_idx = pos_m * rotary_dim + k_idx + cos_offset = COS + cos_idx + cos_val = tl.load(cos_offset, mask=mask).to(tl.float32) + + sin_idx = pos_m * rotary_dim + k_idx + sin_offset = SIN + sin_idx + sin_val = tl.load(sin_offset, mask=mask).to(tl.float32) + if conj: + sin_val = -sin_val + + if interleaved: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + 1 + mask_2 = 2 * k_idx + 1 < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask_2).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_2).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask_2) + tl.store(OUT + x_idx1, out1, mask=mask_2) + else: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + rotary_dim_half + mask_half = k_idx + rotary_dim_half < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_half).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask_half) + + if IS_EVEN_K: + return + + k_offset = rotary_dim + total_dim = stride_dim + for d in range(rotary_dim, total_dim): + idx = offset_m * stride_seqlen + pid_head * stride_head + d + val = tl.load(X + idx) + tl.store(OUT + idx, val.astype(X.dtype.element_ty)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seq_offset: int = 0, + cu_seqlens: Optional[torch.Tensor] = None, + seqlens: Optional[torch.Tensor] = None, + rotary_dim: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + rotary_dim = rotary_dim or headdim + + assert rotary_dim % 2 == 0, "rotary_dim must be even" + assert headdim >= rotary_dim, "headdim must be >= rotary_dim" + assert cos.shape == (seqlen, rotary_dim), f"cos shape mismatch: {cos.shape} vs ({seqlen}, {rotary_dim})" + assert sin.shape == (seqlen, rotary_dim), f"sin shape mismatch: {sin.shape} vs ({seqlen}, {rotary_dim})" + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + BLOCK_SIZE_M = min(max(triton.next_power_of_2(rotary_dim // 2), 16), 64) + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, META["BLOCK_M"])) + + rotary_kernel[grid]( + x, + cos, + sin, + cu_seqlens, + seqlens, + out, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + rotary_dim, + seqlen, + batch * seqlen, + nheads, + seqlen, + interleaved, + conjugate, + BLOCK_SIZE_M, + IS_EVEN_K=(headdim == rotary_dim), + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stdout new file mode 100644 index 0000000..1fccb18 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_344391.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_344391 due to cos shape mismatch: torch.Size([128, 16]) vs (128, 64) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py b/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py new file mode 100644 index 0000000..6066a2d --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py @@ -0,0 +1,343 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + OUT, + CU_SEQLENS, + SEQLENS, + stride_x_batch, + stride_x_head, + stride_x_m, + stride_x_k, + stride_c_stride, + stride_cos_m, + stride_cos_k, + stride_sin_m, + stride_sin_k, + stride_out_batch, + stride_out_head, + stride_out_m, + stride_out_k, + n_ctx, + HEAD_K: tl.constexpr, + IS_VARIABLE_KV: tl.constexpr, + CONJUGATE: tl.constexpr, + INTERLEAVED: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) * BLOCK_H + tl.arange(0, BLOCK_H) + pid_m = tl.program_id(2) * BLOCK_M + tl.arange(0, BLOCK_M) + + mask_h = pid_head < HEAD_K + mask_m = pid_m < n_ctx + + if IS_VARIABLE_KV: + # Handle variable sequence lengths + cu_seq = tl.load(CU_SEQLENS + pid_batch) + seq_len = tl.load(SEQLENS + pid_batch) + offset_m = cu_seq + pid_m + else: + # Handle fixed sequence length + offset_m = pid_batch * n_ctx + pid_m + seq_len = n_ctx + + mask_seq = pid_m < seq_len + + if INTERLEAVED: + # Interleaved format: real and imag parts are interleaved + load_real_idx = 2 * pid_m + 0 + load_imag_idx = 2 * pid_m + 1 + + off_real = ( + pid_batch * stride_x_batch + + pid_head[None, :] * stride_x_head + + load_real_idx[:, None] * stride_x_m + + tl.arange(0, HEAD_K // 2)[None, :] * stride_x_k + ) + off_imag = ( + pid_batch * stride_x_batch + + pid_head[None, :] * stride_x_head + + load_imag_idx[:, None] * stride_x_m + + tl.arange(0, HEAD_K // 2)[None, :] * stride_x_k + ) + + # Load real and imaginary parts + x_real = tl.load(X + off_real, mask=mask_m[:, None] & mask_h[None, :], other=0.0) + x_imag = tl.load(X + off_imag, mask=mask_m[:, None] & mask_h[None, :], other=0.0) + + # Load COS and SIN + off_cos_m = offset_m[:, None] * stride_cos_m + off_sin_m = offset_m[:, None] * stride_sin_m + + # Get the right dimension for COS/SIN + off_cos_real = ( + off_cos_m + + (2 * tl.arange(0, HEAD_K // 2))[None, :] * stride_cos_k + ) + off_sin_real = ( + off_sin_m + + (2 * tl.arange(0, HEAD_K // 2))[None, :] * stride_sin_k + ) + off_cos_imag = ( + off_cos_m + + (2 * tl.arange(0, HEAD_K // 2) + 1)[None, :] * stride_cos_k + ) + off_sin_imag = ( + off_sin_m + + (2 * tl.arange(0, HEAD_K // 2) + 1)[None, :] * stride_sin_k + ) + + cos_real = tl.load(COS + off_cos_real, mask=mask_m[:, None], other=1.0) + sin_real = tl.load(SIN + off_sin_real, mask=mask_m[:, None], other=0.0) + cos_imag = tl.load(COS + off_cos_imag, mask=mask_m[:, None], other=0.0) + sin_imag = tl.load(SIN + off_sin_imag, mask=mask_m[:, None], other=0.0) + + else: + # Non-interleaved format: first half is real, second half is imag + half_k = HEAD_K // 2 + + # Offsets for real and imaginary parts + off_real = ( + pid_batch * stride_x_batch + + pid_head[None, :] * stride_x_head + + pid_m[:, None] * stride_x_m + + tl.arange(0, half_k)[None, :] * stride_x_k + ) + off_imag = ( + pid_batch * stride_x_batch + + (half_k + pid_head)[None, :] * stride_x_head + + pid_m[:, None] * stride_x_m + + tl.arange(0, half_k)[None, :] * stride_x_k + ) + + # Load real and imaginary parts + x_real = tl.load(X + off_real, mask=mask_m[:, None] & (pid_head < half_k)[None, :], other=0.0) + x_imag = tl.load(X + off_imag, mask=mask_m[:, None] & (pid_head >= half_k)[None, :], other=0.0) + + # Load COS and SIN for non-interleaved + off_cos = ( + offset_m[:, None] * stride_cos_m + + tl.arange(0, half_k)[None, :] * stride_cos_k + ) + off_sin = ( + offset_m[:, None] * stride_sin_m + + tl.arange(0, half_k)[None, :] * stride_sin_k + ) + + cos = tl.load(COS + off_cos, mask=mask_m[:, None], other=1.0) + sin = tl.load(SIN + off_sin, mask=mask_m[:, None], other=0.0) + + cos_real = cos + sin_real = sin + cos_imag = cos + sin_imag = sin + + # Compute rotary transform + if CONJUGATE: + # With conjugation + out_real = x_real * cos_real + x_imag * sin_real + out_imag = -x_real * sin_imag + x_imag * cos_imag + else: + # Without conjugation + out_real = x_real * cos_real - x_imag * sin_real + out_imag = x_real * sin_imag + x_imag * cos_imag + + # Store results + if INTERLEAVED: + off_out_real = ( + pid_batch * stride_out_batch + + pid_head[None, :] * stride_out_head + + load_real_idx[:, None] * stride_out_m + + tl.arange(0, HEAD_K // 2)[None, :] * stride_out_k + ) + off_out_imag = ( + pid_batch * stride_out_batch + + pid_head[None, :] * stride_out_head + + load_imag_idx[:, None] * stride_out_m + + tl.arange(0, HEAD_K // 2)[None, :] * stride_out_k + ) + tl.store(OUT + off_out_real, out_real, mask=mask_m[:, None] & mask_h[None, :]) + tl.store(OUT + off_out_imag, out_imag, mask=mask_m[:, None] & mask_h[None, :]) + else: + off_out_real = ( + pid_batch * stride_out_batch + + pid_head[None, :] * stride_out_head + + pid_m[:, None] * stride_out_m + + tl.arange(0, half_k)[None, :] * stride_out_k + ) + off_out_imag = ( + pid_batch * stride_out_batch + + (half_k + pid_head)[None, :] * stride_out_head + + pid_m[:, None] * stride_out_m + + tl.arange(0, half_k)[None, :] * stride_out_k + ) + tl.store(OUT + off_out_real, out_real, mask=mask_m[:, None] & (pid_head < half_k)[None, :]) + tl.store(OUT + off_out_imag, out_imag, mask=mask_m[:, None] & (pid_head >= half_k)[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offset: int = 0, + cu_seqlens: torch.Tensor = None, + seqlens: torch.Tensor = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + # Determine output tensor + out = x if inplace else torch.empty_like(x) + + # Get dimensions + batch, head, n_ctx, head_k = x.shape + + # Determine if variable sequence lengths are used + is_variable_kv = cu_seqlens is not None and seqlens is not None + + # Building the grid + BLOCK_H = 64 + BLOCK_M = 32 + grid = (batch, triton.cdiv(head, BLOCK_H), triton.cdiv(n_ctx, BLOCK_M)) + + # Launch the kernel + rotary_kernel[grid]( + x, + cos, + sin, + out, + cu_seqlens, + seqlens, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + 1 if is_variable_kv else 0, + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + n_ctx, + HEAD_K=head_k, + IS_VARIABLE_KV=is_variable_kv, + CONJUGATE=conjugate, + INTERLEAVED=interleaved, + BLOCK_H=BLOCK_H, + BLOCK_M=BLOCK_M, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stdout new file mode 100644 index 0000000..ae32b4a --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_373163.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_373163 due to at 106:14: + sin_imag = tl.load(SIN + off_sin_imag, mask=mask_m[:, None], other=0.0) + + else: + # Non-interleaved format: first half is real, second half is imag + half_k = HEAD_K // 2 + + # Offsets for real and imaginary parts + off_real = ( + pid_batch * stride_x_batch + + pid_head[None, :] * stride_x_head + + pid_m[:, None] * stride_x_m + + tl.arange(0, half_k)[None, :] * stride_x_k + ^ diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py b/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py new file mode 100644 index 0000000..a4a52e9 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py @@ -0,0 +1,272 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + stride_cos_seqlen, + stride_cos_dim, + stride_sin_seqlen, + stride_sin_dim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, + ROTARY_DIM_HALF: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + seq_start = tl.load(CU_SEQLENS + pid_batch) + cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - seq_start + x_ptr = X + seq_start * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + seq_start * stride_out_seqlen + pid_head * stride_out_nheads + if pid_m * BLOCK_M >= cur_seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk_half = tl.arange(0, BLOCK_K // 2) + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + rm_cs = tl.where(rm_cs < seqlen_ro, rm_cs, seqlen_ro - 1) + if not INTERLEAVED: + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + (rk_half + ROTARY_DIM_HALF)[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + rk_half[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_dim + mask_m = rm[:, None] < cur_seqlen + mask_k_half = rk_half[None, :] < ROTARY_DIM_HALF + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_k_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rk_half + ROTARY_DIM_HALF)[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_k_half) + else: + rk_even = 2 * tl.arange(0, ROTARY_DIM_HALF) + rk_odd = 2 * tl.arange(0, ROTARY_DIM_HALF) + 1 + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_even[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_odd[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + tl.arange(0, ROTARY_DIM_HALF)[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + tl.arange(0, ROTARY_DIM_HALF)[None, :] * stride_sin_dim + mask_m = rm[:, None] < cur_seqlen + mask_half = tl.arange(0, ROTARY_DIM_HALF)[None, :] < ROTARY_DIM_HALF + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_even[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_odd[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_half) + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Apply rotary embedding to the input tensor x using Triton kernels optimized for AMD GPU ROCm.""" + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + if max_seqlen is None: + raise ValueError("max_seqlen must be provided if cu_seqlens is used") + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype == x.dtype + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 if rotary_dim <= 32 else + 64 if rotary_dim <= 64 else + 128 if rotary_dim <= 128 else 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, + 0, + output.stride(0) if not is_varlen else 0, + output.stride(-3), output.stride(-2), output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), x.stride(-2), x.stride(-1), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ROTARY_DIM_HALF=rotary_dim_half + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_385268.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py b/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py new file mode 100644 index 0000000..ea52312 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py @@ -0,0 +1,275 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + CU_SEQLENS, + OUT, + stride_xb, + stride_xh, + stride_xn, + stride_xd, + stride_cosn, + stride_cosd, + stride_sinn, + stride_sind, + stride_cu_off, + stride_ob, + stride_oh, + stride_on, + stride_od, + nheads, + seqlen, + rotary_dim, + interleaved, + conjugate, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_EVEN_N: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + rot_dim_half = rotary_dim // 2 + + if CU_SEQLENS is None: + seq_start = 0 + seq_id = pid_batch + else: + seq_start = 0 + if pid_batch > 0: + seq_start = tl.load(CU_SEQLENS + pid_batch - 1) + seq_end = tl.load(CU_SEQLENS + pid_batch) + seq_id = seq_start + pid_m + if seq_id >= seq_end: + return + + offset_b = seq_id * stride_xb + offset_h = pid_head * stride_xh + offset_n = pid_m * stride_xn + offset_d = tl.arange(0, BLOCK_K) + offset_k = tl.arange(0, BLOCK_N) + + # Compute input pointer base for this element + x_base = X + offset_b + offset_h + offset_n + # Load input values for rotary dimensions + if IS_EVEN_K: + x_rot = tl.load(x_base + offset_d, mask=offset_d < rotary_dim) + else: + mask_d = offset_d < rotary_dim + x_rot = tl.load(x_base + offset_d, mask=mask_d) + + # Compute cosine/sine pointers + cos_base = COS + seq_id * stride_cosn + sin_base = SIN + seq_id * stride_sinn + + # Load cosine and sine values + if IS_EVEN_K: + cos = tl.load(cos_base + offset_d, mask=offset_d < rotary_dim) + sin = tl.load(sin_base + offset_d, mask=offset_d < rotary_dim) + else: + mask_d = offset_d < rotary_dim + cos = tl.load(cos_base + offset_d, mask=mask_d) + sin = tl.load(sin_base + offset_d, mask=mask_d) + + # Split into two halves + x0 = x_rot[:rot_dim_half] if rotary_dim <= BLOCK_K else x_rot[0:rot_dim_half:2] if interleaved else x_rot[:rot_dim_half] + x1 = x_rot[rot_dim_half:] if rotary_dim <= BLOCK_K else x_rot[1:rot_dim_half*2:2] if interleaved else x_rot[rot_dim_half:] + + # Gather corresponding cos/sin for each half + cos0 = cos[:rot_dim_half] if rotary_dim <= BLOCK_K else cos[0:rot_dim_half:2] if interleaved else cos[:rot_dim_half] + cos1 = cos[rot_dim_half:] if rotary_dim <= BLOCK_K else cos[1:rot_dim_half*2:2] if interleaved else cos[rot_dim_half:] + sin0 = sin[:rot_dim_half] if rotary_dim <= BLOCK_K else sin[0:rot_dim_half:2] if interleaved else sin[:rot_dim_half] + sin1 = sin[rot_dim_half:] if rotary_dim <= BLOCK_K else sin[1:rot_dim_half*2:2] if interleaved else sin[rot_dim_half:] + + if conjugate: + sin0 = -sin0 + sin1 = -sin1 + + # Apply rotary transform + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + + # Prepare output pointers + out_base = OUT + offset_b + offset_h + offset_n + + # Store rotary section + if interleaved: + rot_indices = tl.arange(0, rotary_dim) + # Handle interleaved storage pattern + if rotary_dim <= BLOCK_K: + tl.store(out_base + rot_indices[0::2], y0, mask=rot_indices[0::2] < rotary_dim) + tl.store(out_base + rot_indices[1::2], y1, mask=rot_indices[1::2] < rotary_dim) + else: + tl.store(out_base + rot_indices[0::2], y0, mask=rot_indices[0::2] < rotary_dim) + tl.store(out_base + rot_indices[1::2], y1, mask=rot_indices[1::2] < rotary_dim) + else: + if rotary_dim <= BLOCK_K: + tl.store(out_base + offset_d[:rot_dim_half], y0, mask=offset_d[:rot_dim_half] < rotary_dim) + tl.store(out_base + offset_d[rot_dim_half:], y1, mask=offset_d[rot_dim_half:] < rotary_dim) + else: + tl.store(out_base + offset_d[:rot_dim_half], y0, mask=offset_d[:rot_dim_half] < rotary_dim) + tl.store(out_base + offset_d[rot_dim_half:], y1, mask=offset_d[rot_dim_half:] < rotary_dim) + + # Copy non-rotary dimensions + if rotary_dim < BLOCK_K: + if IS_EVEN_K: + x_non_rot = tl.load(x_base + offset_d + rotary_dim, mask=offset_d + rotary_dim < BLOCK_K) + tl.store(out_base + offset_d + rotary_dim, x_non_rot, mask=offset_d + rotary_dim < BLOCK_K) + else: + mask_rest = (offset_d + rotary_dim) < BLOCK_K + x_non_rot = tl.load(x_base + offset_d + rotary_dim, mask=mask_rest) + tl.store(out_base + offset_d + rotary_dim, x_non_rot, mask=mask_rest) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int = 0, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, +) -> torch.Tensor: + rotary_dim = cos.shape[-1] + assert rotary_dim % 2 == 0 + assert cos.shape == sin.shape + + batch, seqlen = x.shape[0], x.shape[2] + nheads = x.shape[1] + + if max_seqlen is None: + max_seqlen = seqlen + + # Determine block parameters + BLOCK_K = triton.next_power_of_2(rotary_dim) + BLOCK_N = triton.next_power_of_2(max_seqlen) + + # Grid dimensions + grid = lambda META: (batch, nheads, triton.cdiv(max_seqlen, META['BLOCK_N'])) + + # Prepare output tensor + out = torch.empty_like(x) + + rotary_kernel[grid]( + x, cos, sin, cu_seqlens, out, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + cu_seqlens.stride(0) if cu_seqlens is not None else 0, + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + nheads, seqlen, rotary_dim, False, False, + BLOCK_K=BLOCK_K, + BLOCK_N=min(BLOCK_N, 128), + IS_EVEN_N=x.shape[2] % min(BLOCK_N, 128) == 0, + IS_EVEN_K=rotary_dim % BLOCK_K == 0, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stdout new file mode 100644 index 0000000..1f4d5d9 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_405620.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_405620 due to at 76:9: + + # Load cosine and sine values + if IS_EVEN_K: + cos = tl.load(cos_base + offset_d, mask=offset_d < rotary_dim) + sin = tl.load(sin_base + offset_d, mask=offset_d < rotary_dim) + else: + mask_d = offset_d < rotary_dim + cos = tl.load(cos_base + offset_d, mask=mask_d) + sin = tl.load(sin_base + offset_d, mask=mask_d) + + # Split into two halves + x0 = x_rot[:rot_dim_half] if rotary_dim <= BLOCK_K else x_rot[0:rot_dim_half:2] if interleaved else x_rot[:rot_dim_half] + ^ +ValueError('unsupported tensor index: slice(None, , None)') diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py b/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py new file mode 100644 index 0000000..8230c46 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py @@ -0,0 +1,237 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + X, COS, SIN, CU_SEQLENS, SEQLENS, OUT, + stride_batch, stride_seqlen, stride_head, stride_dim, + rotary_dim, max_seqlen, total_seqlens, + nheads, seqlen_ro, interleaved, conj, BLOCK_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_batch: + return + if pid_head >= nheads: + return + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_start + else: + seq_start = pid_batch * max_seqlen + seqlen_i = tl.load(SEQLENS + pid_batch) if SEQLENS is not None else max_seqlen + + if pid_m >= seqlen_i: + return + + offset_m = seq_start + pid_m + + rotary_dim_half = rotary_dim // 2 + BLOCK_K = tl.min(BLOCK_SIZE_M, rotary_dim_half) + for k in range(0, rotary_dim_half, BLOCK_K): + k_idx = k + tl.arange(0, BLOCK_K) + mask = k_idx < rotary_dim_half + + pos_m = pid_m + cos_idx = pos_m * rotary_dim + k_idx + cos_offset = COS + cos_idx + cos_val = tl.load(cos_offset, mask=mask).to(tl.float32) + + sin_idx = pos_m * rotary_dim + k_idx + sin_offset = SIN + sin_idx + sin_val = tl.load(sin_offset, mask=mask).to(tl.float32) + if conj: + sin_val = -sin_val + + if interleaved: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + 1 + mask_2 = 2 * k_idx + 1 < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask_2).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_2).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask_2) + tl.store(OUT + x_idx1, out1, mask=mask_2) + else: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + rotary_dim_half + mask_half = k_idx + rotary_dim_half < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_half).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask_half) + + if IS_EVEN_K: + return + + k_offset = rotary_dim + total_dim = stride_dim + for d in range(rotary_dim, total_dim): + idx = offset_m * stride_seqlen + pid_head * stride_head + d + val = tl.load(X + idx) + tl.store(OUT + idx, val.astype(X.dtype.element_ty)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seq_offset: int = 0, + cu_seqlens: Optional[torch.Tensor] = None, + seqlens: Optional[torch.Tensor] = None, + rotary_dim: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + rotary_dim = rotary_dim or headdim + + assert rotary_dim % 2 == 0, "rotary_dim must be even" + assert headdim >= rotary_dim, "headdim must be >= rotary_dim" + assert cos.shape == (seqlen, rotary_dim), f"cos shape mismatch: {cos.shape} vs ({seqlen}, {rotary_dim})" + assert sin.shape == (seqlen, rotary_dim), f"sin shape mismatch: {sin.shape} vs ({seqlen}, {rotary_dim})" + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + BLOCK_SIZE_M = min(max(triton.next_power_of_2(rotary_dim // 2), 16), 64) + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, META["BLOCK_M"])) + + rotary_kernel[grid]( + x, + cos, + sin, + cu_seqlens, + seqlens, + out, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + rotary_dim, + seqlen, + batch * seqlen, + nheads, + seqlen, + interleaved, + conjugate, + BLOCK_SIZE_M, + IS_EVEN_K=(headdim == rotary_dim), + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stdout new file mode 100644 index 0000000..f074403 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_431864.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_431864 due to cos shape mismatch: torch.Size([128, 16]) vs (128, 64) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py b/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py new file mode 100644 index 0000000..ff2ed16 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py @@ -0,0 +1,265 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLENS_OFFSETS, + stride_out_batch, stride_out_head, stride_out_m, stride_out_k, + stride_x_batch, stride_x_head, stride_x_m, stride_x_k, + stride_cos_batch, stride_cos_m, stride_cos_k, + stride_sin_batch, stride_sin_m, stride_sin_k, + rotary_dim, rotary_half, conjugate, + HEADS: tl.constexpr, SEQLEN: tl.constexpr, DIM: tl.constexpr, + IS_VARIABLE: tl.constexpr, INTERLEAVED: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_out_batch: + return + + seqlen_offset = 0 + if IS_VARIABLE: + seqlen_offset = tl.load(SEQLENS_OFFSETS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - tl.load(CU_SEQLENS + pid_batch) + if pid_m >= seq_len: + return + else: + if SEQLEN is not None and pid_m >= SEQLEN: + return + seqlen_offset = tl.load(SEQLENS_OFFSETS + pid_batch) if SEQLENS_OFFSETS else 0 + + rotary_dim = rotary_dim + k = tl.arange(0, BLOCK_K) + + # Compute offsets for X + if INTERLEAVED: + offs_x = ( + pid_batch * stride_x_batch + + pid_head * stride_x_head + + pid_m * stride_x_m + + (k * 2) * stride_x_k + ) + else: + offs_x = ( + pid_batch * stride_x_batch + + pid_head * stride_x_head + + pid_m * stride_x_m + + k * stride_x_k + ) + + # Compute offsets for COS/SIN + offs_cos_sin = pid_m * stride_cos_m + k * stride_cos_k + + # Load COS/SIN + cos = tl.load(COS + offs_cos_sin, mask=k < rotary_dim, other=1.0) + sin = tl.load(SIN + offs_cos_sin, mask=k < rotary_dim, other=0.0) + + # Process rotary pairs + for i in range(0, tl.cdiv(rotary_dim, 2), BLOCK_K // 2): + # Calculate indices for current pair + if INTERLEAVED: + idx = i * 2 + k0 = idx + k1 = idx + 1 + else: + idx = i + k0 = idx + k1 = idx + rotary_half + + # Load x0, x1 + x0 = tl.load(X + offs_x + k0 * stride_x_k, mask=k0 < rotary_dim, other=0.0) + x1 = tl.load(X + offs_x + k1 * stride_x_k, mask=k1 < rotary_dim, other=0.0) + + # Apply rotation + if conjugate: + out0 = x0 * cos - x1 * sin + out1 = x0 * sin + x1 * cos + else: + out0 = x0 * cos + x1 * sin + out1 = -x0 * sin + x1 * cos + + # Store results + tl.store(OUT + offs_x + k0 * stride_x_k, out0, mask=k0 < rotary_dim) + tl.store(OUT + offs_x + k1 * stride_x_k, out1, mask=k1 < rotary_dim) + + # Handle non-rotary dimensions (copy original values) + if rotary_dim < DIM: + for i in range(rotary_dim, DIM, BLOCK_K): + offs_non_rot = ( + pid_batch * stride_x_batch + + pid_head * stride_x_head + + pid_m * stride_x_m + + i * stride_x_k + ) + val = tl.load(X + offs_non_rot) + tl.store(OUT + offs_non_rot, val) + + +def apply_rotary( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, max_seqlen: int = None, interleaved: bool = False, + in_place: bool = False, conjugate: bool = False +) -> torch.Tensor: + dims = x.dim() + assert dims in [3, 4], "Input tensor must be 3D (B, T, D) or 4D (B, H, T, D)" + + if dims == 3: # Treat as (B, T, D) + batch, seqlen, dim = x.shape + heads = 1 + x = x.view(batch, heads, seqlen, dim) + else: # dims == 4: (B, H, T, D) + batch, heads, seqlen, dim = x.shape + + rotary_dim = cos.shape[-1] + rotary_half = rotary_dim // 2 + + assert rotary_dim <= dim, "Rotary dimension must be <= feature dimension" + assert cos.shape == sin.shape, "COS and SIN must have same shape" + assert cos.shape[-1] == rotary_dim, "Last dimension of COS/SIN must match rotary_dim" + + # Prepare output tensor + if in_place: + out = x + else: + out = torch.empty_like(x) + + # Handle max_seqlen for grid dimension + actual_max_seqlen = max_seqlen if max_seqlen is not None else seqlen + + # Prepare sequence offsets + if seqlen_offsets is None: + seqlen_offsets = torch.zeros(batch, dtype=torch.int64, device=x.device) + + # Determine IS_VARIABLE flag + IS_VARIABLE = cu_seqlens is not None + + # Grid configuration + grid = lambda META: ( + batch, + heads, + triton.cdiv(actual_max_seqlen, META["BLOCK_M"]) + ) + + # Configure block sizes (tune these parameters) + BLOCK_M = min(64, actual_max_seqlen) + BLOCK_K = min(64, rotary_dim) + + rotary_kernel[grid]( + out, x, cos, sin, cu_seqlens, seqlen_offsets, + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cos.stride(0) if cos.dim() > 1 else 0, cos.stride(-2) if cos.dim() > 1 else 0, cos.stride(-1), + sin.stride(0) if sin.dim() > 1 else 0, sin.stride(-2) if sin.dim() > 1 else 0, sin.stride(-1), + rotary_dim, rotary_half, conjugate, + HEADS=heads, SEQLEN=seqlen, DIM=dim, + IS_VARIABLE=IS_VARIABLE, INTERLEAVED=interleaved, + BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K + ) + + return out.view(batch, seqlen, dim) if dims == 3 else out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stdout new file mode 100644 index 0000000..23f5e26 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_44150.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_44150 due to at 27:24: + return + + seqlen_offset = 0 + if IS_VARIABLE: + seqlen_offset = tl.load(SEQLENS_OFFSETS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - tl.load(CU_SEQLENS + pid_batch) + if pid_m >= seq_len: + return + else: + if SEQLEN is not None and pid_m >= SEQLEN: + return + seqlen_offset = tl.load(SEQLENS_OFFSETS + pid_batch) if SEQLENS_OFFSETS else 0 + ^ +AssertionError('ternary expression with dynamic condition has inconsistent types int64 and int32') diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py b/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py new file mode 100644 index 0000000..ba8ab02 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py @@ -0,0 +1,307 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, # *float16/bfloat16/float32 + X, # *float16/bfloat16/float32 + COS, # *float16/bfloat16/float32 + SIN, # *float16/bfloat16/float32 + CU_SEQLENS, # *int32 + SEQLEN_OFFSETS, # *int32 when IS_SEQLEN_OFFSETS_TENSOR==True, otherwise scalar int + seqlen, # int32 + nheads, # int32 + rotary_dim, # int32 + seqlen_ro, # int32 + CACHE_KEY_SEQLEN, # int32 (unused in kernel; kept for signature match) + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K : tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR : tl.constexpr, + IS_VARLEN : tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE : tl.constexpr, + BLOCK_M : tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch= tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + # Determine batch / seqlen per sample + if IS_VARLEN == 0: + # fixed-seqlen; X, OUT already point at or part of contiguous + offset_b = pid_batch * stride_x_batch + offset_bo = pid_batch * stride_out_batch + current_seqlen = seqlen + else: + seqlen_start = tl.load(CU_SEQLENS + pid_batch).to(tl.int32) + seqlen_end = tl.load(CU_SEQLENS + pid_batch + 1).to(tl.int32) + current_seqlen = seqlen_end - seqlen_start + offset_b = seqlen_start * stride_x_seqlen + offset_bo = seqlen_start * stride_out_seqlen + + # Compute linears + X += offset_b + pid_head * stride_x_nheads + OUT += offset_bo + pid_head * stride_out_nheads + + # Return early for empty/tail blocks + if pid_m * BLOCK_M >= current_seqlen: + return + + # Row indices and validity mask + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < current_seqlen + + # sequence length offset + if IS_SEQLEN_OFFSETS_TENSOR: + seq_offset = tl.load(SEQLEN_OFFSETS + pid_batch).to(tl.int32) + else: + seq_offset = SEQLEN_OFFSETS # scalar integer captured at launch & constant in kernel + + # half-size dimension indices + rotary_dim_half = rotary_dim // 2 + rk_half = tl.arange(0, rotary_dim_half) + mask_half = rk_half < rotary_dim_half + + if INTERLEAVED == 0: + # Non-interleaved layout ------------------------------------------------- + base_pos = (rm[:, None] + seq_offset) * rotary_dim + rk_half[None, :] + cos_mask = ((rm[:, None] + seq_offset) < seqlen_ro) & mask_half[None, :] + sin_mask = cos_mask + + cos = tl.load(COS + base_pos, mask=cos_mask, other=1.0).to(tl.float32) + sin = tl.load(SIN + base_pos, mask=sin_mask, other=0.0).to(tl.float32) + + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim + x0 = tl.load(x0_ptr, mask=mask_m[:, None] & mask_half[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m[:, None] & mask_half[None, :], other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + y0 = x0 * cos - x1 * sin + y1 = x0 * sin + x1 * cos + + tl.store(OUT + x0_ptr - X + OUT, y0, + mask=mask_m[:, None] & mask_half[None, :]) + tl.store(OUT + x1_ptr - X + OUT, y1, + mask=mask_m[:, None] & mask_half[None, :]) + + # remainder pass-through + if rotary_dim < stride_x_headdim * stride_x_headdim or True: + rk_rem = tl.arange(rotary_dim, stride_x_headdim) + out_off = OUT + rm[:, None] * stride_out_seqlen + rk_rem[None, :] * stride_out_headdim + x_off = X + rm[:, None] * stride_x_seqlen + rk_rem[None, :] * stride_x_headdim + mask_rem = (rk_rem[None, :] < stride_x_headdim) & mask_m[:, None] + val_rem = tl.load(x_off, mask=mask_rem, other=0.0) + tl.store(out_off, val_rem, mask=mask_rem) + + else: + # Interleaved layout ---------------------------------------------------- + full_dim = rotary_dim + rk = tl.arange(0, full_dim) + mask_k = rk < full_dim + rk_half_idx = rk // 2 + + base_pos = (rm[:, None] + seq_offset) * full_dim + rk_half_idx[None, :] + mask_pos = ((rm[:, None] + seq_offset) < seqlen_ro) & mask_k[None, :] + + cos_val = tl.load(COS + base_pos, mask=mask_pos, other=1.0).to(tl.float32) + sin_val = tl.load(SIN + base_pos, mask=mask_pos, other=0.0).to(tl.float32) + + x_off = X + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x_val = tl.load(x_off, mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + + if CONJUGATE: + sin_val = -sin_val + + # flip sin when odd indices + sin_flipped = tl.where((rk[None, :] % 2) == 0, sin_val, -sin_val) + out_val = x_val * cos_val + sin_flipped * x_val.roll(-1, axis=1) + + tl.store(OUT + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, + out_val, mask=mask_m[:, None] & mask_k[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Top-level wrapper for RoPE Triton kernel (AMD ROCm).""" + if cu_seqlens is None: + batch, seqlen, nheads, headdim = x.shape + total_seqlen = batch * seqlen + stride_batch = x.stride(0) + else: + assert max_seqlen is not None + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.numel() - 1 + seqlen = max_seqlen + stride_batch = 0 # unused in varlen mode + + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype == x.dtype + assert rotary_dim % 2 == 0 + + x = x.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.dtype == torch.int32 + seqlen_offsets = seqlen_offsets.contiguous() + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.to(torch.int32).contiguous() + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = triton.next_power_of_2(rotary_dim) + + grid_m = lambda META: (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + BLOCK_M = 4 if interleaved else 8 + + rotary_kernel[grid_m]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + 0, # CACHE_KEY_SEQLEN (placeholder, unused) + *output.stride(), + *x.stride(), + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + cu_seqlens is not None, + interleaved, + conjugate, + BLOCK_M, + ) + + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stdout new file mode 100644 index 0000000..e932c5b --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_450091.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_450091 due to at 65:14: + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < current_seqlen + + # sequence length offset + if IS_SEQLEN_OFFSETS_TENSOR: + seq_offset = tl.load(SEQLEN_OFFSETS + pid_batch).to(tl.int32) + else: + seq_offset = SEQLEN_OFFSETS # scalar integer captured at launch & constant in kernel + + # half-size dimension indices + rotary_dim_half = rotary_dim // 2 + rk_half = tl.arange(0, rotary_dim_half) + ^ diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py b/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py new file mode 100644 index 0000000..6c6ce1e --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py @@ -0,0 +1,294 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + stride_cos_seqlen, + stride_cos_dim, + stride_sin_seqlen, + stride_sin_dim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + cur_seqlen = seqlen + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + seq_start = tl.load(CU_SEQLENS + pid_batch) + cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - seq_start + x_ptr = X + seq_start * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + seq_start * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= cur_seqlen: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk_half = tl.arange(0, BLOCK_K // 2) + + if IS_SEQLEN_OFFSETS_TENSOR: + offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + offset = SEQLEN_OFFSETS + rm_cs = rm + offset + + rm_cs = tl.where(rm_cs < seqlen_ro, rm_cs, seqlen_ro - 1) + + if not INTERLEAVED: + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + rk_half[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_dim + + mask_m = rm[:, None] < cur_seqlen + mask_k_half = rk_half[None, :] < rotary_dim_half + + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_k_half, other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_k_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_k_half) + else: + rk_even = 2 * tl.arange(0, rotary_dim_half) + rk_odd = 2 * tl.arange(0, rotary_dim_half) + 1 + + x0_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_even[None, :] * stride_x_headdim + x1_ptr = x_ptr + rm[:, None] * stride_x_seqlen + rk_odd[None, :] * stride_x_headdim + cos_ptr = COS + rm_cs[:, None] * stride_cos_seqlen + tl.arange(0, rotary_dim_half)[None, :] * stride_cos_dim + sin_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + tl.arange(0, rotary_dim_half)[None, :] * stride_sin_dim + + mask_m = rm[:, None] < cur_seqlen + mask_half = tl.arange(0, rotary_dim_half)[None, :] < rotary_dim_half + + cos = tl.load(cos_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=(rm_cs[:, None] < seqlen_ro) & mask_half, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_m & mask_half, other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_even[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_odd[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_half) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Apply rotary embedding to the input tensor x using Triton kernels optimized for AMD GPU ROCm.""" + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + if max_seqlen is None: + raise ValueError("max_seqlen must be provided if cu_seqlens is used") + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_dimhalf = cos.shape + rotary_dim = rotary_dimhalf * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + assert cos.dtype == sin.dtype == x.dtype + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + + BLOCK_K = ( + 32 if rotary_dim <= 32 else + 64 if rotary_dim <= 64 else + 128 if rotary_dim <= 128 else 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, + seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), output.stride(-2), output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), x.stride(-2), x.stride(-1), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stdout new file mode 100644 index 0000000..ed65b02 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_460195.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_460195 due to at 90:22: + if CONJUGATE: + sin = -sin + + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_m & mask_k_half) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_out_headdim, + o1, mask=mask_m & mask_k_half) + else: + rk_even = 2 * tl.arange(0, rotary_dim_half) + ^ diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py b/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py new file mode 100644 index 0000000..070d84e --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py @@ -0,0 +1,247 @@ + +import torch +import triton +import triton.language as tl +import math +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, + seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, + stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, + stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, + BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < seqlen + rm_cs = rm + if IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + else: + rm_cs = rm + SEQLEN_OFFSETS + + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k = rk_half < rotary_dim_half + x0 = tl.load(x_ptr + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + x1 = tl.load(x_ptr + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + cos = tl.load(COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=(rm_cs[:, None] < seqlen_ro) & mask_k[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=(rm_cs[:, None] < seqlen_ro) & mask_k[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_m[:, None] & mask_k[None, :]) + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, + o1, mask=mask_m[:, None] & mask_k[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + x = tl.load(x_ptr + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim, + mask=mask_m[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + rk_half = rk // 2 + mask_k_half = rk_half < rotary_dim_half + cos = tl.load(COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half[None, :], other=1.0).to(tl.float32) + sin = tl.load(SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=(rm_cs[:, None] < seqlen_ro) & mask_k_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + cos = tl.where(rk[None, :] % 2 == 0, cos, cos) + sin = tl.where(rk[None, :] % 2 == 0, sin, sin) + x0 = x + x1 = tl.roll(x, shifts=1, axis=1) + x1 = tl.where(rk[None, :] % 2 == 0, x1, -x1) + out = x0 * cos - x1 * sin + tl.store(out_ptr + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, + out, mask=mask_m[:, None] & mask_k[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + seqlen_offsets = int(seqlen_offsets) + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K_m = ( + 32 if rotary_dim_half <= 32 + else 64 if rotary_dim_half <= 64 + else 128 if rotary_dim_half <= 128 + else 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3) if x.dim() == 4 else output.stride(-2), + output.stride(-2) if x.dim() == 4 else output.stride(-1), + output.stride(-1) if x.dim() == 4 else 1, + x.stride(0) if not is_varlen else 0, + x.stride(-3) if x.dim() == 4 else x.stride(-2), + x.stride(-2) if x.dim() == 4 else x.stride(-1), + x.stride(-1) if x.dim() == 4 else 1, + BLOCK_K_m, isinstance(seqlen_offsets, torch.Tensor), + is_varlen, interleaved, conjugate, BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stdout new file mode 100644 index 0000000..04cdeae --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_527413.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_527413 due to module 'triton.language' has no attribute 'roll' diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py b/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py new file mode 100644 index 0000000..ebe22b7 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py @@ -0,0 +1,284 @@ + +import torch +import triton +import triton.language as tl +from typing import Union, Optional + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + + if IS_VARLEN: + start_idx = tl.load(CU_SEQLENS + pid_batch) + cur_seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_start = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_start = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + else: + cur_seqlen = seqlen + x_start = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_start = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= cur_seqlen: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk_half = tl.arange(0, BLOCK_K) + + if not INTERLEAVED: + cos_ptr = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ptr = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half) + cos = tl.load(cos_ptr, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_cs, other=0.0).to(tl.float32) + + left_ptr = x_start + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + right_ptr = x_start + (rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim) + mask_lr = (rm[:, None] < cur_seqlen) & (rk_half[None, :] < rotary_dim_half) + + x0 = tl.load(left_ptr, mask=mask_lr, other=0.0).to(tl.float32) + x1 = tl.load(right_ptr, mask=mask_lr, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + out0 = x0 * cos - x1 * sin + out1 = x0 * sin + x1 * cos + + tl.store( + out_start + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim), + out0, + mask=mask_lr, + ) + tl.store( + out_start + (rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim), + out1, + mask=mask_lr, + ) + else: + rk = tl.arange(0, 2 * BLOCK_K) + cos_ptr = COS + (rm_cs[:, None] * rotary_dim_half + (rk[None, :] // 2)) + sin_ptr = SIN + (rm_cs[:, None] * rotary_dim_half + (rk[None, :] // 2)) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim) + cos = tl.load(cos_ptr, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_cs, other=0.0).to(tl.float32) + + x_ptr = x_start + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + mask_x = (rm[:, None] < cur_seqlen) & (rk[None, :] < rotary_dim) + x0 = tl.load(x_ptr, mask=mask_x, other=0.0).to(tl.float32) + + x1_ptr = x_start + (rm[:, None] * stride_x_seqlen + (rk[None, :] ^ 1) * stride_x_headdim) + x1 = tl.load(x1_ptr, mask=mask_x, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + out = tl.where(rk[None, :] % 2 == 0, x0 * cos - x1 * sin, x0 * sin + x1 * cos) + tl.store( + out_start + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim), + out, + mask=mask_x, + ) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "max_seqlen required when cu_seqlens given" + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_half = cos.shape + rotary_dim = rotary_half * 2 + assert rotary_dim <= headdim + assert cos.dtype == sin.dtype == x.dtype + cos, sin = cos.contiguous(), sin.contiguous() + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = x if inplace else torch.empty_like(x) + if rotary_dim < headdim and not inplace: + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + + BLOCK_K = max(32, triton.next_power_of_2(rotary_half)) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_540784.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py b/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py new file mode 100644 index 0000000..14fa707 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py @@ -0,0 +1,317 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLENS, + max_seqlen, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cosm, + stride_cosk, + stride_sinm, + stride_sink, + stride_outb, + stride_outh, + stride_outm, + stride_outk, + TOTAL_TOKENS, + HEAD_NUM, + HEAD_DIM: tl.constexpr, + IS_INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + IS_INPLACE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + seq_start = 0 + seq_end = max_seqlen + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + else: + seq_start = pid_batch * max_seqlen + seq_end = (pid_batch + 1) * max_seqlen + + actual_seqlen = seq_end - seq_start + if pid_m * BLOCK_M >= actual_seqlen: + return + + if CU_SEQLENS is not None: + batch_offset = 0 + else: + batch_offset = pid_batch + + head_offset = pid_head + d_half = HEAD_DIM // 2 + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + mask_m = offs_m < actual_seqlen + + if IS_INTERLEAVED: + for il in range(0, HEAD_DIM // 2): + offs_k_cos_0 = il + offs_k_cos_1 = il + d_half + + if CU_SEQLENS is not None: + ptr_x_0 = X + seq_start * stride_xm + head_offset * stride_xh + offs_m[:, None] * stride_xm + offs_k_cos_0 * 2 * stride_xk + offs_n[None, :] * 2 + ptr_x_1 = X + seq_start * stride_xm + head_offset * stride_xh + offs_m[:, None] * stride_xm + offs_k_cos_0 * 2 * stride_xk + offs_n[None, :] * 2 + stride_xk + ptr_cos = COS + offs_m[:, None] * stride_cosm + offs_k_cos_0 * stride_cosk + ptr_sin = SIN + offs_m[:, None] * stride_sinm + offs_k_cos_0 * stride_sink + else: + ptr_x_0 = X + batch_offset * stride_xb + head_offset * stride_xh + offs_m[:, None] * stride_xm + offs_k_cos_0 * 2 * stride_xk + offs_n[None, :] * 2 + ptr_x_1 = X + batch_offset * stride_xb + head_offset * stride_xh + offs_m[:, None] * stride_xm + offs_k_cos_0 * 2 * stride_xk + offs_n[None, :] * 2 + stride_xk + ptr_cos = COS + offs_m[:, None] * stride_cosm + offs_k_cos_0 * stride_cosk + ptr_sin = SIN + offs_m[:, None] * stride_sinm + offs_k_cos_0 * stride_sink + + x0 = tl.load(ptr_x_0, mask=mask_m[:, None]) + x1 = tl.load(ptr_x_1, mask=mask_m[:, None]) + c = tl.load(ptr_cos, mask=mask_m[:, None]) + s = tl.load(ptr_sin, mask=mask_m[:, None]) + + if CONJUGATE: + tmp = x0 * c + x1 * s + x1 = x1 * c - x0 * s + x0 = tmp + else: + tmp = x0 * c - x1 * s + x1 = x0 * s + x1 * c + x0 = tmp + + if IS_INPLACE: + tl.store(ptr_x_0, x0.to(ptr_x_0.type.element_ty), mask=mask_m[:, None]) + tl.store(ptr_x_1, x1.to(ptr_x_1.type.element_ty), mask=mask_m[:, None]) + else: + if CU_SEQLENS is not None: + ptr_out_0 = OUT + seq_start * stride_outm + head_offset * stride_outh + offs_m[:, None] * stride_outm + offs_k_cos_0 * 2 * stride_outk + offs_n[None, :] * 2 + ptr_out_1 = OUT + seq_start * stride_outm + head_offset * stride_outh + offs_m[:, None] * stride_outm + offs_k_cos_0 * 2 * stride_outk + offs_n[None, :] * 2 + stride_outk + else: + ptr_out_0 = OUT + batch_offset * stride_outb + head_offset * stride_outh + offs_m[:, None] * stride_outm + offs_k_cos_0 * 2 * stride_outk + offs_n[None, :] * 2 + ptr_out_1 = OUT + batch_offset * stride_outb + head_offset * stride_outh + offs_m[:, None] * stride_outm + offs_k_cos_0 * 2 * stride_outk + offs_n[None, :] 2 + stride_outk + tl.store(ptr_out_0, x0.to(ptr_out_0.type.element_ty), mask=mask_m[:, None]) + tl.store(ptr_out_1, x1.to(ptr_out_1.type.element_ty), mask=mask_m[:, None]) + else: + for ih in range(0, 2): + if ih == 0: + offs_k_start = 0 + offs_k_end = d_half + offs_cos_k = 0 + else: + offs_k_start = d_half + offs_k_end = HEAD_DIM + offs_cos_k = 1 + + if CU_SEQLENS is not None: + ptr_x_base = X + seq_start * stride_xm + head_offset * stride_xh + ptr_cos_base = COS + offs_m[:, None] * stride_cosm + offs_cos_k * stride_cosk + ptr_sin_base = SIN + offs_m[:, None] * stride_sinm + offs_cos_k * stride_sink + ptr_out_base = OUT + seq_start * stride_outm + head_offset * stride_outh + else: + ptr_x_base = X + batch_offset * stride_xb + head_offset * stride_xh + ptr_cos_base = COS + offs_m[:, None] * stride_cosm + offs_cos_k * stride_cosk + ptr_sin_base = SIN + offs_m[:, None] * stride_sinm + offs_cos_k * stride_sink + ptr_out_base = OUT + batch_offset * stride_outb + head_offset * stride_outh + + x0 = tl.load(ptr_x_base + offs_m[:, None] * stride_xm + (tl.arange(offs_k_start, offs_k_end)[None, :]) * stride_xk, mask=mask_m[:, None]) + x1 = tl.load(ptr_x_base + offs_m[:, None] * stride_xm + (tl.arange(offs_k_start + d_half, offs_k_end + d_half)[None, :]) * stride_xk, mask=mask_m[:, None]) + + c = tl.load(ptr_cos_base) + s = tl.load(ptr_sin_base) + + if CONJUGATE: + tmp = x0 * c + x1 * s + x1 = x1 * c - x0 * s + x0 = tmp + else: + tmp = x0 * c - x1 * s + x1 = x0 * s + x1 * c + x0 = tmp + + if IS_INPLACE: + tl.store(ptr_x_base + offs_m[:, None] * stride_xm + (tl.arange(offs_k_start, offs_k_end)[None, :]) * stride_xk, x0.to(ptr_x_base.type.element_ty), mask=mask_m[:, None]) + tl.store(ptr_x_base + offs_m[:, None] * stride_xm + (tl.arange(offs_k_start + d_half, offs_k_end + d_half)[None, :]) * stride_xk, x1.to(ptr_x_base.type.element_ty), mask=mask_m[:, None]) + else: + tl.store(ptr_out_base + offs_m[:, None] * stride_outm + (tl.arange(offs_k_start, offs_k_end)[None, :]) * stride_outk, x0.to(ptr_out_base.type.element_ty), mask=mask_m[:, None]) + tl.store(ptr_out_base + offs_m[:, None] * stride_outm + (tl.arange(offs_k_start + d_half, offs_k_end + d_half)[None, :]) * stride_outk, x1.to(ptr_out_base.type.element_ty), mask=mask_m[:, None]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int = 0, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +): + assert x.dim() == 4 + batch, head_num, seqlen, headdim = x.shape + + if max_seqlen is None: + if cu_seqlens is None: + max_seqlen = seqlen + else: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + BLOCK_M = max(16, min(128, triton.next_power_of_2(max_seqlen))) + + if inplace: + out = x + else: + out = torch.empty_like(x) + + grid = ( + batch, + head_num, + triton.cdiv(max_seqlen, BLOCK_M), + ) + + rotary_kernel[grid]( + out, + x, + cos, + sin, + cu_seqlens, + None, + max_seqlen, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + cos.stride(0) if cos.dim() > 1 else 0, + cos.stride(1) if cos.dim() > 1 else cos.stride(0), + sin.stride(0) if sin.dim() > 1 else 0, + sin.stride(1) if sin.dim() > 1 else sin.stride(0), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + batch * seqlen, + head_num, + HEAD_DIM=headdim, + IS_INTERLEAVED=interleaved, + CONJUGATE=conjugate, + IS_INPLACE=inplace, + BLOCK_M=BLOCK_M, + BLOCK_N=1, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stdout new file mode 100644 index 0000000..e693b35 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_549779.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_549779 due to invalid syntax (rotary_transform.py_gen_triton_code_549779.py, line 105) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py b/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py new file mode 100644 index 0000000..96a53e3 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py @@ -0,0 +1,351 @@ + +import torch +import triton +import triton.language as tl +import math +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + rotary_dim_half = rotary_dim // 2 + + cu_seqlens_ptr = CU_SEQLENS + seqlen_offsets_ptr = SEQLEN_OFFSETS + + if not IS_VARLEN or CU_SEQLENS is None: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + cur_seqlen = seqlen + else: + seq_start = tl.load(cu_seqlens_ptr + pid_batch) + cur_seqlen = tl.load(cu_seqlens_ptr + pid_batch + 1) - seq_start + X = X + seq_start * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + seq_start * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= cur_seqlen: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(seqlen_offsets_ptr + pid_batch) + + if not INTERLEAVED: + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_x_headdim + + c_ptr = COS + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_headdim + s_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_headdim + + mask_m = rm[:, None] < cur_seqlen + mask_ro_k = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half) + mask_x_k = mask_m & (rk_half[None, :] < rotary_dim_half) + + c = tl.load(c_ptr, mask=mask_ro_k, other=1.0).to(tl.float32) + s = tl.load(s_ptr, mask=mask_ro_k, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_x_k, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_x_k, other=0.0).to(tl.float32) + + if CONJUGATE: + s = -s + + o0 = x0 * c - x1 * s + o1 = x0 * s + x1 * c + + out0_ptr = OUT + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim + out1_ptr = OUT + rm[:, None] * stride_out_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_out_headdim + + tl.store(out0_ptr, o0, mask=mask_x_k) + tl.store(out1_ptr, o1, mask=mask_x_k) + else: + offs_d = 2 * tl.arange(0, rotary_dim_half) + x0_ptr = X + rm[:, None] * stride_x_seqlen + offs_d[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + (offs_d + 1)[None, :] * stride_x_headdim + + c_ptr = COS + rm_cs[:, None] * stride_sin_seqlen + tl.arange(0, rotary_dim_half)[None, :] * stride_sin_headdim + s_ptr = SIN + rm_cs[:, None] * stride_sin_seqlen + tl.arange(0, rotary_dim_half)[None, :] * stride_sin_headdim + + mask_m = rm[:, None] < cur_seqlen + mask_ro_k = (rm_cs[:, None] < seqlen_ro) & (tl.arange(0, rotary_dim_half)[None, :] < rotary_dim // 2) + mask_x_k = mask_m & (tl.arange(0, rotary_dim_half)[None, :] < rotary_dim // 2) + + c = tl.load(c_ptr, mask=mask_ro_k, other=1.0).to(tl.float32) + s = tl.load(s_ptr, mask=mask_ro_k, other=0.0).to(tl.float32) + x0 = tl.load(x0_ptr, mask=mask_x_k, other=0.0).to(tl.float32) + x1 = tl.load(x1_ptr, mask=mask_x_k, other=0.0).to(tl.float32) + + if CONJUGATE: + s = -s + + o0 = x0 * c - x1 * s + o1 = x0 * s + x1 * c + + out0_ptr = OUT + rm[:, None] * stride_out_seqlen + offs_d[None, :] * stride_out_headdim + out1_ptr = OUT + rm[:, None] * stride_out_seqlen + (offs_d + 1)[None, :] * stride_out_headdim + + tl.store(out0_ptr, o0, mask=mask_x_k) + tl.store(out1_ptr, o1, mask=mask_x_k) + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """ + Apply rotary embedding to the input tensor x using Triton kernels optimized for AMD GPU ROCm. + + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int. Required if cu_seqlens is not None. + interleaved: Use interleaved layout (rotary_dim = headdim // 2) + inplace: Whether to perform the rotation in-place (x will be modified) + conjugate: Whether to negate the sine component + Returns: + y: (batch, seqlen, nheads, headdim) or (total_seqlen, nheads, headdim) same shape as x + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + if max_seqlen is None: + raise ValueError("max_seqlen must be provided if cu_seqlens is used") + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_dimhalf = cos.shape + rotary_dim = rotary_dimhalf * 2 + assert sin.shape == cos.shape + assert rotary_dim <= headdim, f"Rotary dimension={rotary_dim} must be <= head_dim={headdim}" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, f"seqlen_ro={seqlen_ro} must >= seqlen={seqlen}" + assert cos.dtype == sin.dtype + assert x.dtype == cos.dtype + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + + BLOCK_K = ( + 32 if rotary_dim <= 32 else + 64 if rotary_dim <= 64 else + 128 if rotary_dim <= 128 else 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + # Set strides correctly depending on tensor shape + if x.dim() == 4: + x_stride_b, x_stride_seqlen, x_stride_h, x_stride_d = ( + x.stride(0), x.stride(1), x.stride(2), x.stride(3) + ) + output_stride_b, output_stride_seqlen, output_stride_h, output_stride_d = ( + output.stride(0), output.stride(1), output.stride(2), output.stride(3) + ) + else: # x.dim() == 3 + x_stride_b, x_stride_seqlen, x_stride_h, x_stride_d = ( + 0, x.stride(0), x.stride(1), x.stride(2) + ) + output_stride_b, output_stride_seqlen, output_stride_h, output_stride_d = ( + 0, output.stride(0), output.stride(1), output.stride(2) + ) + + cos_stride_m = cos.stride(0) + cos_stride_n = cos.stride(1) + sin_stride_m = sin.stride(0) + sin_stride_n = sin.stride(1) + + global stride_sin_seqlen, stride_sin_headdim + stride_sin_seqlen = cos_stride_m + stride_sin_headdim = cos_stride_n + + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # cache key + output_stride_b, + output_stride_seqlen, + output_stride_h, + output_stride_d, + x_stride_b, + x_stride_seqlen, + x_stride_h, + x_stride_d, + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stdout new file mode 100644 index 0000000..610ffe2 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_555768.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_555768 due to at 66:39: + rk_half = tl.arange(0, BLOCK_K // 2) + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(seqlen_offsets_ptr + pid_batch) + + if not INTERLEAVED: + x0_ptr = X + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x1_ptr = X + rm[:, None] * stride_x_seqlen + (rk_half + rotary_dim_half)[None, :] * stride_x_headdim + + c_ptr = COS + rm_cs[:, None] * stride_sin_seqlen + rk_half[None, :] * stride_sin_headdim + ^ +NameError("Cannot access global variable stride_sin_seqlen from within @jit'ed function. Triton kernels can only access global variables that are annotated as constexpr (`x: triton.language.constexpr = 42` or `x = triton.language.constexpr(42)`). Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not promise to support this forever.") diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py b/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py new file mode 100644 index 0000000..9ae0b89 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py @@ -0,0 +1,256 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, + seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, + stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, + stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, + BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_base = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_base = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_base = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_base = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = rm < seqlen + + if IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + else: + rm_cs = rm + SEQLEN_OFFSETS + mask_cs = rm_cs < seqlen_ro + + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K) + mask_k = rk_half < rotary_dim_half + mask_x0 = mask_m[:, None] & mask_k[None, :] + mask_x1 = mask_m[:, None] & mask_k[None, :] + x0 = tl.load(x_base + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim, + mask=mask_x0, other=0.0).to(tl.float32) + x1 = tl.load(x_base + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim, + mask=mask_x1, other=0.0).to(tl.float32) + cos_v = tl.load(COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=mask_cs[:, None] & mask_k[None, :], other=1.0).to(tl.float32) + sin_v = tl.load(SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=mask_cs[:, None] & mask_k[None, :], other=0.0).to(tl.float32) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk < rotary_dim + mask_x = mask_m[:, None] & mask_k[None, :] + x = tl.load(x_base + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim, + mask=mask_x, other=0.0).to(tl.float32) + rk_half = rk // 2 + mask_k_half = rk_half < rotary_dim_half + cos_v = tl.load(COS + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=mask_cs[:, None] & mask_k_half[None, :], other=1.0).to(tl.float32) + sin_v = tl.load(SIN + rm_cs[:, None] * rotary_dim_half + rk_half[None, :], + mask=mask_cs[:, None] & mask_k_half[None, :], other=0.0).to(tl.float32) + + if CONJUGATE: + sin_v = -sin_v + + if not INTERLEAVED: + o0 = x0 * cos_v - x1 * sin_v + o1 = x0 * sin_v + x1 * cos_v + tl.store(out_base + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o0, mask=mask_x0) + tl.store(out_base + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, + o1, mask=mask_x1) + else: + rk_pair = (rk + (-1 if rk % 2 == 0 else 1)) % rotary_dim + mask_pair = mask_m[:, None] & mask_k[None, :] + x_pair = tl.load(x_base + rm[:, None] * stride_x_seqlen + rk_pair[None, :] * stride_x_headdim, + mask=mask_pair, other=0.0).to(tl.float32) + neg = tl.where(rk % 2 == 0, -1.0, 1.0) + o = x * cos_v + neg * x_pair * sin_v + tl.store(out_base + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, + o, mask=mask_x) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + seqlen_offsets = int(seqlen_offsets) + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 if rotary_dim_half <= 32 else + 64 if rotary_dim_half <= 64 else + 128 if rotary_dim_half <= 128 else + 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stderr new file mode 100644 index 0000000..d0d524f --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stderr @@ -0,0 +1,2 @@ +python3: /source/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::detail::TypedValue, From = mlir::Value]: Assertion `isa(Val) && "cast() argument of incompatible type!"' failed. +Aborted (core dumped) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_634902.py.stdout new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py b/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py new file mode 100644 index 0000000..f26e1e6 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py @@ -0,0 +1,265 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, + seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, + stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, + stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_base = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_base = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + seqlen_i = seqlen + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen_i = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_base = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_base = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen_i: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rm_full_mask = rm < seqlen_i + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rm_cs_mask = rm_cs < seqlen_ro + + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K // 2) + rk_mask = rk_half[None, :] < rotary_dim_half + + # Real part + off_x_real = x_base + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + x_real = tl.load(off_x_real, mask=rm_full_mask[:, None] & rk_mask, other=0.0).to(tl.float32) + # Imag part + off_x_imag = x_base + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim + x_imag = tl.load(off_x_imag, mask=rm_full_mask[:, None] & rk_mask, other=0.0).to(tl.float32) + + off_cos = rm_cs[:, None] * (rotary_dim // 2) + rk_half[None, :] + cos = tl.load(COS + off_cos, mask=rm_cs_mask[:, None] & rk_mask, other=1.0).to(tl.float32) + sin_val = tl.load(SIN + off_cos, mask=rm_cs_mask[:, None] & rk_mask, other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + + o_real = x_real * cos - x_imag * sin_val + o_imag = x_real * sin_val + x_imag * cos + + tl.store(out_base + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + o_real, mask=rm_full_mask[:, None] & rk_mask) + tl.store(out_base + rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, + o_imag, mask=rm_full_mask[:, None] & rk_mask) + else: + rk = tl.arange(0, BLOCK_K) + mask_k = rk[None, :] < rotary_dim + rk_half_idx = rk // 2 + mask_k_half = rk_half_idx[None, :] < rotary_dim_half + + off_x = x_base + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x_vals = tl.load(off_x, mask=rm_full_mask[:, None] & mask_k, other=0.0).to(tl.float32) + x0 = tl.where(rk[None, :] % 2 == 0, x_vals, 0) + x1 = tl.where(rk[None, :] % 2 == 1, x_vals, 0) + + off_cos_sin = rm_cs[:, None] * (rotary_dim // 2) + rk_half_idx[None, :] + cos = tl.load(COS + off_cos_sin, mask=rm_cs_mask[:, None] & mask_k_half, other=1.0).to(tl.float32) + sin_val = tl.load(SIN + off_cos_sin, mask=rm_cs_mask[:, None] & mask_k_half, other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + + out_even = x0 * cos - x1 * sin_val + out_odd = x0 * sin_val + x1 * cos + + out_vals = tl.where(rk[None, :] % 2 == 0, out_even, out_odd) + tl.store(out_base + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim, + out_vals, mask=rm_full_mask[:, None] & mask_k) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None + total_seqlen, nheads, headdim = x.shape + seqlen = max_seqlen + batch = cu_seqlens.shape[0] - 1 + + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + + seqlen_ro_needed = seqlen + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.to(torch.int32).contiguous() + seqlen_ro_needed += seqlen_offsets.max().item() + else: + seqlen_ro_needed += seqlen_offsets + assert seqlen_ro >= seqlen_ro_needed + + cos = cos.contiguous() + sin = sin.contiguous() + + output = x if inplace else torch.empty_like(x) + if not inplace and rotary_dim < headdim: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = 32 if rotary_dim <= 32 else \ + 64 if rotary_dim <= 64 else \ + 128 if rotary_dim <= 128 else 256 + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + cu_seqlens_ptr = cu_seqlens.int().contiguous() if is_varlen else None + + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens_ptr, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(1), + output.stride(2), + output.stride(3), + x.stride(0) if not is_varlen else 0, + x.stride(1), + x.stride(2), + x.stride(3), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stdout new file mode 100644 index 0000000..d5e1467 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_669031.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_669031 due to Dimension out of range (expected to be in range of [-3, 2], but got 3) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py b/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py new file mode 100644 index 0000000..4d77ff3 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py @@ -0,0 +1,252 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + OUT, + CU_SEQLENS, + SEQLENS, + stride_xbatch, + stride_xhead, + stride_xm, + stride_xk, + stride_cos_m, + stride_cos_k, + stride_sin_m, + stride_sin_k, + stride_obatch, + stride_ohead, + stride_om, + stride_ok, + TOTAL_TOKENS: tl.constexpr, + HEAD_DIM: tl.constexpr, + MAX_SEQLEN: tl.constexpr, + IS_VARIABLE_L: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_k = tl.arange(0, BLOCK_K) + + if IS_VARIABLE_L: + b_start = 0 if pid_batch == 0 else tl.load(CU_SEQLENS + pid_batch - 1) + b_end = tl.load(CU_SEQLENS + pid_batch) + seqlen = b_end - b_start + else: + b_start = pid_batch * MAX_SEQLEN + seqlen = MAX_SEQLEN + + mask_m = offsets_m < seqlen + mask_k_half = offsets_k < (HEAD_DIM // 2) + + full_offsets_m = b_start + offsets_m + full_mask_m = full_offsets_m < TOTAL_TOKENS + + cos_ptrs = COS + full_offsets_m * stride_cos_m + offsets_k * stride_cos_k + sin_ptrs = SIN + full_offsets_m * stride_sin_m + offsets_k * stride_sin_k + + if INTERLEAVED: + x_offsets_k = offsets_k * 2 + x_offsets_k2 = offsets_k * 2 + 1 + else: + x_offsets_k = offsets_k + x_offsets_k2 = offsets_k + (HEAD_DIM // 2) + + x_ptrs = X + full_offsets_m * stride_xm + pid_head * stride_xhead + x_offsets_k * stride_xk + x2_ptrs = X + full_offsets_m * stride_xm + pid_head * stride_xhead + x_offsets_k2 * stride_xk + + x1 = tl.load(x_ptrs, mask=full_mask_m[:, None] & mask_k_half[None, :]) + x2 = tl.load(x2_ptrs, mask=full_mask_m[:, None] & mask_k_half[None, :]) + + cos = tl.load(cos_ptrs, mask=full_mask_m[:, None] & mask_k_half[None, :]) + sin = tl.load(sin_ptrs, mask=full_mask_m[:, None] & mask_k_half[None, :]) + + if CONJUGATE: + x2_rot = -x2 + else: + x2_rot = x2 + + out1 = x1 * cos - x2_rot * sin + out2 = x1 * sin + x2 * cos + + out_ptrs = OUT + full_offsets_m * stride_om + pid_head * stride_ohead + tl.store(out_ptrs + x_offsets_k * stride_ok, out1, mask=full_mask_m[:, None] & mask_k_half[None, :]) + tl.store(out_ptrs + x_offsets_k2 * stride_ok, out2, mask=full_mask_m[:, None] & mask_k_half[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int = 0, + cu_seqlens: torch.Tensor = None, + *, + inplace: bool = False, + interleaved: bool = False, + conj: bool = False, +): + batch, seqlen, nheads, headdim = x.shape + assert headdim <= 1024 + assert cos.shape == sin.shape == (seqlen, headdim // 2) + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] + + if inplace: + out = x + else: + out = torch.empty_like(x) + + BLOCK_M = 1 + while BLOCK_M * nheads * 4 * x.element_size() < 32768 and BLOCK_M * 2 <= seqlen: + BLOCK_M *= 2 + BLOCK_K = min(triton.next_power_of_2(headdim // 2), 64) + + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, META["BLOCK_M"])) + + TOTAL_TOKENS = batch * seqlen + max_seqlen = seqlen + is_variable_l = cu_seqlens is not None + HEAD_DIM = headdim + + rotary_kernel[grid]( + x, + cos, + sin, + out, + cu_seqlens, + seqlen, + x.stride(0), + x.stride(2), + x.stride(1), + x.stride(3), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + out.stride(0), + out.stride(2), + out.stride(1), + out.stride(3), + TOTAL_TOKENS=TOTAL_TOKENS, + HEAD_DIM=HEAD_DIM, + MAX_SEQLEN=max_seqlen, + IS_VARIABLE_L=is_variable_l, + INTERLEAVED=interleaved, + CONJUGATE=conj, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stdout new file mode 100644 index 0000000..b9446b4 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_711258.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_711258 due to diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py b/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py new file mode 100644 index 0000000..7376f11 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py @@ -0,0 +1,289 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + OUT, + CU_SEQLENS, + seqlen_offsets, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cosb, + stride_cosh, + stride_cosm, + stride_cosk, + stride_sinb, + stride_sinh, + stride_sinm, + stride_sink, + stride_ob, + stride_oh, + stride_om, + stride_ok, + max_seqlen, + rotary_dim, + seqlen, + interleaved: tl.constexpr, + conjugate: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if CU_SEQLENS is not None: + cu_seqlens_start = tl.load(CU_SEQLENS + pid_batch) + cu_seqlens_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen = cu_seqlens_end - cu_seqlens_start + else: + cu_seqlens_start = 0 + + offset = tl.load(seqlen_offsets + pid_batch) if seqlen_offsets is not None else 0 + seqlen = seqlen - offset + if pid_m >= seqlen: + return + + rotary_dim_half = rotary_dim // 2 + cols = tl.arange(0, BLOCK_K) + mask = cols < rotary_dim_half + + offset_m = pid_m + offset + pos = offset_m.to(tl.int32) + + if interleaved: + cos_offset = pos * stride_cosm + (cols * 2) * stride_cosk + sin_offset = pos * stride_sinm + (cols * 2) * stride_sink + else: + cos_offset = pos * stride_cosm + cols * stride_cosk + sin_offset = pos * stride_sinm + cols * stride_sink + + cos = tl.load(COS + cos_offset, mask=mask, other=0.0) + sin = tl.load(SIN + sin_offset, mask=mask, other=0.0) + + x_offset = ( + pid_batch * stride_xb + + pid_head * stride_xh + + pid_m * stride_xm + ) + + if interleaved: + x_col0 = x_offset + (cols * 2) * stride_xk + x_col1 = x_offset + (cols * 2 + 1) * stride_xk + x0 = tl.load(X + x_col0, mask=mask, other=0.0) + x1 = tl.load(X + x_col1, mask=mask, other=0.0) + else: + x_col0 = x_offset + cols * stride_xk + x_col1 = x_offset + (cols + rotary_dim_half) * stride_xk + x0 = tl.load(X + x_col0, mask=mask, other=0.0) + x1 = tl.load(X + x_col1, mask=mask, other=0.0) + + if conjugate: + x1 = -x1 + + out0 = x0 * cos - x1 * sin + out1 = x0 * sin + x1 * cos + + out_offset = ( + pid_batch * stride_ob + + pid_head * stride_oh + + pid_m * stride_om + ) + + if interleaved: + tl.store(OUT + out_offset + (cols * 2) * stride_ok, out0, mask=mask) + tl.store(OUT + out_offset + (cols * 2 + 1) * stride_ok, out1, mask=mask) + else: + tl.store(OUT + out_offset + cols * stride_ok, out0, mask=mask) + tl.store(OUT + out_offset + (cols + rotary_dim_half) * stride_ok, out1, mask=mask) + + # Copy non-rotary dimensions + cols_rest_start = rotary_dim if not interleaved else rotary_dim * 2 + cols_rest_end = max_seqlen + cols_rest = cols_rest_start + tl.arange(0, BLOCK_K) + mask_rest = cols_rest < cols_rest_end + + if interleaved: + x_rest_offset = x_offset + cols_rest * stride_xk + out_rest_offset = out_offset + cols_rest * stride_ok + else: + x_rest_offset = x_offset + cols_rest * stride_xk + out_rest_offset = out_offset + cols_rest * stride_ok + + x_rest = tl.load(X + x_rest_offset, mask=mask_rest, other=0.0) + tl.store(OUT + out_rest_offset, x_rest, mask=mask_rest) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + interleaved: bool = False, + in_place: bool = False, + conjugate: bool = False, +): + batch, head, seqlen, dim = x.shape + rotary_dim = cos.shape[-1] + assert cos.shape == sin.shape + assert rotary_dim * 2 <= dim, "Rotary dim must be <= half of hidden size" + + if max_seqlen is None: + if cu_seqlens is None: + max_seqlen = seqlen + else: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + BLOCK_M = 1 + BLOCK_K = max(rotary_dim, 32) + + grid = (batch, head, seqlen) + + if not in_place: + out = torch.empty_like(x) + else: + out = x + + rotary_kernel[grid]( + x, + cos, + sin, + out, + cu_seqlens, + seqlen_offsets, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + cos.stride(0), + cos.stride(1), + cos.stride(2), + cos.stride(3), + sin.stride(0), + sin.stride(1), + sin.stride(2), + sin.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + max_seqlen, + rotary_dim, + seqlen, + interleaved, + conjugate, + BLOCK_M, + BLOCK_K, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stdout new file mode 100644 index 0000000..316b76b --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_816058.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_816058 due to Dimension out of range (expected to be in range of [-2, 1], but got 2) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py b/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py new file mode 100644 index 0000000..42e4071 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py @@ -0,0 +1,268 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, + seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, + stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, + stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, + BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + rotary_kernel[grid]( + output, x, cos, sin, cu_seqlens, seqlen_offsets, + seqlen, nheads, rotary_dim, seqlen_ro, seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), output.stride(-2), output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), x.stride(-2), x.stride(-1), + BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), + is_varlen, interleaved, conjugate, BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_824557.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py b/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py new file mode 100644 index 0000000..9740ccb --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py @@ -0,0 +1,228 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, COS, SIN, CU_SEQLENS, SEQLENS, OUT, + stride_batch, stride_seqlen, stride_head, stride_dim, + rotary_dim, max_seqlen, total_seqlens, + nheads, seqlen_ro, interleaved, conj, BLOCK_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_batch: + return + if pid_head >= nheads: + return + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_start + else: + seq_start = pid_batch * max_seqlen + seqlen_i = tl.load(SEQLENS + pid_batch) if SEQLENS is not None else max_seqlen + + if pid_m >= seqlen_i: + return + + offset_m = seq_start + pid_m + + k_id = tl.arange(0, rotary_dim // 2) + rotary_dim_half = rotary_dim // 2 + + for k in range(0, rotary_dim_half, BLOCK_SIZE_M): + k_idx = k + tl.arange(0, BLOCK_SIZE_M) + mask = k_idx < rotary_dim_half + + pos_m = pid_m + + cos_idx = pos_m * stride_seqlen + k_idx * 2 + cos_offset = COS + cos_idx + cos0 = tl.load(cos_offset, mask=mask) + cos1 = tl.load(cos_offset + 1, mask=mask) + + sin_idx = pos_m * stride_seqlen + k_idx * 2 + sin_offset = SIN + sin_idx + sin0 = tl.load(sin_offset, mask=mask) + sin1 = tl.load(sin_offset + 1, mask=mask) + + if interleaved: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx * 2 + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx * 2 + 1 + x0 = tl.load(X + x_idx0, mask=mask) + x1 = tl.load(X + x_idx1, mask=mask) + if conj: + x1 = -x1 + out0 = x0 * cos0 - x1 * sin0 + out1 = x0 * sin1 + x1 * cos1 + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask) + else: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + rotary_dim_half + x0 = tl.load(X + x_idx0, mask=mask) + x1 = tl.load(X + x_idx1, mask=mask) + if conj: + x1 = -x1 + out0 = x0 * cos0 - x1 * sin0 + out1 = x0 * sin1 + x1 * cos1 + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask) + + if IS_EVEN_K: + return + + k_offset = rotary_dim + total_dim = stride_dim + for d in range(rotary_dim, total_dim): + idx = offset_m * stride_seqlen + pid_head * stride_head + d + val = tl.load(X + idx) + tl.store(OUT + idx, val) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seq_offset: int = 0, + cu_seqlens: torch.Tensor = None, + seqlens: torch.Tensor = None, + rotary_dim: int = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +): + batch, seqlen, nheads, headdim = x.shape + rotary_dim = rotary_dim or headdim + + assert rotary_dim % 2 == 0 + assert headdim >= rotary_dim + assert cos.shape == (seqlen, rotary_dim // 2 * 2) + assert sin.shape == (seqlen, rotary_dim // 2 * 2) + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + BLOCK_SIZE_M = min(max(triton.next_power_of_2(rotary_dim // 2), 16), 64) + + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, 1)) + + rotary_kernel[grid]( + x, cos, sin, cu_seqlens, seqlens, out, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + rotary_dim, seqlen, batch * seqlen, + nheads, seqlen, interleaved, conjugate, BLOCK_SIZE_M, + IS_EVEN_K=(headdim == rotary_dim), + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stdout new file mode 100644 index 0000000..b7a17ee --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_840463.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_840463 due to diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py b/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py new file mode 100644 index 0000000..8230c46 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py @@ -0,0 +1,237 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + X, COS, SIN, CU_SEQLENS, SEQLENS, OUT, + stride_batch, stride_seqlen, stride_head, stride_dim, + rotary_dim, max_seqlen, total_seqlens, + nheads, seqlen_ro, interleaved, conj, BLOCK_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= stride_batch: + return + if pid_head >= nheads: + return + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1) + seqlen_i = seq_end - seq_start + else: + seq_start = pid_batch * max_seqlen + seqlen_i = tl.load(SEQLENS + pid_batch) if SEQLENS is not None else max_seqlen + + if pid_m >= seqlen_i: + return + + offset_m = seq_start + pid_m + + rotary_dim_half = rotary_dim // 2 + BLOCK_K = tl.min(BLOCK_SIZE_M, rotary_dim_half) + for k in range(0, rotary_dim_half, BLOCK_K): + k_idx = k + tl.arange(0, BLOCK_K) + mask = k_idx < rotary_dim_half + + pos_m = pid_m + cos_idx = pos_m * rotary_dim + k_idx + cos_offset = COS + cos_idx + cos_val = tl.load(cos_offset, mask=mask).to(tl.float32) + + sin_idx = pos_m * rotary_dim + k_idx + sin_offset = SIN + sin_idx + sin_val = tl.load(sin_offset, mask=mask).to(tl.float32) + if conj: + sin_val = -sin_val + + if interleaved: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + 2 * k_idx + 1 + mask_2 = 2 * k_idx + 1 < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask_2).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_2).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask_2) + tl.store(OUT + x_idx1, out1, mask=mask_2) + else: + x_idx0 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + x_idx1 = offset_m * stride_seqlen + pid_head * stride_head + k_idx + rotary_dim_half + mask_half = k_idx + rotary_dim_half < rotary_dim + x0 = tl.load(X + x_idx0, mask=mask).to(tl.float32) + x1 = tl.load(X + x_idx1, mask=mask_half).to(tl.float32) + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + tl.store(OUT + x_idx0, out0, mask=mask) + tl.store(OUT + x_idx1, out1, mask=mask_half) + + if IS_EVEN_K: + return + + k_offset = rotary_dim + total_dim = stride_dim + for d in range(rotary_dim, total_dim): + idx = offset_m * stride_seqlen + pid_head * stride_head + d + val = tl.load(X + idx) + tl.store(OUT + idx, val.astype(X.dtype.element_ty)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seq_offset: int = 0, + cu_seqlens: Optional[torch.Tensor] = None, + seqlens: Optional[torch.Tensor] = None, + rotary_dim: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + batch, seqlen, nheads, headdim = x.shape + rotary_dim = rotary_dim or headdim + + assert rotary_dim % 2 == 0, "rotary_dim must be even" + assert headdim >= rotary_dim, "headdim must be >= rotary_dim" + assert cos.shape == (seqlen, rotary_dim), f"cos shape mismatch: {cos.shape} vs ({seqlen}, {rotary_dim})" + assert sin.shape == (seqlen, rotary_dim), f"sin shape mismatch: {sin.shape} vs ({seqlen}, {rotary_dim})" + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + BLOCK_SIZE_M = min(max(triton.next_power_of_2(rotary_dim // 2), 16), 64) + grid = lambda META: (batch, nheads, triton.cdiv(seqlen, META["BLOCK_M"])) + + rotary_kernel[grid]( + x, + cos, + sin, + cu_seqlens, + seqlens, + out, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + rotary_dim, + seqlen, + batch * seqlen, + nheads, + seqlen, + interleaved, + conjugate, + BLOCK_SIZE_M, + IS_EVEN_K=(headdim == rotary_dim), + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stdout new file mode 100644 index 0000000..37747e0 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_843724.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_843724 due to cos shape mismatch: torch.Size([128, 16]) vs (128, 64) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py b/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py new file mode 100644 index 0000000..e9cb11d --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py @@ -0,0 +1,295 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + OUT, + CU_SEQLENS, + seqlens, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cos_m, + stride_cos_k, + stride_sin_m, + stride_sin_k, + stride_ob, + stride_oh, + stride_om, + stride_ok, + num_heads, + rotary_dim, + seqlen_offsets, + interleaved: tl.constexpr, + conjugate: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if CU_SEQLENS is not None: + seqlen_start = tl.load(CU_SEQLENS + pid_batch) + seqlen_end = tl.load(CU_SEQLENS + pid_batch + 1) + seq_len = seqlen_end - seqlen_start + else: + seqlen_start = 0 + seq_len = tl.load(seqlens + pid_batch) + + if pid_m * BLOCK_M >= seq_len: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K // 2) + + if CU_SEQLENS is not None: + pos = seqlen_start + offs_m + else: + pos = seqlen_offsets + offs_m + + mask_m = offs_m < seq_len + mask_k = offs_k < rotary_dim // 2 + + if not interleaved: + x0_ptrs = ( + X + + pid_batch * stride_xb + + pid_head * stride_xh + + offs_m[:, None] * stride_xm + + offs_k[None, :] * 2 * stride_xk + ) + x1_ptrs = x0_ptrs + stride_xk + + cos_ptrs = COS + pos[:, None] * stride_cos_m + offs_k[None, :] * stride_cos_k + sin_ptrs = SIN + pos[:, None] * stride_sin_m + offs_k[None, :] * stride_sin_k + + x0 = tl.load(x0_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + x1 = tl.load(x1_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + cos = tl.load(cos_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + sin = tl.load(sin_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + + if conjugate: + sin = -sin + + out0 = x0 * cos - x1 * sin + out1 = x0 * sin + x1 * cos + + out0_ptrs = ( + OUT + + pid_batch * stride_ob + + pid_head * stride_oh + + offs_m[:, None] * stride_om + + offs_k[None, :] * 2 * stride_ok + ) + out1_ptrs = out0_ptrs + stride_ok + + tl.store(out0_ptrs, out0, mask=mask_m[:, None] & mask_k[None, :]) + tl.store(out1_ptrs, out1, mask=mask_m[:, None] & mask_k[None, :]) + else: + x_real_ptrs = ( + X + + pid_batch * stride_xb + + pid_head * stride_xh + + offs_m[:, None] * stride_xm + + offs_k[None, :] * stride_xk * 2 + ) + x_imag_ptrs = x_real_ptrs + stride_xk + + cos_ptrs = COS + pos[:, None] * stride_cos_m + offs_k[None, :] * stride_cos_k + sin_ptrs = SIN + pos[:, None] * stride_sin_m + offs_k[None, :] * stride_sin_k + + x_real = tl.load(x_real_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + x_imag = tl.load(x_imag_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + cos = tl.load(cos_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + sin = tl.load(sin_ptrs, mask=mask_m[:, None] & mask_k[None, :]) + + if conjugate: + x_imag = -x_imag + + out_real = x_real * cos - x_imag * sin + out_imag = x_real * sin + x_imag * cos + + out_real_ptrs = ( + OUT + + pid_batch * stride_ob + + pid_head * stride_oh + + offs_m[:, None] * stride_om + + offs_k[None, :] * stride_ok * 2 + ) + out_imag_ptrs = out_real_ptrs + stride_ok + + tl.store(out_real_ptrs, out_real, mask=mask_m[:, None] & mask_k[None, :]) + tl.store(out_imag_ptrs, out_imag, mask=mask_m[:, None] & mask_k[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int = 0, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + interleaved: bool = False, + conjugate: bool = False, + inplace: bool = False, +) -> torch.Tensor: + batch, seqlen, num_heads, head_dim = x.shape + rotary_dim = cos.shape[-1] + assert rotary_dim <= head_dim + assert rotary_dim % 2 == 0 + assert cos.shape == (seqlen, rotary_dim) + assert sin.shape == (seqlen, rotary_dim) + + BLOCK_K = 128 + BLOCK_M = 64 + + grid = (batch, num_heads, triton.cdiv(seqlen, BLOCK_M)) + + if cu_seqlens is not None: + assert cu_seqlens.dtype == torch.int32 + assert cu_seqlens.device == x.device + max_seqlen = cu_seqlens.diff().max().item() + + if inplace: + out = x + else: + out = torch.empty_like(x) + + rotary_kernel[grid]( + x, + cos, + sin, + out, + cu_seqlens, + None if cu_seqlens is not None else torch.tensor([seqlen], dtype=torch.int32, device=x.device), + x.stride(0), + x.stride(2), + x.stride(1), + x.stride(3), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + out.stride(0), + out.stride(2), + out.stride(1), + out.stride(3), + num_heads, + rotary_dim, + seqlen_offsets, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + ) + + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stdout new file mode 100644 index 0000000..243dcf9 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_893238.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_893238 due to not enough values to unpack (expected 4, got 3) diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py b/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py new file mode 100644 index 0000000..06a368d --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py @@ -0,0 +1,292 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + if not IS_VARLEN: + current_batch_offset = pid_batch * stride_x_batch + pid_head * stride_x_nheads + X_ptr = X + current_batch_offset + OUT_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + seq_len = seqlen + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seq_len: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = tl.arange(0, BLOCK_K) + rk_half = rk % (rotary_dim // 2) + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rotary_half = rotary_dim // 2 + + if not INTERLEAVED: + k0 = rk_half + k1 = k0 + rotary_half + mask_m = rm < seq_len + mask_m_cs = rm_cs < seqlen_ro + + # first half + offset0 = rm[:, None] * stride_x_seqlen + k0[None, :] * stride_x_headdim + x0 = tl.load(X_ptr + offset0, mask=mask_m[:, None] & (k0[None, :] < rotary_half)).to(tl.float32) + cos0 = tl.load(COS + rm_cs[:, None] * rotary_half + k0[None, :], + mask=mask_m_cs[:, None] & (k0[None, :] < rotary_half), other=1.0).to(tl.float32) + sin0 = tl.load(SIN + rm_cs[:, None] * rotary_half + k0[None, :], + mask=mask_m_cs[:, None] & (k0[None, :] < rotary_half), other=0.0).to(tl.float32) + + # second half + offset1 = rm[:, None] * stride_x_seqlen + k1[None, :] * stride_x_headdim + x1 = tl.load(X_ptr + offset1, mask=mask_m[:, None] & (k1[None, :] < rotary_dim)).to(tl.float32) + + if CONJUGATE: + sin0 = -sin0 + o0 = x0 * cos0 - x1 * sin0 + o1 = x0 * sin0 + x1 * cos0 + + tl.store(OUT_ptr + offset0, o0, mask=mask_m[:, None] & (k0[None, :] < rotary_half)) + tl.store(OUT_ptr + offset1, o1, mask=mask_m[:, None] & (k1[None, :] < rotary_dim)) + else: + rk_half = rk // 2 + mask_m = rm < seq_len + mask_m_cs = rm_cs < seqlen_ro + + x_offsets = rm[:, None] * stride_x_seqlen + rk[None, :] * stride_out_headdim + cos_sin_offsets = rm_cs[:, None] * rotary_half + rk_half[None, :] + + x = tl.load(X_ptr + x_offsets, mask=mask_m[:, None] & (rk[None, :] < rotary_dim)).to(tl.float32) + + cos = tl.load(COS + cos_sin_offsets, + mask=mask_m_cs[:, None] & (rk_half[None, :] < rotary_half), other=1.0).to(tl.float32) + sin = tl.load(SIN + cos_sin_offsets, + mask=mask_m_cs[:, None] & (rk_half[None, :] < rotary_half), other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + + x0 = tl.where((rk[None, :] % 2) == 0, x, 0) + x1 = tl.where((rk[None, :] % 2) == 1, x, 0) + + out = x0 * cos + x1 * sin + tl.store(OUT_ptr + x_offsets, out, mask=mask_m[:, None] & (rk[None, :] < rotary_dim)) + + +from typing import Union, Optional + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert headdim <= 256 + + if not isinstance(seqlen_offsets, torch.Tensor): + assert isinstance(seqlen_offsets, int) and seqlen_offsets + seqlen <= seqlen_ro + else: + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.to(torch.int32) + + cos = cos.contiguous() + sin = sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + seqlen_offsets = seqlen_offsets.contiguous() + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = max(32, triton.next_power_of_2(rotary_dim)) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + with torch.cuda.device(x.device.type): + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_915460.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py b/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py new file mode 100644 index 0000000..ca54784 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py @@ -0,0 +1,303 @@ + +import torch +import triton +import triton.language as tl +from typing import Union, Optional + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + o_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + o_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + else_seq_len = seqlen + current_seqlen = tl.where(IS_VARLEN, seq_len, else_seq_len) + + if pid_m * BLOCK_M >= current_seqlen: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk_half = tl.arange(0, BLOCK_K) + + if not INTERLEAVED: + cos_ptr = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ptr = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + mask_cos_sin = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half) + cos = tl.load(cos_ptr, mask=mask_cos_sin, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_cos_sin, other=0.0).to(tl.float32) + + left = x_ptr + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + right = x_ptr + (rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim) + mask_lr = (rm[:, None] < current_seqlen) & (rk_half[None, :] < rotary_dim_half) + + x0 = tl.load(left, mask=mask_lr, other=0.0).to(tl.float32) + x1 = tl.load(right, mask=mask_lr, other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + out0 = x0 * cos - x1 * sin + out1 = x0 * sin + x1 * cos + + tl.store( + o_ptr + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim), + out0, + mask=mask_lr, + ) + tl.store( + o_ptr + (rm[:, None] * stride_out_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim), + out1, + mask=mask_lr, + ) + else: + rk = tl.arange(0, 2 * BLOCK_K) + cos_ptr = COS + (rm_cs[:, None] * rotary_dim_half + (rk[None, :] // 2)) + sin_ptr = SIN + (rm_cs[:, None] * rotary_dim_half + (rk[None, :] // 2)) + mask_cos_sin = (rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim) + cos = tl.load(cos_ptr, mask=mask_cos_sin, other=1.0).to(tl.float32) + sin = tl.load(sin_ptr, mask=mask_cos_sin, other=0.0).to(tl.float32) + + idx = x_ptr + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + mask_idx = (rm[:, None] < current_seqlen) & (rk[None, :] < rotary_dim) + x0 = tl.load(idx, mask=mask_idx, other=0.0).to(tl.float32) + + idx1 = x_ptr + (rm[:, None] * stride_x_seqlen + (rk[None, :] ^ 1) * stride_x_headdim) + x1 = tl.load(idx1, mask=mask_idx, other=0.0).to(tl.float32) + + if CONJUGATE: + sin = -sin + out = tl.where(rk[None, :] % 2 == 0, x0 * cos - x1 * sin, x0 * sin + x1 * cos) + + tl.store( + o_ptr + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim), + out, + mask=mask_idx, + ) + # copy tail + if rotary_dim < x.shape[-1]: # headdim in signature is rotary_dim, actual head dim is X.shape[-1] + actual_headdim = x.shape[-1] + else: + actual_headdim = rotary_dim + if rotary_dim < actual_headdim: + cols_tail = rotary_dim + tl.arange(0, BLOCK_K) + mask_tail = (cols_tail < actual_headdim) & (rm[:, None] < current_seqlen) + tail_src = x_ptr + (rm[:, None] * stride_x_seqlen + cols_tail[None, :] * stride_x_headdim) + tail_val = tl.load(tail_src, mask=mask_tail) + tail_dst = o_ptr + (rm[:, None] * stride_out_seqlen + cols_tail[None, :] * stride_out_headdim) + tl.store(tail_dst, tail_val, mask=mask_tail) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "max_seqlen is required when cu_seqlens provided" + total_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + seqlen = max_seqlen + + seqlen_ro, rotary_half = cos.shape + rotary_dim = rotary_half * 2 + assert rotary_dim <= headdim + assert cos.dtype == sin.dtype == x.dtype + cos, sin = cos.contiguous(), sin.contiguous() + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = x if inplace else torch.empty_like(x) + if rotary_dim < headdim and not inplace: + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) # shape (B, S, H, D) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) # shape (ΣS, H, D) + + BLOCK_K = {32: 32, 64: 64, 128: 128, 256: 256}.get( + triton.next_power_of_2(rotary_half), 256 + ) + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, + output.stride(0) if not is_varlen else 0, + output.stride(-3), + output.stride(-2), + output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stdout new file mode 100644 index 0000000..0a3302c --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_925133.py.stdout @@ -0,0 +1,15 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_925133 due to at 42:41: + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + o_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seq_len = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ptr = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + o_ptr = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + else_seq_len = seqlen + current_seqlen = tl.where(IS_VARLEN, seq_len, else_seq_len) + ^ +NameError('seq_len is not defined') diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py b/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py new file mode 100644 index 0000000..340efbe --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py @@ -0,0 +1,279 @@ + +import torch +import triton +import triton.language as tl +from typing import Optional, Union + +@triton.jit +def rotary_kernel( + OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, + seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, + stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, + stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_batch = tl.program_id(1) + pid_head = tl.program_id(2) + + if not IS_VARLEN: + x_base = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_base = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + seqlen_i = seqlen + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen_i = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_base = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_base = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen_i: + return + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rm_mask = rm < seqlen_i + + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rm_cs_mask = rm_cs < seqlen_ro + + rotary_dim_half = rotary_dim // 2 + + if not INTERLEAVED: + rk_half = tl.arange(0, BLOCK_K // 2) + rk_mask = rk_half < rotary_dim_half + + offs_xr = x_base + rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim + xr = tl.load(offs_xr, mask=rm_mask[:, None] & rk_mask[None, :], other=0.0).to(tl.float32) + + offs_xi = x_base + rm[:, None] * stride_x_seqlen + (rk_half[None, :] + rotary_dim_half) * stride_x_headdim + xi = tl.load(offs_xi, mask=rm_mask[:, None] & rk_mask[None, :], other=0.0).to(tl.float32) + + offs_cs = rm_cs[:, None] * rotary_dim_half + rk_half[None, :] + cos = tl.load(COS + offs_cs, mask=rm_cs_mask[:, None] & rk_mask[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(SIN + offs_cs, mask=rm_cs_mask[:, None] & rk_mask[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + + or_ = xr * cos - xi * sin_val + oi = xr * sin_val + xi * cos + + tl.store(out_base + rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim, + or_, mask=rm_mask[:, None] & rk_mask[None, :]) + tl.store(out_base + rm[:, None] * stride_out_seqlen + + (rk_half[None, :] + rotary_dim_half) * stride_out_headdim, + oi, mask=rm_mask[:, None] & rk_mask[None, :]) + else: + rk = tl.arange(0, BLOCK_K) + rk_mask = rk < rotary_dim + rk_half_idx = rk // 2 + rk_mask_half = rk_half_idx < rotary_dim_half + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 + + offs_x0 = x_base + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim + x0 = tl.load(offs_x0, mask=rm_mask[:, None] & rk_mask[None, :], other=0.0).to(tl.float32) + + offs_x1 = x_base + rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim + x1 = tl.load(offs_x1, mask=rm_mask[:, None] & rk_swap[None, :] < rotary_dim, other=0.0).to(tl.float32) + + offs_cs = rm_cs[:, None] * rotary_dim_half + rk_half_idx[None, :] + cos = tl.load(COS + offs_cs, mask=rm_cs_mask[:, None] & rk_mask_half[None, :], other=1.0).to(tl.float32) + sin_val = tl.load(SIN + offs_cs, mask=rm_cs_mask[:, None] & rk_mask_half[None, :], other=0.0).to(tl.float32) + if CONJUGATE: + sin_val = -sin_val + + out_even = x0 * cos - x1 * sin_val + out_odd = x0 * sin_val + x1 * cos + + out_offs = out_base + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim + out_val = tl.where(rk[None, :] % 2 == 0, out_even, out_odd) + tl.store(out_offs, out_val, mask=rm_mask[:, None] & rk_mask[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "Must provide max_seqlen with cu_seqlens" + total_seqlen, nheads, headdim = x.shape + seqlen = max_seqlen + batch = cu_seqlens.shape[0] - 1 + + assert cos.shape == sin.shape + seqlen_ro, rotary_dim_half = cos.shape + rotary_dim = rotary_dim_half * 2 + assert rotary_dim <= headdim + assert seqlen_ro >= seqlen + (seqlen_offsets.max().item() + if isinstance(seqlen_offsets, torch.Tensor) + else seqlen_offsets) + assert x.dtype == cos.dtype == sin.dtype, "All tensors must share dtype" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,), "seqlen_offsets must be 1-D tensor of length batch" + seqlen_offsets = seqlen_offsets.int().contiguous() + else: + seqlen_offsets = int(seqlen_offsets) + + output = x if inplace else torch.empty_like(x) + if not inplace and rotary_dim < headdim: + if not is_varlen: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + else: + output[:, :, rotary_dim:].copy_(x[:, :, rotary_dim:]) + + BLOCK_K = 32 if rotary_dim <= 32 else \ + 64 if rotary_dim <= 64 else \ + 128 if rotary_dim <= 128 else 256 + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) + + cu_seqlens_ptr = cu_seqlens.int().contiguous() if is_varlen else None + + with torch.cuda.device(x.device): + rotary_kernel[grid]( + output, + x, + cos, + sin, + cu_seqlens_ptr, + seqlen_offsets, + seqlen, + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # dummy + output.stride(0) if not is_varlen else 0, + output.stride(-2 if is_varlen else -3), + output.stride(-1 if is_varlen else -2), + output.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-2 if is_varlen else -3), + x.stride(-1 if is_varlen else -2), + x.stride(-1), + BLOCK_K=BLOCK_K, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + BLOCK_M=BLOCK_M, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_939610.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py b/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py new file mode 100644 index 0000000..dd2d7e9 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py @@ -0,0 +1,284 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLENS, + stride_x_batch, + stride_x_head, + stride_x_m, + stride_x_n, + stride_cos_m, + stride_cos_n, + stride_sin_m, + stride_sin_n, + stride_out_batch, + stride_out_head, + stride_out_m, + stride_out_n, + batch_size, + head_num, + seq_len, + H, + D, + HID, + stride_h, + stride_d, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + DTYPE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if pid_batch >= batch_size or pid_head >= head_num: + return + + seq_start = 0 + cur_seq_len = seq_len + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch) + cur_seq_len = tl.load(SEQLENS + pid_batch) + elif seq_len > 0: + cur_seq_len = seq_len + else: + cur_seq_len = seq_len + + if pid_m * BLOCK_M >= cur_seq_len: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + mask_m = offs_m < cur_seq_len + + cos_ptrs = COS + (seq_start + offs_m[:, None]) * stride_cos_m + offs_n[None, :] * stride_cos_n + sin_ptrs = SIN + (seq_start + offs_m[:, None]) * stride_sin_m + offs_n[None, :] * stride_sin_n + + cos = tl.load(cos_ptrs, mask=mask_m[:, None], other=0.0) + sin = tl.load(sin_ptrs, mask=mask_m[:, None], other=0.0) + + x_base_ptr = X + pid_batch * stride_x_batch + pid_head * stride_x_head + out_base_ptr = OUT + pid_batch * stride_out_batch + pid_head * stride_out_head + + if INTERLEAVED: + offs_d = 2 * offs_n + x_ptr0 = x_base_ptr + offs_m[:, None] * stride_x_m + offs_d[None, :] * stride_x_n + x_ptr1 = x_base_ptr + offs_m[:, None] * stride_x_m + (offs_d + 1)[None, :] * stride_x_n + + x0 = tl.load(x_ptr0, mask=mask_m[:, None], other=0.0).to(DTYPE) + x1 = tl.load(x_ptr1, mask=mask_m[:, None], other=0.0).to(DTYPE) + + c = cos + s = sin if not CONJUGATE else -sin + y0 = x0 * c - x1 * s + y1 = x0 * s + x1 * c + + tl.store(out_base_ptr + offs_m[:, None] * stride_out_m + offs_d[None, :] * stride_out_n, y0, mask=mask_m[:, None]) + tl.store(out_base_ptr + offs_m[:, None] * stride_out_m + (offs_d + 1)[None, :] * stride_out_n, y1, mask=mask_m[:, None]) + else: + offs_d0 = offs_n + offs_d1 = offs_n + HID + + x_ptr0 = x_base_ptr + offs_m[:, None] * stride_x_m + offs_d0[None, :] * stride_x_n + x_ptr1 = x_base_ptr + offs_m[:, None] * stride_x_m + offs_d1[None, :] * stride_x_n + + x0 = tl.load(x_ptr0, mask=mask_m[:, None], other=0.0).to(DTYPE) + x1 = tl.load(x_ptr1, mask=mask_m[:, None], other=0.0).to(DTYPE) + + c = cos + s = sin if not CONJUGATE else -sin + y0 = x0 * c - x1 * s + y1 = x0 * s + x1 * c + + tl.store(out_base_ptr + offs_m[:, None] * stride_out_m + offs_d0[None, :] * stride_out_n, y0, mask=mask_m[:, None]) + tl.store(out_base_ptr + offs_m[:, None] * stride_out_m + offs_d1[None, :] * stride_out_n, y1, mask=mask_m[:, None]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = 0, +) -> torch.Tensor: + assert x.dim() == 4 + batch_size, head_num, seq_len, dim = x.shape + assert dim % 2 == 0 + assert cos.dim() == 4 and sin.dim() == 4 + interleaved = False + conjugate = False + + dtype = x.dtype + if dtype == torch.float16: + triton_dtype = tl.float16 + elif dtype == torch.float32: + triton_dtype = tl.float32 + else: + raise ValueError("Unsupported dtype") + + out = torch.empty_like(x) + + HID = dim // 2 + stride_x_batch = x.stride(0) + stride_x_head = x.stride(1) + stride_x_m = x.stride(2) + stride_x_n = x.stride(3) + stride_cos_m = cos.stride(2) + stride_cos_n = cos.stride(3) + stride_sin_m = sin.stride(2) + stride_sin_n = sin.stride(3) + stride_out_batch = out.stride(0) + stride_out_head = out.stride(1) + stride_out_m = out.stride(2) + stride_out_n = out.stride(3) + + BLOCK_M = 32 + BLOCK_N = HID + grid = (triton.cdiv(batch_size, 1), triton.cdiv(head_num, 1), triton.cdiv(seq_len, BLOCK_M)) + + rotary_kernel[grid]( + out, + x, + cos, + sin, + cu_seqlens, + None, + stride_x_batch, + stride_x_head, + stride_x_m, + stride_x_n, + stride_cos_m, + stride_cos_n, + stride_sin_m, + stride_sin_n, + stride_out_batch, + stride_out_head, + stride_out_m, + stride_out_n, + batch_size, + head_num, + seq_len, + None, + dim, + HID, + None, + None, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + INTERLEAVED=interleaved, + CONJUGATE=conjugate, + DTYPE=triton_dtype, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stdout new file mode 100644 index 0000000..6ab1474 --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_946209.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_946209 due to diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py b/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py new file mode 100644 index 0000000..b97081b --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py @@ -0,0 +1,279 @@ + +import torch +import triton +import triton.language as tl +from typing import Union, Optional + + +@triton.jit +def rotary_kernel( + X, + COS, + SIN, + OUT, + CU_SEQLENS, + SEQLEN_OFFSETS, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cosm, + stride_cosk, + stride_sinm, + stride_sink, + stride_ob, + stride_oh, + stride_om, + stride_ok, + max_seqlen, + rotary_dim, + seqlen, + interleaved: tl.constexpr, + conjugate: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_m = tl.program_id(2) + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_b) + seq_end = tl.load(CU_SEQLENS + pid_b + 1) + current_seqlen = seq_end - seq_start + else: + current_seqlen = seqlen + seq_start = 0 + + if SEQLEN_OFFSETS is not None: + offset = tl.load(SEQLEN_OFFSETS + pid_b).to(tl.int32) + else: + offset = 0 + + if pid_m >= current_seqlen: + return + + rotary_dim_half = rotary_dim // 2 + cols = tl.arange(0, BLOCK_K) + mask = cols < rotary_dim_half + + pos = seq_start + pid_m + offset + + # load cos/sin + cos_ptr = COS + pos * stride_cosm + sin_ptr = SIN + pos * stride_sinm + cos_val = tl.load(cos_ptr + cols * stride_cosk, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + cols * stride_sink, mask=mask, other=0.0).to(tl.float32) + + x_base = pid_b * stride_xb + pid_h * stride_xh + pid_m * stride_xm + out_base = pid_b * stride_ob + pid_h * stride_oh + pid_m * stride_om + + if interleaved: + even_ptrs = x_base + (cols * 2) * stride_xk + odd_ptrs = x_base + (cols * 2 + 1) * stride_xk + x0 = tl.load(even_ptrs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(odd_ptrs, mask=mask, other=0.0).to(tl.float32) + + if conjugate: + x1 = -x1 + + o0 = x0 * cos_val - x1 * sin_val + o1 = x0 * sin_val + x1 * cos_val + + tl.store(out_base + (cols * 2) * stride_ok, o0, mask=mask) + tl.store(out_base + (cols * 2 + 1) * stride_ok, o1, mask=mask) + else: + left_ptrs = x_base + cols * stride_xk + right_ptrs = x_base + (cols + rotary_dim_half) * stride_xk + x0 = tl.load(left_ptrs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(right_ptrs, mask=mask, other=0.0).to(tl.float32) + + if conjugate: + x1 = -x1 + + o0 = x0 * cos_val - x1 * sin_val + o1 = x0 * sin_val + x1 * cos_val + + tl.store(out_base + cols * stride_ok, o0, mask=mask) + tl.store(out_base + (cols + rotary_dim_half) * stride_ok, o1, mask=mask) + + cols_rest = rotary_dim + tl.arange(0, BLOCK_K) + mask_rest = cols_rest < max_seqlen + if mask_rest.any(): + src_rest = x_base + cols_rest * stride_xk + dst_rest = out_base + cols_rest * stride_ok + val_rest = tl.load(src_rest, mask=mask_rest, other=0.0) + tl.store(dst_rest, val_rest, mask=mask_rest) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + if cu_seqlens is None: + batch, nheads, seqlen, headdim = x.shape + else: + assert x.ndim == 3 + tot_seqlen, nheads, headdim = x.shape + batch = cu_seqlens.shape[0] - 1 + if max_seqlen is None: + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max()) + seqlen = max_seqlen + + seqlen_ro, halfdim = cos.shape + assert sin.shape == cos.shape + rotary_dim = halfdim * 2 + assert rotary_dim <= headdim + assert x.dtype == cos.dtype == sin.dtype + + cos = cos.contiguous() + sin = sin.contiguous() + + if isinstance(seqlen_offsets, int): + seqlen_offsets_tensor = None + else: + assert seqlen_offsets.shape == (batch,) + seqlen_offsets = seqlen_offsets.contiguous() + seqlen_offsets_tensor = seqlen_offsets + + out = torch.empty_like(x) if not inplace else x + + # only copy non-rotary tails if not inplace + if rotary_dim < headdim and not inplace: + if cu_seqlens is None: + out[..., rotary_dim:] = x[..., rotary_dim:] + else: + out[:, :, rotary_dim:] = x[:, :, rotary_dim:] + + BLOCK_M = 1 + BLOCK_K = triton.next_power_of_2(rotary_dim) // 2 + if cu_seqlens is None: + grid = (batch, nheads, seqlen) + else: + grid = (batch, nheads, max_seqlen) + + rotary_kernel[grid]( + x, cos, sin, out, + cu_seqlens, + seqlen_offsets_tensor, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cos.stride(0), cos.stride(1), + sin.stride(0), sin.stride(1), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + headdim, + rotary_dim, + seqlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + ) + return out + + +################################################################################################################################################## + + + + + +import torch + + + +def test_apply_rotary(): + + results = {} + + + + # Test case 1: Basic test with fixed sequence length and no interleaving + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin) + + results['test_case_1'] = output.shape + + + + # Test case 2: Variable length sequences with interleaving + + total_seqlen, nheads, headdim = 256, 4, 64 + + batch = 3 + + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + + max_seqlen = 128 + + rotary_dim = 32 + + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + + results['test_case_2'] = output.shape + + + + # Test case 3: Conjugate flag enabled + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, conjugate=True) + + results['test_case_3'] = output.shape + + + + # Test case 4: Inplace operation + + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + + rotary_dim = 32 + + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + + output = apply_rotary(x, cos, sin, inplace=True) + + results['test_case_4'] = output.shape + + + + return results + + + +result_gold = test_apply_rotary() diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stderr b/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stdout b/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stdout new file mode 100644 index 0000000..4c4372d --- /dev/null +++ b/src/temp/gen/rotary_transform.py_gen_triton_code_99563.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module rotary_transform.py_gen_triton_code_99563 due to at 80:13: + if conjugate: + x1 = -x1 + + o0 = x0 * cos_val - x1 * sin_val + o1 = x0 * sin_val + x1 * cos_val + + tl.store(out_base + (cols * 2) * stride_ok, o0, mask=mask) + tl.store(out_base + (cols * 2 + 1) * stride_ok, o1, mask=mask) + else: + left_ptrs = x_base + cols * stride_xk + right_ptrs = x_base + (cols + rotary_dim_half) * stride_xk + x0 = tl.load(left_ptrs, mask=mask, other=0.0).to(tl.float32) + ^ diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py b/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py new file mode 100644 index 0000000..5c47073 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) # 1-D program launch grid + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements # protection against out-of-bounds memory access + x = tl.load(x_ptr + offsets, mask=mask) + out = tl.math.sin(x) + tl.store(output_ptr + offsets, out, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stdout new file mode 100644 index 0000000..4bd0c64 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_123151.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module sin_kernel.py_gen_triton_code_123151 due to call_kernel() missing 1 required positional argument: 'BLOCK_SIZE' diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py b/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py new file mode 100644 index 0000000..da32b28 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]), ) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=1024, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_179581.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py b/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py new file mode 100644 index 0000000..7d4dc17 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py @@ -0,0 +1,101 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor): + n_elements = x.numel() + output = torch.empty_like(x) + + BLOCK_SIZE = 1024 + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_370053.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py b/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py new file mode 100644 index 0000000..809fd65 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py @@ -0,0 +1,92 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + kernel_function[grid]( + x, output, n_elements, + BLOCK_SIZE=1024, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_473025.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py b/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py new file mode 100644 index 0000000..074cbe2 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + block_start = block_id * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 1024) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_502063.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py b/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py new file mode 100644 index 0000000..6b2ae34 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py @@ -0,0 +1,102 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + out = tl.math.sin(x) + tl.store(output_ptr + offsets, out, mask=mask) + + +def call_kernel( + x: torch.Tensor, + BLOCK_SIZE: int = 128, +) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_50482.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py b/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py new file mode 100644 index 0000000..97c0a6d --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + result = tl.math.sin(x) + tl.store(output_ptr + offsets, result, mask=mask) + +def call_kernel(x: torch.Tensor, output: torch.Tensor = None): + assert x.is_contiguous() + n_elements = x.numel() + output = torch.empty_like(x) if output is None else output + BLOCK_SIZE = 128 # fast baseline + n_programs = triton.cdiv(n_elements, BLOCK_SIZE) + kernel_function[(n_programs)]( + x_ptr = x, + output_ptr = output, + n_elements = n_elements, + BLOCK_SIZE = BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stdout new file mode 100644 index 0000000..61e9a0b --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_557502.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module sin_kernel.py_gen_triton_code_557502 due to object of type 'int' has no len() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py b/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py new file mode 100644 index 0000000..a75668d --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + result = tl.math.sin(x) + tl.store(output_ptr + offsets, result, mask=mask) + +def call_kernel(x: torch.Tensor, output: torch.Tensor = None): + assert x.is_contiguous() + n_elements = x.numel() + output = torch.empty_like(x) if output is None else output + BLOCK_SIZE = 128 + n_programs = triton.cdiv(n_elements, BLOCK_SIZE) + kernel_function[(n_programs,)]( + x_ptr=x, + output_ptr=output, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_560359.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py b/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py new file mode 100644 index 0000000..f90fd90 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py @@ -0,0 +1,95 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 256) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_794865.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py b/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py new file mode 100644 index 0000000..60d3a13 --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 1024): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_834634.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py b/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py new file mode 100644 index 0000000..b70d3df --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py @@ -0,0 +1,100 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) # program id of the 1D launch grid + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements # mask out-of-bounds elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 64): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + kernel_function[grid]( + x, + output, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +################################################################################################################################################## + + + + + +import torch + + + +# Function to test the Triton kernel + +def test_call_kernel(): + + results = {} + + + + # Test case 1: Small input tensor + + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + + output1 = call_kernel(x1) + + results['test_case_1'] = output1 + + + + # Test case 2: Larger input tensor + + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + + output2 = call_kernel(x2) + + results['test_case_2'] = output2 + + + + # Test case 3: Edge case with zero elements + + x3 = torch.tensor([], dtype=torch.float32).cuda() + + output3 = call_kernel(x3) + + results['test_case_3'] = output3 + + + + # Test case 4: Input tensor with negative values + + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + + output4 = call_kernel(x4) + + results['test_case_4'] = output4 + + + + return results + + + +# Run the test function + +result_gold = test_call_kernel() diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stderr b/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stdout b/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/sin_kernel.py_gen_triton_code_931009.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py b/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py new file mode 100644 index 0000000..00259dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py @@ -0,0 +1,129 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + IS_EVEN_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m, M - first_pid_m * BLOCK_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if IS_EVEN_K or k * BLOCK_SIZE_K + BLOCK_SIZE_K <= K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + k_remaining = K - k * BLOCK_SIZE_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + if a.dtype == torch.float16: + BLOCK_SIZE_M, BLOCK_SIZE_N = 64, 64 + num_stages = 3 + num_warps = 4 + elif a.dtype == torch.float32: + BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 128 + num_stages = 3 + num_warps = 8 + else: + raise ValueError("Unsupported dtype") + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=32, + IS_EVEN_K=(K % 32 == 0), + num_warps=num_warps, + num_stages=num_stages, + ) + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_108037.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py b/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py new file mode 100644 index 0000000..9d9413c --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py @@ -0,0 +1,120 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr = 8, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < (K - k * BLOCK_SIZE_K), other=0.0) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(a_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor = None): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Input a must be contiguous" + assert b.is_contiguous(), "Input b must be contiguous" + + M, K = a.shape + K_, N = b.shape + assert K == K_, "K dimension mismatch" + + if c is None: + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=32, + ) + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_12912.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py b/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py new file mode 100644 index 0000000..c8f7671 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py @@ -0,0 +1,125 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + AXIS: tl.constexpr, +): + pid = tl.program_id(AXIS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mask_a = (offs_am[:, None] < M) & (offs_k[None, :] < K) + mask_b = (offs_k[:, None] < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(a_ptr.dtype.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, c, mask=mask_c) + +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor=None, activation=None): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + if c is None: + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + if a.dtype == torch.float16: + BLOCK_SIZE_K = 64 if K > 64 else 32 + num_warps = 4 + num_stages = 4 if K < 512 else 3 + else: + BLOCK_SIZE_K = 128 + num_warps = 4 + num_stages = 3 + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + AXIS=0, + num_warps=num_warps, + num_stages=num_stages, + ) + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_186313.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py b/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py new file mode 100644 index 0000000..6e314ef --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py @@ -0,0 +1,124 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m - first_pid_m, num_pid_m) + pid_m = first_pid_m + (pid % num_pid_m) + pid_n = (pid % num_pid_in_group) - pid_m * num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_offs = k * BLOCK_SIZE_K + mask_a = offs_k[None, :] < (K - k_offs) + mask_b = offs_k[:, None] < (K - k_offs) + + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_warps = 8 + num_stages = 2 + + grid = lambda META: (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + num_warps=num_warps, + num_stages=num_stages + ) + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_284744.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py b/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py new file mode 100644 index 0000000..64b2f74 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py @@ -0,0 +1,127 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = num_pid_m * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * num_pid_m + group_size_m = min(num_pid_m, M - first_pid_m * BLOCK_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(a_ptr.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + if a.dtype == torch.float16: + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + num_stages = 2 + num_warps = 4 + elif a.dtype == torch.float32: + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 32 + num_stages = 2 + num_warps = 4 + else: + raise ValueError(f"Unsupported dtype: {a.dtype}") + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + num_stages=num_stages, num_warps=num_warps + ) + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_366643.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py b/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py new file mode 100644 index 0000000..ba491f5 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py @@ -0,0 +1,151 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + # fmt: off + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + AXIS_WM: tl.constexpr, AXIS_WN: tl.constexpr, + IS_EVEN_K: tl.constexpr, + DTYPE: tl.constexpr, + # fmt: on +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if not IS_EVEN_K: + k_mask = offs_k[None, :] < K - k * BLOCK_SIZE_K + a = tl.load(a_ptrs, mask=k_mask, other=0.0) + b = tl.load(b_ptrs, mask=k_mask.T, other=0.0) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + a = a.to(DTYPE) + b = b.to(DTYPE) + accumulator += tl.dot(a, b, allow_tf32=True) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + # assumes accum is always fp32 + c = accumulator.to(DTYPE) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor = None) -> torch.Tensor: + assert a.dim() == 2, f"Expected 2D input tensor a, got {a.dim()}D" + assert b.dim() == 2, f"Expected 2D input tensor b, got {b.dim()}D" + M, K = a.shape + K_, N = b.shape + assert K == K_, f"Incompatible inner dimension for matmul: {K} vs {K_}" + + if c is None: + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + else: + assert c.shape == (M, N), f"Output tensor has wrong shape: {c.shape}, expected {(M, N)}" + + # constants depending on dtype + if a.dtype == torch.float16: + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_stages = 2 + num_warps = 8 + else: + # float32 + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 32 + num_stages = 2 + num_warps = 4 + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + # fmt: off + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + AXIS_WM=0, AXIS_WN=1, + IS_EVEN_K=(K % BLOCK_SIZE_K == 0), + DTYPE=tl.float32 if a.dtype == torch.float32 else tl.float16, + num_stages=num_stages, num_warps=num_warps, + ) + # fmt: on + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_391924.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py b/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py new file mode 100644 index 0000000..6b43530 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py @@ -0,0 +1,120 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k * BLOCK_SIZE_K + a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < k_remaining) + b_mask = (offs_k[:, None] < k_remaining) & (offs_n[None, :] < N) + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator.to(c_ptr.type.element_ty), mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.device == b.device and a.dtype == b.dtype, "Input tensors must be on the same device and dtype" + assert a.shape[1] == b.shape[0], "Incompatible dimensions for matrix multiplication" + + M, K = a.shape + _, N = b.shape + c = torch.empty((M, N), dtype=a.dtype, device=a.device) + + block_size_m = 64 + block_size_n = 64 + block_size_k = 32 + + if a.dtype == torch.float16: + num_warps = 4 + num_stages = 3 + else: # torch.float32 or others + num_warps = 8 + num_stages = 2 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + ) + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_395140.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py b/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py new file mode 100644 index 0000000..53a9c2e --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py @@ -0,0 +1,124 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_start = k * BLOCK_SIZE_K + offs_k_curr = k_start + offs_k + a_mask = offs_k_curr[None, :] < K + b_mask = offs_k_curr[:, None] < K + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +def matmul(a: torch.Tensor, b: torch.Tensor, activation=""): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + GROUP_SIZE_M = 8 + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_417385.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py b/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py new file mode 100644 index 0000000..f15a5de --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py @@ -0,0 +1,122 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + IS_EVEN_K: tl.constexpr +): + pid = tl.program_id(0) + num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_in_group = group_size = num_pid_m + group_id = pid // group_size + first_pid_m = group_id * group_size + group_size_m = min(num_pid_m - first_pid_m, group_size) + pid_m = first_pid_m + (pid % group_size) + pid_n = (pid % num_pid_n) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if IS_EVEN_K or k * BLOCK_SIZE_K + offs_k[None, :] < K: + a = tl.load(a_ptrs, mask=None if IS_EVEN_K else k * BLOCK_SIZE_K + offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrs, mask=None if IS_EVEN_K else k * BLOCK_SIZE_K + offs_k[:, None] < K, other=0.0) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_warps = 8 + num_stages = 2 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + IS_EVEN_K=K % BLOCK_SIZE_K == 0, + num_warps=num_warps, + num_stages=num_stages + ) + + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stdout new file mode 100644 index 0000000..ebbed98 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_654780.py.stdout @@ -0,0 +1 @@ +False*#*#None*#*#None*#*#Could not execute module triton_matmul.py_gen_triton_code_654780 due to Simultaneous multiple assignment is not supported. diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py b/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py new file mode 100644 index 0000000..58aedb7 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py @@ -0,0 +1,121 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + IS_EVEN_K: tl.constexpr +): + pid = tl.program_id(0) + num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + group_size = num_pid_m + num_pid_in_group = num_pid_m + group_id = pid // group_size + first_pid_m = group_id * group_size + group_size_m = min(num_pid_m - first_pid_m, group_size) + pid_m = first_pid_m + (pid % group_size) + pid_n = (pid % num_pid_n) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if IS_EVEN_K or k * BLOCK_SIZE_K + offs_k[None, :] < K: + a = tl.load(a_ptrs, mask=None if IS_EVEN_K else k * BLOCK_SIZE_K + offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrs, mask=None if IS_EVEN_K else k * BLOCK_SIZE_K + offs_k[:, None] < K, other=0.0) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + num_warps = 8 + num_stages = 2 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + IS_EVEN_K=K % BLOCK_SIZE_K == 0, + num_warps=num_warps, + num_stages=num_stages + ) + + return c + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stdout new file mode 100644 index 0000000..380d214 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_769893.py.stdout @@ -0,0 +1,14 @@ +False*#*#None*#*#None*#*#Could not execute module triton_matmul.py_gen_triton_code_769893 due to at 30:16: + pid_n = (pid % num_pid_n) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if IS_EVEN_K or k * BLOCK_SIZE_K + offs_k[None, :] < K: + a = tl.load(a_ptrs, mask=None if IS_EVEN_K else k * BLOCK_SIZE_K + offs_k[None, :] < K, other=0.0) + ^ diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py b/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py new file mode 100644 index 0000000..1271799 --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py @@ -0,0 +1,133 @@ + +import torch +import triton +import triton.language as tl + +# ------------------------------------------------------------------------------ +# Helper: fast, tiled matmul for fp16 / fp32 on AMD GPUs (ROCm-compatible) +# ------------------------------------------------------------------------------ +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mask_a = (offs_am[:, None] < M) & ((BLOCK_SIZE_K * k + offs_k)[None, :] < K) + mask_b = ((BLOCK_SIZE_K * k + offs_k)[:, None] < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=mask) + + +def matmul(a: torch.Tensor, b: torch.Tensor, activation=None): + assert a.dtype == b.dtype + assert a.dim() == 2 and b.dim() == 2, "only 2-D tensors supported" + M, K = a.shape + K2, N = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + if a.dtype == torch.float16: + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + num_stages = 2 + num_warps = 4 + elif a.dtype == torch.float32: + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + num_stages = 4 + num_warps = 4 + else: + raise RuntimeError("Unsupported dtype for AMD Triton matmul") + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + num_stages=num_stages, + num_warps=num_warps, + ) + return c + + +################################################################################################################################################## + + + + + +import torch + + + +# Test for matmul + +def test_matmul(): + + results = {} + + M, K, N = 256, 128, 256 + + + + # Test case 1: torch.float16 + + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + + c = matmul(a, b) + + results['test_case_1'] = c + + + + return results + + + +# Run all tests + +result_gold = test_matmul() \ No newline at end of file diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stderr b/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stderr new file mode 100644 index 0000000..e69de29 diff --git a/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stdout b/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stdout new file mode 100644 index 0000000..113b5dd --- /dev/null +++ b/src/temp/gen/triton_matmul.py_gen_triton_code_993568.py.stdout @@ -0,0 +1 @@ +True*#*#True*#*#None*#*#None diff --git a/src/temp/int4_matmul.py b/src/temp/int4_matmul.py new file mode 100644 index 0000000..f9dae9c --- /dev/null +++ b/src/temp/int4_matmul.py @@ -0,0 +1,286 @@ + +import torch +import triton +import triton.language as tl + +# -------------------------------------------------------------------------------- +# Triton kernels for INT4 matrix multiplication (weight dequantized on the fly) +# -------------------------------------------------------------------------------- +@triton.autotune( + configs=[ + # M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_stages, num_warps + triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + ], + key = ["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # pointers to matmul operands + a_ptr, b_ptr, c_ptr, # a is fp16/bf16, b is quantized (int packed), c is output fp16/bf16 + # scales + zero points vectors + scales_ptr, zeros_ptr, # per-group fp16 + # strides + stride_am, stride_ak, + stride_bk, stride_bn, stride_b_packed, # b is (K/8, N) packed 8 int4 in one int32 + stride_cm, stride_cn, + stride_scales, # (num_groups) + stride_zeros, # (num_groups) + # dimension sizes + M, N, K, + groupsize: tl.constexpr, # dequantization group granularity + # block sizes for tiling + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z= tl.program_id(axis=1) # for SPLIT_K + + # tile identifiers + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if SPLIT_K > 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_tiles_k = tl.cdiv(K, BLOCK_SIZE_K) + pid_m = pid // (num_tiles_k * num_pid_n) + remaining = pid % (num_tiles_k * num_pid_n) + pid_n = remaining // num_tiles_k + pid_k_first = remaining % num_tiles_k + pid_k_last = pid_k_first + 1 + # NOTE: currently implement simple row/col tiling, so we set SPLIT_K always to 1 + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # offset block pointers + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + # adjust overlapping + offs_m = tl.where(offs_m < M, offs_m, M-1) + offs_n = tl.where(offs_n < N, offs_n, N-1) + offs_k = tl.where(offs_k < K, offs_k, K-1) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_b_packed + offs_n[None, :] * stride_bn) + + scales_ptrs = scales_ptr + ((offs_k[:, None] // groupsize) * stride_scales) + zeros_ptrs = zeros_ptr + ((offs_k[:, None] // groupsize) * stride_zeros) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + # edge masking + k_cur = k * BLOCK_SIZE_K + offs_k + mask_k = k_cur < K + + # load A tile (fp16) + a_tile = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) + + # load packed INT4 B tile + b_int = tl.load(b_ptrs, mask=mask_k[:, None] & (offs_n[None, :] < N), other=0) + + # ---- dequantize ---- + # unpack each int32 into 8 int4 values (low nibble first) + scales = tl.load(scales_ptrs, mask=mask_k[:, None] & (offs_n[None, :] < N), other=1.0) + zeros = tl.load(zeros_ptrs, mask=mask_k[:, None] & (offs_n[None, :] < N), other=0.0) + + # split nibble from packed int8 + inner = (offs_k[:, None] % 8) * 4 + b_ext = (b_int >> inner) & 0xF # 0..15 + b_deint = b_ext.to(tl.float32) + + bq_f32 = scales * (b_deint - zeros) + + # emulated block-K reduction accumulation + accumulator += tl.dot(a_tile, bq_f32) + + # advance pointers + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K // 8) * stride_b_packed + scales_ptrs += (BLOCK_SIZE_K * SPLIT_K // groupsize) * stride_scales + zeros_ptrs += (BLOCK_SIZE_K * SPLIT_K // groupsize) * stride_zeros + + # write back + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_m = offs_cm < M + mask_n = offs_cn < N + c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + c = accumulator.to(c_ptr.type.element_ty) + tl.store(c_ptrs, c, mask=mask_m[:, None] & mask_n[None, :]) + + + +# -------------------------------------------------------------------------------- +# Python utility entry — int4 dequantized matrix multiply wrapper +# -------------------------------------------------------------------------------- +def matmul_dequantize_int4_s2( + x: torch.Tensor, # (M, K) fp16/fp32 + qweight: torch.Tensor, # (K//8, N) int32 each value holds 8 int4 + scales: torch.Tensor, # (num_groups, N) fp16/fp32 + zeros: torch.Tensor, # (num_groups, N) fp16/fp32 + groupsize: int = 128, +) -> torch.Tensor: + # Device check (ROCm friendly) + assert x.is_cuda or str(x.device).startswith("cuda") + M, K = x.shape + assert qweight.shape == (K//8, qweight.shape[1]) + N = qweight.shape[1] + + # alloc output + c = torch.empty((M, N), dtype=x.dtype, device=x.device) + + # prepare grid + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 1) + + matmul_kernel[grid]( + x, qweight, c, + scales, zeros, + x.stride(0), x.stride(1), + qweight.stride(0), qweight.stride(1), qweight.stride(0), # 3rd stride unused in kernel + c.stride(0), c.stride(1), + scales.stride(0), zeros.stride(0), + M, N, K, + groupsize, + ) + + return c + + + +# -------------------------------------------------------------------------------- +# INT4 quantize helper +# -------------------------------------------------------------------------------- +def quantize_int4(w: torch.Tensor, groupsize: int = 128) -> tuple: + """ + Quantize fp16/32 weights into INT4 with per-group scale & zero-point. + Returns: + qw (K//8, N) int32 -> 8 int4 per int32 + scales (num_groups, N) fp16 + zeros (num_groups, N) fp16 + """ + if w.dim() == 1: + w = w.unsqueeze(1) + shape = w.shape + K_orig, N = shape[-2], shape[-1] + w = w.view(-1, N) + + # pad to multiple of groupsize + K_pad = (K_orig + groupsize - 1) // groupsize * groupsize + if K_pad > K_orig: + w = torch.cat([w, torch.zeros(K_pad - K_orig, N, dtype=w.dtype, device=w.device)], dim=0) + + assert w.shape[0] % groupsize == 0 + num_groups = w.shape[0] // groupsize + + # Reshape to (num_groups, groupsize, N) + w = w.view(num_groups, groupsize, N) + + # compute scale & zero + w_min = torch.amin(w, dim=1) # (num_groups,N) + w_max = torch.amax(w, dim=1) + scale = (w_max - w_min) / 15.0 + scale = scale.clamp(min=1e-10) + zero = (torch.round(-w_min / scale)).clamp(0, 15) + + # quantize + w_int = torch.round(w / scale.unsqueeze(1) + zero.unsqueeze(1)).clamp(0, 15).to(torch.int32) + + # pack 8 INT4 -> 1 INT32 + packed = torch.zeros(num_groups * groupsize // 8, N, dtype=torch.int32, device=w.device) + for i in range(8): + mask = 0xF + packed |= (w_int[:, i::8, :] << (4 * i)) & mask + + packed = packed.view(K_pad // 8, N) + scale = scale.to(torch.float16) + zero = zero.to(torch.float16) + + return packed[: (K_orig + 7) // 8], scale, zero + + +# -------------------------------------------------------------------------------- +# Utility to unpack INT4 for testing only +# -------------------------------------------------------------------------------- +@triton.jit +def _unpack_int4_kernel( + qw_ptr, scales_ptr, zeros_ptr, out_ptr, + K, N, + stride_qw, stride_scales, stride_zeros, stride_out, + BLOCK_SIZE: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_k = tl.program_id(1) + + # indices + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_k = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + mask = (offs_k < K) & (offs_n < N) + scale_ptr = scales_ptr + offs_n * stride_scales + zero_ptr = zeros_ptr + offs_n * stride_zeros + scales = tl.load(scale_ptr, mask=offs_n < N, other=1.0) + zeros = tl.load(zero_ptr , mask=offs_n < N, other=0.0) + + # Each qw elt holds 8 values + offs_k_group = offs_k // 8 + offs_k_inner = offs_k % 8 + + qw_idx = offs_k_group * stride_qw + offs_n * 1 # contig along N + qw = tl.load(qw_ptr + qw_idx, mask=mask, other=0) + + val = (qw >> (4 * offs_k_inner)) & 0xF + fp_val = scales * (val.to(tl.float32) - zeros) + offs_out = offs_k * stride_out + offs_n * 1 + tl.store(out_ptr + offs_out, fp_val, mask=mask) + + +def unpack_int4(qw: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor: + K8, N = qw.shape + K = K8 * 8 + assert scales.shape == zeros.shape == (K // 128, N) # depends on groupsize 128 + out = torch.zeros(K, N, dtype=scales.dtype, device=qw.device) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), triton.cdiv(K, META['BLOCK_SIZE'])) + + _unpack_int4_kernel[grid]( + qw, + scales, + zeros, + out, + K, N, + qw.stride(0), scales.stride(0), zeros.stride(0), out.stride(0), + BLOCK_SIZE=64, + ) + + return out + +################################################################################################################################################## + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + group_size = 128 + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + # Test case + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + results = { + "test_case_1": triton_output + } + + return results + +result_gold = test_correct_int4_s2() diff --git a/src/temp/l2_norm_bwd.py b/src/temp/l2_norm_bwd.py new file mode 100644 index 0000000..8e3f62c --- /dev/null +++ b/src/temp/l2_norm_bwd.py @@ -0,0 +1,117 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_bwd_kernel( + X, + DY, + DX, + M, + N, + eps, + stride_x_row, + stride_dy_row, + stride_dx_row, + BLOCK_N: tl.constexpr +): + row = tl.program_id(0) + if row >= M: + return + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + x_ptr = X + row * stride_x_row + dy_ptr = DY + row * stride_dy_row + dx_ptr = DX + row * stride_dx_row + + x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=0) / N + rstd = tl.math.rsqrt(var + eps) + + term = dy * rstd - tl.sum(dy * x, axis=0) * (1.0 / (var + eps)) * rstd * x + dx = tl.where(mask, term, 0.0) + + tl.store(dx_ptr + cols, dx, mask=mask) + + +def _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-12): + original_shape = x.shape + x = x.view(-1, original_shape[-1]) + dy = dy.view(-1, original_shape[-1]) + + M, N = x.shape + if N == 0: + return torch.empty_like(x).view(*original_shape) + + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError( + f"Cannot normalize a row of size {N} larger than max BLOCK_N ({BLOCK_N})." + ) + + dx = torch.empty_like(x) + + if not x.is_contiguous(): + x = x.contiguous() + if not dy.is_contiguous(): + dy = dy.contiguous() + + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + M, + N, + eps, + x.stride(0), + dy.stride(0), + dx.stride(0), + BLOCK_N=BLOCK_N, + ) + + return dx.view(*original_shape) + +################################################################################################################################################## + + + +import torch + +# Test the backward L2 normalization +def test_l2_norm_bwd(): + results = {} + + # Test case 1: Default case + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_1'] = dx + + # Test case 2: Different shape + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_2'] = dx + + # Test case 3: Larger tensor + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_3'] = dx + + # Test case 4: Edge case with small tensor + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_4'] = dx + + return results + +# Run the tests +result_gold = test_l2_norm_bwd() diff --git a/src/temp/l2_norm_triton1.py b/src/temp/l2_norm_triton1.py new file mode 100644 index 0000000..8bf925b --- /dev/null +++ b/src/temp/l2_norm_triton1.py @@ -0,0 +1,97 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _l2_norm_fwd_1pass_kernel(X, Y, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + # program ids + row_id = tl.program_id(0) + + # offsets + offs_n = tl.arange(0, BLOCK_N) + + # compute normalized offset + row_start = X + row_id * stride_x_row + y_row_start = Y + row_id * stride_x_row + + # compute sum of squares + var = tl.zeros([], dtype=tl.float32) + masked_offs = offs_n < N + for i in range(0, N, BLOCK_N): + offs = i + offs_n + mask = masked_offs & (offs < N) + x_ptrs = row_start + offs # assuming the tensor has stride = 1 in the last dimension + x = tl.load(x_ptrs, mask=mask, other=0.0) + var += tl.sum(x.to(tl.float32) * x.to(tl.float32), axis=0) + + # Compute rstd + rstd = tl.rsqrt(var + eps) + + # normalize and store + for i in range(0, N, BLOCK_N): + offs = i + offs_n + mask = masked_offs & (offs < N) + x_ptrs = row_start + offs + y_ptrs = y_row_start + offs + x = tl.load(x_ptrs, mask=mask, other=0.0) + x_normed = x.to(tl.float32) * rstd + tl.store(y_ptrs, x_normed.to(Y.type.element_ty), mask=mask) + + +def _l2_norm_fwd(x: torch.Tensor, eps: float = 1e-5): + original_shape = x.shape + x = x.view(-1, x.shape[-1]) + M, N = x.shape + y = torch.empty(M, N, dtype=x.dtype, device=x.device) + + element_size = x.element_size() + max_block_size = 65536 // element_size + BLOCK_N = triton.next_power_of_2(N) + if BLOCK_N > max_block_size: + BLOCK_N = triton.next_power_of_2(triton.cdiv(max_block_size, 8)) + assert N <= BLOCK_N, "Feature dimension exceeds the max block size" + + _l2_norm_fwd_1pass_kernel[(M,)]( + x, y, + x.stride(0), + N, + eps, + BLOCK_N=BLOCK_N + ) + return y.view(original_shape) + +################################################################################################################################################## + + + +import torch + +# Test the forward L2 normalization +def test_l2_norm_fwd(): + results = {} + + # Test case 1 + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + y1 = _l2_norm_fwd(x1) + results['test_case_1'] = y1 + + # Test case 2: Different batch size + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + y2 = _l2_norm_fwd(x2) + results['test_case_2'] = y2 + + # Test case 3: Different feature size + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + y3 = _l2_norm_fwd(x3) + results['test_case_3'] = y3 + + # Test case 4: Larger tensor + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + y4 = _l2_norm_fwd(x4) + results['test_case_4'] = y4 + + return results + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/matrix_transpose.py b/src/temp/matrix_transpose.py new file mode 100644 index 0000000..01dfb65 --- /dev/null +++ b/src/temp/matrix_transpose.py @@ -0,0 +1,76 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey, SIZE_M, D_HEAD): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute base pointers for this block + offs_m = pid_m * 16 + tl.arange(0, 16) + offs_n = pid_n * 16 + tl.arange(0, 16) + + # Mask to prevent out-of-bounds access + mask = (offs_m[:, None] < SIZE_M) & (offs_n[None, :] < D_HEAD) + + # Compute memory addresses + in_ptrs = M + offs_m[:, None] * matrix_stridex + offs_n[None, :] * matrix_stridey + out_ptrs = Out + offs_n[:, None] * out_stridex + offs_m[None, :] * out_stridey + + # Load and transpose + data = tl.load(in_ptrs, mask=mask) + tl.store(out_ptrs, data, mask=mask) + + +def wrapper(matrix_stridex: int, matrix_stridey: int, out_stridex: int, out_stridey: int): + # Set dimensions + SIZE_M = 512 + D_HEAD = 256 + + # Initialize tensors on device + matrix = torch.randn(SIZE_M, D_HEAD, dtype=torch.float16, device='cuda') + out = torch.zeros(D_HEAD, SIZE_M, dtype=torch.float16, device='cuda') + + # Configure grid + grid = lambda META: ( + triton.cdiv(SIZE_M, 16), + triton.cdiv(D_HEAD, 16) + ) + + # Launch kernel + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + SIZE_M, D_HEAD + ) + + return out + +################################################################################################################################################## + + + +import torch + +def test_triton_vs_torch(): + results = {} + + # 测试用例 1: 基本矩阵转置 (小矩阵) + size_m, d_head = 16, 16 + out = wrapper(size_m, d_head) + results["test_case_1"] = out.clone() + + # 测试用例 2: 非方形矩阵 + size_m, d_head = 32, 64 + out = wrapper(size_m, d_head) + results["test_case_2"] = out.clone() + + return results + + +# 运行测试 +result_gold = test_triton_vs_torch() +# print(result_gold) \ No newline at end of file diff --git a/src/temp/matrix_vector_multip.py b/src/temp/matrix_vector_multip.py new file mode 100644 index 0000000..a1597eb --- /dev/null +++ b/src/temp/matrix_vector_multip.py @@ -0,0 +1,86 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel( + A, B, C, + N, M, + stride_a0, stride_a1, + stride_b0, + stride_c0, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_n = tl.program_id(0) + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + + accumulator = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for k in range(0, M, BLOCK_M): + offs_k = k + offs_m + mask_a = (offs_n[:, None] < N) & (offs_k[None, :] < M) + a_ptrs = A + offs_n[:, None] * stride_a0 + offs_k[None, :] * stride_a1 + a = tl.load(a_ptrs, mask=mask_a, other=0.0).to(tl.float32) + + mask_b = offs_k < M + b_ptrs = B + offs_k * stride_b0 + b = tl.load(b_ptrs, mask=mask_b, other=0.0).to(tl.float32) + + accumulator += tl.sum(a * b[None, :], axis=1) + + mask_c = offs_n < N + c_ptrs = C + offs_n * stride_c0 + tl.store(c_ptrs, accumulator.to(C.type.element_ty), mask=mask_c) + + +def mv(A: torch.Tensor, B: torch.Tensor): + assert A.dim() == 2 + assert B.dim() == 1 + N, M = A.shape + assert B.shape[0] == M + + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_N = 64 + BLOCK_M = 64 + + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_N']),) + + mv_kernel[grid]( + A, B, C, + N, M, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, + ) + + return C + +################################################################################################################################################## + + + +def test_mv(): + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + triton_result_2 = mv(A, B) + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + A = torch.randn(32, 16, device='cuda') + B = torch.randn(16, device='cuda') + triton_result_3 = mv(A, B) + + return { + "test_case_2": triton_result_2, + "test_case_3": triton_result_3, + } + +result_gold = test_mv() diff --git a/src/temp/rotary_transform.py b/src/temp/rotary_transform.py new file mode 100644 index 0000000..3705695 --- /dev/null +++ b/src/temp/rotary_transform.py @@ -0,0 +1,254 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLENS, + SEQLEN_OFFSETS, + max_seqlens, + stride_outb, + stride_outh, + stride_outm, + stride_outk, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cosb, + stride_coss, + stride_cosk, + stride_sinb, + stride_sins, + stride_sink, + rotary_dim, + seqlen_offsets_ptr, + conjugate: tl.constexpr, + interleaved: tl.constexpr, + seqlen_ro: tl.constexpr, + stride_outg: tl.constexpr, + stride_xg: tl.constexpr, + max_sequence_length: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if CU_SEQLENS is not None: + cu_seqlen_batch = pid_batch + cu_seqlen_prev = tl.load(CU_SEQLENS + cu_seqlen_batch) + cu_seqlen_curr = tl.load(CU_SEQLENS + cu_seqlen_batch + 1) + seqlen = cu_seqlen_curr - cu_seqlen_prev + offset_m_start = cu_seqlen_prev + pid_m * BLOCK_M + else: + seqlen_curr = tl.load(SEQLENS + pid_batch) + seqlen = seqlen_curr + offset_m_start = pid_m * BLOCK_M + + if seqlen <= 0: + return + + offset_k = tl.arange(0, BLOCK_K) + m_offset = offset_m_start + tl.arange(0, BLOCK_M) + m_mask = m_offset < seqlen + + if rotary_dim != -1: + k_mask = offset_k < rotary_dim + else: + k_mask = offset_k < stride_outk + + start_m = m_offset[:, None] + start_k = offset_k[None, :] + + if SEQLEN_OFFSETS is not None and seqlen_offsets_ptr: + seqlen_offset = tl.load(SEQLEN_OFFSETS + pid_batch) + else: + seqlen_offset = 0 + + pos_m = start_m + seqlen_offset + pos_cos = pos_m % max_seqlens + pos_sin = pos_m % max_seqlens + + cos_ptr = COS + pos_cos[:, None] * stride_cosb + start_k * stride_cosk + sin_ptr = SIN + pos_sin[:, None] * stride_sinb + start_k * stride_sink + + cos = tl.load(cos_ptr, mask=m_mask[:, None] & k_mask[None, :]) + sin = tl.load(sin_ptr, mask=m_mask[:, None] & k_mask[None, :]) + + x_ptr0 = X + pid_batch * stride_xb + pid_head * stride_xh + start_m * stride_xm + start_k * stride_xk + x_ptr1 = X + pid_batch * stride_xb + pid_head * stride_xh + start_m * stride_xm + (start_k + 1) * stride_xk + + x0 = tl.load(x_ptr0, mask=m_mask[:, None] & k_mask[None, :]) + x1 = tl.load(x_ptr1, mask=m_mask[:, None] & k_mask[None, :]) + + if interleaved: + o_real = x0 * cos - x1 * sin + o_imag = x1 * cos + x0 * sin + if conjugate: + o_imag = -o_imag + out_ptr0 = OUT + pid_batch * stride_outb + pid_head * stride_outh + start_m * stride_outm + start_k * stride_outk + out_ptr1 = OUT + pid_batch * stride_outb + pid_head * stride_outh + start_m * stride_outm + (start_k + 1) * stride_outk + tl.store(out_ptr0, o_real, mask=m_mask[:, None] & k_mask[None, :]) + tl.store(out_ptr1, o_imag, mask=m_mask[:, None] & k_mask[None, :]) + else: + cos_mask = start_k % 2 == 0 + sin_mask = start_k % 2 == 1 + x_even = tl.where(cos_mask, x0, 0.0) + x_odd = tl.where(sin_mask, x0, 0.0) + o_real = x_even * cos[None, :] - x_odd * sin[None, :] + if conjugate: + o_imag = x_odd * cos[None, :] + x_even * sin[None, :] + else: + o_imag = x_odd * cos[None, :] + x_even * sin[None, :] + out_ptr0 = OUT + pid_batch * stride_outb + pid_head * stride_outh + start_m * stride_outm + start_k * stride_outk + tl.store(out_ptr0, tl.where(cos_mask[None, :], o_real, o_imag), mask=m_mask[:, None] & k_mask[None, :]) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: int = 0, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + interleaved: bool = False, + in_place: bool = False, + conjugate: bool = False, + seqlen_ro: int = None, +) -> torch.Tensor: + assert x.dim() >= 3 + batch = x.shape[0] + seqlen = x.shape[-2] + head = x.shape[-3] if x.dim() >= 4 else 1 + dim = x.shape[-1] + rotary_dim = cos.shape[-1] if cos is not None else dim + + if max_seqlen is None: + max_seqlen = seqlen + assert cos is not None and sin is not None + assert cos.dim() == 3 and sin.dim() == 3 + cos = cos.view(-1, max_seqlen, rotary_dim) + sin = sin.view(-1, max_seqlen, rotary_dim) + + stride_outb = x.stride(0) if x.dim() >= 3 else 0 + stride_outh = x.stride(-3) if x.dim() >= 4 else 0 + stride_outm = x.stride(-2) + stride_outk = x.stride(-1) + stride_xb = x.stride(0) if x.dim() >= 3 else 0 + stride_xh = x.stride(-3) if x.dim() >= 4 else 0 + stride_xm = x.stride(-2) + stride_xk = x.stride(-1) + stride_cosb = cos.stride(0) + stride_coss = cos.stride(1) + stride_cosk = cos.stride(2) + stride_sinb = sin.stride(0) + stride_sins = sin.stride(1) + stride_sink = sin.stride(2) + + seqlen_offsets_tensor = torch.tensor([seqlen_offsets], dtype=torch.int32, device=x.device) if isinstance(seqlen_offsets, int) else seqlen_offsets + + if in_place: + out = x + else: + out = torch.empty_like(x) + + grid = (batch, head, (seqlen + 63) // 64) + + rotary_kernel[grid]( + out, + x, + cos, + sin, + cu_seqlens, + torch.tensor([seqlen], dtype=torch.int32, device=x.device) if cu_seqlens is None else None, + seqlen_offsets_tensor, + torch.tensor([max_seqlen], dtype=torch.int32, device=x.device) if max_seqlen is not None else torch.tensor([seqlen], dtype=torch.int32, device=x.device), + stride_outb, + stride_outh, + stride_outm, + stride_outk, + stride_xb, + stride_xh, + stride_xm, + stride_xk, + stride_cosb, + stride_coss, + stride_cosk, + stride_sinb, + stride_sins, + stride_sink, + rotary_dim, + seqlen_offsets_tensor is not None, + conjugate=conjugate, + interleaved=interleaved, + seqlen_ro=seqlen_ro if seqlen_ro is not None else seqlen, + stride_outg=1, + stride_xg=1, + max_sequence_length=max_seqlen if max_seqlen is not None else seqlen, + BLOCK_M=64, + BLOCK_K=rotary_dim, + ) + + if out.dim() == 2: + out = out.unsqueeze(0) + return out + +################################################################################################################################################## + + + +import torch + +def test_apply_rotary(): + results = {} + + # Test case 1: Basic test with fixed sequence length and no interleaving + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin) + results['test_case_1'] = output.shape + + # Test case 2: Variable length sequences with interleaving + total_seqlen, nheads, headdim = 256, 4, 64 + batch = 3 + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + max_seqlen = 128 + rotary_dim = 32 + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + results['test_case_2'] = output.shape + + # Test case 3: Conjugate flag enabled + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, conjugate=True) + results['test_case_3'] = output.shape + + # Test case 4: Inplace operation + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, inplace=True) + results['test_case_4'] = output.shape + + return results + +result_gold = test_apply_rotary() diff --git a/src/temp/sin_kernel.py b/src/temp/sin_kernel.py new file mode 100644 index 0000000..1563a97 --- /dev/null +++ b/src/temp/sin_kernel.py @@ -0,0 +1,58 @@ + +import torch +import triton +import triton.language as tl + + +@triton.jit +def kernel_function(x_ptr, output_ptr, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.math.sin(x.to(tl.float32)) + tl.store(output_ptr + offsets, y, mask=mask) + + +def call_kernel(x: torch.Tensor, BLOCK_SIZE: int = 1024): + n_elements = x.numel() + output = torch.empty_like(x) + grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) + kernel_function[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return output + +################################################################################################################################################## + + + +import torch + +# Function to test the Triton kernel +def test_call_kernel(): + results = {} + + # Test case 1: Small input tensor + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + output1 = call_kernel(x1) + results['test_case_1'] = output1 + + # Test case 2: Larger input tensor + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + output2 = call_kernel(x2) + results['test_case_2'] = output2 + + # Test case 3: Edge case with zero elements + x3 = torch.tensor([], dtype=torch.float32).cuda() + output3 = call_kernel(x3) + results['test_case_3'] = output3 + + # Test case 4: Input tensor with negative values + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + output4 = call_kernel(x4) + results['test_case_4'] = output4 + + return results + +# Run the test function +result_gold = test_call_kernel() diff --git a/src/temp/triton_matmul.py b/src/temp/triton_matmul.py new file mode 100644 index 0000000..a9ecde2 --- /dev/null +++ b/src/temp/triton_matmul.py @@ -0,0 +1,130 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, # pointers + M, N, K, # shape (M, K) @ (K, N) --> (M, N) + stride_am, stride_ak, # a row/col + stride_bk, stride_bn, # b row/col + stride_cm, stride_cn, # c row/col + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_EVEN_K: tl.constexpr = 0, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if IS_EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + mask_k = offs_k[None, :] < K - k * BLOCK_SIZE_K + a = tl.load(a_ptrs, mask=mask_k, other=0.0) + b = tl.load(b_ptrs, mask=mask_k[:, None], other=0.0) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + +def matmul(a: torch.Tensor, b: torch.Tensor): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 32 + num_stages = 4 + num_warps = 8 + + if str(a.dtype) == 'torch.float16': + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 32 + num_stages = 4 + num_warps = 8 + elif 'float8' in str(a.dtype): + BLOCK_M = 128 + BLOCK_N = 128 + BLOCK_K = 128 + num_stages = 3 + num_warps = 4 + else: + BLOCK_M = 64 + BLOCK_N = 64 + BLOCK_K = 32 + num_stages = 2 + num_warps = 4 + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_M, + BLOCK_SIZE_N=BLOCK_N, + BLOCK_SIZE_K=BLOCK_K, + GROUP_SIZE_M=8, + IS_EVEN_K=K % BLOCK_K == 0, + num_stages=num_stages, + num_warps=num_warps, + ) + return c + +################################################################################################################################################## + + + +import torch + +# Test for matmul +def test_matmul(): + results = {} + M, K, N = 256, 128, 256 + + # Test case 1: torch.float16 + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + c = matmul(a, b) + results['test_case_1'] = c + + return results + +# Run all tests +result_gold = test_matmul() \ No newline at end of file diff --git a/src/utils/__pycache__/utils.cpython-312.pyc b/src/utils/__pycache__/utils.cpython-312.pyc index 5240a44343db32ccd23713863f8a830c080e8631..7e3b0da91d0e92cbfd97d0210a8ce9eded91e7a2 100644 GIT binary patch delta 20 acmeAY?h@uc&CAQh00hh8S8e33=L7&Vss$(j delta 20 acmeAY?h@uc&CAQh00aebb2f6S{