From c37493815e4d4907a0d5d022609205cc85a37586 Mon Sep 17 00:00:00 2001 From: yyyuyu99 Date: Sun, 24 Aug 2025 08:28:12 +0000 Subject: [PATCH 1/4] yuyu --- src/__init__.py | 0 src/__pycache__/args_config.cpython-312.pyc | Bin 556 -> 556 bytes src/agents/__pycache__/Base.cpython-312.pyc | Bin 4986 -> 4986 bytes .../__pycache__/Reflexion.cpython-312.pyc | Bin 2883 -> 2883 bytes .../reflexion_oneshot.cpython-312.pyc | Bin 9006 -> 9006 bytes src/agents/multi_agent_pipeline.py | 437 ++++++++++++++++++ .../.tritonbench_oneshot_config.yaml.swo | Bin 0 -> 12288 bytes .../.tritonbench_oneshot_config.yaml.swp | Bin 0 -> 12288 bytes src/configs/tritonbench_oneshot_config.yaml | 14 +- src/dataloaders/ProblemState.py | 1 + .../TB_eval/__pycache__/utils.cpython-312.pyc | Bin 13229 -> 13229 bytes src/dataloaders/TritonBench.py | 58 ++- .../__pycache__/ProblemState.cpython-312.pyc | Bin 1187 -> 1239 bytes .../__pycache__/TritonBench.cpython-312.pyc | Bin 16573 -> 17576 bytes src/main_multi_agent.py | 46 ++ src/main_reflexion_oneshot.py | 31 +- src/main_reflexion_oneshot.py.backup | 35 ++ .../__pycache__/Memory.cpython-312.pyc | Bin 1368 -> 1368 bytes src/models/__init__.py | 0 src/models/__pycache__/Base.cpython-312.pyc | Bin 631 -> 631 bytes src/models/__pycache__/KimiK2.cpython-312.pyc | Bin 2183 -> 2183 bytes src/prompts/All_Prompts.py | 85 ++++ src/prompts/Analyst_Prompt.py | 26 ++ src/prompts/Base.py | 6 + src/prompts/Baseline_Prompt.py | 56 +++ src/prompts/Executor_Prompt.py | 42 ++ src/prompts/Strategist_Prompt.py | 53 +++ .../prompt_for_generation.cpython-312.pyc | Bin 9969 -> 9969 bytes .../prompt_for_reflection.cpython-312.pyc | Bin 14265 -> 14265 bytes src/prompts/prompt_for_correction_plan.py | 31 ++ src/prompts/prompt_for_repair.py | 25 + .../__pycache__/retriever.cpython-312.pyc | Bin 3354 -> 3354 bytes src/temp/embedding_triton_kernel.py | 207 +++++++++ src/temp/flash_decode2_phi.py | 205 ++++++++ src/temp/int4_matmul.py | 245 ++++++++++ src/temp/l2_norm_bwd.py | 84 ++++ src/temp/l2_norm_triton1.py | 121 +++++ src/temp/matrix_transpose.py | 45 ++ src/temp/matrix_vector_multip.py | 42 ++ src/temp/rotary_transform.py | 291 ++++++++++++ src/temp/sin_kernel.py | 45 ++ src/temp/triton_matmul.py | 44 ++ src/test_main_copy.py | 107 +++++ src/utils/__pycache__/utils.cpython-312.pyc | Bin 2442 -> 6181 bytes src/utils/utils.py | 103 ++++- 45 files changed, 2458 insertions(+), 27 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/agents/multi_agent_pipeline.py create mode 100644 src/configs/.tritonbench_oneshot_config.yaml.swo create mode 100644 src/configs/.tritonbench_oneshot_config.yaml.swp create mode 100644 src/main_multi_agent.py create mode 100644 src/main_reflexion_oneshot.py.backup create mode 100644 src/models/__init__.py create mode 100644 src/prompts/All_Prompts.py create mode 100644 src/prompts/Analyst_Prompt.py create mode 100644 src/prompts/Base.py create mode 100644 src/prompts/Baseline_Prompt.py create mode 100644 src/prompts/Executor_Prompt.py create mode 100644 src/prompts/Strategist_Prompt.py create mode 100644 src/prompts/prompt_for_correction_plan.py create mode 100644 src/prompts/prompt_for_repair.py create mode 100644 src/temp/embedding_triton_kernel.py create mode 100644 src/temp/flash_decode2_phi.py create mode 100644 src/temp/int4_matmul.py create mode 100644 src/temp/l2_norm_bwd.py create mode 100644 src/temp/l2_norm_triton1.py create mode 100644 src/temp/matrix_transpose.py create mode 100644 src/temp/matrix_vector_multip.py create mode 100644 src/temp/rotary_transform.py create mode 100644 src/temp/sin_kernel.py create mode 100644 src/temp/triton_matmul.py create mode 100644 src/test_main_copy.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/__pycache__/args_config.cpython-312.pyc b/src/__pycache__/args_config.cpython-312.pyc index ed62ea94b9178a5147963fc8bfe6e3c473ed5043..66a08cb61c064baadecac31ff1112455cdef1f4d 100644 GIT binary patch delta 20 acmZ3(vWA8GG%qg~0}$++v2r7~A`<{N&;>aF delta 20 acmZ3(vWA8GG%qg~0}vF%&DqGU$OHg15d@?F diff --git a/src/agents/__pycache__/Base.cpython-312.pyc b/src/agents/__pycache__/Base.cpython-312.pyc index f0a272db34755c9ad5ac90fc0aa6f7c474ac611d..00e8d8dc67aaed12520ef54b8365ed10db54e178 100644 GIT binary patch delta 20 acmeyR_DhZXG%qg~0}zC8uG+|*FAM-eo&|sa delta 20 acmeyR_DhZXG%qg~0}vF%&DqGEFAM-h8U^D3 diff --git a/src/agents/__pycache__/Reflexion.cpython-312.pyc b/src/agents/__pycache__/Reflexion.cpython-312.pyc index 54cab4bd97f26a4a5622127942cb5f6a38e1bf70..5b859d933af27676176b79c547e8bc94e917934c 100644 GIT binary patch delta 20 acmX>sc36!2G%qg~0}$++v2r7~1vdaa`30*0 delta 20 acmX>sc36!2G%qg~0}vF%&DqFp!3_X9It3O0 diff --git a/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc b/src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc index db3379e04a5f6150b194bba0a6cbd55a6310e256..b15797fd9423c30d0bbf1697d2decd73216c72e4 100644 GIT binary patch delta 94 zcmZ4Iw$6?3G%qg~0}%K{uF7au-N^S?j8%+R;EUeoZ(<>gjK!0^CFGgBm^SA~aI-S5 u-CQB-z|44Y^G*d}Mz&~1R^!ign=dJPG79@KGX^ux;QYz}W)vv{4FmueLK(XN delta 94 zcmZ4Iw$6?3G%qg~0}vF%&B^$xxRLL%7%MZcz!#3q-^4-~8GR>vOUN^&FmBF~;AUl< uwYfspfthjN=A8<{jBFl^tj3=sHeXWoWE8GuVhm=S!TFT|%qUU@8VCR;b{X{m diff --git a/src/agents/multi_agent_pipeline.py b/src/agents/multi_agent_pipeline.py new file mode 100644 index 0000000..6d531a6 --- /dev/null +++ b/src/agents/multi_agent_pipeline.py @@ -0,0 +1,437 @@ +import os +from tqdm import tqdm +from loguru import logger +import json +from dataclasses import asdict +from agents.Reflexion import Reflexion +from utils.utils import extract_function_signatures, clear_code, extract_function_calls, safe_force_correct_signature +from prompts import prompt_for_reflection +from memories.Memory import MemoryClassMeta +from models.Base import BaseModel +from agents.reflexion_oneshot import Reflexion_Oneshot +# --- Corrected Imports for Prompt Classes --- +from prompts.Analyst_Prompt import Analyst_Prompt +from prompts.Baseline_Prompt import Baseline_Prompt +from prompts.Strategist_Prompt import Strategist_Prompt +from prompts.Executor_Prompt import Executor_Prompt +from retrievers.retriever import BM25Retriever +from prompts import prompt_for_generation +from concurrent.futures import ThreadPoolExecutor, as_completed +import re +from utils.utils import clear_code +import os +import datetime +from tenacity import RetryError + +# Inherit from Reflexion_Oneshot to gain all its functionality +class MultiAgentPipeline(Reflexion_Oneshot): + """ + This agent augments the Reflexion framework by implementing a four-stage pipeline + for the first iteration of code generation. + """ + + def __init__(self, model: BaseModel, dataset, corpus_path, mem_file=None): + """ + Initializes the MultiAgentPipeline. + It correctly calls the parent __init__ to set up all necessary components + like BM25Retriever and memory, then adds its own pipeline-specific components. + """ + # --- Corrected super().__init__() call --- + # We must explicitly call the __init__ of our direct parent, Reflexion_Oneshot, + # to ensure all its setup logic (retrievers, memory init) is executed. + super().__init__(model, dataset, corpus_path, mem_file) + + # Initialize our custom prompt generators for the pipeline + self.analyst_prompt_generator = Analyst_Prompt() + self.baseline_prompt = Baseline_Prompt() + self.strategist_prompt = Strategist_Prompt() + self.executor_prompt = Executor_Prompt() + + # This dictionary will hold the state for our pipeline for each problem + self.pipeline_states = {} + + # --- New: Setup for pipeline run outputs --- + self.run_output_dir = os.path.join("/workspace", "pipeline_run_outputs", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) + os.makedirs(self.run_output_dir, exist_ok=True) + logger.info(f"Pipeline outputs will be saved to: {self.run_output_dir}") + + self._initialize_pipeline_states() + + def _save_pipeline_step(self, kernel_name: str, iteration: int, agent_name: str, step_type: str, content: str): + """Helper function to save pipeline intermediate outputs.""" + try: + kernel_dir = os.path.join(self.run_output_dir, kernel_name) + os.makedirs(kernel_dir, exist_ok=True) + file_name = f"iter_{iteration}_agent_{agent_name}_{step_type}.txt" + file_path = os.path.join(kernel_dir, file_name) + with open(file_path, "w") as f: + f.write(content) + except Exception as e: + logger.error(f"Failed to save pipeline step for {kernel_name}: {e}") + + + def _initialize_pipeline_states(self): + """ + Creates a parallel state management dictionary for our pipeline's needs, + without modifying the original self.memories list. + """ + logger.info("Initializing parallel pipeline states for each kernel...") + for mem in self.memories: + ps = mem.ps + self.pipeline_states[ps.filename] = { + "analysis": None, + "baseline_code": None, + "strategy_plan": [], + "current_code": None, + "best_code": None, # Will be updated based on evaluation + "best_performance": float('-inf'), + "current_strategy_idx": -1, + "iteration": 0 # Initialize iteration counter + } + logger.info(f"{len(self.pipeline_states)} pipeline states initialized.") + + # The run method, and the pipeline-specific helper methods will be added next. + # We will no longer touch memory_init, as super().__init__ handles it. + + def run(self, output_path=None, multi_thread=False, verbose=False, datalen=None, iteration_num=0, temperature=0): + data_len = datalen if datalen else len(self.dataset) + + for i in range(iteration_num): + logger.info(f"\n{'='*20} Iteration {i + 1}/{iteration_num} {'='*20}") + + # --- Generation Phase --- + logger.info("Phase 1: Generating solutions...") + # We iterate through the original self.memories list. + for mem in tqdm(self.memories[:data_len], desc=f"Generation (Iter {i+1})"): + if mem.pass_call: + continue + + if i == 0: + # --- Iteration 1: Use our full pipeline --- + self.run_full_pipeline(mem, temperature) + # After the first iteration, update the state + self.pipeline_states[mem.ps.filename]["iteration"] = 1 + else: + # --- Subsequent Iterations: Use our incremental optimizer --- + # This replaces the original generate_solution call + self.run_incremental_optimization(mem, temperature) + self.pipeline_states[mem.ps.filename]["iteration"] += 1 + + # --- Evaluation Phase (adapted from original) --- + logger.info("Phase 2: Evaluating solutions...") + for mem in tqdm(self.memories[:data_len], desc=f"Evaluation (Iter {i+1})"): + if mem.pass_call: + continue + + # The code to be tested is now in mem.ps.solution + is_pass, err_msg = self.dataset.run_single_call(mem.ps) + if not is_pass: + mem.err_msg = err_msg + else: + mem.pass_call = True + mem.err_msg = None # Clear error message on pass + logger.info(f" -> PASSED: {mem.ps.filename}") + + # --- Reflection Phase --- + logger.info("Phase 3: Generating reflections and repairs for failures...") + for mem in tqdm(self.memories[:data_len], desc=f"Reflection (Iter {i+1})"): + if not mem.pass_call and mem.err_msg: + # This is where we implement our new two-step reflection process + self.diagnose_and_repair(mem, temperature) + + # --- File Writing (adapted from original) --- + if output_path is not None: + root, extension = os.path.splitext(output_path) + iter_path = f"{root}_{i}{extension}" + logger.info(f"Writing results for iteration {i+1} to {iter_path}") + self.dataset.write_file(iter_path) + + def run_full_pipeline(self, mem, temperature): + """Runs the complete 4-agent pipeline and places the result in mem.ps.solution.""" + ps = mem.ps + pipeline_state = self.pipeline_states[ps.filename] + logger.info(f" Running full pipeline for {ps.filename}...") + + try: + # 1. Analyst + analyst_prompt_obj = self.analyst_prompt_generator.get_prompt(ps) # Now returns a list + self._save_pipeline_step(ps.filename, 1, "1_analyst", "input", json.dumps(analyst_prompt_obj, indent=2)) + analysis_str = self.model.generate(analyst_prompt_obj, temperature) + pipeline_state["analysis"] = analysis_str + self._save_pipeline_step(ps.filename, 1, "1_analyst", "output", analysis_str) + + # 2. Baseline + baseline_prompt_obj = self.baseline_prompt.get_prompt(ps, analysis_str) + self._save_pipeline_step(ps.filename, 1, "2_baseline", "input", json.dumps(baseline_prompt_obj, indent=2)) + baseline_code = self.model.generate(baseline_prompt_obj, temperature) + baseline_code = clear_code(baseline_code) + pipeline_state["baseline_code"] = baseline_code + self._save_pipeline_step(ps.filename, 1, "2_baseline", "output", baseline_code) + + # 3. Strategist + strategist_prompt_obj = self.strategist_prompt.get_prompt(ps, analysis_str, baseline_code) + self._save_pipeline_step(ps.filename, 1, "3_strategist", "input", json.dumps(strategist_prompt_obj, indent=2)) + plan_str = self.model.generate(strategist_prompt_obj, temperature) + strategies = self._parse_strategies(plan_str) + pipeline_state["strategies"] = strategies + pipeline_state["strategy_plan"] = plan_str + self._save_pipeline_step(ps.filename, 1, "3_strategist", "output", plan_str) + + # 4. Executor (First Strategy) + if strategies: + first_strategy = strategies[0] + pipeline_state["current_strategy_index"] = 0 + executor_prompt_obj = self.executor_prompt.get_prompt(baseline_code, first_strategy) + self._save_pipeline_step(ps.filename, 1, "4_executor", "input", json.dumps(executor_prompt_obj, indent=2)) + executed_code = self.model.generate(executor_prompt_obj, temperature) + executed_code = clear_code(executed_code) + + # --- NEW: Post-generation Signature Correction --- + try: + # Infer function name from filename, with special handling for known variations + base_name = ps.filename.replace(".py", "") + name_map = { + 'triton_matmul': 'matmul', 'matrix_vector_multip': 'mv', + 'sin_kernel': 'call_kernel', 'matrix_transpose': 'wrapper', + 'l2_norm_bwd': '_l2_norm_bwd' + } + func_name = name_map.get(base_name, base_name) + + baseline_code = pipeline_state.get("baseline_code") + + if baseline_code: + logger.info(f" Applying SAFE signature correction for '{func_name}'...") + corrected_code = safe_force_correct_signature(baseline_code, executed_code, func_name) + + if corrected_code == executed_code: + logger.warning(f" Signature correction for '{func_name}' was skipped (either safe or not needed).") + else: + logger.info(f" Signature for '{func_name}' was successfully corrected.") + + mem.ps.solution = corrected_code + pipeline_state["current_code"] = corrected_code + self._save_pipeline_step(ps.filename, 1, "5_corrected_executor", "output", corrected_code) + else: + logger.error(" Cannot perform signature correction: baseline_code not found.") + mem.ps.solution = executed_code + pipeline_state["current_code"] = executed_code + self._save_pipeline_step(ps.filename, 1, "4_executor", "output", executed_code) + + except Exception as e: + logger.error(f" An unexpected error occurred during signature correction for {ps.filename}: {e}") + mem.ps.solution = executed_code + pipeline_state["current_code"] = executed_code + self._save_pipeline_step(ps.filename, 1, "4_executor", "output", executed_code) + # --- End Correction --- + + else: + # If no strategies, use the baseline code + mem.ps.solution = baseline_code + pipeline_state["current_code"] = baseline_code + + except RetryError as e: + logger.error(f" API call failed for {ps.filename} after multiple retries. Skipping this kernel.") + error_message = f"API Error: The model server failed to respond.\n\nTraceback:\n{e}" + self._save_pipeline_step(ps.filename, 1, "API_ERROR", "log", error_message) + # Set a placeholder solution to indicate failure + mem.ps.solution = "# API_ERROR: Model generation failed for this kernel." + return + + return + + def run_incremental_optimization(self, mem, temperature): + """Applies the next strategy and places the result in mem.ps.solution.""" + ps = mem.ps + pipeline_state = self.pipeline_states[ps.filename] + + logger.info(f"Applying strategy {pipeline_state['current_strategy_index'] + 2}/{len(pipeline_state['strategies'])} for {ps.filename}...") + + # Move to the next strategy + pipeline_state["current_strategy_index"] += 1 + + # Get the next strategy + strategy_index = pipeline_state["current_strategy_index"] + + if strategy_index < len(pipeline_state["strategies"]): + strategy = pipeline_state["strategies"][strategy_index] + + # Get the last working code + last_code = pipeline_state.get("current_code", pipeline_state.get("baseline_code")) + if not last_code: + logger.warning(f"No previous code found for {ps.filename}, cannot apply optimization. Skipping.") + return + + try: + # Use Executor to apply the new strategy + executor_prompt_obj = self.executor_prompt.get_prompt(last_code, strategy) + self._save_pipeline_step(ps.filename, pipeline_state['iteration'] + 1, "4_executor", "input", json.dumps(executor_prompt_obj, indent=2)) + executed_code = self.model.generate(executor_prompt_obj, temperature) + executed_code = clear_code(executed_code) + + # --- NEW: Post-generation Signature Correction (Safe Version) --- + try: + base_name = ps.filename.replace(".py", "") + name_map = { + 'triton_matmul': 'matmul', 'matrix_vector_multip': 'mv', + 'sin_kernel': 'call_kernel', 'matrix_transpose': 'wrapper', + 'l2_norm_bwd': '_l2_norm_bwd' + } + func_name = name_map.get(base_name, base_name) + + baseline_code = pipeline_state.get("baseline_code") + + if baseline_code: + logger.info(f" Applying SAFE signature correction for '{func_name}' in iteration {pipeline_state['iteration'] + 1}...") + corrected_code = safe_force_correct_signature(baseline_code, executed_code, func_name) + + if corrected_code == executed_code: + logger.warning(f" Signature correction for '{func_name}' was skipped.") + else: + logger.info(f" Signature for '{func_name}' was successfully corrected.") + + mem.ps.solution = corrected_code + pipeline_state["current_code"] = corrected_code + self._save_pipeline_step(ps.filename, pipeline_state['iteration'] + 1, "5_corrected_executor", "output", corrected_code) + else: + logger.error(" Cannot perform signature correction: baseline_code not found.") + mem.ps.solution = executed_code + pipeline_state["current_code"] = executed_code + self._save_pipeline_step(ps.filename, pipeline_state['iteration'] + 1, "4_executor", "output", executed_code) + + except Exception as e: + logger.error(f" An unexpected error occurred during signature correction for {ps.filename}: {e}") + mem.ps.solution = executed_code + pipeline_state["current_code"] = executed_code + self._save_pipeline_step(ps.filename, pipeline_state['iteration'] + 1, "4_executor", "output", executed_code) + # --- End Correction --- + + except RetryError as e: + logger.error(f" API call failed during optimization for {ps.filename}. Skipping this optimization step.") + error_message = f"API Error: The model server failed to respond during optimization.\n\nTraceback:\n{e}" + self._save_pipeline_step(ps.filename, pipeline_state['iteration'] + 1, "API_ERROR", "log", error_message) + # We don't have a new solution, so we just return and let the old one be evaluated + else: + logger.info(f"All strategies for {ps.filename} have been applied.") + # If no more strategies, we let the default reflexion handle it + pass + + return + + def _parse_strategies(self, plan_str: str) -> list[str]: + """Helper to parse a numbered list of strategies from a string.""" + matches = re.findall(r"^\s*\d[\.\)-]\s*(.*)", plan_str, re.MULTILINE) + if matches: + return [match.strip() for match in matches] + return [line.strip() for line in plan_str.split('\n') if line.strip()] + + # We are overriding the original generate_solution, but keeping generate_reflexion. + def generate_solution(self, mem, temperature=0): + # This is intentionally left blank because our pipeline's run_* methods + # handle the logic that replaces this. The main `run` loop will no longer call this. + pass + + def diagnose_and_repair(self, mem, temperature): + """ + A new, two-step process for handling failures. + 1. Diagnose Failure and Create Plan: Use an expert prompt to create a high-level correction plan. + 2. Repair Code based on Plan: Use a second prompt to implement the plan and generate corrected code. + """ + ps = mem.ps + pipeline_state = self.pipeline_states[ps.filename] + iteration = pipeline_state['iteration'] + + # --- Step 1: Diagnose Failure and Create Plan --- + logger.info(f" Diagnosing failure for {ps.filename}...") + + # This is the expert diagnostician prompt string. + diagnostician_prompt_template = """ +You are an expert debugging assistant for Triton GPU kernels. You are given a Triton code snippet that has failed evaluation, along with its performance and error logs. +Your task is to create a concise, prioritized, and actionable correction plan. + +**THE ABSOLUTE LAW:** +- Your plan MUST be a bulleted list starting with `-`. +- Do NOT propose fixing the function signature. A "Code Goalkeeper" has already corrected it. Focus on the kernel's internal logic and performance tuning. + +**Input Information:** +1. **Code with Issue:** + ```python + {code} + ``` +2. **Evaluation Results:** + - **Error Type:** "{error}" + - **Error Trace/Log:** "{trace}" + +**Your Task:** Based on the evaluation results, diagnose the primary failure mode and create a correction plan. + +* **If `Error Type` is `Runtime Error`:** Analyze the traceback. Pinpoint the likely cause (e.g., shape mismatch in `tl.dot`, memory access error). +* **If `Error Type` is `Correctness Error`:** This is a logic error. Suspect issues in pointer arithmetic, accumulator updates, or incorrect masking. +* **If `Error Type` is `Success` but performance is low:** This is a tuning problem. Suggest changes to block sizes, num_warps, etc. + +**Correction Plan:** +""" + + diagnose_prompt = diagnostician_prompt_template.format( + code=ps.solution, + error=mem.err_msg.get('error_type', 'Unknown'), + trace=mem.err_msg.get('error_log', 'No log available'), + instruction=ps.instruction + ) + + correction_plan = self.model.generate([{"role": "user", "content": diagnose_prompt}], temperature) + pipeline_state["correction_plan"] = correction_plan + self._save_pipeline_step(ps.filename, iteration, "6_correction_plan", "output", correction_plan) + logger.info(f" Correction plan for {ps.filename} generated.") + + # --- Step 2: Repair Code based on Plan --- + logger.info(f" Repairing code for {ps.filename} based on the new plan...") + from prompts.prompt_for_repair import prompt as repair_prompt_template + + repair_prompt = repair_prompt_template.format( + solution=ps.solution, + test_result=mem.err_msg.get('error_log', 'No log available'), + reflection=correction_plan # Feed the plan into the repair prompt + ) + + repaired_code = self.model.generate([{"role": "user", "content": repair_prompt}], temperature) + repaired_code = clear_code(repaired_code) + + # --- Step 3: Apply Goalkeeper to the repaired code --- + logger.info(f" Applying Goalkeeper to the repaired code for {ps.filename}...") + + base_name = ps.filename.replace(".py", "") + name_map = { + 'triton_matmul': 'matmul', 'matrix_vector_multip': 'mv', + 'sin_kernel': 'call_kernel', 'matrix_transpose': 'wrapper', + 'l2_norm_bwd': '_l2_norm_bwd' + } + func_name = name_map.get(base_name, base_name) + baseline_code = pipeline_state.get("baseline_code") + + if baseline_code: + final_code = safe_force_correct_signature(baseline_code, repaired_code, func_name) + else: + logger.warning(f" Baseline code not found for {ps.filename}, Goalkeeper skipped on repaired code.") + final_code = repaired_code + + # Update memory and state with the new, repaired, and corrected solution + mem.ps.solution = final_code + pipeline_state["current_code"] = final_code + self._save_pipeline_step(ps.filename, iteration, "7_repaired_code", "output", final_code) + logger.info(f" Code for {ps.filename} has been repaired and corrected for the next iteration.") + + # We keep the original generate_reflexion, but it will no longer be called by our main loop. + def generate_reflexion(self, mem, temperature): + if mem.pass_call: + return + reflect_txt = prompt_for_reflection.prompt.format( + problem=mem.ps.instruction, + solution=mem.ps.solution, + test_result=mem.err_msg + ) + reflect_msg = [ + { + "role": "user", + "content": reflect_txt + } + ] + mem.reflection = self.model.generate(reflect_msg, temperature=temperature) \ No newline at end of file diff --git a/src/configs/.tritonbench_oneshot_config.yaml.swo b/src/configs/.tritonbench_oneshot_config.yaml.swo new file mode 100644 index 0000000000000000000000000000000000000000..a85df2b523bb09f7c37b0bfd2ad74b568a331d53 GIT binary patch literal 12288 zcmeI%!Aiq09LMq0yKJ}@nDe}4TR|xJ0vtSc@wxx7?{7C^6wAu{&}{!VbWs#$nP%5{_Skh+ zzW8cemr5IHmENjgy_>4FaY5FVmfE{wv#w>RHo?2xIBi$nncDiW`svMH)$>!=4fGH| zpk3fxq|3#;>0eDQ$D@n;_M>zVKmY**5I_I{1Q0-=cLJeI#4e^f`%RX#ti20PAOsLV z009ILKmY**5I_I{1Wrgm#UJ~xfe_zuHvj*B@Bd?cLJmYl009ILKmY**5I_I{1Q0-= G&jJS(!7h6M literal 0 HcmV?d00001 diff --git a/src/configs/.tritonbench_oneshot_config.yaml.swp b/src/configs/.tritonbench_oneshot_config.yaml.swp new file mode 100644 index 0000000000000000000000000000000000000000..78fd89e48fb69467ca243baded1e876057a568d5 GIT binary patch literal 12288 zcmeI2zmMcJ6vust3OG1HLxW_p#WuH-nK?kWBdP;*I^A7&M}xp}l9!2D$F^)|XF~&j z1`Ze?K}W@Z0SyHd2wEhbGudH+PK#J7;6?gQ6#saB@8joaM)QjKn|JQ>kDZdttm^`O|%Tq#F?+0z`la z5CI}U1c(3;AOb{y2oM1x@Gm4F@gMtZPcZf|D*gX|{`>#eXBhhd^)2cv)R(BwQLmwX zdYZ8hP)pP|R~d__cTjJjUPnETdK~r3Q;dC&`VRH;lZ>6AN>qVbp?Hu8Tytf_SNE1_I z1x2^!u>tP9Eg=Nf$}^!Nco~&7EB-RGU36U(4l7YtJliWz^-ya#!%@2e)V-qA*}K7- z3|$#)<6ORJzytRX(BZ&0N^4#K57BzSuh&ufHAHa&-hd7(ZaS@TcJj59h6e|ws(Un| z%7%}lCXrx;Y;qcS(WmgEG8osy=xi{`TaYGjiOVnMJ4?0od{=HeOKwl5uJ~J7tz~FL z1*NTEDO{swuFtOXL%kHn`c`a%UtsIub)G2`UlOg1xC`5HSx{ySZ*163m<~EOR&1(q z6x#6W46UQr&ITBpp)H`QFvxM;scCP~|9(3OLu;qDPN78VVn`5|y57c8N8AbN8YZNhF{RIPe%$e>+bG1m)(0^WJRPfR7(r`qO> zIqWD>R2Rq7dR53RR%X}8;)y;|r>MyajDIU#>=OTekrd1}kkT#QPHwV6*%ik-cTmDam1T+VykYD|^V%Z(mSkum`fc9~A(N>&)#A6Y zPP}zOYOFCAr;9(l9v8!XI^T%DiD+dw;<9M&oM*T1*BRJCAw1 qn9W&>r&kMwugdH$Ch5`L<lmEpiNk?#a!&E$zpk`q4!NTe`Daiy}QGD`wwQ(02jZGdzo zlP1Sx7e*)kTiogSIVq`m@yYoqshZrAI~aqRiufnrU<_sSovg=XFMXLs>4u_egZBi9 z8Ojr7udt{Vi2?N&aRP~n2gD{%U^?$301^-c5keqB7(|GG2tE+O0wSb=#4VPj{QR6E z84zC-M1Tzv2eCj-0~36cS2KICfh8m-|75l*0L!4NzQtjao1apelWJF_0My3_#KjIk z;sY}yBjbGrwTldD_Zc)kurV-ke`jK2RQtdHBtA0(xgWq>0k8xIqx1&`AOV*6$OL48 G-2(s%jYo+rktk3@5hswCctC9O5vKD#{2&1V5FrR6gg}Hah~NVeEFeM(NZevc%FoX! zk_PcbKm^zzF%U}vM(|F)&+NelmXMgN%VH%DmO)j0i^C>2KczG$)vic>avO`H0V|{O j2L>SVnVEr!`vVh*%LinAXX0R#{=fhvz!D!pT(F%0h^sPu diff --git a/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc b/src/dataloaders/__pycache__/TritonBench.cpython-312.pyc index ad4c954d6ce071bdc2f2335b6c38ed1a4e2e0521..f48cac00a19aafd9c8e0fbdef462b9af329d4d32 100644 GIT binary patch delta 3086 zcmaJ@Yiv}<6`t9>ci+3-_1f##+Fn1_W{vFygKZwh5J+(-)R5Xxx>d7W@4dDc?mja2 znpf5`q_j#~D&fu#NBPn4t1=j=R4#>-MyZsvM2RXH*(vrVRjP)nqEb^A6Inv4+Bvg+ zpsMP%{LMLY&YYS3&Y8!%7xCoxvFBsA+kx2k?+>mIAJ{SODaC&pZ+vx3Mlx2tF)wSa zsxMU-D@^%gevEiOisf5#%*=F~T9hh|6{iBRz>2*jR?7PAYFVm0R=(0#5v$;1gR!6q z^&r`C3CYeW6HJXztnwU=xE_Pp8^MN^jNd)h95ETfp`@l8!i&j-&NRC$=~6uecxiizhnAJ14rvyRNz>Q?vfA8CTa5 zn*!08=}p;b2%4e}k}B3gssTa_49`$Tm6dcnrRXG?&`2#<>H+Ei!rX^(IHc(DQwm8d zsz!pWQ*$y1t$(C~>GXT%&?}{29bs?*J*juucl|Snf}3tSXDhni5aw*93uP^n$=R}P zZwQw?3!$wu?O&M8W9KcDd0W@ac_P!Lb7+w=FQt^%vysBXwKS+-=Zg~1E$M= zhK7D*{hL`No9S&^^_byYNnRo&E2M)3Bik56%w#jjYaOJMsfIO66d8_~LPN-8m2^J; z!bvTYHUu>z$%Y81@d_~nR_HPbK%5c)SXI7kAisvFWmS5}UeUlq+xTusX3`Na*|0VQ zYMe9zw9y~g!*y$mGVq|`e{SjVL?*2(X`OV^ckN9AutMJs7n?}Kj~%6!eN5Hw zzz&uxo069K`NflA*izCdmPVrSv=~+E&#<#fAzfT9 zl1EZmjl2Mg0CL36);H(fE|MPBN4>s%_{?;_uTt>7%yiNR&=0VMUMxNs;pK4*^z|Cy z^TAa4)f@+@lYUto!zbvGz)n0&F9%|nPX9IVz5~BZr)na&g?>^KW!sZ^_Dwr#k7C?U zPt?6+B1QCiU3%=lX5sZ`(aveQk}|~9DaJs)I{B%^DcB34;8zit-mt3dv4Xv= z(UYpiF9Fg+=ju1%{ZtG$;WoM{+=?SK9{vu#HT`+`wzY}#H^!nQkF)!h-8rJ3(Iu70 ztScgrLDrBVfMNPUOH3&O(a+$#X-6m$$Zteo+OTLDqAN1_E_fUTSm&M>fdF21B|XBb zg56al2|gzQP64O@DS%pfv9+T5RS-E%wgE{~pqLEHNR~Z9l!%8_!`;?0iw-_L)M)K; z4ug0W;5@B~9>o8ePDOtsn5r?2ZoN=)6AV9RpuwXe6ymo6`2~HgquCU~^y$`G;TAK{ zA8t8PZRy{C=RkpAH~ZZiJFl6l@bn)#e=4vt z9qp>>gG+P2A@q|Gh3|x3wp-R~(+;Egy~=4`PYRz32wpuLe-ED905tkV*HA@EB`brf%*zuWZ>{5AT^?l!zYA9pv6ZG^!M0EYnvAb^uU8QC>H zoRP_E(ES#`4S%Kuv@B_cDJ%1b?g$wn9xw<)YG%MN5rr?2+@$e6rH(I{g=G_Ihv~&Vk=BRE i`5CHY_z?L%LmM6<{~u8qSMAF-O~6O-6T}<-X7~?-(Ge8@ delta 2199 zcmaKtYfMx}6vyxE-hDsVg#{K^V0kR;rO0E!0#X$L!CIBT|}Q)ep=+Z9YV6VRMu0@`^63^UBI z6iUc@24fHPsl2FL{i07i>D!{aDd_jvy#I&Do?ZbwDwe9>1lSxoO0GdvoPgq09%?Wss>w-Dy{+R zNEK9owXfTsib^kp&xfzIeqz2`)P0)$L0hm&Gc$fAaP8Al3+>fgl;;dgkeN~@&>3{1 zj%?EWj6Nk+?J~KqaVQSukzPv{Y1Ua(`hYH>C$H#AGV}@dhGCvf=bAE*XE@(#w3gZHjg%_HLh6`*c;~gwOvh_0vQS8B!l%dCPR)fc|m3Q3s|0y zS=Pp^>tl5FGBT^LMl+Z1>yeVhW#m^waUpio09-}EDRCM2&pR+j>W$gk zYt9%~jEuE-$>LhHxKcLfUrL2FbES;QbFH+KOXdrWyAhMwi59ILPab)g%!U;^V<}qD z7CsT+JRjzCZBu^5?x^iHD8JsO>eg!R7*yTaT3&L?kZY7oov|&*|v?_SjmK5*j5lwhpvExnL_~ zEgg_#C!h=9BG2r*IoJ>01(2__0X%s%SSY=h9`HDFIByvBk?-^P@j@s@Z(Ti^im0^7O@~WI0!-e$&lwI!w9(J zfC#zZ*^N@k-#oWj#S?~n)N~^E3JhGO5a4gX=}i9&=g8xhYQD*5ki1ISD|cl?1&q_N$Dz9+%01Q9^gFj3&(Erb|SO7<% z`#r!#z-_<+Ko+|YJo$k%fCqpVROqER#jIGE(u!1JNy=Hinq^b%LaWtoRnfzKgjDWT zb-t= str: + # This method is for compatibility with the new _save_pipeline_step which expects a string + prompt_list = self.get_prompt(ps, analysis_json) + # We'll just serialize the whole prompt structure for saving + return json.dumps(prompt_list, indent=2) + + def get_prompt(self, ps, analysis_json) -> list: + # The analysis_json might be a string, ensure it's formatted nicely for the prompt. + try: + if isinstance(analysis_json, dict): + return analysis_json + elif isinstance(analysis_json, str): + return json.loads(analysis_json) + else: + return [] + except json.JSONDecodeError: + return [] + +from .Base import BasePrompt + +class Analyst_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt_str(self, ps) -> str: + # This method is for compatibility with the new _save_pipeline_step which expects a string + prompt_list = self.get_prompt(ps) + return json.dumps(prompt_list, indent=2) + + def get_prompt(self, ps) -> list: + return [ + { + "role": "user", + "content": "You are an analyst. Your task is to analyze the provided data and provide insights. Please provide a detailed analysis of the data, including any patterns, trends, or anomalies you observe. Do not make any assumptions or guesses; only provide factual information based on the data." + } + ] + +from .Base import BasePrompt + +class Executor_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt_str(self, baseline_code, optimization_strategy) -> str: + # This method is for compatibility with the new _save_pipeline_step which expects a string + prompt_list = self.get_prompt(baseline_code, optimization_strategy) + return json.dumps(prompt_list, indent=2) + + def get_prompt(self, baseline_code, optimization_strategy) -> list: + return [ + { + "role": "user", + "content": "You are an executor. Your task is to execute the provided baseline code and optimize it based on the optimization strategy. Please provide a step-by-step explanation of the execution process and the optimization steps taken. Do not make any assumptions or guesses; only provide factual information based on the code and strategy." + } + ] + +from .Base import BasePrompt +import json + +class Strategist_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt_str(self, ps, analysis_json, baseline_code) -> str: + # This method is for compatibility with the new _save_pipeline_step which expects a string + prompt_list = self.get_prompt(ps, analysis_json, baseline_code) + return json.dumps(prompt_list, indent=2) + + def get_prompt(self, ps, analysis_json, baseline_code) -> list: + try: + if isinstance(analysis_json, dict): + return analysis_json + elif isinstance(analysis_json, str): + return json.loads(analysis_json) + else: + return [] + except json.JSONDecodeError: + return [] diff --git a/src/prompts/Analyst_Prompt.py b/src/prompts/Analyst_Prompt.py new file mode 100644 index 0000000..188df4c --- /dev/null +++ b/src/prompts/Analyst_Prompt.py @@ -0,0 +1,26 @@ +from .Base import BasePrompt + +class Analyst_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt(self, ps) -> list: + # Reverting to the original and correct list-based format. + return [ + { + "role": "system", + "content": "You are a top-tier Triton code analysis expert. Please read the following task description and strictly summarize its core computational type, key performance-affecting parameters, and at least three potential optimization directions in a JSON format. Do not add any explanatory text outside of the JSON structure." + }, + { + "role": "user", + "content": f"""Task Description: +{ps.instruction} + +Please provide the analysis in the following JSON structure: +{{ + "type": "", + "key_parameters": ["", "", ...], + "optimization_hints": ["", "", ""] +}}""" + } + ] diff --git a/src/prompts/Base.py b/src/prompts/Base.py new file mode 100644 index 0000000..aed2fb1 --- /dev/null +++ b/src/prompts/Base.py @@ -0,0 +1,6 @@ +class BasePrompt: + def __init__(self): + pass + + def get_prompt(self) -> str: + raise NotImplementedError diff --git a/src/prompts/Baseline_Prompt.py b/src/prompts/Baseline_Prompt.py new file mode 100644 index 0000000..6d95882 --- /dev/null +++ b/src/prompts/Baseline_Prompt.py @@ -0,0 +1,56 @@ +from .Base import BasePrompt +import json + +class Baseline_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt(self, ps, analysis_json) -> list: + # The analysis_json might be a string, ensure it's formatted nicely for the prompt. + try: + # If it's a dict, dump it to a string + if isinstance(analysis_json, dict): + analysis_str = json.dumps(analysis_json, indent=2) + else: + # If it's already a string, try to parse and re-format it to be safe + parsed_json = json.loads(analysis_json) + analysis_str = json.dumps(parsed_json, indent=2) + except (json.JSONDecodeError, TypeError): + analysis_str = str(analysis_json) # Fallback to plain string representation + + return [ + { + "role": "system", + "content": """You are an expert Python programmer specializing in Triton kernels for AMD GPUs (ROCm). +Your task is to write a simple, correct, and easy-to-read Triton kernel. +**DO NOT focus on performance optimization at this stage.** Your only goal is to generate a functionally correct baseline implementation. +""" + }, + { + "role": "user", + "content": f"""**Task Description:** +{ps.instruction} + +**Task Analysis:** +```json +{analysis_str} +``` + +**Test Code (for context on how the function will be called):** +```python +{ps.test_code} +``` + +**Output Requirements:** +1. **Signature Matching:** Your generated function's signature (name and parameters) **MUST EXACTLY MATCH** how it is called in the provided `test_code`. Analyze the test code carefully to determine the correct signature. +2. **AMD Compatibility:** Ensure the code is compatible with AMD GPUs and ROCm. Do not use CUDA-specific features. +3. **Complete & Simple Code:** Generate a single, complete Python code block. The logic should be as straightforward as possible. +4. **Use `tl.dot` for Matrix Multiplication:** For matrix multiplication operations (matrix-matrix or matrix-vector), you **MUST** use the `tl.dot` instruction. Do not implement it manually with element-wise multiplication and summation. +5. **Basic Triton Kernel:** Implement the core logic in a `@triton.jit` function. +6. **Imports:** Include necessary imports like `torch`, `triton`, and `triton.language as tl`. +7. **No Advanced Optimization:** + * **DO NOT** use `triton.autotune`. + * **DO NOT** implement complex tiling or shared memory strategies unless absolutely necessary for correctness. +""" + } + ] diff --git a/src/prompts/Executor_Prompt.py b/src/prompts/Executor_Prompt.py new file mode 100644 index 0000000..f2d5eaa --- /dev/null +++ b/src/prompts/Executor_Prompt.py @@ -0,0 +1,42 @@ +from .Base import BasePrompt + +class Executor_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt(self, baseline_code, optimization_strategy) -> list: + return [ + { + "role": "system", + "content": """You are a world-class expert in GPU kernel optimization, specializing in the Triton language for AMD GPUs. +You will be given the full source code of a Python file containing a baseline Triton kernel and a single, specific optimization strategy. +Your task is to apply **ONLY that single strategy** to the `@triton.jit` kernel function, while leaving the rest of the file completely untouched. +""" + }, + { + "role": "user", + "content": f"""**Full Original Source Code:** +```python +{baseline_code} +``` + +**Optimization Strategy to Apply:** +"{optimization_strategy}" + +**THE ABSOLUTE LAW (NON-NEGOTIABLE CORE DIRECTIVE):** + +Your task is to **replicate the entire original Python file** with surgical precision. The ONLY part of the file you are allowed to modify is the internal implementation of the `@triton.jit` kernel function based on the strategy. + +- **PRESERVE ALL OTHER CODE:** You **MUST** keep all other functions (wrapper functions, helpers, tests) and all other code (imports, comments) completely untouched. +- **NEVER MODIFY SIGNATURES:** You **MUST NOT** change the function signature of ANY function in the file. +- **YOUR SCOPE OF WORK:** Apply the optimization strategy *only* to the code inside the Triton kernel. The wrapper function's logic should generally remain the same unless the strategy requires a change (e.g., modifying grid computation). + +**Final Verification Checklist:** +- [ ] Have I copied the *entire* original file content? +- [ ] Does every function signature in my output *exactly* match the original? +- [ ] Have I only modified the internal logic of the `@triton.jit` kernel according to the strategy? + +**Now, generate the complete, correct, and surgically optimized Python file.** +""" + } + ] diff --git a/src/prompts/Strategist_Prompt.py b/src/prompts/Strategist_Prompt.py new file mode 100644 index 0000000..7fbaab5 --- /dev/null +++ b/src/prompts/Strategist_Prompt.py @@ -0,0 +1,53 @@ +from .Base import BasePrompt +import json + +class Strategist_Prompt(BasePrompt): + def __init__(self): + super().__init__() + + def get_prompt(self, ps, analysis_json, baseline_code) -> list: + # The analysis_json might be a string, ensure it's formatted nicely for the prompt. + try: + if isinstance(analysis_json, dict): + analysis_str = json.dumps(analysis_json, indent=2) + else: + parsed_json = json.loads(analysis_json) + analysis_str = json.dumps(parsed_json, indent=2) + except (json.JSONDecodeError, TypeError): + analysis_str = str(analysis_json) # Fallback to plain string representation + + return [ + { + "role": "system", + "content": """You are a world-class expert in GPU kernel optimization, specializing in the Triton language for AMD GPUs. +Your task is to analyze a baseline Triton kernel and propose a prioritized list of actionable, text-based optimization strategies. Your strategies must be grounded in specific code changes. +**DO NOT generate full code blocks.** Your output should be a numbered list of strategies, each with a brief hint about its implementation.""" + }, + { + "role": "user", + "content": f"""**Task Description:** +{ps.instruction} + +**Task Analysis:** +```json +{analysis_str} +``` + +**Baseline Code:** +```python +{baseline_code} +``` + +**Your Task:** +Based on all the provided information, generate a numbered list of at least three specific, actionable optimization strategies to improve the performance of the 'Baseline Code'. + +For each strategy, you **MUST** also briefly mention the key Triton functions (e.g., `tl.make_block_ptr`, `tl.load` with a mask) or the main code structure change (e.g., 'a new loop outside the main K-loop') that would be involved in its implementation. This provides a clear guide for the next agent. + +For example: +1. Introduce `triton.autotune` to find the best block sizes. Implementation hint: Decorate the kernel with `@triton.autotune(configs=[...], key=[...])`. +2. Improve data loading for matrix 'B' by using 2D blocking. Implementation hint: Use `tl.make_block_ptr` and `tl.load` to handle block-based loading. + +**Optimization Strategies:** +""" + } + ] diff --git a/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc index 29e23cc138d0482512bb0b41a3380383a96f3f73..de8c86bc0aab01f65581fd6d41fb027d02ff61bb 100644 GIT binary patch delta 20 acmez9`_Y&CG%qg~0}$++v2r8#Gc^EC00%w* delta 20 acmez9`_Y&CG%qg~0}vF%&DqHPObq}^K?bP+ diff --git a/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc b/src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc index 71813cab567d2f4251c15a54b2db2ae822f7f37f..71b5587f03c309e7241f0d212a0c062a6782cc85 100644 GIT binary patch delta 20 acmdm)zcZivG%qg~0}$++v2r8#I&%O@83yYB delta 20 acmdm)zcZivG%qg~0}vF%&DqGk&Kv+nS_W1C diff --git a/src/prompts/prompt_for_correction_plan.py b/src/prompts/prompt_for_correction_plan.py new file mode 100644 index 0000000..ef4ecaa --- /dev/null +++ b/src/prompts/prompt_for_correction_plan.py @@ -0,0 +1,31 @@ +prompt = """You are an expert code analyst. Your task is to analyze failed code and its test results to produce a high-level, step-by-step plan for correction. DO NOT write the code yourself. + +**Original Problem Description:** +``` +{problem} +``` + +**Failed Code:** +```python +{solution} +``` + +**Test Failure Log:** +``` +{test_result} +``` + +**Your Task:** +Analyze the failure and create a concise, high-level, step-by-step plan to fix the code. +Your plan **MUST** prioritize fixing the root cause of the failure in the following order: + +1. **Signature & Calling Errors:** First, check if the `Test Failure Log` indicates a `TypeError`, `AttributeError`, `Call Status: False`, or any error related to mismatched function arguments or names. If so, your primary suggestion **MUST** be to meticulously correct the function signatures (name, parameters, order, defaults) to exactly match the original problem's requirements. + +2. **Runtime & Environment Errors:** If the signature appears correct but the code fails during execution (e.g., `Exec Status: False`, CUDA/HIP errors, memory issues), analyze the code logic to find the bug. Pay special attention to hardcoded device names like `'cuda'`. Your suggestion should focus on fixing the specific runtime error. + +3. **Logic & Correctness Issues:** Only after the above are addressed, suggest fixes for algorithmic errors or incorrect outputs. + +Output **only** the correction plan as a numbered list. + +**Correction Plan:** +""" diff --git a/src/prompts/prompt_for_repair.py b/src/prompts/prompt_for_repair.py new file mode 100644 index 0000000..53ae7de --- /dev/null +++ b/src/prompts/prompt_for_repair.py @@ -0,0 +1,25 @@ +prompt = """You are a code repair expert. Your task is to analyze failed code, understand the error, and generate a corrected version based on a high-level plan. + +**Failed Code:** +```python +{solution} +``` + +**Test Failure Log:** +``` +{test_result} +``` + +**High-level Correction Plan:** +```markdown +{reflection} +``` + +**Your Task:** +1. Read the **Failed Code**, **Test Failure Log**, and **Correction Plan** carefully. +2. Implement the corrections described in the plan. +3. Ensure the new code is complete, syntactically correct, and directly addresses the error. +4. Output **only** the complete, corrected Python code block. Do not add any explanations or text outside the code block. + +**Corrected Code:** +""" \ No newline at end of file diff --git a/src/retrievers/__pycache__/retriever.cpython-312.pyc b/src/retrievers/__pycache__/retriever.cpython-312.pyc index dc463e336802936d3187aecc27bdae9c5fdf2a5e..c7a0e0aa4781d68fcbbd762f07a157b54bac3736 100644 GIT binary patch delta 20 acmbOwHA{;7G%qg~0}$++v2r6fKQ90{BLzAD delta 20 acmbOwHA{;7G%qg~0}vF%&DqG!&kF!FWCW!E diff --git a/src/temp/embedding_triton_kernel.py b/src/temp/embedding_triton_kernel.py new file mode 100644 index 0000000..f5bf605 --- /dev/null +++ b/src/temp/embedding_triton_kernel.py @@ -0,0 +1,207 @@ + + + +import triton +import triton.language as tl +import torch + +@triton.jit +def embedding_kernel( + input_ids_ptr, + weight_ptr, + out_ptr, + vob_start_id, + vob_end_id, + stride_weight, + stride_out, + NUM_SEQS, + NUM_TOKENS_PER_SEQ, + embedding_dim, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_NN: tl.constexpr +): + pid_0 = tl.program_id(0) # sequence index + pid_1 = tl.program_id(1) # token index with block stride + + # Calculate mask bounds + seq_mask = pid_0 < NUM_SEQS + token_start = pid_1 * BLOCK_NN + token_mask = token_start + tl.arange(0, BLOCK_NN) + token_mask = token_mask < NUM_TOKENS_PER_SEQ + full_mask = seq_mask & token_mask + + # Compute base addresses + seq_offset = pid_0 * NUM_TOKENS_PER_SEQ * stride_out + token_offset = token_start * stride_out + base_out_ptr = out_ptr + seq_offset + token_offset + base_ids_ptr = input_ids_ptr + pid_0 * NUM_TOKENS_PER_SEQ + token_offset // stride_out + + # Load token IDs with mask + ids = tl.load(base_ids_ptr + tl.arange(0, BLOCK_NN), mask=full_mask, other=vob_start_id - 1) + mask_ids = (ids >= vob_start_id) & (ids < vob_end_id) & full_mask + + # Vectorize embedding loads/stores by processing blocks of embedding_dim with BLOCK_DMODEL granularity + for d in range(0, embedding_dim, BLOCK_DMODEL): + # Create vectorized weight index + weight_vec_ptr = weight_ptr + ids * stride_weight + d + + # Load vectorized embedding data + weight_vec = tl.load(weight_vec_ptr, mask=mask_ids, other=0.0) + + # Compute output pointer with vectorized store + out_ptr_vec = base_out_ptr + d + tl.store(out_ptr_vec, weight_vec, mask=full_mask) + +def embedding(input_ids, weight, vob_start_id, vob_end_id, out=None): + """Triton-accelerated embedding lookup function.""" + if out is None: + out = torch.empty( + input_ids.shape[0], input_ids.shape[1], weight.shape[1], + device=input_ids.device, dtype=weight.dtype + ) + + NUM_SEQS, NUM_TOKENS_PER_SEQ = input_ids.shape + embedding_dim = weight.shape[1] + + # Constants optimized from analysis + BLOCK_DMODEL = triton.next_power_of_2(embedding_dim) + BLOCK_N = 64 + BLOCK_NN = 1 + + # Launch kernel grid + grid = lambda META: (NUM_SEQS, triton.cdiv(NUM_TOKENS_PER_SEQ, META['BLOCK_NN'])) + + embedding_kernel[grid]( + input_ids, weight, out, + vob_start_id, vob_end_id, + weight.stride(0), + out.stride(1), + NUM_SEQS, + NUM_TOKENS_PER_SEQ, + embedding_dim, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + BLOCK_NN=BLOCK_NN + ) + return out + +def test_embedding(): + # 参数定义 + vocab_size = 1000 # 词汇表大小 + embedding_dim = 512 # 嵌入维度 + sequence_length = 128 # 输入序列长度 + vob_start_id = 10 # 词汇表起始 ID + vob_end_id = 1000 # 词汇表结束 ID + + # 创建测试输入张量 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + + # 调用嵌入函数 + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + # 保存结果 + results = {} + results['test_case_1'] = out.clone() + + # 测试不同的输入 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_2'] = out.clone() + + # 测试不同的词汇表范围 + vob_start_id = 0 + vob_end_id = 500 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_3'] = out.clone() + + # 测试不同的嵌入维度 + embedding_dim = 256 + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_4'] = out.clone() + + return results + +result_gold = test_embedding() + +################################################################################################################################################## + + + +import torch + +def test_embedding(): + # 参数定义 + vocab_size = 1000 # 词汇表大小 + embedding_dim = 512 # 嵌入维度 + sequence_length = 128 # 输入序列长度 + vob_start_id = 10 # 词汇表起始 ID + vob_end_id = 1000 # 词汇表结束 ID + + # 创建测试输入张量 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + + # 调用嵌入函数 + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + + # 保存结果 + results = {} + results['test_case_1'] = out.clone() + + # 测试不同的输入 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_2'] = out.clone() + + # 测试不同的词汇表范围 + vob_start_id = 0 + vob_end_id = 500 + input_ids = torch.randint( + vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_3'] = out.clone() + + # 测试不同的嵌入维度 + embedding_dim = 256 + weight = torch.randn( + vocab_size, embedding_dim, dtype=torch.float32, device='cuda' + ) + out = torch.zeros( + sequence_length, embedding_dim, dtype=torch.float32, device='cuda' + ) + embedding(input_ids, weight, vob_start_id, vob_end_id, out) + results['test_case_4'] = out.clone() + + return results + +result_gold = test_embedding() diff --git a/src/temp/flash_decode2_phi.py b/src/temp/flash_decode2_phi.py new file mode 100644 index 0000000..2d0d4e5 --- /dev/null +++ b/src/temp/flash_decode2_phi.py @@ -0,0 +1,205 @@ + + +import triton +import triton.language as tl +import torch + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, # int32[B] number of valid tokens per batch + Mid_O, # float[B, H, Sb, D] partial outputs per block + Mid_O_LogExpSum, # float[B, H, Sb] log-sum-exp of the logits per block + Out, # float[B, H, D] final output + stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, + stride_ls_oh, stride_ls_os, + stride_out_ob, stride_out_oh, stride_out_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr +): + pid_b = tl.program_id(0) # batch dimension + pid_h = tl.program_id(1) # head dimension + + # ------------------------------------------------------------------ + # Offsets for the first element of the current (batch, head) slice. + # ------------------------------------------------------------------ + mid_o_ptr = Mid_O + pid_b * stride_mid_ob + pid_h * stride_mid_oh + mid_ls_ptr = Mid_O_LogExpSum + pid_b * stride_mid_ob + pid_h * stride_mid_oh + + # Load actual sequence length for this batch + seq_len = tl.load(B_Seqlen + pid_b) + block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ # ceil division + + if block_n_size == 0: + return + + # Init per-thread accumulators + max_logic = -float('inf') + sum_exp = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for bn in range(0, block_n_size): + # Load the log-sum-exp for this block + block_lse = tl.load(mid_ls_ptr + bn * stride_ls_os) + + # Update running range reduction + new_max = tl.maximum(max_logic, block_lse) + scale = tl.exp(max_logic - new_max) + new_scale = tl.exp(block_lse - new_max) + + # Scale prior accumulation + sum_exp = sum_exp * scale + new_scale + acc = acc * scale + + # Load partial output + tv = tl.load(mid_o_ptr + bn * stride_mid_os + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_mid_od) + + acc = acc + tv * new_scale # broadcasted multiply over BLOCK_DMODEL + max_logic = new_max + + # Normalize + inv_sum_exp = 1.0 / sum_exp + acc = acc * inv_sum_exp + + # Store final result + out_ptr = Out + pid_b * stride_out_ob + pid_h * stride_out_oh + tl.store(out_ptr + tl.arange(0, BLOCK_DMODEL) * stride_out_od, + acc) + + +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq: int): + """ + Top-level wrapper that launches the Triton kernel above. + mid_out: shape (B, H, Sb, D) + mid_out_logexpsum: shape (B, H, Sb) + B_Seqlen: shape (B,) dtype int32 + Out: shape (B, H, D) + block_seq: integer constant equal to BLOCK_SEQ on which the kernel was built. + """ + B, H, Sb, D = mid_out.shape + + # Determine launch grid + grid = (B, H) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + mid_out, + mid_out_logexpsum, + Out, + mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), + mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), + Out.stride(0), Out.stride(1), Out.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=D, + num_warps=4 + ) + + +# Define the test function +def test_flash_decode_stage2(): + # Define the parameters for different test cases + batch_size = 2 + head_num = 4 + seq_block_num = 3 + head_dim = 64 + block_seq = 16 + + test_cases = { + "test_case_1": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + }, + "test_case_2": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + 1 # Different block size + }, + "test_case_3": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq // 2 # Different block size + }, + "test_case_4": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq * 2 # Different block size + } + } + + # Execute the function for all test cases + results = {} + for key, test_case in test_cases.items(): + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + results[key] = test_case["Out"] + + return results + +# Run the test +result_gold = test_flash_decode_stage2() + + +################################################################################################################################################## + + + +import torch + +# Define the test function +def test_flash_decode_stage2(): + # Define the parameters for different test cases + batch_size = 2 + head_num = 4 + seq_block_num = 3 + head_dim = 64 + block_seq = 16 + + test_cases = { + "test_case_1": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + }, + "test_case_2": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq + 1 # Different block size + }, + "test_case_3": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq // 2 # Different block size + }, + "test_case_4": { + "B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'), + "mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'), + "mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'), + "Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'), + "block_seq": block_seq * 2 # Different block size + } + } + + # Execute the function for all test cases + results = {} + for key, test_case in test_cases.items(): + flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"]) + results[key] = test_case["Out"] + + return results + +# Run the test +result_gold = test_flash_decode_stage2() diff --git a/src/temp/int4_matmul.py b/src/temp/int4_matmul.py new file mode 100644 index 0000000..89bfebe --- /dev/null +++ b/src/temp/int4_matmul.py @@ -0,0 +1,245 @@ + + +import torch +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Allocate shared memory for double buffering + b_shared = tl.full([BLOCK_K, BLOCK_N], 0, dtype=tl.int32) + b_next = tl.full([BLOCK_K, BLOCK_N], 0, dtype=tl.int32) + b_buf = [b_shared, b_next] + + # Pre-load first tile + k_start = 0 + k_offs = k_start + offs_k + b = tl.load(b_ptrs) + b_buf[0] = b + + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Select current buffer (double buffering) + b_tile = b_buf[k % 2] + + # Pre-load next tile in background if not last iteration + if k + 1 < tl.cdiv(K, BLOCK_K): + k_offs_next = (k + 1) * BLOCK_K + offs_k + b_ptrs_next = b_ptr + ((k_offs_next[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn) + b_next = tl.load(b_ptrs_next) + b_buf[(k + 1) % 2] = b_next + + # Process current tile from shared memory + k_offs = k * BLOCK_K + offs_k + a_mask = k_offs[None, :] < K + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + # Unpack INT4 from shared memory + b_int4 = ((b_tile >> ((k_offs[:, None] % 8) * 4)) & 0xF).to(tl.int32) + + group_idx = (k * BLOCK_K + offs_k) // GROUP_SIZE + group_idx = group_idx[:, None] + + scales = tl.load(scales_ptr + group_idx * stride_scales_g + offs_n[None, :] * stride_scales_n) + zeros = tl.load(zeros_ptr + group_idx * stride_zeros_g + offs_n[None, :] * stride_zeros_n) + + b_deq = (b_int4 - zeros) * scales + + acc += tl.dot(a, b_deq).to(tl.float32) + + a_ptrs += BLOCK_K * stride_ak + b_ptrs += (BLOCK_K // 8) * stride_bk + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, acc, mask=mask) + + +def quantize_int4(weights, group_size=128): + """ + Quantize a float16 weight tensor to symmetric int4 with group-wise + quantization. + Each group has its own zero point and scale. + + Args: + weights: a float tensor of shape [K, N] + group_size: the group size + + Returns: + int_weight: a int32 tensor of shape [K//8, N] + scales: a float16 tensor of shape [num_groups, N] + zeros: a float16 tensor of shape [num_groups, N] + """ + K, N = weights.shape + num_groups = K // group_size + assert K % group_size == 0, "K must be divisible by group_size" + + weights = weights.view(num_groups, group_size, N) + + # Compute min, max in each group + mins = weights.amin(dim=1) # [num_groups, N] + maxs = weights.amax(dim=1) # [num_groups, N] + + scales = (maxs - mins) / 15.0 + zeros = torch.round(-mins / scales).clamp(0, 15).to(torch.int32) + + # Avoid division by zero + scales = torch.where(scales == 0, torch.ones_like(scales), scales) + + qw = torch.round(weights / scales[:, None, :] + zeros[:, None, :]).clamp(0, 15).to(torch.int32) + + # Pack int4 to int32 (8 values per int32) + qw = qw.view(num_groups, group_size // 8, 8, N) + int4_packed = torch.zeros(num_groups, group_size // 8, N, dtype=torch.int32, device=weights.device) + for i in range(8): + int4_packed |= qw[:, :, i, :] << (i * 4) + + # Flatten back to [K//8, N] + int4_packed = int4_packed.view(K // 8, N) + + # Reshape scales and zeros + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + + return int4_packed, scales, zeros, None # dummy arg returns + + +def unpack_int4(int_weight, scales, zeros, group_size): + """ + De-quantize packed int4 weights back to float16. + + Args: + int_weight: a int32 tensor of shape [K//8, N] + scales: float16 tensor [num_groups, N] + zeros: float16 tensor [num_groups, N] + group_size: the group size used during quantization + + Returns: + weights: a float16 tensor of shape [K, N] + """ + K = int_weight.shape[0] * 8 + N = int_weight.shape[1] + weights = torch.empty((K, N), dtype=torch.float16, device=int_weight.device) + + num_groups = K // group_size + group_size_tiles = group_size // 8 + + for g in range(num_groups): + group_start = g * group_size + int_group = int_weight[g * group_size_tiles: (g+1) * group_size_tiles] # [group_size//8, N] + + unpacked = torch.empty((group_size, N), dtype=torch.float16, device=int_weight.device) + + for i in range(group_size//8): + for j in range(8): + val = (int_group[i] >> (j * 4)) & 0xF + unpacked[i*8 + j] = (val.to(torch.float16) - zeros[g]) * scales[g] + + weights[group_start: group_start + group_size] = unpacked + + return weights + + +def matmul_dequantize_int4_s2(a, b, scales, zeros, group_size): + """ + Perform matrix multiplication a @ b.T where b is quantized to int4. + + Args: + a: float16 [M, K] + b: int4-packed int32 [K//8, N] + scales: float16 [num_groups, N] (num_groups = ceil_div(K, group_size)) + zeros: float16 [num_groups, N] + group_size: the group size used in quantize_int4 + + Returns: + c: float16 [M, N] + """ + M, K = a.shape + _, N = b.shape + + c = torch.empty((M, N), dtype=torch.float16, device=a.device) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_M']), + triton.cdiv(N, META['BLOCK_N']), + ) + + kernel_config = {'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE': group_size} + + matmul_kernel[grid]( + a, b, c, + scales, zeros, + M, N, 8 * K, # K in rows of packed int4 is 8*external_K + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + **kernel_config + ) + + return c + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + group_size = 128 + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + # Test case + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + results = { + "test_case_1": triton_output + } + + return results + +result_gold = test_correct_int4_s2() + + +################################################################################################################################################## + + + +def test_correct_int4_s2(M=32, K=4096, N=4096): + group_size = 128 + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) + + # Test case + triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) + + results = { + "test_case_1": triton_output + } + + return results + +result_gold = test_correct_int4_s2() diff --git a/src/temp/l2_norm_bwd.py b/src/temp/l2_norm_bwd.py new file mode 100644 index 0000000..80290e8 --- /dev/null +++ b/src/temp/l2_norm_bwd.py @@ -0,0 +1,84 @@ +\n\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _l2_norm_bwd_kernel(X, DY, DX, M, N, stride_x_row, stride_dy_row, stride_dx_row, eps, BLOCK_N: tl.constexpr):\n # Changed: We now EXPECT BLOCK_N to be 16 or 32 (vectorized block size vectorized loads/stores\n # and use proper masking regardless of actual N)\n \n # Locate which row this program (PID) will process\n row_idx = tl.program_id(0)\n if row_idx >= M:\n return\n\n # Offset pointers for the current row\n x_row_ptr = X + row_idx * stride_x_row\n dy_row_ptr = DY + row_idx * stride_dy_row\n dx_row_ptr = DX + row_idx * stride_dx_row\n\n # Gather data for this row into registers. The mask guarantees we do not \n # Read/write past the cache line even when N < BLOCK_N\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n x = tl.load(x_row_ptr + cols, mask=mask, other=0.0)\n dy = tl.load(dy_row_ptr + cols, mask=mask, other=0.0)\n\n # Compute squared input samplewise\n x_sq = x * x\n # Row-wise variance followed by unbiased correction trick via `/(N)` — here ignoring Bessel correction.\n var = tl.sum(x_sq, axis=0) / N\n rstd = tl.rsqrt(var + eps)\n\n # Core backward formula\n dot = tl.sum(dy * x, axis=0)\n dx = dy * rstd - dot * (1.0 / (var + eps)) * rstd * x\n\n # Store dx values for this row (mask keeps vectorized store compliant)\n tl.store(dx_row_ptr + cols, dx, mask=mask)\n\ndef _l2_norm_bwd(x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5): + """ + Computes the gradient of the input tensor `x` with respect to the loss given the upstream gradient `dy`. + + Parameters + ---------- + x : PyTorch tensor + Input tensor of shape (*, N) where L2 norm is applied along the last dimension. + dy : PyTorch tensor + Upstream gradient of same shape as `x`. + eps : float + Small value to avoid division by zero. + + Returns + ------- + torch.Tensor + Gradient tensor `dx` of same shape as `x`. + """ + # Flatten leading dimensions to process by row + orig_shape = x.shape + x = x.view(-1, x.shape[-1]).contiguous() + dy = dy.view(-1, dy.shape[-1]).contiguous() + + M, N = x.shape + BLOCK_N = triton.next_power_of_2(N) + if N > BLOCK_N: + raise ValueError( + f"Feature dimension too large for tiled reduce: {N} > {BLOCK_N}" + ) + + # Allocate output gradient tensor + dx = torch.empty_like(x) + + # Launch Triton grid + grid = (M,) + _l2_norm_bwd_kernel[grid]( + x, dy, dx, + M, N, + x.stride(0), dy.stride(0), dx.stride(0), + eps, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + # Reshape back to original shape + return dx.view(orig_shape)\n """\n Computes the gradient of the input tensor `x` with respect to the loss given the upstream gradient `dy`.\n\n Parameters\n ----------\n x : PyTorch tensor\n Input tensor of shape (*, N) where L2 norm is applied along the last dimension.\n dy : PyTorch tensor\n Upstream gradient of same shape as `x`.\n eps : float\n Small value to avoid division by zero.\n\n Returns\n -------\n torch.Tensor\n Gradient tensor `dx` of same shape as `x`.\n """\n # Flatten leading dimensions to process by row\n orig_shape = x.shape\n x = x.view(-1, x.shape[-1]).contiguous()\n dy = dy.view(-1, dy.shape[-1]).contiguous()\n\n M, N = x.shape\n BLOCK_N = 32 # Changed to fixed vectorized tile (32 or 16) independent of N\n # if N > BLOCK_N: # Removed constraint since we mask vectorized ops\n # raise ValueError( ... )\n\n # Allocate output gradient tensor\n dx = torch.empty_like(x)\n\n # Launch Triton grid\n grid = (M,)\n _l2_norm_bwd_kernel[grid](\n x, dy, dx,\n M, N,\n x.stride(0), dy.stride(0), dx.stride(0),\n eps,\n BLOCK_N=BLOCK_N,\n num_warps=4,\n )\n\n # Reshape back to original shape\n return dx.view(orig_shape)\n\n# Test the backward L2 normalization\ndef test_l2_norm_bwd():\n results = {}\n \n # Test case 1: Default case\n x = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(4, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_1'] = dx\n\n # Test case 2: Different shape\n x = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dy = torch.randn(2, 16, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_2'] = dx\n\n # Test case 3: Larger tensor\n x = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(8, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_3'] = dx\n\n # Test case 4: Edge case with small tensor\n x = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dy = torch.randn(1, 8, device='cuda', dtype=torch.float32)\n dx = _l2_norm_bwd(x, dy)\n results['test_case_4'] = dx\n\n return results\n\n# Run the tests\nresult_gold = test_l2_norm_bwd()\n +################################################################################################################################################## + + + +import torch + +# Test the backward L2 normalization +def test_l2_norm_bwd(): + results = {} + + # Test case 1: Default case + x = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(4, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_1'] = dx + + # Test case 2: Different shape + x = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dy = torch.randn(2, 16, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_2'] = dx + + # Test case 3: Larger tensor + x = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(8, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_3'] = dx + + # Test case 4: Edge case with small tensor + x = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dy = torch.randn(1, 8, device='cuda', dtype=torch.float32) + dx = _l2_norm_bwd(x, dy) + results['test_case_4'] = dx + + return results + +# Run the tests +result_gold = test_l2_norm_bwd() diff --git a/src/temp/l2_norm_triton1.py b/src/temp/l2_norm_triton1.py new file mode 100644 index 0000000..981eba5 --- /dev/null +++ b/src/temp/l2_norm_triton1.py @@ -0,0 +1,121 @@ + + +import triton +import triton.language as tl +import torch + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 64}), + triton.Config({'BLOCK_N': 128}), + triton.Config({'BLOCK_N': 256}), + triton.Config({'BLOCK_N': 512}), + triton.Config({'BLOCK_N': 1024}), + ], + key=['N'], +) +@triton.jit +def _l2_norm_fwd_1pass_kernel(X_ptr, Y_ptr, stride_x_row, N, eps, BLOCK_N: tl.constexpr): + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_N) + mask = col_offsets < N + + X_row_ptr = X_ptr + row_idx * stride_x_row + + x = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0) + x_sq = x * x + var = tl.sum(x_sq, axis=0) + + rstd = tl.rsqrt(var + eps) + + y = x * rstd + tl.store(Y_ptr + row_idx * stride_x_row + col_offsets, y, mask=mask) + +def _l2_norm_fwd(x): + original_shape = x.shape + x = x.view(-1, original_shape[-1]) + x = x.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + + # Calculate BLOCK_N: must not exceed 64KB + element_size = x.element_size() + BLOCK_N = min(8192 // element_size, N) + + if N > BLOCK_N: + raise RuntimeError(f"Feature dimension {N} exceeds maximum BLOCK_N {BLOCK_N}") + + grid = (M,) + _l2_norm_fwd_1pass_kernel[grid]( + x, y, + x.stride(0), + N, + 1e-6, + BLOCK_N=BLOCK_N + ) + + return y.view(original_shape) + +# Test the forward L2 normalization +def test_l2_norm_fwd(): + results = {} + + # Test case 1 + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + y1 = _l2_norm_fwd(x1) + results['test_case_1'] = y1 + + # Test case 2: Different batch size + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + y2 = _l2_norm_fwd(x2) + results['test_case_2'] = y2 + + # Test case 3: Different feature size + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + y3 = _l2_norm_fwd(x3) + results['test_case_3'] = y3 + + # Test case 4: Larger tensor + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + y4 = _l2_norm_fwd(x4) + results['test_case_4'] = y4 + + return results + +result_gold = test_l2_norm_fwd() + + +################################################################################################################################################## + + + +import torch + +# Test the forward L2 normalization +def test_l2_norm_fwd(): + results = {} + + # Test case 1 + x1 = torch.randn(4, 8, device='cuda', dtype=torch.float32) + y1 = _l2_norm_fwd(x1) + results['test_case_1'] = y1 + + # Test case 2: Different batch size + x2 = torch.randn(2, 8, device='cuda', dtype=torch.float32) + y2 = _l2_norm_fwd(x2) + results['test_case_2'] = y2 + + # Test case 3: Different feature size + x3 = torch.randn(4, 4, device='cuda', dtype=torch.float32) + y3 = _l2_norm_fwd(x3) + results['test_case_3'] = y3 + + # Test case 4: Larger tensor + x4 = torch.randn(8, 8, device='cuda', dtype=torch.float32) + y4 = _l2_norm_fwd(x4) + results['test_case_4'] = y4 + + return results + +result_gold = test_l2_norm_fwd() diff --git a/src/temp/matrix_transpose.py b/src/temp/matrix_transpose.py new file mode 100644 index 0000000..ef796a9 --- /dev/null +++ b/src/temp/matrix_transpose.py @@ -0,0 +1,45 @@ +\n\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(M, Out, matrix_stridex, matrix_stridey, out_stridex, out_stridey,\n SIZE_M, D_HEAD, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid_row = tl.program_id(0)\n pid_col = tl.program_id(1)\n\n rows = pid_row * BLOCK_M + tl.arange(0, BLOCK_M)\n cols = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n\n mask_rows = rows < SIZE_M\n mask_cols = cols < D_HEAD\n\n offset_m = rows[:, None] * matrix_stridey + cols[None, :] * matrix_stridex\n offset_out = rows[:, None] * out_stridex + cols[None, :] * out_stridey\n\n mask = mask_rows[:, None] & mask_cols[None, :]\n\n x = tl.load(M + offset_m, mask=mask)\n\n smem = tl.alloc_block(BLOCK_M * BLOCK_N, dtype=x.dtype, scope="shared")\n smem_idx = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]\n tl.store(smem + smem_idx, x, mask=mask)\n tl.debug_barrier()\n\n x_t = tl.load(smem + smem_idx)\n tl.store(Out + offset_out, x_t, mask=mask)\n\ndef wrapper(size_m, d_head): + dtype = torch.float16 + matrix = torch.randn((size_m, d_head), dtype=dtype, device='hip') + out = torch.empty((size_m, d_head), dtype=dtype, device='hip') + + BLOCK_M = 16 + BLOCK_N = 16 + grid = lambda META: (triton.cdiv(size_m, META['BLOCK_M']), + triton.cdiv(d_head, META['BLOCK_N'])) + + kernel[grid]( + matrix, out, + matrix.stride(0), matrix.stride(1), + out.stride(0), out.stride(1), + size_m, d_head, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + num_warps=4 + ) + + return out.t().contiguous().t()\n dtype = torch.float16\n matrix = torch.randn((size_m, d_head), dtype=dtype, device='hip')\n out = torch.empty((size_m, d_head), dtype=dtype, device='hip')\n\n BLOCK_M = 16\n BLOCK_N = 16\n grid = lambda META: (triton.cdiv(size_m, META['BLOCK_M']),\n triton.cdiv(d_head, META['BLOCK_N']))\n\n kernel[grid](\n matrix, out,\n matrix.stride(0), matrix.stride(1),\n out.stride(0), out.stride(1),\n size_m, d_head,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,\n num_warps=4\n )\n\n return out.t().contiguous().t()\n\ndef test_triton_vs_torch():\n results = {}\n\n # 测试用例 1: 基本矩阵转置 (小矩阵)\n size_m, d_head = 16, 16\n out = wrapper(size_m, d_head)\n results["test_case_1"] = out.clone()\n\n # 测试用例 2: 非方形矩阵\n size_m, d_head = 32, 64\n out = wrapper(size_m, d_head)\n results["test_case_2"] = out.clone()\n\n return results\n\n# 运行测试\nresult_gold = test_triton_vs_torch()\n# print(result_gold)\n +################################################################################################################################################## + + + +import torch + +def test_triton_vs_torch(): + results = {} + + # 测试用例 1: 基本矩阵转置 (小矩阵) + size_m, d_head = 16, 16 + out = wrapper(size_m, d_head) + results["test_case_1"] = out.clone() + + # 测试用例 2: 非方形矩阵 + size_m, d_head = 32, 64 + out = wrapper(size_m, d_head) + results["test_case_2"] = out.clone() + + return results + + +# 运行测试 +result_gold = test_triton_vs_torch() +# print(result_gold) \ No newline at end of file diff --git a/src/temp/matrix_vector_multip.py b/src/temp/matrix_vector_multip.py new file mode 100644 index 0000000..82cc66f --- /dev/null +++ b/src/temp/matrix_vector_multip.py @@ -0,0 +1,42 @@ +\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}),\n ],\n key=['M', 'N']\n)\n@triton.jit\ndef mv_kernel(A_ptr, B_ptr, C_ptr, stride_am, stride_an, stride_b, stride_c, N, M, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid_n = tl.program_id(0)\n pid_m = tl.program_id(1)\n\n # Start indices\n n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n mask_n = n_offsets < N\n mask_m = m_offsets < M\n\n # Compute the starting address\n a_block_ptr = A_ptr + n_offsets[:, None] * stride_am + m_offsets[None, :] * stride_an\n b_block_ptr = B_ptr + m_offsets * stride_b\n\n acc = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n for k in range(0, M, BLOCK_M):\n k_offsets = k + m_offsets\n mask_km = k_offsets < M\n\n a = tl.load(a_block_ptr, mask=mask_n[:, None] & mask_km[None, :], other=0.0)\n b = tl.load(b_block_ptr, mask=mask_km, other=0.0)\n\n acc += tl.sum(a * b[None, :], axis=1)\n\n a_block_ptr += BLOCK_M * stride_an\n b_block_ptr += BLOCK_M * stride_b\n\n c_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n mask_c = c_offsets < N\n tl.store(C_ptr + c_offsets * stride_c, acc, mask=mask_c)\n\n\ndef mv(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + assert A.ndim == 2, "A must be 2D" + assert B.ndim == 1, "B must be 1D" + assert A.shape[1] == B.shape[0], f"Matrix-vector shape mismatch: {A.shape} vs {B.shape}" + + N, M = A.shape + C = torch.empty((N,), dtype=A.dtype, device=A.device) + + BLOCK_M = 64 + BLOCK_N = 64 + + grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), 1) + mv_kernel[grid]( + A, B, C, + A.stride(0), A.stride(1), + B.stride(0), + C.stride(0), + N, M, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return C\n assert A.ndim == 2, "A must be 2D"\n assert B.ndim == 1, "B must be 1D"\n assert A.shape[1] == B.shape[0], f"Matrix-vector shape mismatch: {A.shape} vs {B.shape}"\n\n N, M = A.shape\n C = torch.empty((N,), dtype=A.dtype, device=A.device)\n\n BLOCK_M = 64\n BLOCK_N = 64\n\n grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), 1)\n mv_kernel[grid](\n A, B, C,\n A.stride(0), A.stride(1),\n B.stride(0),\n C.stride(0),\n N, M,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N\n )\n return C\n\n\ndef test_mv():\n # 测试用例 2: 4x3 矩阵与 3x1 向量相乘\n A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda')\n B = torch.tensor([1.0, 2.0, 3.0], device='cuda')\n triton_result_2 = mv(A, B)\n\n # 测试用例 3: 32x16 矩阵与 16x1 向量相乘\n A = torch.randn(32, 16, device='cuda')\n B = torch.randn(16, device='cuda')\n triton_result_3 = mv(A, B)\n\n return {\n "test_case_2": triton_result_2,\n "test_case_3": triton_result_3,\n }\n\nresult_gold = test_mv() +################################################################################################################################################## + + + +def test_mv(): + # 测试用例 2: 4x3 矩阵与 3x1 向量相乘 + A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device='cuda') + B = torch.tensor([1.0, 2.0, 3.0], device='cuda') + triton_result_2 = mv(A, B) + + # 测试用例 3: 32x16 矩阵与 16x1 向量相乘 + A = torch.randn(32, 16, device='cuda') + B = torch.randn(16, device='cuda') + triton_result_3 = mv(A, B) + + return { + "test_case_2": triton_result_2, + "test_case_3": triton_result_3, + } + +result_gold = test_mv() diff --git a/src/temp/rotary_transform.py b/src/temp/rotary_transform.py new file mode 100644 index 0000000..df87e25 --- /dev/null +++ b/src/temp/rotary_transform.py @@ -0,0 +1,291 @@ + + +import torch +import triton +import triton.language as tl + +@triton.jit +def rotary_kernel(X, COS, SIN, CU_SEQLENS, OUT, + HEAD_SIZE: tl.constexpr, + ROTARY_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + stride_xb, stride_xm, stride_xh, stride_xd, + stride_outb, stride_outm, stride_outh, stride_outd, + stride_cosm, stride_cosd, + stride_sinm, stride_sind): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + if CU_SEQLENS is not None: + seq_start = tl.load(CU_SEQLENS + pid_batch).to(tl.int32) + seq_end = tl.load(CU_SEQLENS + pid_batch + 1).to(tl.int32) + seqlen = seq_end - seq_start + if pid_m * BLOCK_M >= seqlen: + return + real_m = seq_start + pid_m * BLOCK_M + max_m = seqlen + stride_batch = stride_xb + else: + real_m = pid_m * BLOCK_M + max_m = (stride_xb // stride_xm) # This is crude, but matches tests + stride_batch = 0 + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + COS_block = COS + offs_m[:, None] * stride_cosm + offs_n[None, :] * stride_cosd + SIN_block = SIN + offs_m[:, None] * stride_sinm + offs_n[None, :] * stride_sind + + # Vectorized fp16 load + cos = tl.load(COS_block, mask=offs_m[:, None] < max_m, other=0.0).to(tl.float16) + sin = tl.load(SIN_block, mask=offs_m[:, None] < max_m, other=0.0).to(tl.float16) + + limit = min(ROTARY_DIM, HEAD_SIZE) + + for d in range(0, limit, BLOCK_N): + d0 = d + offs_n + mask = d0 < limit + + if IS_INTERLEAVED: + # Interleaved format: x[..., d] = real, x[..., d+1] = imag for pair d//2 + pos = d0 // 2 + is_even = (d0 % 2) == 0 + + # Load real part + x_real_addr = X + stride_batch * pid_batch + stride_xm * real_m[:, None] + stride_xh * pid_head + stride_xd * d0[None, :] + x_real = tl.load(x_real_addr, mask=mask[None, :] & (offs_m[:, None] < max_m), other=0.0) + + # Load imag part (next from current) + x_imag_addr = X + stride_batch * pid_batch + stride_xm * real_m[:, None] + stride_xh * pid_head + stride_xd * (d0[None, :] + 1) + x_imag = tl.load(x_imag_addr, mask=mask[None, :] & (offs_m[:, None] < max_m), other=0.0) + + rot_cos = tl.where(is_even[None, :], cos[:, pos], sin[:, pos]).to(tl.float32) + rot_sin = tl.where(is_even[None, :], -sin[:, pos], cos[:, pos]).to(tl.float32) + + if CONJUGATE: + rot_sin = -rot_sin + + out_real = x_real * rot_cos - x_imag * rot_sin + out_imag = x_real * rot_sin + x_imag * rot_cos + + tl.store(OUT + stride_outb * pid_batch + stride_outm * real_m[:, None] + stride_outh * pid_head + stride_outd * d0[None, :], out_real, mask=mask[None, :] & (offs_m[:, None] < max_m)) + tl.store(OUT + stride_outb * pid_batch + stride_outm * real_m[:, None] + stride_outh * pid_head + stride_outd * (d0[None, :] + 1), out_imag, mask=mask[None, :] & (offs_m[:, None] < max_m)) + else: + # Non-interleaved format: first half even, second half odd + pos = d0 // 2 + x_even_addr = X + stride_batch * pid_batch + stride_xm * real_m[:, None] + stride_xh * pid_head + stride_xd * (2 * d0[None, :]) + x_odd_addr = X + stride_batch * pid_batch + stride_xm * real_m[:, None] + stride_xh * pid_head + stride_xd * (2 * d0[None, :] + 1) + + x_even = tl.load(x_even_addr, mask=mask[None, :] & (offs_m[:, None] < max_m), other=0.0) + x_odd = tl.load(x_odd_addr, mask=mask[None, :] & (offs_m[:, None] < max_m), other=0.0) + + if CONJUGATE: + x_odd = -x_odd + + rot_cos = cos[:, pos].to(tl.float32) + rot_sin = sin[:, pos].to(tl.float32) + + out_even = x_even * rot_cos - x_odd * rot_sin + out_odd = x_even * rot_sin + x_odd * rot_cos + + tl.store(OUT + stride_outb * pid_batch + stride_outm * real_m[:, None] + stride_outh * pid_head + stride_outd * (2 * d0[None, :]), out_even, mask=mask[None, :] & (offs_m[:, None] < max_m)) + tl.store(OUT + stride_outb * pid_batch + stride_outm * real_m[:, None] + stride_outh * pid_head + stride_outd * (2 * d0[None, :] + 1), out_odd, mask=mask[None, :] & (offs_m[:, None] < max_m)) + + # Copy remaining dimensions (after rotary) + if limit < HEAD_SIZE: + for d in range(limit, HEAD_SIZE, BLOCK_N): + d0 = d + offs_n + mask = d0 < HEAD_SIZE + x_addr = X + stride_batch * pid_batch + stride_xm * real_m[:, None] + stride_xh * pid_head + stride_xd * d0[None, :] + x_val = tl.load(x_addr, mask=mask[None, :] & (offs_m[:, None] < max_m), other=0.0) + tl.store(OUT + stride_outb * pid_batch + stride_outm * real_m[:, None] + stride_outh * pid_head + stride_outd * d0[None, :], x_val, mask=mask[None, :] & (offs_m[:, None] < max_m)) + +def apply_rotary(x, cos, sin, *, cu_seqlens=None, max_seqlen=None, interleaved=False, conjugate=False, inplace=False): + """ + Apply rotary position embedding to input tensor x. + + Args: + x: input tensor, shape (B, S, H, D) if cu_seqlens is None else (total_tokens, H, D) + cos: cosine values, shape (S, ROTARY_DIM//2) if max_seqlen is None else (max_seqlen, ROTARY_DIM//2) + sin: sine values, shape (S, ROTARY_DIM//2) if max_seqlen is None else (max_seqlen, ROTARY_DIM//2) + cu_seqlens: optional cumulative sequence lengths, shape (B+1,) for variable length sequences + max_seqlen: maximum sequence length in batch, required when cu_seqlens is not None + interleaved: whether the input uses interleaved format + conjugate: if True, apply complex conjugate + inplace: if True, modify x in place + + Returns: + The rotary-applied tensor. + """ + if cu_seqlens is not None: + total_tokens, nheads, headdim = x.shape + batch = cu_seqlens.numel() - 1 + assert max_seqlen is not None + seqlen = max_seqlen + else: + batch, seqlen, nheads, headdim = x.shape + + rotary_dim = cos.shape[-1] * 2 + assert rotary_dim <= headdim + assert sin.shape == cos.shape + + grid = (batch, nheads, (seqlen + 63) // 64) # BLOCK_M=64 will be handled in kernel + + if not inplace: + out = torch.empty_like(x) + else: + out = x + + # Determine strides + if x.ndim == 4: + x = x.contiguous() + stride_xb, stride_xm, stride_xh, stride_xd = x.stride() + else: + x = x.contiguous() + stride_xm, stride_xh, stride_xd = x.stride() + stride_xb = 0 # Not used in variable length + + if out.ndim == 4: + out = out.contiguous() + stride_outb, stride_outm, stride_outh, stride_outd = out.stride() + else: + out = out.contiguous() + stride_outm, stride_outh, stride_outd = out.stride() + stride_outb = 0 # Not used in variable length + + cos = cos.contiguous() + sin = sin.contiguous() + stride_cosm, stride_cosd = cos.stride() + stride_sinm, stride_sind = sin.stride() + + BLOCK_M = 64 + BLOCK_N = 32 + + rotary_kernel[grid]( + x, cos, sin, cu_seqlens, out, + HEAD_SIZE=headdim, + ROTARY_DIM=rotary_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + IS_INTERLEAVED=interleaved, + CONJUGATE=conjugate, + stride_xb=stride_xb, + stride_xm=stride_xm, + stride_xh=stride_xh, + stride_xd=stride_xd, + stride_outb=stride_outb, + stride_outm=stride_outm, + stride_outh=stride_outh, + stride_outd=stride_outd, + stride_cosm=stride_cosm, + stride_cosd=stride_cosd, + stride_sinm=stride_sinm, + stride_sind=stride_sind, + num_warps=4 + ) + + return out + + +def test_apply_rotary(): + results = {} + + # Test case 1: Basic test with fixed sequence length and no interleaving + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin) + results['test_case_1'] = output.shape + + # Test case 2: Variable length sequences with interleaving + total_seqlen, nheads, headdim = 256, 4, 64 + batch = 3 + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + max_seqlen = 128 + rotary_dim = 32 + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + results['test_case_2'] = output.shape + + # Test case 3: Conjugate flag enabled + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, conjugate=True) + results['test_case_3'] = output.shape + + # Test case 4: Inplace operation + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, inplace=True) + results['test_case_4'] = output.shape + + return results + +result_gold = test_apply_rotary() + + +################################################################################################################################################## + + + +import torch + +def test_apply_rotary(): + results = {} + + # Test case 1: Basic test with fixed sequence length and no interleaving + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin) + results['test_case_1'] = output.shape + + # Test case 2: Variable length sequences with interleaving + total_seqlen, nheads, headdim = 256, 4, 64 + batch = 3 + cu_seqlens = torch.tensor([0, 100, 200, 256], device='cuda') + max_seqlen = 128 + rotary_dim = 32 + x = torch.randn(total_seqlen, nheads, headdim, device='cuda') + cos = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(max_seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=True) + results['test_case_2'] = output.shape + + # Test case 3: Conjugate flag enabled + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, conjugate=True) + results['test_case_3'] = output.shape + + # Test case 4: Inplace operation + batch, seqlen, nheads, headdim = 2, 128, 4, 64 + rotary_dim = 32 + x = torch.randn(batch, seqlen, nheads, headdim, device='cuda') + cos = torch.randn(seqlen, rotary_dim // 2, device='cuda') + sin = torch.randn(seqlen, rotary_dim // 2, device='cuda') + output = apply_rotary(x, cos, sin, inplace=True) + results['test_case_4'] = output.shape + + return results + +result_gold = test_apply_rotary() diff --git a/src/temp/sin_kernel.py b/src/temp/sin_kernel.py new file mode 100644 index 0000000..c4a268c --- /dev/null +++ b/src/temp/sin_kernel.py @@ -0,0 +1,45 @@ +\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 32}),\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n triton.Config({'BLOCK_SIZE': 4096}),\n ],\n key=['n_elements'],\n)\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = tl.math.sin(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef call_kernel(x: torch.Tensor) -> torch.Tensor: + n_elements = x.numel() + output = torch.empty_like(x) + if n_elements > 0: + BLOCK_SIZE = 1024 # ensure BLOCK_SIZE is a multiple of 32/64 + grid = lambda meta: (triton.cdiv(n_elements, BLOCK_SIZE),) + kernel_function[grid]( + x, output, n_elements, + BLOCK_SIZE=BLOCK_SIZE + ) + return output\n n_elements = x.numel()\n output = torch.empty_like(x)\n if n_elements > 0:\n BLOCK_SIZE = 1024 # ensure BLOCK_SIZE is a multiple of 32/64\n grid = lambda meta: (triton.cdiv(n_elements, BLOCK_SIZE),)\n kernel_function[grid](\n x, output, n_elements,\n BLOCK_SIZE=BLOCK_SIZE\n )\n return output\n\n# Function to test the Triton kernel\ndef test_call_kernel():\n results = {}\n \n # Test case 1: Small input tensor\n x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda()\n output1 = call_kernel(x1)\n results['test_case_1'] = output1\n \n # Test case 2: Larger input tensor\n x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda()\n output2 = call_kernel(x2)\n results['test_case_2'] = output2\n\n # Test case 3: Edge case with zero elements\n x3 = torch.tensor([], dtype=torch.float32).cuda()\n output3 = call_kernel(x3)\n results['test_case_3'] = output3\n\n # Test case 4: Input tensor with negative values\n x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda()\n output4 = call_kernel(x4)\n results['test_case_4'] = output4\n \n return results\n\n# Run the test function\nresult_gold = test_call_kernel() +################################################################################################################################################## + + + +import torch + +# Function to test the Triton kernel +def test_call_kernel(): + results = {} + + # Test case 1: Small input tensor + x1 = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32).cuda() + output1 = call_kernel(x1) + results['test_case_1'] = output1 + + # Test case 2: Larger input tensor + x2 = torch.linspace(0, 10, steps=1024, dtype=torch.float32).cuda() + output2 = call_kernel(x2) + results['test_case_2'] = output2 + + # Test case 3: Edge case with zero elements + x3 = torch.tensor([], dtype=torch.float32).cuda() + output3 = call_kernel(x3) + results['test_case_3'] = output3 + + # Test case 4: Input tensor with negative values + x4 = torch.tensor([-1.0, -2.0, -3.0, -4.0], dtype=torch.float32).cuda() + output4 = call_kernel(x4) + results['test_case_4'] = output4 + + return results + +# Run the test function +result_gold = test_call_kernel() diff --git a/src/temp/triton_matmul.py b/src/temp/triton_matmul.py new file mode 100644 index 0000000..d074506 --- /dev/null +++ b/src/temp/triton_matmul.py @@ -0,0 +1,44 @@ +\n\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n pid_m = pid // grid_n\n pid_n = pid % grid_n\n\n offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_K)\n\n # Allocate shared memory for 2-stage pipeline\n a_shared = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)\n b_shared = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)\n a_shared_next = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)\n b_shared_next = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32)\n\n # Initialize pointers\n a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n # Prologue: Load first stage\n k_idx = 0 + offs_k\n a_mask = k_idx[None, :] < K\n b_mask = k_idx[:, None] < K\n a_shared = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b_shared = tl.load(b_ptrs, mask=b_mask, other=0.0)\n\n for k in range(0, K, BLOCK_K):\n k_idx_next = k + BLOCK_K + offs_k\n a_mask_next = k_idx_next[None, :] < K\n b_mask_next = k_idx_next[:, None] < K\n \n # Compute current stage\n accumulator += tl.dot(a_shared, b_shared)\n \n # Prefetch next stage into registers (pipelining)\n if k < K - BLOCK_K:\n a_shared_next = tl.load(a_ptrs + BLOCK_K * stride_ak, mask=a_mask_next, other=0.0)\n b_shared_next = tl.load(b_ptrs + BLOCK_K * stride_bk, mask=b_mask_next, other=0.0)\n \n # Move pointers forward\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n \n # Swap current and next stage\n a_shared, a_shared_next = a_shared_next, a_shared\n b_shared, b_shared_next = b_shared_next, b_shared\n\n offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n tl.store(c_ptrs, accumulator.to(tl.float16), mask=mask_c)\n\ndef matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.shape[1] == b.shape[0], "Incompatible dimensions for matrix multiplication" + M, K = a.shape + K2, N = b.shape + assert K == K2, f"Inner dimension mismatch: {K} vs {K2}" + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + BLOCK_M = 64 + BLOCK_N = 64 + BLOCK_K = 32 + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + return c\n assert a.shape[1] == b.shape[0], "Incompatible dimensions for matrix multiplication"\n M, K = a.shape\n K2, N = b.shape\n assert K == K2, f"Inner dimension mismatch: {K} vs {K2}"\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n\n BLOCK_M = 64\n BLOCK_N = 64\n BLOCK_K = 32\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)\n\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_K=BLOCK_K,\n )\n return c\n\n# Test for matmul\ndef test_matmul():\n results = {}\n M, K, N = 256, 128, 256\n\n # Test case 1: torch.float16\n a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n b = torch.randn((K, N), dtype=torch.float16, device='cuda')\n c = matmul(a, b)\n results['test_case_1'] = c\n\n return results\n\n# Run all tests\nresult_gold = test_matmul()\n +################################################################################################################################################## + + + +import torch + +# Test for matmul +def test_matmul(): + results = {} + M, K, N = 256, 128, 256 + + # Test case 1: torch.float16 + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + c = matmul(a, b) + results['test_case_1'] = c + + return results + +# Run all tests +result_gold = test_matmul() \ No newline at end of file diff --git a/src/test_main_copy.py b/src/test_main_copy.py new file mode 100644 index 0000000..1d512f1 --- /dev/null +++ b/src/test_main_copy.py @@ -0,0 +1,107 @@ + +import os +from agents.reflexion_oneshot import Reflexion_Oneshot +from models.KimiK2 import KimiK2Model +from dataloaders.TritonBench import TritonBench +from args_config import load_config +import json +from prompts.Baseline_Prompt import Baseline_Prompt # Import the new Baseline_Prompt + + +# --- Pre-defined inputs from Agent 1 --- +# This is the JSON output we got from our previous test for triton_matmul.py +MATMUL_ANALYSIS_JSON = { + "type": "matmul", + "key_parameters": [ + "BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K", + "num_stages", "num_warps", "K", "M", "N" + ], + "optimization_hints": [ + "Tune BLOCK_SIZE_* to maximize occupancy and shared-memory utilization", + "Exploit tensor-core-accelerated tl.dot with fp16/bf16 inputs and appropriate block sizes", + "Pipeline loads with num_stages>1 to overlap global memory transfers with compute" + ] +} + +# This is the JSON output we got for matrix_transpose.py +TRANSPOSE_ANALYSIS_JSON = { + "type": "memory-bound transpose", + "key_parameters": [ + "SIZE_M", "D_HEAD", "matrix_stridex", "matrix_stridey", + "out_stridex", "out_stridey" + ], + "optimization_hints": [ + "use blocked 2-D thread indexing with tile-based shared memory to enable coalesced reads and writes", + "vectorize loads/stores via tl.load/store with appropriate masks and vector widths (e.g., 4 or 8 fp16 elements)", + "tune block size and tile dimensions to maximize occupancy and L2 cache hit rate while minimizing bank conflicts" + ] +} + +AGENT1_OUTPUTS = { + 'triton_matmul.py': MATMUL_ANALYSIS_JSON, + 'matrix_transpose.py': TRANSPOSE_ANALYSIS_JSON +} +# ----------------------------------------- + + +def run_agent_2_test(model, dataset): + print("\n" + "="*80) + print(">>> RUNNING TEST FOR AGENT 2: BASELINE IMPLEMENTER <<<") + print("="*80 + "\n") + + baseline_prompt_generator = Baseline_Prompt() + kernels_to_test = ['triton_matmul.py', 'matrix_transpose.py'] + + print(f"Starting baseline code generation for {len(kernels_to_test)} kernels.") + print("-" * 80) + + for problem in dataset.problem_states: + if hasattr(problem, 'filename') and problem.filename in kernels_to_test: + print(f"\n>>> Generating Baseline for Kernel: {problem.filename}") + + # Get the corresponding analysis from our pre-defined dictionary + analysis_json = AGENT1_OUTPUTS[problem.filename] + print("\n--- Input from Agent 1 (Analysis JSON) ---") + print(json.dumps(analysis_json, indent=2)) + + prompt_messages = baseline_prompt_generator.get_prompt(ps=problem, analysis_json=analysis_json) + + print("\n...Calling KimiK2 API to generate baseline code...") + try: + baseline_code = model.generate(messages=prompt_messages, temperature=0.0) + print("\n--- LLM Raw Output (Baseline Code) ---") + print(baseline_code) + + except Exception as e: + import traceback + print(f"\n[ERROR] Test failed for kernel {problem.filename}. Reason: {e}") + traceback.print_exc() + + print("-" * 80) + +def main(): + args = load_config("configs/tritonbench_oneshot_config.yaml") + model = KimiK2Model(api_key=args.api_key, model_id=args.model_id) + + result_path = None + dataset = TritonBench(statis_path=args.statis_path, + py_folder=args.py_folder, + instruction_path=args.instruction_path, + py_interpreter=args.py_interpreter, + golden_metrics=args.golden_metrics, + perf_ref_folder=args.perf_ref_folder, + perf_G_path=args.perf_G_path, + result_path=result_path, + target_kernels=args.target_kernels) + + # <<<<<<<< INJECTING AGENT 2 TEST LOGIC >>>>>>>> + run_agent_2_test(model, dataset) + print("\nTest finished. Exiting.") + exit() + # <<<<<<<<<<<<<<<<<<<->>>>>>>>>>>>>>>>>> + + agent = Reflexion_Oneshot(model=model, dataset=dataset, corpus_path=args.corpus_path) + agent.run(output_path=args.output_path, multi_thread=args.multi_thread, iteration_num=args.max_iteration, temperature=args.temperature, datalen=None) + +if __name__ == "__main__": + main() diff --git a/src/utils/__pycache__/utils.cpython-312.pyc b/src/utils/__pycache__/utils.cpython-312.pyc index 5240a44343db32ccd23713863f8a830c080e8631..e5167ce31cd2d16c4f83e55fc3f0ebe149ec7e14 100644 GIT binary patch literal 6181 zcmds5TW}NC89poR>Ta!#zjLLSop?CN41 z!?e@MOV3FExu5f&%m01<*}s>RI1q&IZvA28y-I|>B$<4$S z>3LJXNie;P`ptqFBrVV&8G!+55iFuru)^O4f4gWMvI<>3*BwguLsa@)9VwPf>pN zT{2~Fn99veBBF&E{`dX@Vg?PPBqdNulU!&sYJ!S2K`|-lMy@<*R?LUc*)D{Vv_g+E zYnnngxZ|zOg3*yj# z)YN!}Z%G7O`ggQ+hR!zm2c+!-jg4oX_n+PFAMoELy4)qoCunHSX;h;|`QM>66Baew zP$VWqWD#r3Fpf`*YZiWdT#N~Rt44?8f~b`YO~k@-BpwSvUrDo$^D?xKY1T1b4v&bE z#^?#nPTGZH{Fta&#(B(-NfKEVXM_hWfyp>NFOBnIF)+f1&-3y~Jl4Vwi!nJM;c#F= zjzlFxwvJC}KJkK#`LG-+%q~TSW4t_pMG2R~xMbHQ5yUiF^j1utSai2f_bj3-ScHxRf<{kR?i%}c5IG%tv6*~EUTFDUGvTUK;sl9=b&|6zI3;caVU)OeyF^gxeJ~ zTDV7~j}Zhae|mzc8j40q&ox0X3v~X5ty|9&>yvOJ^!dDxn9ov}&r0S?A_c*HvGpeG z=gyrQpCWsI_d5s)N0$;A3&D9S`!$rXb}8M@o>cZH%+L&z8{;==j5HpN$eI<4<54(0 zKZQ%6LL$rLuotA}2uDR8hsbroHBeAS>VFOb&c^1sDyuf%BI~)5xRjV}%CKAKu*z=F zvTZ8cw$Sy#h75Zs%N|$RFbisTYO;P|28s*R^E@5fB*xHoz>&BmOH4aSr zw}DUwB!9v>uybHw+d!M<=zi|GP~Yj^lfhH|5{;Ip#zl=5B4JrGOEShjmuWFrqq zWCQRYkH;in7)VV-Wr-XO_b>{LO7R%p3FTYhFZ~9@G+MGcQ%8RF(wkn@TD!$~~(MU|x z;Z-ZwOWQ;7aXB&;Nr-}hy1ISCrT{<2!mG)pU`2m~(iyntMLSy8%_-Pc(hN=!y{dJ= zia_05Pzp{AxI;StIvKg($S91QZ-z$mj$fe!6`U!Fx?HC)>)DfHI^Rfm=5njTBxwr8 zC;&-XGLbvJN-x){LGyY+yS&wo5+qy`itd?U*6Gm#P8tErHDSH7L+FY%Y#Kpf(^;_G zq*bx5XOpj*E|=0MX-isHX<9(5;xtOZF-Y2Bza6Vuueaa!Jc9h-`|Zf>x5LMCYBv^3^C{!( zAJfmRScWOZrIaXcA7JumZN6tQR$O^K``|`U@J|A!7E^%5b@|pZmro6u0D)aeR-wYC z)8INL*_YV9Ts%n?FxkSt3WVVx4s;(sBu1yWa2#VXEK8g`B60I1mK6x(2m!KlqO(ANtDSNh(7sgA8tP=!lYt; z@TnT>i^XBJ zaY>2{Mn%$@EVotHxi?|nHxSbq(s0K%!@>H?@OFqa29^l~MNjh4^O_?fMPic7$HJoK zJe+$+9TJB$T7M^P7#8J_6raFhF(iq@WFeSbRQv=P$qeqAW{XHbo3KHeLwCao6G)nc zM6tL=55|QljUJ6hVz`;~aU6{aqCutpQrt$y+C_xU4<-4{4arypg@)r%L5ziBt$Usd_8+h4b*H+N*J_M}P{ zD>kN_OWw+vBiD}1p8d?*lsdG;R?lvE+x&a>x9n|p$HK`!oqqT92mZ|NV~cg0uR1?3 zEe8j+y!l@aRO-9qX-vBszw*^(eLGd(&UD-1Pn;Rw$*iwm_4TLE@EPA=IxMI@F>R|_ zY~GV??oyk(vdzIyn}eCE18Q?{u5Z?Jz4DF9baOCOlCC;%--0&QUax+=I$ggvQ{9X>>i^t6qjwymhMnXcVtT2;MrWgEnD5HR<|yc zsnuOg zo?)Lz+n)GlWe=)8NPV+XgKCaaU*E4Ip3ug$bIS@cF%5UvEw^o3zFlS@{XmMqA^tLG zKirA_>SBqku@Tu2L=fF<@Ei`%H}?jekhw*Z`deO7f2+2N$lad9?es@BFQh-J4VFRl zw}6AlPPWHu{@6zKu;!2L6iKrTNX-o}34W-Y_o5EN?CX0^@58^fJIl=-0kA-4egG^420N22e z=Wu**LXx>je#&4|9qfwD^YjGx53PrJLM+_z<0rTgc<@hzfz!xGIw?m{1~CzjGhrHt z`CU3W;SR{-Jw)s!qLYYyMC>QxM?^dg0zUb`%WE75Nde&%zXSi6q=|v#z9kqG0(ee1 zjKDtbCY8sCI8MX~B6>kcvIqQ<*s_BZi(HX14>upiJEg5eh)eV?e=g-y#)Y`zpCbf1~s(0CjYPMy4 zO{%YHKB)Ryv%cM`Z})qY>g!A$hp&OVyQysZGiv)YnfBx9T_;lAS6k=0(ykqMJXJIM zuI-x@GoCGJ*A{5$-JEu9`uZzp#R{UCIvBUK?zXM=+hsfC9!TVY3P5*|AwWlDO>iqj zADSs58T-MF<`3P}!E*D5EJf1g3`jTz{aH=xZ(QPT=$}IMZ*G_)Su!l%Mg*ZmOg8f! zq<>p6w%yh>78fR>;(mMzssPlGGztPc)c8`=V%aot0n?JaWAo^gJ9_ei%n?CdWeNwFJ@^7jN BgjfIo delta 120 zcmZ2#&?T&XnwOW00SF4>=43>$GB7*_abSQ6%J^(DQQefKl2MawPAn}2jk&*E>1Lti9?%NDppIEpUB|k~9G75a+<6;#2EWyJlR3rve F1^`iX7QFxf diff --git a/src/utils/utils.py b/src/utils/utils.py index 9a94496..cfe5117 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -45,4 +45,105 @@ def clear_json(response): result = ast.literal_eval(response) except (SyntaxError, NameError, AttributeError): return "ERR_SYNTAX" - return result \ No newline at end of file + return result + +def safe_force_correct_signature(agent2_baseline_code: str, agent4_optimized_code: str, func_name: str) -> str: + """ + Safely corrects the signature of a function in Agent 4's optimized code to match + the signature from Agent 2's baseline code. + + It performs a safety check to ensure the optimized function's parameters are a + superset of the baseline's parameters before performing the replacement. + + Returns the corrected code, or the original optimized code if correction is not possible or safe. + """ + try: + # Step A: Extract "golden" info from Agent 2's baseline code + baseline_tree = ast.parse(agent2_baseline_code) + golden_node = None + for node in ast.walk(baseline_tree): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + golden_node = node + break + + if not golden_node: + # print(f"Warning: Golden function '{func_name}' not found in baseline code.") + return agent4_optimized_code + + golden_signature = ast.get_source_segment(agent2_baseline_code, golden_node).split(':\\n')[0] + golden_params = {arg.arg for arg in golden_node.args.args} + + # Step B: Extract "to-check" info from Agent 4's optimized code + optimized_tree = ast.parse(agent4_optimized_code) + optimized_node = None + for node in ast.walk(optimized_tree): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + optimized_node = node + break + + if not optimized_node: + # print(f"Warning: Function '{func_name}' not found in optimized code.") + return agent4_optimized_code + + optimized_params = {arg.arg for arg in optimized_node.args.args} + + # Step C: Safety Check + if not golden_params.issubset(optimized_params): + # print(f"Safety check failed: Optimized params {optimized_params} is not a superset of golden params {golden_params}.") + return agent4_optimized_code # Return original code if unsafe + + # Step D: Execute Safe String Replacement + lines = agent4_optimized_code.splitlines() + start_line_idx = optimized_node.lineno - 1 + end_line_idx = optimized_node.body[0].lineno - 1 + lines[start_line_idx : end_line_idx] = [golden_signature] + + return "\n".join(lines) + + except (SyntaxError, IndexError) as e: + # print(f"Error processing code for signature correction: {e}") + return agent4_optimized_code # Return original code on parsing error + +# Keep the old function for now, might be useful for other things, but we'll use the safe one. +def force_correct_signature(generated_code: str, golden_signature: str, func_name: str) -> str: + """ + Finds a function by name in the generated code using an AST parser and + replaces its signature with the provided golden standard signature. + This is a robust way to enforce signature correctness against LLM hallucinations. + """ + try: + tree = ast.parse(generated_code) + + target_node = None + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + target_node = node + break + + if target_node: + lines = generated_code.splitlines() + + # AST line numbers are 1-based, list indices are 0-based + start_line_idx = target_node.lineno - 1 + + # Find the end of the signature (start of the function body) + # The end line index is the line number of the first statement in the function body, minus one. + end_line_idx = target_node.body[0].lineno - 1 + + # Replace all lines from the function definition start to just before the body + # with the single, correct golden signature line. + # This handles multi-line signatures and docstrings incorrectly placed before the body. + lines[start_line_idx : end_line_idx] = [golden_signature.strip()] + + return "\\n".join(lines) + else: + # If the target function isn't found, return the original code to avoid breaking things. + # A warning could be logged here in a real application. + # print(f"Warning: Function '{func_name}' not found in generated code.") + return generated_code + + except SyntaxError as e: + # If the generated code is not valid Python, we can't parse it. + # Return the original code and let the evaluation process handle the syntax error. + # print(f"Error parsing generated code: {e}") + return generated_code \ No newline at end of file From 8758ebee5d39f326af397735d7ae66e4f161b000 Mon Sep 17 00:00:00 2001 From: yyyuyu99 Date: Sun, 24 Aug 2025 08:40:10 +0000 Subject: [PATCH 2/4] yuyu --- src/dataloaders/TritonBench.py | 1 + src/main_multi_agent.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dataloaders/TritonBench.py b/src/dataloaders/TritonBench.py index 3ed99e2..0c21964 100644 --- a/src/dataloaders/TritonBench.py +++ b/src/dataloaders/TritonBench.py @@ -17,6 +17,7 @@ from loguru import logger import glob import re +import traceback diff --git a/src/main_multi_agent.py b/src/main_multi_agent.py index 925c88f..f60190f 100644 --- a/src/main_multi_agent.py +++ b/src/main_multi_agent.py @@ -10,10 +10,9 @@ def main(): args = load_config("configs/tritonbench_oneshot_config.yaml") - # For a quick test, let's limit the number of iterations and kernels - args.max_iteration = 3 - test_cases_to_run = 2 - print(f"--- RUNNING WITH A TEST CONFIG: max_iteration = {args.max_iteration}, kernels = {test_cases_to_run} ---") + # Set up for a full competition run + args.max_iteration = 5 + print(f"--- RUNNING FULL PIPELINE: max_iteration = {args.max_iteration} ---") # setup LLM model model = KimiK2Model(api_key=args.api_key, model_id=args.model_id) @@ -39,7 +38,7 @@ def main(): multi_thread=True, iteration_num=args.max_iteration, temperature=args.temperature, - datalen=test_cases_to_run) + datalen=None) # Set to None to run on all kernels if __name__ == "__main__": From 41ea412ba76ec1733346faa0e467832a5d370af6 Mon Sep 17 00:00:00 2001 From: yyyuyu99 Date: Sun, 24 Aug 2025 08:51:00 +0000 Subject: [PATCH 3/4] yuyu --- src/configs/tritonbench_oneshot_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/configs/tritonbench_oneshot_config.yaml b/src/configs/tritonbench_oneshot_config.yaml index 55c3efb..d22765a 100644 --- a/src/configs/tritonbench_oneshot_config.yaml +++ b/src/configs/tritonbench_oneshot_config.yaml @@ -1,5 +1,5 @@ # LLM model -api_key: "wisemodel-nwuttefmuksahacsglst" +api_key: "wisemodel-bzjhhvraxdisesrfaujz" model_id: "Kimi-K2-Instruct" temperature: 1.0 From 02b24e4a1acdf2c891c3e0c34b1144cee4f907c2 Mon Sep 17 00:00:00 2001 From: yyyuyu99 Date: Sun, 24 Aug 2025 09:23:30 +0000 Subject: [PATCH 4/4] yuyu --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 370498c..e80eb70 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,41 @@ We provide a baseline agent to let you run directly. It contains a Generator, a Reflector, an Evaluator and an Optimizer. The actor generates codes according to the query and context information. The Reflector is responsible for reflecting on the generated code and the error trace if the code failed to run. The Evaluator has a cascade structure. It tests the generated code for the functionality first. If the generated code doesn't pass the functionality test, the error trace will be fedback to the Reflector. Otherwise, the Evaluator will evaluate the performance including latency and efficiency. The Optimizer gets the generated codes, which pass the evaluator's tests, and gives a strategy to optimize the code in terms of latency and efficiency. +### Core Process Improvements: Building a Robust Iterative Optimization System + +Our framework introduces a structured, multi-agent approach that separates concerns into a four-stage pipeline for the initial code generation, followed by a robust, two-step process for iterative reflection and repair. + +#### 1. The Four-Agent Pipeline for Initial Generation + +Instead of relying on a single agent, we decompose the complex task of kernel generation into a "chain of thought" executed by four specialized agents: +1. **Agent 1: The Analyst**: Receives the problem description and performs a deep analysis of the requirements, constraints, and potential challenges. +2. **Agent 2: The Baseline Implementer**: Takes the Analyst's report and generates a simple, functionally correct baseline implementation of the kernel. This serves as a solid foundation for optimization. +3. **Agent 3: The Strategist**: Analyzes the baseline code and proposes a list of concrete, high-level optimization strategies (e.g., "increase block size," "apply autotuning"). +4. **Agent 4: The Executor**: Takes one strategy at a time and applies it to the baseline code, generating the final, optimized kernel for the first iteration. + +This structured pipeline ensures a high-quality initial code generation, setting the stage for effective iterative refinement. + +#### 2. The "Code Goalkeeper": Automated Signature Correction + +**Problem:** A primary reason for evaluation failures was the LLM's tendency to generate code with incorrect function signatures (e.g., wrong function name, incorrect parameter order). Constraining the model via prompts proved to be unreliable. + +**Solution:** We introduced a "Code Goalkeeper," a deterministic post-processing step that runs immediately after code generation. This mechanism uses Python's Abstract Syntax Tree (`ast`) module to: +1. Parse the generated code and the baseline code provided by an earlier agent. +2. Perform a **safety check**: It verifies that the parameters of the generated function are a superset of the baseline function's parameters. This prevents catastrophic failures if the core logic has been fundamentally altered. +3. If the check passes, it programmatically replaces the signature of the generated function with the "golden" signature from the baseline code. + +**Impact:** This completely decouples the LLM's creative optimization task from the rigid, mechanical task of format compliance. It dramatically increases the `Call Status: True` rate, allowing the iterative process to focus on deeper logic and performance issues. + +#### 3. The "Expert Diagnostician": A Two-Step Reflection and Repair Process + +**Problem:** The original reflection process was inefficient. It would feed a raw error log back to the model, which often struggled to identify the root cause, leading to repeated, ineffective repair attempts. + +**Solution:** We re-architected the reflection phase into a structured, two-step "Diagnose-and-Repair" workflow: +1. **Step 1: Expert Diagnosis:** When a test fails, we no longer use a generic reflection prompt. Instead, we use a specialized **"Expert Diagnostician" prompt**. This prompt guides the LLM to act like a senior GPU kernel engineer. It analyzes the specific error type (`Runtime Error`, `Correctness Error`, `Poor Performance`) and the traceback to produce a concise, high-level **Correction Plan**. +2. **Step 2: Guided Repair:** This newly generated Correction Plan is then passed to a dedicated **"Code Repair" prompt**. Instead of grappling with a raw error log, this agent receives clear, expert-level instructions, enabling it to perform a much more precise and effective code fix. + +**Impact:** This structured process transforms the reflection loop from a vague, trial-and-error cycle into a focused, expert-driven debugging session. It significantly improves the agent's ability to recover from "second-layer" failures (logic and performance bugs) after the Goalkeeper has handled the initial formatting issues. + ### the Optimizer We provide previous generated codes as reference codes with their corresponding performance to the Optimizer. The number of reference codes is controlled by the arg `ancestor_num`. The reference codes are arranged in ascending order to help the Optimizer LLM find the optimization direction. We don't ask the LLM to generate new codes directly from the reference codes, instead we ask the Optimizer to analyze the reference codes first and to generate a promising strategy to optimize the code. Then we feed the generated optimization stratgey to the Generator to generate new codes.