Skip to content

Commit

Permalink
move code to load datasets outside of the train dataset loader
Browse files Browse the repository at this point in the history
  • Loading branch information
curtischong committed Jun 20, 2024
1 parent 02da2ff commit 3fa5775
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 48 deletions.
7 changes: 1 addition & 6 deletions docs/core/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ checkpoint_path
```

We have to update our configuration yml file with the dataset. It is necessary to specify the train and test set for some reason.
We have to update our configuration yml file with the test dataset.

```{code-cell} ipython3
from fairchem.core.common.tutorial_utils import generate_yml_config
Expand All @@ -81,11 +81,6 @@ yml = generate_yml_config(checkpoint_path, 'config.yml',
'task.dataset': 'ase_db',
'task.prediction_dtype': 'float32',
'logger':'tensorboard', # don't use wandb!
# Train data
'dataset.train.src': 'data.db',
'dataset.train.a2g_args.r_energy': False,
'dataset.train.a2g_args.r_forces': False,
'dataset.train.select_args.selection': 'natoms>5,xc=PBE',
# Test data - prediction only so no regression
'dataset.test.src': 'data.db',
'dataset.test.a2g_args.r_energy': False,
Expand Down
84 changes: 42 additions & 42 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,49 +298,49 @@ def load_datasets(self) -> None:
self.train_sampler,
)

if self.config.get("val_dataset", None):
if self.config["val_dataset"].get("use_train_settings", True):
val_config = self.config["dataset"].copy()
val_config.update(self.config["val_dataset"])
else:
val_config = self.config["val_dataset"]

self.val_dataset = registry.get_dataset_class(
val_config.get("format", "lmdb")
)(val_config)
self.val_sampler = self.get_sampler(
self.val_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.val_loader = self.get_dataloader(
self.val_dataset,
self.val_sampler,
)
if self.config.get("val_dataset", None):
if self.config["val_dataset"].get("use_train_settings", True):
val_config = self.config["dataset"].copy()
val_config.update(self.config["val_dataset"])
else:
val_config = self.config["val_dataset"]

if self.config.get("test_dataset", None):
if self.config["test_dataset"].get("use_train_settings", True):
test_config = self.config["dataset"].copy()
test_config.update(self.config["test_dataset"])
else:
test_config = self.config["test_dataset"]

self.test_dataset = registry.get_dataset_class(
test_config.get("format", "lmdb")
)(test_config)
self.test_sampler = self.get_sampler(
self.test_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.test_loader = self.get_dataloader(
self.test_dataset,
self.test_sampler,
)
self.val_dataset = registry.get_dataset_class(
val_config.get("format", "lmdb")
)(val_config)
self.val_sampler = self.get_sampler(
self.val_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.val_loader = self.get_dataloader(
self.val_dataset,
self.val_sampler,
)

if self.config.get("test_dataset", None):
if self.config["test_dataset"].get("use_train_settings", True):
test_config = self.config["dataset"].copy()
test_config.update(self.config["test_dataset"])
else:
test_config = self.config["test_dataset"]

self.test_dataset = registry.get_dataset_class(
test_config.get("format", "lmdb")
)(test_config)
self.test_sampler = self.get_sampler(
self.test_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.test_loader = self.get_dataloader(
self.test_dataset,
self.test_sampler,
)

# load relaxation dataset
if "relax_dataset" in self.config["task"]:
Expand Down

0 comments on commit 3fa5775

Please sign in to comment.