-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclustering_step.py
128 lines (107 loc) · 4.13 KB
/
clustering_step.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
from rdkit import Chem
import argparse
import sys
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
def dihedrals_angles(mol,circular = True):
raw_rot_bonds = mol.GetSubstructMatches(Chem.MolFromSmarts("[!#1]~[!$(*#*)&!D1]-!@[!$(*#*)&!D1]~[!#1]"))
bonds = []
rot_bonds = []
for k,i,j,l in raw_rot_bonds:
if (i,j) not in bonds:
bonds.append((i,j))
rot_bonds.append((k,i,j,l))
thetas = []
dih_points = []
if len(rot_bonds) == 0:
print('Not rotatable bonds found. Exiting clustering protocol ...')
sys.exit()
for k,i,j,l in rot_bonds:
if (k+1,i+1,j+1,l+1) not in dih_points:
dih_points.append((k+1,i+1,j+1,l+1))
dih = Chem.rdMolTransforms.GetDihedralRad(mol.GetConformers()[0], k,i,j,l )
if circular:
thetas.append(np.cos(dih))
thetas.append(np.sin(dih))
else:
thetas.append(dih)
return thetas,dih_points
def get_dPCA_mat(mols):
dihedrals = []
dih_mat = []
for mol in mols:
if mol != None:
dihedrals,_ = dihedrals_angles(mol,circular=True)
dih_mat.append(dihedrals)
dih_mat = np.array(dih_mat)
num_pc = len(dihedrals)
X = StandardScaler().fit_transform(dih_mat) #normalizing the features
pca = PCA(n_components=0.75)
X_pca = pca.fit_transform(X)
print('N confs: {} and N features: {}'.format(X_pca.shape[0],X_pca.shape[1]))
return X_pca
def ClusterIndices(clustNum, labels_array):
return np.where(labels_array == clustNum)[0]
def select_lowest_energy_centroid(model,energy,nclusters):
min_clusters = []
for i in range(nclusters):
idx_cluster = ClusterIndices(i,model.labels_)
cluster_energy = [(energy[i],i) for i in idx_cluster]
en_min = 10e6
for en,id_conf in cluster_energy:
if en < en_min:
en_min = en
min_cluster_idx = id_conf
min_clusters.append(min_cluster_idx)
return min_clusters
def get_representatives(dPCA_mat,energy):
model = KMeans(n_clusters=4)
model.fit(dPCA_mat)
# Apply vector quantization to find closest point or "conformer" to the centroid of a particular cluster
closest = select_lowest_energy_centroid(model,energy,4)
return closest
def best_energy_id(mols):
en_min = 10e6
for mol in mols:
if mol.HasProp('PHA_CONF_ID'):
id_mol = mol.GetProp('PHA_CONF_ID')
# energy provided by PharmScreen
en = float(mol.GetProp('PHA_CONF_ENERGY').split(' ')[0])
if en < en_min:
en_min = en
id_mol_min = id_mol
low_en_mol = mol
else:
print(" The input sdf has not include 'PHA_CONF_ID', be sure that you are using the proper Pharmscreen output file")
quit()
return id_mol_min,low_en_mol
def cluster_conformers(input,output):
mols = Chem.SDMolSupplier(input,removeHs=False)
ensemble_PHA_ids = []
w = Chem.SDWriter(output)
best_conf_id,le_mol = best_energy_id(mols)
w.write(le_mol)
print('Lowest energy conformer PHA id: ',best_conf_id)
print("\n")
mols_filtered = []
for mol in enumerate(mols):
mols_filtered.append(mol)
mols2 = [mol for mol in mols_filtered if mol.GetProp('PHA_CONF_ID')!= best_conf_id]
energy = [float(mol.GetProp('PHA_CONF_ENERGY').split(' ')[0]) for mol in mols2]
dPCA_mat = get_dPCA_mat(mols2)
representatives = get_representatives(dPCA_mat,energy)
for id in representatives:
ensemble_PHA_ids.append(mols2[id].GetProp('PHA_CONF_ID'))
w.write(mols2[id])
return ensemble_PHA_ids
def main():
parser = argparse.ArgumentParser(description='Script to cluster conformers from multi-coformer pharmscreen sdf')
parser.add_argument('-i', required=True, help='sdf input file')
parser.add_argument('-o', required=True, help='sdf output file')
args = parser.parse_args()
medoids = cluster_conformers(args.i,args.o)
print("Representatives ids: ",medoids)
if __name__ == '__main__':
main()