Skip to content

Commit

Permalink
Merge pull request #89 from a-r-j/time_splits
Browse files Browse the repository at this point in the history
changed time_cutoff option
  • Loading branch information
amorehead authored Mar 26, 2024
2 parents d5fbab7 + 7a3875d commit a891666
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)
* Remove remaining errors from PDB dataset change
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88)
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) as well as time-based splits [#89](https://github.com/a-r-j/ProteinWorkshop/pull/89)

### Models

Expand Down
11 changes: 7 additions & 4 deletions proteinworkshop/config/dataset/pdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ datamodule:

pdb_dataset:
_target_: "proteinworkshop.datasets.pdb_dataset.PDBData"
fraction: 1.0 # Fraction of dataset to use
fraction: 0.01 # Fraction of dataset to use
molecule_type: "protein" # Type of molecule for which to select
experiment_types: ["diffraction", "NMR", "EM", "other"] # All experiment types
max_length: 1000 # Exclude polypeptides greater than length 1000
max_length: 150 # Exclude polypeptides greater than length 1000
min_length: 10 # Exclude peptides of length 10
oligomeric_min: 1 # Include only monomeric proteins
oligomeric_max: 5 # Include up to 5-meric proteins
Expand All @@ -24,6 +24,9 @@ datamodule:
remove_non_standard_residues: True # Include only proteins containing standard amino acid residues
remove_pdb_unavailable: True # Include only proteins that are available to download
train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random")
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other options are "random" and "time_cutoff"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type!="sequence_similarity")
overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten
split_time_frames: null # Time-cutoffs for train, val and test set (argument is ignored if split_type!="time_cutoff") - e.g., ["2020-01-01", "2021-01-01", "2023-03-01"]


26 changes: 20 additions & 6 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hydra
import omegaconf
import numpy as np
import os
import pandas as pd
import pathlib
Expand Down Expand Up @@ -30,9 +31,11 @@ def __init__(
remove_non_standard_residues: bool,
remove_pdb_unavailable: bool,
train_val_test: List[float],
split_type: Literal["sequence_similarity", "random"],
split_sequence_similiarity: int,
overwrite_sequence_clusters: bool
split_type: Literal["sequence_similarity", "time_cutoff", "random"] = "random",
split_sequence_similiarity: Optional[int] = None,
overwrite_sequence_clusters: Optional[bool] = False,
split_time_frames: Optional[List[str]] = None,

):
self.fraction = fraction
self.molecule_type = molecule_type
Expand All @@ -52,6 +55,11 @@ def __init__(
self.split_type = split_type
self.split_sequence_similarity = split_sequence_similiarity
self.overwrite_sequence_clusters = overwrite_sequence_clusters
if self.split_type == "time_cutoff":
try:
self.split_time_frames = [np.datetime64(date) for date in split_time_frames]
except:
raise TypeError(f"{split_time_frames} does not contain valid dates for np.datetime64 format")
self.splits = ["train", "val", "test"]

def create_dataset(self):
Expand Down Expand Up @@ -128,9 +136,15 @@ def create_dataset(self):
elif self.split_type == "sequence_similarity":
log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...")
log.info(f"Using {self.split_sequence_similarity} sequence similarity for split")
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True)
splits = pdb_manager.split_clusters(
pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters)
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True,
overwrite = self.overwrite_sequence_clusters)
splits = pdb_manager.split_clusters(pdb_manager.df, update=True)

elif self.split_type == "time_cutoff":
log.info(f"Splitting dataset via time_cutoff split into {self.train_val_test}...")
log.info(f"Using {self.split_time_frames} dates for split")
pdb_manager.split_time_frames = self.split_time_frames
splits = pdb_manager.split_by_deposition_date(df=pdb_manager.df, update=True)

log.info(splits["train"])
return splits
Expand Down

0 comments on commit a891666

Please sign in to comment.