Skip to content

Commit 4970ef2

Browse files
authored
More explicit tree search early stopping (#21)
1 parent 2bd6745 commit 4970ef2

File tree

2 files changed

+79
-53
lines changed

2 files changed

+79
-53
lines changed

ldp/alg/tree_search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,19 @@ async def inner_descend(idx: int) -> None:
128128

129129
tree.add_transition(step_id, step)
130130

131+
cumulative_reward = prev_cumulative_reward + step.reward
132+
if cumulative_reward >= self.target_reward:
133+
# signal other descents to stop too
134+
self.target_reward_hit.add(tree.root_id)
135+
return
136+
131137
if step.done:
132138
return
133139

134140
if timestep + 1 >= max_depth:
135141
step.truncated = True
136142
return
137143

138-
cumulative_reward = prev_cumulative_reward + step.reward
139-
if cumulative_reward >= self.target_reward:
140-
# signal other descents to stop too
141-
self.target_reward_hit.add(tree.root_id)
142-
143144
# Recurse
144145
await self._descend(
145146
tree=tree,

tests/test_rollouts.py

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -224,56 +224,81 @@ async def step(
224224
self, action: ToolRequestMessage
225225
) -> tuple[list[Message], float, bool, bool]:
226226
self.state += 1 + random.uniform(-0.01, 0.01)
227-
return [Message(content=str(self.state))], 0.0, self.state >= 3, False
228-
229-
230-
@pytest.mark.asyncio
231-
async def test_tree_search():
232-
agent = CountingAgent()
233-
# Use a slightly stochastic env so we can distinguish branches
234-
env = NoisyCountingEnv()
235-
236-
callback = DummyCallback()
237-
rollout_manager = TreeSearchRollout(
238-
agent,
239-
branching_factor=2,
240-
env_clone_fn=deepcopy,
241-
concurrency_limit=1,
242-
callbacks=[callback],
243-
)
244-
tree = await rollout_manager.sample_tree(env, max_depth=3)
245-
trajs = tree.get_trajectories()
246-
assert len(trajs) == 8
227+
return [Message(content=str(self.state))], 1.0, self.state >= 3, False
228+
229+
230+
class TestTreeSearch:
231+
@pytest.mark.asyncio
232+
async def test_tree_search(self):
233+
agent = CountingAgent()
234+
# Use a slightly stochastic env so we can distinguish branches
235+
env = NoisyCountingEnv()
236+
237+
callback = DummyCallback()
238+
rollout_manager = TreeSearchRollout(
239+
agent,
240+
branching_factor=2,
241+
env_clone_fn=deepcopy,
242+
concurrency_limit=1,
243+
callbacks=[callback],
244+
)
245+
tree = await rollout_manager.sample_tree(env, max_depth=3)
246+
trajs = tree.get_trajectories()
247+
assert len(trajs) == 8
247248

248-
traj_ids_wo_root = {
249-
cast(str, traj.traj_id).replace(tree.root_id, "").lstrip(":") for traj in trajs
250-
}
251-
# IDs should be 0:0:0, 0:0:1, ... 1:1:1 (order doesn't matter)
252-
assert traj_ids_wo_root == {":".join(x) for x in itertools.product("01", repeat=3)}
249+
traj_ids_wo_root = {
250+
cast(str, traj.traj_id).replace(tree.root_id, "").lstrip(":")
251+
for traj in trajs
252+
}
253+
# IDs should be 0:0:0, 0:0:1, ... 1:1:1 (order doesn't matter)
254+
assert traj_ids_wo_root == {
255+
":".join(x) for x in itertools.product("01", repeat=3)
256+
}
253257

254-
observations = {} # type: ignore[var-annotated]
255-
for traj in trajs:
256-
branch_path = tuple(cast(str, traj.traj_id).split(":")[1:])
257-
258-
prev_step: Transition | None = None
259-
for i_step, step in enumerate(traj.steps):
260-
if prev_step is not None:
261-
# Check that the child node started at the state emitted at the parent node
262-
assert prev_step.next_agent_state == step.agent_state
263-
264-
# Steps that started at the same node in the tree should have the same observation
265-
node_id = branch_path[: i_step + 1]
266-
if node_id in observations:
267-
assert observations[node_id] == step.observation[0].content
268-
else:
269-
observations[node_id] = step.observation[0].content
270-
271-
prev_step = step
272-
273-
# We expect sum_{i=1}^3 2^i = 2^4 - 2 = 14 transitions:
274-
# - branching factor = 2, depth = 3
275-
# - root node isn't sampled, so no i=0 term in sum
276-
assert all(v == 14 for v in callback.fn_invocations.values())
258+
observations = {} # type: ignore[var-annotated]
259+
for traj in trajs:
260+
branch_path = tuple(cast(str, traj.traj_id).split(":")[1:])
261+
262+
prev_step: Transition | None = None
263+
for i_step, step in enumerate(traj.steps):
264+
if prev_step is not None:
265+
# Check that the child node started at the state emitted at the parent node
266+
assert prev_step.next_agent_state == step.agent_state
267+
268+
# Steps that started at the same node in the tree should have the same observation
269+
node_id = branch_path[: i_step + 1]
270+
if node_id in observations:
271+
assert observations[node_id] == step.observation[0].content
272+
else:
273+
observations[node_id] = step.observation[0].content
274+
275+
prev_step = step
276+
277+
# We expect sum_{i=1}^3 2^i = 2^4 - 2 = 14 transitions:
278+
# - branching factor = 2, depth = 3
279+
# - root node isn't sampled, so no i=0 term in sum
280+
assert all(v == 14 for v in callback.fn_invocations.values())
281+
282+
@pytest.mark.asyncio
283+
async def test_early_stopping(self):
284+
agent = CountingAgent()
285+
# Use a slightly stochastic env so we can distinguish branches
286+
env = NoisyCountingEnv()
287+
288+
callback = DummyCallback()
289+
rollout_manager = TreeSearchRollout(
290+
agent,
291+
branching_factor=2,
292+
env_clone_fn=deepcopy,
293+
concurrency_limit=1,
294+
callbacks=[callback],
295+
target_reward=0.5,
296+
)
297+
trajs = (await rollout_manager.sample_tree(env, max_depth=3)).get_trajectories()
298+
assert len(trajs) < 8 # should have exited early
299+
for traj in trajs:
300+
# should have hit target reward immediately
301+
assert len(traj.steps) == 1
277302

278303

279304
def test_tree_mc_value():

0 commit comments

Comments
 (0)