Skip to content

Commit b2a4937

Browse files
author
Chris MacLellan
committed
updated wherewhenhownofoa so it can be run with fractions again.
1 parent 1857c54 commit b2a4937

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

apprentice/agents/WhereWhenHowNoFoa.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from concept_formation.structure_mapper import rename_flat
88

99
from apprentice.agents.base import BaseAgent
10-
from apprentice.learners.WhenLearner import get_when_learner
11-
from apprentice.learners.WhereLearner import get_where_learner
10+
# from apprentice.learners.WhenLearner import get_when_learner
11+
from apprentice.learners.WhenLearner import WHEN_CLASSIFIERS
12+
# from apprentice.learners.WhereLearner import get_where_learner
13+
from apprentice.learners.WhereLearner import MostSpecific
1214
from apprentice.planners.fo_planner import FoPlanner, execute_functions, unify, subst
1315
# from ilp.fo_planner import Operator
1416

15-
# from planners.rulesets import function_sets
16-
# from planners.rulesets import feature_sets
17+
from apprentice.working_memory.fo_planner_operators import functionsets
18+
from apprentice.working_memory.fo_planner_operators import featuresets
1719

1820
# search_depth = 1
1921
# epsilon = .9
@@ -166,14 +168,15 @@ class WhereWhenHowNoFoa(BaseAgent):
166168
This is the basis for the 2 mechanism model.
167169
"""
168170
def __init__(self, feature_set, function_set,
169-
when_learner='trestle', where_learner='MostSpecific',
171+
when_learner='decisiontree', # where_learner='MostSpecific',
170172
search_depth=1, numerical_epsilon=0.0):
171-
self.where = get_where_learner(where_learner)
172-
self.when = get_when_learner(when_learner)
173+
self.where = MostSpecific
174+
# self.when = get_when_learner(when_learner)
175+
self.when = WHEN_CLASSIFIERS[when_learner]
173176
self.skills = {}
174177
self.examples = {}
175-
self.feature_set = feature_set
176-
self.function_set = function_set
178+
self.feature_set = featuresets[feature_set]
179+
self.function_set = functionsets[function_set]
177180
self.search_depth = search_depth
178181
self.epsilon = numerical_epsilon
179182

@@ -452,8 +455,8 @@ def explanations_from_how_search(self,state,sai,input_args):
452455
return explanations
453456

454457

455-
def train(self, state, selection, action, inputs, reward, skill_label,
456-
foci_of_attention):
458+
def train(self, state, sai, reward,
459+
skill_label, foci_of_attention, **kw):
457460
"""
458461
Doc String
459462
"""
@@ -471,6 +474,11 @@ def train(self, state, selection, action, inputs, reward, skill_label,
471474
# label = 'math'
472475

473476
# create example dict
477+
478+
selection = sai.selection
479+
action = sai.action
480+
inputs = sai.inputs
481+
474482
example = {}
475483
example['state'] = state
476484
example['skill_label'] = skill_label

apprentice/learners/WhereLearner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def ground_example(self, x):
278278
return grounded
279279

280280
def check_match(self, t, x):
281-
x = x.get_view("flat_ungrounded")
281+
if hasattr(x, 'get_view'):
282+
x = x.get_view("flat_ungrounded")
282283
# print("CHECK MATCHES T", t)
283284

284285
t = tuple(ground(ele) for ele in t)
@@ -301,7 +302,8 @@ def check_match(self, t, x):
301302
return False
302303

303304
def get_matches(self, x, epsilon=0.0):
304-
x = x.get_view("flat_ungrounded")
305+
if hasattr(x, 'get_view'):
306+
x = x.get_view("flat_ungrounded")
305307

306308
grounded = self.ground_example(x)
307309

apprentice/working_memory/fo_planner_operators.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,13 @@ def int3_float_add_then_tens(x, y, w):
559559
# arith_rules = [add_rule, sub_rule, mult_rule, div_rule, update_rule,
560560
# done_rule]
561561
# arith_rules = [add_rule, mult_rule, update_rule, done_rule]
562+
fraction_fun = [add_rule, mult_rule]
563+
fraction_feat = [equal_rule, editable_rule]
562564

563565
functionsets = {'tutor knowledge': arith_rules,
564566
'stoichiometry': stoichiometry_rules,
565-
'rumbleblocks': rb_rules, 'article selection': []}
567+
'rumbleblocks': rb_rules, 'article selection': [],
568+
'fraction arith': fraction_fun}
566569

567570
featuresets = {'tutor knowledge': [equal_rule,
568571
grammar_parser_rule,
@@ -571,7 +574,9 @@ def int3_float_add_then_tens(x, y, w):
571574
'rumbleblocks': [], 'article selection': [unigram_rule,
572575
bigram_rule,
573576
equal_rule,
574-
editable_rule]}
577+
editable_rule],
578+
'fraction arith': fraction_feat}
579+
575580

576581
if __name__ == "__main__":
577582

0 commit comments

Comments
 (0)