diff --git a/src/mud/funs.py b/src/mud/funs.py index 8cd657d..982b6b5 100644 --- a/src/mud/funs.py +++ b/src/mud/funs.py @@ -266,7 +266,7 @@ def iterate(A, b, y, initial_mean, initial_cov, return chain -def mud_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None): +def mud_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None, split=None): """ Wrapper around mud problem, takes in raw qoi + synthetic data and performs WME transformation, instantiates solver object @@ -285,10 +285,22 @@ def mud_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None): elif num_obs > dim_output: raise ValueError("num_obs must be <= dim(qoi)") - # this is our data processing step. - data = qoi_true[0:num_obs] + np.random.randn(num_obs) * sd - q = wme(qoi[:, 0:num_obs], data, sd).reshape(-1, 1) - + # TODO: handle empty sd -> take it from the data. + # TODO: swap for data + leave noise generation separate. no randomness in method. + noise = np.random.randn(num_obs) * sd + if split is None: + # this is our data processing step. + data = qoi_true[0:num_obs] + noise + q = wme(qoi[:, 0:num_obs], data, sd).reshape(-1, 1) + else: + q = [] + for qoi_indices in split: + _q = qoi_indices[qoi_indices < num_obs] + _qoi = qoi[:, _q] + _data = np.array(qoi_true)[_q] + noise[_q] + _newqoi = wme(_qoi, _data, sd) + q.append(_newqoi) + q = np.vstack(q).T # this implements density-based solutions, mud point method d = DensityProblem(lam, q, domain) return d @@ -322,5 +334,6 @@ def map_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None): b.set_likelihood(likelihood) return b + if __name__ == "__main__": run()