Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Binary file modified src/__pycache__/args_config.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/Base.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/Reflexion.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc
Binary file not shown.
80 changes: 40 additions & 40 deletions src/agents/reflexion_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,58 +78,55 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps",
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0):
data_len = datalen if datalen else len(self.dataset)
for iter in range(iteration_num):
# Filter only failed kernels for this iteration (correctness check)
failed_memories = [mem for mem in self.memories[:data_len] if not mem.pass_call]

if not failed_memories:
logger.info(f"\n=== All kernels passed, stopping at iteration {iter} ===")
break

logger.info(f"\n=== Iteration {iter} ===")
logger.info(f"Processing {len(failed_memories)} failed kernels out of {data_len} total")

if output_path is not None:
root, extension = os.path.splitext(output_path)
iter_path = f"{root}_{iter}{extension}"

if multi_thread:
thread_num = 3

# generate solution
logger.info(f"\ngenerate solution")
with tqdm(total=data_len) as pbar:
# generate solution for failed kernels only
logger.info(f"\ngenerate solution for failed kernels")
with tqdm(total=len(failed_memories)) as pbar:
if multi_thread:

with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]}
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in failed_memories}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
for mem in failed_memories:
self.generate_solution(mem, temperature=temperature)
pbar.update(1)

"""
Run the scripts to verify whether the generated kernels can execute without errors.
To check for correctness against expected outputs, use the test_opt_correctness method from TritonBench:

if self.config.agent.output_path is not None:
root, extension = os.path.splitext(self.config.agent.output_path)
tmp_dir = f"{root}_tmp_{n}"
exe_dir = f"{root}_pass_exe_{n}"
perf_result_dir = f"{root}_perf_results_{n}"
perf_log_dir = f"{root}_perf_logs_{n}"

else:
tmp_dir = f"tmp_{n}"
exe_dir = f"pass_exe_{n}"
perf_result_dir = f"perf_results_{n}"
perf_log_dir = f"perf_logs_{n}"

for fn, mems in tqdm(current_memories.items()):
mem = mems[n]
try:
pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.code, mem.ps.filename, tmp_dir, exe_dir=exe_dir)

"""
logger.info(f"\nrun scripts on gpu")
for mem in tqdm(self.memories[:data_len]):
if mem.pass_call:
logger.info(f"\nrun correctness tests on gpu")
for mem in tqdm(failed_memories):
try:
pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(
mem.ps.solution, mem.ps.filename, "temp", exe_dir="pass_exe"
)
except Exception as e:
logger.info(f"failed to test the code due to : {e}")
mem.err_msg = f"failed to test the code due to: {e}"
continue
is_pass, err_msg = self.dataset.run_single_call(mem.ps)
if not is_pass:
mem.err_msg = err_msg

if not pass_call:
mem.err_msg = call_stderr
elif not pass_exe:
mem.err_msg = exe_stderr
else:
# Both call and execution passed - mark as successful
mem.pass_call = True
mem.err_msg = None # Clear previous error
"""
To measure kernel latency, follow these steps:

Expand All @@ -152,16 +149,17 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,

"""

# generate reflections
logger.info(f"\ngenerate reflections")
with tqdm(total=data_len) as pbar:
# generate reflections for failed kernels only
logger.info(f"\ngenerate reflections for failed kernels")
still_failed_memories = [mem for mem in failed_memories if not mem.pass_call]
with tqdm(total=len(still_failed_memories)) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]}
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in still_failed_memories}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
for mem in still_failed_memories:
self.generate_reflexion(mem, temperature=temperature)
pbar.update(1)

Expand All @@ -172,6 +170,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,

def generate_solution(self, mem, temperature=0):
if mem.pass_call:
logger.debug(f"Skipping {mem.ps.filename} - already passed")
return

# tab = "\n"
Expand Down Expand Up @@ -208,6 +207,7 @@ def generate_solution(self, mem, temperature=0):

def generate_reflexion(self, mem, temperature):
if mem.pass_call:
logger.debug(f"Skipping reflection for {mem.ps.filename} - already passed")
return
reflect_txt = prompt_for_reflection.prompt.format(
problem=mem.ps.instruction,
Expand Down
40 changes: 23 additions & 17 deletions src/agents/reflexion_oneshot_ROCm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,29 @@ class Memory(metaclass=MemoryClassMeta, field_names=["ps",
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0):
data_len = datalen if datalen else len(self.dataset)
for iter in range(iteration_num):
# Filter only failed kernels for this iteration
failed_memories = [mem for mem in self.memories[:data_len] if not mem.pass_exe]

if not failed_memories:
logger.info(f"\n=== All kernels passed, stopping at iteration {iter} ===")
break

logger.info(f"\n=== Iteration {iter} ===")
if output_path is not None:
root, extension = os.path.splitext(output_path)
iter_path = f"{root}_{iter}{extension}"
logger.info(f"Processing {len(failed_memories)} failed kernels out of {data_len} total")

if multi_thread:
thread_num = 3

# generate solution
logger.info(f"\ngenerate solution")
with tqdm(total=data_len) as pbar:
# generate solution for failed kernels only
logger.info(f"\ngenerate solution for failed kernels")
with tqdm(total=len(failed_memories)) as pbar:
if multi_thread:

with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]}
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in failed_memories}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
for mem in failed_memories:
self.generate_solution(mem, temperature=temperature)
pbar.update(1)

Expand All @@ -115,9 +119,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,
exe_dir = f"{root}_pass_exe"
perf_result_dir = f"{root}_perf_results"

for mem in tqdm(self.memories[:data_len]):
if mem.pass_exe:
continue
for mem in tqdm(failed_memories):
try:
pass_call, pass_exe, call_stdout, call_stderr, exe_stdout, exe_stderr = self.dataset.test_opt_correctness(mem.ps.solution, mem.ps.filename, tmp_dir, exe_dir=exe_dir)

Expand All @@ -131,6 +133,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,
mem.err_msg = exe_stderr
else:
mem.pass_exe = True
mem.err_msg = None # Clear previous error
"""
To measure kernel speedup, follow these steps:

Expand Down Expand Up @@ -161,16 +164,17 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,
"""


# generate reflections
logger.info(f"\ngenerate reflections")
with tqdm(total=data_len) as pbar:
# generate reflections for failed kernels only
logger.info(f"\ngenerate reflections for failed kernels")
still_failed_memories = [mem for mem in failed_memories if not mem.pass_exe]
with tqdm(total=len(still_failed_memories)) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]}
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in still_failed_memories}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
for mem in still_failed_memories:
self.generate_reflexion(mem, temperature=temperature)
pbar.update(1)

Expand All @@ -181,6 +185,7 @@ def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None,

def generate_solution(self, mem, temperature=0):
if mem.pass_exe:
logger.debug(f"Skipping {mem.ps.filename} - already passed")
return

tab = "\n"
Expand Down Expand Up @@ -217,6 +222,7 @@ def generate_solution(self, mem, temperature=0):

def generate_reflexion(self, mem, temperature):
if mem.pass_exe:
logger.debug(f"Skipping reflection for {mem.ps.filename} - already passed")
return
reflect_txt = prompt_for_reflection.prompt.format(
problem=mem.ps.instruction,
Expand Down
5 changes: 3 additions & 2 deletions src/configs/tritonbench_oneshot_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# LLM model
api_key: ""
api_key: "wisemodel-lpjwbkzybasaizealiwx"
# api_key: "wisemodel-vzelpgxleuvotybtfeqh"
model_id: "Kimi-K2-Instruct"
temperature: 1.0
temperature: 1

# TritonBench
statis_path: "/hackathon-agent/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json"
Expand Down
Binary file modified src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc
Binary file not shown.
Binary file modified src/dataloaders/__pycache__/ProblemState.cpython-312.pyc
Binary file not shown.
Binary file modified src/dataloaders/__pycache__/TritonBench.cpython-312.pyc
Binary file not shown.
143 changes: 143 additions & 0 deletions src/good/flash_decode2_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import torch
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_ob, stride_oh, stride_od, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
seq_len = tl.load(B_Seqlen + cur_batch)
block_n_size = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ
sum_exp = 0.0
max_logic = -float('inf')
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
offs_d = tl.arange(0, BLOCK_DMODEL)
for block_id in range(0, block_n_size):
ptr_tv = Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_id * stride_mid_os + offs_d * stride_mid_od
tv = tl.load(ptr_tv)
ptr_tlogic = Mid_O_LogExpSum + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + block_id * stride_mid_o_es
tlogic = tl.load(ptr_tlogic)
max_prev = max_logic
max_logic = tl.maximum(max_prev, tlogic)
sum_exp = sum_exp * tl.exp(max_prev - max_logic) + tl.exp(tlogic - max_logic)
acc = acc * tl.exp(max_prev - max_logic) + tv * tl.exp(tlogic - max_logic)
result = acc / (sum_exp + 1e-06)
ptr_out = Out + cur_batch * stride_ob + cur_head * stride_oh + offs_d * stride_od
tl.store(ptr_out, result.to(ptr_out.dtype.element_ty))

@torch.no_grad()
def flash_decode_stage2(Mid_O: torch.Tensor, Mid_O_LogExpSum: torch.Tensor, B_Seqlen: torch.Tensor, Out: torch.Tensor, block_seq: int):
batch, head_num, seq_blocks, BLOCK_DMODEL = Mid_O.shape
triton_grid = (batch, head_num)
_fwd_kernel_flash_decode_stage2[triton_grid](B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, Mid_O.stride(0), Mid_O.stride(1), Mid_O.stride(2), Mid_O.stride(3), Mid_O_LogExpSum.stride(0), Mid_O_LogExpSum.stride(1), Mid_O_LogExpSum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2)
return

##################################################################################################################################################





import torch



# Define the test function

def test_flash_decode_stage2():

# Define the parameters for different test cases

batch_size = 2

head_num = 4

seq_block_num = 3

head_dim = 64

block_seq = 16



test_cases = {

"test_case_1": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq

},

"test_case_2": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq + 1 # Different block size

},

"test_case_3": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq // 2 # Different block size

},

"test_case_4": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq * 2 # Different block size

}

}



# Execute the function for all test cases

results = {}

for key, test_case in test_cases.items():

flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"])

results[key] = test_case["Out"]



return results



# Run the test

result_gold = test_flash_decode_stage2()
Loading