@@ -224,56 +224,81 @@ async def step(
224
224
self , action : ToolRequestMessage
225
225
) -> tuple [list [Message ], float , bool , bool ]:
226
226
self .state += 1 + random .uniform (- 0.01 , 0.01 )
227
- return [Message (content = str (self .state ))], 0.0 , self .state >= 3 , False
228
-
229
-
230
- @pytest .mark .asyncio
231
- async def test_tree_search ():
232
- agent = CountingAgent ()
233
- # Use a slightly stochastic env so we can distinguish branches
234
- env = NoisyCountingEnv ()
235
-
236
- callback = DummyCallback ()
237
- rollout_manager = TreeSearchRollout (
238
- agent ,
239
- branching_factor = 2 ,
240
- env_clone_fn = deepcopy ,
241
- concurrency_limit = 1 ,
242
- callbacks = [callback ],
243
- )
244
- tree = await rollout_manager .sample_tree (env , max_depth = 3 )
245
- trajs = tree .get_trajectories ()
246
- assert len (trajs ) == 8
227
+ return [Message (content = str (self .state ))], 1.0 , self .state >= 3 , False
228
+
229
+
230
+ class TestTreeSearch :
231
+ @pytest .mark .asyncio
232
+ async def test_tree_search (self ):
233
+ agent = CountingAgent ()
234
+ # Use a slightly stochastic env so we can distinguish branches
235
+ env = NoisyCountingEnv ()
236
+
237
+ callback = DummyCallback ()
238
+ rollout_manager = TreeSearchRollout (
239
+ agent ,
240
+ branching_factor = 2 ,
241
+ env_clone_fn = deepcopy ,
242
+ concurrency_limit = 1 ,
243
+ callbacks = [callback ],
244
+ )
245
+ tree = await rollout_manager .sample_tree (env , max_depth = 3 )
246
+ trajs = tree .get_trajectories ()
247
+ assert len (trajs ) == 8
247
248
248
- traj_ids_wo_root = {
249
- cast (str , traj .traj_id ).replace (tree .root_id , "" ).lstrip (":" ) for traj in trajs
250
- }
251
- # IDs should be 0:0:0, 0:0:1, ... 1:1:1 (order doesn't matter)
252
- assert traj_ids_wo_root == {":" .join (x ) for x in itertools .product ("01" , repeat = 3 )}
249
+ traj_ids_wo_root = {
250
+ cast (str , traj .traj_id ).replace (tree .root_id , "" ).lstrip (":" )
251
+ for traj in trajs
252
+ }
253
+ # IDs should be 0:0:0, 0:0:1, ... 1:1:1 (order doesn't matter)
254
+ assert traj_ids_wo_root == {
255
+ ":" .join (x ) for x in itertools .product ("01" , repeat = 3 )
256
+ }
253
257
254
- observations = {} # type: ignore[var-annotated]
255
- for traj in trajs :
256
- branch_path = tuple (cast (str , traj .traj_id ).split (":" )[1 :])
257
-
258
- prev_step : Transition | None = None
259
- for i_step , step in enumerate (traj .steps ):
260
- if prev_step is not None :
261
- # Check that the child node started at the state emitted at the parent node
262
- assert prev_step .next_agent_state == step .agent_state
263
-
264
- # Steps that started at the same node in the tree should have the same observation
265
- node_id = branch_path [: i_step + 1 ]
266
- if node_id in observations :
267
- assert observations [node_id ] == step .observation [0 ].content
268
- else :
269
- observations [node_id ] = step .observation [0 ].content
270
-
271
- prev_step = step
272
-
273
- # We expect sum_{i=1}^3 2^i = 2^4 - 2 = 14 transitions:
274
- # - branching factor = 2, depth = 3
275
- # - root node isn't sampled, so no i=0 term in sum
276
- assert all (v == 14 for v in callback .fn_invocations .values ())
258
+ observations = {} # type: ignore[var-annotated]
259
+ for traj in trajs :
260
+ branch_path = tuple (cast (str , traj .traj_id ).split (":" )[1 :])
261
+
262
+ prev_step : Transition | None = None
263
+ for i_step , step in enumerate (traj .steps ):
264
+ if prev_step is not None :
265
+ # Check that the child node started at the state emitted at the parent node
266
+ assert prev_step .next_agent_state == step .agent_state
267
+
268
+ # Steps that started at the same node in the tree should have the same observation
269
+ node_id = branch_path [: i_step + 1 ]
270
+ if node_id in observations :
271
+ assert observations [node_id ] == step .observation [0 ].content
272
+ else :
273
+ observations [node_id ] = step .observation [0 ].content
274
+
275
+ prev_step = step
276
+
277
+ # We expect sum_{i=1}^3 2^i = 2^4 - 2 = 14 transitions:
278
+ # - branching factor = 2, depth = 3
279
+ # - root node isn't sampled, so no i=0 term in sum
280
+ assert all (v == 14 for v in callback .fn_invocations .values ())
281
+
282
+ @pytest .mark .asyncio
283
+ async def test_early_stopping (self ):
284
+ agent = CountingAgent ()
285
+ # Use a slightly stochastic env so we can distinguish branches
286
+ env = NoisyCountingEnv ()
287
+
288
+ callback = DummyCallback ()
289
+ rollout_manager = TreeSearchRollout (
290
+ agent ,
291
+ branching_factor = 2 ,
292
+ env_clone_fn = deepcopy ,
293
+ concurrency_limit = 1 ,
294
+ callbacks = [callback ],
295
+ target_reward = 0.5 ,
296
+ )
297
+ trajs = (await rollout_manager .sample_tree (env , max_depth = 3 )).get_trajectories ()
298
+ assert len (trajs ) < 8 # should have exited early
299
+ for traj in trajs :
300
+ # should have hit target reward immediately
301
+ assert len (traj .steps ) == 1
277
302
278
303
279
304
def test_tree_mc_value ():
0 commit comments