diff --git a/config/gflownet/forwardlooking.yaml b/config/gflownet/forwardlooking.yaml new file mode 100644 index 000000000..c2c641719 --- /dev/null +++ b/config/gflownet/forwardlooking.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: forwardlooking + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index 22dd6dd10..c33e52eaf 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -34,6 +34,8 @@ optimizer: # From original implementation bootstrap_tau: 0.0 clip_grad_norm: 0.0 +# State flow modelling +state_flow: null # If True, compute rewards in batches batch_reward: True # Force zero probability of sampling invalid actions diff --git a/config/state_flow/mlp.yaml b/config/gflownet/state_flow/mlp.yaml similarity index 100% rename from config/state_flow/mlp.yaml rename to config/gflownet/state_flow/mlp.yaml diff --git a/config/main.yaml b/config/main.yaml index c550bfaf6..7ea98e735 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -3,7 +3,6 @@ defaults: - env: grid - gflownet: flowmatch - policy: mlp_${gflownet} - - state_flow: null - proxy: corners - logger: wandb - user: alex diff --git a/config/policy/mlp_forwardlooking.yaml b/config/policy/mlp_forwardlooking.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_forwardlooking.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 59fce6542..b66f3e91e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -77,15 +77,12 @@ def __init__( if optimizer.loss in ["flowmatch", "flowmatching"]: self.loss = "flowmatch" self.logZ = None - self.non_terminal_rewards = False elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) - self.non_terminal_rewards = False elif optimizer.loss in ["forwardlooking", "fl"]: self.loss = "forwardlooking" self.logZ = None - self.non_terminal_rewards = True else: print("Unkown loss. Using flowmatch as default") self.loss = "flowmatch" @@ -428,7 +425,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) # ON-POLICY FORWARD trajectories @@ -438,7 +434,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) while envs: # Sample actions @@ -465,7 +460,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: @@ -501,7 +495,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: @@ -754,8 +747,7 @@ def forwardlooking_loss(self, it, batch): # Can be optimised by reusing states_log_flflow and batch.get_parent_indices parents_log_flflow = self.state_flow(parents_policy) - assert batch.non_terminal_rewards - rewards_states = batch.get_rewards() + rewards_states = batch.get_rewards(do_non_terminating=True) rewards_parents = batch.get_rewards_parents() energies_states = -torch.log(rewards_states) energies_parents = -torch.log(rewards_parents) @@ -869,7 +861,6 @@ def estimate_logprobs_data( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) # Create an environment for each data point and trajectory and set the state envs = [] @@ -973,7 +964,6 @@ def train(self): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) for j in range(self.sttr): sub_batch, times = self.sample_batch( @@ -1268,7 +1258,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) self.random_action_prob = 0 t = time.time() @@ -1296,7 +1285,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index b6da02691..2ae8e0d30 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -39,7 +39,6 @@ def __init__( env: Optional[GFlowNetEnv] = None, device: Union[str, torch.device] = "cpu", float_type: Union[int, torch.dtype] = 32, - non_terminal_rewards: bool = False, ): """ env : GFlowNetEnv @@ -57,8 +56,6 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_type) - # Whether rewards should be computed for non-terminal states - self.non_terminal_rewards = non_terminal_rewards # Generic environment, properties and dictionary of state and forward mask of # source (as tensor) if env is not None: @@ -531,11 +528,6 @@ def get_parents_indices(self): self._compute_parents() return self.parents_indices - def get_parent_is_source(self): - if self.parents_available is False: - self._compute_parents() - return self.parents_indices == -1 - def _compute_parents(self): """ Obtains the parent (single parent for each state) of all states in the batch. @@ -842,7 +834,9 @@ def _compute_masks_backward(self): self.masks_backward_available = True def get_rewards( - self, force_recompute: Optional[bool] = False + self, + force_recompute: Optional[bool] = False, + do_non_terminating: Optional[bool] = False, ) -> TensorType["n_states"]: """ Returns the rewards of all states in the batch (including not done). @@ -851,26 +845,37 @@ def get_rewards( ---- force_recompute : bool If True, the rewards are recomputed even if they are available. + + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ if self.rewards_available is False or force_recompute is True: - self._compute_rewards() + self._compute_rewards(do_non_terminating) return self.rewards - def _compute_rewards(self): + def _compute_rewards(self, do_non_terminating: Optional[bool] = False): """ Computes rewards for all self.states by first converting the states into proxy format. The result is stored in self.rewards as a torch.tensor + + Args + ---- + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ - self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) - done = self.get_done() - if self.non_terminal_rewards: + if do_non_terminating: self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy())) - elif len(done) > 0: - states_proxy_done = self.get_terminating_states(proxy=True) - self.rewards[done] = self.env.proxy2reward( - self.env.proxy(states_proxy_done) - ) + else: + self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + done = self.get_done() + if len(done) > 0: + states_proxy_done = self.get_terminating_states(proxy=True) + self.rewards[done] = self.env.proxy2reward( + self.env.proxy(states_proxy_done) + ) self.rewards_available = True def get_rewards_parents(self) -> TensorType["n_states"]: @@ -888,8 +893,8 @@ def _compute_rewards_parents(self): """ state_rewards = self.get_rewards() self.rewards_parents = torch.zeros_like(state_rewards) - parent_is_source = self.get_parent_is_source() parent_indices = self.get_parents_indices() + parent_is_source = parent_indices == -1 self.rewards_parents[~parent_is_source] = self.rewards[ parent_indices[~parent_is_source] ] diff --git a/main.py b/main.py index 6e10cf771..d7aab0c5f 100644 --- a/main.py +++ b/main.py @@ -60,10 +60,10 @@ def main(config): float_precision=config.float_precision, base=forward_policy, ) - - if config.gflownet.optimizer.loss in ["forwardlooking", "fl"]: + # State flow + if config.gflownet.state_flow is not None: state_flow = hydra.utils.instantiate( - config.state_flow, + config.gflownet.state_flow, env=env, device=config.device, float_precision=config.float_precision, @@ -71,7 +71,7 @@ def main(config): ) else: state_flow = None - + # GFlowNet Agent gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, @@ -83,6 +83,8 @@ def main(config): buffer=config.env.buffer, logger=logger, ) + + # Train GFlowNet gflownet.train() # Sample from trained GFlowNet