Skip to content

Commit

Permalink
parameter clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tony-kuo committed Nov 11, 2024
1 parent 25f9cef commit 5c64bdb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 51 deletions.
49 changes: 16 additions & 33 deletions src/scimilarity/cell_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ def __init__(
self,
model_path: str,
use_gpu: bool = False,
parameters: Optional[dict] = None,
filenames: Optional[dict] = None,
residual: bool = False,
):
"""Constructor.
Expand All @@ -20,12 +17,6 @@ def __init__(
Path to the directory containing model files.
use_gpu: bool, default: False
Use GPU instead of CPU.
parameters: dict, optional, default: None
Use a dictionary of custom model parameters instead of infering from model files.
filenames: dict, optional, default: None
Use a dictionary of custom filenames for model files instead default.
residual: bool, default: False
Use residual connections.
Examples
--------
Expand All @@ -40,44 +31,36 @@ def __init__(
self.model_path = model_path
self.use_gpu = use_gpu

if filenames is None:
filenames = {}

self.filenames = {
"model": os.path.join(
self.model_path, filenames.get("model", "encoder.ckpt")
),
"gene_order": os.path.join(
self.model_path, filenames.get("gene_order", "gene_order.tsv")
),
"model": os.path.join(self.model_path, "encoder.ckpt"),
"gene_order": os.path.join(self.model_path, "gene_order.tsv"),
}

# get gene order
with open(self.filenames["gene_order"], "r") as fh:
self.gene_order = [line.strip() for line in fh]

# get neural network model
if parameters is None: # infer network size if not explicitly given
with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh:
layer_sizes = json.load(fh)
# keys: network.1.weight, network.2.weight, ..., network.n.weight
layers = [
(key, layer_sizes[key])
for key in sorted(list(layer_sizes.keys()))
if "weight" in key and len(layer_sizes[key]) > 1
]
parameters = {
"latent_dim": layers[-1][1][0], # last
"hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last
}
# get neural network model and infer network size
with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh:
layer_sizes = json.load(fh)
# keys: network.1.weight, network.2.weight, ..., network.n.weight
layers = [
(key, layer_sizes[key])
for key in sorted(list(layer_sizes.keys()))
if "weight" in key and len(layer_sizes[key]) > 1
]
parameters = {
"latent_dim": layers[-1][1][0], # last
"hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last
}

self.n_genes = len(self.gene_order)
self.latent_dim = parameters["latent_dim"]
self.model = Encoder(
n_genes=self.n_genes,
latent_dim=parameters["latent_dim"],
hidden_dim=parameters["hidden_dim"],
residual=residual,
residual=False,
)
if self.use_gpu is True:
self.model.cuda()
Expand Down
19 changes: 1 addition & 18 deletions src/scimilarity/cell_search_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ def __init__(
self,
model_path: str,
use_gpu: bool = False,
parameters: Optional[dict] = None,
filenames: Optional[dict] = None,
residual: bool = False,
):
"""Constructor.
Expand All @@ -22,31 +19,17 @@ def __init__(
Path to the directory containing model files.
use_gpu: bool, default: False
Use GPU instead of CPU.
parameters: dict, optional, default: None
Use a dictionary of custom model parameters instead of infering from model files.
filenames: dict, optional, default: None
Use a dictionary of custom filenames for model files instead default.
The kNN filenames also need to be specified here.
residual: bool, default: False
Use residual connections.
Examples
--------
>>> filenames = {"knn": "knn.bin"}
>>> cs = CellSearch(model_path="/opt/data/model", filenames=filesnames)
>>> cs = CellSearchKNN(model_path="/opt/data/model")
"""

super().__init__(
model_path=model_path,
use_gpu=use_gpu,
parameters=parameters,
filenames=filenames,
residual=residual,
)

if filenames is None:
filenames = {}

self.knn = None
self.safelist = None
self.blocklist = None
Expand Down

0 comments on commit 5c64bdb

Please sign in to comment.