Skip to content

Commit

Permalink
Refactor subsim_search
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierBeq committed Aug 20, 2022
1 parent 758759b commit b2a5d17
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 42 deletions.
106 changes: 65 additions & 41 deletions src/papyrus_scripts/subsim_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@

from __future__ import annotations


import time
import multiprocessing
import os
import time
import warnings
from collections import defaultdict
from io import BytesIO
from typing import Optional, Tuple, Union

import pandas as pd
import pystow
import rdkit
from rdkit import Chem
import pandas as pd
from tqdm import tqdm
from rdkit.Chem.rdSubstructLibrary import SubstructLibrary, PatternHolder, CachedMolHolder
from tqdm import tqdm

try:
import cupy
Expand Down Expand Up @@ -87,16 +85,21 @@ def create_from_papyrus(self,
"""
# Set version
self.version = process_data_version(version=version, root_folder=root_folder)
# Set 3D
self.is3d = is3d
# Determine default paths
if root_folder is not None:
os.environ['PYSTOW_HOME'] = os.path.abspath(root_folder)
source_path = pystow.join('papyrus', self.version, 'structures')
# Find the file
filenames = locate_file(source_path.as_posix(),
f'*.*_combined_{3 if is3d else 2}D_set_with{"out" if not is3d else ""}_stereochemistry.sd*')
sd_file = filenames[0]
total = total = get_num_rows_in_file(filetype='structures', is3D=is3d, version=self.version, root_folder=root_folder)
self.create(sd_file=sd_file, outfile=outfile, fingerprint=fingerprint, total=total, progress=progress, njobs=njobs)
f'*.*_combined_{3 if is3d else 2}D_set_'
f'with{"out" if not is3d else ""}_stereochemistry.sd*')
self.sd_file = filenames[0]
total = get_num_rows_in_file(filetype='structures', is3D=is3d, version=self.version, root_folder=root_folder)
self.h5_filename = outfile
self.create(sd_file=self.sd_file, outfile=outfile, fingerprint=fingerprint,
total=total, progress=progress, njobs=njobs)

def create(self,
sd_file: str,
Expand All @@ -109,10 +112,9 @@ def create(self,
fingerprints and handle full substructure search (subgraph isomorphism)
and load it when finished.
:param papyrus_sd_file: papyrus sd file containing chemical structures
:param version: version of the Papyrus dataset
:param sd_file: sd file containing chemical structures
:param outfile: filename or filepath of output database
:param fingerprint: fingerprints to be calculated, if None uses all available
:param fingerprint: fingerprints to be calculated; if None, uses all available
:param progress: whether progress should be shown
:param total: number of molecules for progress display
:param njobs: number of concurrent processes (-1 for all available logical cores)
Expand Down Expand Up @@ -149,7 +151,7 @@ def create(self,
subst_group = h5file.create_group(
h5file.root, "substructure_info", "Infos for substructure search")
# Array containing processed binary of the substructure library
subst_table = h5file.create_earray(
_ = h5file.create_earray(
subst_group, 'substruct_lib', tb.UInt64Atom(), (0,), 'Substructure search library')
# Table for mapping indices to identifiers
h5file.create_table(
Expand Down Expand Up @@ -274,15 +276,18 @@ def _parallel_create(self, njobs=-1, fingerprint: Union[Fingerprint, List[Finger
n_cpus = njobs - 1
processes = []
# Start reader
reader = multiprocessing.Process(target=_reader_process, args=(self.sd_file, n_cpus, total, False, input_queue, output_queue))
reader = multiprocessing.Process(target=_reader_process, args=(self.sd_file, n_cpus,
total, False,
input_queue))
processes.append(reader)
reader.start()
# Start writer
writer = multiprocessing.Process(target=_writer_process, args=(self.h5_filename, output_queue, table_paths, total, progress))
writer = multiprocessing.Process(target=_writer_process, args=(self.h5_filename, output_queue,
table_paths, total, progress))
writer.start()
# Start workers
for i in range(n_cpus):
job = multiprocessing.Process(target=_worker_process, args=(fp_types, input_queue, output_queue, n_cpus))
job = multiprocessing.Process(target=_worker_process, args=(fp_types, input_queue, output_queue))
processes.append(job)
processes[-1].start()
# Joining workers
Expand Down Expand Up @@ -341,7 +346,10 @@ def get_similarity_lib(self, fp_signature: Optional[str] = None, cuda: bool = Fa
return FPSubSim2CudaEngine(self.h5_filename, fp_signature)
return FPSubSim2Engine(self.h5_filename, fp_signature)

def add_fingerprint(self, fingerprint: Fingerprint, papyrus_sd_file: str, progress: bool = True, total: Optional[int]= None):
def add_fingerprint(self, fingerprint: Fingerprint,
papyrus_sd_file: str,
progress: bool = True,
total: Optional[int] = None):
"""Add a similarity fingerprint to the FPSubSim2 database.
:param fingerprint: Fingerprint to be added
Expand All @@ -358,7 +366,7 @@ def add_fingerprint(self, fingerprint: Fingerprint, papyrus_sd_file: str, progre
backend.change_fp_for_append(fingerprint)
backend.append_fps(MolSupplier(source=papyrus_sd_file), total=total, progress=progress)

def add_molecules(self, papyrus_sd_file: str, progress: bool = True, total: Optional[int]= None):
def add_molecules(self, papyrus_sd_file: str, progress: bool = True, total: Optional[int] = None):
"""Add molecules to the FPSubSim2 database.
:param papyrus_sd_file: papyrus sd file containing new chemical structures
Expand All @@ -383,13 +391,14 @@ def add_molecules(self, papyrus_sd_file: str, progress: bool = True, total: Opti
with tb.open_file(self.h5_filename, mode="a") as h5file:
# Remove previous lib
h5file.remove_node(h5file.root.substructure_info.substruct_lib)
h5file.create_earray(h5file.root.substructure_info, 'substruct_lib', tb.UInt64Atom(), (0,), 'Substructure search library')
h5file.create_earray(h5file.root.substructure_info, 'substruct_lib',
tb.UInt64Atom(), (0,), 'Substructure search library')
h5file.root.substructure_info.substruct_lib.attrs.padding = padding
h5file.root.substructure_info.substruct_lib.append(lib_ints)
sort_db_file(self.h5_filename, verbose=progress)


def _reader_process(sd_file, n_workers, total, progress, input_queue, output_queue):
def _reader_process(sd_file, n_workers, total, progress, input_queue):
with MolSupplier(source=sd_file, total=total, show_progress=progress, start_id=1) as supplier:
count = 0
for mol_id, rdmol in supplier:
Expand All @@ -406,7 +415,7 @@ def _reader_process(sd_file, n_workers, total, progress, input_queue, output_que

def _writer_process(h5_filename, output_queue, table_paths, total, progress):
lib = SubstructLibrary(CachedMolHolder(), PatternHolder())
pbar = tqdm(total=total, smoothing = 0.0) if progress else {}
pbar = tqdm(total=total, smoothing=0.0) if progress else {}
mappings_insert = []
similarity_insert = defaultdict(list)
with tb.open_file(h5_filename, mode="r+") as h5file:
Expand Down Expand Up @@ -456,7 +465,7 @@ def _writer_process(h5_filename, output_queue, table_paths, total, progress):
return


def _worker_process(fp_types, input_queue, output_queue, n_workers):
def _worker_process(fp_types, input_queue, output_queue):
while True:
# while output_queue.qsize() > BATCH_WRITE_SIZE * n_workers / 2:
# time.sleep(0.5)
Expand All @@ -480,7 +489,7 @@ def _worker_process(fp_types, input_queue, output_queue, n_workers):
output_queue.put(('similarity', repr(fper), (mol_id, *fp)))


def sort_db_file(filename: str, verbose: bool=False) -> None:
def sort_db_file(filename: str, verbose: bool = False) -> None:
"""Sorts the FPs db file."""
if verbose:
print('Optimizing FPSubSim2 file.')
Expand All @@ -502,7 +511,8 @@ def sort_db_file(filename: str, verbose: bool=False) -> None:
with tb.open_file(tmp_filename, mode="r") as fp_file:
with tb.open_file(filename, mode="w") as sorted_fp_file:
# group to hold similarity tables
siminfo_group = sorted_fp_file.create_group(sorted_fp_file.root, "similarity_info", "Infos for similarity search")
siminfo_group = sorted_fp_file.create_group(sorted_fp_file.root, "similarity_info",
"Infos for similarity search")
simfp_groups = list(fp_file.walk_groups('/similarity_info/'))
i = 0
for simfp_group in simfp_groups:
Expand Down Expand Up @@ -535,7 +545,8 @@ def sort_db_file(filename: str, verbose: bool=False) -> None:

# update count ranges
popcnt_bins = calc_popcnt_bins_pytables(dst_fp_table, fp_table.attrs.length)
popcounts = sorted_fp_file.create_vlarray(dst_group, 'popcounts', tb.ObjectAtom(), f'Popcounts of {dst_group._v_name}')
popcounts = sorted_fp_file.create_vlarray(dst_group, 'popcounts', tb.ObjectAtom(),
f'Popcounts of {dst_group._v_name}')
for x in popcnt_bins:
popcounts.append(x)
# add other tables
Expand All @@ -545,7 +556,9 @@ def sort_db_file(filename: str, verbose: bool=False) -> None:
if isinstance(node, tb.group.Group):
if isinstance(node, tb.group.RootGroup) or 'similarity_info' in str(node):
continue
_ = node._f_copy(sorted_fp_file.root, node._v_name, overwrite=True, recursive=True, filters=filters, stats=stats)
_ = node._f_copy(sorted_fp_file.root, node._v_name,
overwrite=True, recursive=True,
filters=filters, stats=stats)
else:
_ = node.copy(sorted_fp_file.root, node._v_name, overwrite=True, stats=stats)
# remove unsorted file
Expand All @@ -564,8 +577,10 @@ def __init__(self, fp_filename: str, fp_signature: str, in_memory_fps: bool = Tr
for simfp_group in fp_file.walk_groups('/similarity_info/'):
if len(simfp_group._v_name):
fp_table = fp_file.get_node(simfp_group, 'fps', classname='Table')
self._fp_table_mappings[fp_table.attrs.fp_id] = [f'/similarity_info/{simfp_group._v_name}/fps',
f'/similarity_info/{simfp_group._v_name}/popcounts']
self._fp_table_mappings[fp_table.attrs.fp_id] = [f'/similarity_info/{simfp_group._v_name}'
'/fps',
f'/similarity_info/{simfp_group._v_name}'
'/popcounts']
if fp_signature not in self._fp_table_mappings.keys():
raise ValueError(f'fingerprint not available, must be one of {", ".join(self._fp_table_mappings.keys())}')
self._current_fp = fp_signature
Expand Down Expand Up @@ -644,7 +659,8 @@ def delete_fps(self, ids_list: List[int]) -> None:
]
fps_table.remove_row(to_delete[0])

def append_fps(self, supplier: MolSupplier, progress: bool=True, total: Optional[int]=None, sort: bool = True) -> None:
def append_fps(self, supplier: MolSupplier, progress: bool = True,
total: Optional[int] = None, sort: bool = True) -> None:
"""Appends FPs to the file for the fingerprint currently selected."""
with tb.open_file(self.fp_filename, mode="a") as fp_file:
fps_table = fp_file.get_node(self._current_fp_path)
Expand Down Expand Up @@ -679,14 +695,15 @@ def change_fp_for_append(self, fingerprint: Fingerprint):
# New table
particle = create_schema(fingerprint.length)
fp_table = fp_file.create_table(fp_group, 'fps', particle, 'Similarity FPs', expectedrows=1300000,
filters=filters)
filters=filters)
# New attributes
fp_table.attrs.fp_type = fingerprint.name
fp_table.attrs.fp_id = self._current_fp
fp_table.attrs.length = fingerprint.length
fp_table.attrs.fp_params = json.dumps(fingerprint.params)
# New Popcounts
popcounts = fp_file.create_vlarray(fp_group, 'popcounts', tb.ObjectAtom(), f'Popcounts of {fp_group._v_name}')
popcounts = fp_file.create_vlarray(fp_group, 'popcounts', tb.ObjectAtom(),
f'Popcounts of {fp_group._v_name}')
self._current_fp_path = f'/similarity_info/{fp_group._v_name}/fps'
self._current_popcounts_path = f'/similarity_info/{fp_group._v_name}/popcounts'
self.fp_type, self.fp_params, self.rdkit_ver = self.read_parameters()
Expand Down Expand Up @@ -792,7 +809,8 @@ def similarity(self, query_string: str, threshold: float, n_workers: int = 1) ->
"""
data = list(zip(*FPSim2Engine.similarity(self, query_string, threshold, n_workers)))
if not len(data):
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey', f'Tanimoto > {threshold} ({self.storage._current_fp})'])
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey',
f'Tanimoto > {threshold} ({self.storage._current_fp})'])
ids, similarities = data
ids, similarities = list(ids), list(similarities)
data = self._get_mapping(ids)
Expand All @@ -803,7 +821,7 @@ def similarity(self, query_string: str, threshold: float, n_workers: int = 1) ->
data[col] = data[col].apply(lambda x: x.decode('utf-8'))
return data

def on_disk_similarity(self, query_string: str, threshold: float, n_workers: int=1, chunk_size: int=0):
def on_disk_similarity(self, query_string: str, threshold: float, n_workers: int = 1, chunk_size: int = 0):
"""Perform Tanimoto similarity search on disk.
:param query_string:
Expand All @@ -814,7 +832,8 @@ def on_disk_similarity(self, query_string: str, threshold: float, n_workers: int
"""
data = list(zip(*FPSim2Engine.on_disk_similarity(self, query_string, threshold, n_workers, chunk_size)))
if not len(data):
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey', f'Tanimoto > {threshold} ({self.storage._current_fp})'])
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey',
f'Tanimoto > {threshold} ({self.storage._current_fp})'])
ids, similarities = data
ids, similarities = list(ids), list(similarities)
data = self._get_mapping(ids)
Expand All @@ -837,7 +856,8 @@ def tversky(self, query_string: str, threshold: float, a: float, b: float, n_wor
"""
data = list(zip(*FPSim2Engine.tversky(self, query_string, threshold, a, b, n_workers)))
if not len(data):
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey', f'Tversky > {threshold} ({self.storage._current_fp})'])
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey',
f'Tversky > {threshold} ({self.storage._current_fp})'])
ids, similarities = data
ids, similarities = list(ids), list(similarities)
data = self._get_mapping(ids)
Expand All @@ -848,7 +868,9 @@ def tversky(self, query_string: str, threshold: float, a: float, b: float, n_wor
data[col] = data[col].apply(lambda x: x.decode('utf-8'))
return data

def on_disk_tversky(self, query_string: str, threshold: float, a: float, b: float, n_workers: int = 1, chunk_size: int = None):
def on_disk_tversky(self, query_string: str, threshold: float,
a: float, b: float,
n_workers: int = 1, chunk_size: int = None):
"""Perform Tversky similarity search on disk.
:param query_string:
Expand All @@ -861,7 +883,8 @@ def on_disk_tversky(self, query_string: str, threshold: float, a: float, b: floa
"""
data = list(zip(*FPSim2Engine.on_disk_tversky(self, query_string, threshold, a, b, n_workers, chunk_size)))
if not len(data):
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey', f'Tversky > {threshold} ({self.storage._current_fp})'])
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey',
f'Tversky > {threshold} ({self.storage._current_fp})'])
ids, similarities = data
ids, similarities = list(ids), list(similarities)
data = self._get_mapping(ids)
Expand All @@ -887,7 +910,7 @@ def __init__(
fp_filename: str,
fp_signature: str,
storage_backend: str = "pytables",
kernel: str='raw'
kernel: str = 'raw'
) -> None:
"""FPSubSim2 class to run fast CPU similarity searches.
Expand Down Expand Up @@ -935,7 +958,8 @@ def similarity(self, query_string: str, threshold: float) -> pd.DataFrame:
"""Tanimoto similarity search."""
data = list(zip(*FPSim2CudaEngine.similarity(self, query_string, threshold)))
if not len(data):
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey', f'Tanimoto > {threshold} ({self.storage._current_fp})'])
return pd.DataFrame([], columns=['idnumber', 'connectivity', 'InChIKey',
f'Tanimoto > {threshold} ({self.storage._current_fp})'])
ids, similarities = data
ids, similarities = list(ids), list(similarities)
data = self._get_mapping(ids)
Expand Down Expand Up @@ -982,8 +1006,8 @@ def _get_mapping(self, ids: Union[List[int], int]):
data[col] = data[col].apply(lambda x: x.decode('utf-8'))
return data

def GetMatches(self, query: Union[str, Chem.Mol], recursionPossible: bool=True, useChirality: bool=True,
useQueryQueryMatches: bool=False, numThreads: int=-1, maxResults: int=-1):
def GetMatches(self, query: Union[str, Chem.Mol], recursionPossible: bool = True, useChirality: bool = True,
useQueryQueryMatches: bool = False, numThreads: int = -1, maxResults: int = -1):
if isinstance(query, str):
query = load_molecule(query)
ids = list(super(SubstructureLibrary, self).GetMatches(query=query,
Expand Down
1 change: 0 additions & 1 deletion src/papyrus_scripts/utils/UniprotMatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import json
import time
import zlib
from io import StringIO
from typing import List, Union
from xml.etree import ElementTree
from urllib.parse import urlparse, parse_qs, urlencode
Expand Down

0 comments on commit b2a5d17

Please sign in to comment.