@@ -826,3 +826,49 @@ def test_intervention_on_constant_param(
826
826
torch .arange (start_time , end_time + logging_step_size , logging_step_size )
827
827
)
828
828
assert processed_result .shape [1 ] >= 2
829
+
830
+
831
+ @pytest .mark .parametrize ("sample_method" , [sample ])
832
+ @pytest .mark .parametrize ("model_fixture" , MODELS )
833
+ @pytest .mark .parametrize ("end_time" , END_TIMES )
834
+ @pytest .mark .parametrize ("logging_step_size" , LOGGING_STEP_SIZES )
835
+ @pytest .mark .parametrize ("num_samples" , NUM_SAMPLES )
836
+ @pytest .mark .parametrize ("start_time" , START_TIMES )
837
+ def test_observables_change_with_interventions (
838
+ sample_method ,
839
+ model_fixture ,
840
+ end_time ,
841
+ logging_step_size ,
842
+ num_samples ,
843
+ start_time ,
844
+ ):
845
+ # Assert that sample returns expected result with intervention on constant parameter
846
+ if "SIR_param" not in model_fixture .url :
847
+ pytest .skip ("Only test 'SIR_param_in_obs' model" )
848
+ else :
849
+ processed_result = sample_method (
850
+ model_fixture .url ,
851
+ end_time ,
852
+ logging_step_size ,
853
+ num_samples ,
854
+ start_time = start_time ,
855
+ static_parameter_interventions = {
856
+ torch .tensor (2.0 ): {"beta" : torch .tensor (0.001 )}
857
+ },
858
+ )["data" ]
859
+
860
+ print (processed_result ["beta_param_observable_state" ][0 ])
861
+ print (
862
+ processed_result ["beta_param_observable_state" ][
863
+ int (end_time / logging_step_size )
864
+ ]
865
+ )
866
+ print (int (end_time / logging_step_size ))
867
+
868
+ # The test will fail if values before and after the intervention are the same
869
+ assert (
870
+ processed_result ["beta_param_observable_state" ][0 ]
871
+ > processed_result ["beta_param_observable_state" ][
872
+ int (end_time / logging_step_size )
873
+ ]
874
+ )
0 commit comments