Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inference microbenchmark to time prefill only #610

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,4 @@ max_checkify: False
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
197 changes: 133 additions & 64 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,36 @@
import pyconfig


def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""):
_WARMUP_ITERS = 2


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
for i in range(iters):
slot = int(i % (jax.device_count() * config.per_device_batch_size))
for _ in range(iters):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state
max_utils.delete_pytree(prefill_result)
return (end - start).total_seconds()


def prefill_benchmark(
config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None
config, engine, params, tokens, true_length, num_model_params, iters
):
"""Handles init, warmup, running prefill benchmark, and printing results."""
if num_model_params is None:
num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params")

prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True)
max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True)
max_utils.delete_pytree(prefill_result)
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
"""Handles warmup, running prefill benchmark, and printing results."""
for _ in range(_WARMUP_ITERS):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill results for length {tokens.size}:\n")

profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name
time_in_s, decode_state = prefill_benchmark_loop(
config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name
)
print(f"Prefill benchmark results for length {tokens.size}:\n")
time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters)
prefill_average_ms = 1000 * time_in_s / iters
prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config)
tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0
print(
f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n"
f"\tPrefill step average time: {prefill_average_ms:.3f} ms\n"
f"\tPrefill total TFLOPs/device: {prefill_tflops_per_device:.3f}\n"
f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n"
)
Expand All @@ -80,10 +66,50 @@ def prefill_benchmark(
"prefill_total_tflops_per_device": prefill_tflops_per_device,
"prefill_tflops_per_sec_per_device": tflops_per_sec_per_device,
}
return result_dict


def prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name
):
"""Inner loop for benchmarking prefill and insert step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
for i in range(iters):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state


def prefill_insert_benchmark(
config, engine, decode_state, params, total_slots, tokens, true_length, iters
):
"""Handles warmup, running insert benchmark, and printing results."""

for i in range(_WARMUP_ITERS):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill and insert benchmark results for length {tokens.size}:\n")
time_in_s, decode_state = prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}")
prefill_insert_average_ms = time_in_s / iters * 1000.0
print(
f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n"
)
result_dict = {
"prefill_insert_time_in_ms": prefill_insert_average_ms
}
return result_dict, decode_state


def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""):
def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name):
"""Inner loop for benchmarking ar step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
Expand All @@ -95,32 +121,23 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=
return (end - start).total_seconds(), decode_state


def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100):
"""Handles init, warmup, running ar benchmark, and printing results."""
if cache_size is None:
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
if model_size is None:
_, model_size, _ = max_utils.summarize_pytree_data(params, name="Params")
global_batch_size = jax.device_count() * config.per_device_batch_size

# Warmup
decode_state, _ = engine.generate(params, decode_state)
jax.block_until_ready(decode_state)
decode_state, _ = engine.generate(params, decode_state)
def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters):
"""Handles warmup, running ar benchmark, and printing results."""
for _ in range(_WARMUP_ITERS):
decode_state, _ = engine.generate(params, decode_state)
jax.block_until_ready(decode_state)

profile_name = "autoregress" if profile_name == "" else profile_name
time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters)
time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress")
seconds_per_step = time_in_s / iters
ar_average_ms = seconds_per_step * 1000
total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step
total_throughput = global_batch_size / seconds_per_step

GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count()
bw_per_device = GB_per_step_per_device / seconds_per_step
print(
f"AutoRegressive results:\n"
f"\tAR step average time: {ar_average_ms:.3f}ms\n"
f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n"
f"\tAR step average time: {ar_average_ms:.3f} ms\n"
f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f} ms\n"
f"\tAR global batch size: {global_batch_size}\n"
f"\tAR throughput: {total_throughput:.3f} tokens/second\n"
f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n"
Expand Down Expand Up @@ -150,24 +167,54 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i
return results


def write_results(results, filename=""):
def write_results(results, filename):
if filename != "":
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)


def print_results_for_analyze(results):
"""Print results."""
print("\nFor usage in analyze_sharegpt.py :")

if "Prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "Prefill_Insert" in results:
insert_bucket_size_to_ms = {}
for k, v in results["Prefill_Insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3)
print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}")

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
"""Summarize Prefill result."""
print(f"Prefill result of length {tokens.size}:\n")
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = (
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
)
num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = (
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
)
max_utils.delete_pytree(prefill_result)
return {
"num_prefill_logits_params": num_prefill_logits_params,
"total_prefill_logits_size": total_prefill_logits_size,
"avg_prefill_logits_param_size": avg_prefill_logits_param_size,
"num_prefill_cache_params": num_prefill_cache_params,
"total_prefill_cache_size": total_prefill_cache_size,
"avg_prefill_cache_param_size": avg_prefill_cache_param_size,
}


def main(config):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
Expand All @@ -185,27 +232,49 @@ def main(config):

benchmark_results = {}
if "prefill" in stages_to_benchmark:

benchmark_results["Prefill_Result"] = {}
benchmark_results["Prefill"] = {}
benchmark_results["Prefill_Insert"] = {}
prefill_tokens = {}
prefill_true_lengths = {}

for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)
benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["Prefill"][prefill_length] = prefill_benchmark(
config,
engine,
params,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
num_model_params,
benchmark_loop_iters
)

benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark(
config,
engine,
decode_state,
params,
engine.max_concurrent_decodes,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
benchmark_loop_iters
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size)
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename="")
write_results(results, filename=config.inference_microbenchmark_log_file_path)
print_results_for_analyze(results)


Expand Down
Loading