Skip to content

Commit

Permalink
Minor updates (#19)
Browse files Browse the repository at this point in the history
* new function for finding viable parent

* update parameters and style pass

* update training saved objects

* expand some util functions

* add search cluster centroid exhaustive and style pass
  • Loading branch information
tony-kuo authored Nov 11, 2024
1 parent 0681cd4 commit 25f9cef
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 76 deletions.
32 changes: 13 additions & 19 deletions src/scimilarity/cell_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ def __init__(
self,
model_path: str,
use_gpu: bool = False,
parameters: Optional[dict] = None,
filenames: Optional[dict] = None,
residual: bool = False,
):
"""Constructor.
Expand All @@ -22,12 +20,8 @@ def __init__(
Path to the directory containing model files.
use_gpu: bool, default: False
Use GPU instead of CPU.
parameters: dict, optional, default: None
Use a dictionary of custom model parameters instead of infering from model files.
filenames: dict, optional, default: None
Use a dictionary of custom filenames for model files instead default.
residual: bool, default: False
Use residual connections.
Use a dictionary of custom filenames for files instead default.
Examples
--------
Expand All @@ -39,17 +33,14 @@ def __init__(
super().__init__(
model_path=model_path,
use_gpu=use_gpu,
parameters=parameters,
filenames=filenames,
residual=residual,
)

if filenames is None:
filenames = {}

self.annotation_path = os.path.join(model_path, "annotation")
os.makedirs(self.annotation_path, exist_ok=True)

if filenames is None:
filenames = {}

self.filenames["knn"] = os.path.join(
self.annotation_path, filenames.get("knn", "labelled_kNN.bin")
)
Expand Down Expand Up @@ -220,10 +211,10 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]):
>>> ca.blocklist_celltypes(["T cell"])
"""

self.reset_knn()
self.blocklist = set(labels)
self.safelist = None

self.reset_knn()
for i, celltype_name in self.idx2label.items():
if celltype_name in self.blocklist:
self.knn.mark_deleted(i) # mark blocklist
Expand Down Expand Up @@ -321,9 +312,10 @@ def get_predictions_knn(
embeddings=embeddings, k=k, ef=ef
)
end_time = time.time()
print(
f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min"
)
if not disable_progress:
print(
f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min"
)
stats = {
"hits": [],
"hits_weighted": [],
Expand All @@ -345,8 +337,10 @@ def get_predictions_knn(
celltype = defaultdict(float)
celltype_weighted = defaultdict(float)
for neighbor, dist in zip(nns, d_nns):
celltype[self.idx2label[neighbor]] += 1
celltype_weighted[self.idx2label[neighbor]] += 1 / max(dist, 1e-6)
celltype[self.idx2label[neighbor]] += 1.0
celltype_weighted[self.idx2label[neighbor]] += 1.0 / float(
max(dist, 1e-6)
)
# predict based on consensus max occurrence
if weighting:
predictions.append(
Expand Down
18 changes: 7 additions & 11 deletions src/scimilarity/cell_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ def __init__(
self,
model_path: str,
use_gpu: bool = False,
parameters: Optional[dict] = None,
filenames: Optional[dict] = None,
metadata_tiledb_uri: str = "cell_metadata",
embedding_tiledb_uri: str = "cell_embedding",
residual: bool = False,
load_knn: bool = True,
):
"""Constructor.
Expand All @@ -25,16 +23,12 @@ def __init__(
Path to the model directory.
use_gpu: bool, default: False
Use GPU instead of CPU.
parameters: dict, optional, default: None
Use a dictionary of custom model parameters instead of infering from model files.
filenames: dict, optional, default: None
Use a dictionary of custom filenames for model files instead default.
metadata_tiledb_uri: str, default: "cell_metadata"
Relative path to the directory containing the tiledb cell metadata storage.
embedding_tiledb_uri: str, default: "cell_embedding"
Relative path to the directory containing the tiledb cell embedding storage.
residual: bool, default: False
Use residual connections.
load_knn: bool, default: True
Load the knn index. Set to False if knn is not needed.
Expand All @@ -51,11 +45,10 @@ def __init__(
super().__init__(
model_path=model_path,
use_gpu=use_gpu,
parameters=parameters,
filenames=filenames,
residual=residual,
)

self.cellsearch_path = os.path.join(model_path, "cellsearch")
os.makedirs(self.cellsearch_path, exist_ok=True)

if filenames is None:
filenames = {}
Expand Down Expand Up @@ -105,7 +98,7 @@ def __init__(
cell_metadata[c] = cell_metadata[c].replace("NA", np.nan)
cell_metadata = cell_metadata.astype(convert_dict)
tiledb.from_pandas(metadata_tiledb_uri, cell_metadata)
self.cell_metadata = tiledb.open_dataframe(metadata_tiledb_uri)
self.cell_metadata = tiledb.open(metadata_tiledb_uri, "r").df[:]

# get cell embeddings: create tiledb storage if it does not exist
embedding_tiledb_uri = os.path.join(self.cellsearch_path, embedding_tiledb_uri)
Expand Down Expand Up @@ -599,6 +592,9 @@ def search_exhaustive(
import pandas as pd
from scipy.spatial.distance import cdist

if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)

nn_idxs = [[] for _ in range(embeddings.shape[0])]
nn_dists = [[] for _ in range(embeddings.shape[0])]
n_cells = self.cell_metadata.shape[0]
Expand All @@ -624,7 +620,7 @@ def search_exhaustive(

# sort by lowest distance
for row in range(len(nn_idxs)):
nn_idxs[row] = np.hstack(nn_idxs[0])
nn_idxs[row] = np.hstack(nn_idxs[row])
nn_dists[row] = np.hstack(nn_dists[row])
sorted_indices = np.argsort(nn_dists[row])
nn_idxs[row] = nn_idxs[row][sorted_indices]
Expand Down
12 changes: 8 additions & 4 deletions src/scimilarity/nn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def load_state(self, filename: str, use_gpu: bool = False):
"""

if not use_gpu:
ckpt = torch.load(filename, map_location=torch.device("cpu"))
ckpt = torch.load(
filename, map_location=torch.device("cpu"), weights_only=False
)
else:
ckpt = torch.load(filename)
ckpt = torch.load(filename, weights_only=False)
self.load_state_dict(ckpt["state_dict"])


Expand Down Expand Up @@ -218,7 +220,9 @@ def load_state(self, filename: str, use_gpu: bool = False):
"""

if not use_gpu:
ckpt = torch.load(filename, map_location=torch.device("cpu"))
ckpt = torch.load(
filename, map_location=torch.device("cpu"), weights_only=False
)
else:
ckpt = torch.load(filename)
ckpt = torch.load(filename, weights_only=False)
self.load_state_dict(ckpt["state_dict"])
42 changes: 42 additions & 0 deletions src/scimilarity/ontologies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import networkx as nx
import obonet
import pandas as pd
from typing import Union, Tuple, List


Expand Down Expand Up @@ -327,6 +328,47 @@ def get_lowest_common_ancestor(graph, node1, node2) -> nx.DiGraph:
)


def find_most_viable_parent(graph, node, node_list):
"""Get most viable parent of a given node among the node_list.
Parameters
----------
graph: networkx.DiGraph
Node graph.
node: str
ID of given node.
node_list: list, set, optional, default: None
A restricted node list for filtering.
Returns
-------
networkx.DiGraph
Node graph of parents.
Examples
--------
>>> coarse_grained = find_most_viable_parent(onto, id, celltype_list)
"""

parents = get_parents(graph, node, node_list=node_list)
if len(parents) == 0:
coarse_grained = None
all_parents = list(get_parents(graph, node))
if len(all_parents) == 1:
grandparents = get_parents(graph, all_parents[0], node_list=node_list)
if len(grandparents) == 1:
(coarse_grained,) = grandparents
elif len(parents) == 1:
(coarse_grained,) = parents
else:
for parent in list(parents):
coarse_grained = None
if get_all_ancestors(graph, parent, node_list=pd.Index(parents)):
coarse_grained = parent
break
return coarse_grained


def ontology_similarity(graph, node1, node2, restricted_set=None) -> int:
"""Get the ontology similarity of two terms based on the number of common ancestors.
Expand Down
4 changes: 2 additions & 2 deletions src/scimilarity/training_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def save_all(
meta_data["cell_tdb_uri"] = self.trainer.datamodule.cell_tdb_uri
meta_data["gene_tdb_uri"] = self.trainer.datamodule.gene_tdb_uri
meta_data["counts_tdb_uri"] = self.trainer.datamodule.counts_tdb_uri
self.trainer.datamodule.data_df.to_csv(
self.trainer.datamodule.train_df.to_csv(
os.path.join(model_path, "train_cells.csv")
)
if self.trainer.datamodule.val_df is not None:
Expand All @@ -669,7 +669,7 @@ def save_all(
with open(os.path.join(model_path, "reference_labels.tsv"), "w") as f:
f.write(
"\n".join(
self.trainer.datamodule.data_df["cellTypeName"].values.tolist()
self.trainer.datamodule.train_df["cellTypeName"].values.tolist()
)
)
with open(os.path.join(model_path, "metadata.json"), "w") as f:
Expand Down
Loading

0 comments on commit 25f9cef

Please sign in to comment.