From b197870e03d44cefacaf74f39ec881c733f6ff83 Mon Sep 17 00:00:00 2001 From: Greg Lucas Date: Thu, 14 Nov 2024 16:32:38 -0700 Subject: [PATCH] FIX: Handle updating options without changing other variables There was an issue in MSIS2 where some steps are cached in the Fortran code if the input values (lat, lon, ...) are unchanged between subsequent runs. This causes issues if one wants to compare what the differences are between the options switches. (MSIS00 didn't have this same cache) Additionally, add a check to the Python side to only call the init() function when options have changed, not every time we enter the run() call. --- pymsis/msis.py | 20 +++++++++-- src/wrappers/msis2.F90 | 10 ++++++ tests/test_msis.py | 76 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 100 insertions(+), 6 deletions(-) diff --git a/pymsis/msis.py b/pymsis/msis.py index ea1fee9..eb20c9c 100644 --- a/pymsis/msis.py +++ b/pymsis/msis.py @@ -9,6 +9,11 @@ from pymsis.utils import get_f107_ap +# Store the previous options to avoid reinitializing the model +# each iteration unless necessary +_previous_options: dict[str, list[float] | None] = {"0": None, "2.0": None, "2.1": None} + + def run( dates: npt.ArrayLike, lons: npt.ArrayLike, @@ -135,7 +140,9 @@ def run( # convert to string version version = str(version) if version in {"0", "00"}: - msis00f.pytselec(options) + if _previous_options["0"] != options: + msis00f.pytselec(options) + _previous_options["0"] = options output = msis00f.pygtd7d( input_data[:, 0], input_data[:, 1], @@ -154,10 +161,17 @@ def run( # Select the proper library. Default to version 2.1, unless explicitly # requested "2.0" via string - msis_lib = msis21f if version == "2.0": msis_lib = msis20f - msis_lib.pyinitswitch(options, parmpath=msis_path) + else: + version = "2.1" + msis_lib = msis21f + + # Only reinitialize the model if the options have changed + if _previous_options[version] != options: + msis_lib.pyinitswitch(options, parmpath=msis_path) + _previous_options[version] = options + output = msis_lib.pymsiscalc( input_data[:, 0], input_data[:, 1], diff --git a/src/wrappers/msis2.F90 b/src/wrappers/msis2.F90 index fc82854..e9528f4 100644 --- a/src/wrappers/msis2.F90 +++ b/src/wrappers/msis2.F90 @@ -1,14 +1,24 @@ subroutine pyinitswitch(switch_legacy, parmpath) + use msis_calc, only: msiscalc + use msis_constants, only: rp use msis_init, only: msisinit implicit none real(4), intent(in), optional :: switch_legacy(1:25) !Legacy switch array character(len=*), intent(in), optional :: parmpath !Path to parameter file + real(kind=rp) :: output = 0. + real(kind=rp) :: output_arr(1:11) = 0. call msisinit(switch_legacy=switch_legacy, parmpath=parmpath) + ! Artificially call msiscalc to reset the last variables as there is + ! a global cache on these and the parameters won't be updated if we + ! don't set them to something different. + ! See issue: gh-59 + call msiscalc(0., 0., -999., -999., -1., 0., 0., (/1., 1., 1., 1., 1., 1., 1./), output, output_arr) + return end subroutine pyinitswitch diff --git a/tests/test_msis.py b/tests/test_msis.py index 610c2a9..1cf8792 100644 --- a/tests/test_msis.py +++ b/tests/test_msis.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal @@ -42,6 +44,26 @@ def expected_output(): ) +@pytest.fixture +def expected_output_with_options(): + return np.array( + [ + 2.427699e-10, + 2.849738e15, + 1.364307e14, + 3.836351e15, + 9.207778e12, + 1.490457e11, + 2.554763e12, + 3.459567e13, + 8.023306e-04, + 1.326593e12, + 9.277623e02, + ], + dtype=np.float32, + ) + + @pytest.fixture def expected_output00(): return np.array( @@ -182,10 +204,13 @@ def test_create_input_multi_lon_lat(input_data, expected_input): assert_array_equal(data, [expected_input] * 5 * 5) -def test_run_options(input_data): +def test_run_options(input_data, expected_output): # Default options is all 1's, so make sure they are equivalent - assert_array_equal( - msis.run(*input_data, options=None), msis.run(*input_data, options=[1] * 25) + assert_allclose( + np.squeeze(msis.run(*input_data, options=None)), expected_output, rtol=1e-5 + ) + assert_allclose( + np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5 ) with pytest.raises(ValueError, match="options needs to be a list"): @@ -395,3 +420,48 @@ def test_keyword_argument_call(input_data, version, func): # NO missing from versions before 2.1 direct_output[:, -2] = np.nan assert_array_equal(run_output, direct_output) + + +def test_changing_options(input_data, expected_output, expected_output_with_options): + # Calling the function again while just changing options should + # also update the output data. There is global caching in MSIS, + # so we need to make sure that we are actually changing the model + # when the options change. + assert_allclose( + np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5 + ) + assert_allclose( + np.squeeze(msis.run(*input_data, options=[0] * 25)), + expected_output_with_options, + rtol=1e-5, + ) + + +def test_options_calls(input_data): + # Check that we don't call the initialization function unless + # our options have changed between calls. + # Reset the cache + for version in msis._previous_options: + msis._previous_options[version] = None + with patch("pymsis.msis21f.pyinitswitch") as mock_init: + msis.run(*input_data, options=[0] * 25) + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25) + # Called again shouldn't call the initialization function + mock_init.assert_called_once() + + # Our initialization function is different for MSIS00 and v2.0 + # This should still be called and not already set because + # we've already run v2.1 + with patch("pymsis.msis20f.pyinitswitch") as mock_init: + msis.run(*input_data, options=[0] * 25, version="2.0") + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25, version="2.0") + # Called again shouldn't call the initialization function + mock_init.assert_called_once() + with patch("pymsis.msis00f.pytselec") as mock_init: + msis.run(*input_data, options=[0] * 25, version=0) + mock_init.assert_called_once() + msis.run(*input_data, options=[0] * 25, version=0) + # Called again shouldn't call the initialization function + mock_init.assert_called_once()