diff --git a/pymsis/msis.py b/pymsis/msis.py index ea1fee9..eb20c9c 100644 --- a/pymsis/msis.py +++ b/pymsis/msis.py @@ -9,6 +9,11 @@ from pymsis.utils import get_f107_ap +# Store the previous options to avoid reinitializing the model +# each iteration unless necessary +_previous_options: dict[str, list[float] | None] = {"0": None, "2.0": None, "2.1": None} + + def run( dates: npt.ArrayLike, lons: npt.ArrayLike, @@ -135,7 +140,9 @@ def run( # convert to string version version = str(version) if version in {"0", "00"}: - msis00f.pytselec(options) + if _previous_options["0"] != options: + msis00f.pytselec(options) + _previous_options["0"] = options output = msis00f.pygtd7d( input_data[:, 0], input_data[:, 1], @@ -154,10 +161,17 @@ def run( # Select the proper library. Default to version 2.1, unless explicitly # requested "2.0" via string - msis_lib = msis21f if version == "2.0": msis_lib = msis20f - msis_lib.pyinitswitch(options, parmpath=msis_path) + else: + version = "2.1" + msis_lib = msis21f + + # Only reinitialize the model if the options have changed + if _previous_options[version] != options: + msis_lib.pyinitswitch(options, parmpath=msis_path) + _previous_options[version] = options + output = msis_lib.pymsiscalc( input_data[:, 0], input_data[:, 1], diff --git a/src/wrappers/msis2.F90 b/src/wrappers/msis2.F90 index fc82854..e9528f4 100644 --- a/src/wrappers/msis2.F90 +++ b/src/wrappers/msis2.F90 @@ -1,14 +1,24 @@ subroutine pyinitswitch(switch_legacy, parmpath) + use msis_calc, only: msiscalc + use msis_constants, only: rp use msis_init, only: msisinit implicit none real(4), intent(in), optional :: switch_legacy(1:25) !Legacy switch array character(len=*), intent(in), optional :: parmpath !Path to parameter file + real(kind=rp) :: output = 0. + real(kind=rp) :: output_arr(1:11) = 0. call msisinit(switch_legacy=switch_legacy, parmpath=parmpath) + ! Artificially call msiscalc to reset the last variables as there is + ! a global cache on these and the parameters won't be updated if we + ! don't set them to something different. + ! See issue: gh-59 + call msiscalc(0., 0., -999., -999., -1., 0., 0., (/1., 1., 1., 1., 1., 1., 1./), output, output_arr) + return end subroutine pyinitswitch diff --git a/tests/test_msis.py b/tests/test_msis.py index 610c2a9..1cf8792 100644 --- a/tests/test_msis.py +++ b/tests/test_msis.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal @@ -42,6 +44,26 @@ def expected_output(): ) +@pytest.fixture +def expected_output_with_options(): + return np.array( + [ + 2.427699e-10, + 2.849738e15, + 1.364307e14, + 3.836351e15, + 9.207778e12, + 1.490457e11, + 2.554763e12, + 3.459567e13, + 8.023306e-04, + 1.326593e12, + 9.277623e02, + ], + dtype=np.float32, + ) + + @pytest.fixture def expected_output00(): return np.array( @@ -182,10 +204,13 @@ def test_create_input_multi_lon_lat(input_data, expected_input): assert_array_equal(data, [expected_input] * 5 * 5) -def test_run_options(input_data): +def test_run_options(input_data, expected_output): # Default options is all 1's, so make sure they are equivalent - assert_array_equal( - msis.run(*input_data, options=None), msis.run(*input_data, options=[1] * 25) + assert_allclose( + np.squeeze(msis.run(*input_data, options=None)), expected_output, rtol=1e-5 + ) + assert_allclose( + np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5 ) with pytest.raises(ValueError, match="options needs to be a list"): @@ -395,3 +420,48 @@ def test_keyword_argument_call(input_data, version, func): # NO missing from versions before 2.1 direct_output[:, -2] = np.nan assert_array_equal(run_output, direct_output) + + +def test_changing_options(input_data, expected_output, expected_output_with_options): + # Calling the function again while just changing options should + # also update the output data. There is global caching in MSIS, + # so we need to make sure that we are actually changing the model + # when the options change. + assert_allclose( + np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5 + ) + assert_allclose( + np.squeeze(msis.run(*input_data, options=[0] * 25)), + expected_output_with_options, + rtol=1e-5, + ) + + +def test_options_calls(input_data): + # Check that we don't call the initialization function unless + # our options have changed between calls. + # Reset the cache + for version in msis._previous_options: + msis._previous_options[version] = None + with patch("pymsis.msis21f.pyinitswitch") as mock_init: + msis.run(*input_data, options=[0] * 25) + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25) + # Called again shouldn't call the initialization function + mock_init.assert_called_once() + + # Our initialization function is different for MSIS00 and v2.0 + # This should still be called and not already set because + # we've already run v2.1 + with patch("pymsis.msis20f.pyinitswitch") as mock_init: + msis.run(*input_data, options=[0] * 25, version="2.0") + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25, version="2.0") + # Called again shouldn't call the initialization function + mock_init.assert_called_once() + with patch("pymsis.msis00f.pytselec") as mock_init: + msis.run(*input_data, options=[0] * 25, version=0) + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25, version=0) + # Called again shouldn't call the initialization function + mock_init.assert_called_once()