From 935c5e78151e66f20e962859e6a2c8fe557fc374 Mon Sep 17 00:00:00 2001 From: zhengp0 Date: Wed, 26 Jun 2024 16:23:54 -0700 Subject: [PATCH] add order prior parsing --- src/mrtool/core/cov_model.py | 22 +++++++++++++++++++++- tests/test_cat_covmodel.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/mrtool/core/cov_model.py b/src/mrtool/core/cov_model.py index 5e293b7..d9b130c 100644 --- a/src/mrtool/core/cov_model.py +++ b/src/mrtool/core/cov_model.py @@ -6,6 +6,7 @@ Covariates model for `mrtool`. """ +import itertools import warnings import numpy as np @@ -947,7 +948,8 @@ def __init__( ref_cov=None, ref_cat=None, use_re=False, - use_re_intercept=False, + use_re_intercept=True, + prior_order=None, prior_beta_gaussian=None, prior_beta_uniform=None, prior_beta_laplace=None, @@ -957,6 +959,13 @@ def __init__( ) -> None: self.ref_cat = ref_cat self.use_re_intercept = use_re_intercept + if prior_order is not None: + prior_order_raw, prior_order = prior_order, [] + for prior in prior_order_raw: + prior_order.extend(list(zip(prior, prior[1:]))) + prior_order = list(set(prior_order)) + prior_order.sort() + self.prior_order = prior_order super().__init__( alt_cov=alt_cov, name=name, @@ -1009,6 +1018,17 @@ def attach_data(self, data: MRData) -> None: f"ref_cat {self.ref_cat} is not in the categories." ) + # TODO: set the uniform prior for ref_cat to zero + + if self.prior_order is not None: + for cat in set( + list(itertools.chain.from_iterable(self.prior_order)) + ): + if cat not in unique_cats: + raise ValueError( + f"Order prior category {cat} is not in the categories." + ) + def has_data(self) -> bool: """Return if the data has been attached and categories has been parsed.""" return hasattr(self, "cats") diff --git a/tests/test_cat_covmodel.py b/tests/test_cat_covmodel.py index 3e771fe..d59ee2d 100644 --- a/tests/test_cat_covmodel.py +++ b/tests/test_cat_covmodel.py @@ -139,3 +139,31 @@ def test_create_design_mat(data): ] ), ) + + +def test_order_prior(data): + covmodel = CatCovModel( + alt_cov="alt_cat", + ref_cov="ref_cat", + ref_cat="A", + prior_order=[["A", "B"], ["B", "C"]], + ) + covmodel.attach_data(data) + assert covmodel.prior_order == [("A", "B"), ("B", "C")] + + covmodel = CatCovModel( + alt_cov="alt_cat", + ref_cov="ref_cat", + ref_cat="A", + prior_order=[["A", "B", "C"], ["B", "C"]], + ) + assert covmodel.prior_order == [("A", "B"), ("B", "C")] + + with pytest.raises(ValueError): + covmodel = CatCovModel( + alt_cov="alt_cat", + ref_cov="ref_cat", + ref_cat="A", + prior_order=[["A", "B"], ["B", "C", "E"]], + ) + covmodel.attach_data(data)