diff --git a/CHANGELOG.md b/CHANGELOG.md index 12a4b581..cab90057 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ### 0.2.6 (UNRELEASED) +### Datasets +* Add stage-based conditions to `setup` in `ProteinDataModule` [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72) +* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72) + ### Models * Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74) diff --git a/README.md b/README.md index 02432738..f51bec7a 100644 --- a/README.md +++ b/README.md @@ -557,3 +557,18 @@ To build a local version of the project's Sphinx documentation web pages: pip install -r docs/.docs.requirements # one-time only rm -rf docs/build/ && sphinx-build docs/source/ docs/build/ # NOTE: errors can safely be ignored ``` + +## Citing `ProteinWorkshop` + +Please consider citing `proteinworkshop` if it proves useful in your work. + +```bibtex +@inproceedings{ + jamasb2024evaluating, + title={Evaluating Representation Learning on the Protein Structure Universe}, + author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell}, + booktitle={The Twelfth International Conference on Learning Representations}, + year={2024}, +} + +``` diff --git a/citation.bib b/citation.bib new file mode 100644 index 00000000..40102774 --- /dev/null +++ b/citation.bib @@ -0,0 +1,7 @@ +@inproceedings{ +jamasb2024evaluating, +title={Evaluating Representation Learning on the Protein Structure Universe}, +author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell}, +booktitle={The Twelfth International Conference on Learning Representations}, +year={2024}, +} diff --git a/proteinworkshop/datasets/base.py b/proteinworkshop/datasets/base.py index 1f58dbfa..63034e5d 100644 --- a/proteinworkshop/datasets/base.py +++ b/proteinworkshop/datasets/base.py @@ -1,4 +1,5 @@ """Base classes for protein structure datamodules and datasets.""" +import copy import os import pathlib from abc import ABC, abstractmethod @@ -81,12 +82,23 @@ def download(self): def setup(self, stage: Optional[str] = None): self.download() - logger.info("Preprocessing training data") - self.train_ds = self.train_dataset() - logger.info("Preprocessing validation data") - self.val_ds = self.val_dataset() - logger.info("Preprocessing test data") - self.test_ds = self.test_dataset() + + if stage == "fit" or stage is None: + logger.info("Preprocessing training data") + self.train_ds = self.train_dataset() + logger.info("Preprocessing validation data") + self.val_ds = self.val_dataset() + elif stage == "test": + logger.info("Preprocessing test data") + if hasattr(self, "test_dataset_names"): + for split in self.test_dataset_names: + setattr(self, f"{split}_ds", self.test_dataset(split)) + else: + self.test_ds = self.test_dataset() + elif stage == "lazy_init": + logger.info("Preprocessing validation data") + self.val_ds = self.val_dataset() + # self.class_weights = self.get_class_weights() @property @@ -518,7 +530,7 @@ def get(self, idx: int) -> Data: :return: PyTorch Geometric Data object. """ if self.in_memory: - return self._batch_format(self.data[idx]) + return self._batch_format(copy.deepcopy(self.data[idx])) if self.out_names is not None: fname = f"{self.out_names[idx]}.pt" diff --git a/proteinworkshop/datasets/fold_classification.py b/proteinworkshop/datasets/fold_classification.py index 403da603..f1d384ec 100644 --- a/proteinworkshop/datasets/fold_classification.py +++ b/proteinworkshop/datasets/fold_classification.py @@ -1,7 +1,7 @@ import os import pathlib import tarfile -from typing import Callable, Dict, Iterable, Optional +from typing import Callable, Dict, Iterable, List, Literal, Optional import omegaconf import pandas as pd @@ -72,6 +72,11 @@ def __init__( else: self.transform = None + @property + def test_dataset_names(self) -> List[str]: + """Provides a list of test set split names.""" + return ["fold", "family", "superfamily"] + def download(self): self.download_data_files() self.download_structures() @@ -152,16 +157,12 @@ def parse_class_map(self) -> Dict[str, str]: ) return dict(class_map.values) - def setup(self, stage: Optional[str] = None): - self.download_data_files() - self.download_structures() - self.train_ds = self.train_dataset() - self.val_ds = self.val_dataset() - self.test_ds = self.test_dataset() - def _get_dataset(self, split: str) -> ProteinDataset: + if hasattr(self, f"{split}_ds"): + return getattr(self, f"{split}_ds") + df = self.parse_dataset(split) - return ProteinDataset( + ds = ProteinDataset( root=str(self.data_dir), pdb_dir=str(self.structure_dir), pdb_codes=list(df.id), @@ -171,6 +172,8 @@ def _get_dataset(self, split: str) -> ProteinDataset: transform=self.transform, in_memory=self.in_memory, ) + setattr(self, f"{split}_ds", ds) + return ds def train_dataset(self) -> ProteinDataset: return self._get_dataset("training") @@ -178,8 +181,10 @@ def train_dataset(self) -> ProteinDataset: def val_dataset(self) -> ProteinDataset: return self._get_dataset("validation") - def test_dataset(self) -> ProteinDataset: - return self._get_dataset(f"test_{self.split}") + def test_dataset( + self, split: Literal["fold", "family", "superfamily"] + ) -> ProteinDataset: + return self._get_dataset(f"test_{split}") def train_dataloader(self) -> ProteinDataLoader: self.train_ds = self.train_dataset() @@ -201,8 +206,10 @@ def val_dataloader(self) -> ProteinDataLoader: num_workers=self.num_workers, ) - def test_dataloader(self) -> ProteinDataLoader: - self.test_ds = self.test_dataset() + def test_dataloader( + self, split: Literal["fold", "family", "superfamily"] + ) -> ProteinDataLoader: + self.test_ds = self.test_dataset(split) return ProteinDataLoader( self.test_ds, batch_size=self.batch_size, @@ -211,17 +218,6 @@ def test_dataloader(self) -> ProteinDataLoader: num_workers=self.num_workers, ) - def get_test_loader(self, split: str) -> ProteinDataLoader: - log.info(f"Getting test loader: {split}") - test_ds = self._get_dataset(f"test_{split}") - return ProteinDataLoader( - test_ds, - batch_size=self.batch_size, - shuffle=False, - pin_memory=self.pin_memory, - num_workers=self.num_workers, - ) - def parse_dataset(self, split: str) -> pd.DataFrame: """ Parses the raw dataset files to Pandas DataFrames. diff --git a/proteinworkshop/datasets/go.py b/proteinworkshop/datasets/go.py index 43004b48..c4fb0623 100644 --- a/proteinworkshop/datasets/go.py +++ b/proteinworkshop/datasets/go.py @@ -2,7 +2,7 @@ import zipfile from functools import lru_cache from pathlib import Path -from typing import Callable, Dict, Iterable, Literal, Optional +from typing import Callable, Dict, Iterable, List, Literal, Optional import omegaconf import pandas as pd @@ -70,6 +70,14 @@ def __init__( self.shuffle_labels = shuffle_labels + self.test_seq_similarity_cutoffs: List[float] = [ + 0.3, + 0.4, + 0.5, + 0.7, + 0.95, + ] + if transforms is not None: self.transform = self.compose_transforms( omegaconf.OmegaConf.to_container(transforms, resolve=True) @@ -79,7 +87,7 @@ def __init__( self.train_fname = self.data_dir / "nrPDB-GO_train.txt" self.val_fname = self.data_dir / "nrPDB-GO_valid.txt" - self.test_fname = self.data_dir / "nrPDB-GO_test.txt" + self.test_fname = self.data_dir / "nrPDB-GO_test.csv" self.label_fname = self.data_dir / "nrPDB-GO_annot.tsv" self.url = "https://zenodo.org/record/6622158/files/GeneOntology.zip" @@ -87,6 +95,11 @@ def __init__( f"Setting up Gene Ontology dataset. Fraction {self.dataset_fraction}" ) + @property + def test_dataset_names(self) -> List[str]: + """Provides a list of test set split names.""" + return ["test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"] + @lru_cache def parse_labels(self) -> Dict[str, torch.Tensor]: """ @@ -130,11 +143,23 @@ def parse_labels(self) -> Dict[str, torch.Tensor]: return labels def _get_dataset( - self, split: Literal["training", "validation", "testing"] + self, + split: Literal[ + "training", + "validation", + "test_0.3", + "test_0.4", + "test_0.5", + "test_0.7", + "test_0.95", + ], ) -> ProteinDataset: + if hasattr(self, f"{split}_ds"): + return getattr(self, f"{split}_ds") + df = self.parse_dataset(split) log.info("Initialising Graphein dataset...") - return ProteinDataset( + ds = ProteinDataset( root=str(self.data_dir), pdb_dir=str(self.pdb_dir), pdb_codes=list(df.pdb), @@ -147,6 +172,8 @@ def _get_dataset( format=self.format, in_memory=self.in_memory, ) + setattr(self, f"{split}_ds", ds) + return ds def train_dataset(self) -> ProteinDataset: return self._get_dataset("training") @@ -154,8 +181,13 @@ def train_dataset(self) -> ProteinDataset: def val_dataset(self) -> ProteinDataset: return self._get_dataset("validation") - def test_dataset(self) -> ProteinDataset: - return self._get_dataset("testing") + def test_dataset( + self, + split: Literal[ + "test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95" + ], + ) -> ProteinDataset: + return self._get_dataset(split) def train_dataloader(self) -> ProteinDataLoader: return ProteinDataLoader( @@ -175,9 +207,14 @@ def val_dataloader(self) -> ProteinDataLoader: num_workers=self.num_workers, ) - def test_dataloader(self) -> ProteinDataLoader: + def test_dataloader( + self, + split: Literal[ + "test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95" + ], + ) -> ProteinDataLoader: return ProteinDataLoader( - self.test_dataset(), + self.test_dataset(split), batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory, @@ -205,7 +242,16 @@ def exclude_pdbs(self): pass def parse_dataset( - self, split: Literal["training", "validation", "testing"] + self, + split: Literal[ + "training", + "validation", + "test_0.3", + "test_0.4", + "test_0.5", + "test_0.7", + "test_0.95", + ], ) -> pd.DataFrame: # sourcery skip: remove-unnecessary-else, swap-if-else-branches, switch """ @@ -221,8 +267,11 @@ def parse_dataset( data = data.sample(frac=self.dataset_fraction) elif split == "validation": data = pd.read_csv(self.val_fname, sep="\t", header=None) - elif split == "testing": - data = pd.read_csv(self.test_fname, sep="\t", header=None) + elif split.startswith("test_"): + cutoff = int(float(split.split("_")[1]) * 100) + data = pd.read_csv(self.test_fname, sep=",") + data = data.loc[data[f"<{cutoff}%"] == 1] + data = pd.DataFrame(data["PDB-chain"].values) else: raise ValueError(f"Unknown split: {split}") @@ -304,16 +353,18 @@ def __call__(self, data: Protein) -> Protein: cfg.datamodule.transforms = [] log.info("Loaded config") - ds = hydra.utils.instantiate(cfg) - print(ds) - # labels = ds["datamodule"].parse_labels() - ds.datamodule.setup() - dl = ds["datamodule"].train_dataloader() - for batch in dl: - print(batch) - dl = ds["datamodule"].val_dataloader() - for batch in dl: - print(batch) - dl = ds["datamodule"].test_dataloader() - for batch in dl: - print(batch) + ds = hydra.utils.instantiate(cfg)["datamodule"] + ds.parse_dataset("test_0.3") + ds.parse_dataset("test_0.95") + # print(ds) + ## labels = ds["datamodule"].parse_labels() + # ds.datamodule.setup() + # dl = ds["datamodule"].train_dataloader() + # for batch in dl: + # print(batch) + # dl = ds["datamodule"].val_dataloader() + # for batch in dl: + # print(batch) + # dl = ds["datamodule"].test_dataloader() + # for batch in dl: + # print(batch) diff --git a/proteinworkshop/train.py b/proteinworkshop/train.py index 74aeaebf..6559e3cd 100644 --- a/proteinworkshop/train.py +++ b/proteinworkshop/train.py @@ -144,7 +144,7 @@ def train_model( log.info("Initializing lazy layers...") with torch.no_grad(): - datamodule.setup() # type: ignore + datamodule.setup(stage="lazy_init") # type: ignore batch = next(iter(datamodule.val_dataloader())) log.info(f"Unfeaturized batch: {batch}") batch = model.featurise(batch) @@ -185,16 +185,13 @@ def train_model( if cfg.get("test"): log.info("Starting testing!") - # Run test on all splits if using fold_classification dataset - if ( - cfg.dataset.datamodule._target_ - == "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule" - ): - splits = ["fold", "family", "superfamily"] + if hasattr(datamodule, "test_dataset_names"): + splits = datamodule.test_dataset_names wandb_logger = copy.deepcopy(trainer.logger) - for split in splits: - dataloader = datamodule.get_test_loader(split) + for i, split in enumerate(splits): + dataloader = datamodule.test_dataloader(split) trainer.logger = False + log.info(f"Testing on {split} ({i+1} / {len(splits)})...") results = trainer.test( model=model, dataloaders=dataloader, ckpt_path="best" )[0]