Skip to content

Commit

Permalink
close loggers + read_all_results in sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Mar 13, 2024
1 parent 71c0d0d commit 82e171b
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ linters (as well as tests) are run.

We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`.

The cascade of events is as follows:
- For `build-and-test`, `tox -> testenv:py310 -> pytest` is run.
- For `code-quality`, `tox -e style -> testenv:style -> pre-commit -> {isort, black, mypy, bandit, ruff, & others}`. This and the "others" are defined in `.pre-commit-config.yaml` and include things like checking for secrets and trailing whitespace.

## Style Guide

On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).
Expand Down
4 changes: 1 addition & 3 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def __init__(
self.correct_idempotent = correct_idempotent
self.pad_with_terminal_state = pad_with_terminal_state

def sample_from_model(
self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0
):
def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0):
"""Samples a model in a minibatch
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def compute_batch_losses(
clip_log_R = torch.maximum(
log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev)
).float()
cond_info = getattr(batch, 'cond_info', None)
cond_info = getattr(batch, "cond_info", None)
invalid_mask = 1 - batch.is_valid

# This index says which trajectory each graph belongs to, so
Expand Down
3 changes: 2 additions & 1 deletion src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def create_batch(self, trajs, batch_info):
if "focus_dir" in trajs[0]:
batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs])

if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute
# TODO: Restore this during merge
if self.ctx.has_n() and False: # Does this go somewhere else? Require a flag? Might not be cheap to compute
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setup(self):

# Separate Z parameters from non-Z to allow for LR decay on the former
if hasattr(self.model, "logZ"):
Z_params = list(self.model.logZ.parameters())
Z_params = list(self.model._logZ.parameters())
non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)]
else:
Z_params = []
Expand Down
1 change: 0 additions & 1 deletion src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def set_default_hps(self, cfg: Config):

def setup_task(self):
self.task = SEHTask(
dataset=self.training_data,
cfg=self.cfg,
wrap_model=self._wrap_for_mp,
)
Expand Down
5 changes: 5 additions & 0 deletions src/gflownet/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import logging
import os
import pathlib
import shutil
Expand Down Expand Up @@ -331,6 +332,10 @@ def run(self, logger=None):
del final_dl

def terminate(self):
logger = logging.getLogger("logger")
for handler in logger.handlers:
handler.close()

for hook in self.sampling_hooks:
if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate:
hook.terminate()
Expand Down
16 changes: 16 additions & 0 deletions src/gflownet/utils/sqlite_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,19 @@ def insert_many(self, rows, column_names):
cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec
cur.close()
self.db.commit()

def __del__(self):
if self.db is not None:
self.db.close()


def read_all_results(path):
# E402: module level import not at top of file, but pandas is an optional dependency
import pandas as pd # noqa: E402

num_workers = len([f for f in os.listdir(path) if f.startswith("generated_objs")])
dfs = [
pd.read_sql_query("SELECT * FROM results", sqlite3.connect(f"file:{path}/generated_objs_{i}.db?mode=ro"))
for i in range(num_workers)
]
return pd.concat(dfs).sort_index().reset_index(drop=True)

0 comments on commit 82e171b

Please sign in to comment.