@@ -64,6 +64,19 @@ def heuristic(self,state):
6464 s = p + n
6565 return (p / s if s > 0 else 0 , s )
6666
67+ class WeightedProportionCorrect (TotalCorrect ):
68+ def heuristic (self ,state ,w = 2.0 ):
69+ p ,n = self .num_correct , w * self .num_incorrect
70+ s = p + n
71+ return (p / s if s > 0 else 0 , s )
72+
73+ class NonLinearProportionCorrect (TotalCorrect ):
74+ def heuristic (self ,state ,a = 1.0 ,b = 1.0 ):
75+ p ,n = self .num_correct , self .num_incorrect
76+ n = a * n + b * (n * n )
77+ s = p + n
78+ return (p / s if s > 0 else 0 , s )
79+
6780####---------------HOW CULL RULE------------########
6881
6982def first (expl_iter ):
@@ -73,6 +86,18 @@ def most_parsimonious(expl_iter):
7386 l = sorted (expl_iter ,key = lambda x :x .get_how_depth ())
7487 return l [:1 ]
7588
89+ def least_depth (expl_iter ):
90+ expl_iter = list (expl_iter )
91+ shuffle (expl_iter )
92+ l = sorted (expl_iter ,key = lambda x : getattr (x .rhs .input_rule ,'depth' ,0 ))
93+ return l [:1 ]
94+
95+ def least_operations (expl_iter ):
96+ expl_iter = list (expl_iter )
97+ shuffle (expl_iter )
98+ l = sorted (expl_iter ,key = lambda x : getattr (x .rhs .input_rule ,'num_ops' ,0 ))
99+ return l [:1 ]
100+
76101def return_all (expl_iter ):
77102 return [x for x in expl_iter ]
78103
@@ -125,11 +150,15 @@ def get_which_learner(heuristic_learner,explanation_choice,**kwargs):
125150WHICH_HEURISTIC_AGENTS = {
126151 'proportioncorrect' : ProportionCorrect ,
127152 'totalcorrect' : TotalCorrect ,
153+ 'weightedproportioncorrect' : WeightedProportionCorrect ,
154+ 'nonlinearproportioncorrect' : NonLinearProportionCorrect
128155}
129156
130157CULL_HOW_RULES = {
131158 'first' : first ,
132- 'mostparsimonious' : most_parsimonious ,
159+ 'mostparsimonious' : most_parsimonious , #probably need to depricate
160+ 'leastdepth' : least_depth ,
161+ 'leastoperations' : least_operations ,
133162 'all' : return_all ,
134163 'random' : random ,
135164 # 'closest': closest,
0 commit comments