diff --git a/tools/who_what_benchmark/whowhatbench/model_loaders.py b/tools/who_what_benchmark/whowhatbench/model_loaders.py index 8a00c70852..c792a3c0b2 100644 --- a/tools/who_what_benchmark/whowhatbench/model_loaders.py +++ b/tools/who_what_benchmark/whowhatbench/model_loaders.py @@ -41,8 +41,19 @@ def load_text_genai_pipeline(model_dir, device="CPU", ov_config=None): return GenAIModelWrapper(openvino_genai.LLMPipeline(model_dir, device=device, **ov_config), model_dir, "text") +def load_text_llamacpp_pipeline(model_dir): + try: + from llama_cpp import Llama + except ImportError: + logger.error( + "Failed to import llama_cpp package. Please install llama-cpp-python.") + exit(-1) + model = Llama(model_dir) + return model + + def load_text_model( - model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False + model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, ): if use_hf: logger.info("Using HF Transformers API") @@ -53,6 +64,9 @@ def load_text_model( elif use_genai: logger.info("Using OpenVINO GenAI API") model = load_text_genai_pipeline(model_id, device, ov_config) + elif use_llamacpp: + logger.info("Using llama.cpp API") + model = load_text_llamacpp_pipeline(model_id) else: logger.info("Using Optimum API") from optimum.intel.openvino import OVModelForCausalLM @@ -276,7 +290,7 @@ def load_inpainting_model( def load_model( - model_type, model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False + model_type, model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False ): if model_id is None: return None @@ -288,7 +302,7 @@ def load_model( ov_options = {} if model_type == "text": - return load_text_model(model_id, device, ov_options, use_hf, use_genai) + return load_text_model(model_id, device, ov_options, use_hf, use_genai, use_llamacpp) elif model_type == "text-to-image": return load_text2image_model( model_id, device, ov_options, use_hf, use_genai diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 50ce224def..433521a186 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -108,6 +108,7 @@ def __init__( generation_config=None, generation_config_base=None, seqs_per_request=None, + use_chat_template=None, ) -> None: assert ( base_model is not None or gt_data is not None @@ -123,6 +124,7 @@ def __init__( self.generation_config_base = generation_config self.seqs_per_request = seqs_per_request self.generation_fn = gen_answer_fn + self.use_chat_template = use_chat_template if self.generation_config is not None: assert self.seqs_per_request is not None @@ -202,15 +204,21 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): return res def _generate_data(self, model, gen_answer_fn=None, generation_config=None): - def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question): - inputs = self.tokenizer(prompt, return_tensors="pt") - - tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) - - if crop_question: - tokens = tokens[:, inputs["input_ids"].shape[-1] :] - - return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] + def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False): + if use_chat_template: + message = [{"role": "user", "content": prompt}] + inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt") + tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) + if crop_question: + tokens = tokens[:, inputs.shape[-1]:] + res = self.tokenizer.decode(tokens[0], skip_special_tokens=True) + return res + else: + inputs = self.tokenizer(prompt, return_tensors="pt") + tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) + if crop_question: + tokens = tokens[:, inputs["input_ids"].shape[-1] :] + return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] gen_answer_fn = gen_answer_fn or default_gen_answer @@ -250,6 +258,7 @@ def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question): p, self.max_new_tokens, self._crop_question, + self.use_chat_template ) ) else: diff --git a/tools/who_what_benchmark/whowhatbench/wwb.py b/tools/who_what_benchmark/whowhatbench/wwb.py index 7acf3cf5aa..7d4354f846 100644 --- a/tools/who_what_benchmark/whowhatbench/wwb.py +++ b/tools/who_what_benchmark/whowhatbench/wwb.py @@ -40,6 +40,11 @@ def parse_args(): default=None, help="Tokenizer for divergency metric. If not provided, it will be load from base_model or target_model.", ) + parser.add_argument( + "--chat-template", + action="store_true", + help="Whether apply the default chat template.", + ) parser.add_argument( "--gt-data", default=None, @@ -137,6 +142,11 @@ def parse_args(): action="store_true", help="Use LLMPipeline from transformers library to instantiate the model.", ) + parser.add_argument( + "--llamacpp", + action="store_true", + help="Use llama-cpp-python to instantiate the model.", + ) parser.add_argument( "--image-size", type=int, @@ -190,9 +200,13 @@ def load_prompts(args): def load_tokenizer(args): tokenizer = None if args.tokenizer is not None: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=True - ) + if args.llamacpp: + from llama_cpp.llama_tokenizer import LlamaHFTokenizer + tokenizer = LlamaHFTokenizer.from_pretrained(args.tokenizer) + else: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=True + ) elif args.base_model is not None: tokenizer = AutoTokenizer.from_pretrained( args.base_model, trust_remote_code=True @@ -246,8 +260,29 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str: return "".join(output) -def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question): - return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens) +def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): + if use_chat_template: + model.start_chat() + result = model.generate(question, do_sample=False, max_new_tokens=max_new_tokens) + model.finish_chat() + return result + else: + return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens) + + +def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): + if use_chat_template: + output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0) + text = output["choices"][0]["message"]["content"] + if skip_question: + text = text[len(question):] + return text + else: + output = model(question, max_tokens=max_new_tokens, echo=True, temperature=0.0) + text = output["choices"][0]["text"] + if skip_question: + text = text[len(question):] + return text def genai_gen_image(model, prompt, num_inference_steps, generator=None): @@ -322,7 +357,15 @@ def create_evaluator(base_model, args): prompts = load_prompts(args) if task == "text": - tokenizer = load_tokenizer(args) + tokenizer = load_tokenizer(args) if not args.llamacpp else None + + if args.genai: + gen_answer_fn = genai_gen_text + elif args.llamacpp: + gen_answer_fn = llamacpp_gen_text + else: + gen_answer_fn = None + return EvaluatorCLS( base_model=base_model, gt_data=args.gt_data, @@ -331,7 +374,8 @@ def create_evaluator(base_model, args): similarity_model_id=args.data_encoder, num_samples=args.num_samples, language=args.language, - gen_answer_fn=genai_gen_text if args.genai else None, + gen_answer_fn=gen_answer_fn, + use_chat_template=args.chat_template, ) elif task == "text-to-image": return EvaluatorCLS( @@ -467,10 +511,11 @@ def main(): args.ov_config, args.hf, args.genai, + args.llamacpp ) all_metrics_per_question, all_metrics = evaluator.score( target_model, - evaluator.get_generation_fn() if args.genai else None, + evaluator.get_generation_fn() if args.genai or args.llamacpp else None, output_dir=args.output ) logger.info("Metrics for model: %s", args.target_model)