diff --git a/docs/contributing.md b/docs/contributing.md index a85e945c..b63b082f 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -28,6 +28,10 @@ linters (as well as tests) are run. We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`. +The cascade of events is as follows: +- For `build-and-test`, `tox -> testenv:py310 -> pytest` is run. +- For `code-quality`, `tox -e style -> testenv:style -> pre-commit -> {isort, black, mypy, bandit, ruff, & others}`. This and the "others" are defined in `.pre-commit-config.yaml` and include things like checking for secrets and trailing whitespace. + ## Style Guide On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index ebc7e48a..a6600e28 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -62,9 +62,7 @@ def __init__( self.correct_idempotent = correct_idempotent self.pad_with_terminal_state = pad_with_terminal_state - def sample_from_model( - self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0 - ): + def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0): """Samples a model in a minibatch Parameters diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index bdd38aaa..a75b00c6 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -366,7 +366,7 @@ def compute_batch_losses( clip_log_R = torch.maximum( log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev) ).float() - cond_info = getattr(batch, 'cond_info', None) + cond_info = getattr(batch, "cond_info", None) invalid_mask = 1 - batch.is_valid # This index says which trajectory each graph belongs to, so diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index f45c0000..9d7e5ab0 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -225,7 +225,8 @@ def create_batch(self, trajs, batch_info): if "focus_dir" in trajs[0]: batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) - if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + # TODO: Restore this during merge + if self.ctx.has_n() and False: # Does this go somewhere else? Require a flag? Might not be cheap to compute log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 1dc41b86..d320db98 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -77,7 +77,7 @@ def setup(self): # Separate Z parameters from non-Z to allow for LR decay on the former if hasattr(self.model, "logZ"): - Z_params = list(self.model.logZ.parameters()) + Z_params = list(self.model._logZ.parameters()) non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] else: Z_params = [] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index ad854ebd..162f681f 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -162,7 +162,6 @@ def set_default_hps(self, cfg: Config): def setup_task(self): self.task = SEHTask( - dataset=self.training_data, cfg=self.cfg, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index b2d4dc97..386c0494 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -1,4 +1,5 @@ import gc +import logging import os import pathlib import shutil @@ -331,6 +332,10 @@ def run(self, logger=None): del final_dl def terminate(self): + logger = logging.getLogger("logger") + for handler in logger.handlers: + handler.close() + for hook in self.sampling_hooks: if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: hook.terminate() diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 79d06d28..ae544ec5 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -93,3 +93,19 @@ def insert_many(self, rows, column_names): cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec cur.close() self.db.commit() + + def __del__(self): + if self.db is not None: + self.db.close() + + +def read_all_results(path): + # E402: module level import not at top of file, but pandas is an optional dependency + import pandas as pd # noqa: E402 + + num_workers = len([f for f in os.listdir(path) if f.startswith("generated_objs")]) + dfs = [ + pd.read_sql_query("SELECT * FROM results", sqlite3.connect(f"file:{path}/generated_objs_{i}.db?mode=ro")) + for i in range(num_workers) + ] + return pd.concat(dfs).sort_index().reset_index(drop=True)