Skip to content

Commit

Permalink
fix: clamp rewards before taking into log-space (rather than add a sm…
Browse files Browse the repository at this point in the history
…all epsilon)
  • Loading branch information
julienroyd committed Mar 15, 2023
1 parent b3c1238 commit 48198cf
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/gflownet/tasks/qm9/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]:
def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
if isinstance(flat_reward, list):
flat_reward = torch.tensor(flat_reward)
flat_reward = flat_reward.squeeze()
flat_reward = flat_reward.squeeze().clamp(min=1e-30).log()
assert len(flat_reward.shape) == len(cond_info['beta'].shape), \
f"dangerous shape mismatch: {flat_reward.shape} vs {cond_info['beta'].shape}"
return RewardScalar(torch.log(flat_reward + 1e-8) * cond_info['beta'])
return RewardScalar(flat_reward * cond_info['beta'])

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined]
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def sample_conditional_information(self, n: int) -> Dict[str, Tensor]:
def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
if isinstance(flat_reward, list):
flat_reward = torch.tensor(flat_reward)
flat_reward = flat_reward.squeeze()
flat_reward = flat_reward.squeeze().clamp(min=1e-30).log()
assert len(flat_reward.shape) == len(cond_info['beta'].shape), \
f"dangerous shape mismatch: {flat_reward.shape} vs {cond_info['beta'].shape}"
return RewardScalar(torch.log(flat_reward + 1e-8) * cond_info['beta'])
return RewardScalar(flat_reward * cond_info['beta'])

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
graphs = [bengio2021flow.mol2graph(i) for i in mols]
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat
flat_reward = torch.stack(flat_reward)
else:
flat_reward = torch.tensor(flat_reward)
scalar_logreward = torch.log((flat_reward * cond_info['preferences']).sum(1) + 1e-8)
scalar_logreward = (flat_reward * cond_info['preferences']).sum(1).clamp(min=1e-30).log()
assert len(scalar_logreward.shape) == len(cond_info['beta'].shape), \
f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}"
return RewardScalar(scalar_logreward * cond_info['beta'])
Expand Down

0 comments on commit 48198cf

Please sign in to comment.