Skip to content

Commit

Permalink
add order prior parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jun 26, 2024
1 parent 295ea52 commit 935c5e7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/mrtool/core/cov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Covariates model for `mrtool`.
"""

import itertools
import warnings

import numpy as np
Expand Down Expand Up @@ -947,7 +948,8 @@ def __init__(
ref_cov=None,
ref_cat=None,
use_re=False,
use_re_intercept=False,
use_re_intercept=True,
prior_order=None,
prior_beta_gaussian=None,
prior_beta_uniform=None,
prior_beta_laplace=None,
Expand All @@ -957,6 +959,13 @@ def __init__(
) -> None:
self.ref_cat = ref_cat
self.use_re_intercept = use_re_intercept
if prior_order is not None:
prior_order_raw, prior_order = prior_order, []
for prior in prior_order_raw:
prior_order.extend(list(zip(prior, prior[1:])))
prior_order = list(set(prior_order))
prior_order.sort()
self.prior_order = prior_order
super().__init__(
alt_cov=alt_cov,
name=name,
Expand Down Expand Up @@ -1009,6 +1018,17 @@ def attach_data(self, data: MRData) -> None:
f"ref_cat {self.ref_cat} is not in the categories."
)

# TODO: set the uniform prior for ref_cat to zero

if self.prior_order is not None:
for cat in set(
list(itertools.chain.from_iterable(self.prior_order))
):
if cat not in unique_cats:
raise ValueError(
f"Order prior category {cat} is not in the categories."
)

def has_data(self) -> bool:
"""Return if the data has been attached and categories has been parsed."""
return hasattr(self, "cats")
Expand Down
28 changes: 28 additions & 0 deletions tests/test_cat_covmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,31 @@ def test_create_design_mat(data):
]
),
)


def test_order_prior(data):
covmodel = CatCovModel(
alt_cov="alt_cat",
ref_cov="ref_cat",
ref_cat="A",
prior_order=[["A", "B"], ["B", "C"]],
)
covmodel.attach_data(data)
assert covmodel.prior_order == [("A", "B"), ("B", "C")]

covmodel = CatCovModel(
alt_cov="alt_cat",
ref_cov="ref_cat",
ref_cat="A",
prior_order=[["A", "B", "C"], ["B", "C"]],
)
assert covmodel.prior_order == [("A", "B"), ("B", "C")]

with pytest.raises(ValueError):
covmodel = CatCovModel(
alt_cov="alt_cat",
ref_cov="ref_cat",
ref_cat="A",
prior_order=[["A", "B"], ["B", "C", "E"]],
)
covmodel.attach_data(data)

0 comments on commit 935c5e7

Please sign in to comment.