Skip to content

Commit

Permalink
sph:design_sph_filterbank introduce pinv and default to 'real'
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-hld committed Nov 5, 2024
1 parent 5f8af85 commit 7191384
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
27 changes: 16 additions & 11 deletions spaudiopy/sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ def sph_filterbank_reconstruction_factor(w_nm, num_secs, mode=None):
return beta


def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode):
def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, mode, sh_type='real'):
"""Design spherical filter bank analysis and reconstruction matrix.
Parameters
Expand All @@ -1177,9 +1177,9 @@ def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode):
Sector zenith/colatitude steering directions.
c_n : (N,) array_like
SH Modal weights, describing (axisymmetric) pattern.
sh_type : 'real' or 'complex'
mode : 'perfect' or 'energy'
mode : 'perfect' or 'energy' or 'pinv'
Design achieves perfect reconstruction or energy reconstruction.
sh_type : 'real' or 'complex'
Raises
------
Expand Down Expand Up @@ -1231,25 +1231,30 @@ def design_sph_filterbank(N_sph, sec_azi, sec_zen, c_n, sh_type, mode):

# Preservation property
if mode.lower() == 'perfect':
pres = 'amplitude'
beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs,
mode='amplitude')
elif mode.lower() == 'energy':
pres = 'energy'
beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs,
mode='energy')
elif mode.lower() == 'pinv':
pass
else:
raise ValueError("Mode not implemented: " + mode)

beta = sph_filterbank_reconstruction_factor(A[0, :], num_secs, mode=pres)

# Reconstruction matrix
if mode.lower() == 'perfect':
B = beta * repeat_per_order(1/(c_n/c_n[0])) * \
sh_matrix(N_sph, sec_azi, sec_zen, sh_type)
B = beta * (repeat_per_order(1/(c_n/c_n[0])) * \
sh_matrix(N_sph, sec_azi, sec_zen, sh_type)).conj().T
elif mode.lower() == 'energy':
B = np.sqrt(beta) * repeat_per_order(1/(c_n/c_n[0])) * \
sh_matrix(N_sph, sec_azi, sec_zen, sh_type)
B = np.sqrt(beta) * (repeat_per_order(1/(c_n/c_n[0])) * \
sh_matrix(N_sph, sec_azi, sec_zen, sh_type)).conj().T
elif mode.lower() == 'pinv':
B = np.linalg.pinv(A)
else:
raise ValueError("Mode not implemented: " + mode)

return A, B.conj().T
return A, B


def sh_mult(a_nm, b_nm, sh_type):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_sph_filter_bank(test_n_sph):
sec_dirs = spa.utils.cart2sph(*spa.grids.load_t_design(2*N_sph).T)
c_n = spa.sph.maxre_modal_weights(N_sph)
[A, B] = spa.sph.design_sph_filterbank(N_sph, sec_dirs[0], sec_dirs[1],
c_n, 'real', 'perfect')
c_n, mode='perfect')

# diffuse SH signal
in_nm = np.random.randn((N_sph+1)**2, 1000)
Expand Down

0 comments on commit 7191384

Please sign in to comment.