|
1 | 1 | import contextlib
|
2 |
| -import inspect |
3 | 2 | from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
|
4 |
| -from unittest import mock |
5 | 3 |
|
6 | 4 | import numba
|
7 | 5 | import numpy as np
|
@@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
|
108 | 106 | def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
|
109 | 107 | """Evaluate the Numba implementation in pure Python for coverage purposes."""
|
110 | 108 |
|
111 |
| - def py_tuple_setitem(t, i, v): |
112 |
| - ll = list(t) |
113 |
| - ll[i] = v |
114 |
| - return tuple(ll) |
115 |
| - |
116 |
| - def py_to_scalar(x): |
117 |
| - if isinstance(x, np.ndarray): |
118 |
| - return x.item() |
119 |
| - else: |
120 |
| - return x |
121 |
| - |
122 |
| - def njit_noop(*args, **kwargs): |
123 |
| - if len(args) == 1 and callable(args[0]): |
124 |
| - return args[0] |
125 |
| - else: |
126 |
| - return lambda x: x |
127 |
| - |
128 |
| - def vectorize_noop(*args, **kwargs): |
129 |
| - def wrap(fn): |
130 |
| - # `numba.vectorize` allows an `out` positional argument. We need |
131 |
| - # to account for that |
132 |
| - sig = inspect.signature(fn) |
133 |
| - nparams = len(sig.parameters) |
134 |
| - |
135 |
| - def inner_vec(*args): |
136 |
| - if len(args) > nparams: |
137 |
| - # An `out` argument has been specified for an in-place |
138 |
| - # operation |
139 |
| - out = args[-1] |
140 |
| - out[...] = np.vectorize(fn)(*args[:nparams]) |
141 |
| - return out |
142 |
| - else: |
143 |
| - return np.vectorize(fn)(*args) |
144 |
| - |
145 |
| - return inner_vec |
146 |
| - |
147 |
| - if len(args) == 1 and callable(args[0]): |
148 |
| - return wrap(args[0], **kwargs) |
149 |
| - else: |
150 |
| - return wrap |
151 |
| - |
152 |
| - mocks = [ |
153 |
| - mock.patch("numba.njit", njit_noop), |
154 |
| - mock.patch("numba.vectorize", vectorize_noop), |
155 |
| - mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem), |
156 |
| - mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), |
157 |
| - mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), |
158 |
| - mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), |
159 |
| - mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar), |
160 |
| - mock.patch( |
161 |
| - "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", |
162 |
| - lambda dtype: dtype, |
163 |
| - ), |
164 |
| - mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), |
165 |
| - ] |
166 |
| - |
167 |
| - with contextlib.ExitStack() as stack: |
168 |
| - for ctx in mocks: |
169 |
| - stack.enter_context(ctx) |
170 |
| - |
171 |
| - aesara_numba_fn = function( |
172 |
| - fn_inputs, |
173 |
| - fn_outputs, |
174 |
| - mode=mode, |
175 |
| - accept_inplace=True, |
176 |
| - ) |
177 |
| - _ = aesara_numba_fn(*inputs) |
| 109 | + numba.config.DISABLE_JIT = True |
| 110 | + aesara_numba_fn = function( |
| 111 | + fn_inputs, |
| 112 | + fn_outputs, |
| 113 | + mode=mode, |
| 114 | + accept_inplace=True, |
| 115 | + ) |
| 116 | + _ = aesara_numba_fn(*inputs) |
| 117 | + numba.config.DISABLE_JIT = False |
178 | 118 |
|
179 | 119 |
|
180 | 120 | def compare_numba_and_py(
|
|
0 commit comments