Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dots eval #75

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reason/evaluation/dots_prompts/MATH/cot_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
In this step, you need to think step by step with words, solve the problem and get the answer.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
In this step, you need to reflect on the problem, and describe it in your own words. Analyze how you can decompose the problem into smaller, more manageable sub-tasks. Pay attention to small details, nuances, notes and examples in the problem description.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Action Categories:
1. Understanding process: query rewriting: Rewrite the question and answer it. Decomposi- tion: Decompose the questions into multiple subtasks to solve the sub-question. 2. Solving process: chain of thought: For step-by-step reasoning with language. programming: For programming solver. 3. Verification process: self-verification: To check the correctness of the solution.
Task Instruction: For the given question, explain why the above Required actions are nec- essary.
Example 1:
Query: Find 2 · 5−1 + 8 · 11−1 (mod 56). Express your answer as an integer from 0 to 55, inclusive.
Required Action: programming, self-verification
Explanation: This is a Modular arithmetic problem. The problem can be solved using straightforward python code with sympy library, particularly modular arithmetic. Besides, this type of problem is relatively easy to verify. After computing the result, one can check the calculations step by step to ensure correctness and verify that the final answer is within the given range (0 to 55 inclusive). Programming solver is more efficient and accurate for this type of calculation and the verifier ensures the correctness of the result and adherence to the given constraints.
... (multiple examples)
Query: Given Query
Required Action: Actions After Searching
Explanation:
2 changes: 2 additions & 0 deletions reason/evaluation/dots_prompts/MATH/pot_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
In this step, you need to write Python codes to solve the query. Use the simplest and most straightforward programming methods to solve the problem. For instance, if a query can be efficiently solved using a brute force method, prefer it over heuristic or more complex methods. Utilize any available and commonly-used libraries that can simplify the task or improve code maintainability. All the calculations must leverage codes. Print out the results with the print() function. Before executing the program, you have no idea of the final answer. Don’t show it in your comment or code. And don’t use the plot function.
In this step, start with “# Now write Python codes to answer this question and use print() to print out the result”
4 changes: 4 additions & 0 deletions reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
In this step, you need to reveal the Core Question with only a simple sentence and useful information. The output follows the format:
core question:...
Note: Please extract the question-solving information related to the problem, and list them one by one.
useful information:...
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
In this step, you need to carefully verify the correctness of the previous thoughts with natural language. You need to formulate a verification question (not the same question as before) based on the final answer and then verify the final answer you have. If the results are incorrect, the last line should end up with “The answer is: incorrect”. Otherwise, the last line should end with “The answer is: correct”
5 changes: 4 additions & 1 deletion reason/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def parallel_evaluate_test_dataset(
print("Method: {}. Average result: {}".format(method_name, avg_res))
return results

solver_fns = {"cot": cot, "best_of_n": best_of_n}
solver_fns = {"cot": cot, "best_of_n": best_of_n, "dots": dots}

cfg_dict_record = dict()
# XXX: qwen-2.5 requires add more stop words
Expand All @@ -192,6 +192,9 @@ def parallel_evaluate_test_dataset(
config.task_name, num_sequence=config.num_sequence
)
solver_fn = partial(best_of_n, method_config, gen_config)
elif config.method == "dots":
method_config = DotsConfig(config.task_name)
solver_fn = partial(dots, method_config, gen_config)
elif config.method == "beam_search":
method_config = BeamSearchConfig(
task_name=config.task_name,
Expand Down
72 changes: 71 additions & 1 deletion reason/evaluation/methods.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
import random
import functools
from typing import Dict
from reason.inference.lm_call import LMCallingConfig, LanguageModelCallingFunction
from reason.inference.rm_call import RewardModelCallingFunction
from reason.evaluation.evaluator import SolutionOutput, Task, TreeSearchSolutionOutput
from reason.guided_search.tree import SearchTree
from reason.guided_search.rstar import RstarSearchTree
from reason.evaluation.utils import read_txt


@dataclass
Expand Down Expand Up @@ -61,6 +63,74 @@ def best_of_n(
completion_tokens=completion_tokens,
)

@dataclass # TODO: no need to load every time...right?
class DotsConfig(BasicConfig):
depth: int = 5
num_sequence: int = 1

analysis_layer_prompts: list[str] = field(default_factory=list)
solution_layer_prompts: list[str] = field(default_factory=list)
verification_layer_prompts: list[str] = field(default_factory=list)

def __post_init__(self):
empty_prompt: str = ""
rewriting_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/rewriting_prompt.txt")
decomposition_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/decomposition_prompt.txt")
cot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/cot_prompt.txt")
pot_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/pot_prompt.txt")
self_verification_prompt: str = read_txt("reason/evaluation/dots_prompts/MATH/self_verification_prompt.txt")

self.analysis_layer_prompts = [empty_prompt, rewriting_prompt, decomposition_prompt]
self.solution_layer_prompts = [cot_prompt, pot_prompt]
self.verification_layer_prompts = [empty_prompt, self_verification_prompt]

def dots(
config: DotsConfig,
gen_config: LMCallingConfig,
problem_inst: Dict[str, str],
lm_call: LanguageModelCallingFunction,
rm_call: RewardModelCallingFunction,
) -> SolutionOutput:
if gen_config.max_new_tokens < 256:
print("Warning: max_new_tokens is less than 256")

gen_config.n = config.num_sequence
task = Task(task_name=config.task_name)
# prompt = task.prompt_fn(problem_inst["question"])
prompt = problem_inst["question"] + "\n"

# analysis layer
analysis_prompt = random.choice(config.analysis_layer_prompts)
if analysis_prompt != "": # EMPTY actions means bypass
output = lm_call(prompt+analysis_prompt, gen_config)
prompt += (output.text[0] + "\n")

step_num = 0
completion_tokens = 0
while step_num <= config.depth:
step_num += 1

# solution layer
solution_prompt = random.choice(config.solution_layer_prompts)
output = lm_call(prompt+solution_prompt, gen_config) # TODO: PoT run the code
completion_tokens += output.num_tokens[0]
prompt += (output.text[0] + "\n")

# verification layer
verification_prompt = random.choice(config.verification_layer_prompts)
if verification_prompt == "": # EMPTY actions means bypass
continue
output = lm_call(prompt+verification_prompt, gen_config)
prompt += (output.text[0] + "\n")
completion_tokens += output.num_tokens[0]

if "The answer is: correct" in output.text[0]:
break

return SolutionOutput(
solutions=[prompt],
completion_tokens=[completion_tokens],
)

@dataclass
class TreeSearchConfig(BasicConfig):
Expand Down
6 changes: 6 additions & 0 deletions reason/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ def setup_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.backends.cudnn.deterministic = True

def read_txt(file_path):
assert str(file_path).endswith(".txt")
with open(file_path, "r", encoding="utf-8") as f:
data = f.read()
return data
10 changes: 10 additions & 0 deletions scripts/eval/dots.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python reason/evaluation/evaluate.py \
--LM deepseek-math-7b-instruct \
--task_name MATH \
--temperature 0.0 \
--max_new_tokens 2048 \
--save_dir results \
--method dots \
--num_worker 32 \
--controller_addr http://0.0.0.0:28777 \
--local