Skip to content

Commit

Permalink
Removed a number of bugs from input pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Nov 21, 2023
1 parent 78e34b0 commit 12da901
Showing 1 changed file with 71 additions and 58 deletions.
129 changes: 71 additions & 58 deletions src/data/protein_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import os
import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import proteinflow
from torchvision import transforms


class TransformDataset(torch.utils.data.Dataset):
"""A convenience class that applies arbitrary transforms to torch.utils.data.Dataset objects."""

def __init__(
self,
dataset,
transform
) -> None:
dataset: Dataset,
transform: Callable):
super().__init__()
self.dataset = dataset
self.transform = transform
Expand All @@ -30,9 +30,12 @@ def __len__(self):
class Reorder(torch.nn.Module):
"""A transformation that reorders the 3D coordinates of backbone atoms
from N, C, Ca, O -> N, Ca, C, O."""

def forward(self, protein_dict):
raise NotImplementedError
if 'reordered' not in protein_dict.keys():
# If not already reordered, switch to N, Ca, C, ordering.
reordered_X = protein_dict['X'].index_select(1, torch.tensor([0, 2, 1, 3]))
protein_dict['X'] = reordered_X
return protein_dict


class Cropper(torch.nn.Module):
Expand All @@ -59,23 +62,18 @@ def forward(self, protein_dict: dict):
used in `'chain_id'` and `'chain_encoding_all'` objects)
"""
n_res = protein_dict['residue_idx'].shape[0]
n = n_res - self.crop_size
n = max(n_res - self.crop_size, 1)
crop_start = torch.randint(low=0, high=n, size=())
for key, value in protein_dict.items():
if key == "chain_id" or key == "chain_dict": # these are not Tensors, so skip
continue
protein_dict[key] = value[crop_start:crop_start + self.crop_size]
return protein_dict


class ProteinDataModule(LightningDataModule):
"""`LightningDataModule` for the Protein Data Bank.
The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
A `LightningDataModule` implements 7 key methods:
```python
Expand Down Expand Up @@ -116,6 +114,7 @@ def __init__(
data_dir: str = "./data/",
resolution_thr: float = 3.5,
min_seq_id: float = 0.3,
crop_size: int = 384,
max_length: int = 10_000,
use_fraction: float = 1.0,
entry_type: str = "chain",
Expand All @@ -135,6 +134,7 @@ def __init__(
:param resolution_thr: Resolution threshold for PDB structures
:param min_seq_id: Minimum sequence identity for MMSeq2 clustering
:param crop_size: The number of residues to crop the proteins to.
:param max_length: Entries with total length of chains larger than max_length will be disregarded.
:param use_fraction: the fraction of the clusters to use (first N in alphabetic order)
:param entry_type: {"biounit", "chain", "pair"} the type of entries to generate ("biounit" for biounit-level
Expand Down Expand Up @@ -163,24 +163,16 @@ def __init__(
self.save_hyperparameters(logger=False)

# data transformations
# self.transforms = transforms.Compose(
# [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
# )
self.transforms = transforms.Compose(
[Cropper(crop_size=crop_size), Reorder()] # crop and reorder
)

self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size

@property
def num_classes(self) -> int:
"""Get the number of classes.
:return: The number of MNIST classes (10).
"""
return 10

def prepare_data(self) -> None:
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only
within a single process on CPU, so you can safely add your downloading logic within. In
Expand All @@ -189,11 +181,11 @@ def prepare_data(self) -> None:
Do not use it to assign state (self.x = y).
"""
# Download data, PDB cutoff date 27.02.23
# Download precomputed data, PDB cutoff date 27.02.23
# This is a dataset with min_res=3.5A, min_len=30, max_len=10_000, min_seq_id=0.3, train/val/test=90/5/5
os.system("proteinflow download --tag 20230102_stable")

def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: Optional[str] = None, debug: bool = False) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
Expand All @@ -202,31 +194,21 @@ def setup(self, stage: Optional[str] = None) -> None:
`self.setup()` once the data is prepared and available for use.
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
:param debug: debugging mode
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size}). "
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

if not self.data_train and not self.data_val and not self.data_test:
train_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/train")
test_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/test")
val_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/valid")
train_dataset = proteinflow.ProteinDataset(train_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
entry_type=self.hparams.entry_type,
classes_to_exclude=self.hparams.classes_to_exclude,
mask_residues=self.hparams.mask_residues,
lower_limit=self.hparams.lower_limit,
upper_limit=self.hparams.upper_limit,
mask_frac=self.hparams.mask_frac,
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)
train_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/train")
test_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/test")
val_folder = os.path.join(self.hparams.data_dir, "proteinflow_20230102_stable/valid")
if debug:
# only load the test dataset if in debug mode
test_dataset = proteinflow.ProteinDataset(test_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
Expand All @@ -239,20 +221,51 @@ def setup(self, stage: Optional[str] = None) -> None:
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)
val_dataset = proteinflow.ProteinDataset(val_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
entry_type=self.hparams.entry_type,
classes_to_exclude=self.hparams.classes_to_exclude,
mask_residues=self.hparams.mask_residues,
lower_limit=self.hparams.lower_limit,
upper_limit=self.hparams.upper_limit,
mask_frac=self.hparams.mask_frac,
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)
# TODO: Transforms
self.data_train, self.data_val, self.data_test = train_dataset, val_dataset, test_dataset
self.data_test = TransformDataset(test_dataset, transform=self.transforms)
self.data_test = test_dataset
else:
if not self.data_train and not self.data_val and not self.data_test:
train_dataset = proteinflow.ProteinDataset(train_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
entry_type=self.hparams.entry_type,
classes_to_exclude=self.hparams.classes_to_exclude,
mask_residues=self.hparams.mask_residues,
lower_limit=self.hparams.lower_limit,
upper_limit=self.hparams.upper_limit,
mask_frac=self.hparams.mask_frac,
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)
test_dataset = proteinflow.ProteinDataset(test_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
entry_type=self.hparams.entry_type,
classes_to_exclude=self.hparams.classes_to_exclude,
mask_residues=self.hparams.mask_residues,
lower_limit=self.hparams.lower_limit,
upper_limit=self.hparams.upper_limit,
mask_frac=self.hparams.mask_frac,
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)
val_dataset = proteinflow.ProteinDataset(val_folder,
max_length=self.hparams.max_length,
use_fraction=self.hparams.use_fraction,
entry_type=self.hparams.entry_type,
classes_to_exclude=self.hparams.classes_to_exclude,
mask_residues=self.hparams.mask_residues,
lower_limit=self.hparams.lower_limit,
upper_limit=self.hparams.upper_limit,
mask_frac=self.hparams.mask_frac,
mask_sequential=self.hparams.mask_sequential,
mask_whole_chains=self.hparams.mask_whole_chains,
force_binding_sites_frac=self.hparams.force_binding_sites_frac)

# Apply transforms
self.data_train = TransformDataset(train_dataset, transform=self.transforms)
self.data_val = TransformDataset(val_dataset, transform=self.transforms)
self.data_test = TransformDataset(test_dataset, transform=self.transforms)

def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
Expand Down

0 comments on commit 12da901

Please sign in to comment.