From 5c64bdbef5df9caaf01ad8260931df547933f859 Mon Sep 17 00:00:00 2001 From: tony-kuo Date: Mon, 11 Nov 2024 14:38:11 -0800 Subject: [PATCH] parameter clean up --- src/scimilarity/cell_embedding.py | 49 ++++++++++-------------------- src/scimilarity/cell_search_knn.py | 19 +----------- 2 files changed, 17 insertions(+), 51 deletions(-) diff --git a/src/scimilarity/cell_embedding.py b/src/scimilarity/cell_embedding.py index 424f51d..caa836b 100644 --- a/src/scimilarity/cell_embedding.py +++ b/src/scimilarity/cell_embedding.py @@ -8,9 +8,6 @@ def __init__( self, model_path: str, use_gpu: bool = False, - parameters: Optional[dict] = None, - filenames: Optional[dict] = None, - residual: bool = False, ): """Constructor. @@ -20,12 +17,6 @@ def __init__( Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. - parameters: dict, optional, default: None - Use a dictionary of custom model parameters instead of infering from model files. - filenames: dict, optional, default: None - Use a dictionary of custom filenames for model files instead default. - residual: bool, default: False - Use residual connections. Examples -------- @@ -40,36 +31,28 @@ def __init__( self.model_path = model_path self.use_gpu = use_gpu - if filenames is None: - filenames = {} - self.filenames = { - "model": os.path.join( - self.model_path, filenames.get("model", "encoder.ckpt") - ), - "gene_order": os.path.join( - self.model_path, filenames.get("gene_order", "gene_order.tsv") - ), + "model": os.path.join(self.model_path, "encoder.ckpt"), + "gene_order": os.path.join(self.model_path, "gene_order.tsv"), } # get gene order with open(self.filenames["gene_order"], "r") as fh: self.gene_order = [line.strip() for line in fh] - # get neural network model - if parameters is None: # infer network size if not explicitly given - with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh: - layer_sizes = json.load(fh) - # keys: network.1.weight, network.2.weight, ..., network.n.weight - layers = [ - (key, layer_sizes[key]) - for key in sorted(list(layer_sizes.keys())) - if "weight" in key and len(layer_sizes[key]) > 1 - ] - parameters = { - "latent_dim": layers[-1][1][0], # last - "hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last - } + # get neural network model and infer network size + with open(os.path.join(self.model_path, "layer_sizes.json"), "r") as fh: + layer_sizes = json.load(fh) + # keys: network.1.weight, network.2.weight, ..., network.n.weight + layers = [ + (key, layer_sizes[key]) + for key in sorted(list(layer_sizes.keys())) + if "weight" in key and len(layer_sizes[key]) > 1 + ] + parameters = { + "latent_dim": layers[-1][1][0], # last + "hidden_dim": [layer[1][0] for layer in layers][0:-1], # all but last + } self.n_genes = len(self.gene_order) self.latent_dim = parameters["latent_dim"] @@ -77,7 +60,7 @@ def __init__( n_genes=self.n_genes, latent_dim=parameters["latent_dim"], hidden_dim=parameters["hidden_dim"], - residual=residual, + residual=False, ) if self.use_gpu is True: self.model.cuda() diff --git a/src/scimilarity/cell_search_knn.py b/src/scimilarity/cell_search_knn.py index c20dda9..524e309 100644 --- a/src/scimilarity/cell_search_knn.py +++ b/src/scimilarity/cell_search_knn.py @@ -10,9 +10,6 @@ def __init__( self, model_path: str, use_gpu: bool = False, - parameters: Optional[dict] = None, - filenames: Optional[dict] = None, - residual: bool = False, ): """Constructor. @@ -22,31 +19,17 @@ def __init__( Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. - parameters: dict, optional, default: None - Use a dictionary of custom model parameters instead of infering from model files. - filenames: dict, optional, default: None - Use a dictionary of custom filenames for model files instead default. - The kNN filenames also need to be specified here. - residual: bool, default: False - Use residual connections. Examples -------- - >>> filenames = {"knn": "knn.bin"} - >>> cs = CellSearch(model_path="/opt/data/model", filenames=filesnames) + >>> cs = CellSearchKNN(model_path="/opt/data/model") """ super().__init__( model_path=model_path, use_gpu=use_gpu, - parameters=parameters, - filenames=filenames, - residual=residual, ) - if filenames is None: - filenames = {} - self.knn = None self.safelist = None self.blocklist = None