Skip to content

Commit

Permalink
remove Numba buggy caching for packaging
Browse files Browse the repository at this point in the history
  • Loading branch information
DM-Berger committed Sep 22, 2022
1 parent 699d186 commit 52ecc91
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
12 changes: 6 additions & 6 deletions empyricalRMT/observables/rigidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def delta_L(
return np.mean(delta_running), converged, k # type: ignore


@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _delta_grid(
unfolded: ndarray, starts: ndarray, L: float, gridsize: int, use_simpson: bool
) -> f64:
Expand All @@ -320,7 +320,7 @@ def _delta_grid(
return np.mean(delta3s) # type: ignore


@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _slope(x: ndarray, y: ndarray) -> f64:
"""Perform linear regression to compute the slope."""
x_mean = np.mean(x)
Expand All @@ -334,12 +334,12 @@ def _slope(x: ndarray, y: ndarray) -> f64:
return cov / var # type: ignore


@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _intercept(x: ndarray, y: ndarray, slope: f64) -> f64:
return np.mean(y) - slope * np.mean(x) # type: ignore


@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _integrate_fast(grid: ndarray, values: ndarray) -> f64:
"""scipy.integrate.trapz is excruciatingly slow and unusable for our purposes.
This tiny rewrite seems to result in a near 20x speedup. However, being trapezoidal
Expand All @@ -354,7 +354,7 @@ def _integrate_fast(grid: ndarray, values: ndarray) -> f64:

# NOTE: !!!! Very important *NOT* to use parallel=True here, since we parallelize
# the outer loops. Adding it inside *dramatically* slows performance.
@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _sq_lin_deviation(eigs: ndarray, steps: ndarray, K: f64, w: f64, grid: fArr) -> fArr:
"""Compute the sqaured deviation of the staircase function of the best fitting
line, over the region in `grid`.
Expand Down Expand Up @@ -390,7 +390,7 @@ def _sq_lin_deviation(eigs: ndarray, steps: ndarray, K: f64, w: f64, grid: fArr)


# fmt: off
@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, cache=False, fastmath=True)
def _int_simps_nonunif(grid: fArr, vals: fArr) -> f64:
"""
Simpson rule for irregularly spaced data. Copied shamelessly from
Expand Down
4 changes: 2 additions & 2 deletions empyricalRMT/observables/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def step_values(eigs: fArr, x: Union[float, fArr]) -> Union[float, iArr]:
return cast(iArr, _step_function_fast(eigs, x))


@njit(fastmath=True, cache=True) # type: ignore[misc]
@njit(fastmath=True, cache=False) # type: ignore[misc]
def _step_function_fast(eigs: fArr, x: fArr) -> iArr:
"""optimized version that does not repeatedly call np.sum(eigs <= x), since
this function needed to be called extensively in rigidity calculation."""
Expand Down Expand Up @@ -70,7 +70,7 @@ def _step_function_fast(eigs: fArr, x: fArr) -> iArr:
return ret


@njit(fastmath=True, cache=True, parallel=True)
@njit(fastmath=True, cache=False, parallel=True)
def _step_function_correct(eigs: fArr, x: fArr) -> fArr:
"""Intended for testing _step_function_fast correctness, as this function
is for sure correct, just slow.
Expand Down
10 changes: 5 additions & 5 deletions empyricalRMT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def make_parent_directories(path: Path) -> None:
make_directory(path)


@jit(nopython=True, cache=True)
@jit(nopython=True)
def kahan_add(current_sum: f64, update: f64, carry_over: f64) -> Tuple[f64, f64]:
"""
Returns
Expand All @@ -103,7 +103,7 @@ def kahan_add(current_sum: f64, update: f64, carry_over: f64) -> Tuple[f64, f64]
return updated_sum, c


@jit(nopython=True, cache=True, fastmath=True)
@jit(nopython=True, fastmath=True)
def nd_find(arr: ndarray, value: Any) -> Optional[int]:
for i, val in np.ndenumerate(arr):
if val == value:
Expand Down Expand Up @@ -133,7 +133,7 @@ def flatten_4D(img4D: ndarray) -> np.ndarray:
return img4D.reshape((np.prod(img4D.shape[0:-1]),) + (img4D.shape[-1],))


@jit(nopython=True, fastmath=True, cache=True)
@jit(nopython=True, fastmath=True)
def slope(x: ndarray, y: ndarray) -> np.float64:
x_mean = np.mean(x)
y_mean = np.mean(y)
Expand All @@ -158,12 +158,12 @@ def variance(arr: ndarray) -> np.float64:
return np.float64(scale * summed)


@jit(nopython=True, fastmath=True, cache=True)
@jit(nopython=True, fastmath=True)
def intercept(x: ndarray, y: ndarray, slope: np.float64) -> np.float64:
return np.float64(np.mean(y) - slope * np.mean(x))


@jit(nopython=True, fastmath=True, cache=True)
@jit(nopython=True, fastmath=True)
def fast_r(x: ndarray, y: ndarray) -> np.float64:
n = len(x)
num = x * y - n * np.mean(x) * np.mean(y)
Expand Down

0 comments on commit 52ecc91

Please sign in to comment.