Skip to content

Commit 6cca991

Browse files
committed
adding failing test for observables defined by parameter values
1 parent 1ce34fd commit 6cca991

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/test_interfaces.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,49 @@ def test_intervention_on_constant_param(
826826
torch.arange(start_time, end_time + logging_step_size, logging_step_size)
827827
)
828828
assert processed_result.shape[1] >= 2
829+
830+
831+
@pytest.mark.parametrize("sample_method", [sample])
832+
@pytest.mark.parametrize("model_fixture", MODELS)
833+
@pytest.mark.parametrize("end_time", END_TIMES)
834+
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
835+
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
836+
@pytest.mark.parametrize("start_time", START_TIMES)
837+
def test_observables_change_with_interventions(
838+
sample_method,
839+
model_fixture,
840+
end_time,
841+
logging_step_size,
842+
num_samples,
843+
start_time,
844+
):
845+
# Assert that sample returns expected result with intervention on constant parameter
846+
if "SIR_param" not in model_fixture.url:
847+
pytest.skip("Only test 'SIR_param_in_obs' model")
848+
else:
849+
processed_result = sample_method(
850+
model_fixture.url,
851+
end_time,
852+
logging_step_size,
853+
num_samples,
854+
start_time=start_time,
855+
static_parameter_interventions={
856+
torch.tensor(2.0): {"beta": torch.tensor(0.001)}
857+
},
858+
)["data"]
859+
860+
print(processed_result["beta_param_observable_state"][0])
861+
print(
862+
processed_result["beta_param_observable_state"][
863+
int(end_time / logging_step_size)
864+
]
865+
)
866+
print(int(end_time / logging_step_size))
867+
868+
# The test will fail if values before and after the intervention are the same
869+
assert (
870+
processed_result["beta_param_observable_state"][0]
871+
> processed_result["beta_param_observable_state"][
872+
int(end_time / logging_step_size)
873+
]
874+
)

0 commit comments

Comments
 (0)