diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index df3ece6bb..c5a8f9098 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -269,3 +269,4 @@ max_checkify: False inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" inference_microbenchmark_stages: "prefill,generate" inference_microbenchmark_loop_iters: 10 +inference_microbenchmark_log_file_path: "" diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index fde947bb5..3e04c6c6e 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -28,50 +28,36 @@ import pyconfig -def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): +_WARMUP_ITERS = 2 + + +def prefill_benchmark_loop(engine, params, tokens, true_length, iters): """Inner loop for benchmarking prefill step.""" - max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() - for i in range(iters): - slot = int(i % (jax.device_count() * config.per_device_batch_size)) + for _ in range(iters): prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) - decode_state = engine.insert(prefill_result, decode_state, slot=slot) - max_utils.delete_pytree(prefill_result) - jax.block_until_ready(decode_state) + jax.block_until_ready(prefill_result) end = datetime.datetime.now() - max_utils.deactivate_profiler(config) - return (end - start).total_seconds(), decode_state + max_utils.delete_pytree(prefill_result) + return (end - start).total_seconds() + def prefill_benchmark( - config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None + config, engine, params, tokens, true_length, num_model_params, iters ): - """Handles init, warmup, running prefill benchmark, and printing results.""" - if num_model_params is None: - num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params") - - prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) - decode_state = engine.insert(prefill_result, decode_state, slot=0) - max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True) - max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache") - max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True) - max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True) - max_utils.delete_pytree(prefill_result) - prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) - decode_state = engine.insert(prefill_result, decode_state, slot=0) + """Handles warmup, running prefill benchmark, and printing results.""" + for _ in range(_WARMUP_ITERS): + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + jax.block_until_ready(prefill_result) max_utils.delete_pytree(prefill_result) - jax.block_until_ready(decode_state) - - print(f"Prefill results for length {tokens.size}:\n") - profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name - time_in_s, decode_state = prefill_benchmark_loop( - config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name - ) + print(f"Prefill benchmark results for length {tokens.size}:\n") + time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters) prefill_average_ms = 1000 * time_in_s / iters prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config) tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0 print( - f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" + f"\tPrefill step average time: {prefill_average_ms:.3f} ms\n" f"\tPrefill total TFLOPs/device: {prefill_tflops_per_device:.3f}\n" f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n" ) @@ -80,10 +66,50 @@ def prefill_benchmark( "prefill_total_tflops_per_device": prefill_tflops_per_device, "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device, } + return result_dict + + +def prefill_insert_benchmark_loop( + config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name + ): + """Inner loop for benchmarking prefill and insert step.""" + max_utils.activate_profiler(config, profile_name) + start = datetime.datetime.now() + for i in range(iters): + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots)) + max_utils.delete_pytree(prefill_result) + jax.block_until_ready(decode_state) + end = datetime.datetime.now() + max_utils.deactivate_profiler(config) + return (end - start).total_seconds(), decode_state + + +def prefill_insert_benchmark( + config, engine, decode_state, params, total_slots, tokens, true_length, iters + ): + """Handles warmup, running insert benchmark, and printing results.""" + + for i in range(_WARMUP_ITERS): + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots)) + max_utils.delete_pytree(prefill_result) + jax.block_until_ready(decode_state) + + print(f"Prefill and insert benchmark results for length {tokens.size}:\n") + time_in_s, decode_state = prefill_insert_benchmark_loop( + config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}") + prefill_insert_average_ms = time_in_s / iters * 1000.0 + print( + f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n" + ) + result_dict = { + "prefill_insert_time_in_ms": prefill_insert_average_ms + } return result_dict, decode_state -def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""): +def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name): """Inner loop for benchmarking ar step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() @@ -95,32 +121,23 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= return (end - start).total_seconds(), decode_state -def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): - """Handles init, warmup, running ar benchmark, and printing results.""" - if cache_size is None: - _, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache") - if model_size is None: - _, model_size, _ = max_utils.summarize_pytree_data(params, name="Params") - global_batch_size = jax.device_count() * config.per_device_batch_size - - # Warmup - decode_state, _ = engine.generate(params, decode_state) - jax.block_until_ready(decode_state) - decode_state, _ = engine.generate(params, decode_state) +def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters): + """Handles warmup, running ar benchmark, and printing results.""" + for _ in range(_WARMUP_ITERS): + decode_state, _ = engine.generate(params, decode_state) jax.block_until_ready(decode_state) - profile_name = "autoregress" if profile_name == "" else profile_name - time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters) + time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress") seconds_per_step = time_in_s / iters ar_average_ms = seconds_per_step * 1000 - total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step + total_throughput = global_batch_size / seconds_per_step GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count() bw_per_device = GB_per_step_per_device / seconds_per_step print( f"AutoRegressive results:\n" - f"\tAR step average time: {ar_average_ms:.3f}ms\n" - f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" + f"\tAR step average time: {ar_average_ms:.3f} ms\n" + f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f} ms\n" f"\tAR global batch size: {global_batch_size}\n" f"\tAR throughput: {total_throughput:.3f} tokens/second\n" f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n" @@ -150,24 +167,54 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i return results -def write_results(results, filename=""): +def write_results(results, filename): if filename != "": with open(filename, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) def print_results_for_analyze(results): + """Print results.""" + print("\nFor usage in analyze_sharegpt.py :") + if "Prefill" in results: prefill_bucket_size_to_ms = {} for k, v in results["Prefill"].items(): prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) - print("\nFor usage in analyze_sharegpt.py :") print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") + if "Prefill_Insert" in results: + insert_bucket_size_to_ms = {} + for k, v in results["Prefill_Insert"].items(): + insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3) + print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}") + if "AutoRegressive" in results: print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") +def summarize_prefill_result(engine, params, tokens, true_length): + """Summarize Prefill result.""" + print(f"Prefill result of length {tokens.size}:\n") + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + jax.block_until_ready(prefill_result) + num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = ( + max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True) + ) + num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = ( + max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache") + ) + max_utils.delete_pytree(prefill_result) + return { + "num_prefill_logits_params": num_prefill_logits_params, + "total_prefill_logits_size": total_prefill_logits_size, + "avg_prefill_logits_param_size": avg_prefill_logits_param_size, + "num_prefill_cache_params": num_prefill_cache_params, + "total_prefill_cache_size": total_prefill_cache_size, + "avg_prefill_cache_param_size": avg_prefill_cache_param_size, + } + + def main(config): engine = maxengine.MaxEngine(config) params = engine.load_params() @@ -185,27 +232,49 @@ def main(config): benchmark_results = {} if "prefill" in stages_to_benchmark: + + benchmark_results["Prefill_Result"] = {} benchmark_results["Prefill"] = {} + benchmark_results["Prefill_Insert"] = {} + prefill_tokens = {} + prefill_true_lengths = {} + for prefill_length in prefill_lengths: - tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[prefill_length]) - benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( - config, - engine, - params, - decode_state, - tokens, - true_length, - iters=benchmark_loop_iters, - num_model_params=num_model_params + prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad( + text, vocab, is_bos=True, prefill_lengths=[prefill_length] + ) + benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result( + engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] + ) + + for prefill_length in prefill_lengths: + benchmark_results["Prefill"][prefill_length] = prefill_benchmark( + config, + engine, + params, + prefill_tokens[prefill_length], + prefill_true_lengths[prefill_length], + num_model_params, + benchmark_loop_iters + ) + + benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark( + config, + engine, + decode_state, + params, + engine.max_concurrent_decodes, + prefill_tokens[prefill_length], + prefill_true_lengths[prefill_length], + benchmark_loop_iters ) if "generate" in stages_to_benchmark: benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) + config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) - write_results(results, filename="") + write_results(results, filename=config.inference_microbenchmark_log_file_path) print_results_for_analyze(results)