Skip to content

Commit

Permalink
QM9 doc + lint + sampled_avg_reward (#114)
Browse files Browse the repository at this point in the history
* seh_frag_moo fix and lint

* Fix tox.ini

* trigger build
  • Loading branch information
bengioe authored Feb 7, 2024
1 parent defd8fc commit ef5f2cb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ If package dependencies seem not to work, you may need to install the exact froz

## Developing & Contributing

TODO: Write Contributing.md.
External contributions are welcome. We use `tox` to run tests and linting, and `pre-commit` to run checks before committing.
To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively.
1 change: 1 addition & 0 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __iter__(self):
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
)
if num_online > 0:
extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item()
for hook in self.log_hooks:
extra_info.update(
hook(
Expand Down
10 changes: 9 additions & 1 deletion src/gflownet/tasks/qm9/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def inverse_flat_reward_transform(self, rp):
def load_task_models(self, path):
gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0))
# TODO: this path should be part of the config?
state_dict = torch.load(path)
try:
state_dict = torch.load(path)
except Exception as e:
print(
"Could not load model.",
e,
"\nModel weights can be found at",
"https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt",
)
gap_model.load_state_dict(state_dict)
gap_model.cuda()
gap_model, self.device = self._wrap_model(gap_model, send_to_device=True)
Expand Down

0 comments on commit ef5f2cb

Please sign in to comment.