Skip to content

Commit

Permalink
Update text_generation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov authored Jan 2, 2025
1 parent 28c37d4 commit 881a565
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tools/llm_bench/task/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,

def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index,
streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption):
set_seed(args['seed'])
input_text_list = [input_text] * args['batch_size']
if args["output_dir"] is not None and num == 0:
for bs_index, in_text in enumerate(input_text_list):
Expand Down Expand Up @@ -226,6 +225,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data
log.info(out_str)
gen_config = model.get_generation_config()
gen_config.max_new_tokens = max_gen_tokens
gen_config.rng_seed= args["seed"]
gen_config.num_beams = args["num_beams"]
if args.get('draft_model', ''):
config_info = "Speculative decoding config: "
Expand Down Expand Up @@ -352,7 +352,6 @@ def token_printer():

def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, args, iter_data_list, md5_list,
prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption):
set_seed(args['seed'])
input_text_list = [input_text] * args['batch_size']
if args["output_dir"] is not None and num == 0:
for bs_index, in_text in enumerate(input_text_list):
Expand All @@ -378,6 +377,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg
max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count']
streamer.reset()
gen_config = model.get_generation_config()
gen_config.rng_seed= args["seed"]
gen_config.max_new_tokens = max_gen_tokens
gen_config.num_beams = args["num_beams"]
if args.get('draft_model', ''):
Expand Down

0 comments on commit 881a565

Please sign in to comment.