Skip to content

Commit

Permalink
Merge pull request #610 from google:mor--inference
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627082389
  • Loading branch information
maxtext authors committed Apr 22, 2024
2 parents 25adb3d + c2e5b5e commit c0bef1c
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 64 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,4 @@ max_checkify: False
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
197 changes: 133 additions & 64 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,36 @@
import pyconfig


def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""):
_WARMUP_ITERS = 2


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
for i in range(iters):
slot = int(i % (jax.device_count() * config.per_device_batch_size))
for _ in range(iters):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state
max_utils.delete_pytree(prefill_result)
return (end - start).total_seconds()


def prefill_benchmark(
config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None
config, engine, params, tokens, true_length, num_model_params, iters
):
"""Handles init, warmup, running prefill benchmark, and printing results."""
if num_model_params is None:
num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params")

prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True)
max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True)
max_utils.delete_pytree(prefill_result)
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
"""Handles warmup, running prefill benchmark, and printing results."""
for _ in range(_WARMUP_ITERS):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill results for length {tokens.size}:\n")

profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name
time_in_s, decode_state = prefill_benchmark_loop(
config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name
)
print(f"Prefill benchmark results for length {tokens.size}:\n")
time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters)
prefill_average_ms = 1000 * time_in_s / iters
prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config)
tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0
print(
f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n"
f"\tPrefill step average time: {prefill_average_ms:.3f} ms\n"
f"\tPrefill total TFLOPs/device: {prefill_tflops_per_device:.3f}\n"
f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n"
)
Expand All @@ -80,10 +66,50 @@ def prefill_benchmark(
"prefill_total_tflops_per_device": prefill_tflops_per_device,
"prefill_tflops_per_sec_per_device": tflops_per_sec_per_device,
}
return result_dict


def prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name
):
"""Inner loop for benchmarking prefill and insert step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
for i in range(iters):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state


def prefill_insert_benchmark(
config, engine, decode_state, params, total_slots, tokens, true_length, iters
):
"""Handles warmup, running insert benchmark, and printing results."""

for i in range(_WARMUP_ITERS):
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill and insert benchmark results for length {tokens.size}:\n")
time_in_s, decode_state = prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}")
prefill_insert_average_ms = time_in_s / iters * 1000.0
print(
f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n"
)
result_dict = {
"prefill_insert_time_in_ms": prefill_insert_average_ms
}
return result_dict, decode_state


def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""):
def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name):
"""Inner loop for benchmarking ar step."""
max_utils.activate_profiler(config, profile_name)
start = datetime.datetime.now()
Expand All @@ -95,32 +121,23 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=
return (end - start).total_seconds(), decode_state


def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100):
"""Handles init, warmup, running ar benchmark, and printing results."""
if cache_size is None:
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
if model_size is None:
_, model_size, _ = max_utils.summarize_pytree_data(params, name="Params")
global_batch_size = jax.device_count() * config.per_device_batch_size

# Warmup
decode_state, _ = engine.generate(params, decode_state)
jax.block_until_ready(decode_state)
decode_state, _ = engine.generate(params, decode_state)
def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters):
"""Handles warmup, running ar benchmark, and printing results."""
for _ in range(_WARMUP_ITERS):
decode_state, _ = engine.generate(params, decode_state)
jax.block_until_ready(decode_state)

profile_name = "autoregress" if profile_name == "" else profile_name
time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters)
time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress")
seconds_per_step = time_in_s / iters
ar_average_ms = seconds_per_step * 1000
total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step
total_throughput = global_batch_size / seconds_per_step

GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count()
bw_per_device = GB_per_step_per_device / seconds_per_step
print(
f"AutoRegressive results:\n"
f"\tAR step average time: {ar_average_ms:.3f}ms\n"
f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n"
f"\tAR step average time: {ar_average_ms:.3f} ms\n"
f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f} ms\n"
f"\tAR global batch size: {global_batch_size}\n"
f"\tAR throughput: {total_throughput:.3f} tokens/second\n"
f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n"
Expand Down Expand Up @@ -150,24 +167,54 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i
return results


def write_results(results, filename=""):
def write_results(results, filename):
if filename != "":
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)


def print_results_for_analyze(results):
"""Print results."""
print("\nFor usage in analyze_sharegpt.py :")

if "Prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "Prefill_Insert" in results:
insert_bucket_size_to_ms = {}
for k, v in results["Prefill_Insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3)
print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}")

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
"""Summarize Prefill result."""
print(f"Prefill result of length {tokens.size}:\n")
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = (
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
)
num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = (
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
)
max_utils.delete_pytree(prefill_result)
return {
"num_prefill_logits_params": num_prefill_logits_params,
"total_prefill_logits_size": total_prefill_logits_size,
"avg_prefill_logits_param_size": avg_prefill_logits_param_size,
"num_prefill_cache_params": num_prefill_cache_params,
"total_prefill_cache_size": total_prefill_cache_size,
"avg_prefill_cache_param_size": avg_prefill_cache_param_size,
}


def main(config):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
Expand All @@ -185,27 +232,49 @@ def main(config):

benchmark_results = {}
if "prefill" in stages_to_benchmark:

benchmark_results["Prefill_Result"] = {}
benchmark_results["Prefill"] = {}
benchmark_results["Prefill_Insert"] = {}
prefill_tokens = {}
prefill_true_lengths = {}

for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)
benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["Prefill"][prefill_length] = prefill_benchmark(
config,
engine,
params,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
num_model_params,
benchmark_loop_iters
)

benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark(
config,
engine,
decode_state,
params,
engine.max_concurrent_decodes,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
benchmark_loop_iters
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size)
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename="")
write_results(results, filename=config.inference_microbenchmark_log_file_path)
print_results_for_analyze(results)


Expand Down

0 comments on commit c0bef1c

Please sign in to comment.