@@ -44,23 +44,19 @@ def _check_norm(norm):
4444        )
4545
4646
47- def  _check_shapes_for_direct (xs , shape , axes ):
47+ def  _check_shapes_for_direct (s , shape , axes ):
4848    if  len (axes ) >  7 :  # Intel MKL supports up to 7D 
4949        return  False 
50-     if  not  ( len (xs )  ==  len (shape ) ):
51-         # full-dimensional transform 
50+     if  len (s )  !=  len (shape ):
51+         # not a  full-dimensional transform 
5252        return  False 
53-     if  not  ( len (set (axes )) ==  len (axes ) ):
53+     if  len (set (axes )) !=  len (axes ):
5454        # repeated axes 
5555        return  False 
56-     for  xsi , ai  in  zip (xs , axes ):
57-         try :
58-             sh_ai  =  shape [ai ]
59-         except  IndexError :
60-             raise  ValueError ("Invalid axis (%d) specified"  %  ai )
61- 
62-         if  not  (xsi  ==  sh_ai ):
63-             return  False 
56+     new_shape  =  tuple (shape [ax ] for  ax  in  axes )
57+     if  tuple (s ) !=  new_shape :
58+         # trimming or padding is needed 
59+         return  False 
6460    return  True 
6561
6662
@@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape):
7874        return  np .sqrt (fsc )
7975
8076
81- def  _cook_nd_args (a , s = None , axes = None , invreal = False ):
82-     if  s  is  None :
83-         shapeless  =  True 
84-         if  axes  is  None :
85-             s  =  list (a .shape )
86-         else :
87-             try :
88-                 s  =  [a .shape [i ] for  i  in  axes ]
89-             except  IndexError :
90-                 # fake s designed to trip the ValueError further down 
91-                 s  =  range (len (axes ) +  1 )
92-                 pass 
93-     else :
94-         shapeless  =  False 
95-     s  =  list (s )
96-     if  axes  is  None :
97-         axes  =  list (range (- len (s ), 0 ))
98-     if  len (s ) !=  len (axes ):
99-         raise  ValueError ("Shape and axes have different lengths." )
100-     if  invreal  and  shapeless :
101-         s [- 1 ] =  (a .shape [axes [- 1 ]] -  1 ) *  2 
102-     return  s , axes 
103- 
104- 
10577# copied from scipy.fft module 
10678# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py 
10779def  _datacopied (arr , original ):
@@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape):
129101    return  m_ind 
130102
131103
132- # copied from scipy.fftpack.helper 
133- def  _init_nd_shape_and_axes (x , shape , axes ):
134-     """Handle shape and axes arguments for n-dimensional transforms. 
135-     Returns the shape and axes in a standard form, taking into account negative 
136-     values and checking for various potential errors. 
137-     Parameters 
138-     ---------- 
139-     x : array_like 
140-         The input array. 
141-     shape : int or array_like of ints or None 
142-         The shape of the result.  If both `shape` and `axes` (see below) are 
143-         None, `shape` is ``x.shape``; if `shape` is None but `axes` is 
144-         not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. 
145-         If `shape` is -1, the size of the corresponding dimension of `x` is 
146-         used. 
147-     axes : int or array_like of ints or None 
148-         Axes along which the calculation is computed. 
149-         The default is over all axes. 
150-         Negative indices are automatically converted to their positive 
151-         counterpart. 
152-     Returns 
153-     ------- 
154-     shape : array 
155-         The shape of the result. It is a 1D integer array. 
156-     axes : array 
157-         The shape of the result. It is a 1D integer array. 
158-     """ 
159-     x  =  np .asarray (x )
160-     noshape  =  shape  is  None 
161-     noaxes  =  axes  is  None 
162- 
163-     if  noaxes :
164-         axes  =  np .arange (x .ndim , dtype = np .intc )
165-     else :
166-         axes  =  np .atleast_1d (axes )
167- 
168-     if  axes .size  ==  0 :
169-         axes  =  axes .astype (np .intc )
170- 
171-     if  not  axes .ndim  ==  1 :
172-         raise  ValueError ("when given, axes values must be a scalar or vector" )
173-     if  not  np .issubdtype (axes .dtype , np .integer ):
174-         raise  ValueError ("when given, axes values must be integers" )
175- 
176-     axes  =  np .where (axes  <  0 , axes  +  x .ndim , axes )
177- 
178-     if  axes .size  !=  0  and  (axes .max () >=  x .ndim  or  axes .min () <  0 ):
179-         raise  ValueError ("axes exceeds dimensionality of input" )
180-     if  axes .size  !=  0  and  np .unique (axes ).shape  !=  axes .shape :
181-         raise  ValueError ("all axes must be unique" )
182- 
183-     if  not  noshape :
184-         shape  =  np .atleast_1d (shape )
185-     elif  np .isscalar (x ):
186-         shape  =  np .array ([], dtype = np .intc )
187-     elif  noaxes :
188-         shape  =  np .array (x .shape , dtype = np .intc )
189-     else :
190-         shape  =  np .take (x .shape , axes )
191- 
192-     if  shape .size  ==  0 :
193-         shape  =  shape .astype (np .intc )
194- 
195-     if  shape .ndim  !=  1 :
196-         raise  ValueError ("when given, shape values must be a scalar or vector" )
197-     if  not  np .issubdtype (shape .dtype , np .integer ):
198-         raise  ValueError ("when given, shape values must be integers" )
199-     if  axes .shape  !=  shape .shape :
200-         raise  ValueError (
201-             "when given, axes and shape arguments have to be of the same length" 
202-         )
203- 
204-     shape  =  np .where (shape  ==  - 1 , np .array (x .shape )[axes ], shape )
205-     if  shape .size  !=  0  and  (shape  <  1 ).any ():
206-         raise  ValueError (f"invalid number of data points ({ shape }  )
207- 
208-     return  shape , axes 
209- 
210- 
211104def  _iter_complementary (x , axes , func , kwargs , result ):
212-     if  axes  is  None :
213-         # s and axes are None, direct N-D FFT 
214-         return  func (x , ** kwargs , out = result )
215105    x_shape  =  x .shape 
216106    nd  =  x .ndim 
217107    r  =  list (range (nd ))
@@ -260,9 +150,6 @@ def _iter_fftnd(
260150    direction = + 1 ,
261151    scale_function = lambda  ind : 1.0 ,
262152):
263-     a  =  np .asarray (a )
264-     s , axes  =  _init_nd_shape_and_axes (a , s , axes )
265- 
266153    # Combine the two, but in reverse, to end with the first axis given. 
267154    axes_and_s  =  list (zip (axes , s ))[::- 1 ]
268155    # We try to use in-place calculations where possible, which is 
@@ -309,13 +196,14 @@ def _output_dtype(dt):
309196def  _pad_array (arr , s , axes ):
310197    """Pads array arr with zeros to attain shape s associated with axes""" 
311198    arr_shape  =  arr .shape 
199+     new_shape  =  tuple (arr_shape [ax ] for  ax  in  axes )
200+     if  tuple (s ) ==  new_shape :
201+         return  arr 
202+ 
312203    no_padding  =  True 
313204    pad_widths  =  [(0 , 0 )] *  len (arr_shape )
314205    for  si , ai  in  zip (s , axes ):
315-         try :
316-             shp_i  =  arr_shape [ai ]
317-         except  IndexError :
318-             raise  ValueError (f"Invalid axis { ai }  )
206+         shp_i  =  arr_shape [ai ]
319207        if  si  >  shp_i :
320208            no_padding  =  False 
321209            pad_widths [ai ] =  (0 , si  -  shp_i )
@@ -345,14 +233,14 @@ def _trim_array(arr, s, axes):
345233    """ 
346234
347235    arr_shape  =  arr .shape 
236+     new_shape  =  tuple (arr_shape [ax ] for  ax  in  axes )
237+     if  tuple (s ) ==  new_shape :
238+         return  arr 
239+ 
348240    no_trim  =  True 
349241    ind  =  [slice (None , None , None )] *  len (arr_shape )
350242    for  si , ai  in  zip (s , axes ):
351-         try :
352-             shp_i  =  arr_shape [ai ]
353-         except  IndexError :
354-             raise  ValueError (f"Invalid axis { ai }  )
355-         if  si  <  shp_i :
243+         if  si  <  arr_shape [ai ]:
356244            no_trim  =  False 
357245            ind [ai ] =  slice (None , si , None )
358246    if  no_trim :
@@ -383,16 +271,11 @@ def _c2c_fftnd_impl(
383271    if  direction  not  in - 1 , + 1 ]:
384272        raise  ValueError ("Direction of FFT should +1 or -1" )
385273
274+     x  =  np .asarray (x )
386275    valid_dtypes  =  [np .complex64 , np .complex128 , np .float32 , np .float64 ]
387276    # _direct_fftnd requires complex type, and full-dimensional transform 
388-     if  isinstance (x , np .ndarray ) and  x .size  !=  0  and  x .ndim  >  1 :
389-         _direct  =  s  is  None  and  axes  is  None 
390-         if  _direct :
391-             _direct  =  x .ndim  <=  7   # Intel MKL only supports FFT up to 7D 
392-         if  not  _direct :
393-             xs , xa  =  _cook_nd_args (x , s , axes )
394-             if  _check_shapes_for_direct (xs , x .shape , xa ):
395-                 _direct  =  True 
277+     if  x .size  !=  0  and  x .ndim  >  1 :
278+         _direct  =  _check_shapes_for_direct (s , x .shape , axes )
396279        _direct  =  _direct  and  x .dtype  in  valid_dtypes 
397280    else :
398281        _direct  =  False 
@@ -405,14 +288,23 @@ def _c2c_fftnd_impl(
405288            out = out ,
406289        )
407290    else :
408-         if  s  is  None  and  x .dtype  in  valid_dtypes :
409-             x  =  np .asarray (x )
291+         new_shape  =  tuple (x .shape [ax ] for  ax  in  axes )
292+         if  (
293+             tuple (s ) ==  new_shape 
294+             and  x .dtype  in  valid_dtypes 
295+             and  len (set (axes )) ==  len (axes )
296+         ):
410297            if  out  is  None :
411298                res  =  np .empty_like (x , dtype = _output_dtype (x .dtype ))
412299            else :
413300                _validate_out_array (out , x , _output_dtype (x .dtype ))
414301                res  =  out 
415302
303+             # MKL is capable of doing batch N-D FFT, it is not required to 
304+             # manually loop over the batches as done in _iter_complementary and 
305+             # it is the reason for bad performance mentioned in the gh-issue-#67 
306+             # TODO: implement a batch N-D FFT using MKL 
307+             # _iter_complementary performs batches of N-D FFT 
416308            return  _iter_complementary (
417309                x ,
418310                axes ,
@@ -434,14 +326,9 @@ def _c2c_fftnd_impl(
434326
435327def  _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
436328    a  =  np .asarray (x )
437-     no_trim  =  (s  is  None ) and  (axes  is  None )
438-     s , axes  =  _cook_nd_args (a , s , axes )
439-     axes  =  [ax  +  a .ndim  if  ax  <  0  else  ax  for  ax  in  axes ]
440329    la  =  axes [- 1 ]
441- 
442330    # trim array, so that rfft avoids doing unnecessary computations 
443-     if  not  no_trim :
444-         a  =  _trim_array (a , s , axes )
331+     a  =  _trim_array (a , s , axes )
445332
446333    # last axis is not included since we calculate r2c FFT separately 
447334    # and not in the loop 
@@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
453340    a  =  _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
454341    res  =  a 
455342    if  len (s ) >  1 :
456- 
457343        len_axes  =  len (axes )
458344        if  len (set (axes )) ==  len_axes  and  len_axes  ==  a .ndim  and  len_axes  >  2 :
459-             if  not  no_trim :
460-                 ss  =  list (s )
461-                 ss [- 1 ] =  a .shape [la ]
462-                 a  =  _pad_array (a , tuple (ss ), axes )
345+             ss  =  list (s )
346+             ss [- 1 ] =  a .shape [la ]
347+             a  =  _pad_array (a , tuple (ss ), axes )
463348            # a series of ND c2c FFTs along last axis 
464349            ss , aa  =  _remove_axis (s , axes , - 1 )
465350            ind  =  [slice (None , None , 1 )] *  len (s )
@@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
494379
495380def  _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
496381    a  =  np .asarray (x )
497-     no_trim  =  (s  is  None ) and  (axes  is  None )
498-     s , axes  =  _cook_nd_args (a , s , axes , invreal = True )
499-     axes  =  [ax  +  a .ndim  if  ax  <  0  else  ax  for  ax  in  axes ]
500382    la  =  axes [- 1 ]
501-     if  not  no_trim :
502-         a  =  _trim_array (a , s , axes )
503383    if  len (s ) >  1 :
504384        len_axes  =  len (axes )
505385        if  len (set (axes )) ==  len_axes  and  len_axes  ==  a .ndim  and  len_axes  >  2 :
506-             if   not   no_trim : 
507-                  a  =  _pad_array (a , s , axes )
386+             a   =   _trim_array ( a ,  s ,  axes ) 
387+             a  =  _pad_array (a , s , axes )
508388            # a series of ND c2c FFTs along last axis 
509389            # due to need to write into a, we must copy 
510390            a  =  a  if  _datacopied (a , x ) else  a .copy ()
@@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
521401                tind  =  tuple (ind )
522402                a_inp  =  a [tind ]
523403                # out has real dtype and cannot be used in intermediate steps 
524-                 # ss and aa are reversed since np.irfftn uses forward order but  
525-                 # np .ifftn uses reverse order see numpy-gh-28950 
404+                 # ss and aa are reversed since np.fft. irfftn uses forward order 
405+                 # but np.fft .ifftn uses reverse order see numpy-gh-28950 
526406                _  =  _c2c_fftnd_impl (
527407                    a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1 
528408                )
0 commit comments