Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin deprecation of implicit input conversion in FFT module #20818

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ Remember to align the itemized text with the first line of an item within a list
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* The {mod}`jax.numpy.fft` module now issues a deprecation warning when
passing inputs that would require implicit conversion (e.g. `jnp.float32`
with `fft`, `jnp.int32` with `rftt`). Either manually onvert the inputs
to preserve old behavior, or use a more appropriate `fft` function.

## jaxlib 0.4.27

Expand Down
54 changes: 52 additions & 2 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,65 @@
from collections.abc import Sequence
import operator
import numpy as np
import warnings

from jax import dtypes
from jax import lax
from jax._src import dtypes
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact
from jax._src.numpy.util import (
check_arraylike, implements,
promote_dtypes_inexact, promote_dtypes_complex)
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike

Shape = Sequence[int]

NEEDS_COMPLEX_IN = {'fft', 'fftn', 'hfft', 'ifft', 'ifftn', 'irfft', 'irfftn'}
NEEDS_REAL_IN = {
# These are already handled in lax.fft
# 'rfft', 'rfftn', 'ihfft',
'fftshift', 'ifftshift'
}

# TODO(micky774): Promote warnings to ValueErrors when deprecation is completed
# and uncomment the portion of NEEDS_REAL_IN which currently defers type
# checking to lax.fft. Deprecation began 4-18-24.
def _check_input_fft(func_name: str, x: Array):
kind = x.dtype.kind
suggest_alternative_msg = (
" or consider using a more appropriate fft function if applicable."
)
if func_name in NEEDS_COMPLEX_IN and kind != "c":
warnings.warn(
f"Passing non-complex valued inputs to {func_name} is deprecated and "
"will soon raise a ValueError. Please explicitly convert the input to a "
f"complex dtype before passing to {func_name} in order to suppress this "
"warning," + suggest_alternative_msg,
DeprecationWarning, stacklevel=2
)
return promote_dtypes_complex(x)[0]
if func_name in NEEDS_REAL_IN:
if kind == "c":
warnings.warn(
f"Passing complex-valued inputs to {func_name} is deprecated and "
"will soon raise a ValueError. To suppress this warning, please convert "
"to real values first, such as by using jnp.real or jnp.imag to take "
"the real or imaginary components respectively," + suggest_alternative_msg,
DeprecationWarning, stacklevel=2
)
elif kind != "f":
warnings.warn(
f"Passing integral inputs to {func_name} is deprecated and "
"will soon raise a ValueError. Please convert to a real-valued "
"floating-point input first.",
DeprecationWarning, stacklevel=2
)
return promote_dtypes_inexact(x)
return x


def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
if norm == "backward":
return jnp.array(1)
Expand All @@ -50,6 +97,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
full_name = f"jax.numpy.fft.{func_name}"
check_arraylike(full_name, a)
arr = jnp.asarray(a)
arr = _check_input_fft(func_name, arr)

if s is not None:
s = tuple(map(operator.index, s))
Expand Down Expand Up @@ -300,6 +348,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("fftshift", x)
x = jnp.asarray(x)
arr = _check_input_fft("fftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))
Expand All @@ -316,6 +365,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("ifftshift", x)
x = jnp.asarray(x)
arr = _check_input_fft("ifftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/scipy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def dct(x: Array, type: int = 2, n: int | None = None,
for a in range(x.ndim)])

N = x.shape[axis]
v = _dct_interleave(x, axis)
v, = promote_dtypes_complex(_dct_interleave(x, axis))
V = jnp.fft.fft(v, axis=axis)
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
out = V * _W4(N, k)
Expand All @@ -68,7 +68,7 @@ def dct(x: Array, type: int = 2, n: int | None = None,
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
v, = promote_dtypes_complex(_dct_interleave(_dct_interleave(x, axis1), axis2))
V = jnp.fft.fftn(v, axes=axes)
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis1])
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/scipy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,11 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
result = detrend_func(result)

# Apply window by multiplication
if jnp.iscomplexobj(win):
result, = promote_dtypes_complex(result)
result = win.reshape((1,) * len(batch_shape) + (1, nperseg)) * result

# Perform the fft on last axis. Zero-pads automatically
if sides == 'twosided':
result, = promote_dtypes_complex(result)
return jax.numpy.fft.fft(result, n=nfft)
else:
return jax.numpy.fft.rfft(result.real, n=nfft)
Expand Down
47 changes: 46 additions & 1 deletion jax/experimental/array_api/_fft_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,89 @@
# limitations under the License.

import jax.numpy as jnp

from jax._src.numpy.fft import NEEDS_COMPLEX_IN, NEEDS_REAL_IN as _NEEDS_REAL_IN
from jax._src.numpy.util import check_arraylike

NEEDS_REAL_IN = _NEEDS_REAL_IN.union({'rfft', 'rfftn', 'ihfft'})

# TODO(micky774): Remove when jax.numpy.fft deprecation completes. Deprecation
# began 4-18-24.
def _check_input_fft(func_name: str, x):
check_arraylike('jax.experimental.array_api.' + func_name, x)
arr = jnp.asarray(x)
kind = arr.dtype.kind
suggest_alternative_msg = (
" or consider using a more appropriate fft function if applicable."
)
if func_name in NEEDS_COMPLEX_IN and kind != "c":
raise ValueError(
f"{func_name} requires complex-valued input, but received input with type "
f"{arr.dtype} instead. Please explicitly convert to a complex-valued input "
"first," + suggest_alternative_msg,
)
if func_name in NEEDS_REAL_IN:
needs_real_msg = (
f"{func_name} requires real-valued floating-point input, but received "
f"input with type {arr.dtype} instead. Please convert to a real-valued "
"floating-point input first"
)
if kind == "c":
raise ValueError(
needs_real_msg + ", such as by using jnp.real or jnp.imag to take the "
"real or imaginary components respectively," + suggest_alternative_msg,
)
elif kind != "f":
raise ValueError(needs_real_msg + '.')
return arr

def fft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform."""
_check_input_fft('fft', x)
return jnp.fft.fft(x, n=n, axis=axis, norm=norm)

def ifft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse discrete Fourier transform."""
_check_input_fft('ifft', x)
return jnp.fft.ifft(x, n=n, axis=axis, norm=norm)

def fftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional discrete Fourier transform."""
_check_input_fft('fftn', x)
return jnp.fft.fftn(x, s=s, axes=axes, norm=norm)

def ifftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional inverse discrete Fourier transform."""
_check_input_fft('ifftn', x)
return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm)

def rfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform for real-valued input."""
_check_input_fft('rfft', x)
return jnp.fft.rfft(x, n=n, axis=axis, norm=norm)

def irfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse of rfft for complex-valued input."""
_check_input_fft('irfft', x)
return jnp.fft.irfft(x, n=n, axis=axis, norm=norm)

def rfftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional discrete Fourier transform for real-valued input."""
_check_input_fft('rfftn', x)
return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm)

def irfftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional inverse of rfftn for complex-valued input."""
_check_input_fft('irfftn', x)
return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm)

def hfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry."""
_check_input_fft('hfft', x)
return jnp.fft.hfft(x, n=n, axis=axis, norm=norm)

def ihfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry."""
_check_input_fft('ihfft', x)
return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm)

def fftfreq(n, /, *, d=1.0, device=None):
Expand All @@ -65,8 +108,10 @@ def rfftfreq(n, /, *, d=1.0, device=None):

def fftshift(x, /, *, axes=None):
"""Shift the zero-frequency component to the center of the spectrum."""
_check_input_fft('fftshift', x)
return jnp.fft.fftshift(x, axes=axes)

def ifftshift(x, /, *, axes=None):
"""Inverse of fftshift."""
_check_input_fft('ifftshift', x)
return jnp.fft.ifftshift(x, axes=axes)
30 changes: 28 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api

from jax.experimental.array_api._fft_functions import (
NEEDS_COMPLEX_IN, NEEDS_REAL_IN)
config.parse_flags_with_absl()

MAIN_NAMESPACE = {
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_dtypes_info(self, kind):
target_dict = control[kind]
assert info_dict == target_dict

class ArrayAPIErrors(absltest.TestCase):
class ArrayAPIErrors(jtu.JaxTestCase):
"""Test that our array API implementations raise errors where required"""

# TODO(micky774): Remove when jnp.clip deprecation is completed
Expand All @@ -347,6 +348,31 @@ def test_clip_complex(self):
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)

@jtu.sample_product(
[dict(dtype=dtype,func_name=func_name)
for real in [True, False]
for dtype in (jtu.dtypes.complex if real else jtu.dtypes.floating)
+ jtu.dtypes.integer + jtu.dtypes.boolean
for func_name in (NEEDS_REAL_IN if real else NEEDS_COMPLEX_IN)
])
def testFftWarnings(self, dtype, func_name):
shape = (2, 3, 4)
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
func = getattr(array_api.fft, func_name)

if func_name in NEEDS_COMPLEX_IN:
msg = "complex-valued input"
else:
msg = "real-valued"
if x.dtype.kind == 'c':
msg += ".*real or imaginary"
if x.dtype.kind in {'c', 'r'}:
msg += ".*or consider using a more"

with self.assertRaisesRegex(ValueError, expected_regex=msg):
func(x)


if __name__ == '__main__':
absltest.main()
Loading