From f2de13b077e6df0cc5a93d7168fd7f38194c621d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 18 Apr 2024 18:02:12 +0000 Subject: [PATCH] Begin deprecation of implicit input conversion in FFT module --- CHANGELOG.md | 4 ++ jax/_src/numpy/fft.py | 54 ++++++++++++++++++- jax/_src/scipy/fft.py | 4 +- jax/experimental/array_api/_fft_functions.py | 47 +++++++++++++++- tests/array_api_test.py | 30 ++++++++++- tests/fft_test.py | 57 +++++++++++++++----- tests/jet_test.py | 2 +- 7 files changed, 176 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a30f333d961e..af7f86909d4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,10 @@ Remember to align the itemized text with the first line of an item within a list * The {func}`jax.numpy.hypot` function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed. + * The {mod}`jax.numpy.fft` module now issues a deprecation warning when + passing inputs that would require implicit conversion (e.g. `jnp.float32` + with `fft`, `jnp.int32` with `rftt`). Either manually onvert the inputs + to preserve old behavior, or use a more appropriate `fft` function. ## jaxlib 0.4.27 diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 03e468fa99a9..1d0c4c36580a 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -17,18 +17,65 @@ from collections.abc import Sequence import operator import numpy as np +import warnings -from jax import dtypes from jax import lax +from jax._src import dtypes from jax._src.lib import xla_client from jax._src.util import safe_zip -from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact +from jax._src.numpy.util import ( + check_arraylike, implements, + promote_dtypes_inexact, promote_dtypes_complex) from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import ufuncs, reductions from jax._src.typing import Array, ArrayLike Shape = Sequence[int] +NEEDS_COMPLEX_IN = {'fft', 'fftn', 'hfft', 'ifft', 'ifftn', 'irfft', 'irfftn'} +NEEDS_REAL_IN = { + # These are already handled in lax.fft + # 'rfft', 'rfftn', 'ihfft', + 'fftshift', 'ifftshift' + } + +# TODO(micky774): Promote warnings to ValueErrors when deprecation is completed +# and uncomment the portion of NEEDS_REAL_IN which currently defers type +# checking to lax.fft. Deprecation began 4-18-24. +def _check_input_fft(func_name: str, x: Array): + kind = x.dtype.kind + suggest_alternative_msg = ( + " or consider using a more appropriate fft function if applicable." + ) + if func_name in NEEDS_COMPLEX_IN and kind != "c": + warnings.warn( + f"Passing non-complex valued inputs to {func_name} is deprecated and " + "will soon raise a ValueError. Please explicitly convert the input to a " + f"complex dtype before passing to {func_name} in order to suppress this " + "warning," + suggest_alternative_msg, + DeprecationWarning, stacklevel=2 + ) + return promote_dtypes_complex(x)[0] + if func_name in NEEDS_REAL_IN: + if kind == "c": + warnings.warn( + f"Passing complex-valued inputs to {func_name} is deprecated and " + "will soon raise a ValueError. To suppress this warning, please convert " + "to real values first, such as by using jnp.real or jnp.imag to take " + "the real or imaginary components respectively," + suggest_alternative_msg, + DeprecationWarning, stacklevel=2 + ) + elif kind != "f": + warnings.warn( + f"Passing integral inputs to {func_name} is deprecated and " + "will soon raise a ValueError. Please convert to a real-valued " + "floating-point input first.", + DeprecationWarning, stacklevel=2 + ) + return promote_dtypes_inexact(x) + return x + + def _fft_norm(s: Array, func_name: str, norm: str) -> Array: if norm == "backward": return jnp.array(1) @@ -50,6 +97,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, full_name = f"jax.numpy.fft.{func_name}" check_arraylike(full_name, a) arr = jnp.asarray(a) + arr = _check_input_fft(func_name, arr) if s is not None: s = tuple(map(operator.index, s)) @@ -300,6 +348,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("fftshift", x) x = jnp.asarray(x) + arr = _check_input_fft("fftshift", x) shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) @@ -316,6 +365,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("ifftshift", x) x = jnp.asarray(x) + arr = _check_input_fft("ifftshift", x) shift: int | Sequence[int] if axes is None: axes = tuple(range(x.ndim)) diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index 0db98ddc0f40..70f9908c00be 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -55,7 +55,7 @@ def dct(x: Array, type: int = 2, n: int | None = None, for a in range(x.ndim)]) N = x.shape[axis] - v = _dct_interleave(x, axis) + v, = promote_dtypes_complex(_dct_interleave(x, axis)) V = jnp.fft.fft(v, axis=axis) k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis]) out = V * _W4(N, k) @@ -68,7 +68,7 @@ def dct(x: Array, type: int = 2, n: int | None = None, def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes) N1, N2 = x.shape[axis1], x.shape[axis2] - v = _dct_interleave(_dct_interleave(x, axis1), axis2) + v, = promote_dtypes_complex(_dct_interleave(_dct_interleave(x, axis1), axis2)) V = jnp.fft.fftn(v, axes=axes) k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype), [a for a in range(x.ndim) if a != axis1]) diff --git a/jax/experimental/array_api/_fft_functions.py b/jax/experimental/array_api/_fft_functions.py index d1e737a424ac..a956820b5f3e 100644 --- a/jax/experimental/array_api/_fft_functions.py +++ b/jax/experimental/array_api/_fft_functions.py @@ -13,46 +13,89 @@ # limitations under the License. import jax.numpy as jnp - +from jax._src.numpy.fft import NEEDS_COMPLEX_IN, NEEDS_REAL_IN as _NEEDS_REAL_IN +from jax._src.numpy.util import check_arraylike + +NEEDS_REAL_IN = _NEEDS_REAL_IN.union({'rfft', 'rfftn', 'ihfft'}) + +# TODO(micky774): Remove when jax.numpy.fft deprecation completes. Deprecation +# began 4-18-24. +def _check_input_fft(func_name: str, x): + check_arraylike('jax.experimental.array_api.' + func_name, x) + arr = jnp.asarray(x) + kind = arr.dtype.kind + suggest_alternative_msg = ( + " or consider using a more appropriate fft function if applicable." + ) + if func_name in NEEDS_COMPLEX_IN and kind != "c": + raise ValueError( + f"{func_name} requires complex-valued input, but received input with type " + f"{arr.dtype} instead. Please explicitly convert to a complex-valued input " + "first," + suggest_alternative_msg, + ) + if func_name in NEEDS_REAL_IN: + needs_real_msg = ( + f"{func_name} requires real-valued floating-point input, but received " + f"input with type {arr.dtype} instead. Please convert to a real-valued " + "floating-point input first" + ) + if kind == "c": + raise ValueError( + needs_real_msg + ", such as by using jnp.real or jnp.imag to take the " + "real or imaginary components respectively," + suggest_alternative_msg, + ) + elif kind != "f": + raise ValueError(needs_real_msg + '.') + return arr def fft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional discrete Fourier transform.""" + _check_input_fft('fft', x) return jnp.fft.fft(x, n=n, axis=axis, norm=norm) def ifft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional inverse discrete Fourier transform.""" + _check_input_fft('ifft', x) return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) def fftn(x, /, *, s=None, axes=None, norm='backward'): """Computes the n-dimensional discrete Fourier transform.""" + _check_input_fft('fftn', x) return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) def ifftn(x, /, *, s=None, axes=None, norm='backward'): """Computes the n-dimensional inverse discrete Fourier transform.""" + _check_input_fft('ifftn', x) return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) def rfft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional discrete Fourier transform for real-valued input.""" + _check_input_fft('rfft', x) return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) def irfft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional inverse of rfft for complex-valued input.""" + _check_input_fft('irfft', x) return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) def rfftn(x, /, *, s=None, axes=None, norm='backward'): """Computes the n-dimensional discrete Fourier transform for real-valued input.""" + _check_input_fft('rfftn', x) return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) def irfftn(x, /, *, s=None, axes=None, norm='backward'): """Computes the n-dimensional inverse of rfftn for complex-valued input.""" + _check_input_fft('irfftn', x) return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) def hfft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry.""" + _check_input_fft('hfft', x) return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) def ihfft(x, /, *, n=None, axis=-1, norm='backward'): """Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry.""" + _check_input_fft('ihfft', x) return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) def fftfreq(n, /, *, d=1.0, device=None): @@ -65,8 +108,10 @@ def rfftfreq(n, /, *, d=1.0, device=None): def fftshift(x, /, *, axes=None): """Shift the zero-frequency component to the center of the spectrum.""" + _check_input_fft('fftshift', x) return jnp.fft.fftshift(x, axes=axes) def ifftshift(x, /, *, axes=None): """Inverse of fftshift.""" + _check_input_fft('ifftshift', x) return jnp.fft.ifftshift(x, axes=axes) diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 5667c3459dad..c69df6a65487 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -27,7 +27,8 @@ from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype from jax.experimental import array_api - +from jax.experimental.array_api._fft_functions import ( + NEEDS_COMPLEX_IN, NEEDS_REAL_IN) config.parse_flags_with_absl() MAIN_NAMESPACE = { @@ -326,7 +327,7 @@ def test_dtypes_info(self, kind): target_dict = control[kind] assert info_dict == target_dict -class ArrayAPIErrors(absltest.TestCase): +class ArrayAPIErrors(jtu.JaxTestCase): """Test that our array API implementations raise errors where required""" # TODO(micky774): Remove when jnp.clip deprecation is completed @@ -347,6 +348,31 @@ def test_clip_complex(self): with self.assertRaisesRegex(ValueError, complex_msg): array_api.clip(x, max=-1+5j) + @jtu.sample_product( + [dict(dtype=dtype,func_name=func_name) + for real in [True, False] + for dtype in (jtu.dtypes.complex if real else jtu.dtypes.floating) + + jtu.dtypes.integer + jtu.dtypes.boolean + for func_name in (NEEDS_REAL_IN if real else NEEDS_COMPLEX_IN) + ]) + def testFftWarnings(self, dtype, func_name): + shape = (2, 3, 4) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + func = getattr(array_api.fft, func_name) + + if func_name in NEEDS_COMPLEX_IN: + msg = "complex-valued input" + else: + msg = "real-valued" + if x.dtype.kind == 'c': + msg += ".*real or imaginary" + if x.dtype.kind in {'c', 'r'}: + msg += ".*or consider using a more" + + with self.assertRaisesRegex(ValueError, expected_regex=msg): + func(x) + if __name__ == '__main__': absltest.main() diff --git a/tests/fft_test.py b/tests/fft_test.py index ce7455fdb4f8..717abebe3865 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -27,7 +27,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.numpy.util import promote_dtypes_complex -from jax._src.numpy.fft import _fft_norm +from jax._src.numpy.fft import _fft_norm, NEEDS_COMPLEX_IN, NEEDS_REAL_IN config.parse_flags_with_absl() @@ -37,7 +37,8 @@ float_dtypes = jtu.dtypes.floating inexact_dtypes = jtu.dtypes.inexact real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean -all_dtypes = real_dtypes + jtu.dtypes.complex +complex_dtypes = jtu.dtypes.complex +all_dtypes = real_dtypes + complex_dtypes def _get_fftn_test_axes(shape): @@ -89,7 +90,7 @@ def _zero_for_irfft(z, axes): else: parts = [lax.slice_in_dim(z.real, 0, 1, axis=axis).real, lax.slice_in_dim(z.real, 1, size, axis=axis)] - return jnp.concatenate(parts, axis=axis) + return jnp.concatenate(parts, axis=axis, dtype=z.dtype) class FftTest(jtu.JaxTestCase): @@ -142,7 +143,7 @@ def testLaxIrfftDoesNotMutateInputs(self, dtype): [dict(inverse=inverse, real=real, dtype=dtype) for inverse in [False, True] for real in [False, True] - for dtype in (real_dtypes if real and not inverse else all_dtypes) + for dtype in (float_dtypes if real and not inverse else complex_dtypes) ], [dict(shape=shape, axes=axes, s=s) for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)] @@ -202,20 +203,21 @@ def testFftnErrors(self, inverse, real): name = 'r' + name if inverse: name = 'i' + name + dtype = np.float64 if (real and not inverse) else np.complex64 func = _get_fftn_func(jnp.fft, inverse, real) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. " "Got axes None with input rank 4.".format(name), - lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None)) + lambda: func(rng([2, 3, 4, 5], dtype=dtype), axes=None)) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].", - lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1])) + lambda: func(rng([2, 3], dtype=dtype), axes=[1, 1])) self.assertRaises( - ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2])) + ValueError, lambda: func(rng([2, 3], dtype=dtype), axes=[2])) self.assertRaises( - ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3])) + ValueError, lambda: func(rng([2, 3], dtype=dtype), axes=[-3])) def testFftEmpty(self): out = jnp.fft.fft(jnp.zeros((0,), jnp.complex64)).block_until_ready() @@ -225,9 +227,11 @@ def testFftEmpty(self): [dict(inverse=inverse, real=real, hermitian=hermitian, dtype=dtype) for inverse in [False, True] for real in [False, True] - for hermitian in [False, True] - for dtype in (real_dtypes if (real and not inverse) or (hermitian and inverse) - else all_dtypes) + for hermitian in [False] + ([] if real else [True]) + for dtype in ( + float_dtypes if (real and not inverse) or (hermitian and inverse) + else complex_dtypes + ) ], shape=[(10,)], n=[None, 1, 7, 13, 20], @@ -289,7 +293,7 @@ def testFftErrors(self, inverse, real, hermitian): [dict(inverse=inverse, real=real, dtype=dtype) for inverse in [False, True] for real in [False, True] - for dtype in (real_dtypes if real and not inverse else all_dtypes) + for dtype in (float_dtypes if real and not inverse else complex_dtypes) ], shape=[(16, 8, 4, 8), (16, 8, 4, 8, 4)], axes=[(-2, -1), (0, 1), (1, 3), (-1, 2)], @@ -423,7 +427,7 @@ def testRfftfreqErrors(self, n): for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape) ], - dtype=all_dtypes, + dtype=float_dtypes, ) def testFftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) @@ -437,7 +441,7 @@ def testFftshift(self, shape, dtype, axes): for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape) ], - dtype=all_dtypes, + dtype=float_dtypes, ) def testIfftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) @@ -463,5 +467,30 @@ def testFftnormOverflow(self, norm, func_name, dtype): np_norm = np.reciprocal(np_norm) self.assertArraysAllClose(jax_norm, np_norm, rtol=3e-8, check_dtypes=False) + + @jtu.sample_product( + [dict(dtype=dtype,func_name=func_name) + for real in [True, False] + for dtype in (complex_dtypes if real else float_dtypes) + + jtu.dtypes.integer + jtu.dtypes.boolean + for func_name in (NEEDS_REAL_IN if real else NEEDS_COMPLEX_IN) + ]) + def testFftWarnings(self, dtype, func_name): + shape = (2, 3, 4) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + func = getattr(jnp.fft, func_name) + + if func_name in NEEDS_COMPLEX_IN: + msg = "non-complex" + else: + msg = "complex-valued" if x.dtype.kind == 'c' else "integral" + if x.dtype.kind in {'c', 'r'}: + msg += ".*or consider using a more" + + with self.assertWarnsRegex(DeprecationWarning, expected_regex=msg): + func(x) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jet_test.py b/tests/jet_test.py index 5661197509ec..2945d91f6af2 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -267,7 +267,7 @@ def test_stopgrad(self): self.unary_check(lax.stop_gradient) @jtu.skip_on_devices("tpu") def test_abs(self): self.unary_check(jnp.abs) @jtu.skip_on_devices("tpu") - def test_fft(self): self.unary_check(jnp.fft.fft) + def test_fft(self): self.unary_check(jnp.fft.fft, dtype=jnp.complex64) @jtu.skip_on_devices("tpu") def test_log1p(self): self.unary_check(jnp.log1p, lims=[0, 4.]) @jtu.skip_on_devices("tpu")