Skip to content

Commit

Permalink
Merge branch 'refactor/clustering' into 'dev'
Browse files Browse the repository at this point in the history
Refactor/clustering

See merge request cdd/QSPRpred!186
  • Loading branch information
martin-sicho committed Jun 7, 2024
2 parents e430e99 + 06adc8e commit 047ce53
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ From v3.0.2 to v3.1.0
cleared from the data set with the `clear` parameter of `dropDescriptorSets`.
- Added a proper API for parallelization backend selection and configuration (see
documentation of `ParallelGenerator` and `JITParallelGenerator` for more information).
- Clusters can now be added to a `MoleculeTable` with `addClusters` and retrieved with
`getClusters`, similar to scaffolds.

## Removed Features

Expand Down
79 changes: 68 additions & 11 deletions qsprpred/data/chem/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,47 @@
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import Mol
from rdkit.SimDivFilters import rdSimDivPickers

from .scaffolds import BemisMurckoRDKit, Scaffold
from .. import MoleculeTable
from ..descriptors.fingerprints import Fingerprint, MorganFP
from ...logs import logger

from qsprpred.data.processing.mol_processor import MolProcessorWithID

class MoleculeClusters(ABC):

class MoleculeClusters(MolProcessorWithID, ABC):
"""
Abstract base class for clustering molecules.
Attributes:
nClusters (int): number of clusters
"""

def __call__(self, mols: list[str | Mol], props, *args, **kwargs):
"""
Calculate the clusters for a list of molecules.
Args:
mol (str | Mol): SMILES or RDKit molecule to calculate the cluster for.
Returns:
list of cluster index for each molecule
"""
if isinstance(mols[0], Mol):
mols = [Chem.MolToSmiles(mol) for mol in mols]

clusters = self.get_clusters(mols)

# map clusters to molecules
output = np.array([-1]*len(mols))
for cluster_idx, molecule_idxs in clusters.items():
output[molecule_idxs] = cluster_idx

return pd.Series(output, index=props[self.idProp])


@abstractmethod
def get_clusters(self, smiles_list: list[str]) -> dict:
Expand All @@ -41,6 +67,13 @@ def _set_nClusters(self, N: int) -> None:
f"Number of initial clusters is too small to combine them well,\
it has set to {self.nClusters}"
)

def supportsParallel(self) -> bool:
return False

@abstractmethod
def __str__(self):
pass


class RandomClusters(MoleculeClusters):
Expand All @@ -50,9 +83,13 @@ class RandomClusters(MoleculeClusters):
Attributes:
seed (int): random seed
nClusters (int): number of clusters
id_prop (str): name of the property to be used as ID
"""

def __init__(self, seed: int = 42, n_clusters: int | None = None):
def __init__(
self, seed: int = 42, n_clusters: int | None = None, id_prop: str | None = None
):
super().__init__(id_prop=id_prop)
self.seed = seed
self.nClusters = n_clusters

Expand All @@ -79,6 +116,9 @@ def get_clusters(self, smiles_list: list[str]) -> dict:
clusters[i % self.nClusters].append(index)

return clusters

def __str__(self):
return "RandomClusters"


class ScaffoldClusters(MoleculeClusters):
Expand All @@ -87,10 +127,13 @@ class ScaffoldClusters(MoleculeClusters):
Attributes:
scaffold (Scaffold): scaffold generator
id_prop (str): name of the property to be used as ID
"""

def __init__(self, scaffold: Scaffold = BemisMurckoRDKit()):
super().__init__()
def __init__(
self, scaffold: Scaffold = BemisMurckoRDKit(), id_prop: str | None = None
):
super().__init__(id_prop=id_prop)
self.scaffold = scaffold

def get_clusters(self, smiles_list: list[str]) -> dict:
Expand Down Expand Up @@ -126,14 +169,18 @@ def get_clusters(self, smiles_list: list[str]) -> dict:
clusters[unique_scaffolds.index(scaffold)].append(i)

return clusters

def __str__(self):
return f"ScaffoldClusters_{self.scaffold}"


class FPSimilarityClusters(MoleculeClusters):
def __init__(
self,
fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048),
id_prop: str | None = None,
) -> None:
super().__init__()
super().__init__(id_prop=id_prop)
self.fp_calculator = fp_calculator

def get_clusters(self, smiles_list: list[str]) -> dict:
Expand Down Expand Up @@ -187,6 +234,7 @@ class FPSimilarityMaxMinClusters(FPSimilarityClusters):
nClusters (int): number of clusters
seed (int): random seed
initialCentroids (list): list of indices of initial cluster centroids
id_prop (str): name of the property to be used as ID
"""

def __init__(
Expand All @@ -195,8 +243,9 @@ def __init__(
seed: int | None = None,
initial_centroids: list[str] | None = None,
fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048),
id_prop: str | None = None,
):
super().__init__(fp_calculator=fp_calculator)
super().__init__(fp_calculator=fp_calculator, id_prop=id_prop)
self.nClusters = n_clusters
self.seed = seed
self.initialCentroids = initial_centroids
Expand All @@ -213,15 +262,18 @@ def _get_centroids(self, fps: list) -> list:
"""
self._set_nClusters(len(fps))
picker = rdSimDivPickers.MaxMinPicker()
centroid_indices = picker.LazyBitVectorPick(
self.centroid_indices = picker.LazyBitVectorPick(
fps,
len(fps),
self.nClusters,
firstPicks=self.initialCentroids if self.initialCentroids else [],
seed=self.seed if self.seed is not None else -1,
)

return centroid_indices
return self.centroid_indices

def __str__(self):
return "FPSimilarityMaxMinClusters"


class FPSimilarityLeaderPickerClusters(FPSimilarityClusters):
Expand All @@ -231,14 +283,16 @@ class FPSimilarityLeaderPickerClusters(FPSimilarityClusters):
Attributes:
fp_calculator (FingerprintSet): fingerprint calculator
similarity_threshold (float): similarity threshold
id_prop (str): name of the property to be used as ID
"""

def __init__(
self,
similarity_threshold: float = 0.7,
fp_calculator: Fingerprint = MorganFP(radius=3, nBits=2048),
id_prop: str | None = None,
):
super().__init__(fp_calculator=fp_calculator)
super().__init__(fp_calculator=fp_calculator, id_prop=id_prop)
self.similarityThreshold = similarity_threshold
self.fpCalculator = fp_calculator

Expand All @@ -247,8 +301,11 @@ def _get_centroids(self, fps: list) -> list:
Get cluster centroids with LeaderPicker algorithm.
"""
picker = rdSimDivPickers.LeaderPicker()
centroid_indices = picker.LazyBitVectorPick(
self.centroid_indices = picker.LazyBitVectorPick(
fps, len(fps), self.similarityThreshold
)

return centroid_indices
return self.centroid_indices

def __str__(self):
return "FPSimilarityLeaderPickerClusters"
2 changes: 1 addition & 1 deletion qsprpred/data/chem/scaffolds.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self,
real_bemismurcko: bool = True,
use_csk: bool = False,
id_prop: bool | None = None,
id_prop: str | None = None,
):
"""
Initialize the scaffold generator.
Expand Down
26 changes: 26 additions & 0 deletions qsprpred/data/chem/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ... import TargetTasks
from ...data import QSPRDataset
from ...data.chem.scaffolds import BemisMurckoRDKit, BemisMurcko
from ...data.chem.clustering import RandomClusters, FPSimilarityMaxMinClusters, FPSimilarityLeaderPickerClusters, ScaffoldClusters
from ...utils.testing.base import QSPRTestCase
from ...utils.testing.path_mixins import DataSetsPathMixIn

Expand Down Expand Up @@ -40,6 +41,31 @@ def testScaffoldAdd(self, _, scaffold):
# for mol in scaffs[f"Scaffold_{scaffold}_RDMol"]:
# self.assertTrue(isinstance(mol, Chem.rdchem.Mol))

class TestClusters(DataSetsPathMixIn, QSPRTestCase):
"""Test calculation of clusters."""

def setUp(self):
"""Create a test dataset."""
super().setUp()
self.setUpPaths()
self.dataset = self.createLargeTestDataSet(self.__class__.__name__)

@parameterized.expand(
[
("Random", RandomClusters()),
("FPSimilarityMaxMin", FPSimilarityMaxMinClusters()),
("FPSimilarityLeaderPicker", FPSimilarityLeaderPickerClusters()),
("Scaffold", ScaffoldClusters(BemisMurckoRDKit())),
]
)
def testClusterAdd(self, _, cluster):
"""Test the adding and getting of clusters."""
self.dataset.addClusters([cluster])
clusters = self.dataset.getClusters()
self.assertEqual(clusters.shape, (len(self.dataset), 1))
self.dataset.addClusters([cluster], recalculate=True)
self.assertEqual(clusters.shape, (len(self.dataset), 1))


class TestStandardizers(DataSetsPathMixIn, QSPRTestCase):
"""Test the standardizers."""
Expand Down
62 changes: 61 additions & 1 deletion qsprpred/data/tables/mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def createScaffoldGroups(self, mols_per_group: int = 10):
size.
Args:
mols_per_group (int): Number of molecules per scaffold group.
mols_per_group (int): number of molecules per scaffold group.
"""
scaffolds = self.getScaffolds(include_mols=False)
for scaffold in scaffolds.columns:
Expand Down Expand Up @@ -1028,6 +1028,66 @@ def hasScaffoldGroups(self):
> 0
)

def addClusters(
self,
clusters: list["MoleculeClusters"],
recalculate: bool = False,
):
"""Add clusters to the data frame.
A new column is created that contains the identifier of the corresponding
cluster calculator.
Args:
clusters (list): list of `MoleculeClusters` calculators.
recalculate (bool): Whether to recalculate clusters even if they are
already present in the data frame.
"""
for cluster in clusters:
if not recalculate and f"Cluster_{cluster}" in self.df.columns:
continue
for clusters in self.processMols(cluster):
self.df.loc[clusters.index, f"Cluster_{cluster}"] = clusters.values


def getClusterNames(
self, clusters: list["MoleculeClusters"] | None = None
):
"""Get the names of the clusters in the data frame.
Returns:
list: List of cluster names.
"""
all_names = [
col
for col in self.df.columns
if col.startswith("Cluster_")
]
if clusters:
wanted = [str(x) for x in clusters]
return [x for x in all_names if x.split("_", 1)[1] in wanted]
return all_names

def getClusters(
self, clusters: list["MoleculeClusters"] | None = None
):
"""Get the subset of the data frame that contains only clusters.
Returns:
pd.DataFrame: Data frame containing only clusters.
"""
names = self.getClusterNames(clusters)
return self.df[names]

@property
def hasClusters(self):
"""Check whether the data frame contains clusters.
Returns:
bool: Whether the data frame contains clusters.
"""
return len(self.getClusterNames()) > 0

def standardizeSmiles(self, smiles_standardizer, drop_invalid=True):
"""Apply smiles_standardizer to the compounds in parallel
Expand Down

0 comments on commit 047ce53

Please sign in to comment.