Skip to content

Commit

Permalink
tests arbitrary-precision
Browse files Browse the repository at this point in the history
  • Loading branch information
fobos123deimos committed Aug 26, 2024
1 parent 6f5eaa9 commit c28b245
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 9 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ numba==0.59.1
numpy==1.26.4
scipy==1.13.0
sympy==1.12
mpmath==1.3.0
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ numba==0.59.1
numpy==1.26.4
scipy==1.13.0
sympy==1.12
mpmath==1.3.0

[test_requires]
pytest = ^7.1.2
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"numpy==1.26.4",
"scipy==1.13.0",
"sympy==1.12",
"mpmath==1.3.0",
]

test_requires = [
Expand Down
12 changes: 6 additions & 6 deletions src/fast_wave/wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def wavefunction_smmd(n: np.uint64, x: np.ndarray[np.float64], more_fast: bool =
Examples
--------
```python
>>> wavefunction_smmd(0,(1.0,2.0))
>>> wavefunction_smmd(0,np.array([1.0, 2.0]))
array([0.45558067, 0.10165379])
>>> wavefunction_smmd(61,(1.0,2.0))
>>> wavefunction_smmd(61,np.array([1.0, 2.0]))
array([-0.23930492, -0.01677378])
```
Expand Down Expand Up @@ -331,9 +331,9 @@ def c_wavefunction_smmd(n: np.uint64, x: np.ndarray[np.complex128], more_fast: b
Examples
--------
```python
>>> c_wavefunction_smmd(0,(1.0 + 1.0j, 2.0 + 2.0j))
>>> c_wavefunction_smmd(0,np.array([1.0 + 1.0j, 2.0 + 2.0j]))
array([ 0.40583486-0.63205035j, -0.49096842+0.56845369j])
>>> c_wavefunction_smmd(61,(1.0 + 1.0j, 2.0 + 2.0j))
>>> c_wavefunction_smmd(61,np.array([1.0 + 1.0j, 2.0 + 2.0j]))
array([-7.56548941e+03+9.21498621e+02j, -1.64189542e+08-3.70892077e+08j])
```
Expand Down Expand Up @@ -475,7 +475,7 @@ def wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.float64]) -> np.ndarray[np.
Examples
--------
```python
>>> wavefunction_mmmd(1,(1.0 ,2.0))
>>> wavefunction_mmmd(1,np.array([1.0, 2.0]))
array([[0.45558067, 0.10165379],
[0.64428837, 0.28752033]])
```
Expand Down Expand Up @@ -518,7 +518,7 @@ def c_wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.complex128]) -> np.ndarra
Examples
--------
```python
>>> c_wavefunction_mmmd(1,(1.0 + 1.0j,2.0 + 2.0j))
>>> c_wavefunction_mmmd(1,np.array([1.0 + 1.0j, 2.0 + 2.0j]))
array([[ 0.40583486-0.63205035j, -0.49096842+0.56845369j],
[ 1.46779135-0.31991701j, -2.99649822+0.21916143j]])
```
Expand Down
44 changes: 44 additions & 0 deletions tests/test_wavefunction_arb_prec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from src.fast_wave.wavefunction_arb_prec import *

def test_wavefunction_computation():
"""
Tests the basic functionality of all wavefunction_arb_prec functions.
"""

wave_smod_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
wave_smmd_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
wave_mmod_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
wave_mmmd_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
c_wave_smod_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
c_wave_smmd_ap = wavefunction_arb_prec(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)
c_wave_mmod_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
c_wave_mmmd_ap = wavefunction_arb_prec(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)

# Testing basic functionality
test_output_odsm = wave_smod_ap(2, 10.0, 20)
assert isinstance(test_output_odsm, mpmath.ctx_mp_python.mpf)

test_output_odmm = wave_mmod_ap(2, 10.0, 20)
assert isinstance(test_output_odmm, mpmath.matrices.matrices._matrix)

test_output_mdsm = wave_smmd_ap(2, np.array([10.0, 4.5]), 20)
assert isinstance(test_output_mdsm, mpmath.matrices.matrices._matrix)

test_output_mdmm = wave_mmmd_ap(2, np.array([10.0, 4.5]), 20)
assert isinstance(test_output_mdmm, mpmath.matrices.matrices._matrix)

test_output_c_odsm = c_wave_smod_ap(2, 10.0 + 0.0j, 20)
assert isinstance(test_output_c_odsm, mpmath.ctx_mp_python.mpc)

test_output_c_odmm = c_wave_mmod_ap(2, 10.0 + 0.0j, 20)
assert isinstance(test_output_c_odmm, mpmath.matrices.matrices._matrix)

test_output_c_mdsm = c_wave_smmd_ap(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]), 20)
assert isinstance(test_output_c_mdsm, mpmath.matrices.matrices._matrix)

test_output_c_mdmm = c_wave_mmmd_ap(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]), 20)
assert isinstance(test_output_c_mdmm, mpmath.matrices.matrices._matrix)

print("All functionality tests passed.")

0 comments on commit c28b245

Please sign in to comment.