diff --git a/spaudiopy/sph.py b/spaudiopy/sph.py index 3621551..a9c0246 100644 --- a/spaudiopy/sph.py +++ b/spaudiopy/sph.py @@ -1164,7 +1164,7 @@ def sph_filterbank_reconstruction_factor(w_nm, num_secs, mode=None): return beta -def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode): +def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, mode, sh_type='real'): """Design spherical filter bank analysis and reconstruction matrix. Parameters @@ -1177,9 +1177,9 @@ def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode): Sector zenith/colatitude steering directions. c_n : (N,) array_like SH Modal weights, describing (axisymmetric) pattern. - sh_type : 'real' or 'complex' - mode : 'perfect' or 'energy' + mode : 'perfect' or 'energy' or 'pinv' Design achieves perfect reconstruction or energy reconstruction. + sh_type : 'real' or 'complex' Raises ------ @@ -1231,25 +1231,30 @@ def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode): # Preservation property if mode.lower() == 'perfect': - pres = 'amplitude' + beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs, + mode='amplitude') elif mode.lower() == 'energy': - pres = 'energy' + beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs, + mode='energy') + elif mode.lower() == 'pinv': + pass else: raise ValueError("Mode not implemented: " + mode) - beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs, mode=pres) # Reconstruction matrix if mode.lower() == 'perfect': - B = beta * repeat_per_order(1/(c_n/c_n[0])) * \ - sh_matrix(N_sph, sec_azi, sec_zen, sh_type) + B = beta * (repeat_per_order(1/(c_n/c_n[0])) * \ + sh_matrix(N_sph, sec_azi, sec_zen, sh_type)).conj().T elif mode.lower() == 'energy': - B = np.sqrt(beta) * repeat_per_order(1/(c_n/c_n[0])) * \ - sh_matrix(N_sph, sec_azi, sec_zen, sh_type) + B = np.sqrt(beta) * (repeat_per_order(1/(c_n/c_n[0])) * \ + sh_matrix(N_sph, sec_azi, sec_zen, sh_type)).conj().T + elif mode.lower() == 'pinv': + B = np.linalg.pinv(A) else: raise ValueError("Mode not implemented: " + mode) - return A, B.conj().T + return A, B def sh_mult(a_nm, b_nm, sh_type): diff --git a/tests/test_algo.py b/tests/test_algo.py index 788d765..5f995b8 100644 --- a/tests/test_algo.py +++ b/tests/test_algo.py @@ -23,7 +23,7 @@ def test_sph_filter_bank(test_n_sph): sec_dirs = spa.utils.cart2sph(*spa.grids.load_t_design(2*N_sph).T) c_n = spa.sph.maxre_modal_weights(N_sph) [A, B] = spa.sph.design_sph_filterbank(N_sph, sec_dirs[0], sec_dirs[1], - c_n, 'real', 'perfect') + c_n, mode='perfect') # diffuse SH signal in_nm = np.random.randn((N_sph+1)**2, 1000)