diff --git a/README.md b/README.md index 3b7a143d..b2c0d31a 100644 --- a/README.md +++ b/README.md @@ -64,4 +64,5 @@ If package dependencies seem not to work, you may need to install the exact froz ## Developing & Contributing -TODO: Write Contributing.md. +External contributions are welcome. We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. +To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index c546964e..f14793e4 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -254,6 +254,7 @@ def __iter__(self): {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) if num_online > 0: + extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() for hook in self.log_hooks: extra_info.update( hook( diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 5132e154..866a7fac 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -64,7 +64,15 @@ def inverse_flat_reward_transform(self, rp): def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? - state_dict = torch.load(path) + try: + state_dict = torch.load(path) + except Exception as e: + print( + "Could not load model.", + e, + "\nModel weights can be found at", + "https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt", + ) gap_model.load_state_dict(state_dict) gap_model.cuda() gap_model, self.device = self._wrap_model(gap_model, send_to_device=True)