From 3062409ee84ac7bfebab30f445dbf94fe935f24f Mon Sep 17 00:00:00 2001 From: Terry Zhuo Date: Thu, 18 Apr 2024 18:41:35 +1000 Subject: [PATCH] add: evaluation scripts --- script/openeval_chat_osmodel.py | 58 ++++++++---- script/openeval_code_osmodel.py | 153 ++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 18 deletions(-) create mode 100644 script/openeval_code_osmodel.py diff --git a/script/openeval_chat_osmodel.py b/script/openeval_chat_osmodel.py index 094b59f6..6acfea8b 100644 --- a/script/openeval_chat_osmodel.py +++ b/script/openeval_chat_osmodel.py @@ -17,10 +17,24 @@ } """ +EOS = ["\ndef", "\nclass ", "\nimport ", "\nfrom ", "\nassert ", "\n# "] + def get_prompt_base(doc): return "Complete the following function:\n" + doc["prompt"] - +def stop_at_stop_token(decoded_string, stop_tokens): + """ + Produces the prefix of decoded_string that ends at the first occurrence of + a stop_token. + WARNING: the decoded_string *must not* include the prompt, which may have stop tokens + itself. + """ + min_stop_index = len(decoded_string) + for stop_token in stop_tokens: + stop_index = decoded_string.find(stop_token) + if stop_index != -1 and stop_index < min_stop_index: + min_stop_index = stop_index + return decoded_string[:min_stop_index] class ParseError(Exception): pass @@ -46,12 +60,12 @@ def __call__(self, prompt: str, content: str, entry_point: str): matcher = CSequenceMatcher(None, prompt, content) tag, _, _, j1, j2 = matcher.get_opcodes()[-1] if tag == "insert": - return content[j1:j2] + return stop_at_stop_token(content[j1:j2], EOS) # second parse content with assumption that model wrote code without description for entry_point in self._entry_point_variations(entry_point): if entry_point in content: - content = content.split(entry_point)[-1] - return "".join(content.splitlines(keepends=True)[1:]) + content = content.split(entry_point)[1] + return stop_at_stop_token("".join(content.splitlines(keepends=True)[1:]), EOS) raise ParseError(f"Prompt is not in content:\n{content}") @@ -66,7 +80,7 @@ def __init__(self, model: str): ) self.tokenizer = AutoTokenizer.from_pretrained(model) - def __call__(self, prompt: str, n: int) -> str: + def __call__(self, prompt: str, max_new_tokens=1024, temperature=0, top_p=0.95, n=1) -> List[str]: messages = [ { "role": "user", @@ -80,17 +94,23 @@ def __call__(self, prompt: str, n: int) -> str: ) model_inputs = self.tokenizer([text], return_tensors="pt").to(device) - generated_ids = self.model.generate( - model_inputs.input_ids, - do_sample=True, - max_new_tokens=1024, - temperature=0.2, - top_p=0.95, - num_return_sequences=n, - ) - generated_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) - ] + if not temperature: + generated_ids = self.model.generate( + model_inputs.input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + temperature=temperature, + num_return_sequences=n, + ) + else: + generated_ids = self.model.generate( + model_inputs.input_ids, + do_sample=True, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + num_return_sequences=n, + ) content_list = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return content_list @@ -100,6 +120,8 @@ def __call__(self, prompt: str, n: int) -> str: TIMES = 1 VERBOSE = True MODEL = "Qwen/CodeQwen1.5-7B-Chat" + TEMPERATURE = 0 + input_file = sys.argv[1] # make test directory if not os.path.exists("results"): @@ -120,7 +142,7 @@ def __call__(self, prompt: str, n: int) -> str: if VERBOSE: print(f"Processing {sample['task_id']} ({idx + 1}/{len(samples)}))...") - sample["raw_generation"] = chat_wrapper(prompt, TIMES) + sample["raw_generation"] = chat_wrapper(prompt, temperature=TEMPERATURE, n=TIMES) try: sample["generation"] = [parser(prompt, generation_item, sample["task_id"]) for generation_item in sample["raw_generation"]] except ParseError as e: @@ -136,6 +158,6 @@ def __call__(self, prompt: str, n: int) -> str: if VERBOSE: print("parse error rate:", parse_errors / len(samples)) - results_filename = MODEL+f"_completions_"+input_file.split("/")[-1].split(".")[0]+".jsonl" + results_filename = MODEL.split("/")[-1]+f"_completions_"+input_file.split("/")[-1].split(".")[0]+".jsonl" with jsonlines.open("results/"+results_filename, "w") as writer: writer.write_all(samples) diff --git a/script/openeval_code_osmodel.py b/script/openeval_code_osmodel.py new file mode 100644 index 00000000..3c420691 --- /dev/null +++ b/script/openeval_code_osmodel.py @@ -0,0 +1,153 @@ +import os +from transformers import AutoModelForCausalLM, AutoTokenizer + +import termcolor +import jsonlines +import sys + +from cdifflib import CSequenceMatcher +from camel_converter import to_snake +# from datasets import load_dataset +from typing import List +from tqdm import tqdm +device = "cuda" # the device to load the model onto + +_CITATION = """ + +} +""" + +EOS = ["\ndef", "\nclass ", "\nimport ", "\nfrom ", "\nassert ", "\n# "] + +def get_prompt_base(doc): + return "Complete the following function:\n" + doc["prompt"] + +def stop_at_stop_token(decoded_string, stop_tokens): + """ + Produces the prefix of decoded_string that ends at the first occurrence of + a stop_token. + WARNING: the decoded_string *must not* include the prompt, which may have stop tokens + itself. + """ + min_stop_index = len(decoded_string) + for stop_token in stop_tokens: + stop_index = decoded_string.find(stop_token) + if stop_index != -1 and stop_index < min_stop_index: + min_stop_index = stop_index + return decoded_string[:min_stop_index] + +class ParseError(Exception): + pass + +class ContentParser: + + @staticmethod + def _entry_point_variations(entry_point: str) -> List[str]: + # NOTE: workaround dataset's bug with entry point naming + return [ + entry_point, + to_snake(entry_point), + entry_point[0].lower() + entry_point[1:], + ] + + def __call__(self, prompt: str, content: str, entry_point: str): + # NOTE: Model doesn't follow instructions directly: + # adds description of change and sometimes fixes + # typos, or other "bugs" in description. + if "```" in content: + content = content.split("```")[1] + # first parse with assumption that content has description + matcher = CSequenceMatcher(None, prompt, content) + tag, _, _, j1, j2 = matcher.get_opcodes()[-1] + if tag == "insert": + return stop_at_stop_token(content[j1:j2], EOS) + # second parse content with assumption that model wrote code without description + for entry_point in self._entry_point_variations(entry_point): + if entry_point in content: + content = content.split(entry_point)[1] + return stop_at_stop_token("".join(content.splitlines(keepends=True)[1:]), EOS) + raise ParseError(f"Prompt is not in content:\n{content}") + + +class CodeWrapper: + + def __init__(self, model: str): + + self.model = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto", + device_map="auto" + ) + self.tokenizer = AutoTokenizer.from_pretrained(model) + + def __call__(self, prompt: str, max_new_tokens=1024, temperature=0, top_p=0.95, n=1) -> List[str]: + + model_inputs = self.tokenizer([prompt], return_tensors="pt").to(device) + + if not temperature: + generated_ids = self.model.generate( + model_inputs.input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + temperature=temperature, + num_return_sequences=n, + ) + else: + generated_ids = self.model.generate( + model_inputs.input_ids, + do_sample=True, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + num_return_sequences=n, + ) + + content_list = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return content_list + + +if __name__ == '__main__': + TIMES = 1 + VERBOSE = True + MODEL = "bigcode/starcoder2-7b" + TEMPERATURE = 0 + + input_file = sys.argv[1] + # make test directory + if not os.path.exists("results"): + os.makedirs("results") + + # Load descriptions + + samples = [] + with jsonlines.open(input_file) as f: + for s in f: + samples.append(s) + + code_wrapper = CodeWrapper(MODEL) + parse_errors = 0 + parser = ContentParser() + for idx, sample in enumerate(tqdm(samples)): + prompt = get_prompt_base(sample) + + if VERBOSE: + print(f"Processing {sample['task_id']} ({idx + 1}/{len(samples)}))...") + sample["raw_generation"] = code_wrapper(prompt, temperature=TEMPERATURE, n=TIMES) + try: + sample["generation"] = [parser(prompt, generation_item, sample["task_id"]) for generation_item in sample["raw_generation"]] + except ParseError as e: + parse_errors += 1 + print("PARSE EXCEPTION:", e) + sample["generation"] = [""] + if VERBOSE: + for i in range(TIMES): + print(termcolor.colored(sample["task_id"], "yellow", attrs=["bold"])) + print(termcolor.colored(prompt, "yellow")) + print(termcolor.colored(sample["canonical_solution"], "red")) + print(termcolor.colored(sample["generation"][i], "green")+"\n\n") + if VERBOSE: + print("parse error rate:", parse_errors / len(samples)) + + results_filename = MODEL.split("/")[-1]+f"_completions_"+input_file.split("/")[-1].split(".")[0]+".jsonl" + with jsonlines.open("results/"+results_filename, "w") as writer: + writer.write_all(samples)