Skip to content

Commit

Permalink
Merge pull request #57 from ihmeuw-msca/feature/cat-covmodel
Browse files Browse the repository at this point in the history
Add categorical covariate model
  • Loading branch information
zhengp0 authored Oct 7, 2024
2 parents 0942ac9 + 01fc880 commit 2ee00fa
Show file tree
Hide file tree
Showing 6 changed files with 637 additions and 20 deletions.
12 changes: 11 additions & 1 deletion src/mrtool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
"""

from .core import utils
from .core.cov_model import CovModel, LinearCovModel, LogCovModel
from .core.cov_model import (
CatCovModel,
CovModel,
LinearCatCovModel,
LinearCovModel,
LogCatCovModel,
LogCovModel,
)
from .core.data import MRData
from .core.model import MRBRT, MRBeRT
from .cov_selection.covfinder import CovFinder

__all__ = [
"MRData",
"CatCovModel",
"CovModel",
"LinearCatCovModel",
"LinearCovModel",
"LogCatCovModel",
"LogCovModel",
"MRBRT",
"MRBeRT",
Expand Down
241 changes: 237 additions & 4 deletions src/mrtool/core/cov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
Covariates model for `mrtool`.
"""

import itertools
import warnings
from typing import Callable

import numpy as np
import pandas as pd
import xspline
from numpy.typing import NDArray

Expand Down Expand Up @@ -451,7 +456,7 @@ def create_spline(
Returns
-------
xspline.XSpline
XSpline
The spline object.
"""
Expand Down Expand Up @@ -535,7 +540,7 @@ def create_design_mat(self, data) -> tuple[NDArray, NDArray]:
Returns
-------
tuple[numpy.ndarray, numpy.ndarray]
tuple[NDArray, NDArray]
Return the design matrix for linear cov or spline.
"""
Expand Down Expand Up @@ -832,7 +837,7 @@ def create_z_mat(self, data):
Returns
-------
numpy.ndarray
NDArray
Design matrix for random effects.
"""
Expand Down Expand Up @@ -884,7 +889,7 @@ def create_z_mat(self, data):
Returns
-------
numpy.ndarray
NDArray
Design matrix for random effects.
"""
Expand Down Expand Up @@ -929,3 +934,231 @@ def num_constraints(self):
@property
def num_z_vars(self):
return int(self.use_re)


class CatCovModel(CovModel):
"""Categorical covariate model."""

def __init__(
self,
alt_cov,
name=None,
ref_cov=None,
ref_cat=None,
use_re=False,
use_re_intercept=True,
prior_order=None,
prior_beta_gaussian=None,
prior_beta_uniform=None,
prior_beta_laplace=None,
prior_gamma_gaussian=None,
prior_gamma_uniform=None,
prior_gamma_laplace=None,
) -> None:
self.ref_cat = ref_cat
self.use_re_intercept = use_re_intercept
if prior_order is not None:
prior_order_raw, prior_order = prior_order, []
for prior in prior_order_raw:
prior_order.extend(list(zip(prior, prior[1:])))
prior_order = list(set(prior_order))
prior_order.sort()
self.prior_order = prior_order
super().__init__(
alt_cov=alt_cov,
name=name,
ref_cov=ref_cov,
use_re=use_re,
prior_beta_gaussian=prior_beta_gaussian,
prior_beta_uniform=prior_beta_uniform,
prior_beta_laplace=prior_beta_laplace,
prior_gamma_gaussian=prior_gamma_gaussian,
prior_gamma_uniform=prior_gamma_uniform,
prior_gamma_laplace=prior_gamma_laplace,
)

if len(self.alt_cov) != 1:
raise ValueError("alt_cov should be a single column.")
if len(self.ref_cov) > 1:
raise ValueError("ref_cov should be nothing or a single column.")
if len(self.ref_cov) == 1 and self.ref_cat is None:
warnings.warn(
"ref_cat is not provided for a comparison covmodel, it will be "
"inferenced as the most common categories when attaching data."
)
if len(self.ref_cov) == 0 and self.ref_cat is not None:
raise ValueError(
"Cannot set ref_cat when this is not a comparison model."
)

self.cats: pd.Series

def attach_data(self, data: MRData) -> None:
"""Attach data and parse the categories. Number of variables will be
determined here and priors will be processed and if ref_cov is not set
before, and this is a comparison model, ref_cov will be inferred as the
most common category.
"""
alt_cov = data.get_covs(self.alt_cov)
ref_cov = data.get_covs(self.ref_cov)
unique_cats, counts = np.unique(
np.hstack([alt_cov, ref_cov]), return_counts=True
)
self.cats = pd.Series(unique_cats, name="cats")
self._process_priors()

if len(self.ref_cov) == 1:
if self.ref_cat is None:
self.ref_cat = unique_cats[counts.argmax()]
if self.ref_cat not in unique_cats:
raise ValueError(
f"ref_cat {self.ref_cat} is not in the categories."
)

if self.ref_cat is not None:
ref_index = dict(zip(self.cats, self.cats.index))[self.ref_cat]
ref_beta_uprior = self.prior_beta_uniform[:, ref_index]
if not (
np.isinf(ref_beta_uprior).all()
or np.allclose(ref_beta_uprior, 0.0)
):
warnings.warn(
f"Reset ref_cat beta uniform prior from {ref_beta_uprior} to (0, 0)"
)
self.prior_beta_uniform[:, ref_index] = 0.0
if self.use_re and (not self.use_re_intercept):
ref_gamma_uprior = self.prior_gamma_uniform[:, ref_index]
if not (
np.isinf(ref_gamma_uprior[1]).all()
or np.allclose(ref_gamma_uprior, 0.0)
):
warnings.warn(
f"Reset ref_cat gamma uniform prior from {ref_gamma_uprior} to (0, 0)"
)
self.prior_gamma_uniform[:, ref_index] = 0.0

if self.prior_order is not None:
for cat in set(
list(itertools.chain.from_iterable(self.prior_order))
):
if cat not in unique_cats:
raise ValueError(
f"Order prior category {cat} is not in the categories."
)

def has_data(self) -> bool:
"""Return if the data has been attached and categories has been parsed."""
return hasattr(self, "cats")

def encode(self, x: NDArray) -> NDArray:
"""Encode the provided categories into dummy variables."""
col = pd.merge(
pd.Series(x, name="cats"), self.cats.reset_index(), how="left"
)["index"]
if np.isnan(col).any():
raise ValueError("Categories not found")
mat = np.zeros((len(x), self.num_x_vars))
mat[range(len(x)), col] = 1.0
return mat

def create_design_mat(self, data: MRData) -> tuple[NDArray, NDArray]:
"""Create design matrix for alternative and reference categories."""
alt_cov = data.get_covs(self.alt_cov).ravel()
ref_cov = data.get_covs(self.ref_cov).ravel()

alt_mat = self.encode(alt_cov)
if ref_cov.size == 0:
ref_mat = np.empty((len(alt_cov), 0))
else:
ref_mat = self.encode(ref_cov)
return alt_mat, ref_mat

def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
c_mat, c_val = super().create_constraint_mat()
if not self.prior_order:
return c_mat, c_val

c_val = np.hstack(
[
c_val,
np.repeat(
np.array([[-np.inf], [0.0]]), len(self.prior_order), axis=1
),
]
)

mats = []
for alt_cat, ref_cat in self.prior_order:
alt_mat = self.encode([alt_cat])
ref_mat = self.encode([ref_cat])
mats.append(alt_mat - ref_mat)
c_mat = np.vstack([c_mat] + mats)
return c_mat, c_val

@property
def num_x_vars(self) -> int:
"""Number of the fixed effects. Returns 0 if data is not attached
otherwise it will return the number of categories.
"""
if not hasattr(self, "cats"):
return 0
return len(self.cats)

@property
def num_z_vars(self) -> int:
"""Number of the random effects. When use_re_intercept is set to True,
it will use a single intercept random effect. Otherwise, it will use
each category will have its own random effect.
"""
if not self.use_re:
return 0
if self.use_re_intercept:
return 1
return self.num_x_vars

@property
def num_constraints(self) -> int:
num = super().num_constraints
if self.prior_order:
num += len(self.prior_order)
return num

def create_z_mat(self, data: MRData) -> NDArray:
if not self.use_re:
return np.empty((data.num_obs, 0))

if self.use_re_intercept:
alt_mat = np.ones((data.num_obs, 1))
ref_mat = np.empty((data.num_obs, 0))
else:
alt_mat, ref_mat = self.create_design_mat(data)

z_mat = alt_mat if ref_mat.size == 0 else alt_mat - ref_mat
return z_mat


class LinearCatCovModel(CatCovModel):
def create_x_fun(self, data: MRData) -> Callable:
alt_mat, ref_mat = self.create_design_mat(data)
return utils.mat_to_fun(alt_mat, ref_mat=ref_mat)


class LogCatCovModel(CatCovModel):
def attach_data(self, data: MRData) -> None:
super().attach_data(data)

# add positive constraints to each category
# Currently we hard-code the offset value
offset = 1e-6
shift = 0.0 if self.ref_cat is None else 1.0
lb = -shift + offset

self.prior_beta_uniform = np.maximum(lb, self.prior_beta_uniform)

def create_x_fun(self, data: MRData) -> Callable:
alt_mat, ref_mat = self.create_design_mat(data)
add_one = self.ref_cat is not None
return utils.mat_to_log_fun(alt_mat, ref_mat=ref_mat, add_one=add_one)
16 changes: 9 additions & 7 deletions src/mrtool/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _check_attr_type(self):
assert isinstance(self.covs, dict)
for cov in self.covs.values():
assert isinstance(cov, np.ndarray)
assert is_numeric_array(cov)
# assert is_numeric_array(cov)

def _get_cov_scales(self):
"""Compute the covariate scale."""
Expand All @@ -103,6 +103,7 @@ def _get_cov_scales(self):
self.cov_scales = {
cov_name: np.max(np.abs(cov))
for cov_name, cov in self.covs.items()
if is_numeric_array(cov)
}
zero_covs = [
cov_name
Expand Down Expand Up @@ -159,12 +160,13 @@ def _remove_nan_in_covs(self):
if not self.is_empty():
index = np.full(self.num_obs, False)
for cov_name, cov in self.covs.items():
cov_index = np.isnan(cov)
if cov_index.any():
warnings.warn(
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
)
index = index | cov_index
if is_numeric_array(cov):
cov_index = np.isnan(cov)
if cov_index.any():
warnings.warn(
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
)
index = index | cov_index
self._remove_data(index)

def _remove_data(self, index: NDArray):
Expand Down
18 changes: 10 additions & 8 deletions src/mrtool/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def __init__(
self.cov_names.extend(cov_model.covs)
self.num_covs = len(self.cov_names)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def _infer_shape(self) -> None:
# add random effects
if not any([cov_model.use_re for cov_model in self.cov_models]):
self.cov_models[0].use_re = True
Expand Down Expand Up @@ -83,14 +92,6 @@ def __init__(
[cov_model.num_regularizations for cov_model in self.cov_models]
)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def attach_data(self, data=None):
"""Attach data to cov_model."""
data = self.data if data is None else data
Expand Down Expand Up @@ -239,6 +240,7 @@ def fit_model(self, **fit_options):
"""
if not all([cov_model.has_data() for cov_model in self.cov_models]):
self.attach_data()
self._infer_shape()

# dimensions
n = self.data.study_sizes
Expand Down
2 changes: 2 additions & 0 deletions src/mrtool/cov_selection/covfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def create_model(
model = MRBRT(
self.data, cov_models=cov_models, inlier_pct=self.inlier_pct
)
model.attach_data()
model._infer_shape()
return model

def fit_gaussian_model(self, covs: list[str]) -> MRBRT:
Expand Down
Loading

0 comments on commit 2ee00fa

Please sign in to comment.