Skip to content

Commit

Permalink
fix MSRC
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Oct 8, 2023
1 parent bc39a4c commit 663a650
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions ConservedWaterSearch/water_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def save_clustering_options(
print(self.nsnaps, file=f)
print(clustering_type, file=f)
print(clustering_algorithm, file=f)
if type(options) != list:
if type(options) is not list:
raise Exception("option has to be a list")
for i in options:
if type(i) != list and type(i) != np.ndarray:
if type(i) is not list and type(i) is not np.ndarray:
print(i, file=f)
elif type(i) == np.ndarray:
elif isinstance(i, np.ndarray):
print(*list(i), file=f)
else:
print(*i, file=f)
Expand Down Expand Up @@ -317,6 +317,7 @@ def _add_water_solutions(
oxygens and two hydrogens and water classification.
"""
for i in waters:
print(i)
self._waterO.append(i[0])
if len(i) > 2:
self._waterH1.append(i[1])
Expand All @@ -335,13 +336,13 @@ def __check_cls_alg_and_whichH(self, clustering_algorithm, whichH):
def __check_and_setup_single(self, xis, whichH, clustering_algorithm, minsamp):
if minsamp is None:
minsamp = int(self.numbpct_oxygen * self.nsnaps)
elif type(minsamp) != int:
elif type(minsamp) is not int:
raise Exception("minsamp must be an int")
elif minsamp > self.nsnaps or minsamp <= 0:
raise Exception("minsamp must be between 0 and nsnaps")
if xis is None:
xis = 0.05
elif type(xis) != float:
elif type(xis) is not float:
raise Exception("xi must be a float")
elif xis < 0 or xis > 1:
raise Exception("xis should be between 0 and 1")
Expand All @@ -358,13 +359,13 @@ def __check_and_setup_MSRC(
self, lower_minsamp_pct, every_minsamp, xis, whichH, clustering_algorithm
):
for i in xis:
if type(i) != float:
if type(i) is not float:
raise Exception("xis must contain floats")
if i > 1 or i < 0:
raise Exception("xis should be between 0 and 1")
if lower_minsamp_pct > 1.0000001 or lower_minsamp_pct < 0:
raise Exception("lower_misamp_pct must be between 0 and 1")
if type(every_minsamp) != int:
if type(every_minsamp) is not int:
raise Exception("every_minsamp must be integer")
if every_minsamp <= 0 or every_minsamp > self.nsnaps:
raise Exception("every_minsamp must be 0<every_minsamp<=nsnaps")
Expand All @@ -379,18 +380,16 @@ def __check_and_setup_MSRC(
)
if clustering_algorithm == "HDBSCAN":
lxis: list[float] = [0.0]
allow_single = False
elif clustering_algorithm == "OPTICS":
lxis = xis
allow_single = None
if self.save_intermediate_results:
self.save_clustering_options(
"multi_stage_reclustering",
clustering_algorithm,
[lower_minsamp_pct, every_minsamp, lxis],
whichH,
)
return minsamps, lxis, allow_single
return minsamps, lxis

def __scan_clustering_params(
self,
Expand All @@ -399,18 +398,19 @@ def __scan_clustering_params(
minsamps,
lxis,
whichH,
allow_single,
restart: bool = True,
H1=None,
H2=None,
):
found: bool = False if len(Odata) < self.nsnaps else True
for wt in whichH:
wta = [wt]
print(wt)
found: bool = False if len(Odata) < self.nsnaps else True
while found:
found = False
# loop over minsamps- from N(snapshots) to 0.75*N(snapshots)
for i in minsamps:
print(i)
if clustering_algorithm == "OPTICS":
clust: OPTICS | HDBSCAN = OPTICS(min_samples=int(i), n_jobs=self.njobs) # type: ignore
clust.fit(Odata)
Expand All @@ -426,7 +426,9 @@ def __scan_clustering_params(
),
cluster_selection_method="eom",
n_jobs=self.njobs,
allow_single_cluster=allow_single, # type: ignore
allow_single_cluster=True
if len(Odata) < self.nsnaps * 2
else False,
)
clust.fit(Odata)
clusters: np.ndarray = clust.labels_
Expand Down Expand Up @@ -464,9 +466,8 @@ def __scan_clustering_params(
plt = __check_mpl_installation()
plt.close(ff)
if len(waters) > 0:
print("waters found")
found = True
if clustering_algorithm == "HDBSCAN" and allow_single:
allow_single = False
if wt == "onlyO":
Odata = self._delete_data(idcs, Odata)
else:
Expand All @@ -475,19 +476,15 @@ def __scan_clustering_params(
if self.save_intermediate_results:
self.__save_intermediate_results()
i = i - 1
print("breaking inner for loop")
break
print("to break from for loop: found, restart ", found, restart)
if (found and restart) or len(Odata) < self.nsnaps:
break
# check if size of remaining data set is bigger then number of snapshots
if (
clustering_algorithm == "HDBSCAN"
and found is False
and allow_single is False
):
found = True
allow_single = True
if len(Odata) < self.nsnaps:
found = False
if len(Odata) < self.nsnaps or restart is False:
break
print("end of while loop", found)
if (self.debugH == 1 or self.debugO == 1) and self.plotend:
plt = __check_mpl_installation()
plt.show()
Expand Down Expand Up @@ -516,10 +513,13 @@ def __save_intermediate_results(self):
)

def __check_data(self, Odata, H1, H2, whichH):
if H1 is None or H2 is None and whichH != ["onlyO"]:
raise Exception("H1 and H2 have to be provided for non oxygen only search.")
if len(Odata) != len(H1) or len(Odata) != len(H2) or len(H1) != len(H2):
raise Exception("Odata, H1 and H2 have to be of same length")
if (H1 is None or H2 is None) and "onlyO" not in whichH:
raise Exception(
f"H1 and H2 have to be provided for non oxygen only search. Run type {whichH}"
)
if H1 is not None and H2 is not None:
if len(Odata) != len(H1) or len(Odata) != len(H2) or len(H1) != len(H2):
raise Exception("Odata, H1 and H2 have to be of same length")

def multi_stage_reclustering(
self,
Expand Down Expand Up @@ -579,11 +579,18 @@ def multi_stage_reclustering(
"""
self.__check_data(Odata, H1, H2, whichH)
self.__check_cls_alg_and_whichH(clustering_algorithm, whichH)
minsamps, lxis, allow_single = self.__check_and_setup_MSRC(
minsamps, lxis = self.__check_and_setup_MSRC(
lower_minsamp_pct, every_minsamp, xis, whichH, clustering_algorithm
)
self.__scan_clustering_params(
Odata, clustering_algorithm, minsamps, lxis, whichH, allow_single, H1, H2
Odata,
clustering_algorithm,
minsamps,
lxis,
whichH,
True,
H1,
H2,
)

def quick_multi_stage_reclustering(
Expand Down Expand Up @@ -645,7 +652,7 @@ def quick_multi_stage_reclustering(
"""
self.__check_data(Odata, H1, H2, whichH)
self.__check_cls_alg_and_whichH(clustering_algorithm, whichH)
minsamps, lxis, allow_single = self.__check_and_setup_MSRC(
minsamps, lxis = self.__check_and_setup_MSRC(
lower_minsamp_pct, every_minsamp, xis, whichH, clustering_algorithm
)
self.__scan_clustering_params(
Expand All @@ -654,7 +661,6 @@ def quick_multi_stage_reclustering(
minsamps,
lxis,
whichH,
allow_single,
False,
H1,
H2,
Expand Down Expand Up @@ -815,7 +821,7 @@ def _analyze_oxygen_clustering(
print(f"O clust {k}, size {len(clusters[clusters==k])}\n")
O_center = np.mean(Odata[mask], axis=0)
water = [O_center]
if not (whichH[0] == "onlyO"):
if "onlyO" not in whichH:
# Construct array of hydrogen orientations
orientations = np.vstack([H1[mask], H2[mask]])
# Analyse clustering with hydrogen orientation analysis and more debug stuff
Expand Down Expand Up @@ -845,6 +851,7 @@ def _analyze_oxygen_clustering(
plt = __check_mpl_installation()
plt.show()
if len(hyd) > 0:
print(hyd)
# add water atoms for pymol visualisation
for i in hyd:
water.append(O_center + i[0])
Expand All @@ -860,6 +867,7 @@ def _analyze_oxygen_clustering(
):
plt = __check_mpl_installation()
plt.show()
print(waters)
if stop_after_frist_water_found:
return waters, idcs
else:
Expand Down Expand Up @@ -978,6 +986,9 @@ def restart_cluster(
options,
whichH,
) = self.read_water_clust_options(options_file=options_file)
if "onlyO" in whichH:
H1 = None
H2 = None
if clustering_type == "multi_stage_reclustering":
self.multi_stage_reclustering(
Odata, # type: ignore
Expand Down Expand Up @@ -1304,7 +1315,7 @@ def __oxygen_clustering_plot(
For debuging oxygen clustering. Not ment for general usage.
"""
if type(cc) != OPTICS:
if type(cc) is not OPTICS:
plotreach = False
if debugO > 0:
plt = __check_mpl_installation()
Expand Down

0 comments on commit 663a650

Please sign in to comment.