From 881a565fe826dc7676f420fe4642d5d2c02cdaf3 Mon Sep 17 00:00:00 2001 From: Anatoliy Talamanov Date: Thu, 2 Jan 2025 21:00:28 +0000 Subject: [PATCH] Update text_generation.py --- tools/llm_bench/task/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index d6aebdbc3e..ad1a55ef2f 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -198,7 +198,6 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): - set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: for bs_index, in_text in enumerate(input_text_list): @@ -226,6 +225,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data log.info(out_str) gen_config = model.get_generation_config() gen_config.max_new_tokens = max_gen_tokens + gen_config.rng_seed= args["seed"] gen_config.num_beams = args["num_beams"] if args.get('draft_model', ''): config_info = "Speculative decoding config: " @@ -352,7 +352,6 @@ def token_printer(): def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): - set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: for bs_index, in_text in enumerate(input_text_list): @@ -378,6 +377,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count'] streamer.reset() gen_config = model.get_generation_config() + gen_config.rng_seed= args["seed"] gen_config.max_new_tokens = max_gen_tokens gen_config.num_beams = args["num_beams"] if args.get('draft_model', ''):