-
Notifications
You must be signed in to change notification settings - Fork 31
/
mice.py
71 lines (59 loc) · 2.16 KB
/
mice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
# Explicitly enable experimental IterativeImputer (new in scikit-learn 0.22.2)
from sklearn.experimental import enable_iterative_imputer # noqa F401
from sklearn.impute import IterativeImputer
from sklearn.linear_model import BayesianRidge
from ..datasets.dataset import Dataset
from ..datasets.variables import Variables
from .sk_learn_imputer import SKLearnImputer
class MICE(SKLearnImputer):
def __init__(
self,
model_id: str,
variables: Variables,
save_dir,
max_iter=10,
initial_strategy="mean",
random_seed=0,
test_sample_count=50,
):
imputer = IterativeImputer(
max_iter=max_iter,
initial_strategy=initial_strategy,
random_state=random_seed,
estimator=BayesianRidge(),
sample_posterior=True,
)
super().__init__(model_id, variables, save_dir, imputer)
self._sample_count = test_sample_count
@classmethod
def name(cls) -> str:
return "mice"
def run_train(
self,
dataset: Dataset,
train_config_dict: Optional[Dict[str, Any]] = None,
report_progress_callback: Optional[Callable[[str, int, int], None]] = None,
) -> None:
data, mask = dataset.train_data_and_mask
data = self.fill_mask(data, mask)
self._imputer.fit(data)
def impute(
self,
data: np.ndarray,
mask: np.ndarray,
impute_config_dict: Optional[Dict[str, int]] = None,
vamp_prior_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
average: bool = True,
) -> np.ndarray:
data = self.fill_mask(data, mask)
row_count, feature_count = data.shape
imputed = np.zeros((self._sample_count, row_count, feature_count))
# Sample posterior N times.
for sample_idx in range(self._sample_count):
imputed[sample_idx, :, :] = self._imputer.transform(data)
# Take average across sample dimension
if average:
imputed = imputed.mean(axis=0) # Shape (row_count, feature_count)
return imputed