From 4af3fabdb83294a508f3f1cbd6f21a402b4f663d Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 12:57:48 -0700 Subject: [PATCH 1/5] added rtol, atol, to interface --- pyciemss/interfaces.py | 41 ++++++++++++++++++++++++++++++++++++---- pyciemss/ouu/ouu.py | 6 +++++- tests/test_interfaces.py | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 7ade53a6..93c83cfc 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -57,6 +57,8 @@ def ensemble_sample( time_unit: Optional[str] = None, alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS, stacking_order: str = "timepoints", + rtol: float = 1e-7, + atol: float = 1e-9, ) -> Dict[str, Any]: """ Load a collection of models from files, compile them into an ensemble probabilistic program, @@ -107,6 +109,10 @@ def ensemble_sample( stacking_order: Optional[str] - The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state. - Options: "timepoints" or "quantiles" + rtol: float + - The relative tolerance for the solver. + atol: float + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -138,7 +144,7 @@ def ensemble_sample( raise ValueError("num_samples must be a positive integer") def wrapped_model(): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): solution = model( torch.as_tensor(start_time), torch.as_tensor(end_time), @@ -195,6 +201,8 @@ def ensemble_calibrate( num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, + rtol: float = 1e-7, + atol: float = 1e-9 ) -> Dict[str, Any]: """ Infer parameters for an ensemble of DynamicalSystem models conditional on data. @@ -254,6 +262,10 @@ def ensemble_calibrate( - This is called at the beginning of each iteration. - By default, this is a no-op. - This can be used to implement custom progress bars. + rtol: float + - The relative tolerance for the solver. + atol: float + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -314,7 +326,7 @@ def autoguide(model): def wrapped_model(): obs = condition(data=_data)(_noise_model) - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): solution = model( torch.as_tensor(start_time), torch.as_tensor(data_timepoints[-1]), @@ -366,6 +378,9 @@ def sample( Dict[str, Intervention], ] = {}, alpha: float = 0.95, + rtol: float = 1e-7, + atol: float = 1e-9, + ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and sample from it. @@ -427,6 +442,10 @@ def sample( :func:`~chirho.interventional.ops.intervene`, including functions. alpha: float - Risk level for alpha-superquantile outputs in the results dictionary. + rtol: float + - The relative tolerance for the solver. + atol: float + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -492,7 +511,7 @@ def sample( def wrapped_model(): with ParameterInterventionTracer(): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -576,6 +595,8 @@ def calibrate( num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, + rtol: float = 1e-7, + atol: float = 1e-9 ) -> Dict[str, Any]: """ Infer parameters for a DynamicalSystem model conditional on data. @@ -655,6 +676,10 @@ def calibrate( - This is called at the beginning of each iteration. - By default, this is a no-op. - This can be used to implement custom progress bars. + - rtol: float + - The relative tolerance for the solver. + - atol: float + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -740,7 +765,7 @@ def wrapped_model(): obs = condition(data=_data)(_noise_model) with StaticBatchObservation(data_timepoints, observation=obs): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -793,6 +818,8 @@ def optimize( verbose: bool = False, roundup_decimal: int = 4, progress_hook: Callable[[torch.Tensor], None] = lambda x: None, + rtol: float = 1e-7, + atol: float = 1e-9 ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based @@ -863,6 +890,10 @@ def optimize( - A callback function that takes in the current parameter vector as a tensor. If the function returns StopIteration, the minimization will terminate. - This can be used to implement custom progress bars and/or early stopping criteria. + rtol: float + - The relative tolerance for the solver. + atol: float + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -906,6 +937,8 @@ def optimize( solver_options=solver_options, u_bounds=bounds_np, risk_bound=risk_bound, + rtol=rtol, + atol=atol ) # Run one sample to estimate model evaluation time diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 09f1cd50..7c740ed0 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -78,6 +78,8 @@ def __init__( solver_options: Dict[str, Any] = {}, u_bounds: np.ndarray = np.atleast_2d([[0], [1]]), risk_bound: List[float] = [0.0], + rtol: float = 1e-7, + atol: float = 1e-9 ): self.model = model self.interventions = interventions @@ -97,6 +99,8 @@ def __init__( self.u_bounds = u_bounds self.risk_bound = risk_bound # used for defining penalty warnings.simplefilter("always", UserWarning) + self.rtol = rtol + self.atol = atol def __call__(self, x): if np.any(x - self.u_bounds[0, :] < 0.0) or np.any( @@ -144,7 +148,7 @@ def propagate_uncertainty(self, x): def wrapped_model(): with ParameterInterventionTracer(): with TorchDiffEq( - method=self.solver_method, options=self.solver_options + rtol=self.rtol, atol=self.atol, method=self.solver_method, options=self.solver_options ): with contextlib.ExitStack() as stack: for handler in static_parameter_intervention_handlers: diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 1f653b06..8fbe2ea2 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -87,6 +87,8 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): "num_iterations": 2, } +RTOL = [1e-6, 1e-5, 1e-4] +ATOL = [1e-8, 1e-7, 1e-6] @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @pytest.mark.parametrize("model", MODELS) @@ -124,6 +126,44 @@ def test_sample_no_interventions( assert "total_state" in result1.keys() +@pytest.mark.parametrize("sample_method", SAMPLE_METHODS) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) +def test_sample_with_tolerance( + sample_method, model, start_time, end_time, logging_step_size, num_samples, rtol, atol +): + model_url = model.url + + with pyro.poutine.seed(rng_seed=0): + result1 = sample_method( + model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + )["unprocessed_result"] + with pyro.poutine.seed(rng_seed=0): + result2 = sample_method( + model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + )["unprocessed_result"] + + result3 = sample_method( + model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + )["unprocessed_result"] + + for result in [result1, result2, result3]: + assert isinstance(result, dict) + check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) + + check_states_match(result1, result2) + if model.has_distributional_parameters: + check_states_match_in_all_but_values(result1, result3) + + if sample_method.__name__ == "dummy_ensemble_sample": + assert "total_state" in result1.keys() + + @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @pytest.mark.parametrize("model_url", MODEL_URLS) @pytest.mark.parametrize("start_time", START_TIMES) From 1680f7024442e8260e7c5449f5adb2ded147fb16 Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:02:05 -0700 Subject: [PATCH 2/5] added tests for optimize and calibration --- tests/test_interfaces.py | 57 +++++++++++----------------------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 8fbe2ea2..b29de8cf 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -87,43 +87,8 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): "num_iterations": 2, } -RTOL = [1e-6, 1e-5, 1e-4] -ATOL = [1e-8, 1e-7, 1e-6] - -@pytest.mark.parametrize("sample_method", SAMPLE_METHODS) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("start_time", START_TIMES) -@pytest.mark.parametrize("end_time", END_TIMES) -@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) -@pytest.mark.parametrize("num_samples", NUM_SAMPLES) -def test_sample_no_interventions( - sample_method, model, start_time, end_time, logging_step_size, num_samples -): - model_url = model.url - - with pyro.poutine.seed(rng_seed=0): - result1 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - with pyro.poutine.seed(rng_seed=0): - result2 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - - result3 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - - for result in [result1, result2, result3]: - assert isinstance(result, dict) - check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) - - check_states_match(result1, result2) - if model.has_distributional_parameters: - check_states_match_in_all_but_values(result1, result3) - - if sample_method.__name__ == "dummy_ensemble_sample": - assert "total_state" in result1.keys() +RTOL = [1e-6, 1e-4] +ATOL = [1e-8, 1e-6] @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @@ -134,7 +99,7 @@ def test_sample_no_interventions( @pytest.mark.parametrize("num_samples", NUM_SAMPLES) @pytest.mark.parametrize("rtol", RTOL) @pytest.mark.parametrize("atol", ATOL) -def test_sample_with_tolerance( +def test_sample_no_interventions( sample_method, model, start_time, end_time, logging_step_size, num_samples, rtol, atol ): model_url = model.url @@ -404,8 +369,10 @@ def test_calibrate_no_kwargs( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) def test_calibrate_deterministic( - model_fixture, start_time, end_time, logging_step_size + model_fixture, start_time, end_time, logging_step_size, rtol, atol ): model_url = model_fixture.url ( @@ -421,6 +388,8 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, + "rtol": rtol, + "atol": atol, **CALIBRATE_KWARGS, } @@ -440,7 +409,7 @@ def test_calibrate_deterministic( assert torch.allclose(param_value, param_sample_2[param_name]) result = sample( - *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters + *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, rtol=rtol, atol=atol )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1) @@ -603,7 +572,9 @@ def test_output_format( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) -def test_optimize(model_fixture, start_time, end_time, num_samples): +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) +def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol): logging_step_size = 1.0 model_url = model_fixture.url @@ -627,6 +598,8 @@ def __call__(self, x): "maxiter": 1, "maxfeval": 2, "progress_hook": progress_hook, + "rtol": rtol, + "atol": atol } bounds_interventions = optimize_kwargs["bounds_interventions"] opt_result = optimize( @@ -665,6 +638,8 @@ def __call__(self, x): static_parameter_interventions=opt_intervention, solver_method=optimize_kwargs["solver_method"], solver_options=optimize_kwargs["solver_options"], + rtol=rtol, + atol=atol )["unprocessed_result"] intervened_result_subset = { From d076a2ec04dc5609ae3adadfe21145b45644afcc Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:39:51 -0700 Subject: [PATCH 3/5] passes linter --- pyciemss/interfaces.py | 43 ++++++++++++++++++++--------------- pyciemss/ouu/ouu.py | 7 ++++-- tests/test_interfaces.py | 49 ++++++++++++++++++++++++++++++++-------- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 93c83cfc..e3160ab0 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -58,7 +58,7 @@ def ensemble_sample( alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS, stacking_order: str = "timepoints", rtol: float = 1e-7, - atol: float = 1e-9, + atol: float = 1e-9, ) -> Dict[str, Any]: """ Load a collection of models from files, compile them into an ensemble probabilistic program, @@ -144,7 +144,9 @@ def ensemble_sample( raise ValueError("num_samples must be a positive integer") def wrapped_model(): - with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): solution = model( torch.as_tensor(start_time), torch.as_tensor(end_time), @@ -202,7 +204,7 @@ def ensemble_calibrate( deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, rtol: float = 1e-7, - atol: float = 1e-9 + atol: float = 1e-9, ) -> Dict[str, Any]: """ Infer parameters for an ensemble of DynamicalSystem models conditional on data. @@ -263,9 +265,9 @@ def ensemble_calibrate( - By default, this is a no-op. - This can be used to implement custom progress bars. rtol: float - - The relative tolerance for the solver. + - The relative tolerance for the solver. atol: float - - The absolute tolerance for the solver. + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -326,7 +328,9 @@ def autoguide(model): def wrapped_model(): obs = condition(data=_data)(_noise_model) - with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): solution = model( torch.as_tensor(start_time), torch.as_tensor(data_timepoints[-1]), @@ -379,8 +383,7 @@ def sample( ] = {}, alpha: float = 0.95, rtol: float = 1e-7, - atol: float = 1e-9, - + atol: float = 1e-9, ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and sample from it. @@ -511,7 +514,9 @@ def sample( def wrapped_model(): with ParameterInterventionTracer(): - with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -596,7 +601,7 @@ def calibrate( deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, rtol: float = 1e-7, - atol: float = 1e-9 + atol: float = 1e-9, ) -> Dict[str, Any]: """ Infer parameters for a DynamicalSystem model conditional on data. @@ -677,7 +682,7 @@ def calibrate( - By default, this is a no-op. - This can be used to implement custom progress bars. - rtol: float - - The relative tolerance for the solver. + - The relative tolerance for the solver. - atol: float - The absolute tolerance for the solver. @@ -765,7 +770,9 @@ def wrapped_model(): obs = condition(data=_data)(_noise_model) with StaticBatchObservation(data_timepoints, observation=obs): - with TorchDiffEq(rtol=rtol, atol=atol, method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -818,8 +825,8 @@ def optimize( verbose: bool = False, roundup_decimal: int = 4, progress_hook: Callable[[torch.Tensor], None] = lambda x: None, - rtol: float = 1e-7, - atol: float = 1e-9 + rtol: float = 1e-7, + atol: float = 1e-9, ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based @@ -891,9 +898,9 @@ def optimize( If the function returns StopIteration, the minimization will terminate. - This can be used to implement custom progress bars and/or early stopping criteria. rtol: float - - The relative tolerance for the solver. + - The relative tolerance for the solver. atol: float - - The absolute tolerance for the solver. + - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -937,8 +944,8 @@ def optimize( solver_options=solver_options, u_bounds=bounds_np, risk_bound=risk_bound, - rtol=rtol, - atol=atol + rtol=rtol, + atol=atol, ) # Run one sample to estimate model evaluation time diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 7c740ed0..363a86ae 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -78,7 +78,7 @@ def __init__( solver_options: Dict[str, Any] = {}, u_bounds: np.ndarray = np.atleast_2d([[0], [1]]), risk_bound: List[float] = [0.0], - rtol: float = 1e-7, + rtol: float = 1e-7, atol: float = 1e-9 ): self.model = model @@ -148,7 +148,10 @@ def propagate_uncertainty(self, x): def wrapped_model(): with ParameterInterventionTracer(): with TorchDiffEq( - rtol=self.rtol, atol=self.atol, method=self.solver_method, options=self.solver_options + rtol=self.rtol, + atol=self.atol, + method=self.solver_method, + options=self.solver_options, ): with contextlib.ExitStack() as stack: for handler in static_parameter_intervention_handlers: diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index b29de8cf..2811b5e6 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -100,21 +100,46 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): @pytest.mark.parametrize("rtol", RTOL) @pytest.mark.parametrize("atol", ATOL) def test_sample_no_interventions( - sample_method, model, start_time, end_time, logging_step_size, num_samples, rtol, atol + sample_method, + model, + start_time, + end_time, + logging_step_size, + num_samples, + rtol, + atol, ): model_url = model.url with pyro.poutine.seed(rng_seed=0): result1 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + rtol=rtol, + atol=atol, )["unprocessed_result"] with pyro.poutine.seed(rng_seed=0): result2 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + rtol=rtol, + atol=atol, )["unprocessed_result"] result3 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time, rtol=rtol, atol=atol + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + rtol=rtol, + atol=atol, )["unprocessed_result"] for result in [result1, result2, result3]: @@ -388,7 +413,7 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, - "rtol": rtol, + "rtol": rtol, "atol": atol, **CALIBRATE_KWARGS, } @@ -409,7 +434,11 @@ def test_calibrate_deterministic( assert torch.allclose(param_value, param_sample_2[param_name]) result = sample( - *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, rtol=rtol, atol=atol + *sample_args, + **sample_kwargs, + inferred_parameters=inferred_parameters, + rtol=rtol, + atol=atol, )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1) @@ -598,8 +627,8 @@ def __call__(self, x): "maxiter": 1, "maxfeval": 2, "progress_hook": progress_hook, - "rtol": rtol, - "atol": atol + "rtol": rtol, + "atol": atol, } bounds_interventions = optimize_kwargs["bounds_interventions"] opt_result = optimize( @@ -638,8 +667,8 @@ def __call__(self, x): static_parameter_interventions=opt_intervention, solver_method=optimize_kwargs["solver_method"], solver_options=optimize_kwargs["solver_options"], - rtol=rtol, - atol=atol + rtol=rtol, + atol=atol, )["unprocessed_result"] intervened_result_subset = { From e8929fa7ec234c1b7068212d1c0f5034c6b48f84 Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:37:55 -0700 Subject: [PATCH 4/5] atol and rtol moved to inside solver_options dict, updated tests --- pyciemss/interfaces.py | 58 +++++++++++++++------------------------- pyciemss/ouu/ouu.py | 6 ++--- tests/test_interfaces.py | 21 +++++---------- 3 files changed, 29 insertions(+), 56 deletions(-) diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index e3160ab0..6cd8b703 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -57,8 +57,6 @@ def ensemble_sample( time_unit: Optional[str] = None, alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS, stacking_order: str = "timepoints", - rtol: float = 1e-7, - atol: float = 1e-9, ) -> Dict[str, Any]: """ Load a collection of models from files, compile them into an ensemble probabilistic program, @@ -95,7 +93,7 @@ def ensemble_sample( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -109,10 +107,6 @@ def ensemble_sample( stacking_order: Optional[str] - The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state. - Options: "timepoints" or "quantiles" - rtol: float - - The relative tolerance for the solver. - atol: float - - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -127,6 +121,10 @@ def ensemble_sample( """ check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 + atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + with torch.no_grad(): if dirichlet_alpha is None: dirichlet_alpha = torch.ones(len(model_paths_or_jsons)) @@ -203,8 +201,6 @@ def ensemble_calibrate( num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, - rtol: float = 1e-7, - atol: float = 1e-9, ) -> Dict[str, Any]: """ Infer parameters for an ensemble of DynamicalSystem models conditional on data. @@ -243,7 +239,7 @@ def ensemble_calibrate( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -264,10 +260,6 @@ def ensemble_calibrate( - This is called at the beginning of each iteration. - By default, this is a no-op. - This can be used to implement custom progress bars. - rtol: float - - The relative tolerance for the solver. - atol: float - - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -293,6 +285,10 @@ def ensemble_calibrate( # Check that num_iterations is a positive integer if not (isinstance(num_iterations, int) and num_iterations > 0): raise ValueError("num_iterations must be a positive integer") + + # Get tolerances for solver + rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 + atol = solver_options.pop('atol', 1e-9) # default = 1e-9 def autoguide(model): guide = pyro.infer.autoguide.AutoGuideList(model) @@ -382,8 +378,6 @@ def sample( Dict[str, Intervention], ] = {}, alpha: float = 0.95, - rtol: float = 1e-7, - atol: float = 1e-9, ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and sample from it. @@ -402,7 +396,7 @@ def sample( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -445,10 +439,6 @@ def sample( :func:`~chirho.interventional.ops.intervene`, including functions. alpha: float - Risk level for alpha-superquantile outputs in the results dictionary. - rtol: float - - The relative tolerance for the solver. - atol: float - - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -471,6 +461,10 @@ def sample( check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 + atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + with torch.no_grad(): model = CompiledDynamics.load(model_path_or_json) @@ -600,8 +594,6 @@ def calibrate( num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, - rtol: float = 1e-7, - atol: float = 1e-9, ) -> Dict[str, Any]: """ Infer parameters for a DynamicalSystem model conditional on data. @@ -628,7 +620,7 @@ def calibrate( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. - solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. - start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -681,10 +673,6 @@ def calibrate( - This is called at the beginning of each iteration. - By default, this is a no-op. - This can be used to implement custom progress bars. - - rtol: float - - The relative tolerance for the solver. - - atol: float - - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -698,6 +686,10 @@ def calibrate( check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 + atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + pyro.clear_param_store() model = CompiledDynamics.load(model_path_or_json) @@ -825,8 +817,6 @@ def optimize( verbose: bool = False, roundup_decimal: int = 4, progress_hook: Callable[[torch.Tensor], None] = lambda x: None, - rtol: float = 1e-7, - atol: float = 1e-9, ) -> Dict[str, Any]: r""" Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based @@ -868,7 +858,7 @@ def optimize( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -897,10 +887,6 @@ def optimize( - A callback function that takes in the current parameter vector as a tensor. If the function returns StopIteration, the minimization will terminate. - This can be used to implement custom progress bars and/or early stopping criteria. - rtol: float - - The relative tolerance for the solver. - atol: float - - The absolute tolerance for the solver. Returns: result: Dict[str, Any] @@ -944,8 +930,6 @@ def optimize( solver_options=solver_options, u_bounds=bounds_np, risk_bound=risk_bound, - rtol=rtol, - atol=atol, ) # Run one sample to estimate model evaluation time diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 363a86ae..e68e0a5b 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -78,8 +78,6 @@ def __init__( solver_options: Dict[str, Any] = {}, u_bounds: np.ndarray = np.atleast_2d([[0], [1]]), risk_bound: List[float] = [0.0], - rtol: float = 1e-7, - atol: float = 1e-9 ): self.model = model self.interventions = interventions @@ -99,8 +97,8 @@ def __init__( self.u_bounds = u_bounds self.risk_bound = risk_bound # used for defining penalty warnings.simplefilter("always", UserWarning) - self.rtol = rtol - self.atol = atol + self.rtol = self.solver_options.pop('rtol', 1e-7) # default = 1e-7 + self.atol = self.solver_options.pop('atol', 1e-9) # default = 1e-9 def __call__(self, x): if np.any(x - self.u_bounds[0, :] < 0.0) or np.any( diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 2811b5e6..03200b1f 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -118,8 +118,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - rtol=rtol, - atol=atol, + solver_options = {'rtol':rtol, 'atol':atol} )["unprocessed_result"] with pyro.poutine.seed(rng_seed=0): result2 = sample_method( @@ -128,8 +127,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - rtol=rtol, - atol=atol, + solver_options = {'rtol':rtol, 'atol':atol} )["unprocessed_result"] result3 = sample_method( @@ -138,8 +136,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - rtol=rtol, - atol=atol, + solver_options = {'rtol':rtol, 'atol':atol} )["unprocessed_result"] for result in [result1, result2, result3]: @@ -413,8 +410,7 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, - "rtol": rtol, - "atol": atol, + "solver_options": {'rtol':rtol, 'atol':atol}, **CALIBRATE_KWARGS, } @@ -437,8 +433,7 @@ def test_calibrate_deterministic( *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, - rtol=rtol, - atol=atol, + solver_options={'rtol':rtol, 'atol':atol} )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1) @@ -621,14 +616,12 @@ def __call__(self, x): optimize_kwargs = { **model_fixture.optimize_kwargs, "solver_method": "euler", - "solver_options": {"step_size": 0.1}, + "solver_options": {"step_size": 0.1, "rtol": rtol, "atol": atol}, "start_time": start_time, "n_samples_ouu": int(2), "maxiter": 1, "maxfeval": 2, "progress_hook": progress_hook, - "rtol": rtol, - "atol": atol, } bounds_interventions = optimize_kwargs["bounds_interventions"] opt_result = optimize( @@ -667,8 +660,6 @@ def __call__(self, x): static_parameter_interventions=opt_intervention, solver_method=optimize_kwargs["solver_method"], solver_options=optimize_kwargs["solver_options"], - rtol=rtol, - atol=atol, )["unprocessed_result"] intervened_result_subset = { From 710acdad0329a0e2b653ea3f6b62114e497b814d Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:41:35 -0700 Subject: [PATCH 5/5] formatting and linting passing --- pyciemss/interfaces.py | 29 ++++++++++++++++------------- pyciemss/ouu/ouu.py | 4 ++-- tests/test_interfaces.py | 10 +++++----- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 6cd8b703..9ffe13f2 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -122,9 +122,9 @@ def ensemble_sample( check_solver(solver_method, solver_options) # Get tolerances for solver - rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 - atol = solver_options.pop('atol', 1e-9) # default = 1e-9 - + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 + with torch.no_grad(): if dirichlet_alpha is None: dirichlet_alpha = torch.ones(len(model_paths_or_jsons)) @@ -285,10 +285,10 @@ def ensemble_calibrate( # Check that num_iterations is a positive integer if not (isinstance(num_iterations, int) and num_iterations > 0): raise ValueError("num_iterations must be a positive integer") - + # Get tolerances for solver - rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 - atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 def autoguide(model): guide = pyro.infer.autoguide.AutoGuideList(model) @@ -396,7 +396,8 @@ def sample( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -462,8 +463,8 @@ def sample( check_solver(solver_method, solver_options) # Get tolerances for solver - rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 - atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 with torch.no_grad(): model = CompiledDynamics.load(model_path_or_json) @@ -620,7 +621,8 @@ def calibrate( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. - solver_options: Dict[str, Any] - - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. - start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -687,8 +689,8 @@ def calibrate( check_solver(solver_method, solver_options) # Get tolerances for solver - rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7 - atol = solver_options.pop('atol', 1e-9) # default = 1e-9 + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 pyro.clear_param_store() @@ -858,7 +860,8 @@ def optimize( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index e68e0a5b..28629d4f 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -97,8 +97,8 @@ def __init__( self.u_bounds = u_bounds self.risk_bound = risk_bound # used for defining penalty warnings.simplefilter("always", UserWarning) - self.rtol = self.solver_options.pop('rtol', 1e-7) # default = 1e-7 - self.atol = self.solver_options.pop('atol', 1e-9) # default = 1e-9 + self.rtol = self.solver_options.pop("rtol", 1e-7) # default = 1e-7 + self.atol = self.solver_options.pop("atol", 1e-9) # default = 1e-9 def __call__(self, x): if np.any(x - self.u_bounds[0, :] < 0.0) or np.any( diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 03200b1f..e4563b8b 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -118,7 +118,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - solver_options = {'rtol':rtol, 'atol':atol} + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] with pyro.poutine.seed(rng_seed=0): result2 = sample_method( @@ -127,7 +127,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - solver_options = {'rtol':rtol, 'atol':atol} + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] result3 = sample_method( @@ -136,7 +136,7 @@ def test_sample_no_interventions( logging_step_size, num_samples, start_time=start_time, - solver_options = {'rtol':rtol, 'atol':atol} + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] for result in [result1, result2, result3]: @@ -410,7 +410,7 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, - "solver_options": {'rtol':rtol, 'atol':atol}, + "solver_options": {"rtol": rtol, "atol": atol}, **CALIBRATE_KWARGS, } @@ -433,7 +433,7 @@ def test_calibrate_deterministic( *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, - solver_options={'rtol':rtol, 'atol':atol} + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1)