diff --git a/CHANGELOG.md b/CHANGELOG.md index 31dd4772..c4f9c89b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ From v3.0.2 to v3.1.0 cleared from the data set with the `clear` parameter of `dropDescriptorSets`. - Added a proper API for parallelization backend selection and configuration (see documentation of `ParallelGenerator` and `JITParallelGenerator` for more information). +- Clusters can now be added to a `MoleculeTable` with `addClusters` and retrieved with + `getClusters`, similar to scaffolds. ## Removed Features diff --git a/qsprpred/data/chem/clustering.py b/qsprpred/data/chem/clustering.py index cad3c663..eba0bd9f 100644 --- a/qsprpred/data/chem/clustering.py +++ b/qsprpred/data/chem/clustering.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd from rdkit import Chem, DataStructs +from rdkit.Chem import Mol from rdkit.SimDivFilters import rdSimDivPickers from .scaffolds import BemisMurckoRDKit, Scaffold @@ -11,14 +12,39 @@ from ..descriptors.fingerprints import Fingerprint, MorganFP from ...logs import logger +from qsprpred.data.processing.mol_processor import MolProcessorWithID -class MoleculeClusters(ABC): + +class MoleculeClusters(MolProcessorWithID, ABC): """ Abstract base class for clustering molecules. Attributes: nClusters (int): number of clusters """ + + def __call__(self, mols: list[str | Mol], props, *args, **kwargs): + """ + Calculate the clusters for a list of molecules. + + Args: + mol (str | Mol): SMILES or RDKit molecule to calculate the cluster for. + + Returns: + list of cluster index for each molecule + """ + if isinstance(mols[0], Mol): + mols = [Chem.MolToSmiles(mol) for mol in mols] + + clusters = self.get_clusters(mols) + + # map clusters to molecules + output = np.array([-1]*len(mols)) + for cluster_idx, molecule_idxs in clusters.items(): + output[molecule_idxs] = cluster_idx + + return pd.Series(output, index=props[self.idProp]) + @abstractmethod def get_clusters(self, smiles_list: list[str]) -> dict: @@ -41,6 +67,13 @@ def _set_nClusters(self, N: int) -> None: f"Number of initial clusters is too small to combine them well,\ it has set to {self.nClusters}" ) + + def supportsParallel(self) -> bool: + return False + + @abstractmethod + def __str__(self): + pass class RandomClusters(MoleculeClusters): @@ -50,9 +83,13 @@ class RandomClusters(MoleculeClusters): Attributes: seed (int): random seed nClusters (int): number of clusters + id_prop (str): name of the property to be used as ID """ - def __init__(self, seed: int = 42, n_clusters: int | None = None): + def __init__( + self, seed: int = 42, n_clusters: int | None = None, id_prop: str | None = None + ): + super().__init__(id_prop=id_prop) self.seed = seed self.nClusters = n_clusters @@ -79,6 +116,9 @@ def get_clusters(self, smiles_list: list[str]) -> dict: clusters[i % self.nClusters].append(index) return clusters + + def __str__(self): + return "RandomClusters" class ScaffoldClusters(MoleculeClusters): @@ -87,10 +127,13 @@ class ScaffoldClusters(MoleculeClusters): Attributes: scaffold (Scaffold): scaffold generator + id_prop (str): name of the property to be used as ID """ - def __init__(self, scaffold: Scaffold = BemisMurckoRDKit()): - super().__init__() + def __init__( + self, scaffold: Scaffold = BemisMurckoRDKit(), id_prop: str | None = None + ): + super().__init__(id_prop=id_prop) self.scaffold = scaffold def get_clusters(self, smiles_list: list[str]) -> dict: @@ -126,14 +169,18 @@ def get_clusters(self, smiles_list: list[str]) -> dict: clusters[unique_scaffolds.index(scaffold)].append(i) return clusters + + def __str__(self): + return f"ScaffoldClusters_{self.scaffold}" class FPSimilarityClusters(MoleculeClusters): def __init__( self, fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048), + id_prop: str | None = None, ) -> None: - super().__init__() + super().__init__(id_prop=id_prop) self.fp_calculator = fp_calculator def get_clusters(self, smiles_list: list[str]) -> dict: @@ -187,6 +234,7 @@ class FPSimilarityMaxMinClusters(FPSimilarityClusters): nClusters (int): number of clusters seed (int): random seed initialCentroids (list): list of indices of initial cluster centroids + id_prop (str): name of the property to be used as ID """ def __init__( @@ -195,8 +243,9 @@ def __init__( seed: int | None = None, initial_centroids: list[str] | None = None, fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048), + id_prop: str | None = None, ): - super().__init__(fp_calculator=fp_calculator) + super().__init__(fp_calculator=fp_calculator, id_prop=id_prop) self.nClusters = n_clusters self.seed = seed self.initialCentroids = initial_centroids @@ -213,7 +262,7 @@ def _get_centroids(self, fps: list) -> list: """ self._set_nClusters(len(fps)) picker = rdSimDivPickers.MaxMinPicker() - centroid_indices = picker.LazyBitVectorPick( + self.centroid_indices = picker.LazyBitVectorPick( fps, len(fps), self.nClusters, @@ -221,7 +270,10 @@ def _get_centroids(self, fps: list) -> list: seed=self.seed if self.seed is not None else -1, ) - return centroid_indices + return self.centroid_indices + + def __str__(self): + return "FPSimilarityMaxMinClusters" class FPSimilarityLeaderPickerClusters(FPSimilarityClusters): @@ -231,14 +283,16 @@ class FPSimilarityLeaderPickerClusters(FPSimilarityClusters): Attributes: fp_calculator (FingerprintSet): fingerprint calculator similarity_threshold (float): similarity threshold + id_prop (str): name of the property to be used as ID """ def __init__( self, similarity_threshold: float = 0.7, fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048), + id_prop: str | None = None, ): - super().__init__(fp_calculator=fp_calculator) + super().__init__(fp_calculator=fp_calculator, id_prop=id_prop) self.similarityThreshold = similarity_threshold self.fpCalculator = fp_calculator @@ -247,8 +301,11 @@ def _get_centroids(self, fps: list) -> list: Get cluster centroids with LeaderPicker algorithm. """ picker = rdSimDivPickers.LeaderPicker() - centroid_indices = picker.LazyBitVectorPick( + self.centroid_indices = picker.LazyBitVectorPick( fps, len(fps), self.similarityThreshold ) - return centroid_indices + return self.centroid_indices + + def __str__(self): + return "FPSimilarityLeaderPickerClusters" diff --git a/qsprpred/data/chem/scaffolds.py b/qsprpred/data/chem/scaffolds.py index e27b0533..9b1817e0 100644 --- a/qsprpred/data/chem/scaffolds.py +++ b/qsprpred/data/chem/scaffolds.py @@ -83,7 +83,7 @@ def __init__( self, real_bemismurcko: bool = True, use_csk: bool = False, - id_prop: bool | None = None, + id_prop: str | None = None, ): """ Initialize the scaffold generator. diff --git a/qsprpred/data/chem/tests.py b/qsprpred/data/chem/tests.py index f1e05d4d..d35bd676 100644 --- a/qsprpred/data/chem/tests.py +++ b/qsprpred/data/chem/tests.py @@ -3,6 +3,7 @@ from ... import TargetTasks from ...data import QSPRDataset from ...data.chem.scaffolds import BemisMurckoRDKit, BemisMurcko +from ...data.chem.clustering import RandomClusters, FPSimilarityMaxMinClusters, FPSimilarityLeaderPickerClusters, ScaffoldClusters from ...utils.testing.base import QSPRTestCase from ...utils.testing.path_mixins import DataSetsPathMixIn @@ -40,6 +41,31 @@ def testScaffoldAdd(self, _, scaffold): # for mol in scaffs[f"Scaffold_{scaffold}_RDMol"]: # self.assertTrue(isinstance(mol, Chem.rdchem.Mol)) +class TestClusters(DataSetsPathMixIn, QSPRTestCase): + """Test calculation of clusters.""" + + def setUp(self): + """Create a test dataset.""" + super().setUp() + self.setUpPaths() + self.dataset = self.createLargeTestDataSet(self.__class__.__name__) + + @parameterized.expand( + [ + ("Random", RandomClusters()), + ("FPSimilarityMaxMin", FPSimilarityMaxMinClusters()), + ("FPSimilarityLeaderPicker", FPSimilarityLeaderPickerClusters()), + ("Scaffold", ScaffoldClusters(BemisMurckoRDKit())), + ] + ) + def testClusterAdd(self, _, cluster): + """Test the adding and getting of clusters.""" + self.dataset.addClusters([cluster]) + clusters = self.dataset.getClusters() + self.assertEqual(clusters.shape, (len(self.dataset), 1)) + self.dataset.addClusters([cluster], recalculate=True) + self.assertEqual(clusters.shape, (len(self.dataset), 1)) + class TestStandardizers(DataSetsPathMixIn, QSPRTestCase): """Test the standardizers.""" diff --git a/qsprpred/data/tables/mol.py b/qsprpred/data/tables/mol.py index 4e210c64..47a1e803 100644 --- a/qsprpred/data/tables/mol.py +++ b/qsprpred/data/tables/mol.py @@ -982,7 +982,7 @@ def createScaffoldGroups(self, mols_per_group: int = 10): size. Args: - mols_per_group (int): Number of molecules per scaffold group. + mols_per_group (int): number of molecules per scaffold group. """ scaffolds = self.getScaffolds(include_mols=False) for scaffold in scaffolds.columns: @@ -1028,6 +1028,66 @@ def hasScaffoldGroups(self): > 0 ) + def addClusters( + self, + clusters: list["MoleculeClusters"], + recalculate: bool = False, + ): + """Add clusters to the data frame. + + A new column is created that contains the identifier of the corresponding + cluster calculator. + + Args: + clusters (list): list of `MoleculeClusters` calculators. + recalculate (bool): Whether to recalculate clusters even if they are + already present in the data frame. + """ + for cluster in clusters: + if not recalculate and f"Cluster_{cluster}" in self.df.columns: + continue + for clusters in self.processMols(cluster): + self.df.loc[clusters.index, f"Cluster_{cluster}"] = clusters.values + + + def getClusterNames( + self, clusters: list["MoleculeClusters"] | None = None + ): + """Get the names of the clusters in the data frame. + + Returns: + list: List of cluster names. + """ + all_names = [ + col + for col in self.df.columns + if col.startswith("Cluster_") + ] + if clusters: + wanted = [str(x) for x in clusters] + return [x for x in all_names if x.split("_", 1)[1] in wanted] + return all_names + + def getClusters( + self, clusters: list["MoleculeClusters"] | None = None + ): + """Get the subset of the data frame that contains only clusters. + + Returns: + pd.DataFrame: Data frame containing only clusters. + """ + names = self.getClusterNames(clusters) + return self.df[names] + + @property + def hasClusters(self): + """Check whether the data frame contains clusters. + + Returns: + bool: Whether the data frame contains clusters. + """ + return len(self.getClusterNames()) > 0 + def standardizeSmiles(self, smiles_standardizer, drop_invalid=True): """Apply smiles_standardizer to the compounds in parallel