Skip to content

Commit e82c540

Browse files
authored
Merge pull request #35 from Microsoft/parallel-quests-navigation
chaining: Re-generate actions when branching
2 parents 4e0d91b + dd90df1 commit e82c540

File tree

2 files changed

+112
-44
lines changed

2 files changed

+112
-44
lines changed

textworld/generator/chaining.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class ChainingOptions:
7676
subquests:
7777
Whether to also return incomplete quests, which could be extended
7878
without reaching the depth or breadth limits.
79+
independent_chains:
80+
Whether to allow totally independent parallel chains.
7981
create_variables:
8082
Whether new variables may be created during chaining.
8183
fixed_mapping:
@@ -99,6 +101,7 @@ def __init__(self):
99101
self.min_breadth = 1
100102
self.max_breadth = 1
101103
self.subquests = False
104+
self.independent_chains = False
102105
self.create_variables = False
103106
self.fixed_mapping = data.get_types().constants_mapping
104107
self.rng = None
@@ -189,30 +192,17 @@ class _Node:
189192
A node in a chain being generated.
190193
191194
Each node is aware of its position (depth, breadth) in the dependency tree
192-
induced by the chain. For generating parallel quests, the backtracks field
193-
holds actions that can be use to go up the dependency tree and start a new
194-
chain.
195-
196-
For example, taking the action node.backtracks[i][j] will produce a new node
197-
at depth (i + 1) and breadth (self.breadth + 1). To avoid duplication, in
198-
trees like this:
199-
200-
root
201-
/ | \
202-
A B C
203-
| | |
204-
.......
205-
206-
A.backtracks[0] will be [B, C], B.backtracks[0] will be [C], and
207-
C.backtracks[0] will be [].
195+
induced by the chain. To avoid duplication when generating parallel chains,
196+
each node stores the actions that have already been used at that depth.
208197
"""
209198

210-
def __init__(self, parent, dep_parent, state, action, backtracks, depth, breadth):
199+
def __init__(self, parent, dep_parent, state, action, rules, used, depth, breadth):
211200
self.parent = parent
212201
self.dep_parent = dep_parent
213202
self.state = state
214203
self.action = action
215-
self.backtracks = backtracks
204+
self.rules = rules
205+
self.used = used
216206
self.depth = depth
217207
self.breadth = breadth
218208

@@ -235,7 +225,7 @@ def __init__(self, state, options):
235225

236226
def root(self) -> _Node:
237227
"""Create the root node for chaining."""
238-
return _Node(None, None, self.state, None, [], 0, 1)
228+
return _Node(None, None, self.state, None, [], set(), 0, 1)
239229

240230
def chain(self, node: _Node) -> Iterable[_Node]:
241231
"""
@@ -251,30 +241,21 @@ def chain(self, node: _Node) -> Iterable[_Node]:
251241
if self.rng:
252242
self.rng.shuffle(assignments)
253243

254-
partials = []
255-
actions = []
256-
states = []
244+
used = set()
257245
for partial in assignments:
258246
action = self.try_instantiate(node.state, partial)
259247
if not action:
260248
continue
261249

262-
if not self.check_action(node, action):
250+
if not self.check_action(node, node.state, action):
263251
continue
264252

265253
state = self.apply(node, action)
266254
if not state:
267255
continue
268256

269-
partials.append(partial)
270-
actions.append(action)
271-
states.append(state)
272-
273-
for i, action in enumerate(actions):
274-
# Only allow backtracking into later actions, to avoid duplication
275-
remaining = partials[i+1:]
276-
backtracks = node.backtracks + [remaining]
277-
yield _Node(node, node, states[i], action, backtracks, node.depth + 1, node.breadth)
257+
used = used | {action}
258+
yield _Node(node, node, state, action, rules, used, node.depth + 1, node.breadth)
278259

279260
def backtrack(self, node: _Node) -> Iterable[_Node]:
280261
"""
@@ -284,21 +265,39 @@ def backtrack(self, node: _Node) -> Iterable[_Node]:
284265
if node.breadth >= self.max_breadth:
285266
return
286267

287-
for i, partials in enumerate(node.backtracks):
288-
backtracks = node.backtracks[:i]
289-
290-
for j, partial in enumerate(partials):
268+
parent = node
269+
parents = []
270+
while parent.dep_parent:
271+
if parent.depth == 1 and not self.options.independent_chains:
272+
break
273+
parents.append(parent)
274+
parent = parent.dep_parent
275+
parents = parents[::-1]
276+
277+
for sibling in parents:
278+
parent = sibling.dep_parent
279+
rules = self.options.get_rules(parent.depth)
280+
assignments = self.all_assignments(node, rules)
281+
if self.rng:
282+
self.rng.shuffle(assignments)
283+
284+
for partial in assignments:
291285
action = self.try_instantiate(node.state, partial)
292286
if not action:
293287
continue
294288

289+
if action in sibling.used:
290+
continue
291+
292+
if not self.check_action(parent, node.state, action):
293+
continue
294+
295295
state = self.apply(node, action)
296296
if not state:
297297
continue
298298

299-
remaining = partials[j+1:]
300-
new_backtracks = backtracks + [remaining]
301-
yield _Node(node, partial.node, state, action, new_backtracks, i + 1, node.breadth + 1)
299+
used = sibling.used | {action}
300+
yield _Node(node, parent, state, action, rules, used, sibling.depth, node.breadth + 1)
302301

303302
def all_assignments(self, node: _Node, rules: Iterable[Rule]) -> Iterable[_PartialAction]:
304303
"""
@@ -359,7 +358,7 @@ def create_variable(self, state, ph, type_counts):
359358
type_counts[ph.type] += 1
360359
return var
361360

362-
def check_action(self, node: _Node, action: Action) -> bool:
361+
def check_action(self, node: _Node, state: State, action: Action) -> bool:
363362
# Find the last action before a navigation action
364363
# TODO: Fold this behaviour into ChainingOptions.check_action()
365364
nav_parent = node
@@ -387,7 +386,7 @@ def check_action(self, node: _Node, action: Action) -> bool:
387386
if len(recent.added & relevant) == 0 or len(pre_navigation.added & relevant) == 0:
388387
return False
389388

390-
return self.options.check_action(node.state, action)
389+
return self.options.check_action(state, action)
391390

392391
def _is_navigation(self, action):
393392
return action.name.startswith("go/")
@@ -405,8 +404,8 @@ def apply(self, node: _Node, action: Action) -> Optional[State]:
405404

406405
new_state.apply(action)
407406

408-
# Some debug checks
409-
assert self.check_state(new_state)
407+
if not self.check_state(new_state):
408+
return None
410409

411410
# Detect cycles
412411
state = new_state.copy()

textworld/generator/tests/test_chaining.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,73 @@ def test_parallel_quests():
237237
options.min_breadth = 1
238238
options.create_variables = True
239239
chains = list(get_chains(State(), options))
240-
assert len(chains) == 6
240+
assert len(chains) == 5
241+
242+
243+
def test_parallel_quests_navigation():
244+
logic = GameLogic.parse("""
245+
type P {
246+
}
247+
248+
type I {
249+
}
250+
251+
type r {
252+
rules {
253+
move :: at(P, r) & $free(r, r') -> at(P, r');
254+
}
255+
256+
constraints {
257+
atat :: at(P, r) & at(P, r') -> fail();
258+
}
259+
}
260+
261+
type o {
262+
rules {
263+
take :: $at(P, r) & at(o, r) -> in(o, I);
264+
}
265+
266+
constraints {
267+
inat :: in(o, I) & at(o, r) -> fail();
268+
}
269+
}
270+
271+
type flour : o {
272+
}
273+
274+
type eggs : o {
275+
}
276+
277+
type cake {
278+
rules {
279+
bake :: in(flour, I) & in(eggs, I) -> in(cake, I) & in(flour, cake) & in(eggs, cake);
280+
}
281+
282+
constraints {
283+
inincake :: in(o, I) & in(o, cake) -> fail();
284+
atincake :: at(o, r) & in(o, cake) -> fail();
285+
}
286+
}
287+
""")
288+
289+
state = State([
290+
Proposition.parse("at(P, r3: r)"),
291+
Proposition.parse("free(r2: r, r3: r)"),
292+
Proposition.parse("free(r1: r, r2: r)"),
293+
])
294+
295+
bake = [logic.rules["bake"]]
296+
non_bake = [r for r in logic.rules.values() if r.name != "bake"]
297+
298+
options = ChainingOptions()
299+
options.backward = True
300+
options.create_variables = True
301+
options.min_depth = 3
302+
options.max_depth = 3
303+
options.min_breadth = 2
304+
options.max_breadth = 2
305+
options.logic = logic
306+
options.rules_per_depth = [bake, non_bake, non_bake]
307+
options.restricted_types = {"P", "r"}
308+
chains = list(get_chains(state, options))
309+
assert len(chains) == 2

0 commit comments

Comments
 (0)