From 48198cf2e6bae8dc4a31e5d85abe11d31a75b1a6 Mon Sep 17 00:00:00 2001 From: Julien Roy Date: Wed, 15 Mar 2023 09:46:30 -0600 Subject: [PATCH] fix: clamp rewards before taking into log-space (rather than add a small epsilon) --- src/gflownet/tasks/qm9/qm9.py | 4 ++-- src/gflownet/tasks/seh_frag.py | 4 ++-- src/gflownet/tasks/seh_frag_moo.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index cc629507..2d80b899 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -103,10 +103,10 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): flat_reward = torch.tensor(flat_reward) - flat_reward = flat_reward.squeeze() + flat_reward = flat_reward.squeeze().clamp(min=1e-30).log() assert len(flat_reward.shape) == len(cond_info['beta'].shape), \ f"dangerous shape mismatch: {flat_reward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(torch.log(flat_reward + 1e-8) * cond_info['beta']) + return RewardScalar(flat_reward * cond_info['beta']) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 2161b809..a727364b 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -87,10 +87,10 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: if isinstance(flat_reward, list): flat_reward = torch.tensor(flat_reward) - flat_reward = flat_reward.squeeze() + flat_reward = flat_reward.squeeze().clamp(min=1e-30).log() assert len(flat_reward.shape) == len(cond_info['beta'].shape), \ f"dangerous shape mismatch: {flat_reward.shape} vs {cond_info['beta'].shape}" - return RewardScalar(torch.log(flat_reward + 1e-8) * cond_info['beta']) + return RewardScalar(flat_reward * cond_info['beta']) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 742bd5f8..94eddbf4 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -106,7 +106,7 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalar_logreward = torch.log((flat_reward * cond_info['preferences']).sum(1) + 1e-8) + scalar_logreward = (flat_reward * cond_info['preferences']).sum(1).clamp(min=1e-30).log() assert len(scalar_logreward.shape) == len(cond_info['beta'].shape), \ f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" return RewardScalar(scalar_logreward * cond_info['beta'])