diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index d010f503..fe6438e6 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -432,17 +432,7 @@ def reverse_backward_trajectories( trajectories: Trajectories, debug: bool = False ) -> Trajectories: """Reverses a backward trajectory""" - # FIXME: This method is not compatible with continuous GFN. - assert trajectories.is_backward, "Trajectories must be backward." - new_actions = torch.full( - ( - trajectories.max_length + 1, - len(trajectories), - *trajectories.actions.action_shape, - ), - -1, - ) # env.sf should never be None unless something went wrong during class # instantiation. @@ -466,8 +456,8 @@ def reverse_backward_trajectories( ) # shape (max_len + 1, n_trajectories, *state_dim) # Initialize new actions and states - new_actions = torch.full( - (max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1 + new_actions = trajectories.env.dummy_action.repeat( + max_len + 1, len(trajectories), 1 ).to( actions ) # shape (max_len + 1, n_trajectories, *action_dim) @@ -504,9 +494,9 @@ def reverse_backward_trajectories( # Assign reversed actions to new_actions new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]] - new_actions[torch.arange(len(trajectories)), seq_lengths] = ( - trajectories.env.n_actions - 1 - ) # FIXME: This can be problematic if action_dim != 1 (e.g. continuous actions) + new_actions[ + torch.arange(len(trajectories)), seq_lengths + ] = trajectories.env.exit_action # Assign reversed states to new_states assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0" @@ -539,9 +529,11 @@ def reverse_backward_trajectories( # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_actions = torch.full( - (max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1 - ).to(actions) + _new_actions = trajectories.env.dummy_action.repeat( + max_len + 1, len(trajectories), 1 + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) _new_states = trajectories.env.sf.repeat( max_len + 2, len(trajectories), 1 ).to( @@ -549,9 +541,9 @@ def reverse_backward_trajectories( ) # shape (max_len + 2, n_trajectories, *state_dim) for i in range(len(trajectories)): - _new_actions[trajectories.when_is_done[i], i] = ( - trajectories.env.n_actions - 1 - ) + _new_actions[ + trajectories.when_is_done[i], i + ] = trajectories.env.exit_action _new_actions[ : trajectories.when_is_done[i], i ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip( diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8223317d..f5e359b0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -554,8 +554,6 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 bs = prev_trajectories.n_trajectories device = prev_trajectories.states.device - state_shape = prev_trajectories.states.state_shape - action_shape = prev_trajectories.env.action_shape env = prev_trajectories.env # Obtain full trajectories by concatenating the backward and forward parts. @@ -590,12 +588,12 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # Prepare the new states and actions # Note that these are initialized in transposed shapes - new_trajectories_states_tsr = torch.full( - (bs, max_traj_len + 1, *state_shape), -1 - ).to(prev_trajectories.states.tensor) - new_trajectories_actions_tsr = torch.full( - (bs, max_traj_len, *action_shape), -1 - ).to(prev_trajectories.actions.tensor) + new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to( + prev_trajectories.states.tensor + ) + new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to( + prev_trajectories.actions.tensor + ) # Assign the first part (backtracked from backward policy) of the trajectory prev_mask_truc = prev_mask[:, :max_n_prev] @@ -664,11 +662,11 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_trajectories_states_tsr = torch.full( - (max_traj_len + 1, bs, *state_shape), -1 - ).to(prev_trajectories.states.tensor) - _new_trajectories_actions_tsr = torch.full( - (max_traj_len, bs, *action_shape), -1 + _new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to( + prev_trajectories.states.tensor + ) + _new_trajectories_actions_tsr = env.dummy_action.repeat( + max_traj_len, bs, 1 ).to(prev_trajectories.actions.tensor) if save_logprobs: diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index aff12b60..a405fa8d 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -60,6 +60,7 @@ class BoxArgs(CommonArgs): gamma_scheduler: float = 0.5 scheduler_milestone: int = 2500 lr_F: float = 1e-2 + use_local_search: bool = False @pytest.mark.parametrize("ndim", [2, 4]) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index b501dad8..dfcef4ac 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -32,6 +32,7 @@ BoxStateFlowModule, ) from gfn.modules import ScalarEstimator +from gfn.samplers import LocalSearchSampler, Sampler from gfn.utils.common import set_seed DEFAULT_SEED = 4444 @@ -179,6 +180,20 @@ def main(args): # noqa: C901 ) assert gflownet is not None, f"No gflownet for loss {args.loss}" + gflownet = gflownet.to(device_str) + + if not args.use_local_search: + sampler = Sampler(estimator=pf_estimator) + local_search_params = {} + else: + sampler = LocalSearchSampler( + pf_estimator=pf_estimator, pb_estimator=pb_estimator + ) + local_search_params = { + "n_local_search_loops": args.n_local_search_loops, + "back_ratio": args.back_ratio, + "use_metropolis_hastings": args.use_metropolis_hastings, + } # 3. Create the optimizer and scheduler @@ -226,13 +241,13 @@ def main(args): # noqa: C901 states_visited = 0 jsd = float("inf") - for iteration in trange(n_iterations): + for iteration in trange(n_iterations, dynamic_ncols=True): if iteration % 1000 == 0: print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") # Sampling on-policy, so we save logprobs for faster computation. - trajectories = gflownet.sample_trajectories( - env, save_logprobs=True, n=args.batch_size + trajectories = sampler.sample_trajectories( + env, save_logprobs=True, n=args.batch_size, **local_search_params ) training_samples = gflownet.to_training_samples(trajectories) @@ -399,6 +414,31 @@ def main(args): # noqa: C901 help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", ) + parser.add_argument( + "--use_local_search", + action="store_true", + help="Use local search to sample the next state", + ) + + # Local search parameters. + parser.add_argument( + "--n_local_search_loops", + type=int, + default=2, + help="Number of local search loops", + ) + parser.add_argument( + "--back_ratio", + type=float, + default=0.5, + help="The ratio of the number of backward steps to the length of the trajectory", + ) + parser.add_argument( + "--use_metropolis_hastings", + action="store_true", + help="Use Metropolis-Hastings acceptance criterion", + ) + parser.add_argument( "--n_trajectories", type=int,