From defd8fc1bc41e880a2b8b783e6621921ff1bfb64 Mon Sep 17 00:00:00 2001 From: Julien Roy Date: Tue, 6 Feb 2024 15:24:42 -0500 Subject: [PATCH] Fixes double-log, multithreading termination and duplicated config (#117) * fix: synchronize threads after exit by adding MultiObjectiveStatsHook.terminate() method * fix: removed double-log in reward computation * fix: removed duplicate configs from SEHMOOTaskConfig * chore: tox --- src/gflownet/envs/graph_building_env.py | 9 ++- src/gflownet/models/bengio2021flow.py | 1 + src/gflownet/tasks/config.py | 27 +-------- src/gflownet/tasks/qm9/qm9.py | 3 +- src/gflownet/tasks/seh_frag.py | 3 +- src/gflownet/tasks/seh_frag_moo.py | 64 +++++++++++++--------- src/gflownet/tasks/toy_seq.py | 3 +- src/gflownet/utils/conditioning.py | 22 +++----- src/gflownet/utils/focus_model.py | 12 ++-- src/gflownet/utils/multiobjective_hooks.py | 4 ++ src/gflownet/utils/transforms.py | 4 ++ 11 files changed, 76 insertions(+), 76 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 2cba54cb..74ba3e4f 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -536,9 +536,12 @@ def __init__( # This generalizes to edges and non-edges. # Append '_batch' to keys except for 'x', since TG has a special case (done by default for 'x') self.batch = [ - getattr(graphs, f"{k}_batch" if k != "x" else "batch") if k is not None - # None signals a global logit rather than a per-instance logit - else torch.arange(graphs.num_graphs, device=dev) + ( + getattr(graphs, f"{k}_batch" if k != "x" else "batch") + if k is not None + # None signals a global logit rather than a per-instance logit + else torch.arange(graphs.num_graphs, device=dev) + ) for k in keys ] # This is the cumulative sum (prefixed by 0) of N[i]s diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index 9797e15f..d975dfc9 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -7,6 +7,7 @@ In particular, this model class allows us to compare to the same target proxy used in that paper (sEH binding affinity prediction). """ + import gzip import os import pickle # nosec diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index a9f6ac3f..28960399 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List @dataclass @@ -13,22 +13,6 @@ class SEHMOOTaskConfig: Attributes ---------- - use_steer_thermometer : bool - Whether to use a thermometer encoding for the steering. - preference_type : Optional[str] - The preference sampling distribution, defaults to "dirichlet". - focus_type : Union[list, str, None] - The type of focus distribtuion used, see SEHMOOTask.setup_focus_regions. - focus_cosim : float - The cosine similarity threshold for the focus distribution. - focus_limit_coef : float - The smoothing coefficient for the focus reward. - focus_model_training_limits : Optional[Tuple[int, int]] - The training limits for the focus sampling model (if used). - focus_model_state_space_res : Optional[int] - The state space resolution for the focus sampling model (if used). - max_train_it : Optional[int] - The maximum number of training iterations for the focus sampling model (if used). n_valid : int The number of valid cond_info tensors to sample n_valid_repeats : int @@ -37,15 +21,6 @@ class SEHMOOTaskConfig: The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. """ - use_steer_thermometer: bool = False - preference_type: Optional[str] = "dirichlet" - focus_type: Optional[str] = None - focus_dirs_listed: Optional[List[List[float]]] = None - focus_cosim: float = 0.0 - focus_limit_coef: float = 1.0 - focus_model_training_limits: Optional[Tuple[int, int]] = None - focus_model_state_space_res: Optional[int] = None - max_train_it: Optional[int] = None n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index e5b1d29a..5132e154 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -17,6 +17,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class QM9GapTask(GFNTask): @@ -73,7 +74,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 52c63b11..e916f732 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -18,6 +18,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class SEHTask(GFNTask): @@ -59,7 +60,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index f8b432a1..bd597c31 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -22,6 +22,7 @@ from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.transforms import to_logreward class SEHMOOTask(SEHTask): @@ -146,20 +147,26 @@ def relabel_condinfo_and_logrewards( return cond_info, log_rewards def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """ + Compute the logreward from the flat_reward and the conditional information + """ if isinstance(flat_reward, list): if isinstance(flat_reward[0], Tensor): flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) - focused_reward = ( - self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) + scalarized_logrewards = to_logreward(scalarized_rewards) + focused_logreward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) if self.focus_cond is not None - else scalarized_reward + else scalarized_logrewards ) - tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) - return RewardScalar(tempered_reward) + tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) + clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) + + return RewardScalar(clamped_logreward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -252,41 +259,42 @@ def setup(self): self.cfg.log_dir, compute_igd=True, compute_pc_entropy=True, - compute_focus_accuracy=True if self.cfg.task.seh_moo.focus_type is not None else False, - focus_cosim=self.cfg.task.seh_moo.focus_cosim, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, ) ) # instantiate preference and focus conditioning vectors for validation - tcfg = self.cfg.task.seh_moo - n_obj = len(tcfg.objectives) + n_obj = len(self.cfg.task.seh_moo.objectives) + cond_cfg = self.cfg.cond # making sure hyperparameters for preferences and focus regions are consistent if not ( - tcfg.focus_type is None - or tcfg.focus_type == "centered" - or (isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) == 1) + cond_cfg.focus_region.focus_type is None + or cond_cfg.focus_region.focus_type == "centered" + or (isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) == 1) ): - assert tcfg.preference_type is None, ( - f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " - f"and preference_type={tcfg.preference_type}" + assert cond_cfg.weighted_prefs.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, " + f"here focus_type={cond_cfg.focus_region.focus_type} " + f"and preference_type={cond_cfg.weighted_prefs.preference_type }" ) - if isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) > 1: - n_valid = len(tcfg.focus_type) + if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: + n_valid = len(cond_cfg.focus_region.focus_type) else: - n_valid = tcfg.n_valid + n_valid = self.cfg.task.seh_moo.n_valid # preference vectors - if tcfg.preference_type is None: + if cond_cfg.weighted_prefs.preference_type is None: valid_preferences = np.ones((n_valid, n_obj)) - elif tcfg.preference_type == "dirichlet": + elif cond_cfg.weighted_prefs.preference_type == "dirichlet": valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - elif tcfg.preference_type == "seeded_single": + elif cond_cfg.weighted_prefs.preference_type == "seeded_single": seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] - elif tcfg.preference_type == "seeded_many": + elif cond_cfg.weighted_prefs.preference_type == "seeded_many": valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") @@ -310,8 +318,8 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -340,6 +348,12 @@ def _save_state(self, it): self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) + def run(self): + super().run() + for hook in self.sampling_hooks: + if hasattr(hook, "terminate"): + hook.terminate() + class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 8b38a100..7fe0f24b 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -13,6 +13,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class ToySeqTask(GFNTask): @@ -34,7 +35,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, objs: List[str]) -> Tuple[FlatRewards, Tensor]: rs = torch.tensor([sum([s.count(p) for p in self.seqs]) for s in objs]).float() / self.norm diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index 0279167c..acb7d12e 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -74,12 +74,11 @@ def sample(self, n): assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" return {"beta": torch.tensor(beta), "encoding": beta_enc} - def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor: - scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( + def transform(self, cond_info: Dict[str, Tensor], logreward: Tensor) -> Tensor: + assert len(logreward.shape) == len( cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return scalar_logreward * cond_info["beta"] + ), f"dangerous shape mismatch: {logreward.shape} vs {cond_info['beta'].shape}" + return logreward * cond_info["beta"] def encode(self, conditional: Tensor) -> Tensor: cfg = self.cfg.cond.temperature @@ -114,9 +113,9 @@ def sample(self, n): return {"preferences": preferences, "encoding": self.encode(preferences)} def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" - return scalar_logreward + scalar_reward = (flat_reward * cond_info["preferences"]).sum(1) + assert len(scalar_reward.shape) == 1, f"scalar_reward should be a 1D array, got {scalar_reward.shape}" + return scalar_reward def encoding_size(self): return max(1, self.num_thermometer_dim * self.num_objectives) @@ -227,11 +226,8 @@ def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_l focus_coef, in_focus_mask = metrics.compute_focus_coef( flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef ) - if scalar_logreward is None: - scalar_logreward = torch.log(focus_coef) - else: - scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) - scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward return scalar_logreward diff --git a/src/gflownet/utils/focus_model.py b/src/gflownet/utils/focus_model.py index 14bf6c71..70cb4950 100644 --- a/src/gflownet/utils/focus_model.py +++ b/src/gflownet/utils/focus_model.py @@ -89,12 +89,12 @@ def sample_focus_directions(self, n: int): """ sampling_likelihoods = torch.zeros_like(self.focus_dir_count).float().to(self.device) sampling_likelihoods[self.focus_dir_count == 0] = self.feasible_flow - sampling_likelihoods[ - torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0) - ] = self.feasible_flow - sampling_likelihoods[ - torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0) - ] = self.infeasible_flow + sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0)] = ( + self.feasible_flow + ) + sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0)] = ( + self.infeasible_flow + ) focus_dir_indices = torch.multinomial(sampling_likelihoods, n, replacement=True) return self.focus_dir_dataset[focus_dir_indices].to("cpu") diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 4862c6c7..4743efa0 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -197,6 +197,10 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): return info + def terminate(self): + self.stop.set() + self.pareto_thread.join() + class TopKHook: def __init__(self, k, repeats, num_preferences): diff --git a/src/gflownet/utils/transforms.py b/src/gflownet/utils/transforms.py index 20050e4f..f5428ddc 100644 --- a/src/gflownet/utils/transforms.py +++ b/src/gflownet/utils/transforms.py @@ -2,6 +2,10 @@ from torch import Tensor +def to_logreward(reward: Tensor) -> Tensor: + return reward.squeeze().clamp(min=1e-30).log() + + def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) -> Tensor: """Thermometer encoding of a scalar quantity.