From 89ffe04a00e4944cb620a5da5829e6423bdd0490 Mon Sep 17 00:00:00 2001 From: "sanghyeok.choi" Date: Fri, 24 Jan 2025 23:12:56 +0900 Subject: [PATCH 1/3] make reverse_backward_trajectories and LocalSearchSampler support continuous case --- src/gfn/containers/trajectories.py | 34 +++++++---------------- src/gfn/samplers.py | 16 +++++------ tutorials/examples/train_box.py | 44 ++++++++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index d010f503..68bfc5ad 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -432,17 +432,7 @@ def reverse_backward_trajectories( trajectories: Trajectories, debug: bool = False ) -> Trajectories: """Reverses a backward trajectory""" - # FIXME: This method is not compatible with continuous GFN. - assert trajectories.is_backward, "Trajectories must be backward." - new_actions = torch.full( - ( - trajectories.max_length + 1, - len(trajectories), - *trajectories.actions.action_shape, - ), - -1, - ) # env.sf should never be None unless something went wrong during class # instantiation. @@ -466,11 +456,9 @@ def reverse_backward_trajectories( ) # shape (max_len + 1, n_trajectories, *state_dim) # Initialize new actions and states - new_actions = torch.full( - (max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1 - ).to( - actions - ) # shape (max_len + 1, n_trajectories, *action_dim) + new_actions = trajectories.env.dummy_action.repeat( + max_len + 1, len(trajectories), 1 + ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to( states ) # shape (max_len + 2, n_trajectories, *state_dim) @@ -505,8 +493,8 @@ def reverse_backward_trajectories( # Assign reversed actions to new_actions new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]] new_actions[torch.arange(len(trajectories)), seq_lengths] = ( - trajectories.env.n_actions - 1 - ) # FIXME: This can be problematic if action_dim != 1 (e.g. continuous actions) + trajectories.env.exit_action + ) # Assign reversed states to new_states assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0" @@ -539,18 +527,16 @@ def reverse_backward_trajectories( # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_actions = torch.full( - (max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1 - ).to(actions) + _new_actions = trajectories.env.dummy_action.repeat( + max_len + 1, len(trajectories), 1 + ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) _new_states = trajectories.env.sf.repeat( max_len + 2, len(trajectories), 1 - ).to( - states - ) # shape (max_len + 2, n_trajectories, *state_dim) + ).to(states) # shape (max_len + 2, n_trajectories, *state_dim) for i in range(len(trajectories)): _new_actions[trajectories.when_is_done[i], i] = ( - trajectories.env.n_actions - 1 + trajectories.env.exit_action ) _new_actions[ : trajectories.when_is_done[i], i diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8223317d..b1701809 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -590,11 +590,11 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # Prepare the new states and actions # Note that these are initialized in transposed shapes - new_trajectories_states_tsr = torch.full( - (bs, max_traj_len + 1, *state_shape), -1 + new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( + bs, max_traj_len + 1, 1 ).to(prev_trajectories.states.tensor) - new_trajectories_actions_tsr = torch.full( - (bs, max_traj_len, *action_shape), -1 + new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( + bs, max_traj_len, 1 ).to(prev_trajectories.actions.tensor) # Assign the first part (backtracked from backward policy) of the trajectory @@ -664,11 +664,11 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_trajectories_states_tsr = torch.full( - (max_traj_len + 1, bs, *state_shape), -1 + _new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( + max_traj_len + 1, bs, 1 ).to(prev_trajectories.states.tensor) - _new_trajectories_actions_tsr = torch.full( - (max_traj_len, bs, *action_shape), -1 + _new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( + max_traj_len, bs, 1 ).to(prev_trajectories.actions.tensor) if save_logprobs: diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index b501dad8..65b25831 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -32,6 +32,7 @@ BoxStateFlowModule, ) from gfn.modules import ScalarEstimator +from gfn.samplers import Sampler, LocalSearchSampler from gfn.utils.common import set_seed DEFAULT_SEED = 4444 @@ -179,6 +180,18 @@ def main(args): # noqa: C901 ) assert gflownet is not None, f"No gflownet for loss {args.loss}" + gflownet = gflownet.to(device_str) + + if not args.use_local_search: + sampler = Sampler(estimator=pf_estimator) + local_search_params = {} + else: + sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator) + local_search_params = { + "n_local_search_loops": args.n_local_search_loops, + "back_ratio": args.back_ratio, + "use_metropolis_hastings": args.use_metropolis_hastings, + } # 3. Create the optimizer and scheduler @@ -226,13 +239,13 @@ def main(args): # noqa: C901 states_visited = 0 jsd = float("inf") - for iteration in trange(n_iterations): + for iteration in trange(n_iterations, dynamic_ncols=True): if iteration % 1000 == 0: print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") # Sampling on-policy, so we save logprobs for faster computation. - trajectories = gflownet.sample_trajectories( - env, save_logprobs=True, n=args.batch_size + trajectories = sampler.sample_trajectories( + env, save_logprobs=True, n=args.batch_size, **local_search_params ) training_samples = gflownet.to_training_samples(trajectories) @@ -399,6 +412,31 @@ def main(args): # noqa: C901 help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", ) + parser.add_argument( + "--use_local_search", + action="store_true", + help="Use local search to sample the next state", + ) + + # Local search parameters. + parser.add_argument( + "--n_local_search_loops", + type=int, + default=2, + help="Number of local search loops", + ) + parser.add_argument( + "--back_ratio", + type=float, + default=0.5, + help="The ratio of the number of backward steps to the length of the trajectory", + ) + parser.add_argument( + "--use_metropolis_hastings", + action="store_true", + help="Use Metropolis-Hastings acceptance criterion", + ) + parser.add_argument( "--n_trajectories", type=int, From 25ac995ea2253980cf939a75074190da5c93d73b Mon Sep 17 00:00:00 2001 From: "sanghyeok.choi" Date: Fri, 24 Jan 2025 23:43:09 +0900 Subject: [PATCH 2/3] black applied --- src/gfn/containers/trajectories.py | 24 +++++++++++++++--------- src/gfn/samplers.py | 22 ++++++++++------------ tutorials/examples/train_box.py | 6 ++++-- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 68bfc5ad..fe6438e6 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -458,7 +458,9 @@ def reverse_backward_trajectories( # Initialize new actions and states new_actions = trajectories.env.dummy_action.repeat( max_len + 1, len(trajectories), 1 - ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to( states ) # shape (max_len + 2, n_trajectories, *state_dim) @@ -492,9 +494,9 @@ def reverse_backward_trajectories( # Assign reversed actions to new_actions new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]] - new_actions[torch.arange(len(trajectories)), seq_lengths] = ( - trajectories.env.exit_action - ) + new_actions[ + torch.arange(len(trajectories)), seq_lengths + ] = trajectories.env.exit_action # Assign reversed states to new_states assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0" @@ -529,15 +531,19 @@ def reverse_backward_trajectories( if debug: _new_actions = trajectories.env.dummy_action.repeat( max_len + 1, len(trajectories), 1 - ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) _new_states = trajectories.env.sf.repeat( max_len + 2, len(trajectories), 1 - ).to(states) # shape (max_len + 2, n_trajectories, *state_dim) + ).to( + states + ) # shape (max_len + 2, n_trajectories, *state_dim) for i in range(len(trajectories)): - _new_actions[trajectories.when_is_done[i], i] = ( - trajectories.env.exit_action - ) + _new_actions[ + trajectories.when_is_done[i], i + ] = trajectories.env.exit_action _new_actions[ : trajectories.when_is_done[i], i ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip( diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index b1701809..f5e359b0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -554,8 +554,6 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 bs = prev_trajectories.n_trajectories device = prev_trajectories.states.device - state_shape = prev_trajectories.states.state_shape - action_shape = prev_trajectories.env.action_shape env = prev_trajectories.env # Obtain full trajectories by concatenating the backward and forward parts. @@ -590,12 +588,12 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # Prepare the new states and actions # Note that these are initialized in transposed shapes - new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( - bs, max_traj_len + 1, 1 - ).to(prev_trajectories.states.tensor) - new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( - bs, max_traj_len, 1 - ).to(prev_trajectories.actions.tensor) + new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to( + prev_trajectories.states.tensor + ) + new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to( + prev_trajectories.actions.tensor + ) # Assign the first part (backtracked from backward policy) of the trajectory prev_mask_truc = prev_mask[:, :max_n_prev] @@ -664,10 +662,10 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( - max_traj_len + 1, bs, 1 - ).to(prev_trajectories.states.tensor) - _new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( + _new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to( + prev_trajectories.states.tensor + ) + _new_trajectories_actions_tsr = env.dummy_action.repeat( max_traj_len, bs, 1 ).to(prev_trajectories.actions.tensor) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 65b25831..dfcef4ac 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -32,7 +32,7 @@ BoxStateFlowModule, ) from gfn.modules import ScalarEstimator -from gfn.samplers import Sampler, LocalSearchSampler +from gfn.samplers import LocalSearchSampler, Sampler from gfn.utils.common import set_seed DEFAULT_SEED = 4444 @@ -186,7 +186,9 @@ def main(args): # noqa: C901 sampler = Sampler(estimator=pf_estimator) local_search_params = {} else: - sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator) + sampler = LocalSearchSampler( + pf_estimator=pf_estimator, pb_estimator=pb_estimator + ) local_search_params = { "n_local_search_loops": args.n_local_search_loops, "back_ratio": args.back_ratio, From 7d6e3a370b243890d805a518d6d71e15d9372f7b Mon Sep 17 00:00:00 2001 From: "sanghyeok.choi" Date: Sat, 25 Jan 2025 00:09:41 +0900 Subject: [PATCH 3/3] fix test_scripts --- tutorials/examples/test_scripts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index aff12b60..a405fa8d 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -60,6 +60,7 @@ class BoxArgs(CommonArgs): gamma_scheduler: float = 0.5 scheduler_milestone: int = 2500 lr_F: float = 1e-2 + use_local_search: bool = False @pytest.mark.parametrize("ndim", [2, 4])