Skip to content

Commit

Permalink
Merge branch 'trunk' into qm9
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP authored Feb 8, 2024
2 parents a4d467b + ef5f2cb commit 1491bcf
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 70 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ If package dependencies seem not to work, you may need to install the exact froz

## Developing & Contributing

TODO: Write Contributing.md.
External contributions are welcome. We use `tox` to run tests and linting, and `pre-commit` to run checks before committing.
To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively.
1 change: 1 addition & 0 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def __iter__(self):
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
)
if num_online > 0:
extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item()
for hook in self.log_hooks:
extra_info.update(
hook(
Expand Down
27 changes: 1 addition & 26 deletions src/gflownet/tasks/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import List


@dataclass
Expand All @@ -13,22 +13,6 @@ class SEHMOOTaskConfig:
Attributes
----------
use_steer_thermometer : bool
Whether to use a thermometer encoding for the steering.
preference_type : Optional[str]
The preference sampling distribution, defaults to "dirichlet".
focus_type : Union[list, str, None]
The type of focus distribtuion used, see SEHMOOTask.setup_focus_regions.
focus_cosim : float
The cosine similarity threshold for the focus distribution.
focus_limit_coef : float
The smoothing coefficient for the focus reward.
focus_model_training_limits : Optional[Tuple[int, int]]
The training limits for the focus sampling model (if used).
focus_model_state_space_res : Optional[int]
The state space resolution for the focus sampling model (if used).
max_train_it : Optional[int]
The maximum number of training iterations for the focus sampling model (if used).
n_valid : int
The number of valid cond_info tensors to sample
n_valid_repeats : int
Expand All @@ -37,15 +21,6 @@ class SEHMOOTaskConfig:
The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"].
"""

use_steer_thermometer: bool = False
preference_type: Optional[str] = "dirichlet"
focus_type: Optional[str] = None
focus_dirs_listed: Optional[List[List[float]]] = None
focus_cosim: float = 0.0
focus_limit_coef: float = 1.0
focus_model_training_limits: Optional[Tuple[int, int]] = None
focus_model_state_space_res: Optional[int] = None
max_train_it: Optional[int] = None
n_valid: int = 15
n_valid_repeats: int = 128
objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"])
Expand Down
13 changes: 11 additions & 2 deletions src/gflownet/tasks/qm9/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
from gflownet.utils.conditioning import TemperatureConditional
from gflownet.utils.transforms import to_logreward


class QM9GapTask(GFNTask):
Expand Down Expand Up @@ -64,7 +65,15 @@ def inverse_flat_reward_transform(self, rp):
def load_task_models(self, path, device):
gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0))
# TODO: this path should be part of the config?
state_dict = torch.load(path, map_location=device)
try:
state_dict = torch.load(path)
except Exception as e:
print(
"Could not load model.",
e,
"\nModel weights can be found at",
"https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt",
)
gap_model.load_state_dict(state_dict)
gap_model.to(device)
gap_model, self.device = self._wrap_model(gap_model, send_to_device=True)
Expand All @@ -74,7 +83,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward))
return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward)))

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined]
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
from gflownet.utils.conditioning import TemperatureConditional
from gflownet.utils.transforms import to_logreward


class SEHTask(GFNTask):
Expand Down Expand Up @@ -59,7 +60,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward))
return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward)))

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
graphs = [bengio2021flow.mol2graph(i) for i in mols]
Expand Down
64 changes: 39 additions & 25 deletions src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from gflownet.utils import metrics, sascore
from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences
from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook
from gflownet.utils.transforms import to_logreward


class SEHMOOTask(SEHTask):
Expand Down Expand Up @@ -146,20 +147,26 @@ def relabel_condinfo_and_logrewards(
return cond_info, log_rewards

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
"""
Compute the logreward from the flat_reward and the conditional information
"""
if isinstance(flat_reward, list):
if isinstance(flat_reward[0], Tensor):
flat_reward = torch.stack(flat_reward)
else:
flat_reward = torch.tensor(flat_reward)

scalarized_reward = self.pref_cond.transform(cond_info, flat_reward)
focused_reward = (
self.focus_cond.transform(cond_info, flat_reward, scalarized_reward)
scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward)
scalarized_logrewards = to_logreward(scalarized_rewards)
focused_logreward = (
self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards)
if self.focus_cond is not None
else scalarized_reward
else scalarized_logrewards
)
tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward)
return RewardScalar(tempered_reward)
tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward)
clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward)

return RewardScalar(clamped_logreward)

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
graphs = [bengio2021flow.mol2graph(i) for i in mols]
Expand Down Expand Up @@ -260,41 +267,42 @@ def setup(self):
self.cfg.log_dir,
compute_igd=True,
compute_pc_entropy=True,
compute_focus_accuracy=True if self.cfg.task.seh_moo.focus_type is not None else False,
focus_cosim=self.cfg.task.seh_moo.focus_cosim,
compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False,
focus_cosim=self.cfg.cond.focus_region.focus_cosim,
)
)
# instantiate preference and focus conditioning vectors for validation

tcfg = self.cfg.task.seh_moo
n_obj = len(tcfg.objectives)
n_obj = len(self.cfg.task.seh_moo.objectives)
cond_cfg = self.cfg.cond

# making sure hyperparameters for preferences and focus regions are consistent
if not (
tcfg.focus_type is None
or tcfg.focus_type == "centered"
or (isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) == 1)
cond_cfg.focus_region.focus_type is None
or cond_cfg.focus_region.focus_type == "centered"
or (isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) == 1)
):
assert tcfg.preference_type is None, (
f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} "
f"and preference_type={tcfg.preference_type}"
assert cond_cfg.weighted_prefs.preference_type is None, (
f"Cannot use preferences with multiple focus regions, "
f"here focus_type={cond_cfg.focus_region.focus_type} "
f"and preference_type={cond_cfg.weighted_prefs.preference_type }"
)

if isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) > 1:
n_valid = len(tcfg.focus_type)
if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1:
n_valid = len(cond_cfg.focus_region.focus_type)
else:
n_valid = tcfg.n_valid
n_valid = self.cfg.task.seh_moo.n_valid

# preference vectors
if tcfg.preference_type is None:
if cond_cfg.weighted_prefs.preference_type is None:
valid_preferences = np.ones((n_valid, n_obj))
elif tcfg.preference_type == "dirichlet":
elif cond_cfg.weighted_prefs.preference_type == "dirichlet":
valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1")
elif tcfg.preference_type == "seeded_single":
elif cond_cfg.weighted_prefs.preference_type == "seeded_single":
seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid)
valid_preferences = seeded_prefs[0].reshape((1, n_obj))
self.task.seeded_preference = valid_preferences[0]
elif tcfg.preference_type == "seeded_many":
elif cond_cfg.weighted_prefs.preference_type == "seeded_many":
valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid)
else:
raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}")
Expand All @@ -318,8 +326,8 @@ def setup(self):
else:
valid_cond_vector = valid_preferences

self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid)
self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats)
self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid)
self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats)
self.valid_sampling_hooks.append(self._top_k_hook)

self.algo.task = self.task
Expand Down Expand Up @@ -348,6 +356,12 @@ def _save_state(self, it):
self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir))
return super()._save_state(it)

def run(self):
super().run()
for hook in self.sampling_hooks:
if hasattr(hook, "terminate"):
hook.terminate()


class RepeatedCondInfoDataset:
def __init__(self, cond_info_vectors, repeat):
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/tasks/toy_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gflownet.online_trainer import StandardOnlineTrainer
from gflownet.trainer import FlatRewards, GFNTask, RewardScalar
from gflownet.utils.conditioning import TemperatureConditional
from gflownet.utils.transforms import to_logreward


class ToySeqTask(GFNTask):
Expand All @@ -34,7 +35,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten
return self.temperature_conditional.sample(n)

def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward))
return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward)))

def compute_flat_rewards(self, objs: List[str]) -> Tuple[FlatRewards, Tensor]:
rs = torch.tensor([sum([s.count(p) for p in self.seqs]) for s in objs]).float() / self.norm
Expand Down
23 changes: 9 additions & 14 deletions src/gflownet/utils/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def sample(self, n):
assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}"
return {"beta": torch.tensor(beta), "encoding": beta_enc}

def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor:
scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log()
assert len(scalar_logreward.shape) == len(
def transform(self, cond_info: Dict[str, Tensor], logreward: Tensor) -> Tensor:
assert len(logreward.shape) == len(
cond_info["beta"].shape
), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}"
return scalar_logreward * cond_info["beta"]
), f"dangerous shape mismatch: {logreward.shape} vs {cond_info['beta'].shape}"
return logreward * cond_info["beta"]

def encode(self, conditional: Tensor) -> Tensor:
cfg = self.cfg.cond.temperature
Expand Down Expand Up @@ -116,10 +115,9 @@ def sample(self, n):
return {"preferences": preferences, "encoding": self.encode(preferences)}

def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor:
# NO LOG NO LOG NO LOG NO LOG NO LOG NO LOG
scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30)
assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}"
return scalar_logreward
scalar_reward = (flat_reward * cond_info["preferences"]).sum(1)
assert len(scalar_reward.shape) == 1, f"scalar_reward should be a 1D array, got {scalar_reward.shape}"
return scalar_reward

def encoding_size(self):
return max(1, self.num_thermometer_dim * self.num_objectives)
Expand Down Expand Up @@ -230,11 +228,8 @@ def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_l
focus_coef, in_focus_mask = metrics.compute_focus_coef(
flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef
)
if scalar_logreward is None:
scalar_logreward = torch.log(focus_coef)
else:
scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask])
scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward
scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask])
scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward

return scalar_logreward

Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/utils/multiobjective_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info):

return info

def terminate(self):
self.stop.set()
self.pareto_thread.join()


class TopKHook:
def __init__(self, k, repeats, num_preferences):
Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from torch import Tensor


def to_logreward(reward: Tensor) -> Tensor:
return reward.squeeze().clamp(min=1e-30).log()


def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) -> Tensor:
"""Thermometer encoding of a scalar quantity.
Expand Down

0 comments on commit 1491bcf

Please sign in to comment.