diff --git a/tests/test_model.py b/tests/test_model.py index 2babc1f..eace89e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,21 +6,17 @@ def test_VAEmodel(): n_latent = 5 adata = synthetic_iid() - VAEModel.setup_anndata(adata, batch_key="batch", labels_key="labels", adata_obs="labels") + adata = VAEModel.use_obs(adata, labels_key="labels_key", adata_obs=["batch", "labels"]) + adata_manager, adata = VAEModel.setup_anndata(adata, labels_key="labels_key", batch_key="batch") print("Model Initiated..") + model = VAEModel(adata, n_latent=n_latent, n_layers=10) model.train(max_epochs=5) - model.get_elbo() - model.get_latent_representation() - model.get_marginal_ll(n_mc_samples=5) - model.get_reconstruction_error() - print("ELBO:", model.get_elbo()) - print("Latent Representation:", model.get_latent_representation) # tests __repr__ print(model) print("\nSuccess!..") -test_VAEmodel() +adata = test_VAEmodel()