diff --git a/requirements.txt b/requirements.txt index 426bb04..4f312fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ numba==0.59.1 numpy==1.26.4 scipy==1.13.0 sympy==1.12 -mpmath==1.3.0 diff --git a/setup.cfg b/setup.cfg index fa8714d..5b9777e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,6 @@ numba==0.59.1 numpy==1.26.4 scipy==1.13.0 sympy==1.12 -mpmath==1.3.0 [test_requires] pytest = ^7.1.2 diff --git a/setup.py b/setup.py index 21ffb65..af8ec43 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ "numpy==1.26.4", "scipy==1.13.0", "sympy==1.12", - "mpmath==1.3.0", ] test_requires = [ diff --git a/src/fast_wave/wavefunction.py b/src/fast_wave/wavefunction.py index 347a118..d3176b9 100644 --- a/src/fast_wave/wavefunction.py +++ b/src/fast_wave/wavefunction.py @@ -264,9 +264,9 @@ def wavefunction_smmd(n: np.uint64, x: np.ndarray[np.float64], more_fast: bool = Examples -------- ```python - >>> wavefunction_smmd(0,(1.0,2.0)) + >>> wavefunction_smmd(0,np.array([1.0, 2.0])) array([0.45558067, 0.10165379]) - >>> wavefunction_smmd(61,(1.0,2.0)) + >>> wavefunction_smmd(61,np.array([1.0, 2.0])) array([-0.23930492, -0.01677378]) ``` @@ -331,9 +331,9 @@ def c_wavefunction_smmd(n: np.uint64, x: np.ndarray[np.complex128], more_fast: b Examples -------- ```python - >>> c_wavefunction_smmd(0,(1.0 + 1.0j, 2.0 + 2.0j)) + >>> c_wavefunction_smmd(0,np.array([1.0 + 1.0j, 2.0 + 2.0j])) array([ 0.40583486-0.63205035j, -0.49096842+0.56845369j]) - >>> c_wavefunction_smmd(61,(1.0 + 1.0j, 2.0 + 2.0j)) + >>> c_wavefunction_smmd(61,np.array([1.0 + 1.0j, 2.0 + 2.0j])) array([-7.56548941e+03+9.21498621e+02j, -1.64189542e+08-3.70892077e+08j]) ``` @@ -475,7 +475,7 @@ def wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.float64]) -> np.ndarray[np. Examples -------- ```python - >>> wavefunction_mmmd(1,(1.0 ,2.0)) + >>> wavefunction_mmmd(1,np.array([1.0, 2.0])) array([[0.45558067, 0.10165379], [0.64428837, 0.28752033]]) ``` @@ -518,7 +518,7 @@ def c_wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.complex128]) -> np.ndarra Examples -------- ```python - >>> c_wavefunction_mmmd(1,(1.0 + 1.0j,2.0 + 2.0j)) + >>> c_wavefunction_mmmd(1,np.array([1.0 + 1.0j, 2.0 + 2.0j])) array([[ 0.40583486-0.63205035j, -0.49096842+0.56845369j], [ 1.46779135-0.31991701j, -2.99649822+0.21916143j]]) ``` diff --git a/tests/test_wavefunction_arb_prec.py b/tests/test_wavefunction_arb_prec.py new file mode 100644 index 0000000..dbc5df0 --- /dev/null +++ b/tests/test_wavefunction_arb_prec.py @@ -0,0 +1,44 @@ +import numpy as np +from src.fast_wave.wavefunction_arb_prec import * + +def test_wavefunction_computation(): + """ + Tests the basic functionality of all wavefunction_arb_prec functions. + """ + + wave_smod_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) + wave_smmd_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) + wave_mmod_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) + wave_mmmd_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) + c_wave_smod_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) + c_wave_smmd_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) + c_wave_mmod_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) + c_wave_mmmd_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) + + # Testing basic functionality + test_output_odsm = wave_smod_ap(2, 10.0, 20) + assert isinstance(test_output_odsm, mpmath.ctx_mp_python.mpf) + + test_output_odmm = wave_mmod_ap(2, 10.0, 20) + assert isinstance(test_output_odmm, mpmath.matrices.matrices._matrix) + + test_output_mdsm = wave_smmd_ap(2, np.array([10.0, 4.5]), 20) + assert isinstance(test_output_mdsm, mpmath.matrices.matrices._matrix) + + test_output_mdmm = wave_mmmd_ap(2, np.array([10.0, 4.5]), 20) + assert isinstance(test_output_mdmm, mpmath.matrices.matrices._matrix) + + test_output_c_odsm = c_wave_smod_ap(2, 10.0 + 0.0j, 20) + assert isinstance(test_output_c_odsm, mpmath.ctx_mp_python.mpc) + + test_output_c_odmm = c_wave_mmod_ap(2, 10.0 + 0.0j, 20) + assert isinstance(test_output_c_odmm, mpmath.matrices.matrices._matrix) + + test_output_c_mdsm = c_wave_smmd_ap(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]), 20) + assert isinstance(test_output_c_mdsm, mpmath.matrices.matrices._matrix) + + test_output_c_mdmm = c_wave_mmmd_ap(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]), 20) + assert isinstance(test_output_c_mdmm, mpmath.matrices.matrices._matrix) + + print("All functionality tests passed.") +