Skip to content

Commit

Permalink
Merge pull request #60 from greglucas/fix-options
Browse files Browse the repository at this point in the history
FIX: Handle updating options without changing other variables
  • Loading branch information
greglucas authored Nov 15, 2024
2 parents 2f6c237 + b197870 commit be2cb86
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 6 deletions.
20 changes: 17 additions & 3 deletions pymsis/msis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from pymsis.utils import get_f107_ap


# Store the previous options to avoid reinitializing the model
# each iteration unless necessary
_previous_options: dict[str, list[float] | None] = {"0": None, "2.0": None, "2.1": None}


def run(
dates: npt.ArrayLike,
lons: npt.ArrayLike,
Expand Down Expand Up @@ -135,7 +140,9 @@ def run(
# convert to string version
version = str(version)
if version in {"0", "00"}:
msis00f.pytselec(options)
if _previous_options["0"] != options:
msis00f.pytselec(options)
_previous_options["0"] = options
output = msis00f.pygtd7d(
input_data[:, 0],
input_data[:, 1],
Expand All @@ -154,10 +161,17 @@ def run(

# Select the proper library. Default to version 2.1, unless explicitly
# requested "2.0" via string
msis_lib = msis21f
if version == "2.0":
msis_lib = msis20f
msis_lib.pyinitswitch(options, parmpath=msis_path)
else:
version = "2.1"
msis_lib = msis21f

# Only reinitialize the model if the options have changed
if _previous_options[version] != options:
msis_lib.pyinitswitch(options, parmpath=msis_path)
_previous_options[version] = options

output = msis_lib.pymsiscalc(
input_data[:, 0],
input_data[:, 1],
Expand Down
10 changes: 10 additions & 0 deletions src/wrappers/msis2.F90
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@

subroutine pyinitswitch(switch_legacy, parmpath)
use msis_calc, only: msiscalc
use msis_constants, only: rp
use msis_init, only: msisinit

implicit none

real(4), intent(in), optional :: switch_legacy(1:25) !Legacy switch array
character(len=*), intent(in), optional :: parmpath !Path to parameter file
real(kind=rp) :: output = 0.
real(kind=rp) :: output_arr(1:11) = 0.

call msisinit(switch_legacy=switch_legacy, parmpath=parmpath)

! Artificially call msiscalc to reset the last variables as there is
! a global cache on these and the parameters won't be updated if we
! don't set them to something different.
! See issue: gh-59
call msiscalc(0., 0., -999., -999., -1., 0., 0., (/1., 1., 1., 1., 1., 1., 1./), output, output_arr)

return
end subroutine pyinitswitch

Expand Down
76 changes: 73 additions & 3 deletions tests/test_msis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -42,6 +44,26 @@ def expected_output():
)


@pytest.fixture
def expected_output_with_options():
return np.array(
[
2.427699e-10,
2.849738e15,
1.364307e14,
3.836351e15,
9.207778e12,
1.490457e11,
2.554763e12,
3.459567e13,
8.023306e-04,
1.326593e12,
9.277623e02,
],
dtype=np.float32,
)


@pytest.fixture
def expected_output00():
return np.array(
Expand Down Expand Up @@ -182,10 +204,13 @@ def test_create_input_multi_lon_lat(input_data, expected_input):
assert_array_equal(data, [expected_input] * 5 * 5)


def test_run_options(input_data):
def test_run_options(input_data, expected_output):
# Default options is all 1's, so make sure they are equivalent
assert_array_equal(
msis.run(*input_data, options=None), msis.run(*input_data, options=[1] * 25)
assert_allclose(
np.squeeze(msis.run(*input_data, options=None)), expected_output, rtol=1e-5
)
assert_allclose(
np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5
)

with pytest.raises(ValueError, match="options needs to be a list"):
Expand Down Expand Up @@ -395,3 +420,48 @@ def test_keyword_argument_call(input_data, version, func):
# NO missing from versions before 2.1
direct_output[:, -2] = np.nan
assert_array_equal(run_output, direct_output)


def test_changing_options(input_data, expected_output, expected_output_with_options):
# Calling the function again while just changing options should
# also update the output data. There is global caching in MSIS,
# so we need to make sure that we are actually changing the model
# when the options change.
assert_allclose(
np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5
)
assert_allclose(
np.squeeze(msis.run(*input_data, options=[0] * 25)),
expected_output_with_options,
rtol=1e-5,
)


def test_options_calls(input_data):
# Check that we don't call the initialization function unless
# our options have changed between calls.
# Reset the cache
for version in msis._previous_options:
msis._previous_options[version] = None
with patch("pymsis.msis21f.pyinitswitch") as mock_init:
msis.run(*input_data, options=[0] * 25)
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25)
# Called again shouldn't call the initialization function
mock_init.assert_called_once()

# Our initialization function is different for MSIS00 and v2.0
# This should still be called and not already set because
# we've already run v2.1
with patch("pymsis.msis20f.pyinitswitch") as mock_init:
msis.run(*input_data, options=[0] * 25, version="2.0")
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25, version="2.0")
# Called again shouldn't call the initialization function
mock_init.assert_called_once()
with patch("pymsis.msis00f.pytselec") as mock_init:
msis.run(*input_data, options=[0] * 25, version=0)
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25, version=0)
# Called again shouldn't call the initialization function
mock_init.assert_called_once()

0 comments on commit be2cb86

Please sign in to comment.