diff --git a/sccoda/model/scCODA_model.py b/sccoda/model/scCODA_model.py index bddb908..1f1be39 100644 --- a/sccoda/model/scCODA_model.py +++ b/sccoda/model/scCODA_model.py @@ -615,6 +615,7 @@ class scCODAModel(CompositionalModel): def __init__( self, reference_cell_type: int, + seed: Optional[int] = None, *args, **kwargs): """ @@ -708,6 +709,9 @@ def target_log_prob_fn(*argsl): self.target_log_prob_fn = target_log_prob_fn # MCMC starting values + if seed is not None: + tf.random.set_seed(seed) + self.init_params = [ tf.ones(sigma_size, name="init_sigma_d", dtype=dtype), tf.random.normal(beta_nobl_size, 0, 1, name='init_b_offset', dtype=dtype), diff --git a/sccoda/util/comp_ana.py b/sccoda/util/comp_ana.py index 02aa20a..e9ad1d9 100644 --- a/sccoda/util/comp_ana.py +++ b/sccoda/util/comp_ana.py @@ -36,6 +36,7 @@ def __new__( formula: str, reference_cell_type: Union[str, int] = "automatic", automatic_reference_absence_threshold: float = 0.05, + seed: Optional[int] = None, ) -> dm.scCODAModel: """ Builds count and covariate matrix, returns a CompositionalModel object @@ -100,6 +101,7 @@ def __new__( covariate_names=covariate_names, reference_cell_type=ref_index, formula=formula, + seed=seed, ) # Column name as reference cell type @@ -112,6 +114,7 @@ def __new__( covariate_names=covariate_names, reference_cell_type=num_index, formula=formula, + seed=seed, ) # Numeric reference cell type @@ -123,6 +126,7 @@ def __new__( covariate_names=covariate_names, reference_cell_type=reference_cell_type, formula=formula, + seed=seed, ) # None of the above: Throw error