Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 124 additions & 79 deletions ucm/store/test/e2e/nfsstore_embed_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# SOFTWARE.
#
import csv
import math
import os
import secrets
import time
Expand All @@ -35,7 +36,7 @@


def setup(
storage_backends, block_size, device_id, io_size, transferStreamNumber
storage_backends, block_size, device_id, io_size, transferStreamNumber, useDirect
) -> UcmKVStoreBase:
config = {
"storage_backends": storage_backends,
Expand All @@ -44,19 +45,41 @@ def setup(
"device": device_id,
"io_size": io_size,
"transferStreamNumber": transferStreamNumber,
"useDirect": useDirect,
}
return UcmNfsStore(config)


def make_aligned_tensor(shape, dtype, device, alignment=4096):
numl = math.prod(shape)
dtype_size = torch.tensor(1, dtype=dtype).element_size()
total_byters = numl * dtype_size

padded_bytes = total_byters + alignment
storage = torch.ByteTensor(padded_bytes).to(device)

ptr = storage.data_ptr()
offset = ptr % alignment
if offset != 0:
aligned_ptr = ptr + (alignment - offset)
else:
aligned_ptr = ptr

aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype)
tensor = aligned_storage[:numl].view(shape)
tensor.storage_ref = storage
return tensor


def make_buffers(
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
):
hashes = [secrets.token_hex(16) for _ in range(block_number)]
kv_caches = {}
for i in range(block_layer):
kv_caches[i] = torch.rand(
kv_caches[i] = make_aligned_tensor(
[kv, block_number, block_len, num_head, head_dim],
dtype=torch.bfloat16,
dtype=torch.float16,
device=f"cuda:{device_id}",
)
return hashes, kv_caches
Expand All @@ -69,6 +92,14 @@ def store_all_hashes(hashes: List[str]):
f.write(h + "\n")


def load_hashes_from_file() -> List[str]:
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
if not os.path.exists(file_path):
return []
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines()]


def embed(
store: UcmKVStoreBase,
hashes: List[str],
Expand Down Expand Up @@ -177,6 +208,8 @@ def run(
block_elem_size: int,
kv: int,
mla: bool,
useDirect: bool,
operation_mode: str = "both", # "write_only", "read_only", or "both"
) -> Tuple[float, float, float, float, float, float]:
"""
Run a single test with given parameters and return performance metrics.
Expand All @@ -196,87 +229,99 @@ def run(
w_size_sum, r_size_sum = 0.0, 0.0

store = setup(
storage_backends, block_size, device_id, io_size, transferStreamNumber
storage_backends,
block_size,
device_id,
io_size,
transferStreamNumber,
useDirect,
)

for r in range(repeat):
print(f"\n--- Round {r+1} ---")

hashes, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

results = store.create(hashes[:batch_size])
assert sum(results) == 0, "Create operation failed"

w_size, w_time, w_bw = embed(
store,
hashes[:batch_size],
kvcaches,
mla,
)
store.commit(hashes[:batch_size], True)

store_all_hashes(hashes[:batch_size])

r_size, r_time, r_bw = fetch(
store,
hashes[:batch_size],
kvcaches,
mla,
)

w_bw_list.append(w_bw)
r_bw_list.append(r_bw)
w_time_list.append(w_time)
r_time_list.append(r_time)
w_size_sum += w_size
r_size_sum += r_size

# Clean up resources
del kvcaches, hashes
torch.cuda.empty_cache()
if operation_mode in ["write_only", "both"]:
hashes, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

results = store.create(hashes[:batch_size])
assert sum(results) == 0, "Create operation failed"

w_size, w_time, w_bw = embed(
store,
hashes[:batch_size],
kvcaches,
mla,
)
store.commit(hashes[:batch_size], True)

if r == 0:
store_all_hashes(hashes[:batch_size])

w_bw_list.append(w_bw)
w_time_list.append(w_time)
w_size_sum += w_size

if operation_mode == "write_only":
del kvcaches, hashes
torch.cuda.empty_cache()

if operation_mode in ["read_only", "both"]:
if operation_mode == "read_only":
saved_hashes = load_hashes_from_file()
if not saved_hashes:
raise RuntimeError("No saved hashes found for read operation")

_, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

r_size, r_time, r_bw = fetch(
store,
saved_hashes[:batch_size],
kvcaches,
mla,
)
else:
r_size, r_time, r_bw = fetch(
store,
hashes[:batch_size],
kvcaches,
mla,
)

r_bw_list.append(r_bw)
r_time_list.append(r_time)
r_size_sum += r_size

if operation_mode == "read_only":
del kvcaches
torch.cuda.empty_cache()
else:
del kvcaches, hashes
torch.cuda.empty_cache()

del store
avg_w_bw = sum(w_bw_list) / repeat
avg_r_bw = sum(r_bw_list) / repeat
avg_w_time = sum(w_time_list) / repeat
avg_r_time = sum(r_time_list) / repeat
avg_w_size = w_size_sum / (1024**3) / repeat
avg_r_size = r_size_sum / (1024**3) / repeat
avg_w_bw = sum(w_bw_list) / len(w_bw_list) if w_bw_list else 0.0
avg_r_bw = sum(r_bw_list) / len(r_bw_list) if r_bw_list else 0.0
avg_w_time = sum(w_time_list) / len(w_time_list) if w_time_list else 0.0
avg_r_time = sum(r_time_list) / len(r_time_list) if r_time_list else 0.0
avg_w_size = w_size_sum / (1024**3) / len(w_time_list) if w_time_list else 0.0
avg_r_size = r_size_sum / (1024**3) / len(r_time_list) if r_time_list else 0.0

return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size


if __name__ == "__main__":
os.environ["UC_LOGGER_LEVEL"] = "debug"

try:
result = run(
storage_backends="/home/nfs/zht_data",
device_id=1,
repeat=1,
num_head=1,
block_len=128,
transferStreamNumber=32,
num_tokens=4096,
block_layer=61,
head_size=576,
block_elem_size=2,
kv=1,
mla=True,
)

avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size = result

except Exception as e:
print(f"Error: {e}")
import traceback

traceback.print_exc()
41 changes: 37 additions & 4 deletions ucm/store/test/e2e/nfsstore_embed_fetch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,41 @@ def run_wrapper(result_queue, *args):
result_queue.put(("error", str(e)))


def get_user_input(prompt, default=None):
if default is not None:
user_input = input(f"{prompt} (default: {default}): ").strip()
return user_input if user_input else default
else:
return input(f"{prompt}: ").strip()


def main():
storage_backends = "."
device_id = 1
mla = False
repeat = 3
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
transferStreamNumbers = [32, 64, 128]

print("1. Model Selection:")
print(" 1 - QwQ-32B")
print(" 2 - deepseek-v3")
model_choice = get_user_input("Please select model", "1")
mla = True if model_choice == "2" else False

print("\n2. GDS Transfer:")
print(" 1 - Enable GDS (default)")
print(" 2 - Disable GDS")
useDirect = get_user_input("Please select Direct IO mode", "1")
useDirect = False if useDirect == "2" else True

print("\n3. Operation Mode:")
print(" 1 - Read and Write Test (default)")
print(" 2 - Write Only Test")
print(" 3 - Read Only Test")
op_choice = get_user_input("Please select operation mode", "1")
operation_mode_map = {"1": "both", "2": "write_only", "3": "read_only"}
operation_mode = operation_mode_map.get(op_choice, "both")

if mla:
block_lens = [64, 128]
block_layer = 61
Expand All @@ -64,9 +91,11 @@ def main():
num_head_list = [1, 2, 4, 8]

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_result.csv")
csv_file = os.path.join(SCRIPT_DIR, "embed_fetch_all.csv")
need_header = not os.path.exists(csv_file)

os.makedirs(SCRIPT_DIR, exist_ok=True)

with open(csv_file, "a", newline="", encoding="utf-8") as csv_fp:
writer = csv.writer(csv_fp)

Expand Down Expand Up @@ -107,7 +136,6 @@ def main():
batch_size = int(num_tokens / block_len)
io_num = int(num_tokens / block_len * block_layer)

# Run test and get results
result_queue = multiprocessing.Queue()

process = multiprocessing.Process(
Expand All @@ -126,6 +154,8 @@ def main():
block_elem_size,
kv,
mla,
useDirect,
operation_mode,
),
)

Expand Down Expand Up @@ -165,9 +195,12 @@ def main():
f"{avg_r_bw:.4f}",
]
)

csv_fp.flush()

print(
f"WRITE COMPLETE for num_head={num_head}, num_tokens={num_tokens}"
)

print("\n" + "=" * 60 + "\n= All combinations tested =\n" + "=" * 60 + "\n")


Expand Down