diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 844b560ea..6acc0ddfe 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -730,13 +730,13 @@ def detailedbalance_loss(self, it, batch): ) # Get logflows - logflow_states = self.state_flow(states_policy) - logflow_states[done.eq(1)] = torch.log(rewards) - # TODO: Optimise by reusing logflow_states and batch.get_parent_indices - logflow_parents = self.state_flow(parents_policy) + logflows_states = self.state_flow(states_policy) + logflows_states[done.eq(1)] = torch.log(rewards) + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflows_parents = self.state_flow(parents_policy) # Detailed balance loss - loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2) + loss_all = (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2) loss = loss_all.mean() loss_terminating = loss_all[done].mean() loss_intermediate = loss_all[~done].mean() @@ -769,50 +769,49 @@ def forwardlooking_loss(self, it, batch): assert batch.is_valid() # Get necessary tensors from batch - states_policy = batch.get_states(policy=True) states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) actions = batch.get_actions() - parents_policy = batch.get_parents(policy=True) parents = batch.get_parents(policy=False) - traj_indices = batch.get_trajectory_indices(consecutive=True) + parents_policy = batch.get_parents(policy=True) + rewards_states = batch.get_rewards(do_non_terminating=True) + rewards_parents = batch.get_rewards_parents() done = batch.get_done() - masks_b = batch.get_masks_backward() - policy_output_b = self.backward_policy(states_policy) - logprobs_bkw = self.env.get_logprobs( - policy_output_b, actions, masks_b, states, is_backward=True - ) + # Get logprobs masks_f = batch.get_masks_forward(of_parents=True) policy_output_f = self.forward_policy(parents_policy) - logprobs_fwd = self.env.get_logprobs( + logprobs_f = self.env.get_logprobs( policy_output_f, actions, masks_f, parents, is_backward=False ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) - states_log_flflow = self.state_flow(states_policy) - # forward-looking flow is 1 in the terminal states - states_log_flflow[done.eq(1)] = 0.0 - # Can be optimised by reusing states_log_flflow and batch.get_parent_indices - parents_log_flflow = self.state_flow(parents_policy) - - rewards_states = batch.get_rewards(do_non_terminating=True) - rewards_parents = batch.get_rewards_parents() - energies_states = -torch.log(rewards_states) - energies_parents = -torch.log(rewards_parents) - - per_node_loss = ( - parents_log_flflow - - states_log_flflow - + logprobs_fwd - - logprobs_bkw - + energies_states - - energies_parents + # Get FL logflows + logflflows_states = self.state_flow(states_policy) + # Log FL flow of terminal states is 0 (eq. 9 of paper) + logflflows_states[done.eq(1)] = 0.0 + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflflows_parents = self.state_flow(parents_policy) + + # Get energies transitions + energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states) + + # Forward-looking loss + loss_all = ( + logflflows_parents + - logflflows_states + + logprobs_f + - logprobs_b + + energies_transitions ).pow(2) - - term_loss = per_node_loss[done].mean() - nonterm_loss = per_node_loss[~done].mean() - loss = per_node_loss.mean() - - return loss, term_loss, nonterm_loss + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate @torch.no_grad() def estimate_logprobs_data(