Skip to content

Corrected latency statistics when batchsize is greater than 1 #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions llm_bench/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,
result_md5_list.append(hashlib.md5(result_text.encode()).hexdigest())
if num == 0:
warmup_md5[prompt_index] = result_md5_list
per_token_time = generation_time * 1000 / num_tokens
per_token_time = generation_time * 1000 / (num_tokens / args['batch_size'])
iter_data = gen_iterate_data(
num,
input_token_size * args['batch_size'],
max_output_token_size * args['batch_size'],
max_output_token_size,
num_tokens,
generation_time,
per_token_time,
Expand All @@ -157,7 +157,8 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,
warm_up=(num == 0),
max_rss_mem=max_rss_mem_consumption,
max_shared_mem=max_shared_mem_consumption,
tokenization_time=(tok_encode_time, tok_decode_time)
tokenization_time=(tok_encode_time, tok_decode_time),
batch_size=args['batch_size']
)
if num > 0:
warmup_md5_list = warmup_md5[prompt_index]
Expand All @@ -183,6 +184,7 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters

# if num_iters == 0, just output warm-up data
proc_id = os.getpid()
prompt_idx_list = [prompt_idx for prompt_idx, input_text in enumerate(input_text_list)]
if args['subsequent'] is False:
for num in range(num_iters + 1):
for prompt_idx, input_text in enumerate(input_text_list):
Expand All @@ -196,7 +198,7 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters
log.info(f'[warm-up] Input text: {input_text}')
run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, warmup_md5, prompt_idx, bench_hook, model_precision, proc_id)

utils.metrics_print.print_average(iter_data_list)
utils.metrics_print.print_average(iter_data_list, prompt_idx_list, args['batch_size'], True)
return iter_data_list, pretrain_time


Expand Down Expand Up @@ -277,6 +279,7 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter

# if num_iters == 0, just output warm-up data
proc_id = os.getpid()
prompt_idx_list = [image_id for image_id, image_param in enumerate(input_image_list)]
if args['subsequent'] is False:
for num in range(num_iters + 1):
for image_id, image_param in enumerate(input_image_list):
Expand All @@ -286,7 +289,7 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter
for num in range(num_iters + 1):
run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, proc_id)

utils.metrics_print.print_average(iter_data_list)
utils.metrics_print.print_average(iter_data_list, prompt_idx_list, args['batch_size'], False)
return iter_data_list, pretrain_time


Expand Down Expand Up @@ -397,7 +400,7 @@ def run_ldm_super_resolution_benchmark(model_path, framework, device, args, num_
run_ldm_super_resolution(img, num, pipe, args, framework, iter_data_list, image_id, tm_list, proc_id)
tm_list.clear()
image_id = image_id + 1
utils.metrics_print.print_average(iter_data_list)
utils.metrics_print.print_average(iter_data_list, [], 0, False)

return iter_data_list, pretrain_time

Expand Down
55 changes: 42 additions & 13 deletions llm_bench/python/utils/metrics_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def print_metrics(
iter_num, iter_data, tms=None, tms_infer=None, warm_up=False, max_rss_mem=-1, max_shared_mem=-1,
stable_diffusion=None, tokenization_time=None
stable_diffusion=None, tokenization_time=None, batch_size=1
):
if tms is None:
tms = []
Expand All @@ -16,6 +16,9 @@ def print_metrics(
if warm_up:
iter_str = 'warm-up'
output_str = ''
latency_unit = 'token'
if batch_size > 1:
latency_unit = '{}tokens'.format(batch_size)
if iter_data['input_size'] != '':
output_str += 'Input token size: {}, '.format(iter_data['input_size'])
if iter_data['output_size'] != '':
Expand All @@ -29,16 +32,16 @@ def print_metrics(
if iter_data['generation_time'] != '':
output_str += 'Generation Time: {:.2f}s, '.format(iter_data['generation_time'])
if iter_data['latency'] != '':
output_str += 'Latency: {:.2f} ms/token'.format(iter_data['latency'])
output_str += 'Latency: {:.2f} ms/{}'.format(iter_data['latency'], latency_unit)
if output_str != '':
output_str = ' '.join(['[{}]'.format(iter_str), output_str])
log.info(output_str)
if len(tms) > 0:
iter_data['first_token_latency'] = tms[0] * 1000 if len(tms) > 0 else -1
iter_data['other_tokens_avg_latency'] = sum(tms[1:]) / (len(tms) - 1) * 1000 if len(tms) > 1 else -1
log.info(
f"[{iter_str}] First token latency: {iter_data['first_token_latency']:.2f} ms/token, "
f"other tokens latency: {iter_data['other_tokens_avg_latency']:.2f} ms/token, len of tokens: {len(tms)}",
f"[{iter_str}] First token latency: {iter_data['first_token_latency']:.2f} ms/{latency_unit}, "
f"other tokens latency: {iter_data['other_tokens_avg_latency']:.2f} ms/{latency_unit}, len of tokens: {len(tms)} * {batch_size}",
)
if len(tms_infer) > 0:
iter_data['first_token_infer_latency'] = tms_infer[0] * 1000 if len(tms_infer) > 0 else -1
Expand Down Expand Up @@ -103,7 +106,36 @@ def print_ldm_unet_vqvae_infer_latency(iter_num, iter_data, tms=None, warm_up=Fa
f"vqvae decoder step count: 1",)


def print_average(iter_data_list):
def output_avg_statis_tokens(prompt_dict, prompt_idx_list, iter_data_list, batch_size):
for p_idx in prompt_idx_list:
avg_1st_token_latency = 0
avg_2nd_tokens_latency = 0
avg_2nd_token_tput = 0
avg_input_size = 0
index_num = 0
for iter_data in iter_data_list:
# Exclude the warm-up iteration
if iter_data['iteration'] == 0:
continue
if iter_data['prompt_idx'] == p_idx:
avg_1st_token_latency += iter_data['first_token_latency']
avg_2nd_tokens_latency += iter_data['other_tokens_avg_latency']
avg_input_size += iter_data['input_size']
index_num = index_num + 1
if index_num > 0:
avg_1st_token_latency = avg_1st_token_latency / index_num
avg_2nd_tokens_latency = avg_2nd_tokens_latency / index_num
avg_input_size = int(avg_input_size / index_num)
avg_2nd_token_tput = (1 / avg_2nd_tokens_latency) * batch_size * 1000
latency_unit = 'token'
if batch_size > 1:
latency_unit = '{}tokens'.format(batch_size)
prompt_dict[p_idx] = '\n[ INFO ] [Average] Prompt[{}] Input token size: {}, 1st token lantency: {:.2f} ms/{}, ' \
'2nd tokens latency: {:.2f} ms/{}, 2nd tokens throughput: {:.2f} tokens/s' \
.format(p_idx, avg_input_size, avg_1st_token_latency, latency_unit, avg_2nd_tokens_latency, latency_unit, avg_2nd_token_tput)


def print_average(iter_data_list, prompt_idx_list, batch_size, is_text_gen=False):
if len(iter_data_list) <= 1:
# 1st iteration is the warm-up iteration
return
Expand All @@ -123,14 +155,11 @@ def print_average(iter_data_list):
total_iters = len(iter_data_list) - warm_up_iters

if total_iters > 0:
prompt_dict = {}
if is_text_gen is True:
output_avg_statis_tokens(prompt_dict, prompt_idx_list, iter_data_list, batch_size)
log.info('<<< Warm-up iteration is excluded. >>>')
out_str = '[Total] Iterations: {}'.format(total_iters)
if total_num_tokens > 0:
out_str += ', Output size: {} tokens'.format(total_num_tokens)
if total_generation_time > 0:
avg_per_iter_time = total_generation_time / total_iters
out_str += '\n[ INFO ] [Average] Iteration time: {:.2f}s'.format(avg_per_iter_time)
if total_num_tokens > 0:
avg_per_token_time = total_generation_time * 1000 / total_num_tokens
out_str += ', Latency: {:.2f} ms/token'.format(avg_per_token_time)
for prompt_key in prompt_dict:
out_str += prompt_dict[prompt_key]
log.info(out_str)