Skip to content

Commit 402a38f

Browse files
committed
Use disable numba JIT
1 parent 984ee55 commit 402a38f

File tree

1 file changed

+9
-69
lines changed

1 file changed

+9
-69
lines changed

tests/link/numba/test_basic.py

+9-69
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import contextlib
2-
import inspect
32
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
4-
from unittest import mock
53

64
import numba
75
import numpy as np
@@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
108106
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
109107
"""Evaluate the Numba implementation in pure Python for coverage purposes."""
110108

111-
def py_tuple_setitem(t, i, v):
112-
ll = list(t)
113-
ll[i] = v
114-
return tuple(ll)
115-
116-
def py_to_scalar(x):
117-
if isinstance(x, np.ndarray):
118-
return x.item()
119-
else:
120-
return x
121-
122-
def njit_noop(*args, **kwargs):
123-
if len(args) == 1 and callable(args[0]):
124-
return args[0]
125-
else:
126-
return lambda x: x
127-
128-
def vectorize_noop(*args, **kwargs):
129-
def wrap(fn):
130-
# `numba.vectorize` allows an `out` positional argument. We need
131-
# to account for that
132-
sig = inspect.signature(fn)
133-
nparams = len(sig.parameters)
134-
135-
def inner_vec(*args):
136-
if len(args) > nparams:
137-
# An `out` argument has been specified for an in-place
138-
# operation
139-
out = args[-1]
140-
out[...] = np.vectorize(fn)(*args[:nparams])
141-
return out
142-
else:
143-
return np.vectorize(fn)(*args)
144-
145-
return inner_vec
146-
147-
if len(args) == 1 and callable(args[0]):
148-
return wrap(args[0], **kwargs)
149-
else:
150-
return wrap
151-
152-
mocks = [
153-
mock.patch("numba.njit", njit_noop),
154-
mock.patch("numba.vectorize", vectorize_noop),
155-
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
156-
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
157-
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
158-
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
159-
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
160-
mock.patch(
161-
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
162-
lambda dtype: dtype,
163-
),
164-
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
165-
]
166-
167-
with contextlib.ExitStack() as stack:
168-
for ctx in mocks:
169-
stack.enter_context(ctx)
170-
171-
aesara_numba_fn = function(
172-
fn_inputs,
173-
fn_outputs,
174-
mode=mode,
175-
accept_inplace=True,
176-
)
177-
_ = aesara_numba_fn(*inputs)
109+
numba.config.DISABLE_JIT = True
110+
aesara_numba_fn = function(
111+
fn_inputs,
112+
fn_outputs,
113+
mode=mode,
114+
accept_inplace=True,
115+
)
116+
_ = aesara_numba_fn(*inputs)
117+
numba.config.DISABLE_JIT = False
178118

179119

180120
def compare_numba_and_py(

0 commit comments

Comments
 (0)