Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

610 atol and rtol cannot be specified in solver options #612

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def ensemble_sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -121,6 +121,10 @@ def ensemble_sample(
"""
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
if dirichlet_alpha is None:
dirichlet_alpha = torch.ones(len(model_paths_or_jsons))
Expand All @@ -138,7 +142,9 @@ def ensemble_sample(
raise ValueError("num_samples must be a positive integer")

def wrapped_model():
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
solution = model(
torch.as_tensor(start_time),
torch.as_tensor(end_time),
Expand Down Expand Up @@ -233,7 +239,7 @@ def ensemble_calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -280,6 +286,10 @@ def ensemble_calibrate(
if not (isinstance(num_iterations, int) and num_iterations > 0):
raise ValueError("num_iterations must be a positive integer")

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

def autoguide(model):
guide = pyro.infer.autoguide.AutoGuideList(model)
guide.append(
Expand Down Expand Up @@ -314,7 +324,9 @@ def autoguide(model):
def wrapped_model():
obs = condition(data=_data)(_noise_model)

with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
solution = model(
torch.as_tensor(start_time),
torch.as_tensor(data_timepoints[-1]),
Expand Down Expand Up @@ -384,7 +396,8 @@ def sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -449,6 +462,10 @@ def sample(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
model = CompiledDynamics.load(model_path_or_json)

Expand Down Expand Up @@ -492,7 +509,9 @@ def sample(

def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
with contextlib.ExitStack() as stack:
for handler in intervention_handlers:
stack.enter_context(handler)
Expand Down Expand Up @@ -602,7 +621,8 @@ def calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
- solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
- start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -668,6 +688,10 @@ def calibrate(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

pyro.clear_param_store()

model = CompiledDynamics.load(model_path_or_json)
Expand Down Expand Up @@ -740,7 +764,9 @@ def wrapped_model():
obs = condition(data=_data)(_noise_model)

with StaticBatchObservation(data_timepoints, observation=obs):
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
with contextlib.ExitStack() as stack:
for handler in intervention_handlers:
stack.enter_context(handler)
Expand Down Expand Up @@ -834,7 +860,8 @@ def optimize(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down
7 changes: 6 additions & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
self.u_bounds = u_bounds
self.risk_bound = risk_bound # used for defining penalty
warnings.simplefilter("always", UserWarning)
self.rtol = self.solver_options.pop("rtol", 1e-7) # default = 1e-7
self.atol = self.solver_options.pop("atol", 1e-9) # default = 1e-9

def __call__(self, x):
if np.any(x - self.u_bounds[0, :] < 0.0) or np.any(
Expand Down Expand Up @@ -144,7 +146,10 @@ def propagate_uncertainty(self, x):
def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(
method=self.solver_method, options=self.solver_options
rtol=self.rtol,
atol=self.atol,
method=self.solver_method,
options=self.solver_options,
):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
Expand Down
51 changes: 43 additions & 8 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,56 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):
"num_iterations": 2,
}

RTOL = [1e-6, 1e-4]
ATOL = [1e-8, 1e-6]


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_sample_no_interventions(
sample_method, model, start_time, end_time, logging_step_size, num_samples
sample_method,
model,
start_time,
end_time,
logging_step_size,
num_samples,
rtol,
atol,
):
model_url = model.url

with pyro.poutine.seed(rng_seed=0):
result1 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result2 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

result3 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

for result in [result1, result2, result3]:
Expand Down Expand Up @@ -364,8 +391,10 @@ def test_calibrate_no_kwargs(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_calibrate_deterministic(
model_fixture, start_time, end_time, logging_step_size
model_fixture, start_time, end_time, logging_step_size, rtol, atol
):
model_url = model_fixture.url
(
Expand All @@ -381,6 +410,7 @@ def test_calibrate_deterministic(
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"deterministic_learnable_parameters": deterministic_learnable_parameters,
"solver_options": {"rtol": rtol, "atol": atol},
**CALIBRATE_KWARGS,
}

Expand All @@ -400,7 +430,10 @@ def test_calibrate_deterministic(
assert torch.allclose(param_value, param_sample_2[param_name])

result = sample(
*sample_args, **sample_kwargs, inferred_parameters=inferred_parameters
*sample_args,
**sample_kwargs,
inferred_parameters=inferred_parameters,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

check_result_sizes(result, start_time, end_time, logging_step_size, 1)
Expand Down Expand Up @@ -563,7 +596,9 @@ def test_output_format(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_optimize(model_fixture, start_time, end_time, num_samples):
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol):
logging_step_size = 1.0
model_url = model_fixture.url

Expand All @@ -581,7 +616,7 @@ def __call__(self, x):
optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
"solver_options": {"step_size": 0.1},
"solver_options": {"step_size": 0.1, "rtol": rtol, "atol": atol},
"start_time": start_time,
"n_samples_ouu": int(2),
"maxiter": 1,
Expand Down
Loading