Skip to content

Commit

Permalink
Merge pull request #597 from google:mor--inference
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626410509
  • Loading branch information
maxtext authors committed Apr 19, 2024
2 parents 6ec7556 + b46783c commit 0e1c078
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 46 deletions.
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,8 @@ vertex_tensorboard_region: ""

# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance.
max_checkify: False

# Inference
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
93 changes: 47 additions & 46 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,13 @@
import sys

from jetstream.engine import token_utils

import max_utils
import maxengine
import maxtext_utils
import pyconfig


def summarize_pytree_data(params, name="Params"):
"""Generate basic metrics of a given Pytree."""
num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params)
num_params_in_billions = num_params / 1e9
total_param_size_in_gb = total_param_size / 1e9
print(
f"{name} stats: \n"
f"\tTotal number of params: {num_params_in_billions:.3f} billion \n"
f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n"
)
return num_params, total_param_size, avg_param_size


def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""):
"""Inner loop for benchmarking prefill step."""
max_utils.activate_profiler(config, profile_name)
Expand All @@ -49,24 +36,29 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le
slot = int(i % (jax.device_count() * config.per_device_batch_size))
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state


def prefill_benchmark(
config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None
):
"""Handles init, warmup, running prefill benchmark, and printing results."""
if num_model_params is None:
num_model_params, _, _ = summarize_pytree_data(params, name="Params")
num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params")

prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
jax.block_until_ready(decode_state)
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True)
max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True)
max_utils.delete_pytree(prefill_result)
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill results for length {tokens.size}:\n")
Expand Down Expand Up @@ -106,9 +98,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=
def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100):
"""Handles init, warmup, running ar benchmark, and printing results."""
if cache_size is None:
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache")
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
if model_size is None:
_, model_size, _ = summarize_pytree_data(params, name="Params")
_, model_size, _ = max_utils.summarize_pytree_data(params, name="Params")
global_batch_size = jax.device_count() * config.per_device_batch_size

# Warmup
Expand Down Expand Up @@ -165,43 +157,52 @@ def write_results(results, filename=""):


def print_results_for_analyze(results):
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")
if "Prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")


def main(config):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
prefill_lengths = [64, 128, 256, 512, 1024]
benchmark_loop_iters = 10
prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")]
stages_to_benchmark = config.inference_microbenchmark_stages.split(",")
benchmark_loop_iters = config.inference_microbenchmark_loop_iters

text = config.prompt
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = summarize_pytree_data(params, name="Model")

benchmark_results = {"Prefill": {}}
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size
)
for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params,
)
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

benchmark_results = {}
if "prefill" in stages_to_benchmark:
benchmark_results["Prefill"] = {}
for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename="")
Expand Down
27 changes: 27 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,30 @@ def get_project():
max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project <project>'")
return None
return project_outputs[-1]


def delete_pytree(p):
def delete_leaf(leaf):
if isinstance(leaf, jax.Array):
leaf.delete()
del leaf

jax.tree_map(delete_leaf, p)


def summarize_pytree_data(params, name="Params", raw=False):
"""Generate basic metrics of a given Pytree."""
num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params)
if not raw:
num_params_in_billions = num_params / 1e9
total_param_size_in_gb = total_param_size / 1e9
print(f"{name} stats: \n"
f"\tTotal number of params: {num_params_in_billions:.3f} billion \n"
f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n")
else:
print(f"{name} stats: \n"
f"\tTotal number of params: {num_params:.3f} \n"
f"\tTotal memory usage: {total_param_size:.3f} bytes \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n")
return num_params, total_param_size, avg_param_size

0 comments on commit 0e1c078

Please sign in to comment.