-
Notifications
You must be signed in to change notification settings - Fork 0
/
calc_truth.py
35 lines (31 loc) · 1020 Bytes
/
calc_truth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys
sys.path.append('/home/vincent/graphrule/src')
from pandas import read_csv
from task.game24 import calc_exprs, ExpressionTreeBuilder, ASTNode
from graph.standard import Graph24PointI
from llm.tag import Tag
from tqdm import tqdm
from sympy import simplify
path = "/home/vincent/graphrule/data/tasks/24.csv"
graph_path = "/home/vincent/graphrule/data/graph/truth"
tasks = read_csv(path)['Puzzles']
cnt = 0
for task in tqdm(tasks):
data = list(map(lambda x: int(x), task.split()))
accs = calc_exprs(*data)
if not accs:
continue
cnt1 = len(accs)
unique_exprs = set()
for expr in accs:
# simplified_expr = str(simplify(expr))
unique_exprs.add(expr)
accs = list(unique_exprs)
print(len(accs), cnt1)
roots = []
for expr in accs:
root = ExpressionTreeBuilder().build(expr)
roots.append(root)
graph = Graph24PointI.from_ast(roots, data, cnt)
# graph.save_to_json(f"{graph_path}/{graph.name}_{cnt}.json")
cnt += 1