Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor multi test set datasets; add seq id test splits to GO #72

Merged
merged 8 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### 0.2.6 (UNRELEASED)

### Datasets
* Add stage-based conditions to `setup` in `ProteinDataModule` [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)

### Models

* Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74)
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,18 @@ To build a local version of the project's Sphinx documentation web pages:
pip install -r docs/.docs.requirements # one-time only
rm -rf docs/build/ && sphinx-build docs/source/ docs/build/ # NOTE: errors can safely be ignored
```

## Citing `ProteinWorkshop`

Please consider citing `proteinworkshop` if it proves useful in your work.

```bibtex
@inproceedings{
jamasb2024evaluating,
title={Evaluating Representation Learning on the Protein Structure Universe},
author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}

```
7 changes: 7 additions & 0 deletions citation.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@inproceedings{
jamasb2024evaluating,
title={Evaluating Representation Learning on the Protein Structure Universe},
author={Arian R. Jamasb, Alex Morehead, Zuobai Zhang, Chaitanya K. Joshi, Kieran Didi, Simon V. Mathis, Charles Harris, Jian Tang, Jianlin Cheng, Pietro Lio, Tom L. Blundell},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}
26 changes: 19 additions & 7 deletions proteinworkshop/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base classes for protein structure datamodules and datasets."""
import copy
import os
import pathlib
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -81,12 +82,23 @@ def download(self):

def setup(self, stage: Optional[str] = None):
self.download()
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
logger.info("Preprocessing test data")
self.test_ds = self.test_dataset()

if stage == "fit" or stage is None:
logger.info("Preprocessing training data")
self.train_ds = self.train_dataset()
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()
elif stage == "test":
logger.info("Preprocessing test data")
if hasattr(self, "test_dataset_names"):
for split in self.test_dataset_names:
setattr(self, f"{split}_ds", self.test_dataset(split))
else:
self.test_ds = self.test_dataset()
elif stage == "lazy_init":
logger.info("Preprocessing validation data")
self.val_ds = self.val_dataset()

# self.class_weights = self.get_class_weights()

@property
Expand Down Expand Up @@ -518,7 +530,7 @@ def get(self, idx: int) -> Data:
:return: PyTorch Geometric Data object.
"""
if self.in_memory:
return self._batch_format(self.data[idx])
return self._batch_format(copy.deepcopy(self.data[idx]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this lead to one of the performance enhancements we discussed previously?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No noticeable delta but it makes me feel better about some of the gotchas here: https://alecstashevsky.com/post/on-the-fly-augmentation-with-pytorch-geometric-and-lightning-what-tutorials-dont-teach/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. SGTM


if self.out_names is not None:
fname = f"{self.out_names[idx]}.pt"
Expand Down
44 changes: 20 additions & 24 deletions proteinworkshop/datasets/fold_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pathlib
import tarfile
from typing import Callable, Dict, Iterable, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -72,6 +72,11 @@ def __init__(
else:
self.transform = None

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["fold", "family", "superfamily"]

def download(self):
self.download_data_files()
self.download_structures()
Expand Down Expand Up @@ -152,16 +157,12 @@ def parse_class_map(self) -> Dict[str, str]:
)
return dict(class_map.values)

def setup(self, stage: Optional[str] = None):
self.download_data_files()
self.download_structures()
self.train_ds = self.train_dataset()
self.val_ds = self.val_dataset()
self.test_ds = self.test_dataset()

def _get_dataset(self, split: str) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.structure_dir),
pdb_codes=list(df.id),
Expand All @@ -171,15 +172,19 @@ def _get_dataset(self, split: str) -> ProteinDataset:
transform=self.transform,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset(f"test_{self.split}")
def test_dataset(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataset:
return self._get_dataset(f"test_{split}")

def train_dataloader(self) -> ProteinDataLoader:
self.train_ds = self.train_dataset()
Expand All @@ -201,8 +206,10 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
self.test_ds = self.test_dataset()
def test_dataloader(
self, split: Literal["fold", "family", "superfamily"]
) -> ProteinDataLoader:
self.test_ds = self.test_dataset(split)
return ProteinDataLoader(
self.test_ds,
batch_size=self.batch_size,
Expand All @@ -211,17 +218,6 @@ def test_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def get_test_loader(self, split: str) -> ProteinDataLoader:
log.info(f"Getting test loader: {split}")
test_ds = self._get_dataset(f"test_{split}")
return ProteinDataLoader(
test_ds,
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)

def parse_dataset(self, split: str) -> pd.DataFrame:
"""
Parses the raw dataset files to Pandas DataFrames.
Expand Down
99 changes: 75 additions & 24 deletions proteinworkshop/datasets/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Callable, Dict, Iterable, Literal, Optional
from typing import Callable, Dict, Iterable, List, Literal, Optional

import omegaconf
import pandas as pd
Expand Down Expand Up @@ -70,6 +70,14 @@ def __init__(

self.shuffle_labels = shuffle_labels

self.test_seq_similarity_cutoffs: List[float] = [
0.3,
0.4,
0.5,
0.7,
0.95,
]

if transforms is not None:
self.transform = self.compose_transforms(
omegaconf.OmegaConf.to_container(transforms, resolve=True)
Expand All @@ -79,14 +87,19 @@ def __init__(

self.train_fname = self.data_dir / "nrPDB-GO_train.txt"
self.val_fname = self.data_dir / "nrPDB-GO_valid.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.txt"
self.test_fname = self.data_dir / "nrPDB-GO_test.csv"
self.label_fname = self.data_dir / "nrPDB-GO_annot.tsv"
self.url = "https://zenodo.org/record/6622158/files/GeneOntology.zip"

log.info(
f"Setting up Gene Ontology dataset. Fraction {self.dataset_fraction}"
)

@property
def test_dataset_names(self) -> List[str]:
"""Provides a list of test set split names."""
return ["test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"]

@lru_cache
def parse_labels(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -130,11 +143,23 @@ def parse_labels(self) -> Dict[str, torch.Tensor]:
return labels

def _get_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> ProteinDataset:
if hasattr(self, f"{split}_ds"):
return getattr(self, f"{split}_ds")

df = self.parse_dataset(split)
log.info("Initialising Graphein dataset...")
return ProteinDataset(
ds = ProteinDataset(
root=str(self.data_dir),
pdb_dir=str(self.pdb_dir),
pdb_codes=list(df.pdb),
Expand All @@ -147,15 +172,22 @@ def _get_dataset(
format=self.format,
in_memory=self.in_memory,
)
setattr(self, f"{split}_ds", ds)
return ds

def train_dataset(self) -> ProteinDataset:
return self._get_dataset("training")

def val_dataset(self) -> ProteinDataset:
return self._get_dataset("validation")

def test_dataset(self) -> ProteinDataset:
return self._get_dataset("testing")
def test_dataset(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataset:
return self._get_dataset(split)

def train_dataloader(self) -> ProteinDataLoader:
return ProteinDataLoader(
Expand All @@ -175,9 +207,14 @@ def val_dataloader(self) -> ProteinDataLoader:
num_workers=self.num_workers,
)

def test_dataloader(self) -> ProteinDataLoader:
def test_dataloader(
self,
split: Literal[
"test_0.3", "test_0.4", "test_0.5", "test_0.7", "test_0.95"
],
) -> ProteinDataLoader:
return ProteinDataLoader(
self.test_dataset(),
self.test_dataset(split),
batch_size=self.batch_size,
shuffle=False,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -205,7 +242,16 @@ def exclude_pdbs(self):
pass

def parse_dataset(
self, split: Literal["training", "validation", "testing"]
self,
split: Literal[
"training",
"validation",
"test_0.3",
"test_0.4",
"test_0.5",
"test_0.7",
"test_0.95",
],
) -> pd.DataFrame:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches, switch
"""
Expand All @@ -221,8 +267,11 @@ def parse_dataset(
data = data.sample(frac=self.dataset_fraction)
elif split == "validation":
data = pd.read_csv(self.val_fname, sep="\t", header=None)
elif split == "testing":
data = pd.read_csv(self.test_fname, sep="\t", header=None)
elif split.startswith("test_"):
cutoff = int(float(split.split("_")[1]) * 100)
data = pd.read_csv(self.test_fname, sep=",")
data = data.loc[data[f"<{cutoff}%"] == 1]
data = pd.DataFrame(data["PDB-chain"].values)
else:
raise ValueError(f"Unknown split: {split}")

Expand Down Expand Up @@ -304,16 +353,18 @@ def __call__(self, data: Protein) -> Protein:
cfg.datamodule.transforms = []
log.info("Loaded config")

ds = hydra.utils.instantiate(cfg)
print(ds)
# labels = ds["datamodule"].parse_labels()
ds.datamodule.setup()
dl = ds["datamodule"].train_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].val_dataloader()
for batch in dl:
print(batch)
dl = ds["datamodule"].test_dataloader()
for batch in dl:
print(batch)
ds = hydra.utils.instantiate(cfg)["datamodule"]
ds.parse_dataset("test_0.3")
ds.parse_dataset("test_0.95")
# print(ds)
## labels = ds["datamodule"].parse_labels()
# ds.datamodule.setup()
# dl = ds["datamodule"].train_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].val_dataloader()
# for batch in dl:
# print(batch)
# dl = ds["datamodule"].test_dataloader()
# for batch in dl:
# print(batch)
15 changes: 6 additions & 9 deletions proteinworkshop/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def train_model(

log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down Expand Up @@ -185,16 +185,13 @@ def train_model(

if cfg.get("test"):
log.info("Starting testing!")
# Run test on all splits if using fold_classification dataset
if (
cfg.dataset.datamodule._target_
== "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule"
):
splits = ["fold", "family", "superfamily"]
if hasattr(datamodule, "test_dataset_names"):
splits = datamodule.test_dataset_names
wandb_logger = copy.deepcopy(trainer.logger)
for split in splits:
dataloader = datamodule.get_test_loader(split)
for i, split in enumerate(splits):
dataloader = datamodule.test_dataloader(split)
trainer.logger = False
log.info(f"Testing on {split} ({i+1} / {len(splits)})...")
results = trainer.test(
model=model, dataloaders=dataloader, ckpt_path="best"
)[0]
Expand Down
Loading