Skip to content

Commit

Permalink
onlyO mode can now be computed without hydrogens
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Oct 8, 2023
1 parent 33c4e42 commit bc39a4c
Showing 1 changed file with 45 additions and 24 deletions.
69 changes: 45 additions & 24 deletions ConservedWaterSearch/water_clustering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from shutil import which
from typing import TYPE_CHECKING

try:
Expand Down Expand Up @@ -394,14 +395,14 @@ def __check_and_setup_MSRC(
def __scan_clustering_params(
self,
Odata,
H1,
H2,
clustering_algorithm,
minsamps,
lxis,
whichH,
allow_single,
restart: bool = True,
H1=None,
H2=None,
):
found: bool = False if len(Odata) < self.nsnaps else True
for wt in whichH:
Expand Down Expand Up @@ -466,7 +467,10 @@ def __scan_clustering_params(
found = True
if clustering_algorithm == "HDBSCAN" and allow_single:
allow_single = False
Odata, H1, H2 = self._delete_data(idcs, Odata, H1, H2)
if wt == "onlyO":
Odata = self._delete_data(idcs, Odata)
else:
Odata, H1, H2 = self._delete_data(idcs, Odata, H1, H2)
self._add_water_solutions(waters)
if self.save_intermediate_results:
self.__save_intermediate_results()
Expand Down Expand Up @@ -511,11 +515,17 @@ def __save_intermediate_results(self):
"unable to overwrite temp save files. Restarting might not work properly"
)

def __check_data(self, Odata, H1, H2, whichH):
if H1 is None or H2 is None and whichH != ["onlyO"]:
raise Exception("H1 and H2 have to be provided for non oxygen only search.")
if len(Odata) != len(H1) or len(Odata) != len(H2) or len(H1) != len(H2):
raise Exception("Odata, H1 and H2 have to be of same length")

def multi_stage_reclustering(
self,
Odata: np.ndarray,
H1: np.ndarray,
H2: np.ndarray,
H1: np.ndarray | None,
H2: np.ndarray | None,
clustering_algorithm: str = "OPTICS",
lower_minsamp_pct: float = 0.25,
every_minsamp: int = 1,
Expand Down Expand Up @@ -544,8 +554,10 @@ def multi_stage_reclustering(
Args:
Odata (np.ndarray): Oxygen coordinates.
H1 (np.ndarray): Hydrogen 1 orientations.
H2 (np.ndarray): Hydrogen 2 orientations.
H1 (np.ndarray | None): Hydrogen 1 orientations. If None ``whichH``
must be "onlyO".
H2 (np.ndarray | None): Hydrogen 2 orientations. If None ``whichH``
must be "onlyO".
clustering_algorithm (str, optional): Options are "OPTICS"
or "HDBSCAN". OPTICS provides slightly better results, but
is also slightly slower. Defaults to "OPTICS".
Expand All @@ -565,19 +577,20 @@ def multi_stage_reclustering(
allowed, or "onlyO" for oxygen clustering only.
Defaults to ["FCW", "HCW", "WCW"].
"""
self.__check_data(Odata, H1, H2, whichH)
self.__check_cls_alg_and_whichH(clustering_algorithm, whichH)
minsamps, lxis, allow_single = self.__check_and_setup_MSRC(
lower_minsamp_pct, every_minsamp, xis, whichH, clustering_algorithm
)
self.__scan_clustering_params(
Odata, H1, H2, clustering_algorithm, minsamps, lxis, whichH, allow_single
Odata, clustering_algorithm, minsamps, lxis, whichH, allow_single, H1, H2
)

def quick_multi_stage_reclustering(
self,
Odata: np.ndarray,
H1: np.ndarray,
H2: np.ndarray,
H1: np.ndarray | None,
H2: np.ndarray | None,
clustering_algorithm: str = "OPTICS",
lower_minsamp_pct: float = 0.25,
every_minsamp: int = 1,
Expand Down Expand Up @@ -606,8 +619,10 @@ def quick_multi_stage_reclustering(
Args:
Odata (np.ndarray): Oxygen coordinates.
H1 (np.ndarray): Hydrogen 1 orientations.
H2 (np.ndarray): Hydrogen 2 orientations.
H1 (np.ndarray | None): Hydrogen 1 orientations. If None ``whichH``
must be "onlyO".
H2 (np.ndarray | None): Hydrogen 2 orientations. If None ``whichH``
must be "onlyO".
clustering_algorithm (str, optional): Options are "OPTICS"
or "HDBSCAN". OPTICS provides slightly better results, but
is also slightly slower. Defaults to "OPTICS".
Expand All @@ -628,27 +643,28 @@ def quick_multi_stage_reclustering(
allowed, or "onlyO" for oxygen clustering only.
Defaults to ["FCW", "HCW", "WCW"].
"""
self.__check_data(Odata, H1, H2, whichH)
self.__check_cls_alg_and_whichH(clustering_algorithm, whichH)
minsamps, lxis, allow_single = self.__check_and_setup_MSRC(
lower_minsamp_pct, every_minsamp, xis, whichH, clustering_algorithm
)
self.__scan_clustering_params(
Odata,
H1,
H2,
clustering_algorithm,
minsamps,
lxis,
whichH,
allow_single,
restart=False,
False,
H1,
H2,
)

def single_clustering(
self,
Odata: np.ndarray,
H1: np.ndarray,
H2: np.ndarray,
H1: np.ndarray | None,
H2: np.ndarray | None,
clustering_algorithm: str = "OPTICS",
minsamp: int | None = None,
xi: float | None = None,
Expand All @@ -662,8 +678,10 @@ def single_clustering(
Args:
Odata (np.ndarray): Oxygen coordinates.
H1 (np.ndarray): Hydrogen 1 orientations.
H2 (np.ndarray): Hydrogen 2 orientations.
H1 (np.ndarray | None): Hydrogen 1 orientations. If None ``whichH``
must be "onlyO".
H2 (np.ndarray | None): Hydrogen 2 orientations. If None ``whichH``
must be "onlyO".
clustering_algorithm (str, optional): Options are "OPTICS"
or "HDBSCAN". OPTICS provides slightly better results, but
is also slightly slower. Defaults to "OPTICS".
Expand All @@ -679,6 +697,7 @@ def single_clustering(
allowed, or "onlyO" for oxygen clustering only.
Defaults to ["FCW", "HCW", "WCW"].
"""
self.__check_data(Odata, H1, H2, whichH)
self.__check_cls_alg_and_whichH(clustering_algorithm, whichH)
minsamp, xi = self.__check_and_setup_single(
xi, whichH, clustering_algorithm, minsamp
Expand Down Expand Up @@ -739,8 +758,8 @@ def single_clustering(
def _analyze_oxygen_clustering(
self,
Odata: np.ndarray,
H1: np.ndarray,
H2: np.ndarray,
H1: np.ndarray | None,
H2: np.ndarray | None,
clusters: np.ndarray,
stop_after_frist_water_found: bool,
whichH: list[str],
Expand All @@ -755,8 +774,10 @@ def _analyze_oxygen_clustering(
Args:
Odata (np.ndarray): Oxygen coordinates
H1 (np.ndarray): Hydrogen 1 orientations.
H2 (np.ndarray): Hydrogen 2 orientations.
H1 (np.ndarray | None): Hydrogen 1 orientations. If None ``whichH``
must be "onlyO".
H2 (np.ndarray | None): Hydrogen 2 orientations. If None ``whichH``
must be "onlyO".
clusters (np.ndarray): Output of clustering
results from OPTICS or HDBSCAN.
stop_after_frist_water_found (bool): If True, the procedure
Expand Down Expand Up @@ -785,7 +806,7 @@ def _analyze_oxygen_clustering(
idcs = np.array([], dtype=int)
# Loop over all oxygen clusters (-1 is non cluster)
for k in cluster_ids:
mask = (clusters == k)
mask = clusters == k
# Number of elements in oxygen cluster
neioc = np.count_nonzero(mask)
# If number of elements in oxygen cluster is Nsnap*0.85<Nelem<Nsnap*1.15 then ignore
Expand Down

0 comments on commit bc39a4c

Please sign in to comment.