From 5a135d9cae55eeedcffcef0a16ec1f651d7fc725 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 16:44:20 -0500 Subject: [PATCH 1/4] Implementation of the detailed balance loss and the necessary configuration files --- config/gflownet/detailedbalance.yaml | 9 ++++ config/policy/mlp_detailedbalance.yaml | 7 +++ gflownet/gflownet.py | 66 +++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 config/gflownet/detailedbalance.yaml create mode 100644 config/policy/mlp_detailedbalance.yaml diff --git a/config/gflownet/detailedbalance.yaml b/config/gflownet/detailedbalance.yaml new file mode 100644 index 000000000..073ae99f4 --- /dev/null +++ b/config/gflownet/detailedbalance.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: detailedbalance + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/policy/mlp_detailedbalance.yaml b/config/policy/mlp_detailedbalance.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_detailedbalance.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 7372a215b..09934949d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -80,6 +80,9 @@ def __init__( elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) + elif optimizer.loss in ["detailedbalance", "db"]: + self.loss = "detailedbalance" + self.logZ = None elif optimizer.loss in ["forwardlooking", "fl"]: self.loss = "forwardlooking" self.logZ = None @@ -198,7 +201,7 @@ def parameters(self): raise ValueError("Backward Policy cannot be a model in flowmatch.") parameters += list(self.backward_policy.model.parameters()) if self.state_flow is not None: - if self.loss != "forwardlooking": + if self.loss not in ["detailedbalance", "forwardlooking"]: raise ValueError(f"State flow cannot be trained with {self.loss} loss.") parameters += list(self.state_flow.model.parameters()) return parameters @@ -679,6 +682,65 @@ def trajectorybalance_loss(self, it, batch): ) return loss, loss, loss + def detailedbalance_loss(self, it, batch): + """ + Computes the Detailed Balance GFlowNet loss of a batch + Reference : https://arxiv.org/pdf/2201.13259.pdf (eq 11) + + Args + ---- + it : int + Iteration + + batch : Batch + A batch of data, containing all the states in the trajectories. + + + Returns + ------- + loss : float + + term_loss : float + Loss of the terminal nodes only + + nonterm_loss : float + Loss of the intermediate nodes only + """ + + assert batch.is_valid() + # Get necessary tensors from batch + states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) + actions = batch.get_actions() + parents = batch.get_parents(policy=False) + parents_policy = batch.get_parents(policy=True) + done = batch.get_done() + rewards = batch.get_terminating_rewards(sort_by="insertion") + + # Get logprobs + masks_f = batch.get_masks_forward(of_parents=True) + policy_output_f = self.forward_policy(parents_policy) + logprobs_f = self.env.get_logprobs( + policy_output_f, actions, masks_f, parents, is_backward=False + ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) + + # Get logflows + logflow_states = self.state_flow(states_policy).squeeze() + logflow_states[done.eq(1)] = rewards + # TODO: Optimise by reusing logflow_states and batch.get_parent_indices + logflow_parents = self.state_flow(parents_policy).squeeze() + + # Detailed balance loss + loss = ( + (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2).mean() + ) + return loss, loss, loss + def forwardlooking_loss(self, it, batch): """ Computes the Forward-Looking GFlowNet loss of a batch @@ -957,6 +1019,8 @@ def train(self): losses = self.trajectorybalance_loss( it * self.ttsr + j, batch ) # returns (opt loss, *metrics) + elif self.loss == "detailedbalance": + losses = self.detailedbalance_loss(it * self.ttsr + j, batch) elif self.loss == "forwardlooking": losses = self.forwardlooking_loss(it * self.ttsr + j, batch) else: From f27a8262ea8f17dcf1ebf582de86330be0d26842 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 17:37:41 -0500 Subject: [PATCH 2/4] Fix: logflows of terminating is log(rewards) --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 09934949d..abbd82053 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -731,7 +731,7 @@ def detailedbalance_loss(self, it, batch): # Get logflows logflow_states = self.state_flow(states_policy).squeeze() - logflow_states[done.eq(1)] = rewards + logflow_states[done.eq(1)] = torch.log(rewards) # TODO: Optimise by reusing logflow_states and batch.get_parent_indices logflow_parents = self.state_flow(parents_policy).squeeze() From 293010f29fb754404587f0d8f60091bd2d0c43c0 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:08:14 -0500 Subject: [PATCH 3/4] Compute DB loss on terminating and intermediate states --- gflownet/gflownet.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index abbd82053..b8ed7d33e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -736,10 +736,11 @@ def detailedbalance_loss(self, it, batch): logflow_parents = self.state_flow(parents_policy).squeeze() # Detailed balance loss - loss = ( - (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2).mean() - ) - return loss, loss, loss + loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2) + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate def forwardlooking_loss(self, it, batch): """ From 798b5eff3ab7eff23018d0941fb0951a35faf935 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:09:26 -0500 Subject: [PATCH 4/4] Remove squeeze in DB loss --- gflownet/gflownet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index b8ed7d33e..844b560ea 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -730,10 +730,10 @@ def detailedbalance_loss(self, it, batch): ) # Get logflows - logflow_states = self.state_flow(states_policy).squeeze() + logflow_states = self.state_flow(states_policy) logflow_states[done.eq(1)] = torch.log(rewards) # TODO: Optimise by reusing logflow_states and batch.get_parent_indices - logflow_parents = self.state_flow(parents_policy).squeeze() + logflow_parents = self.state_flow(parents_policy) # Detailed balance loss loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2)