Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reflexion_oneshot_tritonbench_2.json

Large diffs are not rendered by default.

Binary file modified src/__pycache__/args_config.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/Base.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/Reflexion.cpython-312.pyc
Binary file not shown.
Binary file modified src/agents/__pycache__/reflexion_oneshot.cpython-312.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion src/configs/tritonbench_oneshot_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# LLM model
api_key: ""
model_id: "Kimi-K2-Instruct"
temperature: 1.0
temperature: 0.1

# TritonBench
statis_path: "/hackathon-agent/src/dataloaders/TB_eval/data/TritonBench_G_comp_alpac_v1_hackathon.json"
Expand Down
Binary file modified src/dataloaders/TB_eval/__pycache__/utils.cpython-312.pyc
Binary file not shown.

Large diffs are not rendered by default.

Binary file modified src/dataloaders/__pycache__/ProblemState.cpython-312.pyc
Binary file not shown.
Binary file modified src/dataloaders/__pycache__/TritonBench.cpython-312.pyc
Binary file not shown.
Binary file modified src/memories/__pycache__/Memory.cpython-312.pyc
Binary file not shown.
Binary file modified src/models/__pycache__/Base.cpython-312.pyc
Binary file not shown.
Binary file modified src/models/__pycache__/KimiK2.cpython-312.pyc
Binary file not shown.
Binary file modified src/prompts/__pycache__/prompt_for_generation.cpython-312.pyc
Binary file not shown.
Binary file modified src/prompts/__pycache__/prompt_for_reflection.cpython-312.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion src/prompts/prompt_for_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
* **`tl.arange`:** Arguments `start` and `end` **must be `tl.constexpr`**.
* **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`.
8. **Triton Version:** Assume Triton version 3.1.0 or later.

9. If the input is float and double, convert to bf16 precision for calculation; if the input is int, use int8 precision for calculation
**FINAL VERIFICATION:**
Before completing, verify:
1. ALL functions defined in the code have EXACT signatures matching the required function signatures above.
Expand Down
4 changes: 4 additions & 0 deletions src/prompts/prompt_for_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
**Important Instructions:**
- Think before writing the reflection and no more explanation is required after the reflection.
- You should not suggest changes to the name of the function.
- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable
- generate the reflection wrapped in a code block with the tag `reflection`, e.g.
"```markdown<your reflections>```"

Expand Down Expand Up @@ -47,6 +48,7 @@
**Important Instructions:**
- Think before writing the reflection and no more explanation is required after the reflection.
- You should not suggest changes to the name of the function.
- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable
- generate the reflection wrapped in a code block with the tag `reflection`, e.g.
"```markdown<your reflections>```"

Expand Down Expand Up @@ -100,6 +102,7 @@
**Important Instructions:**
- Think before writing the reflection and no more explanation is required after the reflection.
- You should not suggest changes to the name of the function.
- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable
- generate the reflection wrapped in a code block with the tag `reflection`, e.g.
"```markdown<your reflections>```"

Expand Down Expand Up @@ -239,6 +242,7 @@ def grid(args: dict[str, Any]) -> tuple[int]:
**Important Instructions:**
- Think before writing the optimization and no more explanation is required after the reflection.
- You should not suggest changes to the name of the function and parameter names, counts, or order.
- Please check all variable names. Pay special attention not to write Mid_O when using the MID_O variable
- generate the reflection wrapped in a code block with the tag `reflection`, e.g.
"```markdown<your reflections>```"

Expand Down
Binary file modified src/retrievers/__pycache__/retriever.cpython-312.pyc
Binary file not shown.
131 changes: 131 additions & 0 deletions src/temp/embedding_triton_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import triton
import triton.language as tl


@triton.jit
def embedding_kernel(
tokens_ptr, # int32*
out_ptr, # weight.dtype*
weight_ptr, # weight.dtype*
seq_len,
vocab_size,
n_dim,
stride_tokens,
stride_out_n,
stride_out_d,
stride_weight_vocab,
stride_weight_dim,
BLOCK_N: tl.constexpr,
BLOCK_NN: tl.constexpr,
):
pid_seq = tl.program_id(0) # batch dimension
offs_d = tl.arange(0, n_dim)

for block_start in range(0, seq_len, BLOCK_NN):
cur_block_size = tl.minimum(BLOCK_NN, seq_len - block_start)
block_token_offs = pid_seq * stride_tokens + block_start + tl.arange(0, BLOCK_N)
mask_n = tl.arange(0, BLOCK_N) < cur_block_size
block_tokens = tl.load(tokens_ptr + block_token_offs, mask=mask_n, other=0)

offs_n = block_start + tl.arange(0, BLOCK_N)[:, None] # [BLOCK_N, 1]
offs_w = block_tokens[:, None] * stride_weight_vocab + offs_d[None, :] * stride_weight_dim # [BLOCK_N, n_dim]

w_vec = tl.load(weight_ptr + offs_w,
mask=(offs_n < seq_len)[:, None] & (offs_d[None, :] < n_dim))

offs_out = pid_seq * stride_out_n + offs_n * stride_out_d + offs_d[None, :]
tl.store(out_ptr + offs_out,
w_vec,
mask=(offs_n < seq_len)[:, None] & (offs_d[None, :] < n_dim))


def embedding(tokens: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
assert tokens.dim() == 2, "Expected tokens shape (batch, seq)"
bsz, seq_len = tokens.shape
vocab_size, n_dim = weight.shape
assert tokens.dtype in [torch.int32, torch.int64], "tokens must be int32 or int64"
output = torch.empty((bsz, seq_len, n_dim), dtype=weight.dtype, device=weight.device)

BLOCK_N = 64
BLOCK_NN = BLOCK_N
grid = (bsz,)
embedding_kernel[grid](
tokens,
output,
weight,
seq_len,
vocab_size,
n_dim,
tokens.stride(0),
output.stride(0),
output.stride(1),
weight.stride(0),
weight.stride(1),
BLOCK_N=BLOCK_N,
BLOCK_NN=BLOCK_NN,
)
return output
##################################################################################################################################################



import torch

def test_embedding():
# 参数定义
vocab_size = 1000 # 词汇表大小
embedding_dim = 512 # 嵌入维度
sequence_length = 128 # 输入序列长度
vob_start_id = 10 # 词汇表起始 ID
vob_end_id = 1000 # 词汇表结束 ID

# 创建测试输入张量
input_ids = torch.randint(
vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'
)
weight = torch.randn(
vocab_size, embedding_dim, dtype=torch.float32, device='cuda'
)
out = torch.zeros(
sequence_length, embedding_dim, dtype=torch.float32, device='cuda'
)

# 调用嵌入函数
embedding(input_ids, weight, vob_start_id, vob_end_id, out)

# 保存结果
results = {}
results['test_case_1'] = out.clone()

# 测试不同的输入
input_ids = torch.randint(
vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'
)
embedding(input_ids, weight, vob_start_id, vob_end_id, out)
results['test_case_2'] = out.clone()

# 测试不同的词汇表范围
vob_start_id = 0
vob_end_id = 500
input_ids = torch.randint(
vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'
)
embedding(input_ids, weight, vob_start_id, vob_end_id, out)
results['test_case_3'] = out.clone()

# 测试不同的嵌入维度
embedding_dim = 256
weight = torch.randn(
vocab_size, embedding_dim, dtype=torch.float32, device='cuda'
)
out = torch.zeros(
sequence_length, embedding_dim, dtype=torch.float32, device='cuda'
)
embedding(input_ids, weight, vob_start_id, vob_end_id, out)
results['test_case_4'] = out.clone()

return results

result_gold = test_embedding()
147 changes: 147 additions & 0 deletions src/temp/flash_decode2_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@

import torch
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel_flash_decode_stage2(
B_Seqlen,
Mid_O,
Mid_O_LogExpSum,
Out,
stride_mid_o_b,
stride_mid_o_h,
stride_mid_o_s,
stride_mid_o_d,
stride_mid_lse_b,
stride_mid_lse_h,
stride_mid_lse_s,
stride_out_b,
stride_out_h,
stride_out_d,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)

cur_seq_len = tl.load(B_Seqlen + cur_batch)
block_n_size = (cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ

sum_exp = tl.full([], 0.0, dtype=tl.float32)
max_logic = tl.full([], -float("inf"), dtype=tl.float32)
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)

for block_id in range(0, block_n_size):
offs_d = tl.arange(0, BLOCK_DMODEL)
ptr_mid = Mid_O + cur_batch * stride_mid_o_b + cur_head * stride_mid_o_h + block_id * stride_mid_o_s + offs_d * stride_mid_o_d
tv = tl.load(ptr_mid).to(tl.float32)

ptr_lse = Mid_O_LogExpSum + cur_batch * stride_mid_lse_b + cur_head * stride_mid_lse_h + block_id * stride_mid_lse_s
tlogic = tl.load(ptr_lse).to(tl.float32)

new_max = tl.maximum(max_logic, tlogic)
scale = tl.exp(max_logic - new_max)
acc = acc * scale
sum_exp = sum_exp * scale
exp_di = tl.exp(tlogic - new_max)
sum_exp += exp_di
acc += tv * exp_di
max_logic = new_max

acc_norm = acc / sum_exp

offs_out_d = tl.arange(0, BLOCK_DMODEL)
ptr_out = Out + cur_batch * stride_out_b + cur_head * stride_out_h + offs_out_d * stride_out_d
tl.store(ptr_out, acc_norm.to(Out.type.element_ty))


@torch.no_grad()
def flash_decode_stage2(
B_Seqlen: torch.Tensor,
Mid_O: torch.Tensor,
Mid_O_LogExpSum: torch.Tensor,
Out: torch.Tensor,
BLOCK_SEQ: int
):
BLOCK_DMODEL = Mid_O.shape[-1]
batch = B_Seqlen.shape[0]
head_num = Mid_O.shape[1]
grid = (batch, head_num)
_fwd_kernel_flash_decode_stage2[grid](
B_Seqlen,
Mid_O,
Mid_O_LogExpSum,
Out,
Mid_O.stride(0),
Mid_O.stride(1),
Mid_O.stride(2),
Mid_O.stride(3),
Mid_O_LogExpSum.stride(0),
Mid_O_LogExpSum.stride(1),
Mid_O_LogExpSum.stride(2),
Out.stride(0),
Out.stride(1),
Out.stride(2),
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=BLOCK_DMODEL,
num_warps=4,
num_stages=1,
)

##################################################################################################################################################



import torch

# Define the test function
def test_flash_decode_stage2():
# Define the parameters for different test cases
batch_size = 2
head_num = 4
seq_block_num = 3
head_dim = 64
block_seq = 16

test_cases = {
"test_case_1": {
"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),
"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),
"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),
"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),
"block_seq": block_seq
},
"test_case_2": {
"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),
"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),
"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),
"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),
"block_seq": block_seq + 1 # Different block size
},
"test_case_3": {
"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),
"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),
"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),
"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),
"block_seq": block_seq // 2 # Different block size
},
"test_case_4": {
"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),
"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),
"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),
"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),
"block_seq": block_seq * 2 # Different block size
}
}

# Execute the function for all test cases
results = {}
for key, test_case in test_cases.items():
flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"])
results[key] = test_case["Out"]

return results

# Run the test
result_gold = test_flash_decode_stage2()
Loading