Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions astromodels/functions/functions_1D/absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numba as nb
import numpy as np
from astropy.io import fits
from interpolation import interp

from astromodels.functions.function import Function1D, FunctionMeta
from astromodels.utils import _get_data_file_path
Expand Down Expand Up @@ -190,6 +189,10 @@ def _init_xsect(self):

self.xsect_ene, self.xsect_val = phabs.xsect_table

assert np.all(
np.diff(self.xsect_ene) > 0
), "xsec_ene must be strictly increasing otherwise interpolation fails"

@property
def abundance_table_info(self):
print(phabs.info)
Expand All @@ -207,7 +210,7 @@ def evaluate(self, x, NH, redshift):
_redshift = redshift
_x = x

xsect_interp = interp(self.xsect_ene, self.xsect_val, _x * (1 + _redshift))
xsect_interp = np.interp(_x * (1 + _redshift), self.xsect_ene, self.xsect_val)

# evaluate the exponential with numba

Expand Down Expand Up @@ -280,6 +283,9 @@ def _init_xsect(self):
tbabs.set_table(self.abundance_table.value)

self.xsect_ene, self.xsect_val = tbabs.xsect_table
assert np.all(
np.diff(self.xsect_ene) > 0
), "xsec_ene must be strictly increasing otherwise interpolation fails"

log.debug(f"updated the TbAbs table to {self.abundance_table.value}")

Expand All @@ -300,7 +306,7 @@ def evaluate(self, x, NH, redshift):
_redshift = redshift
_x = x

xsect_interp = interp(self.xsect_ene, self.xsect_val, _x * (1 + _redshift))
xsect_interp = np.interp(_x * (1 + _redshift), self.xsect_ene, self.xsect_val)

spec = _numba_eval(NH, xsect_interp) * _y_unit

Expand Down Expand Up @@ -355,6 +361,10 @@ def _init_xsect(self):

self.xsect_ene, self.xsect_val = wabs.xsect_table

assert np.all(
np.diff(self.xsect_ene) > 0
), "xsec_ene must be strictly increasing otherwise interpolation fails"

@property
def abundance_table_info(self):
print(wabs.info)
Expand All @@ -372,7 +382,7 @@ def evaluate(self, x, NH, redshift):
_redshift = redshift
_x = x

xsect_interp = interp(self.xsect_ene, self.xsect_val, _x * (1 + _redshift))
xsect_interp = np.interp(_x * (1 + _redshift), self.xsect_ene, self.xsect_val)

spec = _numba_eval(NH, xsect_interp) * _y_unit

Expand Down
9 changes: 5 additions & 4 deletions astromodels/functions/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import h5py
import numpy as np
import scipy.interpolate
from interpolation import interp
from interpolation.splines import eval_linear

from astromodels.core.parameter import Parameter
from astromodels.functions.function import Function1D, FunctionMeta
Expand Down Expand Up @@ -63,7 +61,7 @@ def __init__(self, grid, values):
self._values = np.ascontiguousarray(values)

def __call__(self, v):
return eval_linear(self._grid, self._values, v)
return scipy.interpolate.interpn(self._grid, self._values, v)


class UnivariateSpline(object):
Expand All @@ -72,7 +70,10 @@ def __init__(self, x, y):
self._y = y

def __call__(self, v):
return interp(self._x, self._y, v)
if isinstance(self._x, np.ndarray) and isinstance(self._y, np.ndarray):
return np.interp(v, self._x, self._y.reshape(self._x.shape))
else:
return np.interp(v, self._x, self._y)


class TemplateModelFactory(object):
Expand Down
15 changes: 14 additions & 1 deletion astromodels/tests/test_template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
TemplateModelFactory,
XSPECTableModel,
)
from astromodels.functions.template_model import convert_old_table_model
from astromodels.functions.template_model import (
convert_old_table_model,
UnivariateSpline,
)
from astromodels.utils import _get_data_file_path
from astromodels.utils.logging import update_logging_level

Expand Down Expand Up @@ -291,3 +294,13 @@ def test_table_conversion():
npt.assert_almost_equal(test(xx), old_table(xx))

p.unlink()


def test_univariate_spline():
univ = UnivariateSpline([0, 1, 2], [3, 4, 5])
res = univ(1.5)
assert np.isclose(res, 4.5)

univ_np = UnivariateSpline(np.array([0, 1, 2]), np.array([3, 4, 5]))
res_np = univ_np(1.5)
assert np.isclose(res_np, 4.5)
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
"astropy>=1.2",
"dill",
"future",
"legacy-cgi; python_version >= '3.13'",
"interpolation>=2.2.3",
"numba",
"h5py",
"pandas",
Expand Down
Loading