Skip to content

Commit 28c7ed7

Browse files
committed
Replace numpy.vectorize with numba.vectorize in src/qha/grid_interpolation.py
wiht `calculate_eulerian_strain` & `from_eulerian_strain`
1 parent 70572c7 commit 28c7ed7

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/qha/grid_interpolation.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from typing import Optional, Tuple
1212

1313
import numpy as np
14+
from numba import float64, vectorize
1415

15-
from qha.fitting import polynomial_least_square_fitting, apply_finite_strain_fitting
16-
from qha.type_aliases import Vector, Matrix
16+
from qha.fitting import apply_finite_strain_fitting, polynomial_least_square_fitting
17+
from qha.type_aliases import Matrix, Vector
1718
from qha.unit_conversion import gpa_to_ry_b3
1819

1920
# ===================== What can be exported? =====================
@@ -25,7 +26,7 @@
2526
]
2627

2728

28-
@np.vectorize
29+
@vectorize([float64(float64, float64)], nopython=True, cache=True)
2930
def calculate_eulerian_strain(v0, vs):
3031
"""
3132
Calculate the Eulerian strain (:math:`f`s) of a given volume vector *vs* with respect to a reference volume *v0*,
@@ -42,7 +43,7 @@ def calculate_eulerian_strain(v0, vs):
4243
return 1 / 2 * ((v0 / vs) ** (2 / 3) - 1)
4344

4445

45-
@np.vectorize
46+
@vectorize([float64(float64, float64)], nopython=True, cache=True)
4647
def from_eulerian_strain(v0, fs):
4748
"""
4849
Calculate the corresponding volumes :math:`V`s from a vector of given Eulerian strains (*fs*)
@@ -157,9 +158,10 @@ class is created, unless user is clear what is being done.
157158
# r = v_upper / v_max = v_min / v_lower
158159
v_lower, v_upper = v_min / self._ratio, v_max * self._ratio
159160
# The *v_max* is a reference value here.
160-
s_upper, s_lower = calculate_eulerian_strain(
161-
v_max, v_lower
162-
), calculate_eulerian_strain(v_max, v_upper)
161+
s_upper, s_lower = (
162+
calculate_eulerian_strain(v_max, v_lower),
163+
calculate_eulerian_strain(v_max, v_upper),
164+
)
163165
self._strains = np.linspace(s_lower, s_upper, self._out_volumes_num)
164166
self._out_volumes = from_eulerian_strain(v_max, self._strains)
165167

0 commit comments

Comments
 (0)