diff --git a/CHANGELOG.md b/CHANGELOG.md index 1feddd64..ebe1a315 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ * Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72) * Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86) * Remove remaining errors from PDB dataset change -* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) +* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) as well as time-based splits [#89](https://github.com/a-r-j/ProteinWorkshop/pull/89) ### Models diff --git a/proteinworkshop/config/dataset/pdb.yaml b/proteinworkshop/config/dataset/pdb.yaml index 991917d7..56540cc3 100644 --- a/proteinworkshop/config/dataset/pdb.yaml +++ b/proteinworkshop/config/dataset/pdb.yaml @@ -10,10 +10,10 @@ datamodule: pdb_dataset: _target_: "proteinworkshop.datasets.pdb_dataset.PDBData" - fraction: 1.0 # Fraction of dataset to use + fraction: 0.01 # Fraction of dataset to use molecule_type: "protein" # Type of molecule for which to select experiment_types: ["diffraction", "NMR", "EM", "other"] # All experiment types - max_length: 1000 # Exclude polypeptides greater than length 1000 + max_length: 150 # Exclude polypeptides greater than length 1000 min_length: 10 # Exclude peptides of length 10 oligomeric_min: 1 # Include only monomeric proteins oligomeric_max: 5 # Include up to 5-meric proteins @@ -24,6 +24,9 @@ datamodule: remove_non_standard_residues: True # Include only proteins containing standard amino acid residues remove_pdb_unavailable: True # Include only proteins that are available to download train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits - split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random" - split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random") + split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other options are "random" and "time_cutoff" + split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type!="sequence_similarity") overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten + split_time_frames: null # Time-cutoffs for train, val and test set (argument is ignored if split_type!="time_cutoff") - e.g., ["2020-01-01", "2021-01-01", "2023-03-01"] + + diff --git a/proteinworkshop/datasets/pdb_dataset.py b/proteinworkshop/datasets/pdb_dataset.py index 1de4abb7..0daa9d95 100644 --- a/proteinworkshop/datasets/pdb_dataset.py +++ b/proteinworkshop/datasets/pdb_dataset.py @@ -2,6 +2,7 @@ import hydra import omegaconf +import numpy as np import os import pandas as pd import pathlib @@ -30,9 +31,11 @@ def __init__( remove_non_standard_residues: bool, remove_pdb_unavailable: bool, train_val_test: List[float], - split_type: Literal["sequence_similarity", "random"], - split_sequence_similiarity: int, - overwrite_sequence_clusters: bool + split_type: Literal["sequence_similarity", "time_cutoff", "random"] = "random", + split_sequence_similiarity: Optional[int] = None, + overwrite_sequence_clusters: Optional[bool] = False, + split_time_frames: Optional[List[str]] = None, + ): self.fraction = fraction self.molecule_type = molecule_type @@ -52,6 +55,11 @@ def __init__( self.split_type = split_type self.split_sequence_similarity = split_sequence_similiarity self.overwrite_sequence_clusters = overwrite_sequence_clusters + if self.split_type == "time_cutoff": + try: + self.split_time_frames = [np.datetime64(date) for date in split_time_frames] + except: + raise TypeError(f"{split_time_frames} does not contain valid dates for np.datetime64 format") self.splits = ["train", "val", "test"] def create_dataset(self): @@ -128,9 +136,15 @@ def create_dataset(self): elif self.split_type == "sequence_similarity": log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...") log.info(f"Using {self.split_sequence_similarity} sequence similarity for split") - pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True) - splits = pdb_manager.split_clusters( - pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters) + pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True, + overwrite = self.overwrite_sequence_clusters) + splits = pdb_manager.split_clusters(pdb_manager.df, update=True) + + elif self.split_type == "time_cutoff": + log.info(f"Splitting dataset via time_cutoff split into {self.train_val_test}...") + log.info(f"Using {self.split_time_frames} dates for split") + pdb_manager.split_time_frames = self.split_time_frames + splits = pdb_manager.split_by_deposition_date(df=pdb_manager.df, update=True) log.info(splits["train"]) return splits