Skip to content

Commit aa08c3e

Browse files
hotfixing which learner to incorperate new explanation choice and heuristic options
1 parent 14783ba commit aa08c3e

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

apprentice/learners/WhichLearner.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ def heuristic(self,state):
6464
s = p + n
6565
return (p / s if s > 0 else 0, s)
6666

67+
class WeightedProportionCorrect(TotalCorrect):
68+
def heuristic(self,state,w=2.0):
69+
p,n = self.num_correct, w*self.num_incorrect
70+
s = p + n
71+
return (p / s if s > 0 else 0, s)
72+
73+
class NonLinearProportionCorrect(TotalCorrect):
74+
def heuristic(self,state,a=1.0,b=1.0):
75+
p,n = self.num_correct, self.num_incorrect
76+
n = a*n + b*(n*n)
77+
s = p + n
78+
return (p / s if s > 0 else 0, s)
79+
6780
####---------------HOW CULL RULE------------########
6881

6982
def first(expl_iter):
@@ -73,6 +86,18 @@ def most_parsimonious(expl_iter):
7386
l = sorted(expl_iter,key=lambda x:x.get_how_depth())
7487
return l[:1]
7588

89+
def least_depth(expl_iter):
90+
expl_iter = list(expl_iter)
91+
shuffle(expl_iter)
92+
l = sorted(expl_iter,key=lambda x: getattr(x.rhs.input_rule,'depth',0))
93+
return l[:1]
94+
95+
def least_operations(expl_iter):
96+
expl_iter = list(expl_iter)
97+
shuffle(expl_iter)
98+
l = sorted(expl_iter,key=lambda x: getattr(x.rhs.input_rule,'num_ops',0))
99+
return l[:1]
100+
76101
def return_all(expl_iter):
77102
return [x for x in expl_iter]
78103

@@ -125,11 +150,15 @@ def get_which_learner(heuristic_learner,explanation_choice,**kwargs):
125150
WHICH_HEURISTIC_AGENTS = {
126151
'proportioncorrect': ProportionCorrect,
127152
'totalcorrect': TotalCorrect,
153+
'weightedproportioncorrect': WeightedProportionCorrect,
154+
'nonlinearproportioncorrect': NonLinearProportionCorrect
128155
}
129156

130157
CULL_HOW_RULES = {
131158
'first': first,
132-
'mostparsimonious': most_parsimonious,
159+
'mostparsimonious': most_parsimonious, #probably need to depricate
160+
'leastdepth': least_depth,
161+
'leastoperations': least_operations,
133162
'all': return_all,
134163
'random' : random,
135164
# 'closest': closest,

0 commit comments

Comments
 (0)