From bd588a9cfe13da5228f49d0cf2b2b8c526324250 Mon Sep 17 00:00:00 2001 From: banma12956 <1718613239banma@gmail.com> Date: Fri, 6 Dec 2024 14:42:58 +0800 Subject: [PATCH 1/2] add dots eval --- .../dots_prompts/MATH/cot_prompt.txt | 1 + .../MATH/decomposition_prompt.txt | 1 + .../MATH/explanation_generation_prompt.txt | 11 ++++ .../dots_prompts/MATH/pot_prompt.txt | 2 + .../dots_prompts/MATH/rewriting_prompt.txt | 4 ++ .../MATH/self_verification_prompt.txt | 1 + reason/evaluation/evaluate.py | 5 +- reason/evaluation/methods.py | 65 +++++++++++++++++++ reason/evaluation/utils.py | 6 ++ scripts/eval/dots.sh | 10 +++ 10 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 reason/evaluation/dots_prompts/MATH/cot_prompt.txt create mode 100644 reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt create mode 100644 reason/evaluation/dots_prompts/MATH/explanation_generation_prompt.txt create mode 100644 reason/evaluation/dots_prompts/MATH/pot_prompt.txt create mode 100644 reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt create mode 100644 reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt create mode 100755 scripts/eval/dots.sh diff --git a/reason/evaluation/dots_prompts/MATH/cot_prompt.txt b/reason/evaluation/dots_prompts/MATH/cot_prompt.txt new file mode 100644 index 0000000..a956f7e --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/cot_prompt.txt @@ -0,0 +1 @@ +In this step, you need to think step by step with words, solve the problem and get the answer. diff --git a/reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt b/reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt new file mode 100644 index 0000000..c135939 --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt @@ -0,0 +1 @@ +In this step, you need to reflect on the problem, and describe it in your own words. Analyze how you can decompose the problem into smaller, more manageable sub-tasks. Pay attention to small details, nuances, notes and examples in the problem description. diff --git a/reason/evaluation/dots_prompts/MATH/explanation_generation_prompt.txt b/reason/evaluation/dots_prompts/MATH/explanation_generation_prompt.txt new file mode 100644 index 0000000..c3df45f --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/explanation_generation_prompt.txt @@ -0,0 +1,11 @@ +Action Categories: +1. Understanding process: query rewriting: Rewrite the question and answer it. Decomposi- tion: Decompose the questions into multiple subtasks to solve the sub-question. 2. Solving process: chain of thought: For step-by-step reasoning with language. programming: For programming solver. 3. Verification process: self-verification: To check the correctness of the solution. +Task Instruction: For the given question, explain why the above Required actions are nec- essary. +Example 1: +Query: Find 2 · 5−1 + 8 · 11−1 (mod 56). Express your answer as an integer from 0 to 55, inclusive. +Required Action: programming, self-verification +Explanation: This is a Modular arithmetic problem. The problem can be solved using straightforward python code with sympy library, particularly modular arithmetic. Besides, this type of problem is relatively easy to verify. After computing the result, one can check the calculations step by step to ensure correctness and verify that the final answer is within the given range (0 to 55 inclusive). Programming solver is more efficient and accurate for this type of calculation and the verifier ensures the correctness of the result and adherence to the given constraints. +... (multiple examples) +Query: Given Query +Required Action: Actions After Searching +Explanation: diff --git a/reason/evaluation/dots_prompts/MATH/pot_prompt.txt b/reason/evaluation/dots_prompts/MATH/pot_prompt.txt new file mode 100644 index 0000000..576cc62 --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/pot_prompt.txt @@ -0,0 +1,2 @@ +In this step, you need to write Python codes to solve the query. Use the simplest and most straightforward programming methods to solve the problem. For instance, if a query can be efficiently solved using a brute force method, prefer it over heuristic or more complex methods. Utilize any available and commonly-used libraries that can simplify the task or improve code maintainability. All the calculations must leverage codes. Print out the results with the print() function. Before executing the program, you have no idea of the final answer. Don’t show it in your comment or code. And don’t use the plot function. +In this step, start with “# Now write Python codes to answer this question and use print() to print out the result” diff --git a/reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt b/reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt new file mode 100644 index 0000000..6b4f0ee --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt @@ -0,0 +1,4 @@ +In this step, you need to reveal the Core Question with only a simple sentence and useful information. The output follows the format: +core question:... +Note: Please extract the question-solving information related to the problem, and list them one by one. +useful information:... diff --git a/reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt b/reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt new file mode 100644 index 0000000..95642c5 --- /dev/null +++ b/reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt @@ -0,0 +1 @@ +In this step, you need to carefully verify the correctness of the previous thoughts with natural language. You need to formulate a verification question (not the same question as before) based on the final answer and then verify the final answer you have. If the results are incorrect, the last line should end up with “The answer is: incorrect”. Otherwise, the last line should end with “The answer is: correct” diff --git a/reason/evaluation/evaluate.py b/reason/evaluation/evaluate.py index 17cf6b1..8da1a2e 100644 --- a/reason/evaluation/evaluate.py +++ b/reason/evaluation/evaluate.py @@ -169,7 +169,7 @@ def parallel_evaluate_test_dataset( print("Method: {}. Average result: {}".format(method_name, avg_res)) return results - solver_fns = {"cot": cot, "best_of_n": best_of_n} + solver_fns = {"cot": cot, "best_of_n": best_of_n, "dots": dots} cfg_dict_record = dict() # XXX: qwen-2.5 requires add more stop words @@ -192,6 +192,9 @@ def parallel_evaluate_test_dataset( config.task_name, num_sequence=config.num_sequence ) solver_fn = partial(best_of_n, method_config, gen_config) + elif config.method == "dots": + method_config = DotsConfig(config.task_name) + solver_fn = partial(dots, method_config, gen_config) elif config.method == "beam_search": method_config = BeamSearchConfig( task_name=config.task_name, diff --git a/reason/evaluation/methods.py b/reason/evaluation/methods.py index bad3bd7..20b8f9e 100644 --- a/reason/evaluation/methods.py +++ b/reason/evaluation/methods.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import random import functools from typing import Dict from reason.inference.lm_call import LMCallingConfig, LanguageModelCallingFunction @@ -6,6 +7,7 @@ from reason.evaluation.evaluator import SolutionOutput, Task, TreeSearchSolutionOutput from reason.guided_search.tree import SearchTree from reason.guided_search.rstar import RstarSearchTree +from reason.evaluation.utils import read_txt @dataclass @@ -61,6 +63,69 @@ def best_of_n( completion_tokens=completion_tokens, ) +@dataclass # TODO: no need to load every time...right? +class DotsConfig(BasicConfig): + depth: int = 5 + num_sequence: int = 1 + + empty_prompt: str = "" + rewriting_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt") + decomposition_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt") + cot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/cot_prompt.txt") + pot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/pot_prompt.txt") + self_verification_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt") + + analysis_layer_prompts = [empty_prompt, rewriting_prompt, decomposition_prompt] + solution_layer_prompts = [cot_prompt, pot_prompt] + verification_layer_prompts = [empty_prompt, self_verification_prompt] + +def dots( + config: DotsConfig, + gen_config: LMCallingConfig, + problem_inst: Dict[str, str], + lm_call: LanguageModelCallingFunction, + rm_call: RewardModelCallingFunction, +) -> SolutionOutput: + if gen_config.max_new_tokens < 256: + print("Warning: max_new_tokens is less than 256") + + gen_config.n = config.num_sequence + task = Task(task_name=config.task_name) + # prompt = task.prompt_fn(problem_inst["question"]) + prompt = problem_inst["question"] + "\n" + + # analysis layer + analysis_prompt = random.choice(config.analysis_layer_prompts) + if analysis_prompt != "": # EMPTY actions means bypass + output = lm_call(prompt+analysis_prompt, gen_config) + prompt += (output.text[0] + "\n") + + step_num = 0 + completion_tokens = 0 + while step_num <= config.depth: + step_num += 1 + + # solution layer + solution_prompt = random.choice(config.solution_layer_prompts) + output = lm_call(prompt+solution_prompt, gen_config) # TODO: PoT run the code + completion_tokens += output.num_tokens[0] + prompt += (output.text[0] + "\n") + + # verification layer + verification_prompt = random.choice(config.verification_layer_prompts) + if verification_prompt == "": # EMPTY actions means bypass + continue + output = lm_call(prompt+verification_prompt, gen_config) + prompt += (output.text[0] + "\n") + completion_tokens += output.num_tokens[0] + + if "The answer is: correct" in output.text[0]: + break + + return SolutionOutput( + solutions=[prompt], + completion_tokens=[completion_tokens], + ) @dataclass class TreeSearchConfig(BasicConfig): diff --git a/reason/evaluation/utils.py b/reason/evaluation/utils.py index 40955b3..2751677 100644 --- a/reason/evaluation/utils.py +++ b/reason/evaluation/utils.py @@ -32,3 +32,9 @@ def setup_seed(seed): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.backends.cudnn.deterministic = True + +def read_txt(file_path): + assert str(file_path).endswith(".txt") + with open(file_path, "r", encoding="utf-8") as f: + data = f.read() + return data diff --git a/scripts/eval/dots.sh b/scripts/eval/dots.sh new file mode 100755 index 0000000..7c3a639 --- /dev/null +++ b/scripts/eval/dots.sh @@ -0,0 +1,10 @@ +python reason/evaluation/evaluate.py \ + --LM deepseek-math-7b-instruct \ + --task_name MATH \ + --temperature 0.0 \ + --max_new_tokens 2048 \ + --save_dir results \ + --method dots \ + --num_worker 32 \ + --controller_addr http://0.0.0.0:28777 \ + --local \ No newline at end of file From 4d0378a6931ba55cbb22066e7a0e48a9a34aeb41 Mon Sep 17 00:00:00 2001 From: banma12956 <1718613239banma@gmail.com> Date: Fri, 6 Dec 2024 15:16:09 +0800 Subject: [PATCH 2/2] add post_init function --- reason/evaluation/methods.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/reason/evaluation/methods.py b/reason/evaluation/methods.py index 20b8f9e..2782509 100644 --- a/reason/evaluation/methods.py +++ b/reason/evaluation/methods.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import random import functools from typing import Dict @@ -68,16 +68,21 @@ class DotsConfig(BasicConfig): depth: int = 5 num_sequence: int = 1 - empty_prompt: str = "" - rewriting_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt") - decomposition_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt") - cot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/cot_prompt.txt") - pot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/pot_prompt.txt") - self_verification_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt") + analysis_layer_prompts: list[str] = field(default_factory=list) + solution_layer_prompts: list[str] = field(default_factory=list) + verification_layer_prompts: list[str] = field(default_factory=list) - analysis_layer_prompts = [empty_prompt, rewriting_prompt, decomposition_prompt] - solution_layer_prompts = [cot_prompt, pot_prompt] - verification_layer_prompts = [empty_prompt, self_verification_prompt] + def __post_init__(self): + empty_prompt: str = "" + rewriting_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt") + decomposition_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt") + cot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/cot_prompt.txt") + pot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/pot_prompt.txt") + self_verification_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt") + + self.analysis_layer_prompts = [empty_prompt, rewriting_prompt, decomposition_prompt] + self.solution_layer_prompts = [cot_prompt, pot_prompt] + self.verification_layer_prompts = [empty_prompt, self_verification_prompt] def dots( config: DotsConfig,