diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index a9f6ac3f..28960399 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List @dataclass @@ -13,22 +13,6 @@ class SEHMOOTaskConfig: Attributes ---------- - use_steer_thermometer : bool - Whether to use a thermometer encoding for the steering. - preference_type : Optional[str] - The preference sampling distribution, defaults to "dirichlet". - focus_type : Union[list, str, None] - The type of focus distribtuion used, see SEHMOOTask.setup_focus_regions. - focus_cosim : float - The cosine similarity threshold for the focus distribution. - focus_limit_coef : float - The smoothing coefficient for the focus reward. - focus_model_training_limits : Optional[Tuple[int, int]] - The training limits for the focus sampling model (if used). - focus_model_state_space_res : Optional[int] - The state space resolution for the focus sampling model (if used). - max_train_it : Optional[int] - The maximum number of training iterations for the focus sampling model (if used). n_valid : int The number of valid cond_info tensors to sample n_valid_repeats : int @@ -37,15 +21,6 @@ class SEHMOOTaskConfig: The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. """ - use_steer_thermometer: bool = False - preference_type: Optional[str] = "dirichlet" - focus_type: Optional[str] = None - focus_dirs_listed: Optional[List[List[float]]] = None - focus_cosim: float = 0.0 - focus_limit_coef: float = 1.0 - focus_model_training_limits: Optional[Tuple[int, int]] = None - focus_model_state_space_res: Optional[int] = None - max_train_it: Optional[int] = None n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index b0db3047..866a7fac 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -17,6 +17,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class QM9GapTask(GFNTask): @@ -81,7 +82,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 52c63b11..e916f732 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -18,6 +18,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class SEHTask(GFNTask): @@ -59,7 +60,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 22dab75a..bd597c31 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -22,6 +22,7 @@ from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.transforms import to_logreward class SEHMOOTask(SEHTask): @@ -146,21 +147,26 @@ def relabel_condinfo_and_logrewards( return cond_info, log_rewards def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """ + Compute the logreward from the flat_reward and the conditional information + """ if isinstance(flat_reward, list): if isinstance(flat_reward[0], Tensor): flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalarized_logreward = self.pref_cond.transform(cond_info, flat_reward) + scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) + scalarized_logrewards = to_logreward(scalarized_rewards) focused_logreward = ( - self.focus_cond.transform(cond_info, flat_reward, scalarized_logreward) + self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) if self.focus_cond is not None - else scalarized_logreward + else scalarized_logrewards ) - # Temperature conditionals expect linear scalars and output log-scalars - tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward.exp()) - return RewardScalar(tempered_logreward) + tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) + clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) + + return RewardScalar(clamped_logreward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -253,41 +259,42 @@ def setup(self): self.cfg.log_dir, compute_igd=True, compute_pc_entropy=True, - compute_focus_accuracy=True if self.cfg.task.seh_moo.focus_type is not None else False, - focus_cosim=self.cfg.task.seh_moo.focus_cosim, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, ) ) # instantiate preference and focus conditioning vectors for validation - tcfg = self.cfg.task.seh_moo - n_obj = len(tcfg.objectives) + n_obj = len(self.cfg.task.seh_moo.objectives) + cond_cfg = self.cfg.cond # making sure hyperparameters for preferences and focus regions are consistent if not ( - tcfg.focus_type is None - or tcfg.focus_type == "centered" - or (isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) == 1) + cond_cfg.focus_region.focus_type is None + or cond_cfg.focus_region.focus_type == "centered" + or (isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) == 1) ): - assert tcfg.preference_type is None, ( - f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " - f"and preference_type={tcfg.preference_type}" + assert cond_cfg.weighted_prefs.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, " + f"here focus_type={cond_cfg.focus_region.focus_type} " + f"and preference_type={cond_cfg.weighted_prefs.preference_type }" ) - if isinstance(tcfg.focus_type, list) and len(tcfg.focus_type) > 1: - n_valid = len(tcfg.focus_type) + if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: + n_valid = len(cond_cfg.focus_region.focus_type) else: - n_valid = tcfg.n_valid + n_valid = self.cfg.task.seh_moo.n_valid # preference vectors - if tcfg.preference_type is None: + if cond_cfg.weighted_prefs.preference_type is None: valid_preferences = np.ones((n_valid, n_obj)) - elif tcfg.preference_type == "dirichlet": + elif cond_cfg.weighted_prefs.preference_type == "dirichlet": valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - elif tcfg.preference_type == "seeded_single": + elif cond_cfg.weighted_prefs.preference_type == "seeded_single": seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] - elif tcfg.preference_type == "seeded_many": + elif cond_cfg.weighted_prefs.preference_type == "seeded_many": valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") @@ -311,8 +318,8 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -341,6 +348,12 @@ def _save_state(self, it): self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) + def run(self): + super().run() + for hook in self.sampling_hooks: + if hasattr(hook, "terminate"): + hook.terminate() + class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 8b38a100..7fe0f24b 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -13,6 +13,7 @@ from gflownet.online_trainer import StandardOnlineTrainer from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.transforms import to_logreward class ToySeqTask(GFNTask): @@ -34,7 +35,7 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten return self.temperature_conditional.sample(n) def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, flat_reward)) + return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_flat_rewards(self, objs: List[str]) -> Tuple[FlatRewards, Tensor]: rs = torch.tensor([sum([s.count(p) for p in self.seqs]) for s in objs]).float() / self.norm diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index 0279167c..acb7d12e 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -74,12 +74,11 @@ def sample(self, n): assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" return {"beta": torch.tensor(beta), "encoding": beta_enc} - def transform(self, cond_info: Dict[str, Tensor], linear_reward: Tensor) -> Tensor: - scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == len( + def transform(self, cond_info: Dict[str, Tensor], logreward: Tensor) -> Tensor: + assert len(logreward.shape) == len( cond_info["beta"].shape - ), f"dangerous shape mismatch: {scalar_logreward.shape} vs {cond_info['beta'].shape}" - return scalar_logreward * cond_info["beta"] + ), f"dangerous shape mismatch: {logreward.shape} vs {cond_info['beta'].shape}" + return logreward * cond_info["beta"] def encode(self, conditional: Tensor) -> Tensor: cfg = self.cfg.cond.temperature @@ -114,9 +113,9 @@ def sample(self, n): return {"preferences": preferences, "encoding": self.encode(preferences)} def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() - assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" - return scalar_logreward + scalar_reward = (flat_reward * cond_info["preferences"]).sum(1) + assert len(scalar_reward.shape) == 1, f"scalar_reward should be a 1D array, got {scalar_reward.shape}" + return scalar_reward def encoding_size(self): return max(1, self.num_thermometer_dim * self.num_objectives) @@ -227,11 +226,8 @@ def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_l focus_coef, in_focus_mask = metrics.compute_focus_coef( flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef ) - if scalar_logreward is None: - scalar_logreward = torch.log(focus_coef) - else: - scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) - scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward + scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) + scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward return scalar_logreward diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 4862c6c7..4743efa0 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -197,6 +197,10 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): return info + def terminate(self): + self.stop.set() + self.pareto_thread.join() + class TopKHook: def __init__(self, k, repeats, num_preferences): diff --git a/src/gflownet/utils/transforms.py b/src/gflownet/utils/transforms.py index 20050e4f..f5428ddc 100644 --- a/src/gflownet/utils/transforms.py +++ b/src/gflownet/utils/transforms.py @@ -2,6 +2,10 @@ from torch import Tensor +def to_logreward(reward: Tensor) -> Tensor: + return reward.squeeze().clamp(min=1e-30).log() + + def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) -> Tensor: """Thermometer encoding of a scalar quantity.