Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Handle updating options without changing other variables #60

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pymsis/msis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from pymsis.utils import get_f107_ap


# Store the previous options to avoid reinitializing the model
# each iteration unless necessary
_previous_options: dict[str, list[float] | None] = {"0": None, "2.0": None, "2.1": None}


def run(
dates: npt.ArrayLike,
lons: npt.ArrayLike,
Expand Down Expand Up @@ -135,7 +140,9 @@ def run(
# convert to string version
version = str(version)
if version in {"0", "00"}:
msis00f.pytselec(options)
if _previous_options["0"] != options:
msis00f.pytselec(options)
_previous_options["0"] = options
output = msis00f.pygtd7d(
input_data[:, 0],
input_data[:, 1],
Expand All @@ -154,10 +161,17 @@ def run(

# Select the proper library. Default to version 2.1, unless explicitly
# requested "2.0" via string
msis_lib = msis21f
if version == "2.0":
msis_lib = msis20f
msis_lib.pyinitswitch(options, parmpath=msis_path)
else:
version = "2.1"
msis_lib = msis21f

# Only reinitialize the model if the options have changed
if _previous_options[version] != options:
msis_lib.pyinitswitch(options, parmpath=msis_path)
_previous_options[version] = options

output = msis_lib.pymsiscalc(
input_data[:, 0],
input_data[:, 1],
Expand Down
10 changes: 10 additions & 0 deletions src/wrappers/msis2.F90
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@

subroutine pyinitswitch(switch_legacy, parmpath)
use msis_calc, only: msiscalc
use msis_constants, only: rp
use msis_init, only: msisinit

implicit none

real(4), intent(in), optional :: switch_legacy(1:25) !Legacy switch array
character(len=*), intent(in), optional :: parmpath !Path to parameter file
real(kind=rp) :: output = 0.
real(kind=rp) :: output_arr(1:11) = 0.

call msisinit(switch_legacy=switch_legacy, parmpath=parmpath)

! Artificially call msiscalc to reset the last variables as there is
! a global cache on these and the parameters won't be updated if we
! don't set them to something different.
! See issue: gh-59
call msiscalc(0., 0., -999., -999., -1., 0., 0., (/1., 1., 1., 1., 1., 1., 1./), output, output_arr)

return
end subroutine pyinitswitch

Expand Down
76 changes: 73 additions & 3 deletions tests/test_msis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -42,6 +44,26 @@ def expected_output():
)


@pytest.fixture
def expected_output_with_options():
return np.array(
[
2.427699e-10,
2.849738e15,
1.364307e14,
3.836351e15,
9.207778e12,
1.490457e11,
2.554763e12,
3.459567e13,
8.023306e-04,
1.326593e12,
9.277623e02,
],
dtype=np.float32,
)


@pytest.fixture
def expected_output00():
return np.array(
Expand Down Expand Up @@ -182,10 +204,13 @@ def test_create_input_multi_lon_lat(input_data, expected_input):
assert_array_equal(data, [expected_input] * 5 * 5)


def test_run_options(input_data):
def test_run_options(input_data, expected_output):
# Default options is all 1's, so make sure they are equivalent
assert_array_equal(
msis.run(*input_data, options=None), msis.run(*input_data, options=[1] * 25)
assert_allclose(
np.squeeze(msis.run(*input_data, options=None)), expected_output, rtol=1e-5
)
assert_allclose(
np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5
)

with pytest.raises(ValueError, match="options needs to be a list"):
Expand Down Expand Up @@ -395,3 +420,48 @@ def test_keyword_argument_call(input_data, version, func):
# NO missing from versions before 2.1
direct_output[:, -2] = np.nan
assert_array_equal(run_output, direct_output)


def test_changing_options(input_data, expected_output, expected_output_with_options):
# Calling the function again while just changing options should
# also update the output data. There is global caching in MSIS,
# so we need to make sure that we are actually changing the model
# when the options change.
assert_allclose(
np.squeeze(msis.run(*input_data, options=[1] * 25)), expected_output, rtol=1e-5
)
assert_allclose(
np.squeeze(msis.run(*input_data, options=[0] * 25)),
expected_output_with_options,
rtol=1e-5,
)


def test_options_calls(input_data):
# Check that we don't call the initialization function unless
# our options have changed between calls.
# Reset the cache
for version in msis._previous_options:
msis._previous_options[version] = None
with patch("pymsis.msis21f.pyinitswitch") as mock_init:
msis.run(*input_data, options=[0] * 25)
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25)
# Called again shouldn't call the initialization function
mock_init.assert_called_once()

# Our initialization function is different for MSIS00 and v2.0
# This should still be called and not already set because
# we've already run v2.1
with patch("pymsis.msis20f.pyinitswitch") as mock_init:
msis.run(*input_data, options=[0] * 25, version="2.0")
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25, version="2.0")
# Called again shouldn't call the initialization function
mock_init.assert_called_once()
with patch("pymsis.msis00f.pytselec") as mock_init:
msis.run(*input_data, options=[0] * 25, version=0)
mock_init.assert_called_once()
msis.run(*input_data, options=[0] * 25, version=0)
# Called again shouldn't call the initialization function
mock_init.assert_called_once()
Loading